|
|
--- |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
base_model: Shekswess/trlm-stage-2-sft-final-2 |
|
|
tags: |
|
|
- trl |
|
|
- dpo |
|
|
- preference-alignment |
|
|
- reasoning |
|
|
- generated_from_trainer |
|
|
model-index: |
|
|
- name: trlm-stage-3-dpo-final-2 |
|
|
results: [] |
|
|
--- |
|
|
|
|
|
# Tiny Reasoning Language Model (trlm-135) |
|
|
|
|
|
 |
|
|
|
|
|
## Table of Contents |
|
|
|
|
|
1. [Model Summary](#model-summary) |
|
|
2. [Post-Training Pipeline](#post-training-pipeline) |
|
|
3. [How to use](#how-to-use) |
|
|
4. [Training](#training) |
|
|
5. [Evaluation](#evaluation) |
|
|
6. [Limitations](#limitations) |
|
|
7. [Acknowledgements](#acknowledgements) |
|
|
8. [License](#license) |
|
|
--- |
|
|
|
|
|
## Model Summary |
|
|
|
|
|
The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning. |
|
|
It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**: |
|
|
|
|
|
* **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning). |
|
|
* **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags. |
|
|
* **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style. |
|
|
|
|
|
The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)** |
|
|
|
|
|
--- |
|
|
|
|
|
## Post-Training Pipeline |
|
|
<img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" /> |
|
|
|
|
|
--- |
|
|
|
|
|
## How to use |
|
|
|
|
|
```bash |
|
|
pip install -U transformers accelerate |
|
|
``` |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "Shekswess/trlm-135m" |
|
|
device = "cuda" # or "cpu" |
|
|
|
|
|
# Load tokenizer & model |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
).to(device) |
|
|
|
|
|
# Example prompt |
|
|
prompt = "Give me a brief explanation of gravity in simple terms." |
|
|
messages = [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
# Apply chat template |
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
# Generate |
|
|
outputs = model.generate(**inputs, max_new_tokens=256) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
> [!TIP] |
|
|
> For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`. |
|
|
|
|
|
--- |
|
|
|
|
|
## Training |
|
|
|
|
|
### Model |
|
|
|
|
|
* **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model). |
|
|
* **Parameters**: ~135M. |
|
|
* **Precision**: mix-precision (bfloat16) during training. |
|
|
|
|
|
### Software & Hardware |
|
|
|
|
|
* **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL. |
|
|
* **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM). |
|
|
|
|
|
**Special thanks to [@HotAisle](https://x.com/HotAisle)** |
|
|
|
|
|
### Training Stages |
|
|
|
|
|
1. **Stage 1 β SFT (non-reasoning)** |
|
|
* ~58k samples, everyday conversations & instruction following. |
|
|
2. **Stage 2 β SFT (reasoning)** |
|
|
* ~78k samples with `<think>` segments. |
|
|
3. **Stage 3 β DPO (alignment)** |
|
|
* ~50k preference pairs (chosen vs. rejected reasoning traces). |
|
|
--- |
|
|
|
|
|
## Evaluation |
|
|
|
|
|
Evaluation was done with `lm-eval-harness`: |
|
|
|
|
|
| **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** | |
|
|
| -------------------- | ---------------------------- | ------------------------- | ---------------------------- | |
|
|
| **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** | |
|
|
| **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** | |
|
|
| **BoolQ** | **62.17** | β | N/A | |
|
|
| **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** | |
|
|
| **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** | |
|
|
| **MMLU** | **34.95** | 29.3 | **+5.65** | |
|
|
| **PIQA** | **64.91** | 66.3 | **β1.39** | |
|
|
| **HellaSwag** | β | 40.9 | N/A | |
|
|
| **MT-Bench** | β | 19.8 | N/A | |
|
|
|
|
|
--- |
|
|
|
|
|
## Limitations |
|
|
|
|
|
* **Not production-ready**: hallucinations and logical errors are frequent. |
|
|
* **Small size**: limited general knowledge and reasoning depth. |
|
|
* **English-only**: multilingual capabilities not explored. |
|
|
|
|
|
--- |
|
|
|
|
|
## Acknowledgements |
|
|
|
|
|
- [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup. |
|
|
- [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples. |
|
|
- [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection. |
|
|
- [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset. |
|
|
- [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization. |
|
|
|
|
|
--- |
|
|
|
|
|
## License |
|
|
|
|
|
[Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) |
|
|
|
|
|
--- |