tiger-gpt2-chat / app.py
xingyu1996's picture
Update app.py
7386b18 verified
raw
history blame
1.76 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# --- η›΄ζŽ₯εŠ θ½½ζ¨‘εž‹ε’Œεˆ†θ―ε™¨ ---
model_id = "xingyu1996/tiger-gpt2"
tokenizer = AutoTokenizer.from_pretrained("gpt2") # δ½Ώη”¨εŽŸε§‹ηš„ GPT-2 εˆ†θ―ε™¨
model = AutoModelForCausalLM.from_pretrained(model_id)
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
# ε°†θΎ“ε…₯ζ–‡ζœ¬θ½¬ζ’δΈΊ token ID
input_ids = tokenizer.encode(message, return_tensors="pt")
# ε‡†ε€‡η”Ÿζˆε‚ζ•°
gen_kwargs = {
"max_length": input_ids.shape[1] + max_tokens,
"do_sample": True if temperature > 0 else False,
}
if temperature > 0:
gen_kwargs["temperature"] = temperature
if top_p < 1.0:
gen_kwargs["top_p"] = top_p
# η”Ÿζˆζ–‡ζœ¬
with torch.no_grad():
output_ids = model.generate(input_ids, **gen_kwargs)
# εͺδΏη•™ζ–°η”Ÿζˆηš„ιƒ¨εˆ†
new_tokens = output_ids[0, input_ids.shape[1]:]
# θ§£η η”Ÿζˆηš„ token ID
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
return response
# ε…Άδ»– Gradio η•Œι’δ»£η δΈε˜
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=512, value=325, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title=f"ζŽ¨η†ζ΅‹θ―•: {model_id}",
description="θΎ“ε…₯δΈ­ζ–‡ζ–‡ζœ¬οΌŒζ¨‘εž‹ε°†θΏ›θ‘Œθ‘₯全。"
)
if __name__ == "__main__":
demo.launch()