melissachall commited on
Commit
298b5ca
Β·
verified Β·
1 Parent(s): e044e45

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +968 -0
app.py ADDED
@@ -0,0 +1,968 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ APT Classification System - Version corrigΓ©e et synchronisΓ©e
4
+ Correction des incohΓ©rences entre Streamlit et Gradio
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
11
+ import numpy as np
12
+ import json
13
+ import time
14
+ from datetime import datetime
15
+ import plotly.graph_objects as go
16
+ import re
17
+ import requests
18
+ import os
19
+ import io
20
+ from typing import Dict, List, Optional
21
+ import logging
22
+ from dataclasses import dataclass
23
+
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ @dataclass
29
+ class ClassificationResult:
30
+ predicted_class: str
31
+ confidence: float
32
+ top5_probabilities: Dict[str, float]
33
+ processing_time: float
34
+ extracted_features: Dict[str, List[str]]
35
+ attribution_factors: List[str]
36
+ timestamp: str
37
+
38
+ class CySecBERTMaxPerformance(nn.Module):
39
+ """Version EXACTEMENT identique Γ  Streamlit"""
40
+
41
+ def __init__(
42
+ self,
43
+ model_name: str = "markusbayer/CySecBERT",
44
+ num_classes: int = 12, # ⚠️ IMPORTANT: Doit correspondre au modèle sauvegardé
45
+ max_length: int = 384,
46
+ dropout_rate: float = 0.15
47
+ ):
48
+ super(CySecBERTMaxPerformance, self).__init__()
49
+
50
+ self.model_name = model_name
51
+ self.num_classes = num_classes
52
+ self.max_length = max_length
53
+
54
+ # CySecBERT specialized for cybersecurity
55
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ self.config = AutoConfig.from_pretrained(model_name)
57
+ self.bert = AutoModel.from_pretrained(model_name)
58
+
59
+ # EXPANDED architecture for maximum capacity
60
+ self.dropout = nn.Dropout(dropout_rate)
61
+ self.intermediate1 = nn.Linear(self.config.hidden_size, 512)
62
+ self.intermediate_dropout1 = nn.Dropout(dropout_rate * 0.6)
63
+ self.intermediate2 = nn.Linear(512, 256)
64
+ self.intermediate_dropout2 = nn.Dropout(dropout_rate * 0.7)
65
+
66
+ # Batch normalization for stability
67
+ self.batch_norm1 = nn.BatchNorm1d(512)
68
+ self.batch_norm2 = nn.BatchNorm1d(256)
69
+
70
+ self.classifier = nn.Linear(256, num_classes)
71
+
72
+ # Optimized activations
73
+ self.relu = nn.ReLU()
74
+ self.gelu = nn.GELU()
75
+
76
+ def forward(self, input_ids, attention_mask):
77
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
78
+
79
+ # [CLS] token with minimal dropout
80
+ cls_output = outputs.last_hidden_state[:, 0]
81
+ cls_output = self.dropout(cls_output)
82
+
83
+ # First LARGE intermediate layer
84
+ intermediate1 = self.gelu(self.intermediate1(cls_output))
85
+ intermediate1 = self.intermediate_dropout1(intermediate1)
86
+
87
+ if intermediate1.size(0) > 1:
88
+ intermediate1 = self.batch_norm1(intermediate1)
89
+
90
+ # Second intermediate layer
91
+ intermediate2 = self.relu(self.intermediate2(intermediate1))
92
+ intermediate2 = self.intermediate_dropout2(intermediate2)
93
+
94
+ if intermediate2.size(0) > 1:
95
+ intermediate2 = self.batch_norm2(intermediate2)
96
+
97
+ # Final classification
98
+ logits = self.classifier(intermediate2)
99
+
100
+ return {
101
+ 'logits': logits,
102
+ 'probabilities': torch.softmax(logits, dim=-1)
103
+ }
104
+
105
+ class APTClassifier:
106
+ def __init__(self):
107
+ self.device = torch.device('cpu')
108
+ self.model = None
109
+ self.class_names = []
110
+ self.label_encoder = None
111
+
112
+ # βœ… PROFILS APT IDENTIQUES Γ€ STREAMLIT (corrigΓ©s)
113
+ self.apt_profiles = {
114
+ 'APT1': {
115
+ 'country': 'China',
116
+ 'flag': 'πŸ‡¨πŸ‡³',
117
+ 'aliases': ['Comment Crew', 'Comment Group', 'PLA Unit 61398', 'Shanghai Group'],
118
+ 'description': 'Chinese cyber espionage group attributed to the People\'s Liberation Army Unit 61398. Known for large-scale intellectual property theft and targeting of over 140 organizations across 20 industries.',
119
+ 'first_observed': '2006',
120
+ 'attribution_confidence': 'High',
121
+ 'sponsor': 'State-sponsored (PLA Unit 61398)',
122
+ 'malware': ['WEBC2', 'BACKDOOR.BARKIOFORK', 'AURIGA', 'BANGAT', 'BISCUIT'],
123
+ 'tools': ['HTRAN', 'GSECDUMP', 'GETMAIL', 'MAPIGET'],
124
+ 'targets': ['Intellectual property', 'Government agencies', 'Industrial companies', 'Legal services', 'IT companies'],
125
+ 'sectors': ['Information Technology', 'Energy', 'Financial Services', 'Government', 'Healthcare'],
126
+ 'regions': ['United States', 'Canada', 'United Kingdom', 'India'],
127
+ 'ttps': ['T1566.001', 'T1059.003', 'T1071.001', 'T1083', 'T1005'],
128
+ 'mitre_groups': ['G0006'],
129
+ 'notable_campaigns': ['Operation Aurora (2009)', 'RSA SecurID breach (2011)', 'Elderwood campaigns'],
130
+ 'motivations': ['Espionage', 'Intellectual property theft'],
131
+ 'sophistication': 'Medium to High'
132
+ },
133
+ 'APT28': {
134
+ 'country': 'Russia',
135
+ 'flag': 'πŸ‡·πŸ‡Ί',
136
+ 'aliases': ['Fancy Bear', 'Sofacy', 'Sednit', 'STRONTIUM', 'Pawn Storm', 'Swallowtail'],
137
+ 'description': 'Russian military intelligence cyber operations unit attributed to GRU Unit 26165. Highly sophisticated group known for targeting government, military, and security organizations worldwide.',
138
+ 'first_observed': '2007',
139
+ 'attribution_confidence': 'High',
140
+ 'sponsor': 'State-sponsored (GRU Unit 26165)',
141
+ 'malware': ['X-Agent', 'Sofacy', 'GAMEFISH', 'Zebrocy', 'CHOPSTICK', 'EVILTOSS'],
142
+ 'tools': ['Responder', 'Mimikatz', 'Compiled HTML Help', 'PowerShell Empire'],
143
+ 'targets': ['Government agencies', 'Military organizations', 'Defense contractors', 'Aerospace', 'Media'],
144
+ 'sectors': ['Government', 'Defense', 'Aerospace', 'Media', 'Think Tanks'],
145
+ 'regions': ['United States', 'Europe', 'Asia-Pacific', 'Middle East'],
146
+ 'ttps': ['T1566.001', 'T1059.001', 'T1055', 'T1027', 'T1083', 'T1203'],
147
+ 'mitre_groups': ['G0007'],
148
+ 'notable_campaigns': ['DNC hack (2016)', 'Olympic Destroyer (2018)', 'UEFI rootkit campaigns'],
149
+ 'motivations': ['Espionage', 'Political influence', 'Military intelligence'],
150
+ 'sophistication': 'Very High'
151
+ },
152
+ 'APT29': {
153
+ 'country': 'Russia',
154
+ 'flag': 'πŸ‡·πŸ‡Ί',
155
+ 'aliases': ['Cozy Bear', 'The Dukes', 'NOBELIUM', 'Midnight Blizzard', 'UNC2452'],
156
+ 'description': 'Russian foreign intelligence service (SVR) cyber unit. Extremely sophisticated group known for stealth, persistence, and advanced techniques in espionage operations.',
157
+ 'first_observed': '2008',
158
+ 'attribution_confidence': 'High',
159
+ 'sponsor': 'State-sponsored (SVR)',
160
+ 'malware': ['HAMMERTOSS', 'COZYCAR', 'SeaDuke', 'SUNBURST', 'TEARDROP', 'BEACON'],
161
+ 'tools': ['PowerShell', 'WMI', 'Cobalt Strike', 'AdFind', 'BloodHound'],
162
+ 'targets': ['Government agencies', 'Think tanks', 'Healthcare organizations', 'Technology companies'],
163
+ 'sectors': ['Government', 'Healthcare', 'Technology', 'Research', 'NGOs'],
164
+ 'regions': ['United States', 'Europe', 'Global'],
165
+ 'ttps': ['T1566.002', 'T1071.001', 'T1055', 'T1027', 'T1078', 'T1490'],
166
+ 'mitre_groups': ['G0016'],
167
+ 'notable_campaigns': ['SolarWinds supply chain attack (2020)', 'COVID-19 research targeting', 'Azure/M365 attacks'],
168
+ 'motivations': ['Espionage', 'Intelligence gathering', 'Political influence'],
169
+ 'sophistication': 'Very High'
170
+ },
171
+ 'Lazarus': {
172
+ 'country': 'North Korea',
173
+ 'flag': 'πŸ‡°πŸ‡΅',
174
+ 'aliases': ['Lazarus Group', 'Hidden Cobra', 'ZINC', 'TEMP.Hermit', 'Labyrinth Chollima'],
175
+ 'description': 'North Korean state-sponsored hacking group known for financially motivated attacks, cryptocurrency theft, and destructive operations. Connected to RGB (Reconnaissance General Bureau).',
176
+ 'first_observed': '2009',
177
+ 'attribution_confidence': 'High',
178
+ 'sponsor': 'State-sponsored (RGB)',
179
+ 'malware': ['WannaCry', 'HOPLIGHT', 'TYPEFRAME', 'BADCALL', 'FALLCHILL', 'ELECTRICFISH'],
180
+ 'tools': ['PowerShell', 'Mimikatz', 'PsExec', 'Living-off-the-land binaries'],
181
+ 'targets': ['Financial institutions', 'Cryptocurrency exchanges', 'Entertainment companies', 'Defense contractors'],
182
+ 'sectors': ['Financial Services', 'Entertainment', 'Cryptocurrency', 'Defense', 'Healthcare'],
183
+ 'regions': ['Global', 'South Korea', 'United States', 'Europe'],
184
+ 'ttps': ['T1566.001', 'T1059.003', 'T1055', 'T1027', 'T1486', 'T1490'],
185
+ 'mitre_groups': ['G0032'],
186
+ 'notable_campaigns': ['Sony Pictures attack (2014)', 'WannaCry ransomware (2017)', 'SWIFT banking attacks'],
187
+ 'motivations': ['Financial gain', 'Espionage', 'Destruction', 'Sanctions evasion'],
188
+ 'sophistication': 'High'
189
+ },
190
+ 'Equation': {
191
+ 'country': 'United States (suspected)',
192
+ 'flag': 'πŸ‡ΊπŸ‡Έ',
193
+ 'aliases': ['Equation Group', 'EQGRP', 'Tilded Team'],
194
+ 'description': 'Highly sophisticated cyber espionage group suspected to be linked to the NSA. Known for advanced persistent threats, zero-day exploits, and firmware-level implants.',
195
+ 'first_observed': '2001',
196
+ 'attribution_confidence': 'Medium',
197
+ 'sponsor': 'State-sponsored (suspected NSA)',
198
+ 'malware': ['DOUBLEFANTASY', 'EQUATIONDRUG', 'GRAYFISH', 'FANNY', 'STUXNET'],
199
+ 'tools': ['EternalBlue', 'EternalRomance', 'DoublePulsar', 'FuzzBunch'],
200
+ 'targets': ['High-value targets', 'Government agencies', 'Telecommunications', 'Research institutions'],
201
+ 'sectors': ['Government', 'Telecommunications', 'Research', 'Technology', 'Energy'],
202
+ 'regions': ['Middle East', 'Asia', 'Europe', 'Global'],
203
+ 'ttps': ['T1055', 'T1027', 'T1083', 'T1068', 'T1542.009', 'T1014'],
204
+ 'mitre_groups': ['G0020'],
205
+ 'notable_campaigns': ['Operation Equation (2008-2015)', 'STUXNET collaboration', 'Flame malware'],
206
+ 'motivations': ['Espionage', 'Intelligence gathering', 'Sabotage'],
207
+ 'sophistication': 'Extremely High'
208
+ },
209
+ 'Carbanak': {
210
+ 'country': 'International',
211
+ 'flag': '🌍',
212
+ 'aliases': ['FIN7', 'Carbanak Group', 'Anunak', 'Carbon Spider'],
213
+ 'description': 'Financially motivated cybercriminal organization responsible for stealing over $1 billion from financial institutions worldwide through ATM and point-of-sale attacks.',
214
+ 'first_observed': '2013',
215
+ 'attribution_confidence': 'High',
216
+ 'sponsor': 'Cybercriminal',
217
+ 'malware': ['Carbanak', 'CARBANAK', 'HALFBAKED', 'BABYMETAL', 'GRIFFON'],
218
+ 'tools': ['Cobalt Strike', 'Mimikatz', 'PowerShell Empire', 'Metasploit'],
219
+ 'targets': ['Financial institutions', 'Banks', 'Payment processors', 'Hospitality', 'Retail'],
220
+ 'sectors': ['Financial Services', 'Hospitality', 'Retail', 'Restaurant'],
221
+ 'regions': ['Global', 'United States', 'Europe', 'Asia'],
222
+ 'ttps': ['T1566.001', 'T1059.003', 'T1055', 'T1027', 'T1021.001', 'T1083'],
223
+ 'mitre_groups': ['G0008', 'G0046'],
224
+ 'notable_campaigns': ['Carbanak banking attacks', 'FIN7 point-of-sale attacks', 'Restaurant POS campaigns'],
225
+ 'motivations': ['Financial gain'],
226
+ 'sophistication': 'High'
227
+ },
228
+ 'APT40': {
229
+ 'country': 'China',
230
+ 'flag': 'πŸ‡¨πŸ‡³',
231
+ 'aliases': ['Leviathan', 'TEMP.Periscope', 'TEMP.Jumper', 'Kryptonite Panda'],
232
+ 'description': 'Chinese state-sponsored cyber espionage group focused on maritime industries, engineering companies, and research organizations to support China\'s Belt and Road Initiative.',
233
+ 'first_observed': '2013',
234
+ 'attribution_confidence': 'High',
235
+ 'sponsor': 'State-sponsored (MSS Hainan)',
236
+ 'malware': ['BADFLICK', 'PHOTO', 'HOMEFRY', 'MURKYTOP', 'LUNCHMONEY'],
237
+ 'tools': ['China Chopper', 'Mimikatz', 'PowerShell', 'WMI'],
238
+ 'targets': ['Maritime industries', 'Engineering companies', 'Research organizations', 'Government agencies'],
239
+ 'sectors': ['Maritime', 'Engineering', 'Research', 'Government', 'Healthcare'],
240
+ 'regions': ['United States', 'Europe', 'Asia-Pacific'],
241
+ 'ttps': ['T1566.001', 'T1190', 'T1059.003', 'T1055', 'T1027'],
242
+ 'mitre_groups': ['G0065'],
243
+ 'notable_campaigns': ['Maritime industry targeting', 'COVID-19 research theft', 'Belt and Road surveillance'],
244
+ 'motivations': ['Espionage', 'Economic advantage', 'Strategic intelligence'],
245
+ 'sophistication': 'High'
246
+ }
247
+ }
248
+
249
+ # Cybersecurity indicators (identiques Γ  Streamlit)
250
+ self.security_indicators = {
251
+ 'malware': r'\b(trojan|virus|worm|ransomware|backdoor|rootkit|spyware|adware|botnet|rat|loader)\b',
252
+ 'techniques': r'\bT\d{4}(\.\d{3})?\b',
253
+ 'domains': r'\b[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b',
254
+ 'ips': r'\b(?:\d{1,3}\.){3}\d{1,3}\b',
255
+ 'hashes': r'\b[a-fA-F0-9]{32,64}\b',
256
+ 'cve': r'\bCVE-\d{4}-\d{4,}\b',
257
+ 'tools': r'\b(cobalt strike|metasploit|mimikatz|powershell|psexec|wmi|bloodhound)\b'
258
+ }
259
+
260
+ self.load_model()
261
+
262
+ def download_model_from_hf(self):
263
+ """TΓ©lΓ©chargement robuste avec vΓ©rification du checksum"""
264
+ try:
265
+ model_url = "https://huggingface.co/melissachall/cysecbert-apt-classifier/resolve/main/best_cysecbert_max_performance.pt"
266
+
267
+ logger.info(f"Downloading model from: {model_url}")
268
+ response = requests.get(model_url, timeout=300, stream=True)
269
+
270
+ if response.status_code == 200:
271
+ model_path = "downloaded_model.pt"
272
+ total_size = int(response.headers.get('content-length', 0))
273
+
274
+ with open(model_path, "wb") as f:
275
+ downloaded = 0
276
+ for chunk in response.iter_content(chunk_size=8192):
277
+ if chunk:
278
+ f.write(chunk)
279
+ downloaded += len(chunk)
280
+ if total_size > 0:
281
+ percent = (downloaded / total_size) * 100
282
+ if downloaded % 1000000 == 0: # Log every MB
283
+ logger.info(f"Download progress: {percent:.1f}%")
284
+
285
+ logger.info(f"βœ… Model downloaded: {downloaded} bytes")
286
+
287
+ # βœ… VALIDATION CRITIQUE DU MODÈLE TΓ‰LΓ‰CHARGΓ‰
288
+ try:
289
+ test_checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
290
+
291
+ # VΓ©rifier que les champs critiques existent
292
+ required_fields = ['model_state_dict', 'class_names']
293
+ missing_fields = [field for field in required_fields if field not in test_checkpoint]
294
+
295
+ if missing_fields:
296
+ raise ValueError(f"Missing critical fields in checkpoint: {missing_fields}")
297
+
298
+ class_names = test_checkpoint.get('class_names', [])
299
+ if len(class_names) == 0:
300
+ raise ValueError("Checkpoint has empty class_names")
301
+
302
+ logger.info(f"βœ… Model validation passed. Classes: {class_names}")
303
+ return model_path
304
+
305
+ except Exception as e:
306
+ logger.error(f"❌ Downloaded model validation failed: {e}")
307
+ if os.path.exists(model_path):
308
+ os.remove(model_path)
309
+ return None
310
+ else:
311
+ logger.error(f"❌ HTTP {response.status_code} for {model_url}")
312
+ return None
313
+
314
+ except Exception as e:
315
+ logger.error(f"❌ Download error: {e}")
316
+ return None
317
+
318
+ def load_model(self):
319
+ """Chargement EXACTEMENT identique Γ  Streamlit"""
320
+ try:
321
+ # Étape 1: Télécharger le modèle
322
+ model_path = self.download_model_from_hf()
323
+
324
+ if not model_path or not os.path.exists(model_path):
325
+ raise RuntimeError("❌ Cannot download model from HuggingFace")
326
+
327
+ # Γ‰tape 2: Charger le checkpoint
328
+ logger.info(f"Loading checkpoint from {model_path}")
329
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
330
+
331
+ # βœ… Γ‰TAPE CRITIQUE: RΓ©cupΓ©rer les mΓ©tadonnΓ©es EXACTEMENT comme Streamlit
332
+ self.class_names = checkpoint.get('class_names', [])
333
+ if not self.class_names:
334
+ raise RuntimeError("❌ Checkpoint missing class_names. Upload complete .pt file with metadata.")
335
+
336
+ self.label_encoder = checkpoint.get('label_encoder')
337
+ num_classes = len(self.class_names)
338
+
339
+ logger.info(f"βœ… Class names from checkpoint: {self.class_names}")
340
+ logger.info(f"βœ… Number of classes: {num_classes}")
341
+
342
+ # βœ… VΓ‰RIFICATION: Les class_names doivent correspondre aux profils
343
+ profile_classes = set(self.apt_profiles.keys())
344
+ checkpoint_classes = set(self.class_names)
345
+
346
+ if profile_classes != checkpoint_classes:
347
+ logger.warning(f"⚠️ MISMATCH DETECTED!")
348
+ logger.warning(f" Profile classes: {profile_classes}")
349
+ logger.warning(f" Checkpoint classes: {checkpoint_classes}")
350
+ logger.warning(f" Missing in profiles: {checkpoint_classes - profile_classes}")
351
+ logger.warning(f" Extra in profiles: {profile_classes - checkpoint_classes}")
352
+
353
+ # Étape 3: Créer le modèle avec les bonnes dimensions
354
+ self.model = CySecBERTMaxPerformance(
355
+ num_classes=num_classes,
356
+ dropout_rate=checkpoint.get('config', {}).get('dropout_rate', 0.15)
357
+ ).to(self.device)
358
+
359
+ # Γ‰tape 4: Charger les poids
360
+ if 'model_state_dict' not in checkpoint:
361
+ raise RuntimeError("❌ Checkpoint missing model_state_dict")
362
+
363
+ self.model.load_state_dict(checkpoint['model_state_dict'])
364
+ self.model.eval()
365
+
366
+ logger.info("βœ… MODEL LOADED SUCCESSFULLY - IDENTICAL TO STREAMLIT!")
367
+
368
+ # Nettoyage
369
+ if os.path.exists(model_path):
370
+ os.remove(model_path)
371
+
372
+ return True
373
+
374
+ except Exception as e:
375
+ logger.error(f"❌ Model loading error: {e}")
376
+ raise RuntimeError(f"Cannot load model: {e}")
377
+
378
+ def extract_features(self, text: str) -> Dict[str, List[str]]:
379
+ """IDENTIQUE Γ  Streamlit"""
380
+ features = {}
381
+ text_lower = text.lower()
382
+
383
+ for feature_type, pattern in self.security_indicators.items():
384
+ matches = re.findall(pattern, text_lower, re.IGNORECASE)
385
+ features[feature_type] = list(set(matches))[:10] # Limit to 10 items
386
+
387
+ return features
388
+
389
+ def get_attribution_factors(self, text: str, predicted_class: str) -> List[str]:
390
+ """IDENTIQUE Γ  Streamlit"""
391
+ factors = []
392
+ text_lower = text.lower()
393
+
394
+ if predicted_class in self.apt_profiles:
395
+ profile = self.apt_profiles[predicted_class]
396
+
397
+ # Check for group mentions
398
+ if predicted_class.lower() in text_lower:
399
+ factors.append(f"Direct mention of {predicted_class}")
400
+
401
+ # Check for aliases
402
+ for alias in profile.get('aliases', []):
403
+ if alias.lower() in text_lower:
404
+ factors.append(f"Alias detected: {alias}")
405
+
406
+ # Check for known malware
407
+ for malware in profile.get('malware', []):
408
+ if malware.lower() in text_lower:
409
+ factors.append(f"Known malware: {malware}")
410
+
411
+ # Check for tools
412
+ for tool in profile.get('tools', []):
413
+ if tool.lower() in text_lower:
414
+ factors.append(f"Known tool: {tool}")
415
+
416
+ # Check for target sectors
417
+ for target in profile.get('targets', []):
418
+ if target.lower() in text_lower:
419
+ factors.append(f"Target sector match: {target}")
420
+
421
+ # Check for TTPs
422
+ for ttp in profile.get('ttps', []):
423
+ if ttp in text:
424
+ factors.append(f"MITRE technique: {ttp}")
425
+
426
+ return factors
427
+
428
+ def classify(self, text: str, confidence_threshold: float = 0.5) -> ClassificationResult:
429
+ """Classification EXACTEMENT identique Γ  Streamlit"""
430
+ start_time = time.time()
431
+
432
+ # VΓ©rifications strictes
433
+ if self.model is None:
434
+ raise RuntimeError("❌ Model not loaded")
435
+
436
+ if not hasattr(self.model, 'tokenizer') or self.model.tokenizer is None:
437
+ raise RuntimeError("❌ Tokenizer not loaded")
438
+
439
+ logger.info("πŸš€ Using CySecBERTMaxPerformance (identical to Streamlit)")
440
+
441
+ # Tokenisation IDENTIQUE
442
+ encoding = self.model.tokenizer(
443
+ text,
444
+ max_length=self.model.max_length,
445
+ padding='max_length',
446
+ truncation=True,
447
+ return_tensors='pt'
448
+ )
449
+
450
+ input_ids = encoding['input_ids'].to(self.device)
451
+ attention_mask = encoding['attention_mask'].to(self.device)
452
+
453
+ # PrΓ©diction
454
+ with torch.no_grad():
455
+ outputs = self.model(input_ids, attention_mask)
456
+ probabilities = outputs['probabilities'].cpu().numpy()[0]
457
+
458
+ # Top 5 IDENTIQUE
459
+ top5_indices = np.argsort(probabilities)[::-1][:5]
460
+ predicted_class = self.class_names[top5_indices[0]]
461
+ confidence = float(probabilities[top5_indices[0]])
462
+
463
+ # Distribution top 5
464
+ top5_probabilities = {
465
+ self.class_names[idx]: float(probabilities[idx])
466
+ for idx in top5_indices
467
+ }
468
+
469
+ logger.info(f"βœ… Prediction: {predicted_class} ({confidence:.1%})")
470
+ logger.info(f"βœ… Top 5: {top5_probabilities}")
471
+
472
+ # Features et attribution
473
+ extracted_features = self.extract_features(text)
474
+ attribution_factors = self.get_attribution_factors(text, predicted_class)
475
+
476
+ processing_time = time.time() - start_time
477
+
478
+ return ClassificationResult(
479
+ predicted_class=predicted_class,
480
+ confidence=confidence,
481
+ top5_probabilities=top5_probabilities,
482
+ processing_time=processing_time,
483
+ extracted_features=extracted_features,
484
+ attribution_factors=attribution_factors,
485
+ timestamp=datetime.now().isoformat()
486
+ )
487
+
488
+ # ===== FONCTIONS UTILITAIRES IDENTIQUES =====
489
+
490
+ def process_uploaded_file(uploaded_file):
491
+ """Process uploaded file and extract text content"""
492
+ if uploaded_file is None:
493
+ return ""
494
+
495
+ try:
496
+ file_name = uploaded_file.name.lower()
497
+
498
+ if file_name.endswith('.txt'):
499
+ content = uploaded_file.read()
500
+ if isinstance(content, bytes):
501
+ return content.decode('utf-8', errors='ignore')
502
+ return str(content)
503
+
504
+ elif file_name.endswith('.json'):
505
+ content = uploaded_file.read()
506
+ if isinstance(content, bytes):
507
+ content = content.decode('utf-8')
508
+
509
+ try:
510
+ json_data = json.loads(content)
511
+ text_fields = []
512
+
513
+ def extract_text_from_json(obj, depth=0):
514
+ if depth > 3:
515
+ return
516
+
517
+ if isinstance(obj, dict):
518
+ for key, value in obj.items():
519
+ if isinstance(value, str) and len(value) > 10:
520
+ text_fields.append(f"{key}: {value}")
521
+ elif isinstance(value, (dict, list)):
522
+ extract_text_from_json(value, depth + 1)
523
+ elif isinstance(obj, list):
524
+ for item in obj:
525
+ extract_text_from_json(item, depth + 1)
526
+
527
+ extract_text_from_json(json_data)
528
+ return "\n".join(text_fields)
529
+ except:
530
+ return content
531
+
532
+ else:
533
+ # Generic text extraction
534
+ content = uploaded_file.read()
535
+ if isinstance(content, bytes):
536
+ return content.decode('utf-8', errors='ignore')
537
+ return str(content)
538
+
539
+ except Exception as e:
540
+ logger.error(f"File processing error: {e}")
541
+ return f"Error processing file: {str(e)}"
542
+
543
+ def create_prediction_plot(top5_probs):
544
+ """CrΓ©er le graphique des top 5 prΓ©dictions"""
545
+ fig = go.Figure(go.Bar(
546
+ x=list(top5_probs.values()),
547
+ y=list(top5_probs.keys()),
548
+ orientation='h',
549
+ marker=dict(
550
+ color=['#667eea', '#764ba2', '#f093fb', '#f5576c', '#4facfe'][:len(top5_probs)],
551
+ line=dict(color='rgba(50,50,50,0.8)', width=1)
552
+ ),
553
+ text=[f"{prob:.2%}" for prob in top5_probs.values()],
554
+ textposition='auto',
555
+ textfont=dict(size=12, color='white')
556
+ ))
557
+
558
+ fig.update_layout(
559
+ title=dict(
560
+ text="🎯 Top 5 APT Group Predictions",
561
+ font=dict(size=18, color='#2c3e50'),
562
+ x=0.5
563
+ ),
564
+ xaxis=dict(
565
+ title=dict(text="Confidence Score", font=dict(size=14)),
566
+ tickfont=dict(size=12),
567
+ range=[0, max(top5_probs.values()) * 1.1]
568
+ ),
569
+ yaxis=dict(
570
+ title=dict(text="APT Groups", font=dict(size=14)),
571
+ tickfont=dict(size=12)
572
+ ),
573
+ height=400,
574
+ margin=dict(l=100, r=50, t=80, b=50),
575
+ plot_bgcolor='rgba(248,249,250,0.8)',
576
+ paper_bgcolor='white'
577
+ )
578
+
579
+ return fig
580
+
581
+ def format_apt_profile(predicted_class, classifier):
582
+ """Formater le profil APT (IDENTIQUE Γ  Streamlit)"""
583
+ if predicted_class not in classifier.apt_profiles:
584
+ return f"<div style='padding: 1rem; background: #f8d7da; border-radius: 8px; color: #721c24;'>⚠️ No profile available for '{predicted_class}'. This might indicate a class mapping issue.</div>"
585
+
586
+ profile = classifier.apt_profiles[predicted_class]
587
+
588
+ html = f"""
589
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 15px; margin: 1rem 0; box-shadow: 0 10px 30px rgba(0,0,0,0.15);">
590
+ <h3 style="margin: 0 0 1.5rem 0; font-size: 1.8rem; text-align: center;">{profile.get('flag', '🌍')} {predicted_class} - Complete Profile</h3>
591
+
592
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 2rem;">
593
+ <div>
594
+ <h4 style="color: #ffd700; margin-bottom: 1rem;">πŸ“‹ Basic Information</h4>
595
+ <p><strong>Origin:</strong> {profile.get('country', 'Unknown')}</p>
596
+ <p><strong>First Observed:</strong> {profile.get('first_observed', 'Unknown')}</p>
597
+ <p><strong>Attribution Confidence:</strong> {profile.get('attribution_confidence', 'Unknown')}</p>
598
+ <p><strong>Sponsor:</strong> {profile.get('sponsor', 'Unknown')}</p>
599
+ <p><strong>Sophistication:</strong> {profile.get('sophistication', 'Unknown')}</p>
600
+
601
+ <h4 style="color: #ffd700; margin: 1rem 0;">🎭 Known Aliases</h4>
602
+ <ul style="margin: 0.5rem 0;">
603
+ {''.join([f"<li style='margin: 0.3rem 0;'>{alias}</li>" for alias in profile.get('aliases', [])])}
604
+ </ul>
605
+
606
+ <h4 style="color: #ffd700; margin: 1rem 0;">🦠 Associated Malware</h4>
607
+ <div style="display: flex; flex-wrap: wrap; gap: 0.5rem;">
608
+ {''.join([f"<span style='background: rgba(255,255,255,0.2); padding: 0.3rem 0.6rem; border-radius: 12px; font-size: 0.9rem;'>{malware}</span>" for malware in profile.get('malware', [])])}
609
+ </div>
610
+ </div>
611
+
612
+ <div>
613
+ <h4 style="color: #ffd700; margin-bottom: 1rem;">🎯 Typical Targets</h4>
614
+ <ul style="margin: 0.5rem 0;">
615
+ {''.join([f"<li style='margin: 0.3rem 0;'>{target}</li>" for target in profile.get('targets', [])])}
616
+ </ul>
617
+
618
+ <h4 style="color: #ffd700; margin: 1rem 0;">πŸ› οΈ Known Tools</h4>
619
+ <div style="display: flex; flex-wrap: wrap; gap: 0.5rem;">
620
+ {''.join([f"<span style='background: rgba(255,255,255,0.2); padding: 0.3rem 0.6rem; border-radius: 12px; font-size: 0.9rem;'>{tool}</span>" for tool in profile.get('tools', [])])}
621
+ </div>
622
+
623
+ <h4 style="color: #ffd700; margin: 1rem 0;">βš™οΈ MITRE ATT&CK TTPs</h4>
624
+ <div style="display: flex; flex-wrap: wrap; gap: 0.5rem;">
625
+ {''.join([f"<span style='background: rgba(255,255,255,0.3); padding: 0.3rem 0.6rem; border-radius: 12px; font-family: monospace; font-size: 0.9rem;'>{ttp}</span>" for ttp in profile.get('ttps', [])])}
626
+ </div>
627
+ </div>
628
+ </div>
629
+
630
+ <div style="margin-top: 1.5rem; padding: 1rem; background: rgba(255,255,255,0.1); border-radius: 10px;">
631
+ <h4 style="color: #ffd700; margin: 0 0 0.5rem 0;">πŸ“– Description</h4>
632
+ <p style="line-height: 1.6; margin: 0;">{profile.get('description', 'No description available')}</p>
633
+ </div>
634
+
635
+ <div style="margin-top: 1rem; padding: 1rem; background: rgba(255,255,255,0.1); border-radius: 10px;">
636
+ <h4 style="color: #ffd700; margin: 0 0 0.5rem 0;">🚨 Notable Campaigns</h4>
637
+ <ul style="margin: 0.5rem 0;">
638
+ {''.join([f"<li style='margin: 0.3rem 0;'>{campaign}</li>" for campaign in profile.get('notable_campaigns', [])])}
639
+ </ul>
640
+ </div>
641
+ </div>
642
+ """
643
+
644
+ return html
645
+
646
+ def classify_text(text, uploaded_file, confidence_threshold, show_features, show_attribution, show_profile):
647
+ """Fonction principale de classification - IDENTIQUE Γ  Streamlit"""
648
+
649
+ # Traitement de l'input
650
+ input_text = ""
651
+ file_processed = False
652
+
653
+ if uploaded_file is not None:
654
+ file_content = process_uploaded_file(uploaded_file)
655
+ if file_content.strip():
656
+ input_text = file_content
657
+ file_processed = True
658
+
659
+ if not input_text and text.strip():
660
+ input_text = text
661
+
662
+ if not input_text:
663
+ return (
664
+ "Please enter text or upload a file",
665
+ "No confidence",
666
+ None,
667
+ "0.000s",
668
+ "No input provided",
669
+ "No features extracted" if show_features else "",
670
+ "No attribution factors" if show_attribution else "",
671
+ "No profile available" if show_profile else "",
672
+ {},
673
+ "File processed successfully" if file_processed else ""
674
+ )
675
+
676
+ # Classification
677
+ try:
678
+ classifier = APTClassifier()
679
+ result = classifier.classify(input_text, confidence_threshold)
680
+
681
+ # Formatage des rΓ©sultats
682
+ confidence_color = "#27ae60" if result.confidence > 0.8 else "#f39c12" if result.confidence > 0.6 else "#e74c3c"
683
+
684
+ main_result = f"""
685
+ <div style="background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); padding: 2rem; border-radius: 15px; border-left: 6px solid #667eea; box-shadow: 0 8px 25px rgba(0,0,0,0.1);">
686
+ <div style="text-align: center; margin-bottom: 1.5rem;">
687
+ <h2 style="color: #2c3e50; margin: 0;">πŸ›‘οΈ APT Classification Result</h2>
688
+ <p style="color: #7f8c8d; margin: 0.5rem 0;">βœ… Model synchronized with Streamlit version</p>
689
+ </div>
690
+
691
+ <div style="display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 1rem; margin-bottom: 1rem;">
692
+ <div style="text-align: center; padding: 1rem; background: white; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
693
+ <h3 style="color: #667eea; margin: 0 0 0.5rem 0;">Predicted Group</h3>
694
+ <p style="font-size: 1.5rem; font-weight: bold; color: #2c3e50; margin: 0;">{result.predicted_class}</p>
695
+ </div>
696
+ <div style="text-align: center; padding: 1rem; background: white; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
697
+ <h3 style="color: #667eea; margin: 0 0 0.5rem 0;">Confidence</h3>
698
+ <p style="font-size: 1.5rem; font-weight: bold; color: {confidence_color}; margin: 0;">{result.confidence:.2%}</p>
699
+ </div>
700
+ <div style="text-align: center; padding: 1rem; background: white; border-radius: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
701
+ <h3 style="color: #667eea; margin: 0 0 0.5rem 0;">Processing Time</h3>
702
+ <p style="font-size: 1.5rem; font-weight: bold; color: #2c3e50; margin: 0;">{result.processing_time:.3f}s</p>
703
+ </div>
704
+ </div>
705
+
706
+ <div style="text-align: center; padding: 1rem; background: rgba(102, 126, 234, 0.1); border-radius: 10px;">
707
+ <p style="margin: 0; color: #667eea;"><strong>Analysis completed at:</strong> {datetime.fromisoformat(result.timestamp).strftime('%H:%M:%S UTC')}</p>
708
+ </div>
709
+ </div>
710
+ """
711
+
712
+ # Graphique
713
+ plot = create_prediction_plot(result.top5_probabilities)
714
+
715
+ # Features
716
+ features_html = ""
717
+ if show_features and any(result.extracted_features.values()):
718
+ features_html = "<div style='background: #f8f9fa; padding: 1.5rem; border-radius: 10px; border-left: 4px solid #17a2b8;'>"
719
+ features_html += "<h4 style='color: #2c3e50; margin-bottom: 1rem;'>πŸ” Extracted Cybersecurity Features</h4>"
720
+
721
+ icon_map = {'malware': '🦠', 'techniques': 'βš™οΈ', 'domains': '🌐', 'ips': 'πŸ”’', 'hashes': '#️⃣', 'cve': '🚨', 'tools': 'πŸ› οΈ'}
722
+
723
+ for feature_type, feature_list in result.extracted_features.items():
724
+ if feature_list:
725
+ icon = icon_map.get(feature_type, 'πŸ“Œ')
726
+ features_html += f"<p><strong>{icon} {feature_type.title()}:</strong></p>"
727
+ for feature in feature_list[:5]:
728
+ features_html += f"<span style='background: #e9ecef; padding: 0.3rem 0.6rem; margin: 0.2rem; border-radius: 12px; font-family: monospace; font-size: 0.9rem; display: inline-block;'>{feature}</span>"
729
+
730
+ features_html += "</div>"
731
+
732
+ # Attribution
733
+ attribution_html = ""
734
+ if show_attribution and result.attribution_factors:
735
+ attribution_html = "<div style='background: #f8f9fa; padding: 1.5rem; border-radius: 10px; border-left: 4px solid #28a745;'>"
736
+ attribution_html += "<h4 style='color: #2c3e50; margin-bottom: 1rem;'>🎯 Attribution Factors</h4>"
737
+ for factor in result.attribution_factors:
738
+ attribution_html += f"<div style='background: #e8f5e8; padding: 0.8rem; margin: 0.5rem 0; border-radius: 8px; border-left: 3px solid #28a745;'>{factor}</div>"
739
+ attribution_html += "</div>"
740
+
741
+ # Profil
742
+ profile_html = ""
743
+ if show_profile:
744
+ profile_html = format_apt_profile(result.predicted_class, classifier)
745
+
746
+ # Export data
747
+ export_data = {
748
+ 'predicted_class': result.predicted_class,
749
+ 'confidence': result.confidence,
750
+ 'processing_time': result.processing_time,
751
+ 'top5_probabilities': result.top5_probabilities,
752
+ 'extracted_features': result.extracted_features,
753
+ 'attribution_factors': result.attribution_factors,
754
+ 'timestamp': result.timestamp,
755
+ 'model_info': 'CySecBERTMaxPerformance - Synchronized with Streamlit'
756
+ }
757
+
758
+ file_status = f"βœ… File '{uploaded_file.name}' processed successfully ({len(input_text)} characters)" if file_processed else ""
759
+
760
+ return (
761
+ result.predicted_class,
762
+ f"{result.confidence:.2%}",
763
+ plot,
764
+ f"{result.processing_time:.3f}s",
765
+ main_result,
766
+ features_html,
767
+ attribution_html,
768
+ profile_html,
769
+ export_data,
770
+ file_status
771
+ )
772
+
773
+ except Exception as e:
774
+ error_msg = f"❌ Classification error: {str(e)}"
775
+ logger.error(error_msg)
776
+ return (
777
+ "Error",
778
+ "0%",
779
+ None,
780
+ "0.000s",
781
+ f"<div style='background: #f8d7da; padding: 1rem; border-radius: 8px; color: #721c24;'>{error_msg}</div>",
782
+ "",
783
+ "",
784
+ "",
785
+ {},
786
+ ""
787
+ )
788
+
789
+ # ===== INTERFACE GRADIO =====
790
+
791
+ # CSS optimisΓ©
792
+ css = """
793
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
794
+
795
+ .gradio-container {
796
+ font-family: 'Inter', sans-serif !important;
797
+ max-width: 1200px !important;
798
+ margin: 0 auto !important;
799
+ }
800
+
801
+ .header-container {
802
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
803
+ padding: 3rem 2rem;
804
+ border-radius: 20px;
805
+ text-align: center;
806
+ margin-bottom: 2rem;
807
+ box-shadow: 0 15px 35px rgba(102, 126, 234, 0.3);
808
+ }
809
+
810
+ .header-title {
811
+ color: white;
812
+ font-size: 3.5rem;
813
+ font-weight: 700;
814
+ margin: 0;
815
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
816
+ }
817
+
818
+ .header-subtitle {
819
+ color: rgba(255,255,255,0.95);
820
+ font-size: 1.3rem;
821
+ font-weight: 400;
822
+ margin: 1rem 0 0 0;
823
+ }
824
+
825
+ .status-success {
826
+ background: linear-gradient(135deg, #28a745 0%, #20c997 100%);
827
+ color: white;
828
+ padding: 1rem 2rem;
829
+ border-radius: 10px;
830
+ margin: 1rem 0;
831
+ text-align: center;
832
+ font-weight: 500;
833
+ box-shadow: 0 4px 15px rgba(40, 167, 69, 0.3);
834
+ }
835
+ """
836
+
837
+ # Examples optimisΓ©s pour les 7 groupes de Streamlit
838
+ example_texts = [
839
+ "Advanced persistent threat campaign attributed to APT28 (Fancy Bear) targeting government entities. Spear-phishing emails deliver X-Agent payload with T1566.001 techniques. Network analysis reveals C2 communications consistent with Sofacy operations and GRU Unit 26165 characteristics.",
840
+
841
+ "Financial institutions targeted by Lazarus Group operations. Watering hole attacks deploy custom malware for cryptocurrency theft. TTPs include T1566.001 and T1059.003, consistent with Hidden Cobra methodologies. HOPLIGHT and TYPEFRAME malware observed.",
842
+
843
+ "Government agencies report sophisticated malware attributed to APT29 (Cozy Bear). Advanced T1566.002 techniques with HAMMERTOSS C2. Campaign characteristics consistent with Russian SVR operations. NOBELIUM techniques with SUNBURST and TEARDROP malware detected.",
844
+
845
+ "Chinese cyber espionage operations attributed to APT1 targeting intellectual property. Comment Crew techniques observed with WEBC2 and BACKDOOR.BARKIOFORK malware. Campaign consistent with PLA Unit 61398 operations targeting industrial companies.",
846
+
847
+ "Advanced threat operations attributed to Equation Group. Zero-day exploits and firmware-level implants detected. DOUBLEFANTASY and EQUATIONDRUG malware with characteristics consistent with NSA-linked operations."
848
+ ]
849
+
850
+ # Interface Gradio principale
851
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="APT Classification System - Fixed") as demo:
852
+
853
+ # Header
854
+ gr.HTML("""
855
+ <div class="header-container">
856
+ <h1 class="header-title">πŸ›‘οΈ APT Classification System</h1>
857
+ <p class="header-subtitle">CySecBERTMaxPerformance - Synchronized with Streamlit</p>
858
+ <p style="color: rgba(255,255,255,0.9); margin: 0.5rem 0;">πŸ”„ Fixed version - Identical behavior to Streamlit interface</p>
859
+ </div>
860
+ """)
861
+
862
+ # Status
863
+ gr.HTML("""
864
+ <div class="status-success">
865
+ βœ… <strong>SYNCHRONIZED MODEL</strong> β€’ Fixed class mapping β€’ Identical to Streamlit version
866
+ </div>
867
+ """)
868
+
869
+ with gr.Row():
870
+ # Main input column
871
+ with gr.Column(scale=2):
872
+ gr.Markdown("### πŸ“ Threat Intelligence Input")
873
+
874
+ with gr.Tab("Text Input"):
875
+ text_input = gr.Textbox(
876
+ lines=6,
877
+ placeholder="Describe the cybersecurity incident, including TTPs, malware, targets, and attribution indicators...",
878
+ label="Incident Description",
879
+ show_label=False
880
+ )
881
+
882
+ gr.Examples(
883
+ examples=[[text] for text in example_texts],
884
+ inputs=text_input,
885
+ label="πŸ“š Test Cases for APT Groups"
886
+ )
887
+
888
+ with gr.Tab("File Upload"):
889
+ file_input = gr.File(
890
+ file_types=[".txt", ".log", ".json", ".csv"],
891
+ label="Upload Threat Intelligence Report",
892
+ file_count="single"
893
+ )
894
+
895
+ file_status = gr.HTML("")
896
+
897
+ # Configuration column
898
+ with gr.Column(scale=1):
899
+ gr.Markdown("### βš™οΈ Analysis Configuration")
900
+
901
+ confidence_threshold = gr.Slider(
902
+ 0.0, 1.0, value=0.3, step=0.05,
903
+ label="🎯 Confidence Threshold",
904
+ info="Minimum confidence for predictions"
905
+ )
906
+
907
+ gr.Markdown("### πŸ“Š Display Options")
908
+ show_features = gr.Checkbox(value=True, label="πŸ” Extract Features")
909
+ show_attribution = gr.Checkbox(value=True, label="🎯 Show Attribution")
910
+ show_profile = gr.Checkbox(value=True, label="πŸ“‹ Complete Profile")
911
+
912
+ analyze_button = gr.Button(
913
+ "πŸ” ANALYZE THREAT",
914
+ variant="primary",
915
+ size="lg"
916
+ )
917
+
918
+ gr.Markdown("### πŸ“ˆ Model Information")
919
+ gr.HTML("""
920
+ <div style="background: #f8f9fa; padding: 1rem; border-radius: 8px; font-size: 0.9rem;">
921
+ <strong>Status:</strong> βœ… Fixed & Synchronized<br>
922
+ <strong>Architecture:</strong> CySecBERT + Custom Layers<br>
923
+ <strong>Classes:</strong> 7 APT groups (same as Streamlit)<br>
924
+ <strong>Source:</strong> HuggingFace Hub download<br>
925
+ <strong>Validation:</strong> Class mapping verified
926
+ </div>
927
+ """)
928
+
929
+ # Results section
930
+ gr.Markdown("## πŸ“Š Analysis Results")
931
+
932
+ with gr.Row():
933
+ predicted_class = gr.Textbox(label="🎯 Predicted APT Group", interactive=False)
934
+ confidence_display = gr.Textbox(label="πŸ“Š Confidence Score", interactive=False)
935
+ processing_time = gr.Textbox(label="⚑ Processing Time", interactive=False)
936
+
937
+ # Main results
938
+ main_result = gr.HTML(label="Main Results")
939
+
940
+ # Visualization
941
+ prediction_plot = gr.Plot(label="Top 5 Predictions Visualization")
942
+
943
+ # Additional results
944
+ with gr.Row():
945
+ features_output = gr.HTML(label="Extracted Features")
946
+ attribution_output = gr.HTML(label="Attribution Factors")
947
+
948
+ # APT Profile
949
+ profile_output = gr.HTML(label="Complete APT Profile")
950
+
951
+ # Export
952
+ gr.Markdown("### πŸ’Ύ Export Results")
953
+ export_data = gr.JSON(label="Analysis Data", visible=False)
954
+
955
+ # Event handler
956
+ analyze_button.click(
957
+ classify_text,
958
+ inputs=[text_input, file_input, confidence_threshold, show_features, show_attribution, show_profile],
959
+ outputs=[predicted_class, confidence_display, prediction_plot, processing_time, main_result, features_output, attribution_output, profile_output, export_data, file_status]
960
+ )
961
+
962
+ if __name__ == "__main__":
963
+ demo.launch(
964
+ server_name="0.0.0.0",
965
+ server_port=7860,
966
+ share=False,
967
+ show_error=True
968
+ )