test_chatbot_2 / app.py
John6666's picture
Upload 3 files
fdd975b verified
raw
history blame
3.42 kB
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
from threading import Thread
import torch
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
quant_config = Int8DynamicActivationInt8WeightConfig()
#quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
#checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
checkpoint = "unsloth/gemma-3-4b-it"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device, quantization_config=quantization_config)
@spaces.GPU(duration=30)
def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
inputs=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
cache_implementation="static",
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for piece in streamer:
partial += piece
yield partial
@spaces.GPU(duration=30)
def respond(message, history, system_message, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
cache_implementation="static",
)
gen_ids = outputs[0][input_ids.shape[-1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
with gr.Blocks() as demo:
chatbot.render()
if __name__ == "__main__":
demo.queue().launch()