|
|
import gradio as gr |
|
|
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import re |
|
|
import time |
|
|
from PIL import Image |
|
|
import torch |
|
|
import spaces |
|
|
|
|
|
from processor import SmolVLMQwen3Processor |
|
|
|
|
|
AutoProcessor.register("SmolVLMQwen3Processor", SmolVLMQwen3Processor) |
|
|
processor = AutoProcessor.from_pretrained("TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B") |
|
|
model = Idefics3ForConditionalGeneration.from_pretrained("TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B", |
|
|
torch_dtype=torch.bfloat16, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def model_inference( |
|
|
input_dict, history, decoding_strategy, temperature, max_new_tokens, |
|
|
repetition_penalty, top_p |
|
|
): |
|
|
text = input_dict["text"] |
|
|
print(input_dict["files"]) |
|
|
if len(input_dict["files"]) > 1: |
|
|
images = [Image.open(image).convert("RGB") for image in input_dict["files"]] |
|
|
elif len(input_dict["files"]) == 1: |
|
|
images = [Image.open(input_dict["files"][0]).convert("RGB")] |
|
|
else: |
|
|
images = [] |
|
|
|
|
|
if not images and history: |
|
|
for turn in reversed(history): |
|
|
files, _ = turn |
|
|
if isinstance(files, tuple) and len(files) > 0: |
|
|
images = [Image.open(image).convert("RGB") for image in files] |
|
|
break |
|
|
|
|
|
if text == "" and not images: |
|
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
|
|
if text == "" and images: |
|
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
|
|
resulting_messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "简短回复问题."}, |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": text} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) |
|
|
inputs = processor(text=prompt, images=[images], return_tensors="pt") |
|
|
|
|
|
generation_args = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
"num_return_sequences": 1, |
|
|
"no_repeat_ngram_size": 2, |
|
|
"min_new_tokens": 16, |
|
|
|
|
|
} |
|
|
|
|
|
assert decoding_strategy in [ |
|
|
"Greedy", |
|
|
"Top P Sampling", |
|
|
] |
|
|
if decoding_strategy == "Greedy": |
|
|
generation_args["do_sample"] = False |
|
|
elif decoding_strategy == "Top P Sampling": |
|
|
generation_args["temperature"] = temperature |
|
|
generation_args["do_sample"] = True |
|
|
generation_args["top_p"] = top_p |
|
|
|
|
|
generation_args.update(inputs) |
|
|
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
|
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) |
|
|
generated_text = "" |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_args) |
|
|
thread.start() |
|
|
|
|
|
yield "..." |
|
|
buffer = "" |
|
|
|
|
|
for new_text in streamer: |
|
|
buffer += new_text |
|
|
generated_text_without_prompt = buffer |
|
|
time.sleep(0.01) |
|
|
yield buffer |
|
|
|
|
|
|
|
|
examples = [ |
|
|
[{"text": "图中的小女孩穿着什么颜色的上衣?", |
|
|
"files": ["example_images/objects365_v1_00322846.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
|
[{"text": "请描述这张图片的内容,并检测其中的苹果", |
|
|
"files": ["example_images/objects365_v1_00361740.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
|
[{"text": "图中是什么交通工具?", |
|
|
"files": ["example_images/objects365_v1_00357438.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
|
[{"text": "图中有几只鸭子?", |
|
|
"files": ["example_images/objects365_v1_00323167.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
|
[{"text": "这是在哪?", |
|
|
"files": ["example_images/objects365_v1_00363692.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
|
] |
|
|
demo = gr.ChatInterface( |
|
|
fn=model_inference, |
|
|
title="SmolVLM2-256M-Married-Qwen3-0.6B: SmolVLM拥抱Qwen3,支持中文问答🤖", |
|
|
description="[TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B](https://huggingface.co/TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B) 演示。请上传图片和文本,或尝试下方示例。", |
|
|
examples=examples, |
|
|
textbox=gr.MultimodalTextbox( |
|
|
label="请输入查询文本(附带图片)", |
|
|
file_types=["image"], |
|
|
file_count="multiple" |
|
|
), |
|
|
stop_btn="停止生成", |
|
|
multimodal=True, |
|
|
additional_inputs=[ |
|
|
gr.Radio( |
|
|
["Top P Sampling", "Greedy"], |
|
|
value="Greedy", |
|
|
label="解码策略", |
|
|
info="选择生成文本的方式:采样更随机,贪心更确定。" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=5.0, |
|
|
value=0.4, |
|
|
step=0.1, |
|
|
interactive=True, |
|
|
label="采样温度 (Temperature)", |
|
|
info="数值越高,输出越多样化;越低则更保守。" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=8, |
|
|
maximum=1024, |
|
|
value=512, |
|
|
step=1, |
|
|
interactive=True, |
|
|
label="最大生成 Token 数", |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=0.01, |
|
|
maximum=5.0, |
|
|
value=1.2, |
|
|
step=0.01, |
|
|
interactive=True, |
|
|
label="重复惩罚 (Repetition penalty)", |
|
|
info="1.0 表示不做惩罚;数值越大越避免重复。" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=0.01, |
|
|
maximum=0.99, |
|
|
value=0.8, |
|
|
step=0.01, |
|
|
interactive=True, |
|
|
label="Top P", |
|
|
info="数值越高,表示会采样更多低概率的 token。" |
|
|
), |
|
|
], |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|