Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| QLoRA Fine-tuning script for OpenAI OSS 120B model | |
| Using smangrul/ad-copy-generation dataset for advertisement copy generation | |
| """ | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TrainingArguments, | |
| pipeline, | |
| logging, | |
| ) | |
| from peft import LoraConfig, PeftModel, TaskType, get_peft_model | |
| from trl import SFTTrainer | |
| import warnings | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| logging.set_verbosity(logging.CRITICAL) | |
| # Configuration | |
| class Config: | |
| # Model configuration | |
| model_name = "microsoft/DialoGPT-medium" # Replace with actual OpenAI OSS 120B model name | |
| dataset_name = "smangrul/ad-copy-generation" | |
| # Training parameters | |
| output_dir = "./sft_results" | |
| num_train_epochs = 3 | |
| per_device_train_batch_size = 1 | |
| gradient_accumulation_steps = 4 | |
| optim = "paged_adamw_32bit" | |
| save_steps = 25 | |
| logging_steps = 25 | |
| learning_rate = 2e-4 | |
| weight_decay = 0.001 | |
| fp16 = False | |
| bf16 = False | |
| max_grad_norm = 0.3 | |
| max_steps = -1 | |
| warmup_ratio = 0.03 | |
| group_by_length = True | |
| lr_scheduler_type = "constant" | |
| report_to = "tensorboard" | |
| # QLoRA parameters | |
| lora_alpha = 16 | |
| lora_dropout = 0.1 | |
| lora_r = 64 | |
| # bitsandbytes parameters | |
| use_4bit = True | |
| bnb_4bit_compute_dtype = "float16" | |
| bnb_4bit_quant_type = "nf4" | |
| use_nested_quant = False | |
| # SFT parameters | |
| max_seq_length = 512 | |
| packing = False | |
| def create_bnb_config(): | |
| """Create BitsAndBytesConfig for 4-bit quantization""" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=Config.use_4bit, | |
| bnb_4bit_quant_type=Config.bnb_4bit_quant_type, | |
| bnb_4bit_compute_dtype=getattr(torch, Config.bnb_4bit_compute_dtype), | |
| bnb_4bit_use_double_quant=Config.use_nested_quant, | |
| ) | |
| return bnb_config | |
| def load_model_and_tokenizer(): | |
| """Load model and tokenizer with quantization""" | |
| print("Loading model and tokenizer...") | |
| # Create BnB config | |
| bnb_config = create_bnb_config() | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| Config.model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| use_auth_token=True, # If using gated model | |
| ) | |
| model.config.use_cache = False | |
| model.config.pretraining_tp = 1 | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| Config.model_name, | |
| trust_remote_code=True, | |
| use_auth_token=True, # If using gated model | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| return model, tokenizer | |
| def create_peft_config(): | |
| """Create PEFT (LoRA) configuration""" | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=Config.lora_r, | |
| lora_alpha=Config.lora_alpha, | |
| lora_dropout=Config.lora_dropout, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| ) | |
| return peft_config | |
| def load_and_prepare_dataset(tokenizer): | |
| """Load and prepare the dataset""" | |
| print("Loading dataset...") | |
| # Load dataset | |
| dataset = load_dataset(Config.dataset_name, split="train") | |
| print(f"Dataset loaded: {len(dataset)} samples") | |
| # Format dataset for chat completion | |
| def format_prompts(examples): | |
| texts = [] | |
| for conversation in examples["conversations"]: | |
| if len(conversation) >= 2: | |
| user_msg = conversation[0]["value"] | |
| assistant_msg = conversation[1]["value"] | |
| # Format as chat template | |
| text = f"### Human: {user_msg}\n### Assistant: {assistant_msg}{tokenizer.eos_token}" | |
| texts.append(text) | |
| else: | |
| # Fallback for malformed data | |
| texts.append(f"### Human: Create an advertisement\n### Assistant: {conversation[0]['value']}{tokenizer.eos_token}") | |
| return {"text": texts} | |
| # Apply formatting | |
| dataset = dataset.map( | |
| format_prompts, | |
| batched=True, | |
| remove_columns=dataset.column_names | |
| ) | |
| return dataset | |
| def create_training_arguments(): | |
| """Create training arguments""" | |
| training_arguments = TrainingArguments( | |
| output_dir=Config.output_dir, | |
| num_train_epochs=Config.num_train_epochs, | |
| per_device_train_batch_size=Config.per_device_train_batch_size, | |
| gradient_accumulation_steps=Config.gradient_accumulation_steps, | |
| optim=Config.optim, | |
| save_steps=Config.save_steps, | |
| logging_steps=Config.logging_steps, | |
| learning_rate=Config.learning_rate, | |
| weight_decay=Config.weight_decay, | |
| fp16=Config.fp16, | |
| bf16=Config.bf16, | |
| max_grad_norm=Config.max_grad_norm, | |
| max_steps=Config.max_steps, | |
| warmup_ratio=Config.warmup_ratio, | |
| group_by_length=Config.group_by_length, | |
| lr_scheduler_type=Config.lr_scheduler_type, | |
| report_to=Config.report_to, | |
| save_strategy="steps", | |
| evaluation_strategy="no", | |
| load_best_model_at_end=False, | |
| push_to_hub=False, | |
| remove_unused_columns=False, | |
| ) | |
| return training_arguments | |
| def main(): | |
| """Main fine-tuning function""" | |
| print("π Starting QLoRA fine-tuning of OpenAI OSS 120B model") | |
| # Check CUDA availability | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for this training script") | |
| print(f"Using GPU: {torch.cuda.get_device_name()}") | |
| print(f"Available VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| # Load model and tokenizer | |
| model, tokenizer = load_model_and_tokenizer() | |
| # Apply PEFT | |
| peft_config = create_peft_config() | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| # Load and prepare dataset | |
| dataset = load_and_prepare_dataset(tokenizer) | |
| # Create training arguments | |
| training_arguments = create_training_arguments() | |
| # Create trainer | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=dataset, | |
| peft_config=peft_config, | |
| dataset_text_field="text", | |
| max_seq_length=Config.max_seq_length, | |
| tokenizer=tokenizer, | |
| args=training_arguments, | |
| packing=Config.packing, | |
| ) | |
| # Start training | |
| print("π₯ Starting training...") | |
| trainer.train() | |
| # Save model | |
| print("πΎ Saving model...") | |
| trainer.model.save_pretrained(Config.output_dir) | |
| tokenizer.save_pretrained(Config.output_dir) | |
| print("β Training completed!") | |
| # Test the model | |
| test_model(trainer.model, tokenizer) | |
| def test_model(model, tokenizer): | |
| """Test the fine-tuned model""" | |
| print("\nπ§ͺ Testing the fine-tuned model...") | |
| # Test prompts | |
| test_prompts = [ | |
| "Create an advertisement for a new smartphone with advanced camera features", | |
| "Write ad copy for an eco-friendly clothing brand targeting young professionals", | |
| "Generate marketing content for a fitness app with AI personal trainer", | |
| ] | |
| for prompt in test_prompts: | |
| formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generated_text = response[len(formatted_prompt):].strip() | |
| print(f"\nπ Prompt: {prompt}") | |
| print(f"π Generated: {generated_text}") | |
| print("-" * 50) | |
| if __name__ == "__main__": | |
| # Set environment variables | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| main() |