tiger-gpt2-chat / app.py
xingyu1996's picture
Update app.py
6d60be6 verified
raw
history blame
2.55 kB
import gradio as gr
from huggingface_hub import InferenceClient, hf_hub_download
from transformers import AutoTokenizer
import os
import torch
# --- 设置模型 ID ---
model_id = "xingyu1996/tiger-gpt2"
client = InferenceClient(model_id)
# --- 关键变化: 直接加载与训练时相同的分词器 ---
tokenizer = AutoTokenizer.from_pretrained("gpt2")
def respond(
message,
history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
prompt = message
response_ids = []
response_text = ""
# --- 参数准备 ---
generation_args = {
"max_new_tokens": max_tokens,
"stream": True,
"details": True, # 让 API 返回 token ID (重要变化)
}
if temperature is not None and temperature > 0:
generation_args["temperature"] = temperature
if top_p is not None and top_p < 1.0:
generation_args["top_p"] = top_p
try:
# --- 调用 API, 获取 token IDs ---
for output in client.text_generation(prompt, **generation_args):
if hasattr(output, 'token'): # 流式生成时的输出
# 这里 output.token 是一个字典,包含 id 和 text
token_id = output.token.id
response_ids.append(token_id)
# 使用我们自己的 tokenizer 来解码
current_text = tokenizer.decode(response_ids, skip_special_tokens=True)
response_text = current_text
yield response_text
elif hasattr(output, 'generated_text'): # 非流式生成时的最终输出
# 如果直接返回了完整文本,就用它
response_text = output.generated_text
yield response_text
except Exception as e:
print(f"推理时发生错误: {type(e).__name__} - {e}")
yield f"抱歉,推理时遇到错误: {type(e).__name__} - {str(e)}"
# 其他 Gradio 界面代码不变
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=512, value=325, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title=f"推理测试: {model_id}",
description="输入中文文本,模型将进行补全。"
)
if __name__ == "__main__":
demo.launch()