Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from gradio.themes.ocean import Ocean | |
| import torch | |
| import numpy as np | |
| import supervision as sv | |
| from transformers import ( | |
| Qwen3VLForConditionalGeneration, | |
| Qwen3VLProcessor, | |
| ) | |
| import json | |
| import ast | |
| import re | |
| from PIL import Image | |
| from spaces import GPU | |
| # --- Constants and Configuration --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = "auto" | |
| CATEGORIES = ["Query", "Caption", "Point", "Detect"] | |
| PLACEHOLDERS = { | |
| "Query": "What's in this image?", | |
| "Caption": "Select caption length: short, normal, or long", | |
| "Point": "Select an object from suggestions or enter manually", | |
| "Detect": "Select an object from suggestions or enter manually", | |
| } | |
| # --- Model Loading --- | |
| # Load Qwen3-VL | |
| qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| dtype=DTYPE, | |
| device_map=DEVICE, | |
| ).eval() | |
| qwen_processor = Qwen3VLProcessor.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| ) | |
| # --- Utility Functions --- | |
| def safe_parse_json(text: str): | |
| """Safely parse a string that may be JSON or a Python literal.""" | |
| text = text.strip() | |
| # Remove markdown code blocks | |
| text = re.sub(r"^```(json)?", "", text) | |
| text = re.sub(r"```$", "", text) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| # Fallback to literal_eval for Python-like dictionary/list strings | |
| return ast.literal_eval(text) | |
| except Exception: | |
| return {} | |
| # --- Inference Functions --- | |
| def run_qwen_inference(image: Image.Image, prompt: str): | |
| """Core function to run inference with the Qwen model.""" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| with torch.inference_mode(): | |
| generated_ids = qwen_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = qwen_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| return output_text | |
| def get_suggested_objects(image: Image.Image): | |
| """Get suggested objects in the image using Qwen.""" | |
| if image is None: | |
| return [] | |
| try: | |
| # Resize image for faster suggestion generation | |
| suggest_image = image.copy() | |
| suggest_image.thumbnail((512, 512)) | |
| prompt = "List the main objects in the image in a Python list format. For example: ['cat', 'dog', 'table']" | |
| result_text = run_qwen_inference(suggest_image, prompt) | |
| # Clean up the output to find the list | |
| match = re.search(r'\[.*?\]', result_text) | |
| if match: | |
| suggested_objects = ast.literal_eval(match.group()) | |
| if isinstance(suggested_objects, list): | |
| # Return up to 3 suggestions | |
| return suggested_objects[:3] | |
| return [] | |
| except Exception as e: | |
| print(f"Error getting suggestions with Qwen: {e}") | |
| return [] | |
| def annotate_image(image: Image.Image, result: dict): | |
| """Annotates the image with points or bounding boxes based on model output.""" | |
| if not isinstance(image, Image.Image) or not isinstance(result, dict): | |
| return image | |
| original_width, original_height = image.size | |
| scene_np = np.array(image.copy()) | |
| # Handle Point annotations | |
| if "points" in result and result["points"]: | |
| points_list = [] | |
| for point in result.get("points", []): | |
| x = int(point["x"] * original_width) | |
| y = int(point["y"] * original_height) | |
| points_list.append([x, y]) | |
| if not points_list: | |
| return image | |
| points_array = np.array(points_list).reshape(-1, 2) | |
| key_points = sv.KeyPoints(xy=points_array) | |
| vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) | |
| annotated_image_np = vertex_annotator.annotate( | |
| scene=scene_np, key_points=key_points | |
| ) | |
| return Image.fromarray(annotated_image_np) | |
| # Handle Detection annotations | |
| if "objects" in result and result["objects"]: | |
| boxes = [] | |
| for obj in result["objects"]: | |
| x_min = obj["x_min"] * original_width | |
| y_min = obj["y_min"] * original_height | |
| x_max = obj["x_max"] * original_width | |
| y_max = obj["y_max"] * original_height | |
| boxes.append([x_min, y_min, x_max, y_max]) | |
| if not boxes: | |
| return image | |
| detections = sv.Detections(xyxy=np.array(boxes)) | |
| box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4) | |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| annotated_image_np = box_annotator.annotate( | |
| scene=scene_np, detections=detections | |
| ) | |
| return Image.fromarray(annotated_image_np) | |
| return image | |
| def process_qwen(image: Image.Image, category: str, prompt: str): | |
| """Processes the input based on the selected category using the Qwen model.""" | |
| if category == "Query": | |
| return run_qwen_inference(image, prompt), {} | |
| elif category == "Caption": | |
| full_prompt = f"Provide a {prompt} length caption for the image." | |
| return run_qwen_inference(image, full_prompt), {} | |
| elif category == "Point": | |
| full_prompt = ( | |
| f"Provide 2d point coordinates for {prompt}. Report in JSON format like " | |
| `[{"point_2d": [x, y]}]` " where coordinates are from 0 to 1000." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| points_result = {"points": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "point_2d" in item and len(item["point_2d"]) == 2: | |
| x, y = item["point_2d"] | |
| points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) | |
| return json.dumps(points_result, indent=2), points_result | |
| elif category == "Detect": | |
| full_prompt = ( | |
| f"Provide bounding box coordinates for {prompt}. Report in JSON format like " | |
| `[{"bbox_2d": [xmin, ymin, xmax, ymax]}]` " where coordinates are from 0 to 1000." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| objects_result = {"objects": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "bbox_2d" in item and len(item["bbox_2d"]) == 4: | |
| xmin, ymin, xmax, ymax = item["bbox_2d"] | |
| objects_result["objects"].append( | |
| { | |
| "x_min": xmin / 1000.0, | |
| "y_min": ymin / 1000.0, | |
| "x_max": xmax / 1000.0, | |
| "y_max": ymax / 1000.0, | |
| } | |
| ) | |
| return json.dumps(objects_result, indent=2), objects_result | |
| return "Invalid category", {} | |
| # --- Gradio Interface Logic --- | |
| def on_category_and_image_change(image, category): | |
| """Generate suggestions when category or image changes.""" | |
| text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True) | |
| if category == "Caption": | |
| return gr.Radio(choices=["short", "normal", "long"], label="Caption Length", value="normal", visible=True), text_box | |
| if image is None or category not in ["Point", "Detect"]: | |
| return gr.Radio(choices=[], visible=False), text_box | |
| suggestions = get_suggested_objects(image) | |
| if suggestions: | |
| return gr.Radio(choices=suggestions, label="Suggestions", visible=True, interactive=True), text_box | |
| else: | |
| return gr.Radio(choices=[], visible=False), text_box | |
| def update_prompt_from_radio(selected_object): | |
| """Update prompt textbox when a radio option is selected.""" | |
| if selected_object: | |
| return gr.Textbox(value=selected_object) | |
| return gr.Textbox(value="") | |
| def process_inputs(image, category, prompt): | |
| """Main function to handle the user's request.""" | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt and category not in ["Caption"]: | |
| # Caption can have an empty prompt if a length is selected | |
| if category == "Caption" and not prompt: | |
| prompt = "normal" # default | |
| else: | |
| raise gr.Error("Please provide a prompt or select a suggestion.") | |
| # Resize the image to make inference quicker | |
| image.thumbnail((1024, 1024)) | |
| # Process with Qwen | |
| qwen_text, qwen_data = process_qwen(image, category, prompt) | |
| qwen_annotated_image = annotate_image(image, qwen_data) | |
| return qwen_annotated_image, qwen_text | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks(theme=Ocean()) as demo: | |
| gr.Markdown("# 👓 Object Understanding with Qwen3-VL") | |
| gr.Markdown( | |
| "### Explore object detection, visual grounding, and keypoint detection through natural language prompts." | |
| ) | |
| gr.Markdown(""" | |
| *Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| category_select = gr.Radio( | |
| choices=CATEGORIES, | |
| value=CATEGORIES[0], | |
| label="Select Task Category", | |
| interactive=True, | |
| ) | |
| suggestions_radio = gr.Radio( | |
| choices=[], | |
| label="Suggestions", | |
| visible=False, | |
| interactive=True, | |
| ) | |
| prompt_input = gr.Textbox( | |
| placeholder=PLACEHOLDERS[CATEGORIES[0]], | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output") | |
| qwen_img_output = gr.Image(label="Annotated Image") | |
| qwen_text_output = gr.Textbox( | |
| label="Text Output", lines=10, interactive=False | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/example_1.jpg", "Query", "How many cars are in the image?"], | |
| ["examples/example_1.jpg", "Detect", "car"], | |
| ["examples/example_2.JPG", "Point", "the person's face"], | |
| ["examples/example_2.JPG", "Caption", "short"], | |
| ], | |
| inputs=[image_input, category_select, prompt_input], | |
| ) | |
| # --- Event Listeners --- | |
| # When image or category changes, update suggestions | |
| category_select.change( | |
| fn=on_category_and_image_change, | |
| inputs=[image_input, category_select], | |
| outputs=[suggestions_radio, prompt_input], | |
| ) | |
| image_input.change( | |
| fn=on_category_and_image_change, | |
| inputs=[image_input, category_select], | |
| outputs=[suggestions_radio, prompt_input], | |
| ) | |
| # When a suggestion is clicked, update the prompt box | |
| suggestions_radio.change( | |
| fn=update_prompt_from_radio, | |
| inputs=[suggestions_radio], | |
| outputs=[prompt_input], | |
| ) | |
| # Main submission action | |
| submit_btn.click( | |
| fn=process_inputs, | |
| inputs=[image_input, category_select, prompt_input], | |
| outputs=[qwen_img_output, qwen_text_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |