Luigi commited on
Commit
a96d072
·
1 Parent(s): 11139f5

Reduce memory usage

Browse files
Files changed (1) hide show
  1. train.py +47 -16
train.py CHANGED
@@ -7,14 +7,16 @@ from datasets import load_dataset
7
  import torch
8
  import os
9
  import math
10
- import wandb
11
  from transformers.integrations import WandbCallback
12
 
13
  PROJECT_NAME = 'SmolLM2-135M-Instruct-TaiwanChat'
14
  BASE_MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
15
  DATASET_ID = "yentinglin/TaiwanChat"
16
- N_SAMPLES = 1000
17
- MAX_LEN = 256
 
 
 
18
 
19
  # Tell wandb which project to use, and that you want to log your model
20
  os.environ["WANDB_PROJECT"] = f'{PROJECT_NAME}_LOCAL'
@@ -30,15 +32,31 @@ print(f'Device is {device_str}')
30
 
31
  # Load Model & Tokenizer
32
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
33
- model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
34
- model.to(device_str)
35
 
36
  # Prepare the TaiwanChat Dataset
37
  # Load and split into train/validation
38
- full_ds = load_dataset(DATASET_ID, split=f"train[:{N_SAMPLES}]")
39
- splits = full_ds.train_test_split(test_size=0.1, seed=42)
40
- train_ds = splits['train']
41
- val_ds = splits['test']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Preprocessing function
44
  def preprocess_examples(examples):
@@ -78,24 +96,37 @@ def preprocess_examples(examples):
78
  "attention_mask": attention_mask,
79
  "labels": labels}
80
 
81
- # Tokenize and collate
82
- tokenized_train = train_ds.map(
83
- preprocess_examples, batched=True, remove_columns=train_ds.column_names
 
 
 
84
  )
85
- tokenized_val = val_ds.map(
86
- preprocess_examples, batched=True, remove_columns=val_ds.column_names
 
 
 
 
87
  )
88
 
89
  data_collator = DataCollatorForLanguageModeling(
90
  tokenizer=tokenizer, mlm=False
91
  )
92
 
 
 
 
 
 
93
  # Define training arguments with evaluation
94
  training_args = TrainingArguments(
 
95
  output_dir=PROJECT_NAME,
96
- per_device_train_batch_size=4,
97
  learning_rate=5e-5,
98
- num_train_epochs=3,
99
  fp16=False if device_str == 'xpu' else True,
100
  bf16=True if device_str == 'xpu' else False,
101
  logging_steps=1000,
 
7
  import torch
8
  import os
9
  import math
 
10
  from transformers.integrations import WandbCallback
11
 
12
  PROJECT_NAME = 'SmolLM2-135M-Instruct-TaiwanChat'
13
  BASE_MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
14
  DATASET_ID = "yentinglin/TaiwanChat"
15
+ N_SAMPLES = 9000
16
+ MAX_LEN = 512
17
+ VAL_FRACTION = 0.1
18
+ PER_DEVICE_TRAIN_BATCH_SIZE=1
19
+ NUM_TRAIN_EPOCHS=3
20
 
21
  # Tell wandb which project to use, and that you want to log your model
22
  os.environ["WANDB_PROJECT"] = f'{PROJECT_NAME}_LOCAL'
 
32
 
33
  # Load Model & Tokenizer
34
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
35
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, low_cpu_mem_usage=True )
36
+ model.to(device_str, dtype=torch.bfloat16 if device_str == 'xpu' else torch.float16)
37
 
38
  # Prepare the TaiwanChat Dataset
39
  # Load and split into train/validation
40
+ # 1) Load the raw train split as a stream
41
+ raw_stream = load_dataset(
42
+ DATASET_ID,
43
+ split="train", # no slicing here
44
+ streaming=True
45
+ )
46
+
47
+ # 2) (Optional) Shuffle the stream with a buffer
48
+ shuffled = raw_stream.shuffle(buffer_size=5_000, seed=42)
49
+
50
+ # 3) Take exactly N_SAMPLES examples
51
+ limited = shuffled.take(N_SAMPLES)
52
+
53
+ # 4) Split into train / validation
54
+ n_val = int(N_SAMPLES * VAL_FRACTION)
55
+ n_train = N_SAMPLES - n_val
56
+
57
+ train_stream = limited.take(n_train)
58
+ val_stream = limited.skip(n_train).take(n_val)
59
+
60
 
61
  # Preprocessing function
62
  def preprocess_examples(examples):
 
96
  "attention_mask": attention_mask,
97
  "labels": labels}
98
 
99
+ # 5) Tokenize on the fly with a small batch
100
+ tokenized_train = train_stream.map(
101
+ preprocess_examples,
102
+ batched=True,
103
+ batch_size=32, # controls RAM for each map() call
104
+ remove_columns=["messages"] # or whatever your raw column names are
105
  )
106
+
107
+ tokenized_val = val_stream.map(
108
+ preprocess_examples,
109
+ batched=True,
110
+ batch_size=32,
111
+ remove_columns=["messages"]
112
  )
113
 
114
  data_collator = DataCollatorForLanguageModeling(
115
  tokenizer=tokenizer, mlm=False
116
  )
117
 
118
+ # 1) Compute steps_per_epoch from your constants:
119
+ steps_per_epoch = math.ceil(N_SAMPLES / PER_DEVICE_TRAIN_BATCH_SIZE)
120
+ total_steps = steps_per_epoch * NUM_TRAIN_EPOCHS
121
+
122
+
123
  # Define training arguments with evaluation
124
  training_args = TrainingArguments(
125
+ max_steps=total_steps,
126
  output_dir=PROJECT_NAME,
127
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
128
  learning_rate=5e-5,
129
+ num_train_epochs=NUM_TRAIN_EPOCHS,
130
  fp16=False if device_str == 'xpu' else True,
131
  bf16=True if device_str == 'xpu' else False,
132
  logging_steps=1000,