Spaces:
Runtime error
Runtime error
| import time | |
| from functools import lru_cache | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| # only cache the latest model | |
| def get_model_and_tokenizer(model_id): | |
| config = AutoConfig.from_pretrained(model_id) | |
| if config.is_encoder_decoder: | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| return model, tokenizer | |
| # cache up to 32k examples | |
| def run_generation( | |
| text, | |
| model_id, | |
| max_new_tokens, | |
| alpha=0.0, | |
| top_k=0, | |
| num_beams=1, | |
| do_sample=False, | |
| top_p=0.0, | |
| seed=0 | |
| ): | |
| model, tokenizer = get_model_and_tokenizer(model_id) | |
| inputs = tokenizer(text, return_tensors='pt') | |
| if seed: | |
| torch.manual_seed(seed) | |
| start = time.time_ns() | |
| contrastive_ids = model.generate( | |
| # from the tokenizer | |
| **inputs, | |
| # fixed arguments | |
| num_return_sequences=1, | |
| early_stopping=True, | |
| # variable arguments | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=num_beams, | |
| penalty_alpha=alpha or None, | |
| top_k=top_k or None, | |
| top_p=top_p or None, | |
| ) | |
| end = time.time_ns() | |
| contrastive_time = (end - start) / 1e6 | |
| contrastive_text = tokenizer.decode(contrastive_ids[0], skip_special_tokens=True) | |
| return contrastive_text, contrastive_time | |
| def generate_beam_search(text, model_id, max_new_tokens, alpha, k, num_beams): | |
| contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
| beam_search_text, beam_search_time = run_generation(text, model_id, max_new_tokens, num_beams=num_beams) | |
| return contrastive_text, contrastive_time, beam_search_text, beam_search_time | |
| def generate_top_k(text, model_id, max_new_tokens, alpha, k, top_k, seed): | |
| contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
| top_k_text, top_k_time = run_generation( | |
| text, model_id, max_new_tokens, top_k=top_k, seed=seed, do_sample=True | |
| ) | |
| return contrastive_text, contrastive_time, top_k_text, top_k_time | |
| def generate_top_p(text, model_id, max_new_tokens, alpha, k, top_p, seed): | |
| contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k) | |
| top_p_text, top_p_time = run_generation( | |
| text, model_id, max_new_tokens, top_p=top_p, seed=seed, do_sample=True | |
| ) | |
| return contrastive_text, contrastive_time, top_p_text, top_p_time | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # Contrastive Search Generation comparison | |
| Credits to the contrastive search generation [paper](https://arxiv.org/abs/2202.06417) authors, including | |
| @[pangpang666](https://huggingface.co/pangpang666) and @[GMFTBY](https://huggingface.co/GMFTBY). Check out the | |
| follow-up [work](https://arxiv.org/abs/2210.14140), which demonstrates the usefulness of the technique with | |
| off-the-shelf LLMs, as well as their [HF guest blog post](https://huggingface.co/blog/introducing-csearch). | |
| From the paper: | |
| "At each decoding step, the key ideas of contrastive search are (i) the generated output should be selected | |
| from the set of most probable candidates predicted by the model; and (ii) the generated output should be | |
| discriminative enough with respect to the previous context. In this way, the generated text can (i) better | |
| maintain the semantic coherence with respect to the prefix while (ii) avoiding model degeneration." | |
| π¨ Warnings: π¨ | |
| - Avoid using large models (> 1GB) in this demo. It will take a long time to load the model and generate text. | |
| - Too slow/long queue? Check our | |
| [colab](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/115_introducing_contrastive_search.ipynb) | |
| instead. | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("vs. Beam Search"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Inputs βοΈ") | |
| gr.Markdown("General options:") | |
| model_id = gr.Text(value="facebook/galactica-1.3b", label="Model Repository") | |
| input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
| max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
| gr.Markdown("Contrastive Search options:") | |
| alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
| k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
| gr.Markdown("Beam Search options:") | |
| num_beams = gr.Slider(value=4, minimum=1, maximum=16, step=1, label="Number of beams") | |
| generate_button = gr.Button(value="Generate", label="Generate") | |
| with gr.Column(): | |
| gr.Markdown("## Outputs π€") | |
| gr.Markdown("Contrastive Search generation:") | |
| text_contrastive = gr.Textbox(value="", label="") | |
| time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| gr.Markdown("Beam Search generation:") | |
| text_beam_search = gr.Textbox(value="", label="") | |
| time_beam_search = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| # actions | |
| generate_button.click( | |
| fn=generate_beam_search, | |
| inputs=[input_text, model_id, max_new_tokens, alpha, k, num_beams], | |
| outputs=[text_contrastive, time_contrastive, text_beam_search, time_beam_search] | |
| ) | |
| with gr.TabItem("vs. Top K Sampling"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Inputs βοΈ") | |
| gr.Markdown("General options:") | |
| model_id = gr.Text(value="facebook/opt-125m", label="Model Repository") | |
| input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
| max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
| gr.Markdown("Contrastive Search options:") | |
| alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
| k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
| gr.Markdown("Sampling options:") | |
| top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| generate_button = gr.Button(value="Generate", label="Generate") | |
| with gr.Column(): | |
| gr.Markdown("## Outputs π€") | |
| gr.Markdown("Contrastive Search generation:") | |
| text_contrastive = gr.Textbox(value="", label="") | |
| time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| gr.Markdown("Top K Sampling generation:") | |
| text_top_k = gr.Textbox(value="", label="") | |
| time_top_k = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| # actions | |
| generate_button.click( | |
| fn=generate_top_k, | |
| inputs=[input_text, model_id, max_new_tokens, alpha, k, top_k, seed], | |
| outputs=[text_contrastive, time_contrastive, text_top_k, time_top_k] | |
| ) | |
| with gr.TabItem("vs. Nucleus Sampling"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Inputs βοΈ") | |
| gr.Markdown("General options:") | |
| model_id = gr.Text(value="facebook/opt-125m", label="Model Repository") | |
| input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text") | |
| max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate") | |
| gr.Markdown("Contrastive Search options:") | |
| alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha") | |
| k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K") | |
| gr.Markdown("Sampling options:") | |
| top_p = gr.Slider(value=0.95, minimum=0.01, maximum=1.0, step=0.01, label="Top P") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| generate_button = gr.Button(value="Generate", label="Generate") | |
| with gr.Column(): | |
| gr.Markdown("## Outputs π€") | |
| gr.Markdown("Contrastive Search generation:") | |
| text_contrastive = gr.Textbox(value="", label="") | |
| time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| gr.Markdown("Nucleus Sampling generation:") | |
| text_top_p = gr.Textbox(value="", label="") | |
| time_top_p = gr.Number(value=0.0, precision=1, label="Generation time (ms)") | |
| # actions | |
| generate_button.click( | |
| fn=generate_top_p, | |
| inputs=[input_text, model_id, max_new_tokens, alpha, k, top_p, seed], | |
| outputs=[text_contrastive, time_contrastive, text_top_p, time_top_p] | |
| ) | |
| demo.launch() | |