Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient, hf_hub_download | |
| from transformers import AutoTokenizer | |
| import os | |
| import torch | |
| # --- 设置模型 ID --- | |
| model_id = "xingyu1996/tiger-gpt2" | |
| client = InferenceClient(model_id) | |
| # --- 关键变化: 直接加载与训练时相同的分词器 --- | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| prompt = message | |
| response_ids = [] | |
| response_text = "" | |
| # --- 参数准备 --- | |
| generation_args = { | |
| "max_new_tokens": max_tokens, | |
| "stream": True, | |
| "details": True, # 让 API 返回 token ID (重要变化) | |
| } | |
| if temperature is not None and temperature > 0: | |
| generation_args["temperature"] = temperature | |
| if top_p is not None and top_p < 1.0: | |
| generation_args["top_p"] = top_p | |
| try: | |
| # --- 调用 API, 获取 token IDs --- | |
| for output in client.text_generation(prompt, **generation_args): | |
| if hasattr(output, 'token'): # 流式生成时的输出 | |
| # 这里 output.token 是一个字典,包含 id 和 text | |
| token_id = output.token.id | |
| response_ids.append(token_id) | |
| # 使用我们自己的 tokenizer 来解码 | |
| current_text = tokenizer.decode(response_ids, skip_special_tokens=True) | |
| response_text = current_text | |
| yield response_text | |
| elif hasattr(output, 'generated_text'): # 非流式生成时的最终输出 | |
| # 如果直接返回了完整文本,就用它 | |
| response_text = output.generated_text | |
| yield response_text | |
| except Exception as e: | |
| print(f"推理时发生错误: {type(e).__name__} - {e}") | |
| yield f"抱歉,推理时遇到错误: {type(e).__name__} - {str(e)}" | |
| # 其他 Gradio 界面代码不变 | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Slider(minimum=1, maximum=512, value=325, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| title=f"推理测试: {model_id}", | |
| description="输入中文文本,模型将进行补全。" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |