daniel commited on
Commit
c3717d3
Β·
0 Parent(s):

Clean Space - loads model from Hub

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ env/
7
+ venv/
8
+ models/
9
+ *.log
10
+ .env
11
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y gcc g++ && rm -rf /var/lib/apt/lists/*
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
9
+ WORKDIR $HOME/app
10
+
11
+ COPY --chown=user requirements.txt .
12
+ RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY --chown=user api/ ./api/
15
+
16
+ # Model loaded from HuggingFace Hub at runtime
17
+ ENV PHP_MODEL_REPO=mekbus/codebert-xss-php
18
+ EXPOSE 7860
19
+
20
+ CMD ["python", "-m", "uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XSS Vulnerability Scanner API
2
+
3
+ A FastAPI-based API for detecting XSS vulnerabilities in JavaScript and PHP code using fine-tuned CodeBERT models.
4
+
5
+ ## Features
6
+
7
+ - **JavaScript XSS Detection** - Trained on 14,000+ patterns
8
+ - **PHP XSS Detection** - Trained on 9,700+ balanced patterns
9
+ - **Multi-vulnerability detection** - Finds multiple vulnerabilities per file
10
+ - **Chunking support** - Handles large files by splitting into chunks
11
+
12
+ ## API Endpoints
13
+
14
+ ### Health Check
15
+ ```
16
+ GET /api/v1/health
17
+ ```
18
+
19
+ ### Scan Code
20
+ ```
21
+ POST /api/v1/scan
22
+ {
23
+ "code": "<?php echo $_GET['name']; ?>",
24
+ "language": "php"
25
+ }
26
+ ```
27
+
28
+ ### Languages Supported
29
+ - `php` - PHP code
30
+ - `js` / `javascript` - JavaScript code
31
+
32
+ ## Models
33
+
34
+ This Space uses fine-tuned CodeBERT models:
35
+ - PHP Model: `checkpoint-1867` (92% accuracy on test cases)
36
+ - JS Model: `best_model` (trained on 14k real-world patterns)
37
+
38
+ ## Local Development
39
+
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ python -m api.main
43
+ ```
44
+
45
+ ## License
46
+
47
+ MIT License
api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Empty __init__ files for Python modules
api/dependencies.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dependency injection for FastAPI
3
+ """
4
+ from typing import Optional
5
+ from fastapi import HTTPException
6
+ from api.services.model_service import ModelService
7
+
8
+ # Global model service instance
9
+ model_service: Optional[ModelService] = None
10
+
11
+
12
+ def get_model_service() -> ModelService:
13
+ """Dependency injection for model service"""
14
+ if model_service is None:
15
+ raise HTTPException(status_code=503, detail="Models not loaded")
16
+ return model_service
17
+
18
+
19
+ async def initialize_models():
20
+ """Initialize models on startup"""
21
+ global model_service
22
+ print("πŸš€ Loading CodeBERT models...")
23
+ model_service = ModelService()
24
+ await model_service.load_models()
25
+ print("βœ… Models loaded successfully!")
26
+
27
+
28
+ def cleanup_models():
29
+ """Cleanup on shutdown"""
30
+ global model_service
31
+ print("πŸ‘‹ Shutting down...")
32
+ model_service = None
api/main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import uvicorn
4
+ from contextlib import asynccontextmanager
5
+
6
+ from api.routes import scan, health
7
+ from api import dependencies
8
+
9
+
10
+ @asynccontextmanager
11
+ async def lifespan(app: FastAPI):
12
+ """Load models on startup, cleanup on shutdown"""
13
+ await dependencies.initialize_models()
14
+ yield
15
+ dependencies.cleanup_models()
16
+
17
+
18
+ app = FastAPI(
19
+ title="XSS Detection API",
20
+ description="CodeBERT-based XSS vulnerability detection for PHP and JavaScript",
21
+ version="1.0.0",
22
+ lifespan=lifespan
23
+ )
24
+
25
+ # CORS configuration
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"], # In production, replace with your frontend URL
29
+ allow_credentials=True,
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ # Include routers
35
+ app.include_router(scan.router, prefix="/api/v1", tags=["scan"])
36
+ app.include_router(health.router, prefix="/api/v1", tags=["health"])
37
+
38
+
39
+ @app.get("/")
40
+ async def root():
41
+ return {
42
+ "service": "XSS Detection API",
43
+ "version": "1.0.0",
44
+ "status": "running",
45
+ "docs": "/docs"
46
+ }
47
+
48
+
49
+ if __name__ == "__main__":
50
+ uvicorn.run(
51
+ "api.main:app",
52
+ host="0.0.0.0",
53
+ port=8080,
54
+ reload=True,
55
+ log_level="info"
56
+ )
api/routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Empty __init__ file
api/routes/health.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends
2
+ from pydantic import BaseModel
3
+ from api.dependencies import get_model_service
4
+ from api.services.model_service import ModelService
5
+
6
+
7
+ router = APIRouter()
8
+
9
+
10
+ class HealthResponse(BaseModel):
11
+ status: str
12
+ php_model_loaded: bool
13
+ js_model_loaded: bool
14
+ device: str
15
+
16
+
17
+ @router.get("/health", response_model=HealthResponse)
18
+ async def health_check(model_service: ModelService = Depends(get_model_service)):
19
+ """
20
+ Health check endpoint for load balancer
21
+ """
22
+ return HealthResponse(
23
+ status="healthy",
24
+ php_model_loaded=model_service.php_model is not None,
25
+ js_model_loaded=model_service.js_model is not None,
26
+ device=str(model_service.device)
27
+ )
28
+
29
+
30
+ @router.get("/metrics")
31
+ async def metrics():
32
+ """
33
+ Prometheus metrics endpoint (placeholder)
34
+ """
35
+ return {
36
+ "status": "ok",
37
+ "metrics": "Not implemented yet"
38
+ }
api/routes/scan.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Depends
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Optional
4
+ from enum import Enum
5
+
6
+ from api.dependencies import get_model_service
7
+ from api.services.model_service import ModelService
8
+
9
+
10
+ router = APIRouter()
11
+
12
+
13
+ class LanguageEnum(str, Enum):
14
+ PHP = "php"
15
+ JS = "js"
16
+ JAVASCRIPT = "javascript"
17
+
18
+
19
+ class VulnerabilityDetail(BaseModel):
20
+ type: str = "xss"
21
+ severity: str
22
+ line_number: Optional[int] = None
23
+ description: str
24
+ code_snippet: str
25
+ suggestion: str
26
+
27
+
28
+ class ScanRequest(BaseModel):
29
+ code: str = Field(..., description="Source code to analyze")
30
+ language: LanguageEnum = Field(..., description="Programming language (php or js)")
31
+ file_path: Optional[str] = Field(None, description="File path for context")
32
+
33
+
34
+ class ScanResult(BaseModel):
35
+ is_vulnerable: bool
36
+ confidence: float
37
+ label: str
38
+ vulnerabilities: List[VulnerabilityDetail] = []
39
+ processing_time_ms: Optional[int] = None
40
+ cached: bool = False
41
+
42
+
43
+ class BatchScanRequest(BaseModel):
44
+ files: List[ScanRequest]
45
+
46
+
47
+ class BatchScanResult(BaseModel):
48
+ job_id: str
49
+ total_files: int
50
+ results: List[ScanResult]
51
+
52
+
53
+ @router.post("/scan", response_model=ScanResult)
54
+ async def scan_code(
55
+ request: ScanRequest,
56
+ model_service: ModelService = Depends(get_model_service)
57
+ ):
58
+ """
59
+ Analyze a single code snippet for XSS vulnerabilities
60
+ """
61
+ try:
62
+ import time
63
+ start = time.time()
64
+
65
+ # Run prediction with multi-vulnerability support
66
+ result = model_service.predict_multi(
67
+ request.code,
68
+ request.language.value
69
+ )
70
+
71
+ # Build vulnerability list from all detected vulnerabilities
72
+ vulnerabilities = []
73
+ for vuln_info in result['vulnerabilities']:
74
+ confidence = vuln_info['confidence']
75
+
76
+ # Determine severity based on confidence
77
+ if confidence >= 0.95:
78
+ severity = "critical"
79
+ elif confidence >= 0.85:
80
+ severity = "high"
81
+ elif confidence >= 0.70:
82
+ severity = "medium"
83
+ else:
84
+ severity = "low"
85
+
86
+ # Get code snippet for this line range
87
+ lines = request.code.split('\n')
88
+ start_line = vuln_info['start_line']
89
+ end_line = min(vuln_info['end_line'], len(lines))
90
+ code_snippet = '\n'.join(lines[start_line-1:min(start_line+5, end_line)])
91
+
92
+ vuln = VulnerabilityDetail(
93
+ type="xss",
94
+ severity=severity,
95
+ line_number=start_line,
96
+ description=f"Potential XSS vulnerability detected with {confidence:.1%} confidence (lines {start_line}-{end_line})",
97
+ code_snippet=code_snippet[:500], # Limit snippet length
98
+ suggestion=_get_suggestion(request.language.value)
99
+ )
100
+ vulnerabilities.append(vuln)
101
+
102
+ processing_time = int((time.time() - start) * 1000)
103
+
104
+ # Use max confidence for overall result
105
+ max_confidence = result['max_confidence']
106
+ is_vulnerable = result['is_vulnerable']
107
+ label = "VULNERABLE" if is_vulnerable else "SAFE"
108
+
109
+ return ScanResult(
110
+ is_vulnerable=is_vulnerable,
111
+ confidence=max_confidence,
112
+ label=label,
113
+ vulnerabilities=vulnerabilities,
114
+ processing_time_ms=processing_time,
115
+ cached=False
116
+ )
117
+
118
+ except Exception as e:
119
+ raise HTTPException(status_code=500, detail=str(e))
120
+
121
+
122
+ @router.post("/scan/batch", response_model=BatchScanResult)
123
+ async def scan_batch(
124
+ request: BatchScanRequest,
125
+ model_service: ModelService = Depends(get_model_service)
126
+ ):
127
+ """
128
+ Analyze multiple code files in batch
129
+ """
130
+ import uuid
131
+
132
+ job_id = str(uuid.uuid4())
133
+ results = []
134
+
135
+ for file_request in request.files:
136
+ try:
137
+ result = await scan_code(file_request, model_service)
138
+ results.append(result)
139
+ except Exception as e:
140
+ # Add error result
141
+ results.append(ScanResult(
142
+ is_vulnerable=False,
143
+ confidence=0.0,
144
+ label="ERROR",
145
+ vulnerabilities=[],
146
+ processing_time_ms=0,
147
+ cached=False
148
+ ))
149
+
150
+ return BatchScanResult(
151
+ job_id=job_id,
152
+ total_files=len(request.files),
153
+ results=results
154
+ )
155
+
156
+
157
+ def _extract_vulnerable_code(code: str, language: str) -> tuple:
158
+ """
159
+ Extract the most likely vulnerable code snippet and line number.
160
+ Returns (code_snippet, line_number)
161
+ """
162
+ import re
163
+
164
+ lines = code.split('\n')
165
+
166
+ # Define vulnerable patterns by language
167
+ if language == "php":
168
+ patterns = [
169
+ # Direct output of user input superglobals
170
+ r'echo\s+\$_(GET|POST|REQUEST|COOKIE)',
171
+ r'print\s+\$_(GET|POST|REQUEST|COOKIE)',
172
+ # Echo with array access (database output) - common stored XSS
173
+ r'echo\s+["\'].*\.\s*\$\w+\[',
174
+ r'echo\s+["\'].*\$\w+\[.*\]',
175
+ # Print with concatenation
176
+ r'print\s+["\'].*\.\s*\$',
177
+ # Unescaped variable in echo
178
+ r'echo\s+\$\w+\s*;',
179
+ r'print\s+\$\w+\s*;',
180
+ # Short echo tag with variable
181
+ r'<\?=\s*\$\w+',
182
+ # Dangerous functions
183
+ r'eval\s*\(',
184
+ r'innerHTML\s*=',
185
+ # SQL with user input (can lead to stored XSS)
186
+ r'query\s*\(.*\$_(GET|POST|REQUEST)',
187
+ r'INSERT INTO.*\$\w+',
188
+ r'mysql_query\s*\(.*\$',
189
+ # Direct concatenation in HTML
190
+ r'echo\s+["\']<[^>]+>\s*["\'].*\.\s*\$',
191
+ ]
192
+ else: # JavaScript
193
+ patterns = [
194
+ r'innerHTML\s*=',
195
+ r'outerHTML\s*=',
196
+ r'document\.write\s*\(',
197
+ r'eval\s*\(',
198
+ r'\.html\s*\(', # jQuery
199
+ r'insertAdjacentHTML\s*\(',
200
+ r'location\s*=.*\+', # URL manipulation
201
+ r'window\.location\s*=',
202
+ ]
203
+
204
+ # Search for patterns and find matching lines
205
+ for i, line in enumerate(lines, 1):
206
+ for pattern in patterns:
207
+ if re.search(pattern, line, re.IGNORECASE):
208
+ # Get context: 2 lines before and after
209
+ start = max(0, i - 3)
210
+ end = min(len(lines), i + 2)
211
+ context_lines = lines[start:end]
212
+
213
+ # Mark the vulnerable line
214
+ snippet = '\n'.join(context_lines)
215
+ return snippet, i
216
+
217
+ # If no specific pattern found, skip comments and find real code
218
+ for i, line in enumerate(lines, 1):
219
+ stripped = line.strip()
220
+ # Skip empty lines, comments, and PHP opening tag
221
+ if (stripped and
222
+ not stripped.startswith('//') and
223
+ not stripped.startswith('/*') and
224
+ not stripped.startswith('*') and
225
+ not stripped.startswith('#') and
226
+ stripped != '<?php' and
227
+ not stripped.startswith('/**')):
228
+ # Found first real code line, get context
229
+ start = max(0, i - 1)
230
+ end = min(len(lines), i + 5)
231
+ context_lines = lines[start:end]
232
+ snippet = '\n'.join(context_lines)
233
+ return snippet, i
234
+
235
+ # Fallback: return truncated code
236
+ return code[:300] + "..." if len(code) > 300 else code, 1
237
+
238
+
239
+ def _get_suggestion(language: str) -> str:
240
+ """Get language-specific security suggestion"""
241
+ if language == "php":
242
+ return "Use htmlspecialchars($var, ENT_QUOTES, 'UTF-8') for output encoding"
243
+ elif language in ["js", "javascript"]:
244
+ return "Use textContent instead of innerHTML, or sanitize with DOMPurify"
245
+ return "Sanitize user input before output"
api/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Empty __init__ file
api/services/model_service.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model service for XSS detection - loads model from Hugging Face Hub
3
+ """
4
+ import os
5
+ import re
6
+ import torch
7
+ from typing import Tuple, List
8
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
9
+
10
+
11
+ class ModelService:
12
+ def __init__(self):
13
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ print(f"Using device: {self.device}")
15
+
16
+ # Load tokenizer
17
+ self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
18
+
19
+ # Load PHP model from HuggingFace Hub
20
+ php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php')
21
+ try:
22
+ self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo)
23
+ self.php_model.to(self.device)
24
+ self.php_model.eval()
25
+ print(f"βœ… PHP model loaded from {php_model_repo}")
26
+ except Exception as e:
27
+ print(f"⚠️ PHP model not found: {e}")
28
+ self.php_model = None
29
+
30
+ # Load JS model from HuggingFace Hub
31
+ js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js')
32
+ try:
33
+ self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo)
34
+ self.js_model.to(self.device)
35
+ self.js_model.eval()
36
+ print(f"βœ… JS model loaded from {js_model_repo}")
37
+ except Exception as e:
38
+ print(f"⚠️ JS model not found: {e}")
39
+ self.js_model = None
40
+
41
+ def extract_php_blocks(self, code: str) -> str:
42
+ """Extract PHP code from mixed PHP/HTML and remove comments"""
43
+ php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE)
44
+
45
+ if php_blocks:
46
+ processed_blocks = []
47
+ for block in php_blocks:
48
+ block = block.strip()
49
+ if block.startswith('='):
50
+ block = 'echo ' + block[1:].strip() + ';'
51
+ processed_blocks.append(block)
52
+ php_code = '\n'.join(processed_blocks)
53
+ else:
54
+ php_code = code
55
+
56
+ # Remove comments
57
+ php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL)
58
+ php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE)
59
+ php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE)
60
+ php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip())
61
+
62
+ return php_code
63
+
64
+ def extract_js_code(self, code: str) -> str:
65
+ """Extract and clean JavaScript code"""
66
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
67
+ code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
68
+ code = re.sub(r'\n\s*\n+', '\n', code.strip())
69
+ return code
70
+
71
+ def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
72
+ """Split large code into overlapping chunks"""
73
+ lines = code.split('\n')
74
+ chunks = []
75
+ max_lines = 50
76
+ overlap_lines = 6
77
+
78
+ i = 0
79
+ while i < len(lines):
80
+ chunk_lines = lines[i:i + max_lines]
81
+ chunk = '\n'.join(chunk_lines)
82
+ if chunk.strip():
83
+ chunks.append(chunk)
84
+ i += max_lines - overlap_lines
85
+
86
+ return chunks if chunks else [code]
87
+
88
+ def predict_single(self, code: str, model) -> Tuple[float, float]:
89
+ """Make a single prediction"""
90
+ inputs = self.tokenizer(
91
+ code,
92
+ return_tensors='pt',
93
+ truncation=True,
94
+ max_length=512,
95
+ padding=True
96
+ )
97
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
98
+
99
+ with torch.no_grad():
100
+ outputs = model(**inputs)
101
+ probs = torch.softmax(outputs.logits, dim=1)
102
+ return probs[0][0].item(), probs[0][1].item()
103
+
104
+ def predict(self, code: str, language: str) -> Tuple[bool, float, str]:
105
+ """Predict if code is vulnerable"""
106
+ result = self.predict_multi(code, language)
107
+ if result['vulnerabilities']:
108
+ max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence'])
109
+ return True, max_vuln['confidence'], "VULNERABLE"
110
+ else:
111
+ return False, result['max_confidence'], "SAFE"
112
+
113
+ def predict_multi(self, code: str, language: str) -> dict:
114
+ """Predict vulnerabilities - returns multiple if found"""
115
+ if language == 'php':
116
+ model = self.php_model
117
+ code = self.extract_php_blocks(code)
118
+ elif language in ['js', 'javascript']:
119
+ model = self.js_model
120
+ code = self.extract_js_code(code)
121
+ else:
122
+ raise ValueError(f"Unsupported language: {language}")
123
+
124
+ if model is None:
125
+ raise RuntimeError(f"{language.upper()} model not loaded")
126
+
127
+ vulnerabilities = []
128
+ max_vuln_prob = 0.0
129
+ threshold = 0.5
130
+ use_chunking = len(code) > 2500
131
+
132
+ if use_chunking:
133
+ chunks = self.chunk_code(code)
134
+ print(f"πŸ“„ Large {language.upper()} file: {len(chunks)} chunks")
135
+
136
+ lines = code.split('\n')
137
+ for i, chunk in enumerate(chunks):
138
+ safe_prob, vuln_prob = self.predict_single(chunk, model)
139
+ if vuln_prob > max_vuln_prob:
140
+ max_vuln_prob = vuln_prob
141
+ if vuln_prob >= threshold:
142
+ start_line = i * 44 + 1
143
+ end_line = min(start_line + 49, len(lines))
144
+ vulnerabilities.append({
145
+ 'chunk_id': i + 1,
146
+ 'start_line': start_line,
147
+ 'end_line': end_line,
148
+ 'confidence': vuln_prob
149
+ })
150
+ else:
151
+ safe_prob, vuln_prob = self.predict_single(code, model)
152
+ max_vuln_prob = vuln_prob
153
+ if vuln_prob >= threshold:
154
+ vulnerabilities.append({
155
+ 'chunk_id': 1,
156
+ 'start_line': 1,
157
+ 'end_line': len(code.split('\n')),
158
+ 'confidence': vuln_prob
159
+ })
160
+
161
+ return {
162
+ 'is_vulnerable': len(vulnerabilities) > 0,
163
+ 'max_confidence': max_vuln_prob,
164
+ 'vulnerabilities': vulnerabilities
165
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ transformers==4.35.2
4
+ torch==2.1.1
5
+ pydantic==2.5.0
6
+ python-multipart==0.0.6
7
+ huggingface-hub>=0.19.0
test_api.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for XSS Detection API
3
+ Run this after starting the server to verify everything works
4
+ """
5
+ import requests
6
+ import json
7
+
8
+ BASE_URL = "http://localhost:8080/api/v1"
9
+
10
+
11
+ def test_health():
12
+ """Test health endpoint"""
13
+ print("πŸ” Testing health endpoint...")
14
+ response = requests.get(f"{BASE_URL}/health")
15
+ print(f"Status: {response.status_code}")
16
+ print(f"Response: {json.dumps(response.json(), indent=2)}\n")
17
+
18
+
19
+ def test_php_vulnerable():
20
+ """Test PHP vulnerable code"""
21
+ print("πŸ” Testing PHP vulnerable code...")
22
+ payload = {
23
+ "code": "<?php echo $_GET['input']; ?>",
24
+ "language": "php",
25
+ "file_path": "test.php"
26
+ }
27
+ response = requests.post(f"{BASE_URL}/scan", json=payload)
28
+ print(f"Status: {response.status_code}")
29
+ print(f"Response: {json.dumps(response.json(), indent=2)}\n")
30
+
31
+
32
+ def test_php_safe():
33
+ """Test PHP safe code"""
34
+ print("πŸ” Testing PHP safe code...")
35
+ payload = {
36
+ "code": "<?php echo htmlspecialchars($_GET['input'], ENT_QUOTES, 'UTF-8'); ?>",
37
+ "language": "php",
38
+ "file_path": "safe.php"
39
+ }
40
+ response = requests.post(f"{BASE_URL}/scan", json=payload)
41
+ print(f"Status: {response.status_code}")
42
+ print(f"Response: {json.dumps(response.json(), indent=2)}\n")
43
+
44
+
45
+ def test_js_vulnerable():
46
+ """Test JS vulnerable code"""
47
+ print("πŸ” Testing JS vulnerable code...")
48
+ payload = {
49
+ "code": "document.getElementById('output').innerHTML = userInput;",
50
+ "language": "js",
51
+ "file_path": "test.js"
52
+ }
53
+ response = requests.post(f"{BASE_URL}/scan", json=payload)
54
+ print(f"Status: {response.status_code}")
55
+ print(f"Response: {json.dumps(response.json(), indent=2)}\n")
56
+
57
+
58
+ def test_batch():
59
+ """Test batch scanning"""
60
+ print("πŸ” Testing batch scan...")
61
+ payload = {
62
+ "files": [
63
+ {
64
+ "code": "<?php echo $_POST['name']; ?>",
65
+ "language": "php"
66
+ },
67
+ {
68
+ "code": "<?php echo htmlspecialchars($_POST['name']); ?>",
69
+ "language": "php"
70
+ }
71
+ ]
72
+ }
73
+ response = requests.post(f"{BASE_URL}/scan/batch", json=payload)
74
+ print(f"Status: {response.status_code}")
75
+ print(f"Response: {json.dumps(response.json(), indent=2)}\n")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ print("πŸš€ Starting API tests...\n")
80
+
81
+ try:
82
+ test_health()
83
+ test_php_vulnerable()
84
+ test_php_safe()
85
+ test_js_vulnerable()
86
+ test_batch()
87
+
88
+ print("βœ… All tests completed!")
89
+
90
+ except requests.exceptions.ConnectionError:
91
+ print("❌ Error: Cannot connect to API server")
92
+ print("Make sure the server is running: python -m api.main")
93
+ except Exception as e:
94
+ print(f"❌ Error: {e}")