Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, jsonify | |
| from flask_socketio import SocketIO | |
| import os | |
| import shutil | |
| import numpy as np | |
| from PIL import Image | |
| from utils.predictor import Predictor | |
| from utils.helpers import ( | |
| blend_mask_with_image, | |
| save_mask_as_png, | |
| convert_mask_to_yolo, | |
| ) | |
| import torch | |
| from ultralytics import YOLO | |
| import threading | |
| from threading import Lock | |
| import subprocess | |
| import time | |
| import logging | |
| import multiprocessing | |
| # Initialize Flask app and SocketIO | |
| app = Flask(__name__) | |
| socketio = SocketIO(app) | |
| # Define Base Directory | |
| BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
| # Folder structure with absolute paths | |
| UPLOAD_FOLDERS = { | |
| 'input': os.path.join(BASE_DIR, 'static/uploads/input'), | |
| 'segmented_voids': os.path.join(BASE_DIR, 'static/uploads/segmented/voids'), | |
| 'segmented_chips': os.path.join(BASE_DIR, 'static/uploads/segmented/chips'), | |
| 'mask_voids': os.path.join(BASE_DIR, 'static/uploads/mask/voids'), | |
| 'mask_chips': os.path.join(BASE_DIR, 'static/uploads/mask/chips'), | |
| 'automatic_segmented': os.path.join(BASE_DIR, 'static/uploads/segmented/automatic'), | |
| } | |
| HISTORY_FOLDERS = { | |
| 'images': os.path.join(BASE_DIR, 'static/history/images'), | |
| 'masks_chip': os.path.join(BASE_DIR, 'static/history/masks/chip'), | |
| 'masks_void': os.path.join(BASE_DIR, 'static/history/masks/void'), | |
| } | |
| DATASET_FOLDERS = { | |
| 'train_images': os.path.join(BASE_DIR, 'dataset/train/images'), | |
| 'train_labels': os.path.join(BASE_DIR, 'dataset/train/labels'), | |
| 'val_images': os.path.join(BASE_DIR, 'dataset/val/images'), | |
| 'val_labels': os.path.join(BASE_DIR, 'dataset/val/labels'), | |
| 'temp_backup': os.path.join(BASE_DIR, 'temp_backup'), | |
| 'models': os.path.join(BASE_DIR, 'models'), | |
| 'models_old': os.path.join(BASE_DIR, 'models/old'), | |
| } | |
| # Ensure all folders exist | |
| for folder_name, folder_path in {**UPLOAD_FOLDERS, **HISTORY_FOLDERS, **DATASET_FOLDERS}.items(): | |
| os.makedirs(folder_path, exist_ok=True) | |
| logging.info(f"Ensured folder exists: {folder_name} -> {folder_path}") | |
| training_process = None | |
| def initialize_training_status(): | |
| """Initialize global training status.""" | |
| global training_status | |
| training_status = {'running': False, 'cancelled': False} | |
| def persist_training_status(): | |
| """Save training status to a file.""" | |
| with open(os.path.join(BASE_DIR, 'training_status.json'), 'w') as status_file: | |
| json.dump(training_status, status_file) | |
| def load_training_status(): | |
| """Load training status from a file.""" | |
| global training_status | |
| status_path = os.path.join(BASE_DIR, 'training_status.json') | |
| if os.path.exists(status_path): | |
| with open(status_path, 'r') as status_file: | |
| training_status = json.load(status_file) | |
| else: | |
| training_status = {'running': False, 'cancelled': False} | |
| load_training_status() | |
| os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0" | |
| # Initialize SAM Predictor | |
| MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
| CHECKPOINT = "checkpoints/sam2.1_hiera_large.pt" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| predictor = Predictor(MODEL_CFG, CHECKPOINT, DEVICE) | |
| # Initialize YOLO-seg | |
| YOLO_CFG = os.path.join(DATASET_FOLDERS['models'], "best.pt") | |
| yolo_model = YOLO(YOLO_CFG) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler(os.path.join(BASE_DIR, "app.log")) # Log to a file | |
| ] | |
| ) | |
| def index(): | |
| """Serve the main UI.""" | |
| return render_template('index.html') | |
| def upload_image(): | |
| """Handle image uploads.""" | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Save the uploaded file to the input folder | |
| input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) | |
| file.save(input_path) | |
| # Set the uploaded image in the predictor | |
| image = np.array(Image.open(input_path).convert("RGB")) | |
| predictor.set_image(image) | |
| # Return a web-accessible URL instead of the file system path | |
| web_accessible_url = f"/static/uploads/input/{file.filename}" | |
| print(f"Image uploaded and set for prediction: {input_path}") | |
| return jsonify({'image_url': web_accessible_url}) | |
| def segment(): | |
| """ | |
| Perform segmentation and return the blended image URL. | |
| """ | |
| try: | |
| # Extract data from request | |
| data = request.json | |
| points = np.array(data.get('points', [])) | |
| labels = np.array(data.get('labels', [])) | |
| current_class = data.get('class', 'voids') # Default to 'voids' if class not provided | |
| # Ensure predictor has an image set | |
| if not predictor.image_set: | |
| raise ValueError("No image set for prediction.") | |
| # Perform SAM prediction | |
| masks, _, _ = predictor.predict( | |
| point_coords=points, | |
| point_labels=labels, | |
| multimask_output=False | |
| ) | |
| # Check if masks exist and have non-zero elements | |
| if masks is None or masks.size == 0: | |
| raise RuntimeError("No masks were generated by the predictor.") | |
| # Define output paths based on class | |
| mask_folder = UPLOAD_FOLDERS.get(f'mask_{current_class}') | |
| segmented_folder = UPLOAD_FOLDERS.get(f'segmented_{current_class}') | |
| if not mask_folder or not segmented_folder: | |
| raise ValueError(f"Invalid class '{current_class}' provided.") | |
| os.makedirs(mask_folder, exist_ok=True) | |
| os.makedirs(segmented_folder, exist_ok=True) | |
| # Save the raw mask | |
| mask_path = os.path.join(mask_folder, 'raw_mask.png') | |
| save_mask_as_png(masks[0], mask_path) | |
| # Generate blended image | |
| blend_color = [34, 139, 34] if current_class == 'voids' else [30, 144, 255] # Green for voids, blue for chips | |
| blended_image = blend_mask_with_image(predictor.image, masks[0], blend_color) | |
| # Save blended image | |
| blended_filename = f"blended_{current_class}.png" | |
| blended_path = os.path.join(segmented_folder, blended_filename) | |
| Image.fromarray(blended_image).save(blended_path) | |
| # Return URL for frontend access | |
| segmented_url = f"/static/uploads/segmented/{current_class}/{blended_filename}" | |
| logging.info(f"Segmentation completed for {current_class}. Points: {points}, Labels: {labels}") | |
| return jsonify({'segmented_url': segmented_url}) | |
| except ValueError as ve: | |
| logging.error(f"Value error during segmentation: {ve}") | |
| return jsonify({'error': str(ve)}), 400 | |
| except Exception as e: | |
| logging.error(f"Unexpected error during segmentation: {e}") | |
| return jsonify({'error': 'Segmentation failed', 'details': str(e)}), 500 | |
| def automatic_segment(): | |
| """Perform automatic segmentation using YOLO.""" | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| input_path = os.path.join(UPLOAD_FOLDERS['input'], file.filename) | |
| file.save(input_path) | |
| try: | |
| # Perform YOLO segmentation | |
| results = yolo_model.predict(input_path, save=False, save_txt=False) | |
| output_folder = UPLOAD_FOLDERS['automatic_segmented'] | |
| os.makedirs(output_folder, exist_ok=True) | |
| chips_data = [] | |
| chips = [] | |
| voids = [] | |
| # Process results and save segmented images | |
| for result in results: | |
| annotated_image = result.plot() | |
| result_filename = f"{file.filename.rsplit('.', 1)[0]}_pred.jpg" | |
| result_path = os.path.join(output_folder, result_filename) | |
| Image.fromarray(annotated_image).save(result_path) | |
| # Separate chips and voids | |
| for i, label in enumerate(result.boxes.cls): # YOLO labels | |
| label_name = result.names[int(label)] # Get label name (e.g., 'chip' or 'void') | |
| box = result.boxes.xyxy[i].cpu().numpy() # Bounding box (x1, y1, x2, y2) | |
| area = float((box[2] - box[0]) * (box[3] - box[1])) # Calculate area | |
| if label_name == 'chip': | |
| chips.append({'box': box, 'area': area, 'voids': []}) | |
| elif label_name == 'void': | |
| voids.append({'box': box, 'area': area}) | |
| # Assign voids to chips based on proximity | |
| for void in voids: | |
| void_centroid = [ | |
| (void['box'][0] + void['box'][2]) / 2, # x centroid | |
| (void['box'][1] + void['box'][3]) / 2 # y centroid | |
| ] | |
| for chip in chips: | |
| # Check if void centroid is within chip bounding box | |
| if (chip['box'][0] <= void_centroid[0] <= chip['box'][2] and | |
| chip['box'][1] <= void_centroid[1] <= chip['box'][3]): | |
| chip['voids'].append(void) | |
| break | |
| # Calculate metrics for each chip | |
| for idx, chip in enumerate(chips): | |
| chip_area = chip['area'] | |
| total_void_area = sum([float(void['area']) for void in chip['voids']]) | |
| max_void_area = max([float(void['area']) for void in chip['voids']], default=0) | |
| void_percentage = (total_void_area / chip_area) * 100 if chip_area > 0 else 0 | |
| max_void_percentage = (max_void_area / chip_area) * 100 if chip_area > 0 else 0 | |
| chips_data.append({ | |
| "chip_number": int(idx + 1), | |
| "chip_area": round(chip_area, 2), | |
| "void_percentage": round(void_percentage, 2), | |
| "max_void_percentage": round(max_void_percentage, 2) | |
| }) | |
| # Return the segmented image URL and table data | |
| segmented_url = f"/static/uploads/segmented/automatic/{result_filename}" | |
| return jsonify({ | |
| "segmented_url": segmented_url, # Use the URL for frontend access | |
| "table_data": { | |
| "image_name": file.filename, | |
| "chips": chips_data | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Error in automatic segmentation: {e}") | |
| return jsonify({'error': 'Segmentation failed.'}), 500 | |
| def save_both(): | |
| """Save both the image and masks into the history folders.""" | |
| data = request.json | |
| image_name = data.get('image_name') | |
| if not image_name: | |
| return jsonify({'error': 'Image name not provided'}), 400 | |
| try: | |
| # Ensure image_name is a pure file name | |
| image_name = os.path.basename(image_name) # Strip any directory path | |
| print(f"Sanitized Image Name: {image_name}") | |
| # Correctly resolve the input image path | |
| input_image_path = os.path.join(UPLOAD_FOLDERS['input'], image_name) | |
| if not os.path.exists(input_image_path): | |
| print(f"Input image does not exist: {input_image_path}") | |
| return jsonify({'error': f'Input image not found: {input_image_path}'}), 404 | |
| # Copy the image to history/images | |
| image_history_path = os.path.join(HISTORY_FOLDERS['images'], image_name) | |
| os.makedirs(os.path.dirname(image_history_path), exist_ok=True) | |
| shutil.copy(input_image_path, image_history_path) | |
| print(f"Image saved to history: {image_history_path}") | |
| # Backup void mask | |
| void_mask_path = os.path.join(UPLOAD_FOLDERS['mask_voids'], 'raw_mask.png') | |
| if os.path.exists(void_mask_path): | |
| void_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") | |
| os.makedirs(os.path.dirname(void_mask_history_path), exist_ok=True) | |
| shutil.copy(void_mask_path, void_mask_history_path) | |
| print(f"Voids mask saved to history: {void_mask_history_path}") | |
| else: | |
| print(f"Voids mask not found: {void_mask_path}") | |
| # Backup chip mask | |
| chip_mask_path = os.path.join(UPLOAD_FOLDERS['mask_chips'], 'raw_mask.png') | |
| if os.path.exists(chip_mask_path): | |
| chip_mask_history_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") | |
| os.makedirs(os.path.dirname(chip_mask_history_path), exist_ok=True) | |
| shutil.copy(chip_mask_path, chip_mask_history_path) | |
| print(f"Chips mask saved to history: {chip_mask_history_path}") | |
| else: | |
| print(f"Chips mask not found: {chip_mask_path}") | |
| return jsonify({'message': 'Image and masks saved successfully!'}), 200 | |
| except Exception as e: | |
| print(f"Error saving files: {e}") | |
| return jsonify({'error': 'Failed to save files.', 'details': str(e)}), 500 | |
| def get_history(): | |
| try: | |
| saved_images = os.listdir(HISTORY_FOLDERS['images']) | |
| return jsonify({'status': 'success', 'images': saved_images}), 200 | |
| except Exception as e: | |
| return jsonify({'status': 'error', 'message': f'Failed to fetch history: {e}'}), 500 | |
| def delete_history_item(): | |
| data = request.json | |
| image_name = data.get('image_name') | |
| if not image_name: | |
| return jsonify({'error': 'Image name not provided'}), 400 | |
| try: | |
| image_path = os.path.join(HISTORY_FOLDERS['images'], image_name) | |
| if os.path.exists(image_path): | |
| os.remove(image_path) | |
| void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") | |
| if os.path.exists(void_mask_path): | |
| os.remove(void_mask_path) | |
| chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") | |
| if os.path.exists(chip_mask_path): | |
| os.remove(chip_mask_path) | |
| return jsonify({'message': f'{image_name} and associated masks deleted successfully.'}), 200 | |
| except Exception as e: | |
| return jsonify({'error': f'Failed to delete files: {e}'}), 500 | |
| # Lock for training status updates | |
| status_lock = Lock() | |
| def update_training_status(key, value): | |
| """Thread-safe update for training status.""" | |
| with status_lock: | |
| training_status[key] = value | |
| def retrain_model(): | |
| """Handle retrain model workflow.""" | |
| global training_status | |
| if training_status.get('running', False): | |
| return jsonify({'error': 'Training is already in progress'}), 400 | |
| try: | |
| # Update training status | |
| update_training_status('running', True) | |
| update_training_status('cancelled', False) | |
| logging.info("Training status updated. Starting training workflow.") | |
| # Backup masks and images | |
| backup_masks_and_images() | |
| logging.info("Backup completed successfully.") | |
| # Prepare YOLO labels | |
| prepare_yolo_labels() | |
| logging.info("YOLO labels prepared successfully.") | |
| # Start YOLO training in a separate thread | |
| threading.Thread(target=run_yolo_training).start() | |
| return jsonify({'message': 'Training started successfully!'}), 200 | |
| except Exception as e: | |
| logging.error(f"Error during training preparation: {e}") | |
| update_training_status('running', False) | |
| return jsonify({'error': f"Failed to start training: {e}"}), 500 | |
| def prepare_yolo_labels(): | |
| """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" | |
| images_folder = HISTORY_FOLDERS['images'] # Use history images as the source | |
| train_labels_folder = DATASET_FOLDERS['train_labels'] | |
| train_images_folder = DATASET_FOLDERS['train_images'] | |
| val_labels_folder = DATASET_FOLDERS['val_labels'] | |
| val_images_folder = DATASET_FOLDERS['val_images'] | |
| # Ensure destination directories exist | |
| os.makedirs(train_labels_folder, exist_ok=True) | |
| os.makedirs(train_images_folder, exist_ok=True) | |
| os.makedirs(val_labels_folder, exist_ok=True) | |
| os.makedirs(val_images_folder, exist_ok=True) | |
| try: | |
| all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] | |
| random.shuffle(all_images) # Shuffle the images for randomness | |
| # Determine split index | |
| split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation | |
| # Split images into train and validation sets | |
| train_images = all_images[:split_idx] | |
| val_images = all_images[split_idx:] | |
| # Process training images | |
| for image_name in train_images: | |
| process_image_and_mask( | |
| image_name, | |
| source_images_folder=images_folder, | |
| dest_images_folder=train_images_folder, | |
| dest_labels_folder=train_labels_folder | |
| ) | |
| # Process validation images | |
| for image_name in val_images: | |
| process_image_and_mask( | |
| image_name, | |
| source_images_folder=images_folder, | |
| dest_images_folder=val_images_folder, | |
| dest_labels_folder=val_labels_folder | |
| ) | |
| logging.info("YOLO labels prepared, and images split into train and validation successfully.") | |
| except Exception as e: | |
| logging.error(f"Error in preparing YOLO labels: {e}") | |
| raise | |
| import random | |
| def prepare_yolo_labels(): | |
| """Convert all masks into YOLO-compatible labels and copy images to the dataset folder.""" | |
| images_folder = HISTORY_FOLDERS['images'] # Use history images as the source | |
| train_labels_folder = DATASET_FOLDERS['train_labels'] | |
| train_images_folder = DATASET_FOLDERS['train_images'] | |
| val_labels_folder = DATASET_FOLDERS['val_labels'] | |
| val_images_folder = DATASET_FOLDERS['val_images'] | |
| # Ensure destination directories exist | |
| os.makedirs(train_labels_folder, exist_ok=True) | |
| os.makedirs(train_images_folder, exist_ok=True) | |
| os.makedirs(val_labels_folder, exist_ok=True) | |
| os.makedirs(val_images_folder, exist_ok=True) | |
| try: | |
| all_images = [img for img in os.listdir(images_folder) if img.endswith(('.jpg', '.png'))] | |
| random.shuffle(all_images) # Shuffle the images for randomness | |
| # Determine split index | |
| split_idx = int(len(all_images) * 0.8) # 80% for training, 20% for validation | |
| # Split images into train and validation sets | |
| train_images = all_images[:split_idx] | |
| val_images = all_images[split_idx:] | |
| # Process training images | |
| for image_name in train_images: | |
| process_image_and_mask( | |
| image_name, | |
| source_images_folder=images_folder, | |
| dest_images_folder=train_images_folder, | |
| dest_labels_folder=train_labels_folder | |
| ) | |
| # Process validation images | |
| for image_name in val_images: | |
| process_image_and_mask( | |
| image_name, | |
| source_images_folder=images_folder, | |
| dest_images_folder=val_images_folder, | |
| dest_labels_folder=val_labels_folder | |
| ) | |
| logging.info("YOLO labels prepared, and images split into train and validation successfully.") | |
| except Exception as e: | |
| logging.error(f"Error in preparing YOLO labels: {e}") | |
| raise | |
| def process_image_and_mask(image_name, source_images_folder, dest_images_folder, dest_labels_folder): | |
| """ | |
| Process a single image and its masks, saving them in the appropriate YOLO format. | |
| """ | |
| try: | |
| image_path = os.path.join(source_images_folder, image_name) | |
| label_file_path = os.path.join(dest_labels_folder, f"{os.path.splitext(image_name)[0]}.txt") | |
| # Copy image to the destination images folder | |
| shutil.copy(image_path, os.path.join(dest_images_folder, image_name)) | |
| # Clear the label file if it exists | |
| if os.path.exists(label_file_path): | |
| os.remove(label_file_path) | |
| # Process void mask | |
| void_mask_path = os.path.join(HISTORY_FOLDERS['masks_void'], f"{os.path.splitext(image_name)[0]}.png") | |
| if os.path.exists(void_mask_path): | |
| convert_mask_to_yolo( | |
| mask_path=void_mask_path, | |
| image_path=image_path, | |
| class_id=0, # Void class | |
| output_path=label_file_path | |
| ) | |
| # Process chip mask | |
| chip_mask_path = os.path.join(HISTORY_FOLDERS['masks_chip'], f"{os.path.splitext(image_name)[0]}.png") | |
| if os.path.exists(chip_mask_path): | |
| convert_mask_to_yolo( | |
| mask_path=chip_mask_path, | |
| image_path=image_path, | |
| class_id=1, # Chip class | |
| output_path=label_file_path, | |
| append=True # Append chip annotations | |
| ) | |
| logging.info(f"Processed {image_name} into YOLO format.") | |
| except Exception as e: | |
| logging.error(f"Error processing {image_name}: {e}") | |
| raise | |
| def backup_masks_and_images(): | |
| """Backup current masks and images from history folders.""" | |
| temp_backup_paths = { | |
| 'voids': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/voids'), | |
| 'chips': os.path.join(DATASET_FOLDERS['temp_backup'], 'masks/chips'), | |
| 'images': os.path.join(DATASET_FOLDERS['temp_backup'], 'images') | |
| } | |
| # Prepare all backup directories | |
| for path in temp_backup_paths.values(): | |
| if os.path.exists(path): | |
| shutil.rmtree(path) | |
| os.makedirs(path, exist_ok=True) | |
| try: | |
| # Backup images from history | |
| for file in os.listdir(HISTORY_FOLDERS['images']): | |
| src_image_path = os.path.join(HISTORY_FOLDERS['images'], file) | |
| dst_image_path = os.path.join(temp_backup_paths['images'], file) | |
| shutil.copy(src_image_path, dst_image_path) | |
| # Backup void masks from history | |
| for file in os.listdir(HISTORY_FOLDERS['masks_void']): | |
| src_void_path = os.path.join(HISTORY_FOLDERS['masks_void'], file) | |
| dst_void_path = os.path.join(temp_backup_paths['voids'], file) | |
| shutil.copy(src_void_path, dst_void_path) | |
| # Backup chip masks from history | |
| for file in os.listdir(HISTORY_FOLDERS['masks_chip']): | |
| src_chip_path = os.path.join(HISTORY_FOLDERS['masks_chip'], file) | |
| dst_chip_path = os.path.join(temp_backup_paths['chips'], file) | |
| shutil.copy(src_chip_path, dst_chip_path) | |
| logging.info("Masks and images backed up successfully from history.") | |
| except Exception as e: | |
| logging.error(f"Error during backup: {e}") | |
| raise RuntimeError("Backup process failed.") | |
| def run_yolo_training(num_epochs=10): | |
| """Run YOLO training process.""" | |
| global training_process | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| data_cfg_path = os.path.join(BASE_DIR, "models/data.yaml") # Ensure correct YAML path | |
| logging.info(f"Starting YOLO training on {device} with {num_epochs} epochs.") | |
| logging.info(f"Using dataset configuration: {data_cfg_path}") | |
| training_command = [ | |
| "yolo", | |
| "train", | |
| f"data={data_cfg_path}", | |
| f"model={os.path.join(DATASET_FOLDERS['models'], 'best.pt')}", | |
| f"device={device}", | |
| f"epochs={num_epochs}", | |
| "project=runs", | |
| "name=train" | |
| ] | |
| training_process = subprocess.Popen( | |
| training_command, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| env=os.environ.copy(), | |
| ) | |
| # Display and log output in real time | |
| for line in iter(training_process.stdout.readline, ''): | |
| print(line.strip()) | |
| logging.info(line.strip()) | |
| socketio.emit('training_update', {'message': line.strip()}) # Send updates to the frontend | |
| training_process.wait() | |
| if training_process.returncode == 0: | |
| finalize_training() # Finalize successfully completed training | |
| else: | |
| raise RuntimeError("YOLO training process failed. Check logs for details.") | |
| except Exception as e: | |
| logging.error(f"Training error: {e}") | |
| restore_backup() # Restore the dataset and masks | |
| # Emit training error event to the frontend | |
| socketio.emit('training_status', {'status': 'error', 'message': f"Training failed: {str(e)}"}) | |
| finally: | |
| update_training_status('running', False) | |
| training_process = None # Reset the process | |
| def handle_cancel_training(): | |
| """Cancel the YOLO training process.""" | |
| global training_process, training_status | |
| if not training_status.get('running', False): | |
| socketio.emit('button_update', {'action': 'retrain'}) # Update button to retrain | |
| return | |
| try: | |
| training_process.terminate() | |
| training_process.wait() | |
| training_status['running'] = False | |
| training_status['cancelled'] = True | |
| restore_backup() | |
| cleanup_train_val_directories() | |
| # Emit button state change | |
| socketio.emit('button_update', {'action': 'retrain'}) | |
| socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) | |
| except Exception as e: | |
| logging.error(f"Error cancelling training: {e}") | |
| socketio.emit('training_status', {'status': 'error', 'message': str(e)}) | |
| def finalize_training(): | |
| """Finalize training by promoting the new model and cleaning up.""" | |
| try: | |
| # Locate the most recent training directory | |
| runs_dir = os.path.join(BASE_DIR, 'runs') | |
| if not os.path.exists(runs_dir): | |
| raise FileNotFoundError("Training runs directory does not exist.") | |
| # Get the latest training run folder | |
| latest_run = max( | |
| [os.path.join(runs_dir, d) for d in os.listdir(runs_dir)], | |
| key=os.path.getmtime | |
| ) | |
| weights_dir = os.path.join(latest_run, 'weights') | |
| best_model_path = os.path.join(weights_dir, 'best.pt') | |
| if not os.path.exists(best_model_path): | |
| raise FileNotFoundError(f"'best.pt' not found in {weights_dir}.") | |
| # Backup the old model | |
| old_model_folder = DATASET_FOLDERS['models_old'] | |
| os.makedirs(old_model_folder, exist_ok=True) | |
| existing_best_model = os.path.join(DATASET_FOLDERS['models'], 'best.pt') | |
| if os.path.exists(existing_best_model): | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| shutil.move(existing_best_model, os.path.join(old_model_folder, f"old_{timestamp}.pt")) | |
| logging.info(f"Old model backed up to {old_model_folder}.") | |
| # Move the new model to the models directory | |
| new_model_dest = os.path.join(DATASET_FOLDERS['models'], 'best.pt') | |
| shutil.move(best_model_path, new_model_dest) | |
| logging.info(f"New model saved to {new_model_dest}.") | |
| # Notify frontend that training is completed | |
| socketio.emit('training_status', { | |
| 'status': 'completed', | |
| 'message': 'Training completed successfully! Model saved as best.pt.' | |
| }) | |
| # Clean up train/val directories | |
| cleanup_train_val_directories() | |
| logging.info("Train and validation directories cleaned up successfully.") | |
| except Exception as e: | |
| logging.error(f"Error finalizing training: {e}") | |
| # Emit error status to the frontend | |
| socketio.emit('training_status', {'status': 'error', 'message': f"Error finalizing training: {str(e)}"}) | |
| def restore_backup(): | |
| """Restore the dataset and masks from the backup.""" | |
| try: | |
| temp_backup = DATASET_FOLDERS['temp_backup'] | |
| shutil.copytree(os.path.join(temp_backup, 'masks/voids'), UPLOAD_FOLDERS['mask_voids'], dirs_exist_ok=True) | |
| shutil.copytree(os.path.join(temp_backup, 'masks/chips'), UPLOAD_FOLDERS['mask_chips'], dirs_exist_ok=True) | |
| shutil.copytree(os.path.join(temp_backup, 'images'), UPLOAD_FOLDERS['input'], dirs_exist_ok=True) | |
| logging.info("Backup restored successfully.") | |
| except Exception as e: | |
| logging.error(f"Error restoring backup: {e}") | |
| def cancel_training(): | |
| global training_process | |
| if training_process is None: | |
| logging.error("No active training process to terminate.") | |
| return jsonify({'error': 'No active training process to cancel.'}), 400 | |
| try: | |
| training_process.terminate() | |
| training_process.wait() | |
| training_process = None # Reset the process after termination | |
| # Update training status | |
| update_training_status('running', False) | |
| update_training_status('cancelled', True) | |
| # Check if the model is already saved as best.pt | |
| best_model_path = os.path.join(DATASET_FOLDERS['models'], 'best.pt') | |
| if os.path.exists(best_model_path): | |
| logging.info(f"Model already saved as best.pt at {best_model_path}.") | |
| socketio.emit('button_update', {'action': 'revert'}) # Notify frontend to revert button state | |
| else: | |
| logging.info("Training canceled, but no new model was saved.") | |
| # Restore backup if needed | |
| restore_backup() | |
| cleanup_train_val_directories() | |
| # Emit status update to frontend | |
| socketio.emit('training_status', {'status': 'cancelled', 'message': 'Training was canceled by the user.'}) | |
| return jsonify({'message': 'Training canceled and data restored successfully.'}), 200 | |
| except Exception as e: | |
| logging.error(f"Error cancelling training: {e}") | |
| return jsonify({'error': f"Failed to cancel training: {e}"}), 500 | |
| def clear_history(): | |
| try: | |
| for folder in [HISTORY_FOLDERS['images'], HISTORY_FOLDERS['masks_chip'], HISTORY_FOLDERS['masks_void']]: | |
| shutil.rmtree(folder, ignore_errors=True) | |
| os.makedirs(folder, exist_ok=True) # Recreate the empty folder | |
| return jsonify({'message': 'History cleared successfully!'}), 200 | |
| except Exception as e: | |
| return jsonify({'error': f'Failed to clear history: {e}'}), 500 | |
| def get_training_status(): | |
| """Return the current training status.""" | |
| if training_status.get('running', False): | |
| return jsonify({'status': 'running', 'message': 'Training in progress.'}), 200 | |
| elif training_status.get('cancelled', False): | |
| return jsonify({'status': 'cancelled', 'message': 'Training was cancelled.'}), 200 | |
| return jsonify({'status': 'idle', 'message': 'No training is currently running.'}), 200 | |
| def cleanup_train_val_directories(): | |
| """Clear the train and validation directories.""" | |
| try: | |
| for folder in [DATASET_FOLDERS['train_images'], DATASET_FOLDERS['train_labels'], | |
| DATASET_FOLDERS['val_images'], DATASET_FOLDERS['val_labels']]: | |
| shutil.rmtree(folder, ignore_errors=True) # Remove folder contents | |
| os.makedirs(folder, exist_ok=True) # Recreate empty folders | |
| logging.info("Train and validation directories cleaned up successfully.") | |
| except Exception as e: | |
| logging.error(f"Error cleaning up train/val directories: {e}") | |
| if __name__ == '__main__': | |
| multiprocessing.set_start_method('spawn') # Required for multiprocessing on Windows | |
| app.run(debug=True, use_reloader=False) | |