|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import re |
|
|
import time |
|
|
import torch |
|
|
import spaces |
|
|
import subprocess |
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
from io import BytesIO |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") |
|
|
model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct", |
|
|
_attn_implementation="flash_attention_2", |
|
|
torch_dtype=torch.bfloat16).to("cuda:0") |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def model_inference( |
|
|
input_dict, history, max_tokens |
|
|
): |
|
|
text = input_dict["text"] |
|
|
images = [] |
|
|
user_content = [] |
|
|
media_queue = [] |
|
|
if history == []: |
|
|
text = input_dict["text"].strip() |
|
|
|
|
|
for file in input_dict.get("files", []): |
|
|
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): |
|
|
media_queue.append({"type": "image", "path": file}) |
|
|
elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")): |
|
|
media_queue.append({"type": "video", "path": file}) |
|
|
|
|
|
if "<image>" in text or "<video>" in text: |
|
|
parts = re.split(r'(<image>|<video>)', text) |
|
|
for part in parts: |
|
|
if part == "<image>" and media_queue: |
|
|
user_content.append(media_queue.pop(0)) |
|
|
elif part == "<video>" and media_queue: |
|
|
user_content.append(media_queue.pop(0)) |
|
|
elif part.strip(): |
|
|
user_content.append({"type": "text", "text": part.strip()}) |
|
|
else: |
|
|
user_content.append({"type": "text", "text": text}) |
|
|
|
|
|
for media in media_queue: |
|
|
user_content.append(media) |
|
|
|
|
|
resulting_messages = [{"role": "user", "content": user_content}] |
|
|
|
|
|
elif len(history) > 0: |
|
|
resulting_messages = [] |
|
|
user_content = [] |
|
|
media_queue = [] |
|
|
for hist in history: |
|
|
if hist["role"] == "user" and isinstance(hist["content"], tuple): |
|
|
file_name = hist["content"][0] |
|
|
if file_name.endswith((".png", ".jpg", ".jpeg")): |
|
|
media_queue.append({"type": "image", "path": file_name}) |
|
|
elif file_name.endswith(".mp4"): |
|
|
media_queue.append({"type": "video", "path": file_name}) |
|
|
|
|
|
|
|
|
for hist in history: |
|
|
if hist["role"] == "user" and isinstance(hist["content"], str): |
|
|
text = hist["content"] |
|
|
parts = re.split(r'(<image>|<video>)', text) |
|
|
|
|
|
for part in parts: |
|
|
if part == "<image>" and media_queue: |
|
|
user_content.append(media_queue.pop(0)) |
|
|
elif part == "<video>" and media_queue: |
|
|
user_content.append(media_queue.pop(0)) |
|
|
elif part.strip(): |
|
|
user_content.append({"type": "text", "text": part.strip()}) |
|
|
|
|
|
elif hist["role"] == "assistant": |
|
|
resulting_messages.append({ |
|
|
"role": "user", |
|
|
"content": user_content |
|
|
}) |
|
|
resulting_messages.append({ |
|
|
"role": "assistant", |
|
|
"content": [{"type": "text", "text": hist["content"]}] |
|
|
}) |
|
|
user_content = [] |
|
|
|
|
|
|
|
|
if text == "" and not images: |
|
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
|
|
if text == "" and images: |
|
|
gr.Error("Please input a text query along the images(s).") |
|
|
print("resulting_messages", resulting_messages) |
|
|
inputs = processor.apply_chat_template( |
|
|
resulting_messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens) |
|
|
generated_text = "" |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_args) |
|
|
thread.start() |
|
|
|
|
|
yield "..." |
|
|
buffer = "" |
|
|
|
|
|
|
|
|
for new_text in streamer: |
|
|
|
|
|
buffer += new_text |
|
|
generated_text_without_prompt = buffer |
|
|
time.sleep(0.01) |
|
|
yield buffer |
|
|
|
|
|
|
|
|
examples=[ |
|
|
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}], |
|
|
[{"text": "What art era this artpiece <image> and this artpiece <image> belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}], |
|
|
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}], |
|
|
[{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}], |
|
|
[{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}], |
|
|
[{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}], |
|
|
] |
|
|
demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺", |
|
|
description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.", |
|
|
examples=examples, |
|
|
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, |
|
|
cache_examples=False, |
|
|
additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")], |
|
|
type="messages" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch(debug=True) |
|
|
|