trlm-135m / README.md
Shekswess's picture
Upload 12 files
98ab9d7 verified
---
library_name: transformers
license: apache-2.0
base_model: Shekswess/trlm-stage-2-sft-final-2
tags:
- trl
- dpo
- preference-alignment
- reasoning
- generated_from_trainer
model-index:
- name: trlm-stage-3-dpo-final-2
results: []
---
# Tiny Reasoning Language Model (trlm-135)
![image/png](https://github.com/user-attachments/assets/5f453496-8180-4cf4-94da-26ebbe1159d4)
## Table of Contents
1. [Model Summary](#model-summary)
2. [Post-Training Pipeline](#post-training-pipeline)
3. [How to use](#how-to-use)
4. [Training](#training)
5. [Evaluation](#evaluation)
6. [Limitations](#limitations)
7. [Acknowledgements](#acknowledgements)
8. [License](#license)
---
## Model Summary
The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning.
It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**:
* **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning).
* **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags.
* **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style.
The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)**
---
## Post-Training Pipeline
<img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" />
---
## How to use
```bash
pip install -U transformers accelerate
```
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Shekswess/trlm-135m"
device = "cuda" # or "cpu"
# Load tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
).to(device)
# Example prompt
prompt = "Give me a brief explanation of gravity in simple terms."
messages = [
{"role": "user", "content": prompt}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Generate
outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
> [!TIP]
> For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`.
---
## Training
### Model
* **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model).
* **Parameters**: ~135M.
* **Precision**: mix-precision (bfloat16) during training.
### Software & Hardware
* **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL.
* **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM).
**Special thanks to [@HotAisle](https://x.com/HotAisle)**
### Training Stages
1. **Stage 1 – SFT (non-reasoning)**
* ~58k samples, everyday conversations & instruction following.
2. **Stage 2 – SFT (reasoning)**
* ~78k samples with `<think>` segments.
3. **Stage 3 – DPO (alignment)**
* ~50k preference pairs (chosen vs. rejected reasoning traces).
---
## Evaluation
Evaluation was done with `lm-eval-harness`:
| **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** |
| -------------------- | ---------------------------- | ------------------------- | ---------------------------- |
| **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** |
| **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** |
| **BoolQ** | **62.17** | – | N/A |
| **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** |
| **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** |
| **MMLU** | **34.95** | 29.3 | **+5.65** |
| **PIQA** | **64.91** | 66.3 | **–1.39** |
| **HellaSwag** | – | 40.9 | N/A |
| **MT-Bench** | – | 19.8 | N/A |
---
## Limitations
* **Not production-ready**: hallucinations and logical errors are frequent.
* **Small size**: limited general knowledge and reasoning depth.
* **English-only**: multilingual capabilities not explored.
---
## Acknowledgements
- [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup.
- [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples.
- [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection.
- [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset.
- [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization.
---
## License
[Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
---