|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
MODEL_ID = "Shekswess/trlm-135m" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def generate_reply(prompt, max_new_tokens, temperature, top_p): |
|
|
if not prompt.strip(): |
|
|
return "" |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = output_ids[0, inputs["input_ids"].shape[1]:] |
|
|
decoded = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
return decoded.strip() |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Tiny Reasoning LM (trlm-135m)\nSmall 135M reasoning model by **Shekswess**.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
prompt = gr.Textbox( |
|
|
lines=8, |
|
|
label="Prompt", |
|
|
placeholder="Ask a question or give an instruction…", |
|
|
) |
|
|
max_new_tokens = gr.Slider( |
|
|
minimum=16, |
|
|
maximum=256, |
|
|
value=128, |
|
|
step=8, |
|
|
label="Max new tokens", |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.5, |
|
|
value=0.8, |
|
|
step=0.05, |
|
|
label="Temperature", |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-p", |
|
|
) |
|
|
generate_btn = gr.Button("Generate") |
|
|
|
|
|
with gr.Column(scale=4): |
|
|
output = gr.Textbox( |
|
|
lines=12, |
|
|
label="Model Output", |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_reply, |
|
|
inputs=[prompt, max_new_tokens, temperature, top_p], |
|
|
outputs=[output], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |