Trouter-Library commited on
Commit
dd84874
·
verified ·
1 Parent(s): e54d66c

Create inference/security.py

Browse files
Files changed (1) hide show
  1. inference/inference/security.py +591 -0
inference/inference/security.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Security Implementation
4
+ Comprehensive security features for safe model deployment
5
+ """
6
+
7
+ import hashlib
8
+ import hmac
9
+ import json
10
+ import logging
11
+ import re
12
+ import secrets
13
+ from collections import defaultdict
14
+ from datetime import datetime, timedelta
15
+ from typing import Dict, List, Optional, Tuple
16
+ from pathlib import Path
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class InputValidator:
23
+ """Validate and sanitize user inputs"""
24
+
25
+ MAX_PROMPT_LENGTH = 131072
26
+ MAX_TOKEN_LIMIT = 8192
27
+
28
+ # Dangerous patterns that could indicate attacks
29
+ DANGEROUS_PATTERNS = [
30
+ r'<script.*?>',
31
+ r'javascript:',
32
+ r'on\w+\s*=',
33
+ r'\beval\s*\(',
34
+ r'\bexec\s*\(',
35
+ r'__import__',
36
+ r'\bos\.',
37
+ r'\bsystem\(',
38
+ r'subprocess',
39
+ r'\[\[.*?\]\].*?\[\[.*?\]\]', # Repeated prompt injection
40
+ ]
41
+
42
+ @classmethod
43
+ def validate_prompt(cls, prompt: str) -> Tuple[bool, Optional[str]]:
44
+ """
45
+ Validate prompt input for security issues
46
+
47
+ Args:
48
+ prompt: Input prompt string
49
+
50
+ Returns:
51
+ (is_valid, error_message)
52
+ """
53
+ # Check for empty input
54
+ if not prompt or not prompt.strip():
55
+ return False, "Prompt cannot be empty"
56
+
57
+ # Length validation
58
+ if len(prompt) > cls.MAX_PROMPT_LENGTH:
59
+ return False, f"Prompt exceeds maximum length of {cls.MAX_PROMPT_LENGTH}"
60
+
61
+ # Check for null bytes
62
+ if '\x00' in prompt:
63
+ return False, "Prompt contains null bytes"
64
+
65
+ # Check for dangerous patterns
66
+ for pattern in cls.DANGEROUS_PATTERNS:
67
+ if re.search(pattern, prompt, re.IGNORECASE | re.MULTILINE):
68
+ logger.warning(f"Dangerous pattern detected: {pattern}")
69
+ return False, f"Prompt contains potentially dangerous content"
70
+
71
+ # Check for excessive repetition (possible DoS)
72
+ words = prompt.split()
73
+ if len(words) > 100:
74
+ word_counts = {}
75
+ for word in words:
76
+ word_counts[word] = word_counts.get(word, 0) + 1
77
+ if word_counts[word] > len(words) * 0.5:
78
+ return False, "Excessive repetition detected"
79
+
80
+ return True, None
81
+
82
+ @classmethod
83
+ def sanitize_text(cls, text: str) -> str:
84
+ """
85
+ Sanitize text by removing dangerous content
86
+
87
+ Args:
88
+ text: Input text
89
+
90
+ Returns:
91
+ Sanitized text
92
+ """
93
+ # Remove script tags
94
+ text = re.sub(r'<script.*?</script>', '', text, flags=re.DOTALL | re.IGNORECASE)
95
+
96
+ # Remove javascript: protocol
97
+ text = re.sub(r'javascript:', '', text, flags=re.IGNORECASE)
98
+
99
+ # Remove event handlers
100
+ text = re.sub(r'\bon\w+\s*=\s*["\'].*?["\']', '', text, flags=re.IGNORECASE)
101
+
102
+ return text
103
+
104
+ @classmethod
105
+ def validate_generation_params(cls, params: Dict) -> Tuple[bool, Optional[str]]:
106
+ """
107
+ Validate generation parameters
108
+
109
+ Args:
110
+ params: Generation parameters dictionary
111
+
112
+ Returns:
113
+ (is_valid, error_message)
114
+ """
115
+ # Temperature validation
116
+ if 'temperature' in params:
117
+ temp = params['temperature']
118
+ if not isinstance(temp, (int, float)) or temp < 0 or temp > 2.0:
119
+ return False, "Temperature must be between 0 and 2.0"
120
+
121
+ # Max tokens validation
122
+ if 'max_tokens' in params:
123
+ max_tok = params['max_tokens']
124
+ if not isinstance(max_tok, int) or max_tok < 1 or max_tok > cls.MAX_TOKEN_LIMIT:
125
+ return False, f"max_tokens must be between 1 and {cls.MAX_TOKEN_LIMIT}"
126
+
127
+ # Top-p validation
128
+ if 'top_p' in params:
129
+ top_p = params['top_p']
130
+ if not isinstance(top_p, (int, float)) or top_p < 0 or top_p > 1.0:
131
+ return False, "top_p must be between 0 and 1.0"
132
+
133
+ return True, None
134
+
135
+
136
+ class ContentFilter:
137
+ """Filter inappropriate and unsafe content"""
138
+
139
+ # Toxicity patterns
140
+ TOXICITY_PATTERNS = [
141
+ r'\b(kill|murder|assassinate|destroy)\s+(myself|yourself|themselves|someone)',
142
+ r'\bhow\s+to\s+(make|build|create)\s+(bomb|weapon|explosive)',
143
+ r'\b(suicide|self-harm|cutting)\s+(method|way|how)',
144
+ r'\b(hack|crack|exploit)\s+(password|account|system)',
145
+ ]
146
+
147
+ # Hate speech patterns
148
+ HATE_SPEECH_PATTERNS = [
149
+ r'\b(hate|despise)\s+\w+\s+(people|race|religion|gender)',
150
+ r'\b(inferior|superior)\s+(race|ethnicity|gender)',
151
+ ]
152
+
153
+ # PII patterns
154
+ PII_PATTERNS = {
155
+ 'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
156
+ 'ssn': r'\b\d{3}-\d{2}-\d{4}\b',
157
+ 'phone': r'\b(\+\d{1,3}[-.]?)?\(?\d{3}\)?[-.]?\d{3}[-.]?\d{4}\b',
158
+ 'credit_card': r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
159
+ 'ip_address': r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b',
160
+ }
161
+
162
+ @classmethod
163
+ def check_toxicity(cls, text: str) -> Tuple[bool, List[str]]:
164
+ """
165
+ Check text for toxic content
166
+
167
+ Args:
168
+ text: Input text
169
+
170
+ Returns:
171
+ (is_safe, violations)
172
+ """
173
+ violations = []
174
+
175
+ # Check toxicity patterns
176
+ for pattern in cls.TOXICITY_PATTERNS:
177
+ if re.search(pattern, text, re.IGNORECASE):
178
+ violations.append(f"toxicity:{pattern[:30]}")
179
+
180
+ # Check hate speech
181
+ for pattern in cls.HATE_SPEECH_PATTERNS:
182
+ if re.search(pattern, text, re.IGNORECASE):
183
+ violations.append(f"hate_speech:{pattern[:30]}")
184
+
185
+ return len(violations) == 0, violations
186
+
187
+ @classmethod
188
+ def detect_pii(cls, text: str) -> List[Tuple[str, str]]:
189
+ """
190
+ Detect PII in text
191
+
192
+ Args:
193
+ text: Input text
194
+
195
+ Returns:
196
+ List of (pii_type, matched_value) tuples
197
+ """
198
+ detected = []
199
+
200
+ for pii_type, pattern in cls.PII_PATTERNS.items():
201
+ matches = re.finditer(pattern, text)
202
+ for match in matches:
203
+ detected.append((pii_type, match.group()))
204
+
205
+ return detected
206
+
207
+ @classmethod
208
+ def redact_pii(cls, text: str) -> str:
209
+ """
210
+ Redact PII from text
211
+
212
+ Args:
213
+ text: Input text
214
+
215
+ Returns:
216
+ Text with PII redacted
217
+ """
218
+ for pii_type, pattern in cls.PII_PATTERNS.items():
219
+ text = re.sub(pattern, f'[REDACTED_{pii_type.upper()}]', text)
220
+
221
+ return text
222
+
223
+ @classmethod
224
+ def filter_content(cls, text: str, redact_pii: bool = True) -> Tuple[str, Dict]:
225
+ """
226
+ Comprehensive content filtering
227
+
228
+ Args:
229
+ text: Input text
230
+ redact_pii: Whether to redact PII
231
+
232
+ Returns:
233
+ (filtered_text, metadata)
234
+ """
235
+ metadata = {
236
+ 'original_length': len(text),
237
+ 'pii_detected': [],
238
+ 'toxic_content': False,
239
+ 'violations': []
240
+ }
241
+
242
+ # Check toxicity
243
+ is_safe, violations = cls.check_toxicity(text)
244
+ if not is_safe:
245
+ metadata['toxic_content'] = True
246
+ metadata['violations'] = violations
247
+
248
+ # Detect PII
249
+ pii_found = cls.detect_pii(text)
250
+ if pii_found:
251
+ metadata['pii_detected'] = [pii_type for pii_type, _ in pii_found]
252
+
253
+ # Redact PII if requested
254
+ filtered_text = text
255
+ if redact_pii and pii_found:
256
+ filtered_text = cls.redact_pii(text)
257
+
258
+ metadata['filtered_length'] = len(filtered_text)
259
+
260
+ return filtered_text, metadata
261
+
262
+
263
+ class RateLimiter:
264
+ """Token bucket rate limiter for API requests"""
265
+
266
+ def __init__(
267
+ self,
268
+ requests_per_minute: int = 60,
269
+ burst_size: int = 10,
270
+ cleanup_interval: int = 3600
271
+ ):
272
+ """
273
+ Initialize rate limiter
274
+
275
+ Args:
276
+ requests_per_minute: Sustained rate limit
277
+ burst_size: Maximum burst requests
278
+ cleanup_interval: Cleanup old entries after this many seconds
279
+ """
280
+ self.rate = requests_per_minute / 60.0
281
+ self.burst_size = burst_size
282
+ self.buckets: Dict[str, Dict] = defaultdict(lambda: {
283
+ 'tokens': burst_size,
284
+ 'last_update': datetime.now(),
285
+ 'total_requests': 0
286
+ })
287
+ self.cleanup_interval = cleanup_interval
288
+ self.last_cleanup = datetime.now()
289
+
290
+ def _cleanup_old_entries(self):
291
+ """Remove inactive client entries"""
292
+ now = datetime.now()
293
+ if (now - self.last_cleanup).total_seconds() < self.cleanup_interval:
294
+ return
295
+
296
+ cutoff = now - timedelta(seconds=self.cleanup_interval)
297
+ inactive = [
298
+ client_id for client_id, bucket in self.buckets.items()
299
+ if bucket['last_update'] < cutoff
300
+ ]
301
+
302
+ for client_id in inactive:
303
+ del self.buckets[client_id]
304
+
305
+ self.last_cleanup = now
306
+ logger.info(f"Cleaned up {len(inactive)} inactive rate limit entries")
307
+
308
+ def allow_request(self, client_id: str) -> Tuple[bool, Dict]:
309
+ """
310
+ Check if request is allowed for client
311
+
312
+ Args:
313
+ client_id: Unique client identifier
314
+
315
+ Returns:
316
+ (allowed, metadata)
317
+ """
318
+ self._cleanup_old_entries()
319
+
320
+ bucket = self.buckets[client_id]
321
+ now = datetime.now()
322
+
323
+ # Calculate elapsed time and add tokens
324
+ elapsed = (now - bucket['last_update']).total_seconds()
325
+ bucket['tokens'] = min(
326
+ self.burst_size,
327
+ bucket['tokens'] + elapsed * self.rate
328
+ )
329
+ bucket['last_update'] = now
330
+
331
+ # Check if request allowed
332
+ if bucket['tokens'] >= 1.0:
333
+ bucket['tokens'] -= 1.0
334
+ bucket['total_requests'] += 1
335
+
336
+ return True, {
337
+ 'allowed': True,
338
+ 'remaining_tokens': int(bucket['tokens']),
339
+ 'total_requests': bucket['total_requests']
340
+ }
341
+ else:
342
+ wait_time = (1.0 - bucket['tokens']) / self.rate
343
+
344
+ return False, {
345
+ 'allowed': False,
346
+ 'retry_after': int(wait_time) + 1,
347
+ 'total_requests': bucket['total_requests']
348
+ }
349
+
350
+ def get_stats(self, client_id: str) -> Dict:
351
+ """Get rate limit statistics for client"""
352
+ if client_id not in self.buckets:
353
+ return {'exists': False}
354
+
355
+ bucket = self.buckets[client_id]
356
+ return {
357
+ 'exists': True,
358
+ 'tokens': bucket['tokens'],
359
+ 'total_requests': bucket['total_requests'],
360
+ 'last_update': bucket['last_update'].isoformat()
361
+ }
362
+
363
+
364
+ class APIKeyManager:
365
+ """Secure API key management"""
366
+
367
+ def __init__(self, storage_path: Optional[str] = None):
368
+ """
369
+ Initialize API key manager
370
+
371
+ Args:
372
+ storage_path: Path to store key hashes
373
+ """
374
+ self.storage_path = Path(storage_path) if storage_path else None
375
+ self.keys: Dict[str, Dict] = {}
376
+
377
+ if self.storage_path and self.storage_path.exists():
378
+ self._load_keys()
379
+
380
+ def generate_key(self, client_id: str, description: str = "") -> str:
381
+ """
382
+ Generate new API key
383
+
384
+ Args:
385
+ client_id: Client identifier
386
+ description: Key description
387
+
388
+ Returns:
389
+ Generated API key
390
+ """
391
+ # Generate cryptographically secure key
392
+ key = f"helion_{secrets.token_urlsafe(32)}"
393
+
394
+ # Hash for storage
395
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
396
+
397
+ # Store metadata
398
+ self.keys[key_hash] = {
399
+ 'client_id': client_id,
400
+ 'description': description,
401
+ 'created_at': datetime.now().isoformat(),
402
+ 'last_used': None,
403
+ 'usage_count': 0
404
+ }
405
+
406
+ self._save_keys()
407
+
408
+ logger.info(f"Generated API key for client: {client_id}")
409
+ return key
410
+
411
+ def verify_key(self, key: str) -> Tuple[bool, Optional[str]]:
412
+ """
413
+ Verify API key
414
+
415
+ Args:
416
+ key: API key to verify
417
+
418
+ Returns:
419
+ (is_valid, client_id)
420
+ """
421
+ if not key or not key.startswith('helion_'):
422
+ return False, None
423
+
424
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
425
+
426
+ if key_hash in self.keys:
427
+ # Update usage statistics
428
+ self.keys[key_hash]['last_used'] = datetime.now().isoformat()
429
+ self.keys[key_hash]['usage_count'] += 1
430
+ self._save_keys()
431
+
432
+ return True, self.keys[key_hash]['client_id']
433
+
434
+ return False, None
435
+
436
+ def revoke_key(self, key: str) -> bool:
437
+ """
438
+ Revoke API key
439
+
440
+ Args:
441
+ key: API key to revoke
442
+
443
+ Returns:
444
+ Success status
445
+ """
446
+ key_hash = hashlib.sha256(key.encode()).hexdigest()
447
+
448
+ if key_hash in self.keys:
449
+ del self.keys[key_hash]
450
+ self._save_keys()
451
+ logger.info(f"Revoked API key: {key_hash[:16]}...")
452
+ return True
453
+
454
+ return False
455
+
456
+ def _load_keys(self):
457
+ """Load keys from storage"""
458
+ try:
459
+ with open(self.storage_path, 'r') as f:
460
+ self.keys = json.load(f)
461
+ logger.info(f"Loaded {len(self.keys)} API keys")
462
+ except Exception as e:
463
+ logger.error(f"Failed to load API keys: {e}")
464
+
465
+ def _save_keys(self):
466
+ """Save keys to storage"""
467
+ if not self.storage_path:
468
+ return
469
+
470
+ try:
471
+ self.storage_path.parent.mkdir(parents=True, exist_ok=True)
472
+ with open(self.storage_path, 'w') as f:
473
+ json.dump(self.keys, f, indent=2)
474
+ except Exception as e:
475
+ logger.error(f"Failed to save API keys: {e}")
476
+
477
+
478
+ class SecurityLogger:
479
+ """Security event logging"""
480
+
481
+ def __init__(self, log_file: str = "security.log"):
482
+ """
483
+ Initialize security logger
484
+
485
+ Args:
486
+ log_file: Path to security log file
487
+ """
488
+ self.log_file = Path(log_file)
489
+ self.log_file.parent.mkdir(parents=True, exist_ok=True)
490
+
491
+ self.logger = logging.getLogger("security")
492
+ handler = logging.FileHandler(self.log_file)
493
+ formatter = logging.Formatter('%(message)s')
494
+ handler.setFormatter(formatter)
495
+ self.logger.addHandler(handler)
496
+ self.logger.setLevel(logging.INFO)
497
+
498
+ def log_event(self, event_type: str, details: Dict):
499
+ """
500
+ Log security event
501
+
502
+ Args:
503
+ event_type: Type of security event
504
+ details: Event details
505
+ """
506
+ event = {
507
+ 'timestamp': datetime.utcnow().isoformat(),
508
+ 'type': event_type,
509
+ 'details': details
510
+ }
511
+ self.logger.info(json.dumps(event))
512
+
513
+ def log_authentication(self, client_id: str, success: bool, ip_address: str = None):
514
+ """Log authentication attempt"""
515
+ self.log_event('authentication', {
516
+ 'client_id': client_id,
517
+ 'success': success,
518
+ 'ip_address': ip_address
519
+ })
520
+
521
+ def log_rate_limit(self, client_id: str, ip_address: str = None):
522
+ """Log rate limit violation"""
523
+ self.log_event('rate_limit', {
524
+ 'client_id': client_id,
525
+ 'ip_address': ip_address
526
+ })
527
+
528
+ def log_content_violation(self, client_id: str, violation_type: str, details: str):
529
+ """Log content policy violation"""
530
+ self.log_event('content_violation', {
531
+ 'client_id': client_id,
532
+ 'violation_type': violation_type,
533
+ 'details': details
534
+ })
535
+
536
+ def log_input_validation_failure(self, client_id: str, reason: str):
537
+ """Log input validation failure"""
538
+ self.log_event('validation_failure', {
539
+ 'client_id': client_id,
540
+ 'reason': reason
541
+ })
542
+
543
+
544
+ # Example usage and integration
545
+ def create_secure_inference_middleware():
546
+ """
547
+ Create middleware for secure inference
548
+
549
+ Returns:
550
+ Dictionary of security components
551
+ """
552
+ return {
553
+ 'validator': InputValidator(),
554
+ 'content_filter': ContentFilter(),
555
+ 'rate_limiter': RateLimiter(requests_per_minute=60),
556
+ 'api_key_manager': APIKeyManager(storage_path='./keys/api_keys.json'),
557
+ 'security_logger': SecurityLogger(log_file='./logs/security.log')
558
+ }
559
+
560
+
561
+ if __name__ == "__main__":
562
+ # Demo security features
563
+ print("Helion Security Module - Feature Demo\n")
564
+
565
+ # Input validation
566
+ validator = InputValidator()
567
+ test_prompt = "Write a Python function to sort a list"
568
+ is_valid, error = validator.validate_prompt(test_prompt)
569
+ print(f"Validation test: {'PASS' if is_valid else 'FAIL'}")
570
+
571
+ # Content filtering
572
+ content_filter = ContentFilter()
573
+ test_text = "My email is [email protected] and phone is 555-1234"
574
+ filtered, metadata = content_filter.filter_content(test_text, redact_pii=True)
575
+ print(f"\nPII Detection: Found {len(metadata['pii_detected'])} types")
576
+ print(f"Filtered text: {filtered}")
577
+
578
+ # Rate limiting
579
+ rate_limiter = RateLimiter(requests_per_minute=10)
580
+ allowed, meta = rate_limiter.allow_request("client_123")
581
+ print(f"\nRate limit test: {'ALLOWED' if allowed else 'DENIED'}")
582
+
583
+ # API key management
584
+ key_manager = APIKeyManager()
585
+ api_key = key_manager.generate_key("test_client", "Demo key")
586
+ print(f"\nGenerated API key: {api_key[:20]}...")
587
+
588
+ is_valid, client = key_manager.verify_key(api_key)
589
+ print(f"Key verification: {'VALID' if is_valid else 'INVALID'}")
590
+
591
+ print("\nSecurity module ready for deployment!")