Spaces:
Runtime error
Runtime error
File size: 15,567 Bytes
3705172 14f08a6 3705172 14f08a6 3705172 14f08a6 3705172 14f08a6 3705172 14f08a6 3705172 d2b4b74 3705172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 |
import torch
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pandas as pd
import os
import cv2
import time
import zipfile
import io
from datetime import datetime
# ===== Optional OCR imports =====
try:
from license_plate_ocr import extract_license_plate_text
OCR_AVAILABLE = True
print("Basic OCR module loaded successfully")
except ImportError as e:
print(f"Basic OCR module not available: {e}")
OCR_AVAILABLE = False
try:
from advanced_ocr import (
extract_license_plate_text_advanced,
get_available_models,
set_ocr_model,
)
ADVANCED_OCR_AVAILABLE = True
print("Advanced OCR module loaded successfully")
except ImportError as e:
print(f"Advanced OCR module not available: {e}")
ADVANCED_OCR_AVAILABLE = False
# ===== Model & class names =====
model = YOLO("best.pt") # make sure best.pt is present
class_names = {0: "With Helmet", 1: "Without Helmet", 2: "License Plate"}
def crop_license_plates(image, detections, extract_text=False, selected_ocr_model="auto"):
"""Crop license plates and (optionally) run OCR on the crops."""
cropped_plates = []
try:
if isinstance(image, str):
if not os.path.exists(image):
print(f"Error: Image file not found: {image}")
return cropped_plates
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
print(f"Error: Unsupported image type: {type(image)}")
return cropped_plates
if image.size[0] == 0 or image.size[1] == 0:
print("Error: Image has zero dimensions")
return cropped_plates
except Exception as e:
print(f"Error loading image: {e}")
return cropped_plates
for i, detection in enumerate(detections):
try:
if detection["Object"] != "License Plate":
continue
pos_str = detection["Position"].strip("()")
if "," not in pos_str:
print(
f"Error: Invalid position format for detection {i}: {detection['Position']}"
)
continue
x1, y1 = map(int, pos_str.split(", "))
dims_str = detection["Dimensions"]
if "x" not in dims_str:
print(
f"Error: Invalid dimensions format for detection {i}: {detection['Dimensions']}"
)
continue
width, height = map(int, dims_str.split("x"))
if width <= 0 or height <= 0:
print(f"Error: Invalid dimensions for detection {i}: {width}x{height}")
continue
x2, y2 = x1 + width, y1 + height
if x1 < 0 or y1 < 0 or x2 > image.width or y2 > image.height:
print(
f"Warning: Bounding box extends beyond image boundaries for detection {i}"
)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(image.width, x2)
y2 = min(image.height, y2)
if x2 <= x1 or y2 <= y1:
print(
f"Error: Invalid crop coordinates for detection {i}: ({x1},{y1}) to ({x2},{y2})"
)
continue
cropped_plate = image.crop((x1, y1, x2, y2))
if cropped_plate.size[0] == 0 or cropped_plate.size[1] == 0:
print(
f"Error: Cropped image has zero dimensions for detection {i}"
)
continue
plate_data = {
"image": cropped_plate,
"confidence": detection["Confidence"],
"position": detection["Position"],
"crop_coords": f"({x1},{y1}) to ({x2},{y2})",
"text": "Processing...",
}
if extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
try:
print(
f"Extracting text from license plate {i+1} using {selected_ocr_model}..."
)
if ADVANCED_OCR_AVAILABLE and selected_ocr_model != "basic":
if selected_ocr_model != "auto":
set_ocr_model(selected_ocr_model)
plate_text = extract_license_plate_text_advanced(
cropped_plate,
None if selected_ocr_model == "auto" else selected_ocr_model,
)
else:
plate_text = extract_license_plate_text(cropped_plate)
if (
plate_text
and plate_text.strip()
and not plate_text.startswith("Error")
):
plate_data["text"] = plate_text.strip()
print(f"Extracted text: {plate_text.strip()}")
else:
plate_data["text"] = "No text detected"
print(f"No text found in plate {i+1}")
except Exception as e:
print(f"OCR extraction failed for plate {i+1}: {e}")
plate_data["text"] = f"OCR Failed: {str(e)}"
elif extract_text and not (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
plate_data["text"] = "OCR not available"
else:
plate_data["text"] = "OCR disabled"
cropped_plates.append(plate_data)
except ValueError as e:
print(f"Error parsing coordinates for detection {i}: {e}")
continue
except Exception as e:
print(f"Error cropping license plate {i}: {e}")
continue
return cropped_plates
def create_download_files(annotated_image, cropped_plates, detections):
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("temp", exist_ok=True)
annotated_path = f"temp/annotated_image_{timestamp}.jpg"
try:
annotated_image.save(annotated_path, quality=95)
except Exception as e:
print(f"Error saving annotated image: {e}")
return None, None, []
plate_paths = []
for i, plate_data in enumerate(cropped_plates):
try:
plate_path = f"temp/license_plate_{i+1}_{timestamp}.jpg"
plate_data["image"].save(plate_path, quality=95)
plate_paths.append(plate_path)
except Exception as e:
print(f"Error saving license plate {i+1}: {e}")
continue
report_data = []
for detection in detections:
report_data.append(detection)
for i, plate_data in enumerate(cropped_plates):
report_data.append(
{
"Object": f"License Plate {i+1} - Text",
"Confidence": plate_data["confidence"],
"Position": plate_data["position"],
"Dimensions": "Extracted Text",
"Text": plate_data.get("text", "N/A"),
}
)
report_path = f"temp/detection_report_{timestamp}.csv"
if report_data:
try:
df = pd.DataFrame(report_data)
df.to_csv(report_path, index=False)
except Exception as e:
print(f"Error creating detection report: {e}")
report_path = None
zip_path = f"temp/detection_results_{timestamp}.zip"
try:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
if os.path.exists(annotated_path):
zipf.write(annotated_path, f"annotated_image_{timestamp}.jpg")
for plate_path in plate_paths:
if os.path.exists(plate_path):
zipf.write(plate_path, os.path.basename(plate_path))
if report_path and os.path.exists(report_path):
zipf.write(report_path, f"detection_report_{timestamp}.csv")
except Exception as e:
print(f"Error creating ZIP file: {e}")
return None, annotated_path, plate_paths
return zip_path, annotated_path, plate_paths
except Exception as e:
print(f"Error in create_download_files: {e}")
return None, None, []
def yolov8_detect(
image=None,
image_size=640,
conf_threshold=0.4,
iou_threshold=0.5,
show_stats=True,
show_confidence=True,
crop_plates=True,
extract_text=False,
ocr_on_no_helmet=False,
selected_ocr_model="auto",
):
"""Main detection function."""
if image_size is None:
image_size = 640
if not isinstance(image_size, int):
image_size = int(image_size)
imgsz = [image_size, image_size]
results = model.predict(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz)
annotated_image = results[0].plot()
if isinstance(annotated_image, np.ndarray):
annotated_image = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
boxes = results[0].boxes
detections = []
if boxes is not None and len(boxes) > 0:
for i, (box, cls, conf) in enumerate(zip(boxes.xyxy, boxes.cls, boxes.conf)):
x1, y1, x2, y2 = box.tolist()
class_id = int(cls)
confidence = float(conf)
label = class_names.get(class_id, f"Class {class_id}")
detections.append(
{
"Object": label,
"Confidence": f"{confidence:.2f}",
"Position": f"({int(x1)}, {int(y1)})",
"Dimensions": f"{int(x2 - x1)}x{int(y2 - y1)}",
}
)
cropped_plates = []
license_plate_gallery = []
plate_texts = []
download_files = None
has_no_helmet = any(d["Object"] == "Without Helmet" for d in detections)
should_extract_text = extract_text or (ocr_on_no_helmet and has_no_helmet)
ocr_available = OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE
if crop_plates and detections:
try:
license_plate_count = len([d for d in detections if d["Object"] == "License Plate"])
print(f"Processing {license_plate_count} license plates...")
if ocr_on_no_helmet and has_no_helmet:
print("⚠️ No helmet detected - OCR will be performed on license plates")
cropped_plates = crop_license_plates(
image, detections, should_extract_text, selected_ocr_model
)
print(f"Successfully cropped {len(cropped_plates)} license plates")
license_plate_gallery = [plate_data["image"] for plate_data in cropped_plates]
if should_extract_text and ocr_available:
print("Extracting text from license plates...")
plate_texts = []
for i, plate_data in enumerate(cropped_plates):
text = plate_data.get("text", "No text detected")
print(f"Plate {i+1} text: {text}")
if ocr_on_no_helmet and has_no_helmet:
plate_texts.append(f"🚨 No Helmet Violation - Plate {i+1}: {text}")
else:
plate_texts.append(f"Plate {i+1}: {text}")
elif should_extract_text and not ocr_available:
plate_texts = [
"OCR not available - install requirements: pip install transformers easyocr"
]
elif not should_extract_text:
plate_texts = [
f"Plate {i+1}: Text extraction disabled" for i in range(len(cropped_plates))
]
if cropped_plates or detections:
download_files, _, _ = create_download_files(
annotated_image, cropped_plates, detections
)
if download_files is None:
print("Warning: Could not create download files")
except Exception as e:
print(f"Error in license plate processing: {e}")
cropped_plates = []
license_plate_gallery = []
plate_texts = ["Error processing license plates"]
download_files = None
stats_text = ""
if show_stats and detections:
df = pd.DataFrame(detections)
counts = df["Object"].value_counts().to_dict()
stats_text = "Detection Summary:\n"
for obj, count in counts.items():
stats_text += f"- {obj}: {count}\n"
if cropped_plates:
stats_text += f"\nLicense Plates Cropped: {len(cropped_plates)}\n"
if has_no_helmet:
stats_text += "⚠️ HELMET VIOLATION DETECTED!\n"
if should_extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
stats_text += "Extracted Text:\n"
for i, plate_data in enumerate(cropped_plates):
text = plate_data.get("text", "No text")
if has_no_helmet and ocr_on_no_helmet:
stats_text += f"🚨 Violation - Plate {i+1}: {text}\n"
else:
stats_text += f"- Plate {i+1}: {text}\n"
detection_table = (
pd.DataFrame(detections)
if detections
else pd.DataFrame(columns=["Object", "Confidence", "Position", "Dimensions"])
)
plate_text_output = (
"\n".join(plate_texts)
if plate_texts
else "No license plates detected or OCR disabled"
)
return (
annotated_image,
detection_table,
stats_text,
license_plate_gallery,
download_files,
plate_text_output,
)
def download_sample_images():
"""Download sample images for testing."""
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-1.jpg?raw=true",
"sample_1.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-2.jpg?raw=true",
"sample_2.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-3.jpg?raw=true",
"sample_3.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-4.jpg?raw=true",
"sample_4.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-5.jpg?raw=true",
"sample_5.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-6.jpg?raw=true",
"sample_6.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-7.jpg?raw=true",
"sample_7.jpg",
)
torch.hub.download_url_to_file(
"https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-8.jpg?raw=true",
"sample_8.jpg",
)
def get_ocr_status():
"""Return OCR availability status."""
return {
"basic_available": OCR_AVAILABLE,
"advanced_available": ADVANCED_OCR_AVAILABLE,
"any_available": OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE
} |