Llama 3.2 3B - GSM8K PPO Fine-tuned (Critic)

This model is a critic checkpoint from Llama 3.2 3B fine-tuned on GSM8K using PPO (Proximal Policy Optimization) with the veRL framework.

Model Description

  • Base Model: meta-llama/Llama-3.2-3B
  • Training Framework: veRL (Versatile Reinforcement Learning)
  • Training Method: PPO (Proximal Policy Optimization)
  • Dataset: GSM8K (Grade School Math 8K)
  • Task: Mathematical reasoning and problem-solving
  • Checkpoint Step: 467
  • Evaluation Score: 0.467
  • Model Type: critic

Training Details

This checkpoint was trained using PPO on the GSM8K dataset to improve mathematical reasoning capabilities. The model was optimized using reward-based learning to generate more accurate step-by-step solutions to math word problems.

The checkpoint was automatically selected using best-of-n evaluation across multiple training steps, ensuring optimal performance.

Usage

Basic Inference

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "samhitha2601/llama3.2-3b-ppo-critic",
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("samhitha2601/llama3.2-3b-ppo-critic")

# Example GSM8K problem
prompt = """Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Answer: Let's solve this step by step:"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.7,
    do_sample=True,
    top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Chat Format (if using Instruct variant)

messages = [
    {"role": "user", "content": "Solve this math problem: If a train travels 60 miles per hour for 2.5 hours, how far does it travel?"}
]

inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

outputs = model.generate(inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Performance

This model has been trained with PPO to maximize reward on GSM8K problems, showing improved:

  • Step-by-step reasoning
  • Arithmetic accuracy
  • Problem decomposition
  • Solution clarity

Training Infrastructure

  • Framework: veRL (Versatile Reinforcement Learning)
  • Algorithm: PPO
  • Distributed Training: FSDP (Fully Sharded Data Parallel)
  • World Size: 2 (2 GPUs/ranks)

Limitations

  • Primarily optimized for GSM8K-style math problems
  • May not generalize well to other domains without fine-tuning
  • Mathematical reasoning is limited to the complexity seen in GSM8K
  • Still susceptible to arithmetic errors on complex calculations

Citation

If you use this model, please cite:

@misc{llama32-gsm8k-ppo,
  title={Llama 3.2 3B Fine-tuned on GSM8K with PPO},
  author={Your Name},
  year={2025},
  howpublished={\url{https://huggingface.co/samhitha2601/llama3.2-3b-ppo-critic}},
}

Acknowledgments

  • Base Model: Meta AI (Llama 3.2)
  • Dataset: GSM8K by OpenAI
  • Training Framework: veRL
  • Training Method: PPO (Proximal Policy Optimization)
Downloads last month
11
Video Preview
loading

Model tree for samhitha2601/llama3.2-3b-ppo-critic

Finetuned
(342)
this model

Dataset used to train samhitha2601/llama3.2-3b-ppo-critic