Spaces:
Runtime error
Runtime error
| import torch | |
| from ultralytics import YOLO | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import cv2 | |
| import time | |
| import zipfile | |
| import io | |
| from datetime import datetime | |
| # ===== Optional OCR imports ===== | |
| try: | |
| from license_plate_ocr import extract_license_plate_text | |
| OCR_AVAILABLE = True | |
| print("Basic OCR module loaded successfully") | |
| except ImportError as e: | |
| print(f"Basic OCR module not available: {e}") | |
| OCR_AVAILABLE = False | |
| try: | |
| from advanced_ocr import ( | |
| extract_license_plate_text_advanced, | |
| get_available_models, | |
| set_ocr_model, | |
| ) | |
| ADVANCED_OCR_AVAILABLE = True | |
| print("Advanced OCR module loaded successfully") | |
| except ImportError as e: | |
| print(f"Advanced OCR module not available: {e}") | |
| ADVANCED_OCR_AVAILABLE = False | |
| # ===== Model & class names ===== | |
| model = YOLO("best.pt") # make sure best.pt is present | |
| class_names = {0: "With Helmet", 1: "Without Helmet", 2: "License Plate"} | |
| def crop_license_plates(image, detections, extract_text=False, selected_ocr_model="auto"): | |
| """Crop license plates and (optionally) run OCR on the crops.""" | |
| cropped_plates = [] | |
| try: | |
| if isinstance(image, str): | |
| if not os.path.exists(image): | |
| print(f"Error: Image file not found: {image}") | |
| return cropped_plates | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| elif not isinstance(image, Image.Image): | |
| print(f"Error: Unsupported image type: {type(image)}") | |
| return cropped_plates | |
| if image.size[0] == 0 or image.size[1] == 0: | |
| print("Error: Image has zero dimensions") | |
| return cropped_plates | |
| except Exception as e: | |
| print(f"Error loading image: {e}") | |
| return cropped_plates | |
| for i, detection in enumerate(detections): | |
| try: | |
| if detection["Object"] != "License Plate": | |
| continue | |
| pos_str = detection["Position"].strip("()") | |
| if "," not in pos_str: | |
| print( | |
| f"Error: Invalid position format for detection {i}: {detection['Position']}" | |
| ) | |
| continue | |
| x1, y1 = map(int, pos_str.split(", ")) | |
| dims_str = detection["Dimensions"] | |
| if "x" not in dims_str: | |
| print( | |
| f"Error: Invalid dimensions format for detection {i}: {detection['Dimensions']}" | |
| ) | |
| continue | |
| width, height = map(int, dims_str.split("x")) | |
| if width <= 0 or height <= 0: | |
| print(f"Error: Invalid dimensions for detection {i}: {width}x{height}") | |
| continue | |
| x2, y2 = x1 + width, y1 + height | |
| if x1 < 0 or y1 < 0 or x2 > image.width or y2 > image.height: | |
| print( | |
| f"Warning: Bounding box extends beyond image boundaries for detection {i}" | |
| ) | |
| x1 = max(0, x1) | |
| y1 = max(0, y1) | |
| x2 = min(image.width, x2) | |
| y2 = min(image.height, y2) | |
| if x2 <= x1 or y2 <= y1: | |
| print( | |
| f"Error: Invalid crop coordinates for detection {i}: ({x1},{y1}) to ({x2},{y2})" | |
| ) | |
| continue | |
| cropped_plate = image.crop((x1, y1, x2, y2)) | |
| if cropped_plate.size[0] == 0 or cropped_plate.size[1] == 0: | |
| print( | |
| f"Error: Cropped image has zero dimensions for detection {i}" | |
| ) | |
| continue | |
| plate_data = { | |
| "image": cropped_plate, | |
| "confidence": detection["Confidence"], | |
| "position": detection["Position"], | |
| "crop_coords": f"({x1},{y1}) to ({x2},{y2})", | |
| "text": "Processing...", | |
| } | |
| if extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): | |
| try: | |
| print( | |
| f"Extracting text from license plate {i+1} using {selected_ocr_model}..." | |
| ) | |
| if ADVANCED_OCR_AVAILABLE and selected_ocr_model != "basic": | |
| if selected_ocr_model != "auto": | |
| set_ocr_model(selected_ocr_model) | |
| plate_text = extract_license_plate_text_advanced( | |
| cropped_plate, | |
| None if selected_ocr_model == "auto" else selected_ocr_model, | |
| ) | |
| else: | |
| plate_text = extract_license_plate_text(cropped_plate) | |
| if ( | |
| plate_text | |
| and plate_text.strip() | |
| and not plate_text.startswith("Error") | |
| ): | |
| plate_data["text"] = plate_text.strip() | |
| print(f"Extracted text: {plate_text.strip()}") | |
| else: | |
| plate_data["text"] = "No text detected" | |
| print(f"No text found in plate {i+1}") | |
| except Exception as e: | |
| print(f"OCR extraction failed for plate {i+1}: {e}") | |
| plate_data["text"] = f"OCR Failed: {str(e)}" | |
| elif extract_text and not (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): | |
| plate_data["text"] = "OCR not available" | |
| else: | |
| plate_data["text"] = "OCR disabled" | |
| cropped_plates.append(plate_data) | |
| except ValueError as e: | |
| print(f"Error parsing coordinates for detection {i}: {e}") | |
| continue | |
| except Exception as e: | |
| print(f"Error cropping license plate {i}: {e}") | |
| continue | |
| return cropped_plates | |
| def create_download_files(annotated_image, cropped_plates, detections): | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("temp", exist_ok=True) | |
| annotated_path = f"temp/annotated_image_{timestamp}.jpg" | |
| try: | |
| annotated_image.save(annotated_path, quality=95) | |
| except Exception as e: | |
| print(f"Error saving annotated image: {e}") | |
| return None, None, [] | |
| plate_paths = [] | |
| for i, plate_data in enumerate(cropped_plates): | |
| try: | |
| plate_path = f"temp/license_plate_{i+1}_{timestamp}.jpg" | |
| plate_data["image"].save(plate_path, quality=95) | |
| plate_paths.append(plate_path) | |
| except Exception as e: | |
| print(f"Error saving license plate {i+1}: {e}") | |
| continue | |
| report_data = [] | |
| for detection in detections: | |
| report_data.append(detection) | |
| for i, plate_data in enumerate(cropped_plates): | |
| report_data.append( | |
| { | |
| "Object": f"License Plate {i+1} - Text", | |
| "Confidence": plate_data["confidence"], | |
| "Position": plate_data["position"], | |
| "Dimensions": "Extracted Text", | |
| "Text": plate_data.get("text", "N/A"), | |
| } | |
| ) | |
| report_path = f"temp/detection_report_{timestamp}.csv" | |
| if report_data: | |
| try: | |
| df = pd.DataFrame(report_data) | |
| df.to_csv(report_path, index=False) | |
| except Exception as e: | |
| print(f"Error creating detection report: {e}") | |
| report_path = None | |
| zip_path = f"temp/detection_results_{timestamp}.zip" | |
| try: | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| if os.path.exists(annotated_path): | |
| zipf.write(annotated_path, f"annotated_image_{timestamp}.jpg") | |
| for plate_path in plate_paths: | |
| if os.path.exists(plate_path): | |
| zipf.write(plate_path, os.path.basename(plate_path)) | |
| if report_path and os.path.exists(report_path): | |
| zipf.write(report_path, f"detection_report_{timestamp}.csv") | |
| except Exception as e: | |
| print(f"Error creating ZIP file: {e}") | |
| return None, annotated_path, plate_paths | |
| return zip_path, annotated_path, plate_paths | |
| except Exception as e: | |
| print(f"Error in create_download_files: {e}") | |
| return None, None, [] | |
| def yolov8_detect( | |
| image=None, | |
| image_size=640, | |
| conf_threshold=0.4, | |
| iou_threshold=0.5, | |
| show_stats=True, | |
| show_confidence=True, | |
| crop_plates=True, | |
| extract_text=False, | |
| ocr_on_no_helmet=False, | |
| selected_ocr_model="auto", | |
| ): | |
| """Main detection function.""" | |
| if image_size is None: | |
| image_size = 640 | |
| if not isinstance(image_size, int): | |
| image_size = int(image_size) | |
| imgsz = [image_size, image_size] | |
| results = model.predict(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz) | |
| annotated_image = results[0].plot() | |
| if isinstance(annotated_image, np.ndarray): | |
| annotated_image = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)) | |
| boxes = results[0].boxes | |
| detections = [] | |
| if boxes is not None and len(boxes) > 0: | |
| for i, (box, cls, conf) in enumerate(zip(boxes.xyxy, boxes.cls, boxes.conf)): | |
| x1, y1, x2, y2 = box.tolist() | |
| class_id = int(cls) | |
| confidence = float(conf) | |
| label = class_names.get(class_id, f"Class {class_id}") | |
| detections.append( | |
| { | |
| "Object": label, | |
| "Confidence": f"{confidence:.2f}", | |
| "Position": f"({int(x1)}, {int(y1)})", | |
| "Dimensions": f"{int(x2 - x1)}x{int(y2 - y1)}", | |
| } | |
| ) | |
| cropped_plates = [] | |
| license_plate_gallery = [] | |
| plate_texts = [] | |
| download_files = None | |
| has_no_helmet = any(d["Object"] == "Without Helmet" for d in detections) | |
| should_extract_text = extract_text or (ocr_on_no_helmet and has_no_helmet) | |
| ocr_available = OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE | |
| if crop_plates and detections: | |
| try: | |
| license_plate_count = len([d for d in detections if d["Object"] == "License Plate"]) | |
| print(f"Processing {license_plate_count} license plates...") | |
| if ocr_on_no_helmet and has_no_helmet: | |
| print("⚠️ No helmet detected - OCR will be performed on license plates") | |
| cropped_plates = crop_license_plates( | |
| image, detections, should_extract_text, selected_ocr_model | |
| ) | |
| print(f"Successfully cropped {len(cropped_plates)} license plates") | |
| license_plate_gallery = [plate_data["image"] for plate_data in cropped_plates] | |
| if should_extract_text and ocr_available: | |
| print("Extracting text from license plates...") | |
| plate_texts = [] | |
| for i, plate_data in enumerate(cropped_plates): | |
| text = plate_data.get("text", "No text detected") | |
| print(f"Plate {i+1} text: {text}") | |
| if ocr_on_no_helmet and has_no_helmet: | |
| plate_texts.append(f"🚨 No Helmet Violation - Plate {i+1}: {text}") | |
| else: | |
| plate_texts.append(f"Plate {i+1}: {text}") | |
| elif should_extract_text and not ocr_available: | |
| plate_texts = [ | |
| "OCR not available - install requirements: pip install transformers easyocr" | |
| ] | |
| elif not should_extract_text: | |
| plate_texts = [ | |
| f"Plate {i+1}: Text extraction disabled" for i in range(len(cropped_plates)) | |
| ] | |
| if cropped_plates or detections: | |
| download_files, _, _ = create_download_files( | |
| annotated_image, cropped_plates, detections | |
| ) | |
| if download_files is None: | |
| print("Warning: Could not create download files") | |
| except Exception as e: | |
| print(f"Error in license plate processing: {e}") | |
| cropped_plates = [] | |
| license_plate_gallery = [] | |
| plate_texts = ["Error processing license plates"] | |
| download_files = None | |
| stats_text = "" | |
| if show_stats and detections: | |
| df = pd.DataFrame(detections) | |
| counts = df["Object"].value_counts().to_dict() | |
| stats_text = "Detection Summary:\n" | |
| for obj, count in counts.items(): | |
| stats_text += f"- {obj}: {count}\n" | |
| if cropped_plates: | |
| stats_text += f"\nLicense Plates Cropped: {len(cropped_plates)}\n" | |
| if has_no_helmet: | |
| stats_text += "⚠️ HELMET VIOLATION DETECTED!\n" | |
| if should_extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE): | |
| stats_text += "Extracted Text:\n" | |
| for i, plate_data in enumerate(cropped_plates): | |
| text = plate_data.get("text", "No text") | |
| if has_no_helmet and ocr_on_no_helmet: | |
| stats_text += f"🚨 Violation - Plate {i+1}: {text}\n" | |
| else: | |
| stats_text += f"- Plate {i+1}: {text}\n" | |
| detection_table = ( | |
| pd.DataFrame(detections) | |
| if detections | |
| else pd.DataFrame(columns=["Object", "Confidence", "Position", "Dimensions"]) | |
| ) | |
| plate_text_output = ( | |
| "\n".join(plate_texts) | |
| if plate_texts | |
| else "No license plates detected or OCR disabled" | |
| ) | |
| return ( | |
| annotated_image, | |
| detection_table, | |
| stats_text, | |
| license_plate_gallery, | |
| download_files, | |
| plate_text_output, | |
| ) | |
| def download_sample_images(): | |
| """Download sample images for testing.""" | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-1.jpg?raw=true", | |
| "sample_1.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-2.jpg?raw=true", | |
| "sample_2.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-3.jpg?raw=true", | |
| "sample_3.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-4.jpg?raw=true", | |
| "sample_4.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-5.jpg?raw=true", | |
| "sample_5.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-6.jpg?raw=true", | |
| "sample_6.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-7.jpg?raw=true", | |
| "sample_7.jpg", | |
| ) | |
| torch.hub.download_url_to_file( | |
| "https://github.com/Abs6187/Helmet-Detection/blob/main/Sample-Image-8.jpg?raw=true", | |
| "sample_8.jpg", | |
| ) | |
| def get_ocr_status(): | |
| """Return OCR availability status.""" | |
| return { | |
| "basic_available": OCR_AVAILABLE, | |
| "advanced_available": ADVANCED_OCR_AVAILABLE, | |
| "any_available": OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE | |
| } |