Spaces:
Runtime error
Runtime error
| import json | |
| import datasets | |
| import gradio as gr | |
| import torch | |
| from backend import (get_message_single, get_message_spam, send_single, | |
| send_spam, tokenizer) | |
| from defaults import (ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA, | |
| defaults_bt_single, defaults_bt_spam, | |
| defaults_vanilla_single, defaults_vanilla_spam) | |
| def dispatch_single( | |
| input_model_single, address_input_vanilla, address_input_bettertransformer | |
| ): | |
| result_vanilla = send_single(input_model_single, address_input_vanilla) | |
| result_bettertransformer = send_single( | |
| input_model_single, address_input_bettertransformer | |
| ) | |
| return result_vanilla, result_bettertransformer | |
| def dispatch_spam_artif( | |
| input_n_spam_artif, | |
| sequence_length, | |
| padding_ratio, | |
| address_input_vanilla, | |
| address_input_bettertransformer, | |
| ): | |
| sequence_length = int(sequence_length) | |
| input_n_spam_artif = int(input_n_spam_artif) | |
| inp_tokens = torch.randint(tokenizer.vocab_size - 1, (sequence_length,)) + 1 | |
| n_pads = max(int(padding_ratio * len(inp_tokens)), 1) | |
| inp_tokens[-n_pads:] = 0 | |
| inp_tokens[0] = 101 | |
| inp_tokens[-n_pads - 1] = 102 | |
| attention_mask = torch.zeros((sequence_length,), dtype=torch.int64) | |
| attention_mask[:-n_pads] = 1 | |
| str_input = json.dumps( | |
| { | |
| "input_ids": inp_tokens.cpu().tolist(), | |
| "attention_mask": attention_mask.cpu().tolist(), | |
| "pre_tokenized": True, | |
| } | |
| ) | |
| input_dataset = datasets.Dataset.from_dict( | |
| {"sentence": [str_input for _ in range(input_n_spam_artif)]} | |
| ) | |
| result_vanilla = send_spam(input_dataset, address_input_vanilla) | |
| result_bettertransformer = send_spam(input_dataset, address_input_bettertransformer) | |
| return result_vanilla, result_bettertransformer | |
| TTILE_IMAGE = """ | |
| <div | |
| style=" | |
| display: block; | |
| margin-left: auto; | |
| margin-right: auto; | |
| width: 50%; | |
| " | |
| > | |
| <img src="https://huggingface.co/spaces/fxmarty/bettertransformer-demo/resolve/main/header_center.png"/> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.HTML(TTILE_IMAGE) | |
| gr.Markdown( | |
| "## Speed up inference and support more workload with PyTorch's BetterTransformer π€" | |
| ) | |
| gr.Markdown( | |
| """ | |
| **The two AWS instances powering this Space are offline (to save us the $$$). Feel free to reproduce using [this backend code](https://github.com/fxmarty/bettertransformer_demo). The example results are from an AWS EC2 g4dn.xlarge instance with a single NVIDIA T4 GPU.** | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| Let's try out [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) + [TorchServe](https://pytorch.org/serve/)! | |
| BetterTransformer is a stable feature made available with [PyTorch 1.13](https://pytorch.org/blog/PyTorch-1.13-release/) allowing to use a fastpath execution for encoder attention blocks. Depending on your hardware, batch size, sequence length, padding ratio, it can bring large speedups at inference **at no cost in prediction quality**. As a one-liner, you can convert your π€ Transformers models to use BetterTransformer thanks to the integration in the [π€ Optimum](https://github.com/huggingface/optimum) library: | |
| ``` | |
| from optimum.bettertransformer import BetterTransformer | |
| better_model = BetterTransformer.transform(model) | |
| ``` | |
| This Space is a demo of an **end-to-end** deployement of PyTorch eager-mode models, both with and without BetterTransformer. The goal is to see what are the benefits server-side and client-side of using BetterTransformer. The model used is [`distilbert-base-uncased-finetuned-sst-2-english`](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english), and TorchServe is parametrized to use a maximum batch size of 8. **Beware:** you may be queued in case several persons use the Space at the same time. | |
| For more details on the TorchServe implementation and to reproduce, see [this reference code](https://github.com/fxmarty/bettertransformer_demo). For more details on BetterTransformer, check out the [blog post on PyTorch's Medium](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2), and [the Optimum documentation](https://huggingface.co/docs/optimum/bettertransformer/overview)!""" | |
| ) | |
| gr.Markdown("""## Single input scenario | |
| Note: BetterTransformer normally shines with batch size > 1 and some padding. So this is not the best case here. Check out the heavy workload case below as well. | |
| """) | |
| address_input_vanilla = gr.Textbox( | |
| max_lines=1, label="ip vanilla", value=ADDRESS_VANILLA, visible=False | |
| ) | |
| address_input_bettertransformer = gr.Textbox( | |
| max_lines=1, | |
| label="ip bettertransformer", | |
| value=ADDRESS_BETTERTRANSFORMER, | |
| visible=False, | |
| ) | |
| input_model_single = gr.Textbox( | |
| max_lines=1, | |
| label="Text", | |
| value="Expectations were low, enjoyment was high. Although the music was not top level, the story was well-paced.", | |
| ) | |
| btn_single = gr.Button("Send single text request") | |
| with gr.Row(): | |
| with gr.Column(scale=50): | |
| gr.Markdown("### Vanilla Transformers + TorchServe") | |
| output_single_vanilla = gr.Markdown( | |
| label="Output single vanilla", | |
| value=get_message_single(**defaults_vanilla_single), | |
| ) | |
| with gr.Column(scale=50): | |
| gr.Markdown("### BetterTransformer + TorchServe") | |
| output_single_bt = gr.Markdown( | |
| label="Output single bt", value=get_message_single(**defaults_bt_single) | |
| ) | |
| btn_single.click( | |
| fn=dispatch_single, | |
| inputs=[ | |
| input_model_single, | |
| address_input_vanilla, | |
| address_input_bettertransformer, | |
| ], | |
| outputs=[output_single_vanilla, output_single_bt], | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Beware that the end-to-end latency can be impacted by a different ping time between the two servers.** | |
| ## Heavy workload scenario | |
| """ | |
| ) | |
| input_n_spam_artif = gr.Number( | |
| label="Number of sequences to send", | |
| value=80, | |
| ) | |
| sequence_length = gr.Number( | |
| label="Sequence length (in tokens)", | |
| value=128, | |
| ) | |
| padding_ratio = gr.Number( | |
| label="Padding ratio (i.e. how much of the input is padding. In the real world when batch size > 1, the token sequence is padded with 0 to have all inputs with the same length.)", | |
| value=0.7, | |
| ) | |
| btn_spam_artif = gr.Button("Spam text requests (using artificial data)") | |
| with gr.Row(): | |
| with gr.Column(scale=50): | |
| gr.Markdown("### Vanilla Transformers + TorchServe") | |
| output_spam_vanilla_artif = gr.Markdown( | |
| label="Output spam vanilla", | |
| value=get_message_spam(**defaults_vanilla_spam), | |
| ) | |
| with gr.Column(scale=50): | |
| gr.Markdown("### BetterTransformer + TorchServe") | |
| output_spam_bt_artif = gr.Markdown( | |
| label="Output spam bt", value=get_message_spam(**defaults_bt_spam) | |
| ) | |
| btn_spam_artif.click( | |
| fn=dispatch_spam_artif, | |
| inputs=[ | |
| input_n_spam_artif, | |
| sequence_length, | |
| padding_ratio, | |
| address_input_vanilla, | |
| address_input_bettertransformer, | |
| ], | |
| outputs=[output_spam_vanilla_artif, output_spam_bt_artif], | |
| ) | |
| demo.queue(concurrency_count=1) | |
| demo.launch() | |