Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import tensorflow as tf | |
| class YoloLeNetOCR: | |
| def __init__(self, | |
| yolo_model_path: str, | |
| lenet_model_path: str, | |
| image_size=(28, 28), | |
| conf_threshold=0.25): | |
| # YOLO detector | |
| self.detector = YOLO(yolo_model_path) | |
| # LeNet CNN | |
| self.cnn = tf.keras.models.load_model(lenet_model_path) | |
| # Embedded class names and inverse map (no external pkl needed) | |
| class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'C', 'dot'] | |
| self.inv_map = {i: label for i, label in enumerate(class_names)} | |
| # params | |
| self.image_size = image_size | |
| self.conf_threshold = conf_threshold | |
| def preprocess(self, crop: np.ndarray) -> np.ndarray: | |
| # Convert to grayscale for CNN | |
| gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY) | |
| resized = cv2.resize(gray, self.image_size) | |
| normed = resized.astype(np.float32) / 255.0 | |
| # CNN expects shape (1, H, W, 1) | |
| return normed.reshape(1, *self.image_size, 1) | |
| def ocr_image(self, image_path: str) -> str: | |
| # 1) Load as single-channel grayscale | |
| gray0 = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) | |
| if gray0 is None: | |
| raise FileNotFoundError(f"Cannot read {image_path}") | |
| # 2) Convert to BGR by stacking gray into 3 channels | |
| img = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR) | |
| # 3) Detect boxes on the grayscale-derived BGR image | |
| res = self.detector.predict(source=img, verbose=False)[0] | |
| boxes = res.boxes.xyxy.cpu().numpy() | |
| confs = res.boxes.conf.cpu().numpy() | |
| classes = res.boxes.cls.cpu().numpy().astype(int) | |
| boxes = boxes[confs >= self.conf_threshold] | |
| classes = classes[confs >= self.conf_threshold] | |
| if len(boxes) == 0: | |
| return "" | |
| # 4) Sort boxes left-to-right | |
| sort_idx = np.argsort(boxes[:, 0]) | |
| boxes = boxes[sort_idx] | |
| classes = classes[sort_idx] | |
| digits = [] | |
| for (x1, y1, x2, y2), cls_idx in zip(boxes, classes): | |
| x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) | |
| crop = img[ | |
| max(0, y1):min(img.shape[0], y2), | |
| max(0, x1):min(img.shape[1], x2) | |
| ] | |
| # If YOLO says it's a dot, just append "dot" | |
| if self.inv_map[cls_idx] == "dot": | |
| digits.append("dot") | |
| else: | |
| # 5) Preprocess and predict with LeNet | |
| inp = self.preprocess(crop) | |
| preds = self.cnn.predict(inp, verbose=0) | |
| idx = int(np.argmax(preds, axis=1)[0]) | |
| digits.append(self.inv_map[idx]) | |
| return "".join(digits) | |
| def process_dataset_folder(self, dataset_dir, output_file=None): | |
| """ | |
| Process all images in a dataset directory and its subdirectories. | |
| Args: | |
| dataset_dir (str): Path to the dataset directory | |
| output_file (str): Optional path to save results to a text file | |
| Returns: | |
| dict: Dictionary mapping image paths to OCR results | |
| """ | |
| results = {} | |
| processed = 0 | |
| # Iterate through all subdirectories | |
| for root, _, files in os.walk(dataset_dir): | |
| for file in files: | |
| if file.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| image_path = os.path.join(root, file) | |
| try: | |
| ocr_result = self.ocr_image(image_path) | |
| ocr_result = ocr_result.replace("dot", ".") | |
| results[image_path] = ocr_result | |
| processed += 1 | |
| # Print progress | |
| if processed % 100 == 0: | |
| print(f"Processed {processed} images") | |
| except Exception as e: | |
| print(f"Error processing {image_path}: {str(e)}") | |
| # Save results to file if specified | |
| if output_file: | |
| with open(output_file, 'w') as f: | |
| for img_path, result in results.items(): | |
| f.write(f"{img_path},{result}\n") | |
| print(f"Results saved to {output_file}") | |
| print(f"Total images processed: {processed}") | |
| return results | |
| # Example usage | |
| if __name__ == "__main__": | |
| ocr = YoloLeNetOCR( | |
| yolo_model_path="Models/res_detect_v4.pt", | |
| lenet_model_path="Models/lenet_res_v4.h5", | |
| conf_threshold=0.3 | |
| ) | |
| # Process the entire new_data/res folder | |
| results = ocr.process_dataset_folder( | |
| dataset_dir="new_data/res", | |
| output_file="ocr_results.csv" | |
| ) | |
| # Print sample results | |
| print("\nSample results:") | |
| sample_count = 0 | |
| for path, text in results.items(): | |
| print(f"{path}: {text}") | |
| sample_count += 1 | |
| if sample_count >= 5: | |
| break | |
| # # ------------------- | |
| # # Example usage | |
| # # ------------------- | |
| # if __name__ == "__main__": | |
| # ocr = YoloLeNetOCR( | |
| # yolo_model_path="Models/res_detect_v3.pt", | |
| # lenet_model_path="Models/lenet_res_v4.h5", | |
| # conf_threshold=0.3 | |
| # ) | |
| # result = ocr.ocr_image("new_data/res") | |
| # result = result.replace("dot", ".") | |
| # print("OCR result:", result) |