akhaliq HF Staff commited on
Commit
f35bf64
·
verified ·
1 Parent(s): 21f22c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -192
app.py CHANGED
@@ -1,230 +1,106 @@
1
  import os
2
- import time
 
 
3
  import torch
4
  import gradio as gr
5
- from typing import List, Dict, Any, Tuple
6
- from transformers import (
7
- AutoTokenizer,
8
- AutoModelForCausalLM,
9
- TextIteratorStreamer,
10
- )
11
  from huggingface_hub import login
12
- import threading
13
  import spaces
14
 
15
- """
16
- Gradio chat app for facebook/MobileLLM-Pro
17
- - Uses the model's chat template when using the "instruct" subfolder
18
- - Streams tokens to the Gradio UI
19
- - Minimal controls: max_new_tokens, temperature, top_p
20
- - Optional HF_TOKEN login via env var or textbox
21
-
22
- To run locally:
23
- pip install -U gradio transformers accelerate sentencepiece huggingface_hub
24
- HF_TOKEN=xxxx python app.py
25
-
26
- On Hugging Face Spaces:
27
- - Remove explicit login() call or set HF_TOKEN as a secret
28
- """
29
-
30
  MODEL_ID = "facebook/MobileLLM-Pro"
31
- DEFAULT_VERSION = "instruct" # "base" | "instruct"
32
- DEFAULT_MAX_NEW_TOKENS = 256
33
- DEFAULT_TEMPERATURE = 0.7
34
- DEFAULT_TOP_P = 0.95
35
 
36
- # ---- Optional: login to Hugging Face if token is provided ----
37
- HF_TOKEN = os.getenv("HF_TOKEN")
38
  if HF_TOKEN:
39
  try:
 
40
  login(token=HF_TOKEN)
41
- print("[INFO] Logged in to Hugging Face Hub.")
42
- except Exception as e:
43
- print(f"[WARN] Could not login to Hugging Face: {e}")
44
-
45
-
46
- def load_model(version: str = DEFAULT_VERSION):
47
- """Load tokenizer+model for the selected subfolder (base/instruct)."""
48
- print(f"[INFO] Loading {MODEL_ID}:{version} ...")
49
- tokenizer = AutoTokenizer.from_pretrained(
50
- MODEL_ID, trust_remote_code=True, subfolder=version
 
 
 
 
 
51
  )
52
- model = AutoModelForCausalLM.from_pretrained(
53
  MODEL_ID,
54
  trust_remote_code=True,
55
- subfolder=version,
56
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
57
  low_cpu_mem_usage=True,
58
  device_map="auto" if torch.cuda.is_available() else None,
59
  )
60
-
61
- # Ensure special tokens are set to avoid warnings
62
- if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
63
- tokenizer.pad_token = tokenizer.eos_token
64
-
65
- model.eval()
66
- print("[INFO] Model loaded.")
67
- return tokenizer, model
68
-
69
 
70
  def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
71
- """Map Gradio history [(user, assistant), ...] to chat template messages."""
72
- messages: List[Dict[str, str]] = []
73
  for user_msg, bot_msg in history:
74
  if user_msg:
75
- messages.append({"role": "user", "content": user_msg})
76
  if bot_msg:
77
- messages.append({"role": "assistant", "content": bot_msg})
78
- return messages
79
-
80
 
81
  @spaces.GPU(duration=120)
82
- def generate_stream(
83
- message: str,
84
- history: List[Tuple[str, str]],
85
- version: str,
86
- max_new_tokens: int,
87
- temperature: float,
88
- top_p: float,
89
- use_chat_template: bool,
90
- state: Dict[str, Any],
91
- ):
92
- """Streaming text generator compatible with gr.ChatInterface.
93
-
94
- Args map to UI controls. `state` holds tokenizer/model between calls.
95
  """
96
- tokenizer = state.get("tokenizer")
97
- model = state.get("model")
98
-
99
- # (Re)load model if version changed or not yet loaded
100
- if (
101
- tokenizer is None
102
- or model is None
103
- or state.get("version") != version
104
- ):
105
- tokenizer, model = load_model(version)
106
- state["tokenizer"], state["model"], state["version"] = tokenizer, model, version
107
-
108
- device = next(model.parameters()).device
109
-
110
- if use_chat_template and version == "instruct":
111
- messages = _history_to_messages(history) + [
112
- {"role": "user", "content": message}
113
- ]
114
- inputs = tokenizer.apply_chat_template(
115
- messages,
116
- return_tensors="pt",
117
- add_generation_prompt=True,
118
- ).to(device)
119
- input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
120
- else:
121
- input_ids = tokenizer(
122
- message,
123
- return_tensors="pt",
124
- add_special_tokens=True,
125
- )["input_ids"].to(device)
126
 
127
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
 
 
 
 
128
 
 
129
  gen_kwargs = dict(
130
  input_ids=input_ids,
131
- max_new_tokens=max_new_tokens,
132
- do_sample=temperature > 0.0,
133
- temperature=max(0.0, float(temperature)),
134
- top_p=float(top_p),
135
- pad_token_id=tokenizer.pad_token_id,
136
- eos_token_id=tokenizer.eos_token_id,
137
  streamer=streamer,
138
  )
139
 
140
- thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
141
  thread.start()
142
 
143
- output_text = ""
144
  for new_text in streamer:
145
- output_text += new_text
146
- yield output_text
147
-
148
-
149
- with gr.Blocks(title="MobileLLM-Pro Chat") as demo:
150
- gr.Markdown("""
151
- # facebook/MobileLLM-Pro — Chat Demo
152
- - **Version**: choose `instruct` to enable the model's chat template.
153
- - **Streaming** is enabled. Use the controls in the right panel.
154
- """)
155
- gr.Markdown(
156
- "<div style='text-align: center;'>Built with <a href='https://huggingface.co/spaces/akhaliq/anycoder'>anycoder</a></div>",
157
- elem_id="anycoder_attribution"
158
- )
159
-
160
- with gr.Row():
161
- with gr.Column(scale=3):
162
- chatbot = gr.Chatbot(height=420, label="MobileLLM-Pro")
163
- msg = gr.Textbox(placeholder="Ask me anything…", scale=1)
164
- submit = gr.Button("Send", variant="primary")
165
- clear_btn = gr.Button("Clear chat")
166
- with gr.Column(scale=2):
167
- version = gr.Dropdown(["base", "instruct"], value=DEFAULT_VERSION, label="Subfolder (version)")
168
- use_chat_template = gr.Checkbox(value=True, label="Use chat template (instruct only)")
169
- max_new = gr.Slider(32, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
170
- temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
171
- top_p = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p")
172
- hf_token_box = gr.Textbox(value=os.getenv("HF_TOKEN", ""), label="HF_TOKEN (optional)")
173
-
174
- state = gr.State({"tokenizer": None, "model": None, "version": None})
175
-
176
- def _maybe_login(token: str):
177
- token = (token or "").strip()
178
- if not token:
179
- return "(No token provided; skipping login)"
180
- try:
181
- login(token=token)
182
- return "Logged in to Hugging Face Hub."
183
- except Exception as e:
184
- return f"Login failed: {e}"
185
-
186
- login_btn = gr.Button("Login to HF (optional)")
187
- login_status = gr.Markdown()
188
- login_btn.click(_maybe_login, inputs=[hf_token_box], outputs=[login_status])
189
-
190
- def user_submit(user_message, chat_history):
191
- # Immediately append the user's message so the stream shows inline
192
- return "", chat_history + [(user_message, None)]
193
-
194
- def bot_respond(chat_history, version, max_new, temperature, top_p, use_chat_template, state):
195
- # The last tuple is (user, None)
196
- user_message = chat_history[-1][0] if chat_history else ""
197
- partials = generate_stream(
198
- user_message,
199
- chat_history[:-1],
200
- version,
201
- int(max_new),
202
- float(temperature),
203
- float(top_p),
204
- bool(use_chat_template),
205
- state,
206
- )
207
- # Stream tokens to the last assistant message slot
208
- for chunk in partials:
209
- chat_history[-1] = (chat_history[-1][0], chunk)
210
- yield chat_history
211
-
212
- msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
213
- bot_respond,
214
- [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
215
- [chatbot],
216
- )
217
- submit.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
218
- bot_respond,
219
- [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
220
- [chatbot],
221
- )
222
-
223
- def clear_chat():
224
- return []
225
-
226
- clear_btn.click(clear_chat, outputs=[chatbot])
227
 
228
  if __name__ == "__main__":
229
- # For Spaces, Gradio will call `demo.launch()` automatically; locally we launch here.
230
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
 
1
  import os
2
+ import threading
3
+ from typing import List, Tuple, Dict
4
+
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
 
8
  from huggingface_hub import login
 
9
  import spaces
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  MODEL_ID = "facebook/MobileLLM-Pro"
12
+ SUBFOLDER = "instruct" # use the chat template
13
+ MAX_NEW_TOKENS = 256
14
+ TEMPERATURE = 0.7
15
+ TOP_P = 0.95
16
 
17
+ # --- Silent Hub auth via env/Space Secret (no UI) ---
18
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
19
  if HF_TOKEN:
20
  try:
21
+ # No prints; stays silent if token works or fails
22
  login(token=HF_TOKEN)
23
+ except Exception:
24
+ # Stay silent to avoid exposing anything to the UI/logs
25
+ pass
26
+
27
+ # Globals so we only load once
28
+ _tokenizer = None
29
+ _model = None
30
+ _device = None
31
+
32
+ def _ensure_loaded():
33
+ global _tokenizer, _model, _device
34
+ if _tokenizer is not None and _model is not None:
35
+ return
36
+ _tokenizer = AutoTokenizer.from_pretrained(
37
+ MODEL_ID, trust_remote_code=True, subfolder=SUBFOLDER
38
  )
39
+ _model = AutoModelForCausalLM.from_pretrained(
40
  MODEL_ID,
41
  trust_remote_code=True,
42
+ subfolder=SUBFOLDER,
43
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
44
  low_cpu_mem_usage=True,
45
  device_map="auto" if torch.cuda.is_available() else None,
46
  )
47
+ if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
48
+ _tokenizer.pad_token = _tokenizer.eos_token
49
+ _model.eval()
50
+ _device = next(_model.parameters()).device
 
 
 
 
 
51
 
52
  def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
53
+ msgs: List[Dict[str, str]] = []
 
54
  for user_msg, bot_msg in history:
55
  if user_msg:
56
+ msgs.append({"role": "user", "content": user_msg})
57
  if bot_msg:
58
+ msgs.append({"role": "assistant", "content": bot_msg})
59
+ return msgs
 
60
 
61
  @spaces.GPU(duration=120)
62
+ def generate_stream(message: str, history: List[Tuple[str, str]]):
 
 
 
 
 
 
 
 
 
 
 
 
63
  """
64
+ Minimal streaming chat function for gr.ChatInterface.
65
+ Uses instruct chat template. No token UI. No extra controls.
66
+ """
67
+ _ensure_loaded()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ messages = _history_to_messages(history) + [{"role": "user", "content": message}]
70
+ inputs = _tokenizer.apply_chat_template(
71
+ messages,
72
+ return_tensors="pt",
73
+ add_generation_prompt=True,
74
+ )
75
+ input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
76
+ input_ids = input_ids.to(_device)
77
 
78
+ streamer = TextIteratorStreamer(_tokenizer, skip_special_tokens=True)
79
  gen_kwargs = dict(
80
  input_ids=input_ids,
81
+ max_new_tokens=MAX_NEW_TOKENS,
82
+ do_sample=TEMPERATURE > 0.0,
83
+ temperature=float(TEMPERATURE),
84
+ top_p=float(TOP_P),
85
+ pad_token_id=_tokenizer.pad_token_id,
86
+ eos_token_id=_tokenizer.eos_token_id,
87
  streamer=streamer,
88
  )
89
 
90
+ thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
91
  thread.start()
92
 
93
+ output = ""
94
  for new_text in streamer:
95
+ output += new_text
96
+ yield output
97
+
98
+ demo = gr.ChatInterface(
99
+ fn=generate_stream,
100
+ chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
101
+ title="MobileLLM-Pro — Chat",
102
+ description="Streaming chat with facebook/MobileLLM-Pro (instruct)",
103
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))