daniel
commited on
Commit
Β·
c3717d3
0
Parent(s):
Clean Space - loads model from Hub
Browse files- .gitattributes +2 -0
- .gitignore +11 -0
- Dockerfile +20 -0
- README.md +47 -0
- api/__init__.py +1 -0
- api/dependencies.py +32 -0
- api/main.py +56 -0
- api/routes/__init__.py +1 -0
- api/routes/health.py +38 -0
- api/routes/scan.py +245 -0
- api/services/__init__.py +1 -0
- api/services/model_service.py +165 -0
- requirements.txt +7 -0
- test_api.py +94 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
env/
|
| 7 |
+
venv/
|
| 8 |
+
models/
|
| 9 |
+
*.log
|
| 10 |
+
.env
|
| 11 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y gcc g++ && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
USER user
|
| 8 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 9 |
+
WORKDIR $HOME/app
|
| 10 |
+
|
| 11 |
+
COPY --chown=user requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
COPY --chown=user api/ ./api/
|
| 15 |
+
|
| 16 |
+
# Model loaded from HuggingFace Hub at runtime
|
| 17 |
+
ENV PHP_MODEL_REPO=mekbus/codebert-xss-php
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
CMD ["python", "-m", "uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# XSS Vulnerability Scanner API
|
| 2 |
+
|
| 3 |
+
A FastAPI-based API for detecting XSS vulnerabilities in JavaScript and PHP code using fine-tuned CodeBERT models.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **JavaScript XSS Detection** - Trained on 14,000+ patterns
|
| 8 |
+
- **PHP XSS Detection** - Trained on 9,700+ balanced patterns
|
| 9 |
+
- **Multi-vulnerability detection** - Finds multiple vulnerabilities per file
|
| 10 |
+
- **Chunking support** - Handles large files by splitting into chunks
|
| 11 |
+
|
| 12 |
+
## API Endpoints
|
| 13 |
+
|
| 14 |
+
### Health Check
|
| 15 |
+
```
|
| 16 |
+
GET /api/v1/health
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
### Scan Code
|
| 20 |
+
```
|
| 21 |
+
POST /api/v1/scan
|
| 22 |
+
{
|
| 23 |
+
"code": "<?php echo $_GET['name']; ?>",
|
| 24 |
+
"language": "php"
|
| 25 |
+
}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Languages Supported
|
| 29 |
+
- `php` - PHP code
|
| 30 |
+
- `js` / `javascript` - JavaScript code
|
| 31 |
+
|
| 32 |
+
## Models
|
| 33 |
+
|
| 34 |
+
This Space uses fine-tuned CodeBERT models:
|
| 35 |
+
- PHP Model: `checkpoint-1867` (92% accuracy on test cases)
|
| 36 |
+
- JS Model: `best_model` (trained on 14k real-world patterns)
|
| 37 |
+
|
| 38 |
+
## Local Development
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
python -m api.main
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## License
|
| 46 |
+
|
| 47 |
+
MIT License
|
api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Empty __init__ files for Python modules
|
api/dependencies.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dependency injection for FastAPI
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from fastapi import HTTPException
|
| 6 |
+
from api.services.model_service import ModelService
|
| 7 |
+
|
| 8 |
+
# Global model service instance
|
| 9 |
+
model_service: Optional[ModelService] = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_model_service() -> ModelService:
|
| 13 |
+
"""Dependency injection for model service"""
|
| 14 |
+
if model_service is None:
|
| 15 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 16 |
+
return model_service
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def initialize_models():
|
| 20 |
+
"""Initialize models on startup"""
|
| 21 |
+
global model_service
|
| 22 |
+
print("π Loading CodeBERT models...")
|
| 23 |
+
model_service = ModelService()
|
| 24 |
+
await model_service.load_models()
|
| 25 |
+
print("β
Models loaded successfully!")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cleanup_models():
|
| 29 |
+
"""Cleanup on shutdown"""
|
| 30 |
+
global model_service
|
| 31 |
+
print("π Shutting down...")
|
| 32 |
+
model_service = None
|
api/main.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
import uvicorn
|
| 4 |
+
from contextlib import asynccontextmanager
|
| 5 |
+
|
| 6 |
+
from api.routes import scan, health
|
| 7 |
+
from api import dependencies
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@asynccontextmanager
|
| 11 |
+
async def lifespan(app: FastAPI):
|
| 12 |
+
"""Load models on startup, cleanup on shutdown"""
|
| 13 |
+
await dependencies.initialize_models()
|
| 14 |
+
yield
|
| 15 |
+
dependencies.cleanup_models()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
app = FastAPI(
|
| 19 |
+
title="XSS Detection API",
|
| 20 |
+
description="CodeBERT-based XSS vulnerability detection for PHP and JavaScript",
|
| 21 |
+
version="1.0.0",
|
| 22 |
+
lifespan=lifespan
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# CORS configuration
|
| 26 |
+
app.add_middleware(
|
| 27 |
+
CORSMiddleware,
|
| 28 |
+
allow_origins=["*"], # In production, replace with your frontend URL
|
| 29 |
+
allow_credentials=True,
|
| 30 |
+
allow_methods=["*"],
|
| 31 |
+
allow_headers=["*"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Include routers
|
| 35 |
+
app.include_router(scan.router, prefix="/api/v1", tags=["scan"])
|
| 36 |
+
app.include_router(health.router, prefix="/api/v1", tags=["health"])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@app.get("/")
|
| 40 |
+
async def root():
|
| 41 |
+
return {
|
| 42 |
+
"service": "XSS Detection API",
|
| 43 |
+
"version": "1.0.0",
|
| 44 |
+
"status": "running",
|
| 45 |
+
"docs": "/docs"
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
uvicorn.run(
|
| 51 |
+
"api.main:app",
|
| 52 |
+
host="0.0.0.0",
|
| 53 |
+
port=8080,
|
| 54 |
+
reload=True,
|
| 55 |
+
log_level="info"
|
| 56 |
+
)
|
api/routes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Empty __init__ file
|
api/routes/health.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from api.dependencies import get_model_service
|
| 4 |
+
from api.services.model_service import ModelService
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
router = APIRouter()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HealthResponse(BaseModel):
|
| 11 |
+
status: str
|
| 12 |
+
php_model_loaded: bool
|
| 13 |
+
js_model_loaded: bool
|
| 14 |
+
device: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@router.get("/health", response_model=HealthResponse)
|
| 18 |
+
async def health_check(model_service: ModelService = Depends(get_model_service)):
|
| 19 |
+
"""
|
| 20 |
+
Health check endpoint for load balancer
|
| 21 |
+
"""
|
| 22 |
+
return HealthResponse(
|
| 23 |
+
status="healthy",
|
| 24 |
+
php_model_loaded=model_service.php_model is not None,
|
| 25 |
+
js_model_loaded=model_service.js_model is not None,
|
| 26 |
+
device=str(model_service.device)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@router.get("/metrics")
|
| 31 |
+
async def metrics():
|
| 32 |
+
"""
|
| 33 |
+
Prometheus metrics endpoint (placeholder)
|
| 34 |
+
"""
|
| 35 |
+
return {
|
| 36 |
+
"status": "ok",
|
| 37 |
+
"metrics": "Not implemented yet"
|
| 38 |
+
}
|
api/routes/scan.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Depends
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
from api.dependencies import get_model_service
|
| 7 |
+
from api.services.model_service import ModelService
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LanguageEnum(str, Enum):
|
| 14 |
+
PHP = "php"
|
| 15 |
+
JS = "js"
|
| 16 |
+
JAVASCRIPT = "javascript"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VulnerabilityDetail(BaseModel):
|
| 20 |
+
type: str = "xss"
|
| 21 |
+
severity: str
|
| 22 |
+
line_number: Optional[int] = None
|
| 23 |
+
description: str
|
| 24 |
+
code_snippet: str
|
| 25 |
+
suggestion: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ScanRequest(BaseModel):
|
| 29 |
+
code: str = Field(..., description="Source code to analyze")
|
| 30 |
+
language: LanguageEnum = Field(..., description="Programming language (php or js)")
|
| 31 |
+
file_path: Optional[str] = Field(None, description="File path for context")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ScanResult(BaseModel):
|
| 35 |
+
is_vulnerable: bool
|
| 36 |
+
confidence: float
|
| 37 |
+
label: str
|
| 38 |
+
vulnerabilities: List[VulnerabilityDetail] = []
|
| 39 |
+
processing_time_ms: Optional[int] = None
|
| 40 |
+
cached: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BatchScanRequest(BaseModel):
|
| 44 |
+
files: List[ScanRequest]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BatchScanResult(BaseModel):
|
| 48 |
+
job_id: str
|
| 49 |
+
total_files: int
|
| 50 |
+
results: List[ScanResult]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@router.post("/scan", response_model=ScanResult)
|
| 54 |
+
async def scan_code(
|
| 55 |
+
request: ScanRequest,
|
| 56 |
+
model_service: ModelService = Depends(get_model_service)
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Analyze a single code snippet for XSS vulnerabilities
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
import time
|
| 63 |
+
start = time.time()
|
| 64 |
+
|
| 65 |
+
# Run prediction with multi-vulnerability support
|
| 66 |
+
result = model_service.predict_multi(
|
| 67 |
+
request.code,
|
| 68 |
+
request.language.value
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Build vulnerability list from all detected vulnerabilities
|
| 72 |
+
vulnerabilities = []
|
| 73 |
+
for vuln_info in result['vulnerabilities']:
|
| 74 |
+
confidence = vuln_info['confidence']
|
| 75 |
+
|
| 76 |
+
# Determine severity based on confidence
|
| 77 |
+
if confidence >= 0.95:
|
| 78 |
+
severity = "critical"
|
| 79 |
+
elif confidence >= 0.85:
|
| 80 |
+
severity = "high"
|
| 81 |
+
elif confidence >= 0.70:
|
| 82 |
+
severity = "medium"
|
| 83 |
+
else:
|
| 84 |
+
severity = "low"
|
| 85 |
+
|
| 86 |
+
# Get code snippet for this line range
|
| 87 |
+
lines = request.code.split('\n')
|
| 88 |
+
start_line = vuln_info['start_line']
|
| 89 |
+
end_line = min(vuln_info['end_line'], len(lines))
|
| 90 |
+
code_snippet = '\n'.join(lines[start_line-1:min(start_line+5, end_line)])
|
| 91 |
+
|
| 92 |
+
vuln = VulnerabilityDetail(
|
| 93 |
+
type="xss",
|
| 94 |
+
severity=severity,
|
| 95 |
+
line_number=start_line,
|
| 96 |
+
description=f"Potential XSS vulnerability detected with {confidence:.1%} confidence (lines {start_line}-{end_line})",
|
| 97 |
+
code_snippet=code_snippet[:500], # Limit snippet length
|
| 98 |
+
suggestion=_get_suggestion(request.language.value)
|
| 99 |
+
)
|
| 100 |
+
vulnerabilities.append(vuln)
|
| 101 |
+
|
| 102 |
+
processing_time = int((time.time() - start) * 1000)
|
| 103 |
+
|
| 104 |
+
# Use max confidence for overall result
|
| 105 |
+
max_confidence = result['max_confidence']
|
| 106 |
+
is_vulnerable = result['is_vulnerable']
|
| 107 |
+
label = "VULNERABLE" if is_vulnerable else "SAFE"
|
| 108 |
+
|
| 109 |
+
return ScanResult(
|
| 110 |
+
is_vulnerable=is_vulnerable,
|
| 111 |
+
confidence=max_confidence,
|
| 112 |
+
label=label,
|
| 113 |
+
vulnerabilities=vulnerabilities,
|
| 114 |
+
processing_time_ms=processing_time,
|
| 115 |
+
cached=False
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@router.post("/scan/batch", response_model=BatchScanResult)
|
| 123 |
+
async def scan_batch(
|
| 124 |
+
request: BatchScanRequest,
|
| 125 |
+
model_service: ModelService = Depends(get_model_service)
|
| 126 |
+
):
|
| 127 |
+
"""
|
| 128 |
+
Analyze multiple code files in batch
|
| 129 |
+
"""
|
| 130 |
+
import uuid
|
| 131 |
+
|
| 132 |
+
job_id = str(uuid.uuid4())
|
| 133 |
+
results = []
|
| 134 |
+
|
| 135 |
+
for file_request in request.files:
|
| 136 |
+
try:
|
| 137 |
+
result = await scan_code(file_request, model_service)
|
| 138 |
+
results.append(result)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
# Add error result
|
| 141 |
+
results.append(ScanResult(
|
| 142 |
+
is_vulnerable=False,
|
| 143 |
+
confidence=0.0,
|
| 144 |
+
label="ERROR",
|
| 145 |
+
vulnerabilities=[],
|
| 146 |
+
processing_time_ms=0,
|
| 147 |
+
cached=False
|
| 148 |
+
))
|
| 149 |
+
|
| 150 |
+
return BatchScanResult(
|
| 151 |
+
job_id=job_id,
|
| 152 |
+
total_files=len(request.files),
|
| 153 |
+
results=results
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _extract_vulnerable_code(code: str, language: str) -> tuple:
|
| 158 |
+
"""
|
| 159 |
+
Extract the most likely vulnerable code snippet and line number.
|
| 160 |
+
Returns (code_snippet, line_number)
|
| 161 |
+
"""
|
| 162 |
+
import re
|
| 163 |
+
|
| 164 |
+
lines = code.split('\n')
|
| 165 |
+
|
| 166 |
+
# Define vulnerable patterns by language
|
| 167 |
+
if language == "php":
|
| 168 |
+
patterns = [
|
| 169 |
+
# Direct output of user input superglobals
|
| 170 |
+
r'echo\s+\$_(GET|POST|REQUEST|COOKIE)',
|
| 171 |
+
r'print\s+\$_(GET|POST|REQUEST|COOKIE)',
|
| 172 |
+
# Echo with array access (database output) - common stored XSS
|
| 173 |
+
r'echo\s+["\'].*\.\s*\$\w+\[',
|
| 174 |
+
r'echo\s+["\'].*\$\w+\[.*\]',
|
| 175 |
+
# Print with concatenation
|
| 176 |
+
r'print\s+["\'].*\.\s*\$',
|
| 177 |
+
# Unescaped variable in echo
|
| 178 |
+
r'echo\s+\$\w+\s*;',
|
| 179 |
+
r'print\s+\$\w+\s*;',
|
| 180 |
+
# Short echo tag with variable
|
| 181 |
+
r'<\?=\s*\$\w+',
|
| 182 |
+
# Dangerous functions
|
| 183 |
+
r'eval\s*\(',
|
| 184 |
+
r'innerHTML\s*=',
|
| 185 |
+
# SQL with user input (can lead to stored XSS)
|
| 186 |
+
r'query\s*\(.*\$_(GET|POST|REQUEST)',
|
| 187 |
+
r'INSERT INTO.*\$\w+',
|
| 188 |
+
r'mysql_query\s*\(.*\$',
|
| 189 |
+
# Direct concatenation in HTML
|
| 190 |
+
r'echo\s+["\']<[^>]+>\s*["\'].*\.\s*\$',
|
| 191 |
+
]
|
| 192 |
+
else: # JavaScript
|
| 193 |
+
patterns = [
|
| 194 |
+
r'innerHTML\s*=',
|
| 195 |
+
r'outerHTML\s*=',
|
| 196 |
+
r'document\.write\s*\(',
|
| 197 |
+
r'eval\s*\(',
|
| 198 |
+
r'\.html\s*\(', # jQuery
|
| 199 |
+
r'insertAdjacentHTML\s*\(',
|
| 200 |
+
r'location\s*=.*\+', # URL manipulation
|
| 201 |
+
r'window\.location\s*=',
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
# Search for patterns and find matching lines
|
| 205 |
+
for i, line in enumerate(lines, 1):
|
| 206 |
+
for pattern in patterns:
|
| 207 |
+
if re.search(pattern, line, re.IGNORECASE):
|
| 208 |
+
# Get context: 2 lines before and after
|
| 209 |
+
start = max(0, i - 3)
|
| 210 |
+
end = min(len(lines), i + 2)
|
| 211 |
+
context_lines = lines[start:end]
|
| 212 |
+
|
| 213 |
+
# Mark the vulnerable line
|
| 214 |
+
snippet = '\n'.join(context_lines)
|
| 215 |
+
return snippet, i
|
| 216 |
+
|
| 217 |
+
# If no specific pattern found, skip comments and find real code
|
| 218 |
+
for i, line in enumerate(lines, 1):
|
| 219 |
+
stripped = line.strip()
|
| 220 |
+
# Skip empty lines, comments, and PHP opening tag
|
| 221 |
+
if (stripped and
|
| 222 |
+
not stripped.startswith('//') and
|
| 223 |
+
not stripped.startswith('/*') and
|
| 224 |
+
not stripped.startswith('*') and
|
| 225 |
+
not stripped.startswith('#') and
|
| 226 |
+
stripped != '<?php' and
|
| 227 |
+
not stripped.startswith('/**')):
|
| 228 |
+
# Found first real code line, get context
|
| 229 |
+
start = max(0, i - 1)
|
| 230 |
+
end = min(len(lines), i + 5)
|
| 231 |
+
context_lines = lines[start:end]
|
| 232 |
+
snippet = '\n'.join(context_lines)
|
| 233 |
+
return snippet, i
|
| 234 |
+
|
| 235 |
+
# Fallback: return truncated code
|
| 236 |
+
return code[:300] + "..." if len(code) > 300 else code, 1
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _get_suggestion(language: str) -> str:
|
| 240 |
+
"""Get language-specific security suggestion"""
|
| 241 |
+
if language == "php":
|
| 242 |
+
return "Use htmlspecialchars($var, ENT_QUOTES, 'UTF-8') for output encoding"
|
| 243 |
+
elif language in ["js", "javascript"]:
|
| 244 |
+
return "Use textContent instead of innerHTML, or sanitize with DOMPurify"
|
| 245 |
+
return "Sanitize user input before output"
|
api/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Empty __init__ file
|
api/services/model_service.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model service for XSS detection - loads model from Hugging Face Hub
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Tuple, List
|
| 8 |
+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ModelService:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
print(f"Using device: {self.device}")
|
| 15 |
+
|
| 16 |
+
# Load tokenizer
|
| 17 |
+
self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
|
| 18 |
+
|
| 19 |
+
# Load PHP model from HuggingFace Hub
|
| 20 |
+
php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php')
|
| 21 |
+
try:
|
| 22 |
+
self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo)
|
| 23 |
+
self.php_model.to(self.device)
|
| 24 |
+
self.php_model.eval()
|
| 25 |
+
print(f"β
PHP model loaded from {php_model_repo}")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"β οΈ PHP model not found: {e}")
|
| 28 |
+
self.php_model = None
|
| 29 |
+
|
| 30 |
+
# Load JS model from HuggingFace Hub
|
| 31 |
+
js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js')
|
| 32 |
+
try:
|
| 33 |
+
self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo)
|
| 34 |
+
self.js_model.to(self.device)
|
| 35 |
+
self.js_model.eval()
|
| 36 |
+
print(f"β
JS model loaded from {js_model_repo}")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"β οΈ JS model not found: {e}")
|
| 39 |
+
self.js_model = None
|
| 40 |
+
|
| 41 |
+
def extract_php_blocks(self, code: str) -> str:
|
| 42 |
+
"""Extract PHP code from mixed PHP/HTML and remove comments"""
|
| 43 |
+
php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE)
|
| 44 |
+
|
| 45 |
+
if php_blocks:
|
| 46 |
+
processed_blocks = []
|
| 47 |
+
for block in php_blocks:
|
| 48 |
+
block = block.strip()
|
| 49 |
+
if block.startswith('='):
|
| 50 |
+
block = 'echo ' + block[1:].strip() + ';'
|
| 51 |
+
processed_blocks.append(block)
|
| 52 |
+
php_code = '\n'.join(processed_blocks)
|
| 53 |
+
else:
|
| 54 |
+
php_code = code
|
| 55 |
+
|
| 56 |
+
# Remove comments
|
| 57 |
+
php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL)
|
| 58 |
+
php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE)
|
| 59 |
+
php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE)
|
| 60 |
+
php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip())
|
| 61 |
+
|
| 62 |
+
return php_code
|
| 63 |
+
|
| 64 |
+
def extract_js_code(self, code: str) -> str:
|
| 65 |
+
"""Extract and clean JavaScript code"""
|
| 66 |
+
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
|
| 67 |
+
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
|
| 68 |
+
code = re.sub(r'\n\s*\n+', '\n', code.strip())
|
| 69 |
+
return code
|
| 70 |
+
|
| 71 |
+
def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
|
| 72 |
+
"""Split large code into overlapping chunks"""
|
| 73 |
+
lines = code.split('\n')
|
| 74 |
+
chunks = []
|
| 75 |
+
max_lines = 50
|
| 76 |
+
overlap_lines = 6
|
| 77 |
+
|
| 78 |
+
i = 0
|
| 79 |
+
while i < len(lines):
|
| 80 |
+
chunk_lines = lines[i:i + max_lines]
|
| 81 |
+
chunk = '\n'.join(chunk_lines)
|
| 82 |
+
if chunk.strip():
|
| 83 |
+
chunks.append(chunk)
|
| 84 |
+
i += max_lines - overlap_lines
|
| 85 |
+
|
| 86 |
+
return chunks if chunks else [code]
|
| 87 |
+
|
| 88 |
+
def predict_single(self, code: str, model) -> Tuple[float, float]:
|
| 89 |
+
"""Make a single prediction"""
|
| 90 |
+
inputs = self.tokenizer(
|
| 91 |
+
code,
|
| 92 |
+
return_tensors='pt',
|
| 93 |
+
truncation=True,
|
| 94 |
+
max_length=512,
|
| 95 |
+
padding=True
|
| 96 |
+
)
|
| 97 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
outputs = model(**inputs)
|
| 101 |
+
probs = torch.softmax(outputs.logits, dim=1)
|
| 102 |
+
return probs[0][0].item(), probs[0][1].item()
|
| 103 |
+
|
| 104 |
+
def predict(self, code: str, language: str) -> Tuple[bool, float, str]:
|
| 105 |
+
"""Predict if code is vulnerable"""
|
| 106 |
+
result = self.predict_multi(code, language)
|
| 107 |
+
if result['vulnerabilities']:
|
| 108 |
+
max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence'])
|
| 109 |
+
return True, max_vuln['confidence'], "VULNERABLE"
|
| 110 |
+
else:
|
| 111 |
+
return False, result['max_confidence'], "SAFE"
|
| 112 |
+
|
| 113 |
+
def predict_multi(self, code: str, language: str) -> dict:
|
| 114 |
+
"""Predict vulnerabilities - returns multiple if found"""
|
| 115 |
+
if language == 'php':
|
| 116 |
+
model = self.php_model
|
| 117 |
+
code = self.extract_php_blocks(code)
|
| 118 |
+
elif language in ['js', 'javascript']:
|
| 119 |
+
model = self.js_model
|
| 120 |
+
code = self.extract_js_code(code)
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 123 |
+
|
| 124 |
+
if model is None:
|
| 125 |
+
raise RuntimeError(f"{language.upper()} model not loaded")
|
| 126 |
+
|
| 127 |
+
vulnerabilities = []
|
| 128 |
+
max_vuln_prob = 0.0
|
| 129 |
+
threshold = 0.5
|
| 130 |
+
use_chunking = len(code) > 2500
|
| 131 |
+
|
| 132 |
+
if use_chunking:
|
| 133 |
+
chunks = self.chunk_code(code)
|
| 134 |
+
print(f"π Large {language.upper()} file: {len(chunks)} chunks")
|
| 135 |
+
|
| 136 |
+
lines = code.split('\n')
|
| 137 |
+
for i, chunk in enumerate(chunks):
|
| 138 |
+
safe_prob, vuln_prob = self.predict_single(chunk, model)
|
| 139 |
+
if vuln_prob > max_vuln_prob:
|
| 140 |
+
max_vuln_prob = vuln_prob
|
| 141 |
+
if vuln_prob >= threshold:
|
| 142 |
+
start_line = i * 44 + 1
|
| 143 |
+
end_line = min(start_line + 49, len(lines))
|
| 144 |
+
vulnerabilities.append({
|
| 145 |
+
'chunk_id': i + 1,
|
| 146 |
+
'start_line': start_line,
|
| 147 |
+
'end_line': end_line,
|
| 148 |
+
'confidence': vuln_prob
|
| 149 |
+
})
|
| 150 |
+
else:
|
| 151 |
+
safe_prob, vuln_prob = self.predict_single(code, model)
|
| 152 |
+
max_vuln_prob = vuln_prob
|
| 153 |
+
if vuln_prob >= threshold:
|
| 154 |
+
vulnerabilities.append({
|
| 155 |
+
'chunk_id': 1,
|
| 156 |
+
'start_line': 1,
|
| 157 |
+
'end_line': len(code.split('\n')),
|
| 158 |
+
'confidence': vuln_prob
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
'is_vulnerable': len(vulnerabilities) > 0,
|
| 163 |
+
'max_confidence': max_vuln_prob,
|
| 164 |
+
'vulnerabilities': vulnerabilities
|
| 165 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
transformers==4.35.2
|
| 4 |
+
torch==2.1.1
|
| 5 |
+
pydantic==2.5.0
|
| 6 |
+
python-multipart==0.0.6
|
| 7 |
+
huggingface-hub>=0.19.0
|
test_api.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for XSS Detection API
|
| 3 |
+
Run this after starting the server to verify everything works
|
| 4 |
+
"""
|
| 5 |
+
import requests
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
BASE_URL = "http://localhost:8080/api/v1"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_health():
|
| 12 |
+
"""Test health endpoint"""
|
| 13 |
+
print("π Testing health endpoint...")
|
| 14 |
+
response = requests.get(f"{BASE_URL}/health")
|
| 15 |
+
print(f"Status: {response.status_code}")
|
| 16 |
+
print(f"Response: {json.dumps(response.json(), indent=2)}\n")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_php_vulnerable():
|
| 20 |
+
"""Test PHP vulnerable code"""
|
| 21 |
+
print("π Testing PHP vulnerable code...")
|
| 22 |
+
payload = {
|
| 23 |
+
"code": "<?php echo $_GET['input']; ?>",
|
| 24 |
+
"language": "php",
|
| 25 |
+
"file_path": "test.php"
|
| 26 |
+
}
|
| 27 |
+
response = requests.post(f"{BASE_URL}/scan", json=payload)
|
| 28 |
+
print(f"Status: {response.status_code}")
|
| 29 |
+
print(f"Response: {json.dumps(response.json(), indent=2)}\n")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_php_safe():
|
| 33 |
+
"""Test PHP safe code"""
|
| 34 |
+
print("π Testing PHP safe code...")
|
| 35 |
+
payload = {
|
| 36 |
+
"code": "<?php echo htmlspecialchars($_GET['input'], ENT_QUOTES, 'UTF-8'); ?>",
|
| 37 |
+
"language": "php",
|
| 38 |
+
"file_path": "safe.php"
|
| 39 |
+
}
|
| 40 |
+
response = requests.post(f"{BASE_URL}/scan", json=payload)
|
| 41 |
+
print(f"Status: {response.status_code}")
|
| 42 |
+
print(f"Response: {json.dumps(response.json(), indent=2)}\n")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_js_vulnerable():
|
| 46 |
+
"""Test JS vulnerable code"""
|
| 47 |
+
print("π Testing JS vulnerable code...")
|
| 48 |
+
payload = {
|
| 49 |
+
"code": "document.getElementById('output').innerHTML = userInput;",
|
| 50 |
+
"language": "js",
|
| 51 |
+
"file_path": "test.js"
|
| 52 |
+
}
|
| 53 |
+
response = requests.post(f"{BASE_URL}/scan", json=payload)
|
| 54 |
+
print(f"Status: {response.status_code}")
|
| 55 |
+
print(f"Response: {json.dumps(response.json(), indent=2)}\n")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_batch():
|
| 59 |
+
"""Test batch scanning"""
|
| 60 |
+
print("π Testing batch scan...")
|
| 61 |
+
payload = {
|
| 62 |
+
"files": [
|
| 63 |
+
{
|
| 64 |
+
"code": "<?php echo $_POST['name']; ?>",
|
| 65 |
+
"language": "php"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"code": "<?php echo htmlspecialchars($_POST['name']); ?>",
|
| 69 |
+
"language": "php"
|
| 70 |
+
}
|
| 71 |
+
]
|
| 72 |
+
}
|
| 73 |
+
response = requests.post(f"{BASE_URL}/scan/batch", json=payload)
|
| 74 |
+
print(f"Status: {response.status_code}")
|
| 75 |
+
print(f"Response: {json.dumps(response.json(), indent=2)}\n")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
print("π Starting API tests...\n")
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
test_health()
|
| 83 |
+
test_php_vulnerable()
|
| 84 |
+
test_php_safe()
|
| 85 |
+
test_js_vulnerable()
|
| 86 |
+
test_batch()
|
| 87 |
+
|
| 88 |
+
print("β
All tests completed!")
|
| 89 |
+
|
| 90 |
+
except requests.exceptions.ConnectionError:
|
| 91 |
+
print("β Error: Cannot connect to API server")
|
| 92 |
+
print("Make sure the server is running: python -m api.main")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"β Error: {e}")
|