dl_course_hw3 / src /utils.py
katyan010's picture
fix font
57f7dec
from typing import Any, Dict, List, Optional
import torch
from PIL import Image, ImageDraw, ImageFont
import streamlit as st
from transformers import pipeline
from const import WHITE, color_for_label
@st.cache_resource(show_spinner=False)
def get_detector(model_id: str):
has_cuda = torch.cuda.is_available()
device = 0 if has_cuda else -1
torch_dtype = torch.float16 if has_cuda else torch.float32
return pipeline(
task="object-detection",
model=model_id,
device=device,
torch_dtype=torch_dtype,
)
@st.cache_data(show_spinner=False, ttl=600)
def run_detection(model_id: str, image: Image.Image) -> List[Dict[str, Any]]:
detector = get_detector(model_id)
return detector(image) # returns list of dicts with label, score, box
def _get_font() -> Optional[ImageFont.FreeTypeFont]:
try:
return ImageFont.load_default()
except Exception:
return None
def draw_boxes(
image: Image.Image,
predictions: List[Dict[str, Any]],
threshold: float,
) -> Image.Image:
annotated = image.copy()
draw = ImageDraw.Draw(annotated)
font = _get_font()
for pred in predictions:
score = float(pred.get("score", 0.0))
if score < threshold:
continue
label = str(pred.get("label", "logo"))
box = pred.get("box", {})
x0 = float(box.get("xmin", box.get("x_min", 0)))
y0 = float(box.get("ymin", box.get("y_min", 0)))
x1 = float(box.get("xmax", box.get("x_max", 0)))
y1 = float(box.get("ymax", box.get("y_max", 0)))
color = color_for_label(label)
# Rectangle
draw.rectangle([(x0, y0), (x1, y1)], outline=color, width=3)
# Label background
text = f"{label} {score:.2f}"
try:
tx0, ty0, tx1, _ = draw.textbbox(
(int(x0), int(y0)),
text,
font=font,
)
except Exception:
tx0, ty0 = int(x0), int(y0) - 20
tx1 = int(x0) + 8 * len(text)
bg_top = min(ty0, y0)
bg_bottom = max(ty0, y0)
draw.rectangle(
[(tx0, bg_top - 2), (tx1, bg_bottom + 2)],
fill=color,
)
# Text
draw.text(
(int(x0) + 2, int(y0) - 18),
text,
fill=WHITE,
font=font,
)
return annotated