import os import threading from typing import List, Tuple, Dict import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from huggingface_hub import login import spaces MODEL_ID = "facebook/MobileLLM-Pro" MAX_NEW_TOKENS = 256 TEMPERATURE = 0.7 TOP_P = 0.95 # --- Silent Hub auth via env/Space Secret (no UI) --- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") if HF_TOKEN: try: login(token=HF_TOKEN) except Exception: pass # stay silent # Globals so we only load once _tokenizer = None _model = None _device = None def _ensure_loaded(): global _tokenizer, _model, _device if _tokenizer is not None and _model is not None: return _tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True ) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, device_map="auto" if torch.cuda.is_available() else None, ) if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None: _tokenizer.pad_token = _tokenizer.eos_token _model.eval() _device = next(_model.parameters()).device def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]: msgs: List[Dict[str, str]] = [] for user_msg, bot_msg in history: if user_msg: msgs.append({"role": "user", "content": user_msg}) if bot_msg: msgs.append({"role": "assistant", "content": bot_msg}) return msgs @spaces.GPU(duration=120) def generate_stream(message: str, history: List[Tuple[str, str]]): """ Minimal streaming chat function for gr.ChatInterface. Uses instruct chat template. No token UI. No extra controls. """ _ensure_loaded() messages = _history_to_messages(history) + [{"role": "user", "content": message}] inputs = _tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, ) input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs input_ids = input_ids.to(_device) # IMPORTANT: don't stream the prompt (prevents system/user text from appearing) streamer = TextIteratorStreamer( _tokenizer, skip_special_tokens=True, skip_prompt=True, # <-- key fix ) gen_kwargs = dict( input_ids=input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=TEMPERATURE > 0.0, temperature=float(TEMPERATURE), top_p=float(TOP_P), pad_token_id=_tokenizer.pad_token_id, eos_token_id=_tokenizer.eos_token_id, streamer=streamer, ) thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs) thread.start() output = "" for new_text in streamer: output += new_text yield output with gr.Blocks(title="MobileLLM-Pro — Chat") as demo: gr.Markdown( """ # MobileLLM-Pro — Chat Streaming chat with facebook/MobileLLM-Pro (instruct)