Luigi commited on
Commit
7ee8171
·
1 Parent(s): a96d072

remove evaluation because it leads to crash due to oom on my pc

Browse files
Files changed (1) hide show
  1. train.py +3 -8
train.py CHANGED
@@ -12,10 +12,10 @@ from transformers.integrations import WandbCallback
12
  PROJECT_NAME = 'SmolLM2-135M-Instruct-TaiwanChat'
13
  BASE_MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
14
  DATASET_ID = "yentinglin/TaiwanChat"
15
- N_SAMPLES = 9000
16
  MAX_LEN = 512
17
  VAL_FRACTION = 0.1
18
- PER_DEVICE_TRAIN_BATCH_SIZE=1
19
  NUM_TRAIN_EPOCHS=3
20
 
21
  # Tell wandb which project to use, and that you want to log your model
@@ -45,7 +45,7 @@ raw_stream = load_dataset(
45
  )
46
 
47
  # 2) (Optional) Shuffle the stream with a buffer
48
- shuffled = raw_stream.shuffle(buffer_size=5_000, seed=42)
49
 
50
  # 3) Take exactly N_SAMPLES examples
51
  limited = shuffled.take(N_SAMPLES)
@@ -131,10 +131,6 @@ training_args = TrainingArguments(
131
  bf16=True if device_str == 'xpu' else False,
132
  logging_steps=1000,
133
  save_steps=5000,
134
- eval_strategy="steps",
135
- eval_steps=1000,
136
- load_best_model_at_end=True,
137
- metric_for_best_model="perplexity",
138
  greater_is_better=False,
139
 
140
  # W&B integration
@@ -161,7 +157,6 @@ trainer = Trainer(
161
  model=model,
162
  args=training_args,
163
  train_dataset=tokenized_train,
164
- eval_dataset=tokenized_val,
165
  compute_metrics=compute_metrics,
166
  data_collator=data_collator,
167
  callbacks=[WandbCallback],
 
12
  PROJECT_NAME = 'SmolLM2-135M-Instruct-TaiwanChat'
13
  BASE_MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
14
  DATASET_ID = "yentinglin/TaiwanChat"
15
+ N_SAMPLES = 3000
16
  MAX_LEN = 512
17
  VAL_FRACTION = 0.1
18
+ PER_DEVICE_TRAIN_BATCH_SIZE=8
19
  NUM_TRAIN_EPOCHS=3
20
 
21
  # Tell wandb which project to use, and that you want to log your model
 
45
  )
46
 
47
  # 2) (Optional) Shuffle the stream with a buffer
48
+ shuffled = raw_stream.shuffle(buffer_size=100, seed=42)
49
 
50
  # 3) Take exactly N_SAMPLES examples
51
  limited = shuffled.take(N_SAMPLES)
 
131
  bf16=True if device_str == 'xpu' else False,
132
  logging_steps=1000,
133
  save_steps=5000,
 
 
 
 
134
  greater_is_better=False,
135
 
136
  # W&B integration
 
157
  model=model,
158
  args=training_args,
159
  train_dataset=tokenized_train,
 
160
  compute_metrics=compute_metrics,
161
  data_collator=data_collator,
162
  callbacks=[WandbCallback],