Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Optional | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| import streamlit as st | |
| from transformers import pipeline | |
| from const import WHITE, color_for_label | |
| def get_detector(model_id: str): | |
| has_cuda = torch.cuda.is_available() | |
| device = 0 if has_cuda else -1 | |
| torch_dtype = torch.float16 if has_cuda else torch.float32 | |
| return pipeline( | |
| task="object-detection", | |
| model=model_id, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| ) | |
| def run_detection(model_id: str, image: Image.Image) -> List[Dict[str, Any]]: | |
| detector = get_detector(model_id) | |
| return detector(image) # returns list of dicts with label, score, box | |
| def _get_font() -> Optional[ImageFont.FreeTypeFont]: | |
| try: | |
| return ImageFont.load_default() | |
| except Exception: | |
| return None | |
| def draw_boxes( | |
| image: Image.Image, | |
| predictions: List[Dict[str, Any]], | |
| threshold: float, | |
| ) -> Image.Image: | |
| annotated = image.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| font = _get_font() | |
| for pred in predictions: | |
| score = float(pred.get("score", 0.0)) | |
| if score < threshold: | |
| continue | |
| label = str(pred.get("label", "logo")) | |
| box = pred.get("box", {}) | |
| x0 = float(box.get("xmin", box.get("x_min", 0))) | |
| y0 = float(box.get("ymin", box.get("y_min", 0))) | |
| x1 = float(box.get("xmax", box.get("x_max", 0))) | |
| y1 = float(box.get("ymax", box.get("y_max", 0))) | |
| color = color_for_label(label) | |
| # Rectangle | |
| draw.rectangle([(x0, y0), (x1, y1)], outline=color, width=3) | |
| # Label background | |
| text = f"{label} {score:.2f}" | |
| try: | |
| tx0, ty0, tx1, _ = draw.textbbox( | |
| (int(x0), int(y0)), | |
| text, | |
| font=font, | |
| ) | |
| except Exception: | |
| tx0, ty0 = int(x0), int(y0) - 20 | |
| tx1 = int(x0) + 8 * len(text) | |
| bg_top = min(ty0, y0) | |
| bg_bottom = max(ty0, y0) | |
| draw.rectangle( | |
| [(tx0, bg_top - 2), (tx1, bg_bottom + 2)], | |
| fill=color, | |
| ) | |
| # Text | |
| draw.text( | |
| (int(x0) + 2, int(y0) - 18), | |
| text, | |
| fill=WHITE, | |
| font=font, | |
| ) | |
| return annotated |