|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
cache_dir = "./model_cache" |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
model_name = "Qwen/Qwen1.5-0.5B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True |
|
|
).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
generation_config = GenerationConfig.from_pretrained(model_name) |
|
|
|
|
|
def generate_text(prompt, temperature, max_new_tokens): |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) |
|
|
return response.strip() |
|
|
|
|
|
|
|
|
with gr.Blocks(theme="soft", title="Qwen Text Generation") as demo: |
|
|
gr.Markdown("# 🧠 Qwen1.5-0.5B Text Generation") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Input Prompt", |
|
|
placeholder="Enter your instruction or question...", |
|
|
lines=5 |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, maximum=2.0, value=0.7, step=0.1, |
|
|
label="Creativity (Temperature)" |
|
|
) |
|
|
max_new_tokens = gr.Slider( |
|
|
minimum=50, maximum=1000, value=200, step=50, |
|
|
label="Max New Tokens" |
|
|
) |
|
|
generate_btn = gr.Button("✨ Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Textbox(label="Model Response", lines=10, interactive=False) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_text, |
|
|
inputs=[prompt, temperature, max_new_tokens], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### ℹ️ Tips |
|
|
- Higher **temperature** = more creative/chaotic responses |
|
|
- Lower **temperature** = more deterministic answers |
|
|
- Adjust **max tokens** for longer/shorter outputs |
|
|
""") |
|
|
|
|
|
demo.launch() |