Spaces:
Sleeping
Sleeping
| import re | |
| import threading | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| CSS = """ | |
| .m3d-auto-scroll > * { | |
| overflow: auto; | |
| } | |
| #reasoning { | |
| overflow: auto; | |
| height: calc(100vh - 128px); | |
| scroll-behavior: smooth; | |
| } | |
| """ | |
| JS = """ | |
| () => { | |
| // auto scroll .auto-scroll elements when text has changed | |
| const block = document.querySelector('#reasoning'); | |
| const observer = new MutationObserver((mutations) => { | |
| block.scrollTop = block.scrollHeight; | |
| }) | |
| observer.observe(block, { | |
| childList: true, | |
| characterData: true, | |
| subtree: true, | |
| }); | |
| } | |
| """ | |
| model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| ) | |
| print(dir(model)) | |
| print(model.config) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| def reformat_math(text): | |
| """Fix MathJax delimiters to use the Gradio syntax. | |
| This is a workaround to display math formulas in Gradio. For now, I havn't found a way to | |
| make it work as expected using others latex_delimites... | |
| """ | |
| text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
| text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
| return text | |
| def chat(prompt, history): | |
| """Respond to a chat prompt.""" | |
| message = { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| # build the messages list | |
| history = [] if history is None else history | |
| message_list = history + [message] | |
| text = tokenizer.apply_chat_template( | |
| message_list, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| threading.Thread( | |
| target=model.generate, | |
| kwargs=dict( | |
| max_new_tokens=1024 * 128, | |
| streamer=streamer, | |
| **model_inputs, | |
| ), | |
| ).start() | |
| buffer = "" | |
| reasoning = "" | |
| thinking = False | |
| reasoning_heading = "# Reasoning\n\n" | |
| for new_text in streamer: | |
| if not thinking and "<think>" in new_text: | |
| thinking = True | |
| continue | |
| if thinking and "</think>" in new_text: | |
| thinking = False | |
| continue | |
| if thinking: | |
| reasoning += new_text | |
| yield ( | |
| "I'm thinking, please wait a moment...", | |
| reasoning_heading + reasoning, | |
| ) | |
| continue | |
| buffer += new_text | |
| yield reformat_math(buffer), reasoning_heading + reasoning | |
| chat_bot = gr.Chatbot( | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| ], | |
| scale=1, | |
| type="messages", | |
| ) | |
| with gr.Blocks( | |
| theme="davehornik/Tealy", | |
| js=JS, | |
| css=CSS, | |
| fill_height=True, | |
| title="Reasoning model example", | |
| ) as demo: | |
| reasoning = gr.Markdown( | |
| "# Reasoning\n\nWhen the model will reasoning, its thoughts will be displayed here.", | |
| label="Reasoning", | |
| show_label=True, | |
| container=True, | |
| elem_classes="m3d-auto-scroll", | |
| render=False, | |
| ) | |
| with gr.Row(equal_height=True, variant="panel"): | |
| with gr.Column(scale=3, variant="compact"): | |
| gr.ChatInterface( | |
| chat, | |
| type="messages", | |
| chatbot=chat_bot, | |
| title="Simple conversational AI with reasoning", | |
| description=( | |
| f"We're using the **{model_name}**. It is a large language model " | |
| "trained on a mixture of instruction and " | |
| "conversational data. It has the capability to reason about the " | |
| "prompt (the user question). " | |
| "When you ask a question, you can see its thoughts " | |
| "on the left block." | |
| ), | |
| additional_outputs=[reasoning], | |
| ) | |
| with gr.Column(variant="compact", elem_id="reasoning"): | |
| reasoning.render() | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |