""" ConvNeXt CheXpert Classifier with GradCAM - Security Enhanced Version Security Improvements: - Removed os.system() calls for package installation - Added weights_only=True for torch.load() where possible - Input validation for images and confidence thresholds - State dict validation to prevent malicious model loading - Error message sanitization to prevent information disclosure - Environment variable usage for secrets (HF_TOKEN) - File extension validation for downloaded models """ import os import torch import json import timm import gradio as gr import numpy as np import torch.nn as nn import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms import cv2 try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image except ImportError: raise ImportError("Required package 'grad-cam' not found. Please install it manually: pip install grad-cam") try: from huggingface_hub import hf_hub_download except ImportError: raise ImportError("Required package 'huggingface_hub' not found. Please install it manually: pip install huggingface_hub") DISEASE_LABELS = [ "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices" ] MODEL_CONFIG = { "input_size": 384, "num_classes": 14, "mean": [0.5029414296150208] * 3, "std": [0.2892409563064575] * 3 } ENSEMBLE_CONFIG = {} try: with open("ensemble_config.json", "r") as f: ENSEMBLE_CONFIG = json.load(f) print("✅ Loaded ensemble_config.json successfully.") except FileNotFoundError: print("⚠️ Warning: ensemble_config.json not found. Using default configuration.") ENSEMBLE_CONFIG = {"weights": {}} except json.JSONDecodeError: print("❌ Error: Could not decode ensemble_config.json. Using default configuration.") ENSEMBLE_CONFIG = {"weights": {}} class CBAM(nn.Module): """Convolutional Block Attention Module - matches training implementation""" def __init__(self, channels, reduction=16, kernel_size=7): super(CBAM, self).__init__() self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // reduction, 1, bias=False), nn.ReLU(), nn.Conv2d(channels // reduction, channels, 1, bias=False), nn.Sigmoid() ) self.spatial_attention = nn.Sequential( nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False), nn.Sigmoid() ) def forward(self, x): ca = self.channel_attention(x) x = x * ca avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) sa = self.spatial_attention(torch.cat([avg_out, max_out], dim=1)) return x * sa class ConvNeXtWithMetadata(nn.Module): """ConvNeXt model with CBAM attention and metadata fusion - matches training architecture""" def __init__(self, num_classes=14, metadata_input_dim=8, model_name="convnext_base", pretrained=False): super().__init__() self.convnext_backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0, features_only=True ) self.cbam = CBAM(self.convnext_backbone.feature_info.channels()[-1], reduction=16) self.global_pool = nn.AdaptiveAvgPool2d(1) self.num_image_features = self.convnext_backbone.feature_info.channels()[-1] self.metadata_fc = nn.Sequential( nn.Linear(metadata_input_dim, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.2) ) self.num_metadata_features = 32 self.classifier = nn.Linear(self.num_image_features + self.num_metadata_features, num_classes) def forward(self, pixel_values, metadata=None): # Handle case where metadata is not provided (for backward compatibility) if metadata is None: # Create default metadata tensor (batch_size, metadata_dim) batch_size = pixel_values.size(0) metadata = torch.zeros(batch_size, 8).to(pixel_values.device) feats_list = self.convnext_backbone(pixel_values) feats = feats_list[-1] feats = self.cbam(feats) feats = self.global_pool(feats) feats = feats.view(feats.size(0), -1) metadata = metadata.float().to(feats.device) metadata_features = self.metadata_fc(metadata) combined_features = torch.cat((feats, metadata_features), dim=1) logits = self.classifier(combined_features) return logits def load_ensemble_model(model_repo="calender/Ensemble_C"): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") models = [] model_weights = [] # Determine if loading locally or from HF Hub # For HF Spaces, prioritize HF Hub; locally prioritize local files load_locally = os.environ.get("LOAD_MODELS_LOCALLY", "false").lower() == "true" is_hf_spaces = os.environ.get("SPACE_ID") is not None if is_hf_spaces: print("🫁 Running on HuggingFace Spaces - prioritizing HF Hub") load_locally = False else: print("💻 Running locally - checking local files first") # Map config filenames to local filenames # Local clean files: Model4.pth, Model5.pth, Model6.pth config_to_local_map = { "Model4.pth": "Model4.pth", "Model5.pth": "Model5.pth", "Model6.pth": "Model6.pth" } # Map config filenames to actual repository filenames # Repository contains clean files: Model4.pth, Model5.pth, Model6.pth # (NOT the checkpoint files Iteration4_BEST.pth which have extra training data) config_to_repo_map = { "Model4.pth": "Model4.pth", "Model5.pth": "Model5.pth", "Model6.pth": "Model6.pth" } # Use filenames from config if available, otherwise default config_model_files = list(ENSEMBLE_CONFIG.get("weights", {}).keys()) if not config_model_files: print("⚠️ No model weights in config, using default model files and equal weights.") config_model_files = ["Model4.pth", "Model5.pth", "Model6.pth"] print(f" Expected repository files: {config_model_files}") else: print(f"📋 Config specifies models: {config_model_files}") print(f"🔍 Will download from repository: {[config_to_repo_map.get(f, f) for f in config_model_files]}") for config_filename in config_model_files: weight = ENSEMBLE_CONFIG.get("weights", {}).get(config_filename, 1.0) if load_locally: local_filename = config_to_local_map.get(config_filename) if not local_filename: print(f"❌ Error: No local mapping for '{config_filename}'. Skipping.") continue model_path = local_filename if not os.path.exists(model_path): print(f"❌ ERROR: Local model file not found at '{model_path}'") return [], device, [] print(f"Found local model: {model_path}") else: # Use the actual filename in the repository repo_filename = config_to_repo_map.get(config_filename, config_filename) print(f"Downloading {repo_filename} from repo {model_repo}...") try: # Set HF token for private repo access hf_token = os.environ.get("HF_TOKEN") if hf_token: print(f"✅ Using HF_TOKEN for private repo access") model_path = hf_hub_download(repo_id=model_repo, filename=repo_filename, token=hf_token) print(f"✅ Downloaded {repo_filename} successfully.") # Basic validation of downloaded file if not model_path.endswith('.pth') and not model_path.endswith('.pt'): print(f"⚠️ Warning: Downloaded file {model_path} doesn't have expected .pth/.pt extension") if not os.path.exists(model_path): print(f"❌ Downloaded file not found at {model_path}") continue except Exception as e: print(f"❌ Failed to download {repo_filename}: {e}") print(f" Make sure '{repo_filename}' exists in the '{model_repo}' repository.") continue # Load state dict with proper error handling try: # Try loading with weights_only=True first (more secure) state_dict = torch.load(model_path, map_location=device, weights_only=True) except Exception: # Fall back to full loading for backward compatibility with older models try: state_dict = torch.load(model_path, map_location=device, weights_only=False) print("⚠️ Loaded model with full serialization (less secure)") except Exception as e: print(f"❌ Failed to load model {model_path}: {e}") continue if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] # Basic validation of state dict if not isinstance(state_dict, dict) or len(state_dict) == 0: print(f"❌ Invalid state dict structure in {model_path}") continue # Check for suspicious keys (potential code injection) suspicious_keys = [k for k in state_dict.keys() if any(suspicious in k.lower() for suspicious in ['exec', 'eval', 'import', 'open', '__'])] if suspicious_keys: print(f"❌ Suspicious keys found in state dict: {suspicious_keys}") continue # Check the actual classifier shape to determine the required feature dimension if 'classifier.weight' in state_dict: required_features = state_dict['classifier.weight'].shape[1] print(f"Saved model requires {required_features} input features") # Try to find a model that matches the required features model_configs = [ ("convnext_base", 1024), ("convnext_large", 1536), ("convnext_small", 768), ("convnext_tiny", 512), ] # Find the best matching model best_match = None for model_name, features in model_configs: if features == required_features: best_match = model_name break if best_match: print(f"Using {best_match} (matches {required_features} features)") model = ConvNeXtWithMetadata(num_classes=14, metadata_input_dim=8, model_name=best_match).to(device) else: print(f"Warning: No standard model matches {required_features} features") print("Using default ConvNeXtWithMetadata architecture (matches training)") model = ConvNeXtWithMetadata(num_classes=14, metadata_input_dim=8).to(device) else: # Check if we can infer from other keys in the state dict print("Warning: Could not determine required features from saved model") print("Available classifier-related keys:", [k for k in state_dict.keys() if 'classifier' in k][:3]) # Show first 3 classifier keys only print("Using default ConvNeXtWithMetadata architecture") model = ConvNeXtWithMetadata(num_classes=14, metadata_input_dim=8).to(device) # Handle module prefix (from DataParallel) if any(key.startswith('module.') for key in state_dict.keys()): state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} print("Removed 'module.' prefix from state dict keys") try: model.load_state_dict(state_dict, strict=False) print(f"✅ Successfully loaded model for '{config_filename}' with weight {weight:.2f}") except Exception as e: print(f"❌ Failed to load state dict for '{config_filename}': {e}") print("This might indicate architecture mismatch. Consider retraining or using correct model architecture.") continue model.eval() models.append(model) model_weights.append(weight) if not models: print("❌ No models were loaded. Ensemble cannot be created.") return [], device, [] # Normalize weights to sum to 1 total_weight = sum(model_weights) normalized_weights = [w / total_weight for w in model_weights] print(f"Ensemble loaded with {len(models)} models. Normalized weights: {[f'{w:.2f}' for w in normalized_weights]}") print(f"📊 Models loaded: {config_model_files}") print(f"🔗 Repository: {model_repo}") print(f"🎯 Ready for Gradio interface!") return models, device, normalized_weights def predict_with_ensemble(models, device, normalized_weights, image, confidence_threshold=0.55): # Validate image input if image is None: raise ValueError("Image cannot be None") if not hasattr(image, 'size') or not hasattr(image, 'convert'): raise ValueError("Invalid image format. Expected PIL Image.") # Check image dimensions are reasonable (not too small or too large) width, height = image.size if width < 10 or height < 10 or width > 10000 or height > 10000: raise ValueError(f"Image dimensions too extreme: {width}x{height}") transform = transforms.Compose([ transforms.Grayscale(num_output_channels=3), transforms.Resize((MODEL_CONFIG["input_size"], MODEL_CONFIG["input_size"])), transforms.ToTensor(), transforms.Normalize(mean=MODEL_CONFIG["mean"], std=MODEL_CONFIG["std"]) ]) input_tensor = transform(image).unsqueeze(0).to(device) # Validate confidence threshold if not isinstance(confidence_threshold, (int, float)) or not 0 <= confidence_threshold <= 1: raise ValueError(f"Invalid confidence threshold: {confidence_threshold}. Must be between 0 and 1.") metadata_tensor = torch.tensor([[ 0.0, # sex (0 = male, 1 = female) 50.0, # age (scaled by 100) 0.0, # age missing indicator 1.0, # frontal/lateral (1 = frontal, 0 = lateral) 0.0, # AP (0 = PA, 1 = AP) 1.0, # PA (1 = PA, 0 = AP) 0.0, # AP/PA no label 0.0 # AP/PA unknown ]], dtype=torch.float).to(device) all_probabilities = [] # Get temperature scaling settings temperature_scaling = ENSEMBLE_CONFIG.get("temperature_scaling", False) temperature_value = ENSEMBLE_CONFIG.get("temperature_value", 1.0) with torch.no_grad(): for i, model in enumerate(models): # Pass both image and metadata to the model logits = model(input_tensor, metadata_tensor) # Apply temperature scaling if enabled if temperature_scaling and temperature_value != 1.0: logits = logits / temperature_value probabilities = torch.sigmoid(logits).squeeze().cpu().numpy() all_probabilities.append(probabilities * normalized_weights[i]) # Sum the weighted probabilities if all_probabilities: ensemble_probabilities = np.sum(all_probabilities, axis=0) else: ensemble_probabilities = np.array([]) confident_indices = [] confident_predictions = [] # Get threshold optimization settings and optimized thresholds threshold_optimization = ENSEMBLE_CONFIG.get("threshold_optimization", False) optimized_thresholds = ENSEMBLE_CONFIG.get("optimized_thresholds", {}) for idx, (prob, disease) in enumerate(zip(ensemble_probabilities, DISEASE_LABELS)): # Use per-class optimized threshold if available, otherwise use global threshold if threshold_optimization and disease in optimized_thresholds: current_threshold = optimized_thresholds[disease] else: current_threshold = confidence_threshold if prob > current_threshold: confident_indices.append(idx) confident_predictions.append({ 'disease': disease, 'confidence': float(prob), 'class_idx': idx, 'threshold_used': current_threshold }) if not confident_predictions: threshold_info = "optimized thresholds" if threshold_optimization else f"{confidence_threshold:.0%} threshold" return { 'predictions': [], 'message': f'No findings above {threshold_info}', 'visualizations': None, 'ensemble_auc': 0.817, 'temperature_scaling_used': temperature_scaling, 'temperature_value': temperature_value } target_layer = None # Find target layer for GradCAM - look in convnext_backbone specifically for module in reversed(list(models[0].convnext_backbone.modules())): if isinstance(module, nn.Conv2d): target_layer = module print(f"Found target layer for GradCAM in convnext_backbone") break if target_layer is None: print("Warning: Could not find suitable layer for GradCAM in convnext_backbone") return { 'predictions': confident_predictions, 'message': 'Could not find suitable layer for GradCAM', 'visualizations': None, 'ensemble_auc': 0.817, 'temperature_scaling_used': temperature_scaling, 'temperature_value': temperature_value } visualizations = {} for pred in confident_predictions: class_idx = pred['class_idx'] disease = pred['disease'] confidence = pred['confidence'] targets = [ClassifierOutputTarget(class_idx)] try: with GradCAM(model=models[0], target_layers=[target_layer]) as cam: grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :] rgb_img = np.array(image.convert('RGB'), dtype=np.float32) / 255.0 grayscale_cam_resized = cv2.resize(grayscale_cam, (rgb_img.shape[1], rgb_img.shape[0])) cam_overlay = show_cam_on_image( rgb_img, grayscale_cam_resized, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_JET ) visualizations[disease] = { 'heatmap': grayscale_cam_resized, 'overlay': cam_overlay, 'confidence': confidence } except Exception as e: print(f"Error generating GradCAM for {disease}: {e}") continue # Create threshold info for successful predictions if threshold_optimization: threshold_summary = "optimized thresholds" else: threshold_summary = f"{confidence_threshold:.0%} threshold" return { 'predictions': confident_predictions, 'message': f'Found {len(confident_predictions)} confident predictions above {threshold_summary} (Ensemble AUC: 0.817)', 'visualizations': visualizations, 'ensemble_auc': 0.817, 'temperature_scaling_used': temperature_scaling, 'temperature_value': temperature_value } def create_gradio_interface(): models, device, normalized_weights = load_ensemble_model() if not models: print("Models could not be loaded. Aborting Gradio interface creation.") return None print("🎉 All models loaded successfully! Ready for Gradio interface.") def analyze_xray(image, confidence_threshold=0.55): if image is None: return "Please upload a chest X-ray image", None, None # Validate inputs if not isinstance(confidence_threshold, (int, float)) or not 0 <= confidence_threshold <= 1: return "❌ Invalid confidence threshold. Must be between 0 and 1.", None, "Invalid input" try: results = predict_with_ensemble(models, device, normalized_weights, image, confidence_threshold) if not results['predictions']: return results['message'], None, None prediction_text = f"## Analysis Results\n\n{results['message']}\n\n" # Add temperature scaling info if used if results.get('temperature_scaling_used', False): temp_value = results.get('temperature_value', 1.0) prediction_text += f"**🌡️ Temperature Scaling:** Applied (T={temp_value})\n\n" prediction_text += "### Confident Predictions:\n\n" for pred in results['predictions']: disease = pred['disease'] confidence = pred['confidence'] threshold = pred.get('threshold_used', confidence_threshold) prediction_text += f"🔍 **{disease}**: {confidence:.1%} (threshold: {threshold:.0%})\n" if results['visualizations']: num_plots = len(results['visualizations']) fig, axes = plt.subplots(num_plots, 3, figsize=(15, 5 * num_plots)) if num_plots == 1: axes = axes.reshape(1, -1) for i, (disease, vis_data) in enumerate(results['visualizations'].items()): axes[i, 0].imshow(image, cmap='gray') axes[i, 0].set_title(f"Original X-ray\n{disease}", fontsize=10) axes[i, 0].axis('off') axes[i, 1].imshow(vis_data['heatmap'], cmap='jet') axes[i, 1].set_title(f"GradCAM Heatmap\n{vis_data['confidence']:.1%}", fontsize=10) axes[i, 1].axis('off') axes[i, 2].imshow(vis_data['overlay']) axes[i, 2].set_title(f"GradCAM Overlay\n{disease}", fontsize=10) axes[i, 2].axis('off') plt.tight_layout() return prediction_text, fig, "✅ Analysis completed successfully!" return prediction_text, None, "✅ Analysis completed successfully!" except Exception as e: # Sanitize error message to prevent information disclosure error_msg = str(e) # Remove potentially sensitive information from error messages if "model_path" in error_msg or "state_dict" in error_msg or "classifier" in error_msg: error_msg = "Model loading failed due to architecture mismatch" return f"❌ Error analyzing image: {error_msg}", None, "Analysis failed" interface = gr.Interface( fn=analyze_xray, inputs=[ gr.Image(label="Upload Chest X-ray", type="pil"), gr.Slider(minimum=0.0, maximum=1.0, value=0.55, step=0.01, label="Confidence Threshold") ], outputs=[ gr.Markdown(label="Analysis Results"), gr.Plot(label="GradCAM Visualizations"), gr.Textbox(label="Status", interactive=False) ], title="🫁 ConvNeXt Ensemble Classifier with GradCAM", description=""" **Medical AI for Chest X-ray Analysis** This tool uses an **Ensemble of 3 ConvNeXt-Base models** with CBAM attention and **metadata fusion** to analyze chest X-rays and identify 14 different thoracic pathologies. **Features:** - 🔍 Multi-label classification of 14 chest conditions - 📊 Shows only confident predictions using optimized per-class thresholds - 🌡️ **Temperature scaling** for improved calibration (T=1.5) - 📋 **Metadata fusion** combining X-ray features with patient metadata - 🎯 GradCAM visualization showing model attention regions - 🏥 **Ensemble AUC: 0.817** (better than single model 0.811) - 🏥 Designed for research and educational purposes **⚠️ Important Medical Disclaimer:** This tool is for research and educational purposes only. Always consult qualified healthcare professionals for medical decisions. **Supported Conditions:** No Finding, Enlarged Cardiomediastinum, Cardiomegaly, Lung Opacity, Lung Lesion, Edema, Consolidation, Pneumonia, Atelectasis, Pneumothorax, Pleural Effusion, Pleural Other, Fracture, Support Devices """, theme="default", allow_flagging="never" ) return interface if __name__ == "__main__": print("Starting ConvNeXt CheXpert GradCAM App...") interface = create_gradio_interface() if interface is not None: interface.launch( server_name="0.0.0.0", server_port=7860, share=False, # Set to False for HF Spaces show_error=True ) else: print("❌ Failed to create Gradio interface. Check model loading.")