Sunil Sarolkar commited on
Commit
0f38fdf
Β·
1 Parent(s): 31fd9d9

updated image references

Browse files
Files changed (1) hide show
  1. app.py +51 -58
app.py CHANGED
@@ -1,86 +1,80 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, AutoModelForVision2Seq
4
  from PIL import Image
 
5
  import time
6
- import fitz # PyMuPDF for PDF support
7
  import io
 
 
8
 
9
- # Define the models you want to compare
10
  MODELS = {
11
  "Pixtral-12B": "mistralai/Pixtral-12B-2409",
12
- "InternVL-2.5": "OpenGVLab/InternVL2_5-Chat",
13
  "Aria-7B": "Aria-7B" # Replace with actual model ID when public
14
  }
15
 
16
  MODEL_CACHE = {}
17
 
18
- # Load models and processors (lazy loading for faster startup)
19
  def load_model(model_id):
20
  if model_id not in MODEL_CACHE:
21
- processor = AutoProcessor.from_pretrained(model_id)
22
- model = AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
23
  MODEL_CACHE[model_id] = (processor, model)
24
  return MODEL_CACHE[model_id]
25
 
26
-
27
  def convert_pdf_to_image(pdf_bytes):
28
- try:
29
- pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
30
- page = pdf_doc.load_page(0) # first page only
31
- pix = page.get_pixmap(dpi=150)
32
- image_bytes = pix.tobytes("png")
33
- image = Image.open(io.BytesIO(image_bytes))
34
- return image
35
- except Exception as e:
36
- raise ValueError(f"Failed to convert PDF: {e}")
37
-
38
-
39
- def compare_models(file, prompt):
40
- results = {}
41
-
42
- if file is None or not prompt:
43
- return {name: "Please provide both image/PDF and prompt." for name in MODELS}, None
44
-
45
- # Determine input type (PDF or image)
46
- if isinstance(file, str):
47
- image = Image.open(file)
 
48
  else:
49
- file_bytes = file.read() if hasattr(file, 'read') else file
50
- if file.name.endswith('.pdf'):
51
- image = convert_pdf_to_image(file_bytes)
52
- else:
53
- image = Image.open(io.BytesIO(file_bytes))
54
-
55
- image.thumbnail((512, 512)) # optimize
56
 
 
57
  latency_data = {}
 
58
 
59
  for name, model_id in MODELS.items():
60
  try:
61
  processor, model = load_model(model_id)
62
  start = time.time()
63
-
64
- inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
65
- outputs = model.generate(**inputs, max_new_tokens=128)
66
- text = processor.decode(outputs[0], skip_special_tokens=True)
67
-
 
68
  elapsed = time.time() - start
69
  results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
70
  latency_data[name] = elapsed
71
-
72
  except Exception as e:
73
  results[name] = f"❌ Error: {str(e)}"
74
  latency_data[name] = 0
75
 
76
- # Return results and latency chart data
77
  return [results.get(name, "Model not loaded.") for name in MODELS], latency_data
78
 
79
-
80
  def plot_latency(latency_data):
81
  if not latency_data:
82
  return None
83
- import matplotlib.pyplot as plt
84
  plt.figure(figsize=(6, 3))
85
  plt.bar(latency_data.keys(), latency_data.values())
86
  plt.title("Model Inference Latency (s)")
@@ -88,19 +82,18 @@ def plot_latency(latency_data):
88
  plt.tight_layout()
89
  return plt
90
 
91
-
92
  def build_ui():
93
- with gr.Blocks(title="Multimodal Model Comparator") as demo:
94
  gr.Markdown("""
95
- # πŸ€– Multimodal Model Comparator
96
- Upload an **image or PDF document** and enter a question.
97
- The app compares outputs from **Pixtral-12B**, **InternVL-2.5**, and **Aria-7B** side-by-side.
98
-
99
  _Licenses: Apache 2.0 / MIT β€” safe for research and demo use._
100
  """)
101
 
102
  with gr.Row():
103
- file_input = gr.File(label="Upload Image or PDF", file_types=[".png", ".jpg", ".jpeg", ".pdf"])
104
  prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")
105
 
106
  with gr.Row():
@@ -110,25 +103,25 @@ def build_ui():
110
 
111
  latency_plot = gr.Plot(label="Latency Comparison")
112
 
113
- def process(file, prompt):
114
- outputs, latency_data = compare_models(file, prompt)
115
  plot = plot_latency(latency_data)
116
  return outputs[0], outputs[1], outputs[2], plot
117
 
118
  run_button = gr.Button("Run Comparison")
119
- run_button.click(fn=process, inputs=[file_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])
120
 
121
  gr.Examples(
122
  examples=[
123
- ["sample_image.jpg", "What is shown in this picture?"],
124
- ["chart_example.png", "Describe the trend in this chart."],
 
125
  ],
126
- inputs=[file_input, prompt_input]
127
  )
128
 
129
  return demo
130
 
131
-
132
  if __name__ == "__main__":
133
  demo = build_ui()
134
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModel
4
  from PIL import Image
5
+ import requests
6
  import time
 
7
  import io
8
+ import fitz # PyMuPDF for PDF support
9
+ import matplotlib.pyplot as plt
10
 
11
+ # Define model repository IDs
12
  MODELS = {
13
  "Pixtral-12B": "mistralai/Pixtral-12B-2409",
14
+ "InternVL-3.5": "OpenGVLab/InternVL3_5-241B-A28B",
15
  "Aria-7B": "Aria-7B" # Replace with actual model ID when public
16
  }
17
 
18
  MODEL_CACHE = {}
19
 
 
20
  def load_model(model_id):
21
  if model_id not in MODEL_CACHE:
22
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
23
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
24
  MODEL_CACHE[model_id] = (processor, model)
25
  return MODEL_CACHE[model_id]
26
 
 
27
  def convert_pdf_to_image(pdf_bytes):
28
+ pdf_doc = fitz.open(stream=pdf_bytes, filetype="pdf")
29
+ page = pdf_doc.load_page(0)
30
+ pix = page.get_pixmap(dpi=150)
31
+ image_bytes = pix.tobytes("png")
32
+ image = Image.open(io.BytesIO(image_bytes))
33
+ return image
34
+
35
+ def load_image_from_url(url):
36
+ response = requests.get(url)
37
+ if response.status_code != 200:
38
+ raise ValueError(f"Failed to load image from {url}")
39
+ return Image.open(io.BytesIO(response.content))
40
+
41
+ def compare_models(input_url, prompt):
42
+ if not input_url or not prompt:
43
+ return {name: "Please provide both image/PDF URL and prompt." for name in MODELS}, None
44
+
45
+ # Load image or PDF from URL
46
+ if input_url.lower().endswith('.pdf'):
47
+ pdf_data = requests.get(input_url).content
48
+ image = convert_pdf_to_image(pdf_data)
49
  else:
50
+ image = load_image_from_url(input_url)
 
 
 
 
 
 
51
 
52
+ image.thumbnail((512, 512))
53
  latency_data = {}
54
+ results = {}
55
 
56
  for name, model_id in MODELS.items():
57
  try:
58
  processor, model = load_model(model_id)
59
  start = time.time()
60
+ if hasattr(model, 'chat'):
61
+ text = model.chat(processor.tokenizer, image=image, query=prompt)
62
+ else:
63
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
64
+ outputs = model.generate(**inputs, max_new_tokens=128)
65
+ text = processor.decode(outputs[0], skip_special_tokens=True)
66
  elapsed = time.time() - start
67
  results[name] = f"🧠 {text}\n\n⏱️ {elapsed:.2f}s"
68
  latency_data[name] = elapsed
 
69
  except Exception as e:
70
  results[name] = f"❌ Error: {str(e)}"
71
  latency_data[name] = 0
72
 
 
73
  return [results.get(name, "Model not loaded.") for name in MODELS], latency_data
74
 
 
75
  def plot_latency(latency_data):
76
  if not latency_data:
77
  return None
 
78
  plt.figure(figsize=(6, 3))
79
  plt.bar(latency_data.keys(), latency_data.values())
80
  plt.title("Model Inference Latency (s)")
 
82
  plt.tight_layout()
83
  return plt
84
 
 
85
  def build_ui():
86
+ with gr.Blocks(title="Multimodal Model Comparator (Online Images)") as demo:
87
  gr.Markdown("""
88
+ # 🌐 Multimodal Model Comparator (Online Images)
89
+ Enter a **URL** for an image or PDF (must be accessible via HTTPS) and provide a question.
90
+ The app compares outputs from **Pixtral-12B**, **InternVL-3.5**, and **Aria-7B** side-by-side.
91
+
92
  _Licenses: Apache 2.0 / MIT β€” safe for research and demo use._
93
  """)
94
 
95
  with gr.Row():
96
+ url_input = gr.Textbox(label="Image or PDF URL", placeholder="https://example.com/sample.jpg")
97
  prompt_input = gr.Textbox(label="Prompt", placeholder="Ask something about the image or PDF...")
98
 
99
  with gr.Row():
 
103
 
104
  latency_plot = gr.Plot(label="Latency Comparison")
105
 
106
+ def process(input_url, prompt):
107
+ outputs, latency_data = compare_models(input_url, prompt)
108
  plot = plot_latency(latency_data)
109
  return outputs[0], outputs[1], outputs[2], plot
110
 
111
  run_button = gr.Button("Run Comparison")
112
+ run_button.click(fn=process, inputs=[url_input, prompt_input], outputs=[pixtral_out, internvl_out, aria_out, latency_plot])
113
 
114
  gr.Examples(
115
  examples=[
116
+ ["https://upload.wikimedia.org/wikipedia/commons/9/99/Unofficial_2023_G20_Logo.png", "Describe this image."],
117
+ ["https://upload.wikimedia.org/wikipedia/commons/3/3f/Fronalpstock_big.jpg", "What mountain scene is this?"],
118
+ ["https://arxiv.org/pdf/1706.03762.pdf", "What is this paper about?"],
119
  ],
120
+ inputs=[url_input, prompt_input]
121
  )
122
 
123
  return demo
124
 
 
125
  if __name__ == "__main__":
126
  demo = build_ui()
127
  demo.launch()