TalkUHulk's picture
Update app.py
6255346 verified
import gradio as gr
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
from processor import SmolVLMQwen3Processor
AutoProcessor.register("SmolVLMQwen3Processor", SmolVLMQwen3Processor)
processor = AutoProcessor.from_pretrained("TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B")
model = Idefics3ForConditionalGeneration.from_pretrained("TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B",
torch_dtype=torch.bfloat16,
# _attn_implementation="flash_attention_2"
) # .to("cuda")
@spaces.GPU
def model_inference(
input_dict, history, decoding_strategy, temperature, max_new_tokens,
repetition_penalty, top_p
):
text = input_dict["text"]
print(input_dict["files"])
if len(input_dict["files"]) > 1:
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
elif len(input_dict["files"]) == 1:
images = [Image.open(input_dict["files"][0]).convert("RGB")]
else:
images = []
if not images and history:
for turn in reversed(history):
files, _ = turn # user text, assistant text
if isinstance(files, tuple) and len(files) > 0:
images = [Image.open(image).convert("RGB") for image in files]
break
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along the image(s).")
resulting_messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "简短回复问题."},
{"type": "image"},
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[images], return_tensors="pt")
# inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_args = {
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
"num_return_sequences": 1,
"no_repeat_ngram_size": 2,
"min_new_tokens": 16,
}
assert decoding_strategy in [
"Greedy",
"Top P Sampling",
]
if decoding_strategy == "Greedy":
generation_args["do_sample"] = False
elif decoding_strategy == "Top P Sampling":
generation_args["temperature"] = temperature
generation_args["do_sample"] = True
generation_args["top_p"] = top_p
generation_args.update(inputs)
# Generate
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
buffer = ""
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer # [len(ext_buffer):]
time.sleep(0.01)
yield buffer
examples = [
[{"text": "图中的小女孩穿着什么颜色的上衣?",
"files": ["example_images/objects365_v1_00322846.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "请描述这张图片的内容,并检测其中的苹果",
"files": ["example_images/objects365_v1_00361740.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "图中是什么交通工具?",
"files": ["example_images/objects365_v1_00357438.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "图中有几只鸭子?",
"files": ["example_images/objects365_v1_00323167.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
[{"text": "这是在哪?",
"files": ["example_images/objects365_v1_00363692.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
]
demo = gr.ChatInterface(
fn=model_inference,
title="SmolVLM2-256M-Married-Qwen3-0.6B: SmolVLM拥抱Qwen3,支持中文问答🤖",
description="[TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B](https://huggingface.co/TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B) 演示。请上传图片和文本,或尝试下方示例。",
examples=examples,
textbox=gr.MultimodalTextbox(
label="请输入查询文本(附带图片)",
file_types=["image"],
file_count="multiple"
),
stop_btn="停止生成",
multimodal=True,
additional_inputs=[
gr.Radio(
["Top P Sampling", "Greedy"],
value="Greedy",
label="解码策略",
info="选择生成文本的方式:采样更随机,贪心更确定。"
),
gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.4,
step=0.1,
interactive=True,
label="采样温度 (Temperature)",
info="数值越高,输出越多样化;越低则更保守。"
),
gr.Slider(
minimum=8,
maximum=1024,
value=512,
step=1,
interactive=True,
label="最大生成 Token 数",
),
gr.Slider(
minimum=0.01,
maximum=5.0,
value=1.2,
step=0.01,
interactive=True,
label="重复惩罚 (Repetition penalty)",
info="1.0 表示不做惩罚;数值越大越避免重复。"
),
gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.8,
step=0.01,
interactive=True,
label="Top P",
info="数值越高,表示会采样更多低概率的 token。"
),
],
cache_examples=False
)
demo.launch(debug=True)