John6666's picture
Upload 3 files
a58a3d1 verified
# -*- coding: utf-8 -*-
"""
Gradio 5 + ZeroGPU-ready chat.
- Streams tokens with TextIteratorStreamer
Refs:
- ZeroGPU docs (per-call GPU, 60s default, dynamic duration, effect-free on non-ZeroGPU):
https://huggingface.co/docs/hub/en/spaces-zerogpu
- Using GPU Spaces:
https://huggingface.co/docs/hub/en/spaces-gpus
- Gradio ChatInterface (type="messages"):
https://www.gradio.app/docs/gradio/chatinterface
- Transformers chat templating:
https://huggingface.co/docs/transformers/en/chat_templating
- TextIteratorStreamer:
https://huggingface.co/docs/transformers/en/internal/generation_utils
- Meltemi Instruct v1.5 (Zephyr prompt format):
https://huggingface.co/ilsp/Meltemi-7B-Instruct-v1.5
"""
import os, threading, gc, re
from typing import Iterable, List, Dict, Any
import gradio as gr
import spaces
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
HAS_CUDA = bool(torch.cuda.is_available())
IS_ZEROGPU = True if os.getenv("SPACES_ZERO_GPU", None) else False
if IS_ZEROGPU:
torch.compiler.set_stance("force_eager")
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
# ------------------------- Config --------------------------------------------
MODEL_ID = "ilsp/Meltemi-7B-v1.5" if HAS_CUDA else "HuggingFaceTB/SmolLM2-135M-Instruct"
SYSTEM_PROMPT = "You are a helpful, concise assistant. Reply in the user's language when possible."
ZEROGPU_DURATION_BASE = 20
ZEROGPU_DURATION_CAP = 60
CPU_MAX_NEW_TOKENS = 128
# Base defaults aim to respect ZeroGPU's 60s window.
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
TOP_K = 50
REPETITION_PENALTY = 1.05
# ------------------------- Load model/tokenizer ------------------------
# Use bf16 on CUDA, fall back to fp32 on CPU for compatibility.
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
def clear_cache():
if HAS_CUDA:
try:
torch.cuda.empty_cache()
except Exception:
pass
gc.collect()
def build_messages(user_text: str, history: List[Dict[str, str]], system_message: str) -> List[Dict[str, str]]:
"""Merge system + history + user into HF messages format."""
msgs: List[Dict[str, str]] = []
if not history and system_message: msgs.append({"role": "system", "content": system_message})
msgs.extend(history or [])
msgs.append({"role": "user", "content": user_text})
return msgs
def to_model_inputs(messages: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
"""
Create tokenized inputs.
- If a chat template exists,
use apply_chat_template(..., add_generation_prompt=True).
- Otherwise fall back to a simple joined text prompt for base models.
"""
if getattr(tokenizer, "chat_template", None):
return tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
)
# Fallback: last user only, no headings that induce lists
print("Chat template not found.")
last_user = next((m["content"] for m in reversed(messages) if m["role"]=="user"), "")
prompt = (messages[0]["content"] + "\n\n" if messages and messages[0]["role"]=="system" else "") + last_user
return tokenizer(prompt, return_tensors="pt")
CHATML_HEAD = re.compile(r'^\s*<\|im_start\|>\s*assistant\s*\n?', re.IGNORECASE)
CHATML_TAIL = re.compile(r'\s*<\|im_end\|>\s*$', re.IGNORECASE)
GEN_TAIL = re.compile(r'\s*(</s>|<\|end_of_text\|>)\s*$', re.IGNORECASE)
def _clean_stream_chunk(text: str) -> str:
t = CHATML_HEAD.sub("", text)
t = CHATML_TAIL.sub("", t)
t = GEN_TAIL.sub("", t)
return t
def estimate_duration(max_new_tokens: int) -> int:
"""
Heuristic duration budget for ZeroGPU (seconds).
Shorter durations improve queue priority.
"""
secs = int(float(ZEROGPU_DURATION_BASE) * float(max_new_tokens) / float(MAX_NEW_TOKENS))
cap = int(ZEROGPU_DURATION_CAP)
return min(secs, cap)
@spaces.GPU(duration=lambda message, history, system_message, max_new_tokens, temperature, top_k, top_p, repetition_penalty: estimate_duration(max_new_tokens if max_new_tokens else MAX_NEW_TOKENS))
@torch.inference_mode()
def generate_on_accelerator(message: str, history: List[Dict[str, str]], system_message: str, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float) -> Iterable[str]:
"""
ChatInterface callback that works both on ZeroGPU and CPU-only Spaces.
The decorator is effect-free outside ZeroGPU, so this still runs on CPU.
"""
try:
device = torch.device("cuda" if HAS_CUDA else "cpu")
if device != model.device: model.to(device)
messages = build_messages(message, history, system_message)
inputs = to_model_inputs(messages)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Reduce generation length on CPU to keep latency reasonable.
eff_max_new_tokens = max_new_tokens if HAS_CUDA else min(max_new_tokens, CPU_MAX_NEW_TOKENS)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
decode_kwargs={"skip_special_tokens": True},
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=eff_max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
cache_implementation="static", # https://github.com/huggingface/transformers/issues/38501
)
# Run generation in a background thread and yield chunks as they arrive.
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
new_response = ""
for text in streamer:
new_response += text
new_response = _clean_stream_chunk(new_response)
yield new_response
except Exception as e:
print(e)
raise gr.Error(e)
finally:
# Move back to CPU and free CUDA cache when using a GPU.
if not HAS_CUDA: model.to("cpu")
clear_cache()
demo = gr.ChatInterface(
fn=generate_on_accelerator,
type="messages", # OpenAI-style histories: [{"role": "...", "content": "..."}]
title=f"{MODEL_ID} Chat",
description=(f"Chat with {MODEL_ID}"),
additional_inputs=[
gr.Textbox(value=SYSTEM_PROMPT, label="System message"),
gr.Slider(minimum=1, maximum=2048, value=MAX_NEW_TOKENS, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=TEMPERATURE, step=0.1, label="Temperature"),
gr.Slider(minimum=0, maximum=360, value=TOP_K, step=1, label="Top-k"),
gr.Slider(minimum=0.1, maximum=1.0, value=TOP_P, step=0.05, label="Top-p"),
gr.Slider(minimum=0.0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition penalty"),
],
examples=[
["Summarize the following paragraph in Greek."],
["Translate this into English: Καλημέρα, τι κάνεις;"],
["Write a short outline about Greek islands."],
],
cache_examples=False,
)
demo.queue().launch()