Luigi commited on
Commit
8249056
·
1 Parent(s): bc9cad6

Add padding and truncation on examples to fix max length

Browse files
Files changed (2) hide show
  1. train.py +28 -22
  2. train_with_unsloth.py +27 -22
train.py CHANGED
@@ -12,7 +12,7 @@ from transformers.integrations import WandbCallback
12
  PROJECT_NAME='SmolLM2-135M-Instruct-TaiwanChat'
13
  BASE_MODEL_ID="HuggingFaceTB/SmolLM2-135M-Instruct"
14
  DATASET_ID="yentinglin/TaiwanChat"
15
- N_SAMPLES=40000
16
  MAX_LEN=512
17
 
18
  # Tell wandb which project to use, and that you want to log your model
@@ -38,38 +38,44 @@ dataset = load_dataset(DATASET_ID, split=f"train[:{N_SAMPLES}]")
38
 
39
  def preprocess_examples(examples):
40
  chats = examples["messages"]
41
- # 1) Render as ChatML with the “assistant:” generation prompt
42
  text = tokenizer.apply_chat_template(
43
- chats,
44
- tokenize=False,
45
- add_generation_prompt=True
46
  )
47
- # 2) Tokenize
48
- toks = tokenizer(text, truncation=True, max_length=MAX_LEN)
49
- input_ids = toks["input_ids"]
50
- attention_mask = toks["attention_mask"]
51
-
52
- # 3) Build labels that mask all tokens _before_ the assistant turn
53
- # so we only compute loss on the assistant’s response
54
- # Find the index where the assistant prompt token <|im_start|>assistant occurs:
55
- role_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>assistant")
56
- # find first occurrence
57
- try:
58
- idx = input_ids.index(role_token_id)
59
- except ValueError:
60
- idx = 0
61
- # +2 to skip the role token and the following newline
62
- start_of_reply = idx + 2
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  labels = [-100] * start_of_reply + input_ids[start_of_reply:]
65
 
 
 
 
 
 
 
66
  return {
67
  "input_ids": input_ids,
68
  "attention_mask": attention_mask,
69
  "labels": labels,
70
  }
71
 
72
-
73
  # Tokenization & Data Collator
74
  tokenized_ds = dataset.map(
75
  preprocess_examples,
 
12
  PROJECT_NAME='SmolLM2-135M-Instruct-TaiwanChat'
13
  BASE_MODEL_ID="HuggingFaceTB/SmolLM2-135M-Instruct"
14
  DATASET_ID="yentinglin/TaiwanChat"
15
+ N_SAMPLES=100
16
  MAX_LEN=512
17
 
18
  # Tell wandb which project to use, and that you want to log your model
 
38
 
39
  def preprocess_examples(examples):
40
  chats = examples["messages"]
41
+ # 1) Render ChatML
42
  text = tokenizer.apply_chat_template(
43
+ chats, tokenize=False, add_generation_prompt=True
 
 
44
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # 2) Tokenize _and_ pad/truncate to MAX_LEN
47
+ toks = tokenizer(
48
+ text,
49
+ truncation=True,
50
+ padding="max_length",
51
+ max_length=MAX_LEN,
52
+ )
53
+ input_ids = toks["input_ids"]
54
+ attention_mask= toks["attention_mask"]
55
+
56
+ # 3) Find where the assistant reply starts
57
+ role_id = tokenizer.convert_tokens_to_ids("<|im_start|>assistant")
58
+ if role_id in input_ids:
59
+ idx = input_ids.index(role_id)
60
+ start_of_reply = idx + 2
61
+ else:
62
+ start_of_reply = 0
63
+
64
+ # 4) Build labels: -100 before reply, then copy the rest
65
  labels = [-100] * start_of_reply + input_ids[start_of_reply:]
66
 
67
+ # 5) Pad or truncate labels to EXACTLY len(input_ids)
68
+ if len(labels) < len(input_ids):
69
+ labels += [-100] * (len(input_ids) - len(labels))
70
+ else:
71
+ labels = labels[: len(input_ids)]
72
+
73
  return {
74
  "input_ids": input_ids,
75
  "attention_mask": attention_mask,
76
  "labels": labels,
77
  }
78
 
 
79
  # Tokenization & Data Collator
80
  tokenized_ds = dataset.map(
81
  preprocess_examples,
train_with_unsloth.py CHANGED
@@ -71,38 +71,43 @@ val_ds = splits["test"]
71
  # Preprocessing Function
72
  def preprocess_examples(examples):
73
  chats = examples["messages"]
74
- # 1) Render as ChatML with the “assistant:” generation prompt
75
  text = tokenizer.apply_chat_template(
76
- chats,
77
- tokenize=False,
78
- add_generation_prompt=True
79
  )
80
- # 2) Tokenize
81
- toks = tokenizer(text, truncation=True, max_length=MAX_LEN)
82
- input_ids = toks["input_ids"]
83
- attention_mask = toks["attention_mask"]
84
-
85
- # 3) Build labels that mask all tokens _before_ the assistant turn
86
- # so we only compute loss on the assistant’s response
87
- # Find the index where the assistant prompt token <|im_start|>assistant occurs:
88
- role_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>assistant")
89
- # find first occurrence
90
- try:
91
- idx = input_ids.index(role_token_id)
92
- except ValueError:
93
- idx = 0
94
- # +2 to skip the role token and the following newline
95
- start_of_reply = idx + 2
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  labels = [-100] * start_of_reply + input_ids[start_of_reply:]
98
 
 
 
 
 
 
 
99
  return {
100
  "input_ids": input_ids,
101
  "attention_mask": attention_mask,
102
  "labels": labels,
103
  }
104
-
105
-
106
  # Tokenization & Data Collator
107
  tokenized_train = train_ds.map(
108
  preprocess_examples, batched=True, remove_columns=train_ds.column_names
 
71
  # Preprocessing Function
72
  def preprocess_examples(examples):
73
  chats = examples["messages"]
74
+ # 1) Render ChatML
75
  text = tokenizer.apply_chat_template(
76
+ chats, tokenize=False, add_generation_prompt=True
 
 
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # 2) Tokenize _and_ pad/truncate to MAX_LEN
80
+ toks = tokenizer(
81
+ text,
82
+ truncation=True,
83
+ padding="max_length",
84
+ max_length=MAX_LEN,
85
+ )
86
+ input_ids = toks["input_ids"]
87
+ attention_mask= toks["attention_mask"]
88
+
89
+ # 3) Find where the assistant reply starts
90
+ role_id = tokenizer.convert_tokens_to_ids("<|im_start|>assistant")
91
+ if role_id in input_ids:
92
+ idx = input_ids.index(role_id)
93
+ start_of_reply = idx + 2
94
+ else:
95
+ start_of_reply = 0
96
+
97
+ # 4) Build labels: -100 before reply, then copy the rest
98
  labels = [-100] * start_of_reply + input_ids[start_of_reply:]
99
 
100
+ # 5) Pad or truncate labels to EXACTLY len(input_ids)
101
+ if len(labels) < len(input_ids):
102
+ labels += [-100] * (len(input_ids) - len(labels))
103
+ else:
104
+ labels = labels[: len(input_ids)]
105
+
106
  return {
107
  "input_ids": input_ids,
108
  "attention_mask": attention_mask,
109
  "labels": labels,
110
  }
 
 
111
  # Tokenization & Data Collator
112
  tokenized_train = train_ds.map(
113
  preprocess_examples, batched=True, remove_columns=train_ds.column_names