Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
| 4 |
import spaces
|
|
|
|
| 5 |
|
| 6 |
# Use GPU if available
|
| 7 |
if torch.cuda.is_available():
|
|
@@ -30,12 +31,18 @@ def query_image(img, text_queries, score_threshold):
|
|
| 30 |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
| 31 |
|
| 32 |
result_labels = []
|
|
|
|
| 33 |
for box, score, label in zip(boxes, scores, labels):
|
| 34 |
box = [int(i) for i in box.tolist()]
|
| 35 |
if score < score_threshold:
|
| 36 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
result_labels.append((box, text_queries[label.item()]))
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
description = """
|
|
@@ -56,10 +63,5 @@ demo = gr.Interface(
|
|
| 56 |
outputs=["annotatedimage", "json"],
|
| 57 |
title="Zero-Shot Object Detection with OWLv2",
|
| 58 |
description=description,
|
| 59 |
-
examples=[
|
| 60 |
-
["assets/astronaut.png", "human face, rocket, star-spangled banner, nasa badge", 0.11],
|
| 61 |
-
["assets/coffee.png", "coffee mug, spoon, plate", 0.1],
|
| 62 |
-
["assets/butterflies.jpeg", "orange butterfly", 0.3],
|
| 63 |
-
],
|
| 64 |
)
|
| 65 |
demo.launch()
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
| 4 |
import spaces
|
| 5 |
+
import json
|
| 6 |
|
| 7 |
# Use GPU if available
|
| 8 |
if torch.cuda.is_available():
|
|
|
|
| 31 |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
| 32 |
|
| 33 |
result_labels = []
|
| 34 |
+
boxes_coords = []
|
| 35 |
for box, score, label in zip(boxes, scores, labels):
|
| 36 |
box = [int(i) for i in box.tolist()]
|
| 37 |
if score < score_threshold:
|
| 38 |
continue
|
| 39 |
+
boxes_coords.append({
|
| 40 |
+
"object": text_queries[label.item()],
|
| 41 |
+
"pos": box
|
| 42 |
+
})
|
| 43 |
result_labels.append((box, text_queries[label.item()]))
|
| 44 |
+
print(boxes_coords)
|
| 45 |
+
return [img, result_labels], boxes_coords
|
| 46 |
|
| 47 |
|
| 48 |
description = """
|
|
|
|
| 63 |
outputs=["annotatedimage", "json"],
|
| 64 |
title="Zero-Shot Object Detection with OWLv2",
|
| 65 |
description=description,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
demo.launch()
|