Llama 3.2 3B - GSM8K PPO Fine-tuned (Critic)
This model is a critic checkpoint from Llama 3.2 3B fine-tuned on GSM8K using PPO (Proximal Policy Optimization) with the veRL framework.
Model Description
- Base Model: meta-llama/Llama-3.2-3B
- Training Framework: veRL (Versatile Reinforcement Learning)
- Training Method: PPO (Proximal Policy Optimization)
- Dataset: GSM8K (Grade School Math 8K)
- Task: Mathematical reasoning and problem-solving
- Checkpoint Step: 467
- Evaluation Score: 0.467
- Model Type: critic
Training Details
This checkpoint was trained using PPO on the GSM8K dataset to improve mathematical reasoning capabilities. The model was optimized using reward-based learning to generate more accurate step-by-step solutions to math word problems.
The checkpoint was automatically selected using best-of-n evaluation across multiple training steps, ensuring optimal performance.
Usage
Basic Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model = AutoModelForCausalLM.from_pretrained(
"samhitha2601/llama3.2-3b-ppo-critic",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo-critic")
# Example GSM8K problem
prompt = """Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: Let's solve this step by step:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Chat Format (if using Instruct variant)
messages = [
{"role": "user", "content": "Solve this math problem: If a train travels 60 miles per hour for 2.5 hours, how far does it travel?"}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Performance
This model has been trained with PPO to maximize reward on GSM8K problems, showing improved:
- Step-by-step reasoning
- Arithmetic accuracy
- Problem decomposition
- Solution clarity
Training Infrastructure
- Framework: veRL (Versatile Reinforcement Learning)
- Algorithm: PPO
- Distributed Training: FSDP (Fully Sharded Data Parallel)
- World Size: 2 (2 GPUs/ranks)
Limitations
- Primarily optimized for GSM8K-style math problems
- May not generalize well to other domains without fine-tuning
- Mathematical reasoning is limited to the complexity seen in GSM8K
- Still susceptible to arithmetic errors on complex calculations
Citation
If you use this model, please cite:
@misc{llama32-gsm8k-ppo,
title={Llama 3.2 3B Fine-tuned on GSM8K with PPO},
author={Your Name},
year={2025},
howpublished={\url{https://huggingface.co/samhitha2601/llama3.2-3b-ppo-critic}},
}
Acknowledgments
- Base Model: Meta AI (Llama 3.2)
- Dataset: GSM8K by OpenAI
- Training Framework: veRL
- Training Method: PPO (Proximal Policy Optimization)
- Downloads last month
- 11
Model tree for samhitha2601/llama3.2-3b-ppo-critic
Base model
meta-llama/Llama-3.2-3B