druvx13 commited on
Commit
663ba54
·
verified ·
1 Parent(s): a8c224c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -6,31 +6,32 @@ import torch
6
  MODEL_NAME = "Qwen/Qwen3-0.6B"
7
  cache_dir = "./model_cache"
8
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(
10
  MODEL_NAME,
11
  trust_remote_code=True,
12
  cache_dir=cache_dir
13
  )
14
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_NAME,
17
  trust_remote_code=True,
18
- torch_dtype=torch.float16,
19
- device_map="auto",
20
  cache_dir=cache_dir
21
- ).eval()
22
 
23
- # Create text generation pipeline
24
  text_generator = pipeline(
25
  "text-generation",
26
  model=model,
27
  tokenizer=tokenizer,
28
- device=0 if torch.cuda.is_available() else -1,
29
- pad_token_id=tokenizer.eos_token_id
30
  )
31
 
32
  def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
33
- """Generate response with safe defaults"""
34
  try:
35
  response = text_generator(
36
  prompt,
@@ -38,35 +39,36 @@ def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
38
  temperature=float(temperature),
39
  top_p=float(top_p),
40
  do_sample=True,
41
- truncation=True
 
42
  )
43
  return response[0]["generated_text"]
44
  except Exception as e:
45
- return f"Error: {str(e)}"
46
 
47
- # Gradio interface with advanced settings
48
  with gr.Blocks(theme="soft", title="Qwen3-0.6B Chat Interface") as demo:
49
  gr.Markdown("# 🧠 Qwen3-0.6B Text-to-Text Chat")
50
- gr.Markdown("Powered by HuggingFace Transformers and Gradio")
51
 
52
  with gr.Row():
53
- with gr.Column():
54
  prompt = gr.Textbox(
55
  label="User Input",
56
  placeholder="Ask me anything...",
57
  lines=5
58
  )
59
- with gr.Accordion("Advanced Settings", open=False):
60
  max_new_tokens = gr.Slider(
61
  minimum=32,
62
- maximum=512,
63
  value=256,
64
  step=32,
65
  label="Max New Tokens"
66
  )
67
  temperature = gr.Slider(
68
  minimum=0.1,
69
- maximum=1.0,
70
  value=0.7,
71
  step=0.1,
72
  label="Temperature"
@@ -79,10 +81,10 @@ with gr.Blocks(theme="soft", title="Qwen3-0.6B Chat Interface") as demo:
79
  label="Top-p Sampling"
80
  )
81
 
82
- with gr.Column():
83
- output = gr.Textbox(label="Model Response", lines=10)
84
 
85
- submit = gr.Button("💬 Generate Response")
86
  submit.click(
87
  fn=generate_response,
88
  inputs=[prompt, max_new_tokens, temperature, top_p],
@@ -95,7 +97,8 @@ with gr.Blocks(theme="soft", title="Qwen3-0.6B Chat Interface") as demo:
95
  ["Write a poem about autumn leaves"],
96
  ["Solve this math problem: 2x + 5 = 17"]
97
  ],
98
- inputs=prompt
 
99
  )
100
 
101
  if __name__ == "__main__":
 
6
  MODEL_NAME = "Qwen/Qwen3-0.6B"
7
  cache_dir = "./model_cache"
8
 
9
+ # Load tokenizer with trust_remote_code for model-specific features
10
  tokenizer = AutoTokenizer.from_pretrained(
11
  MODEL_NAME,
12
  trust_remote_code=True,
13
  cache_dir=cache_dir
14
  )
15
 
16
+ # Load model with GPU acceleration and memory optimization
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
19
  trust_remote_code=True,
20
+ torch_dtype=torch.float16, # FP16 for reduced memory usage
21
+ device_map="auto", # Let accelerate handle device allocation
22
  cache_dir=cache_dir
23
+ ).eval() # Set to evaluation mode
24
 
25
+ # Create text generation pipeline (no explicit device needed with device_map)
26
  text_generator = pipeline(
27
  "text-generation",
28
  model=model,
29
  tokenizer=tokenizer,
30
+ pad_token_id=tokenizer.eos_token_id # Critical fix for generation stability
 
31
  )
32
 
33
  def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
34
+ """Generate response with safe defaults and error handling"""
35
  try:
36
  response = text_generator(
37
  prompt,
 
39
  temperature=float(temperature),
40
  top_p=float(top_p),
41
  do_sample=True,
42
+ truncation=True,
43
+ max_length=tokenizer.model_max_length # Prevent overflow
44
  )
45
  return response[0]["generated_text"]
46
  except Exception as e:
47
+ return f"⚠️ Model Error: {str(e)}\n\nTry reducing input length or adjusting generation parameters."
48
 
49
+ # Gradio interface with enhanced UI
50
  with gr.Blocks(theme="soft", title="Qwen3-0.6B Chat Interface") as demo:
51
  gr.Markdown("# 🧠 Qwen3-0.6B Text-to-Text Chat")
52
+ gr.Markdown(" Optimized for HuggingFace Spaces with GPU acceleration")
53
 
54
  with gr.Row():
55
+ with gr.Column(scale=2):
56
  prompt = gr.Textbox(
57
  label="User Input",
58
  placeholder="Ask me anything...",
59
  lines=5
60
  )
61
+ with gr.Accordion("⚙️ Generation Parameters", open=False):
62
  max_new_tokens = gr.Slider(
63
  minimum=32,
64
+ maximum=1024, # Increased max for long-form generation
65
  value=256,
66
  step=32,
67
  label="Max New Tokens"
68
  )
69
  temperature = gr.Slider(
70
  minimum=0.1,
71
+ maximum=1.5, # Extended range for creative tasks
72
  value=0.7,
73
  step=0.1,
74
  label="Temperature"
 
81
  label="Top-p Sampling"
82
  )
83
 
84
+ with gr.Column(scale=2):
85
+ output = gr.Textbox(label="Model Response", lines=10, show_copy_button=True)
86
 
87
+ submit = gr.Button("💬 Generate Response", variant="primary")
88
  submit.click(
89
  fn=generate_response,
90
  inputs=[prompt, max_new_tokens, temperature, top_p],
 
97
  ["Write a poem about autumn leaves"],
98
  ["Solve this math problem: 2x + 5 = 17"]
99
  ],
100
+ inputs=prompt,
101
+ label="🎯 Example Prompts"
102
  )
103
 
104
  if __name__ == "__main__":