import gradio as gr import torch from transformers import AutoProcessor, AutoModel from PIL import Image import requests import time import io import fitz # PyMuPDF for PDF support import matplotlib.pyplot as plt # Define model repository IDs MODELS = { "Pixtral-12B": "mistralai/Pixtral-12B-2409", "InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B", "Aria-7B": "Aria-7B" # Replace with actual model ID when public } MODEL_CACHE = {} def load_model(model_id): if model_id not in MODEL_CACHE: processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto") MODEL_CACHE[model_id] = (processor, model) return MODEL_CACHE[model_id] def convert_pdf_to_image(pdf_bytes): pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf") page = pdf_doc.load_page(0) pix = page.get_pixmap(dpi=150) image_bytes = pix.tobytes("png") image = Image.open(io.BytesIO(image_bytes)) return image def load_image_from_url(url): response = requests.get(url) if response.status_code != 200: raise ValueError(f"Failed to load image from {url}") return Image.open(io.BytesIO(response.content)) def compare_models(input_url, prompt): if not input_url or not prompt: return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None # Load image or PDF from URL if input_url.lower().endswith('.pdf'): pdf_data = requests.get(input_url).content image = convert_pdf_to_image(pdf_data) else: image = load_image_from_url(input_url) image.thumbnail((512, 512)) latency_data = {} results = {} for name, model_id in MODELS.items(): try: processor, model = load_model(model_id) start = time.time() if hasattr(model, 'chat'): text = model.chat(processor.tokenizer, image=image, query=prompt) else: inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") outputs = model.generate(**inputs, max_new_tokens=128) text = processor.decode(outputs[0], skip_special_tokens=True) elapsed = time.time() - start results[name] = f"🧠 {text}\n\nā±ļø {elapsed:.2f}s" latency_data[name] = elapsed except Exception as e: results[name] = f"āŒ Error: {str(e)}" latency_data[name] = 0 return [results.get(name, "Model not loaded.") for name in MODELS], latency_data def plot_latency(latency_data): if not latency_data: return None plt.figure(figsize=(6, 3)) plt.bar(latency_data.keys(), latency_data.values()) plt.title("Model Inference Latency (s)") plt.ylabel("Seconds") plt.tight_layout() return plt def build_ui(): with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo: gr.Markdown(""" # 🌐 Multimodal Model Comparator (Online Images) Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question. The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side. _Licenses: Apache 2.0 / MIT — safe for research and demo use._ """) with gr.Row(): url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg") prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...") with gr.Row(): pixtral_out = gr.Textbox(label="Pixtral Output") internvl_out = gr.Textbox(label="InternVL Output") aria_out = gr.Textbox(label="Aria Output") latency_plot = gr.Plot(label="Latency Comparison") def process(input_url, prompt): outputs, latency_data = compare_models(input_url, prompt) plot = plot_latency(latency_data) return outputs[0], outputs[1], outputs[2], plot run_button = gr.Button("Run Comparison") run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot]) gr.Examples( examples=[ ["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."], ["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"], ["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"], ], inputs=[url_input, prompt_input] ) return demo if __name__ == "__main__": demo = build_ui() demo.launch()