LLaDA-Prometheus
Model Description
This model is a fine-tuned version of the LLaDA 8B Base model, obtained through a specialized Supervised Fine-Tuning (SFT) process. It innovatively discards the complex attention mask design typically associated with block diffusion, while preserving full attention mechanisms. This allows the model to achieve block diffusion-style inference efficiently—leveraging KV cache for streamlined generation, outputting an EOS token upon completion of the response to seamlessly exit the generation process.
Key innovations:
- Full Attention Preservation: Maintains standard full attention without the overhead of intricate masking.
- Block Diffusion Inference: Enables iterative block-wise generation via KV cache management, ensuring coherent and controlled outputs.
- EOS Handling: Trained to naturally emit EOS tokens at response boundaries.
This approach balances computational efficiency with high-quality generation, making it suitable for tasks requiring structured, multi-step reasoning.
Usage
To load and use this model with Hugging Face Transformers:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "maomaocun/LLaDA-Prometheus-no-template"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
prompt = "Can you tell me an engaging short story about a brave young astronaut who discovers an ancient alien civilization on a distant planet? Make it adventurous and heartwarming, with a twist at the end."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs['input_ids']
attention_mask = inputs.get('attention_mask', torch.ones_like(input_ids))
for chunk in model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_gen_length=1024,
block_length=64,
threshold=0.9,
streaming=True,
eos_token_id=tokenizer.eos_token,
):
all_generated_ids = torch.cat([input_ids, chunk], dim=-1)
text = tokenizer.batch_decode(all_generated_ids, skip_special_tokens=False)[0].split(tokenizer.eos_token)[0]
print(text, end='', flush=True)
For block diffusion-style inference, customize the generation loop to manage KV cache and block outputs as needed.
Benchmarks
The following table compares performance across key evaluation benchmarks. Results are reported as accuracy percentages where applicable.
| Model | GSM8K | GPQA | BBH | MATH | HumanEval | MBPP | MMLU-Pro | MMLU-Generate |
|---|---|---|---|---|---|---|---|---|
| LLaDA 8B Base in Pure Diffusion | 69.06 | 31.91 | 44.77 | 30.84 | 32.92 | 40.8 | 24.26 | 65.9 |
| LLaDA 8B Instruct in Pure Diffusion | 77.48 | 29.01 | 51.49 | 22.32 | 38.71 | 39.2 | 36.41 | 65.5 |
| LLaDA-Prometheus in Block Diffusion | 77.4 | 33.03 | 48.74 | 31.94 | 40.24 | 42 | 33.45 | 65.53 |
These results demonstrate competitive performance, particularly in code generation (HumanEval, MBPP) and reasoning tasks (BBH, MATH), with gains over the base instruct variant in several areas.
- Downloads last month
- 25