gpt2-personal-assistant / train_chatbot.py
hmnshudhmn24's picture
Upload 11 files
bcb577e verified
raw
history blame
2.5 kB
# train_chatbot.py
import os
from datasets import load_dataset
from transformers import (
GPT2TokenizerFast,
GPT2LMHeadModel,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments
)
import torch
# === EDITABLE SETTINGS ===
HF_USERNAME = "hmnshudhmn24"
REPO_ID = f"{HF_USERNAME}/gpt2-personal-assistant"
BASE_MODEL = "gpt2"
OUTPUT_DIR = "./results"
MAX_TRAIN_SAMPLES = 4000
MAX_VAL_SAMPLES = 500
EPOCHS = 1
BATCH_SIZE = 4
LEARNING_RATE = 5e-5
# =========================
def prepare_dataset():
ds = load_dataset("daily_dialog")
def to_text(ex):
dialog = ex["dialog"]
text = "\n".join(dialog)
return {"text": text}
ds = ds.map(to_text, remove_columns=ds["train"].column_names)
ds["train"] = ds["train"].select(range(min(MAX_TRAIN_SAMPLES, len(ds["train"]))))
ds["validation"] = ds["validation"].select(range(min(MAX_VAL_SAMPLES, len(ds["validation"]))))
return ds
def main():
tokenizer = GPT2TokenizerFast.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
model = GPT2LMHeadModel.from_pretrained(BASE_MODEL)
model.resize_token_embeddings(len(tokenizer))
ds = prepare_dataset()
def tokenize_batch(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
tokenized = ds.map(tokenize_batch, batched=True, remove_columns=["text"])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
overwrite_output_dir=True,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=LEARNING_RATE,
weight_decay=0.01,
fp16=torch.cuda.is_available(),
push_to_hub=False,
logging_steps=100
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
data_collator=data_collator,
tokenizer=tokenizer
)
trainer.train()
save_path = "./gpt2-personal-assistant"
os.makedirs(save_path, exist_ok=True)
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model and tokenizer saved to {save_path}")
if __name__ == "__main__":
main()