import gradio as gr from gradio.themes.ocean import Ocean import torch import numpy as np import supervision as sv from transformers import ( Qwen3VLForConditionalGeneration, Qwen3VLProcessor, ) import json import ast import re from PIL import Image from spaces import GPU # --- Constants and Configuration --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = "auto" CATEGORIES = ["Query", "Caption", "Point", "Detect"] PLACEHOLDERS = { "Query": "What's in this image?", "Caption": "Select caption length: short, normal, or long", "Point": "Select an object from suggestions or enter manually", "Detect": "Select an object from suggestions or enter manually", } # --- Model Loading --- # Load Qwen3-VL qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", dtype=DTYPE, device_map=DEVICE, ).eval() qwen_processor = Qwen3VLProcessor.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", ) # --- Utility Functions --- def safe_parse_json(text: str): """Safely parse a string that may be JSON or a Python literal.""" text = text.strip() # Remove markdown code blocks text = re.sub(r"^```(json)?", "", text) text = re.sub(r"```$", "", text) text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass try: # Fallback to literal_eval for Python-like dictionary/list strings return ast.literal_eval(text) except Exception: return {} # --- Inference Functions --- def run_qwen_inference(image: Image.Image, prompt: str): """Core function to run inference with the Qwen model.""" messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] inputs = qwen_processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(DEVICE) with torch.inference_mode(): generated_ids = qwen_model.generate( **inputs, max_new_tokens=512, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = qwen_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] return output_text @GPU def get_suggested_objects(image: Image.Image): """Get suggested objects in the image using Qwen.""" if image is None: return [] try: # Resize image for faster suggestion generation suggest_image = image.copy() suggest_image.thumbnail((512, 512)) prompt = "List the main objects in the image in a Python list format. For example: ['cat', 'dog', 'table']" result_text = run_qwen_inference(suggest_image, prompt) # Clean up the output to find the list match = re.search(r'\[.*?\]', result_text) if match: suggested_objects = ast.literal_eval(match.group()) if isinstance(suggested_objects, list): # Return up to 3 suggestions return suggested_objects[:3] return [] except Exception as e: print(f"Error getting suggestions with Qwen: {e}") return [] def annotate_image(image: Image.Image, result: dict): """Annotates the image with points or bounding boxes based on model output.""" if not isinstance(image, Image.Image) or not isinstance(result, dict): return image original_width, original_height = image.size scene_np = np.array(image.copy()) # Handle Point annotations if "points" in result and result["points"]: points_list = [] for point in result.get("points", []): x = int(point["x"] * original_width) y = int(point["y"] * original_height) points_list.append([x, y]) if not points_list: return image points_array = np.array(points_list).reshape(-1, 2) key_points = sv.KeyPoints(xy=points_array) vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) annotated_image_np = vertex_annotator.annotate( scene=scene_np, key_points=key_points ) return Image.fromarray(annotated_image_np) # Handle Detection annotations if "objects" in result and result["objects"]: boxes = [] for obj in result["objects"]: x_min = obj["x_min"] * original_width y_min = obj["y_min"] * original_height x_max = obj["x_max"] * original_width y_max = obj["y_max"] * original_height boxes.append([x_min, y_min, x_max, y_max]) if not boxes: return image detections = sv.Detections(xyxy=np.array(boxes)) box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_image_np = box_annotator.annotate( scene=scene_np, detections=detections ) return Image.fromarray(annotated_image_np) return image @GPU def process_qwen(image: Image.Image, category: str, prompt: str): """Processes the input based on the selected category using the Qwen model.""" if category == "Query": return run_qwen_inference(image, prompt), {} elif category == "Caption": full_prompt = f"Provide a {prompt} length caption for the image." return run_qwen_inference(image, full_prompt), {} elif category == "Point": full_prompt = ( f"Provide 2d point coordinates for {prompt}. Report in JSON format like " `[{"point_2d": [x, y]}]` " where coordinates are from 0 to 1000." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) points_result = {"points": []} if isinstance(parsed_json, list): for item in parsed_json: if "point_2d" in item and len(item["point_2d"]) == 2: x, y = item["point_2d"] points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) return json.dumps(points_result, indent=2), points_result elif category == "Detect": full_prompt = ( f"Provide bounding box coordinates for {prompt}. Report in JSON format like " `[{"bbox_2d": [xmin, ymin, xmax, ymax]}]` " where coordinates are from 0 to 1000." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) objects_result = {"objects": []} if isinstance(parsed_json, list): for item in parsed_json: if "bbox_2d" in item and len(item["bbox_2d"]) == 4: xmin, ymin, xmax, ymax = item["bbox_2d"] objects_result["objects"].append( { "x_min": xmin / 1000.0, "y_min": ymin / 1000.0, "x_max": xmax / 1000.0, "y_max": ymax / 1000.0, } ) return json.dumps(objects_result, indent=2), objects_result return "Invalid category", {} # --- Gradio Interface Logic --- def on_category_and_image_change(image, category): """Generate suggestions when category or image changes.""" text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True) if category == "Caption": return gr.Radio(choices=["short", "normal", "long"], label="Caption Length", value="normal", visible=True), text_box if image is None or category not in ["Point", "Detect"]: return gr.Radio(choices=[], visible=False), text_box suggestions = get_suggested_objects(image) if suggestions: return gr.Radio(choices=suggestions, label="Suggestions", visible=True, interactive=True), text_box else: return gr.Radio(choices=[], visible=False), text_box def update_prompt_from_radio(selected_object): """Update prompt textbox when a radio option is selected.""" if selected_object: return gr.Textbox(value=selected_object) return gr.Textbox(value="") def process_inputs(image, category, prompt): """Main function to handle the user's request.""" if image is None: raise gr.Error("Please upload an image.") if not prompt and category not in ["Caption"]: # Caption can have an empty prompt if a length is selected if category == "Caption" and not prompt: prompt = "normal" # default else: raise gr.Error("Please provide a prompt or select a suggestion.") # Resize the image to make inference quicker image.thumbnail((1024, 1024)) # Process with Qwen qwen_text, qwen_data = process_qwen(image, category, prompt) qwen_annotated_image = annotate_image(image, qwen_data) return qwen_annotated_image, qwen_text # --- Gradio UI Layout --- with gr.Blocks(theme=Ocean()) as demo: gr.Markdown("# 👓 Object Understanding with Qwen3-VL") gr.Markdown( "### Explore object detection, visual grounding, and keypoint detection through natural language prompts." ) gr.Markdown(""" *Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.* """) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Input Image") category_select = gr.Radio( choices=CATEGORIES, value=CATEGORIES[0], label="Select Task Category", interactive=True, ) suggestions_radio = gr.Radio( choices=[], label="Suggestions", visible=False, interactive=True, ) prompt_input = gr.Textbox( placeholder=PLACEHOLDERS[CATEGORIES[0]], label="Prompt", lines=2, ) submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output") qwen_img_output = gr.Image(label="Annotated Image") qwen_text_output = gr.Textbox( label="Text Output", lines=10, interactive=False ) gr.Examples( examples=[ ["examples/example_1.jpg", "Query", "How many cars are in the image?"], ["examples/example_1.jpg", "Detect", "car"], ["examples/example_2.JPG", "Point", "the person's face"], ["examples/example_2.JPG", "Caption", "short"], ], inputs=[image_input, category_select, prompt_input], ) # --- Event Listeners --- # When image or category changes, update suggestions category_select.change( fn=on_category_and_image_change, inputs=[image_input, category_select], outputs=[suggestions_radio, prompt_input], ) image_input.change( fn=on_category_and_image_change, inputs=[image_input, category_select], outputs=[suggestions_radio, prompt_input], ) # When a suggestion is clicked, update the prompt box suggestions_radio.change( fn=update_prompt_from_radio, inputs=[suggestions_radio], outputs=[prompt_input], ) # Main submission action submit_btn.click( fn=process_inputs, inputs=[image_input, category_select, prompt_input], outputs=[qwen_img_output, qwen_text_output], ) if __name__ == "__main__": demo.launch(debug=True)