import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import logging import numpy as np # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Model configuration MODEL_NAME = "cybersectony/phishing-email-detection-distilbert_v2.4.1" # Global variables for model and tokenizer tokenizer = None model = None def load_model(): """Load the model and tokenizer with error handling""" global tokenizer, model try: logger.info("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) # Debug: Print model configuration logger.info(f"Model config: {model.config}") logger.info(f"Number of labels: {model.config.num_labels}") if hasattr(model.config, 'id2label'): logger.info(f"Label mapping: {model.config.id2label}") # Test model with simple input to check if it's working test_input = "Hello world" inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) test_probs = torch.nn.functional.softmax(outputs.logits, dim=-1) logger.info(f"Test probabilities: {test_probs[0].tolist()}") logger.info("Model loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {e}") return False def get_colored_bar(percentage, label): """Create colored progress bar based on percentage and label type""" # Determine color based on percentage and label if "phishing" in label.lower() or "suspicious" in label.lower(): # Red scale for dangerous content if percentage >= 70: color = "🟄" # High danger - red elif percentage >= 40: color = "🟠" # Medium danger - orange else: color = "🟔" # Low danger - yellow else: # Green scale for legitimate content if percentage >= 70: color = "🟢" # High confidence - green elif percentage >= 40: color = "🟔" # Medium confidence - yellow else: color = "⚪" # Low confidence - white # Create bar (scale to 20 characters) bar_length = max(1, int(percentage / 5)) # Ensure at least 1 if percentage > 0 bar = color * bar_length + "⚪" * (20 - bar_length) return bar def predict_email(email_text): """ Enhanced prediction function with proper model output handling """ # Input validation if not email_text or not email_text.strip(): return "āš ļø **Error**: Please enter some email text to analyze." if len(email_text.strip()) < 5: return "āš ļø **Warning**: Email text too short for reliable analysis." # Check if model is loaded if tokenizer is None or model is None: if not load_model(): return "āŒ **Error**: Failed to load the model." try: # Preprocess and tokenize inputs = tokenizer( email_text, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Get prediction with proper handling with torch.no_grad(): outputs = model(**inputs) # Apply temperature scaling to prevent overconfidence temperature = 1.5 scaled_logits = outputs.logits / temperature predictions = torch.nn.functional.softmax(scaled_logits, dim=-1) # Get probabilities probs = predictions[0].tolist() # Log raw outputs for debugging logger.info(f"Raw logits: {outputs.logits[0].tolist()}") logger.info(f"Scaled probabilities: {probs}") # Get proper labels from model config or use fallback if hasattr(model.config, 'id2label') and model.config.id2label: labels = {model.config.id2label[i]: probs[i] for i in range(len(probs))} else: # Fallback - check the actual model output dimension if len(probs) == 2: labels = { "Legitimate Email": probs[0], "Phishing Email": probs[1] } elif len(probs) == 4: labels = { "Legitimate Email": probs[0], "Phishing Email": probs[1], "Suspicious Content": probs[2], "Spam Email": probs[3] } else: # Generic labels labels = {f"Class {i}": probs[i] for i in range(len(probs))} # Check if model is giving reasonable outputs prob_variance = np.var(probs) max_prob = max(probs) # If variance is too low, the model might not be working properly if prob_variance < 0.01 and max_prob > 0.99: logger.warning("Model showing signs of overconfidence or poor calibration") # Apply smoothing smoothed_probs = [(p * 0.8 + 0.2/len(probs)) for p in probs] labels = {list(labels.keys())[i]: smoothed_probs[i] for i in range(len(smoothed_probs))} # Find prediction max_label = max(labels.items(), key=lambda x: x[1]) # Determine risk level and emoji confidence = max_label[1] prediction_name = max_label[0] if any(word in prediction_name.lower() for word in ['phishing', 'suspicious', 'spam']): if confidence > 0.8: risk_emoji = "🚨" risk_level = "HIGH RISK" elif confidence > 0.6: risk_emoji = "āš ļø" risk_level = "MEDIUM RISK" else: risk_emoji = "⚔" risk_level = "LOW RISK" else: if confidence > 0.8: risk_emoji = "āœ…" risk_level = "SAFE" elif confidence > 0.6: risk_emoji = "āœ…" risk_level = "LIKELY SAFE" else: risk_emoji = "ā“" risk_level = "UNCERTAIN" # Format output with colored bars result = f"{risk_emoji} **{risk_level}**\n\n" result += f"**Primary Classification**: {prediction_name}\n" result += f"**Confidence**: {confidence:.1%}\n\n" result += f"**Detailed Analysis**:\n" # Sort by probability and add colored bars for label, prob in sorted(labels.items(), key=lambda x: x[1], reverse=True): percentage = prob * 100 colored_bar = get_colored_bar(percentage, label) result += f"{label}: {percentage:.1f}% {colored_bar}\n" # Add debug info result += f"\n**Debug Info**:\n" result += f"Model Variance: {prob_variance:.4f}\n" result += f"Raw Probabilities: {[f'{p:.3f}' for p in probs]}\n" # Add recommendations based on actual classification if any(word in prediction_name.lower() for word in ['phishing', 'suspicious']) and confidence > 0.6: result += f"\nāš ļø **Recommendation**: This email shows signs of being malicious. Avoid clicking links or providing personal information." elif 'spam' in prediction_name.lower(): result += f"\nšŸ—‘ļø **Recommendation**: This appears to be spam. Consider deleting or marking as junk." elif confidence > 0.7: result += f"\nāœ… **Recommendation**: This email appears legitimate, but always remain vigilant." else: result += f"\nā“ **Recommendation**: Classification uncertain. Exercise caution and verify sender if needed." return result except Exception as e: logger.error(f"Error during prediction: {e}", exc_info=True) return f"āŒ **Error**: Analysis failed - {str(e)}" # Example emails for testing example_legitimate = """Dear Customer, Thank you for your recent purchase from TechStore. Your order #ORD-2024-001234 has been successfully processed. Order Details: - Product: Wireless Headphones - Amount: $79.99 - Estimated delivery: 3-5 business days You will receive a tracking number once your item ships. Best regards, TechStore Customer Service""" example_phishing = """URGENT SECURITY ALERT!!! Your account has been COMPROMISED! Immediate action required! Click here NOW to secure your account: http://fake-security-site.malicious.com/urgent-verify WARNING: You have only 24 hours before your account is permanently suspended! This is your FINAL notice - act immediately! Security Department""" example_neutral = """Hi team, Hope everyone is doing well. Just wanted to remind you about the meeting scheduled for tomorrow at 2 PM in the conference room. Please bring your project updates and any questions you might have. Thanks, Sarah""" # Load model on startup load_model() # Create enhanced Gradio interface with gr.Blocks(title="PhishGuardian AI", theme=gr.themes.Soft()) as iface: gr.Markdown(""" # šŸ›”ļø PhishGuardian AI - Enhanced Detection Advanced phishing email detection with colored risk indicators and improved model handling. """) with gr.Row(): with gr.Column(scale=2): email_input = gr.Textbox( lines=10, placeholder="Paste your email content here for analysis...", label="šŸ“§ Email Content", info="Enter the complete email text for comprehensive analysis" ) with gr.Row(): analyze_btn = gr.Button("šŸ” Analyze Email", variant="primary", size="lg") clear_btn = gr.Button("šŸ—‘ļø Clear", variant="secondary") with gr.Column(scale=2): output = gr.Textbox( label="šŸ›”ļø Security Analysis Results", lines=20, interactive=False, show_copy_button=True ) # Example section with better examples gr.Markdown("### šŸ“ Test Examples") with gr.Row(): legit_btn = gr.Button("āœ… Legitimate Email", size="sm") phish_btn = gr.Button("🚨 Phishing Email", size="sm") neutral_btn = gr.Button("šŸ“„ Neutral Text", size="sm") # Event handlers analyze_btn.click(predict_email, inputs=email_input, outputs=output) clear_btn.click(lambda: ("", ""), outputs=[email_input, output]) legit_btn.click(lambda: example_legitimate, outputs=email_input) phish_btn.click(lambda: example_phishing, outputs=email_input) neutral_btn.click(lambda: example_neutral, outputs=email_input) # Footer with model info gr.Markdown(""" --- **šŸ”§ Model**: cybersectony/phishing-email-detection-distilbert_v2.4.1 **šŸŽÆ Features**: Temperature scaling, colored risk bars, enhanced debugging **šŸ›ļø Institution**: University of Dar es Salaam (UDSM) """) if __name__ == "__main__": iface.launch( share=True, server_name="0.0.0.0", server_port=7860, show_error=True, debug=True )