Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
| class BaseStreamer: | |
| """ | |
| Base class from which `.generate()` streamers should inherit. | |
| """ | |
| def put(self, value): | |
| """Function that is called by `.generate()` to push new tokens""" | |
| raise NotImplementedError() | |
| def end(self): | |
| """Function that is called by `.generate()` to signal the end of generation""" | |
| raise NotImplementedError() | |
| class TextStreamer(BaseStreamer): | |
| """ | |
| Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. | |
| <Tip warning={true}> | |
| The API for the streamer classes is still under development and may change in the future. | |
| </Tip> | |
| Parameters: | |
| tokenizer (`AutoTokenizer`): | |
| The tokenized used to decode the tokens. | |
| skip_prompt (`bool`, *optional*, defaults to `False`): | |
| Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. | |
| decode_kwargs (`dict`, *optional*): | |
| Additional keyword arguments to pass to the tokenizer's `decode` method. | |
| Examples: | |
| ```python | |
| >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer | |
| >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
| >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") | |
| >>> streamer = TextStreamer(tok) | |
| >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. | |
| >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) | |
| An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, | |
| ``` | |
| """ | |
| def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): | |
| self.tokenizer = tokenizer | |
| self.skip_prompt = skip_prompt | |
| self.decode_kwargs = decode_kwargs | |
| # variables used in the streaming process | |
| self.token_cache = [] | |
| self.print_len = 0 | |
| self.next_tokens_are_prompt = True | |
| def put(self, value): | |
| """ | |
| Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
| """ | |
| if len(value.shape) > 1 and value.shape[0] > 1: | |
| raise ValueError("TextStreamer only supports batch size 1") | |
| elif len(value.shape) > 1: | |
| value = value[0] | |
| if self.skip_prompt and self.next_tokens_are_prompt: | |
| self.next_tokens_are_prompt = False | |
| return | |
| # Add the new token to the cache and decodes the entire thing. | |
| self.token_cache.extend(value.tolist()) | |
| text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
| # After the symbol for a new line, we flush the cache. | |
| if text.endswith("\n"): | |
| printable_text = text[self.print_len :] | |
| self.token_cache = [] | |
| self.print_len = 0 | |
| # If the last token is a CJK character, we print the characters. | |
| elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): | |
| printable_text = text[self.print_len :] | |
| self.print_len += len(printable_text) | |
| # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, | |
| # which may change with the subsequent token -- there are probably smarter ways to do this!) | |
| else: | |
| printable_text = text[self.print_len : text.rfind(" ") + 1] | |
| self.print_len += len(printable_text) | |
| self.on_finalized_text(printable_text) | |
| def end(self): | |
| """Flushes any remaining cache and prints a newline to stdout.""" | |
| # Flush the cache, if it exists | |
| if len(self.token_cache) > 0: | |
| text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
| printable_text = text[self.print_len :] | |
| self.token_cache = [] | |
| self.print_len = 0 | |
| else: | |
| printable_text = "" | |
| self.next_tokens_are_prompt = True | |
| self.on_finalized_text(printable_text, stream_end=True) | |
| def on_finalized_text(self, text: str, stream_end: bool = False): | |
| """Prints the new text to stdout. If the stream is ending, also prints a newline.""" | |
| # print(text, flush=True, end="" if not stream_end else None) | |
| messages.value = [ | |
| *messages.value[:-1], | |
| { | |
| "role": "assistant", | |
| "content": messages.value[-1]["content"] + text, | |
| }, | |
| ] | |
| def _is_chinese_char(self, cp): | |
| """Checks whether CP is the codepoint of a CJK character.""" | |
| # This defines a "chinese character" as anything in the CJK Unicode block: | |
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
| # | |
| # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |
| # despite its name. The modern Korean Hangul alphabet is a different block, | |
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |
| # space-separated words, so they are not treated specially and handled | |
| # like the all of the other languages. | |
| if ( | |
| (cp >= 0x4E00 and cp <= 0x9FFF) | |
| or (cp >= 0x3400 and cp <= 0x4DBF) # | |
| or (cp >= 0x20000 and cp <= 0x2A6DF) # | |
| or (cp >= 0x2A700 and cp <= 0x2B73F) # | |
| or (cp >= 0x2B740 and cp <= 0x2B81F) # | |
| or (cp >= 0x2B820 and cp <= 0x2CEAF) # | |
| or (cp >= 0xF900 and cp <= 0xFAFF) | |
| or (cp >= 0x2F800 and cp <= 0x2FA1F) # | |
| ): # | |
| return True | |
| return False | |
| streamer = TextStreamer(tokenizer, skip_prompt=True) | |
| import re | |
| import solara | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| class MessageDict(TypedDict): | |
| role: str | |
| content: str | |
| messages: solara.Reactive[List[MessageDict]] = solara.reactive([]) | |
| def Page(): | |
| solara.lab.theme.themes.light.primary = "#0000ff" | |
| solara.lab.theme.themes.light.secondary = "#0000ff" | |
| solara.lab.theme.themes.dark.primary = "#0000ff" | |
| solara.lab.theme.themes.dark.secondary = "#0000ff" | |
| title = "Qwen2-0.5B" | |
| with solara.Head(): | |
| solara.Title(f"{title}") | |
| with solara.Column(align="center"): | |
| user_message_count = len([m for m in messages.value if m["role"] == "user"]) | |
| def send(message): | |
| messages.value = [*messages.value, {"role": "user", "content": message}] | |
| def response(message): | |
| messages.value = [*messages.value, {"role": "assistant", "content": ""}] | |
| text = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": message}], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512) | |
| def result(): | |
| if messages.value != []: | |
| response(messages.value[-1]["content"]) | |
| result = solara.lab.use_task(result, dependencies=[user_message_count]) | |
| with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}): | |
| for item in messages.value: | |
| with solara.lab.ChatMessage( | |
| user=item["role"] == "user", | |
| name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct", | |
| avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f", | |
| border_radius="20px", | |
| style="background-color:lightgrey!important;" | |
| ): | |
| item["content"] = re.sub('<\|im_end\|>', '', item["content"]) | |
| solara.Markdown(item["content"]) | |
| solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"}) | |