File size: 3,901 Bytes
a086178
 
 
 
 
 
 
b5808ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50b73b9
b5808ec
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
license: mit
language:
- en
base_model:
- lmsys/vicuna-7b-v1.5
---
# Speculative Decoding

This model implements Medusa, an efficient speculative decoding approach that can achieve up to 3x faster inference for large language models. The implementation consists of the base Vicuna-7B model augmented with specialized prediction heads that enable parallel token generation.

## Model Description

- **Model type:** Causal language model with speculative decoding
- **Base model:** Vicuna-7B-v1.3
- **Language:** English
- **License:** MIT

This implementation adds multiple speculative heads on top of the base Vicuna model. Each speculative head attempts to predict future tokens in parallel, enabling faster inference by generating multiple tokens in a single forward pass.

## Technical Specifications

- **Base model:** lmsys/vicuna-7b-v1.3
- **Parameters:** ~7B (base model) + speculative heads
- **Context length:** 2048 tokens
- **Speculative heads:** 3
- **Layers per head:** 1
- **Architecture:** Each head uses residual blocks (ResBlock) with SiLU activation

### Direct Inference

```python
import torch
from model import MedusaModel
from transformers import AutoTokenizer

# Load model (requires Medusa codebase and compatible weights)
model = MedusaModel.from_pretrained("theharshithh/vicuna-7b-speculative")
tokenizer = model.get_tokenizer()

prompt = "Human: What is machine learning?\nAssistant:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")

# Generate text
generated_text = ""
with torch.no_grad():
    for output in model.medusa_generate(
        input_ids=input_ids,
        temperature=0.7,
        max_steps=512
    ):
        generated_text = output["text"]

print(prompt + generated_text)
```

### ONNX Inference

For accelerated inference, first export the model to ONNX format:

```python
from inference.infer_onnx import export_model, inference

# Export model to ONNX
export_model(model_checkpoint="theharshithh/vicuna-7b-speculative", save_directory="onnx/")

# Run inference with ONNX model
inference(model_checkpoint="theharshithh/vicuna-7b-speculative", save_directory="onnx/")
```

## Training Details

This model was trained on a processed dataset derived from the [ShareGPT conversations dataset](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/blob/main/ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json). Training was conducted using distributed training across multiple GPUs, with these key hyperparameters:

- **Learning rate:** 1e-3
- **Weight decay:** 0.0
- **Warmup ratio:** 0.1
- **Scheduler:** Cosine
- **Batch size:** 4 (per device)
- **Gradient accumulation steps:** 4
- **Precision:** BF16

## Performance

The model achieves significant inference speedups compared to standard autoregressive generation:
- Up to 3x faster inference speed by generating multiple tokens in parallel
- Minimal impact on output quality compared to standard generation

The training loss curves show steady convergence, though longer training would likely yield further improvements.

## Limitations

- Currently optimized for inference with batch size 1
- Performance varies based on text complexity and token predictability
- Training was conducted with limited GPU resources (extended training recommended for optimal results)

## Citation

If you use this model in your research, please cite:

**GitHub Repository:**
```
@misc{medusa-vicuna-7b-speculative-github,
  author = {theharshithh},
  title = {Medusa: Fast LLM Inference with Speculative Decoding},
  year = {2025},
  publisher = {GitHub},
  howpublished = {https://github.com/theharshithh/speculative-decoding}
}
```

## Repository & Source Code

For implementation details, source code, and further documentation, see the GitHub repository:

- [https://github.com/theharshithh/speculative-decoding](https://github.com/theharshithh/speculative-decoding)