Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True) | |
| tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3") | |
| def post_process_sentence(input_sentence, generated_sentence): | |
| new_sentence = generated_sentence.replace(input_sentence, "") | |
| if "\n" not in new_sentence: | |
| return generated_sentence.replace(" ", " ") + "\n- " | |
| else: | |
| return (new_sentence.split("\n")[0]).replace(" ", " ") + "\n- " | |
| def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7, do_sample=True, seed=42): | |
| set_seed(seed) | |
| input_ids = tokenizer.encode(input_sentence, return_tensors="pt") | |
| output = model.generate( | |
| input_ids, do_sample=do_sample, | |
| max_length=len(input_sentence)+max_length, | |
| top_k=top_k, | |
| temperature=temperature, | |
| ) | |
| generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return post_process_sentence(input_sentence, generated_sentence) | |
| def question_bloom(input_sentence, max_length, temperature, do_sample=True, top_k=3, seed=42): | |
| post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length, do_sample=do_sample, top_k=top_k, seed=seed) | |
| return post_processed_output.split("\n-")[-2] | |
| gr.Interface( | |
| question_bloom, | |
| [ | |
| gr.Textbox(lines=10, label="Input code"), | |
| gr.inputs.Slider( | |
| minimum=8, | |
| maximum=256, | |
| step=1, | |
| default=8, | |
| label="Number of tokens to generate", | |
| ), | |
| gr.inputs.Slider( | |
| minimum=0, | |
| maximum=2, | |
| step=0.1, | |
| default=0.6, | |
| label="Temperature", | |
| ), | |
| gr.inputs.Checkbox(True, label="Do Sample"), | |
| gr.inputs.Slider( | |
| minimum=0, | |
| maximum=10, | |
| step=1, | |
| default=3, | |
| label="Top K", | |
| ), | |
| gr.inputs.Slider( | |
| minimum=0, | |
| maximum=256, | |
| step=1, | |
| default=42, | |
| label="Random seed for generation", | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Predicted sentence", lines=10), | |
| ).launch() |