Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| import numpy as np | |
| import json | |
| from datetime import datetime | |
| import logging | |
| import os | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class FixedMultiAgentSystem: | |
| def __init__(self): | |
| self.detection_agent = None | |
| self.counter_speech_agent = None | |
| self.moderation_agent = None | |
| self.sentiment_agent = None | |
| # Load prompt configurations with better error handling | |
| self.counter_speech_prompts = self.load_prompts("counter_speech_prompts.json") | |
| self.moderation_prompts = self.load_prompts("moderation_prompts.json") | |
| self.initialize_agents() | |
| def load_prompts(self, filename): | |
| """Load prompts from JSON file with robust fallback""" | |
| try: | |
| if os.path.exists(filename): | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| else: | |
| logger.warning(f"Prompt file {filename} not found, using built-in prompts") | |
| return self.get_default_prompts(filename) | |
| except Exception as e: | |
| logger.error(f"Error loading prompts from {filename}: {e}") | |
| return self.get_default_prompts(filename) | |
| def get_default_prompts(self, filename): | |
| """Comprehensive default prompts as fallback""" | |
| if "counter_speech" in filename: | |
| return { | |
| "counter_speech_prompts": { | |
| "high_risk": { | |
| "system_prompt": "You are an expert educator specializing in counter-speech and conflict de-escalation.", | |
| "user_prompt_template": "Generate a respectful, educational counter-speech response to address harmful content while promoting understanding. Original text (Risk: {risk_level}, Confidence: {confidence}%, Sentiment: {sentiment}): \"{original_text}\"\n\nProvide a constructive response that educates without attacking:", | |
| "fallback_responses": [ | |
| "This type of language can cause real harm to individuals and communities. Consider expressing your concerns in a way that respects everyone's dignity and opens constructive dialogue.", | |
| "Instead of divisive language, try focusing on shared values and common ground. Everyone deserves respect regardless of their background.", | |
| "Strong communities are built on mutual respect and understanding. How can we work together rather than against each other?" | |
| ] | |
| }, | |
| "medium_risk": { | |
| "fallback_responses": [ | |
| "This message might be interpreted as harmful by some. Consider rephrasing to express your thoughts more constructively.", | |
| "Try framing your message to invite discussion rather than potentially excluding others.", | |
| "How might you express this sentiment in a way that brings people together rather than apart?" | |
| ] | |
| }, | |
| "low_risk": { | |
| "fallback_responses": [ | |
| "While this seems mostly positive, consider how your words might be received by everyone in the conversation.", | |
| "Every interaction is a chance to build understanding and connection.", | |
| "Consider how you can use your voice to create an even more welcoming environment." | |
| ] | |
| }, | |
| "general_template": { | |
| "fallback_responses": [ | |
| "Thank you for sharing your thoughts. Building strong communities works best when we focus on shared values and constructive dialogue.", | |
| "I appreciate your perspective. Sometimes our strongest feelings can be expressed in ways that bring people together.", | |
| "Your engagement with this topic is clear. When we channel that energy into inclusive dialogue, we often find solutions that work for everyone." | |
| ] | |
| } | |
| } | |
| } | |
| else: | |
| return { | |
| "moderation_prompts": { | |
| "comprehensive_analysis": { | |
| "system_prompt": "You are an expert content moderation specialist analyzing text for safety and compliance.", | |
| "user_prompt_template": "Analyze this text for potential violations: \"{text}\"\n\nProvide brief analysis: 1) Safety level 2) Main concerns 3) Recommended action\n\nAnalysis:", | |
| } | |
| } | |
| } | |
| def initialize_agents(self): | |
| """Initialize all AI agents with proper error handling""" | |
| logger.info("🤖 Initializing Fixed Multi-Agent System...") | |
| self.setup_detection_agent() | |
| self.setup_lightweight_agents() | |
| logger.info("✅ All agents initialized successfully!") | |
| def setup_detection_agent(self): | |
| """Initialize the hate speech detection agent with proper label handling""" | |
| try: | |
| logger.info("🔍 Loading Detection Agent (Fine-tuned DistilBERT)...") | |
| model_path = "./model" | |
| # Load model components | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float32 | |
| ) | |
| self.detection_agent = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_all_scores=True, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # Test the model to understand its label mapping | |
| self.test_model_labels() | |
| logger.info("✅ Detection Agent loaded successfully") | |
| except Exception as e: | |
| logger.error(f"❌ Detection Agent failed: {e}") | |
| logger.info("🔄 Using fallback detection model...") | |
| self.detection_agent = pipeline( | |
| "text-classification", | |
| model="unitary/toxic-bert", | |
| return_all_scores=True | |
| ) | |
| self.model_label_mapping = {"TOXIC": "hate", "NORMAL": "normal"} | |
| def test_model_labels(self): | |
| """Test model to understand its label mapping""" | |
| try: | |
| # Test with obviously safe text | |
| safe_text = "I love sunny days and happy people." | |
| results = self.detection_agent(safe_text) | |
| if isinstance(results, list) and len(results) > 0: | |
| if isinstance(results[0], list): | |
| results = results[0] | |
| # Find the label with highest score for safe text | |
| max_result = max(results, key=lambda x: x['score']) | |
| safe_label = max_result['label'] | |
| # Determine label mapping | |
| if safe_label in ['LABEL_0', '0']: | |
| self.model_label_mapping = {"LABEL_0": "normal", "LABEL_1": "hate"} | |
| self.hate_label = "LABEL_1" | |
| self.normal_label = "LABEL_0" | |
| elif safe_label in ['LABEL_1', '1']: | |
| self.model_label_mapping = {"LABEL_0": "hate", "LABEL_1": "normal"} | |
| self.hate_label = "LABEL_0" | |
| self.normal_label = "LABEL_1" | |
| else: | |
| # For models with explicit labels | |
| self.model_label_mapping = {safe_label: "normal"} | |
| self.normal_label = safe_label | |
| # Find the other label | |
| other_labels = [r['label'] for r in results if r['label'] != safe_label] | |
| if other_labels: | |
| self.hate_label = other_labels[0] | |
| self.model_label_mapping[self.hate_label] = "hate" | |
| logger.info(f"Model label mapping determined: {self.model_label_mapping}") | |
| logger.info(f"Normal label: {self.normal_label}, Hate label: {self.hate_label}") | |
| except Exception as e: | |
| logger.error(f"Error testing model labels: {e}") | |
| # Default assumption | |
| self.model_label_mapping = {"LABEL_0": "normal", "LABEL_1": "hate"} | |
| self.hate_label = "LABEL_1" | |
| self.normal_label = "LABEL_0" | |
| def setup_lightweight_agents(self): | |
| """Setup only essential additional agents to reduce load time""" | |
| try: | |
| logger.info("📊 Loading Lightweight Sentiment Agent...") | |
| self.sentiment_agent = pipeline( | |
| "sentiment-analysis", | |
| model="cardiffnlp/twitter-roberta-base-sentiment-latest", | |
| return_all_scores=True, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| logger.info("✅ Sentiment Agent loaded") | |
| # Skip heavy FLAN-T5 models for now - use template-based responses | |
| logger.info("💬 Using template-based counter-speech (fast mode)") | |
| self.counter_speech_agent = None | |
| self.moderation_agent = None | |
| except Exception as e: | |
| logger.error(f"❌ Lightweight agents failed: {e}") | |
| self.sentiment_agent = None | |
| def detect_hate_speech(self, text): | |
| """Fixed detection with proper label interpretation""" | |
| if not text or not text.strip(): | |
| return { | |
| "status": "❌ Please enter some text to analyze.", | |
| "prediction": "No input", | |
| "confidence": 0.0, | |
| "all_scores": {}, | |
| "risk_level": "Unknown", | |
| "is_hate_speech": False | |
| } | |
| try: | |
| results = self.detection_agent(text.strip()) | |
| if isinstance(results, list) and len(results) > 0: | |
| if isinstance(results[0], list): | |
| results = results[0] | |
| all_scores = {} | |
| hate_score = 0 | |
| normal_score = 0 | |
| # Process results with correct label mapping | |
| for result in results: | |
| label = result["label"] | |
| score = result["score"] | |
| # Map to human-readable labels | |
| mapped_label = self.model_label_mapping.get(label, label) | |
| all_scores[f"{label} ({mapped_label})"] = { | |
| "score": score, | |
| "percentage": f"{score*100:.2f}%", | |
| "confidence": f"{score:.4f}" | |
| } | |
| # Track hate vs normal scores | |
| if label == getattr(self, 'hate_label', 'LABEL_1'): | |
| hate_score = score | |
| elif label == getattr(self, 'normal_label', 'LABEL_0'): | |
| normal_score = score | |
| # Determine final classification based on hate score | |
| is_hate_speech = False | |
| risk_level = "Low" | |
| predicted_label = "Normal" | |
| confidence = normal_score | |
| if hate_score > normal_score: | |
| # This is likely hate speech | |
| confidence = hate_score | |
| predicted_label = "Hate Speech" | |
| if hate_score > 0.8: | |
| is_hate_speech = True | |
| risk_level = "High" | |
| status = f"🚨 High confidence hate speech detected! (Hate: {hate_score:.2%})" | |
| elif hate_score > 0.6: | |
| is_hate_speech = True | |
| risk_level = "Medium" | |
| status = f"⚠️ Potential hate speech detected (Hate: {hate_score:.2%})" | |
| else: | |
| risk_level = "Low-Medium" | |
| status = f"⚡ Low confidence hate detection (Hate: {hate_score:.2%})" | |
| else: | |
| # This is normal/safe content | |
| risk_level = "Low" | |
| status = f"✅ No hate speech detected (Normal: {normal_score:.2%})" | |
| return { | |
| "status": status, | |
| "prediction": predicted_label, | |
| "confidence": confidence, | |
| "all_scores": all_scores, | |
| "risk_level": risk_level, | |
| "is_hate_speech": is_hate_speech, | |
| "hate_score": hate_score, | |
| "normal_score": normal_score | |
| } | |
| except Exception as e: | |
| logger.error(f"Detection error: {e}") | |
| return { | |
| "status": f"❌ Detection error: {str(e)}", | |
| "prediction": "Error", | |
| "confidence": 0.0, | |
| "all_scores": {}, | |
| "risk_level": "Unknown", | |
| "is_hate_speech": False | |
| } | |
| def analyze_sentiment(self, text): | |
| """Fast sentiment analysis""" | |
| if not self.sentiment_agent or not text.strip(): | |
| return {"sentiment": "neutral", "confidence": 0.0, "all_sentiments": {}} | |
| try: | |
| results = self.sentiment_agent(text.strip()) | |
| if isinstance(results, list) and len(results) > 0: | |
| if isinstance(results[0], list): | |
| results = results[0] | |
| best_sentiment = max(results, key=lambda x: x['score']) | |
| return { | |
| "sentiment": best_sentiment['label'].lower(), | |
| "confidence": best_sentiment['score'], | |
| "all_sentiments": {r['label']: r['score'] for r in results} | |
| } | |
| except Exception as e: | |
| logger.error(f"Sentiment analysis error: {e}") | |
| return {"sentiment": "neutral", "confidence": 0.0, "all_sentiments": {}} | |
| def generate_template_moderation(self, text, detection_result, sentiment_result): | |
| """Fast template-based moderation analysis""" | |
| risk_level = detection_result.get("risk_level", "Low").lower() | |
| confidence = detection_result.get("confidence", 0.0) | |
| hate_score = detection_result.get("hate_score", 0.0) | |
| if hate_score > 0.8: | |
| analysis = f"🚨 HIGH RISK: Clear hate speech detected with {confidence:.1%} confidence. Immediate review recommended. Content may violate community standards and could cause harm." | |
| safety_level = "harmful" | |
| elif hate_score > 0.6: | |
| analysis = f"⚠️ MEDIUM RISK: Potentially problematic content detected with {confidence:.1%} confidence. Human review recommended to assess context and intent." | |
| safety_level = "concerning" | |
| elif hate_score > 0.3: | |
| analysis = f"⚡ LOW RISK: Minor concerns detected with {confidence:.1%} confidence. Content appears mostly acceptable but may benefit from user education." | |
| safety_level = "review_needed" | |
| else: | |
| analysis = f"✅ SAFE: No significant violations detected. Content appears to meet community standards with {confidence:.1%} confidence." | |
| safety_level = "safe" | |
| return { | |
| "analysis": analysis, | |
| "confidence": confidence, | |
| "safety_level": safety_level, | |
| "method": "template_based_fast" | |
| } | |
| def generate_template_counter_speech(self, text, detection_result, sentiment_result): | |
| """Fast template-based counter-speech""" | |
| if not detection_result.get("is_hate_speech", False): | |
| return "✨ This text promotes positive communication. Great job maintaining respectful dialogue!" | |
| risk_level = detection_result.get("risk_level", "Low").lower() | |
| # Get appropriate responses from prompts | |
| counter_config = self.counter_speech_prompts.get("counter_speech_prompts", {}) | |
| if risk_level == "high": | |
| responses = counter_config.get("high_risk", {}).get("fallback_responses", [ | |
| "This type of language can cause real harm. Consider expressing concerns in a way that respects everyone's dignity." | |
| ]) | |
| elif risk_level == "medium": | |
| responses = counter_config.get("medium_risk", {}).get("fallback_responses", [ | |
| "This message might be harmful to some. Consider rephrasing to express thoughts more constructively." | |
| ]) | |
| else: | |
| responses = counter_config.get("low_risk", {}).get("fallback_responses", [ | |
| "Consider how your words might be received by everyone in the conversation." | |
| ]) | |
| import random | |
| return f"📝 **Educational Response** ({risk_level.title()} Risk): {random.choice(responses)}" | |
| def comprehensive_analysis(self, text): | |
| """Fast comprehensive analysis with fixed logic""" | |
| start_time = datetime.now() | |
| # Run core analysis | |
| detection_result = self.detect_hate_speech(text) | |
| sentiment_result = self.analyze_sentiment(text) | |
| # Run fast template-based analysis | |
| moderation_result = self.generate_template_moderation(text, detection_result, sentiment_result) | |
| counter_speech = self.generate_template_counter_speech(text, detection_result, sentiment_result) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| return { | |
| "detection": detection_result, | |
| "sentiment": sentiment_result, | |
| "moderation": moderation_result, | |
| "counter_speech": counter_speech, | |
| "processing_time": processing_time, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Initialize the fixed system | |
| logger.info("🚀 Starting Fixed Multi-Agent System...") | |
| agent_system = FixedMultiAgentSystem() | |
| def analyze_text_fixed(text): | |
| """Fixed analysis function with proper logic""" | |
| if not text or not text.strip(): | |
| return ( | |
| "❌ Please enter some text to analyze.", | |
| {}, | |
| "No analysis performed.", | |
| "No input provided", | |
| {} | |
| ) | |
| # Run fixed analysis | |
| results = agent_system.comprehensive_analysis(text) | |
| # Extract results for display | |
| detection_status = results["detection"]["status"] | |
| detection_scores = results["detection"]["all_scores"] | |
| counter_speech = results["counter_speech"] | |
| # Create detailed agent summary | |
| agent_summary = f""" | |
| 🔍 **Detection Agent**: {results['detection']['risk_level']} risk ({results['detection']['confidence']:.2%} confidence) | |
| ↳ Hate Score: {results['detection'].get('hate_score', 0):.2%} | Normal Score: {results['detection'].get('normal_score', 0):.2%} | |
| 📊 **Sentiment Agent**: {results['sentiment']['sentiment'].title()} ({results['sentiment']['confidence']:.2%} confidence) | |
| 🛡️ **Moderation Agent**: {results['moderation']['safety_level'].title()} ({results['moderation']['method']}) | |
| 💬 **Counter-Speech Agent**: Template-based response system | |
| ⚡ **Processing Time**: {results['processing_time']:.2f} seconds (Fixed & Optimized) | |
| 📋 **Quick Analysis**: {results['moderation']['analysis'][:150]}... | |
| """ | |
| # Compile comprehensive data | |
| all_agent_data = { | |
| "Detection_Analysis": { | |
| "corrected_scores": detection_scores, | |
| "hate_score": results['detection'].get('hate_score', 0), | |
| "normal_score": results['detection'].get('normal_score', 0), | |
| "final_prediction": results['detection']['prediction'], | |
| "risk_level": results['detection']['risk_level'], | |
| "is_hate_speech": results['detection']['is_hate_speech'] | |
| }, | |
| "Sentiment_Analysis": { | |
| "primary_sentiment": results['sentiment']['sentiment'], | |
| "all_sentiments": results['sentiment'].get('all_sentiments', {}) | |
| }, | |
| "Moderation_Analysis": { | |
| "safety_level": results['moderation']['safety_level'], | |
| "analysis": results['moderation']['analysis'], | |
| "method": results['moderation']['method'] | |
| }, | |
| "System_Info": { | |
| "mode": "Fixed & Optimized", | |
| "processing_time_seconds": results['processing_time'], | |
| "timestamp": results['timestamp'], | |
| "model_labels": getattr(agent_system, 'model_label_mapping', {}) | |
| } | |
| } | |
| return detection_status, detection_scores, counter_speech, agent_summary, all_agent_data | |
| # Create the fixed interface | |
| with gr.Blocks( | |
| title="Fixed Multi-Agent Hate Speech Detection", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox( | |
| label="Enter text for fixed multi-agent analysis", | |
| placeholder="Test the fixed system with any text...", | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button("🔧 Run Fixed Analysis", variant="primary", size="lg") | |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary") | |
| gr.Examples( | |
| examples=[ | |
| ["The diversity in our group makes our discussions much richer and more meaningful."], | |
| ["I love collaborating with people from different backgrounds."], | |
| ["This is a wonderful day to learn something new!"], | |
| ["Thank you for sharing your perspective with us."], | |
| ["Let's work together to build something amazing."] | |
| ], | |
| inputs=text_input, | |
| label="📝 Test with these examples (should show as SAFE):" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| detection_output = gr.Textbox( | |
| label="🎯 Fixed Detection Result", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| agent_summary = gr.Textbox( | |
| label="🔧 Fixed Agent Summary", | |
| interactive=False, | |
| lines=8 | |
| ) | |
| with gr.Column(): | |
| counter_speech_output = gr.Textbox( | |
| label="💬 Counter-Speech Response", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| all_agents_output = gr.JSON( | |
| label="📊 Complete Fixed Analysis Data", | |
| visible=True | |
| ) | |
| # Event handlers | |
| analyze_btn.click( | |
| fn=analyze_text_fixed, | |
| inputs=text_input, | |
| outputs=[detection_output, all_agents_output, counter_speech_output, agent_summary, all_agents_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", "", {}), | |
| outputs=[text_input, detection_output, counter_speech_output, agent_summary, all_agents_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| share=False | |
| ) |