Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| from PIL import Image | |
| import requests | |
| import time | |
| import io | |
| import fitz # PyMuPDF for PDF support | |
| import matplotlib.pyplot as plt | |
| # Define model repository IDs | |
| MODELS = { | |
| "Pixtral-12B": "mistralai/Pixtral-12B-2409", | |
| "InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B", | |
| "Aria-7B": "Aria-7B" # Replace with actual model ID when public | |
| } | |
| MODEL_CACHE = {} | |
| def load_model(model_id): | |
| if model_id not in MODEL_CACHE: | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto") | |
| MODEL_CACHE[model_id] = (processor, model) | |
| return MODEL_CACHE[model_id] | |
| def convert_pdf_to_image(pdf_bytes): | |
| pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| page = pdf_doc.load_page(0) | |
| pix = page.get_pixmap(dpi=150) | |
| image_bytes = pix.tobytes("png") | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| return image | |
| def load_image_from_url(url): | |
| response = requests.get(url) | |
| if response.status_code != 200: | |
| raise ValueError(f"Failed to load image from {url}") | |
| return Image.open(io.BytesIO(response.content)) | |
| def compare_models(input_url, prompt): | |
| if not input_url or not prompt: | |
| return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None | |
| # Load image or PDF from URL | |
| if input_url.lower().endswith('.pdf'): | |
| pdf_data = requests.get(input_url).content | |
| image = convert_pdf_to_image(pdf_data) | |
| else: | |
| image = load_image_from_url(input_url) | |
| image.thumbnail((512, 512)) | |
| latency_data = {} | |
| results = {} | |
| for name, model_id in MODELS.items(): | |
| try: | |
| processor, model = load_model(model_id) | |
| start = time.time() | |
| if hasattr(model, 'chat'): | |
| text = model.chat(processor.tokenizer, image=image, query=prompt) | |
| else: | |
| inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") | |
| outputs = model.generate(**inputs, max_new_tokens=128) | |
| text = processor.decode(outputs[0], skip_special_tokens=True) | |
| elapsed = time.time() - start | |
| results[name] = f"π§ {text}\n\nβ±οΈ {elapsed:.2f}s" | |
| latency_data[name] = elapsed | |
| except Exception as e: | |
| results[name] = f"β Error: {str(e)}" | |
| latency_data[name] = 0 | |
| return [results.get(name, "Model not loaded.") for name in MODELS], latency_data | |
| def plot_latency(latency_data): | |
| if not latency_data: | |
| return None | |
| plt.figure(figsize=(6, 3)) | |
| plt.bar(latency_data.keys(), latency_data.values()) | |
| plt.title("Model Inference Latency (s)") | |
| plt.ylabel("Seconds") | |
| plt.tight_layout() | |
| return plt | |
| def build_ui(): | |
| with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo: | |
| gr.Markdown(""" | |
| # π Multimodal Model Comparator (Online Images) | |
| Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question. | |
| The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side. | |
| _Licenses: Apache 2.0 / MIT β safe for research and demo use._ | |
| """) | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg") | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...") | |
| with gr.Row(): | |
| pixtral_out = gr.Textbox(label="Pixtral Output") | |
| internvl_out = gr.Textbox(label="InternVL Output") | |
| aria_out = gr.Textbox(label="Aria Output") | |
| latency_plot = gr.Plot(label="Latency Comparison") | |
| def process(input_url, prompt): | |
| outputs, latency_data = compare_models(input_url, prompt) | |
| plot = plot_latency(latency_data) | |
| return outputs[0], outputs[1], outputs[2], plot | |
| run_button = gr.Button("Run Comparison") | |
| run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot]) | |
| gr.Examples( | |
| examples=[ | |
| ["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."], | |
| ["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"], | |
| ["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"], | |
| ], | |
| inputs=[url_input, prompt_input] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch() |