0.6wen / app.py
druvx13's picture
Update app.py
663ba54 verified
raw
history blame
3.61 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# Model loading with optimized settings
MODEL_NAME = "Qwen/Qwen3-0.6B"
cache_dir = "./model_cache"
# Load tokenizer with trust_remote_code for model-specific features
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
cache_dir=cache_dir
)
# Load model with GPU acceleration and memory optimization
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16, # FP16 for reduced memory usage
device_map="auto", # Let accelerate handle device allocation
cache_dir=cache_dir
).eval() # Set to evaluation mode
# Create text generation pipeline (no explicit device needed with device_map)
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
pad_token_id=tokenizer.eos_token_id # Critical fix for generation stability
)
def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
"""Generate response with safe defaults and error handling"""
try:
response = text_generator(
prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
truncation=True,
max_length=tokenizer.model_max_length # Prevent overflow
)
return response[0]["generated_text"]
except Exception as e:
return f"⚠️ Model Error: {str(e)}\n\nTry reducing input length or adjusting generation parameters."
# Gradio interface with enhanced UI
with gr.Blocks(theme="soft", title="Qwen3-0.6B Chat Interface") as demo:
gr.Markdown("# 🧠 Qwen3-0.6B Text-to-Text Chat")
gr.Markdown("⚑ Optimized for HuggingFace Spaces with GPU acceleration")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="User Input",
placeholder="Ask me anything...",
lines=5
)
with gr.Accordion("βš™οΈ Generation Parameters", open=False):
max_new_tokens = gr.Slider(
minimum=32,
maximum=1024, # Increased max for long-form generation
value=256,
step=32,
label="Max New Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.5, # Extended range for creative tasks
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p Sampling"
)
with gr.Column(scale=2):
output = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
submit = gr.Button("πŸ’¬ Generate Response", variant="primary")
submit.click(
fn=generate_response,
inputs=[prompt, max_new_tokens, temperature, top_p],
outputs=output
)
gr.Examples(
examples=[
["Explain quantum computing in simple terms"],
["Write a poem about autumn leaves"],
["Solve this math problem: 2x + 5 = 17"]
],
inputs=prompt,
label="🎯 Example Prompts"
)
if __name__ == "__main__":
demo.launch()