|
|
from fastapi import APIRouter, HTTPException, Depends |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Optional |
|
|
from enum import Enum |
|
|
|
|
|
from api.dependencies import get_model_service |
|
|
from api.services.model_service import ModelService |
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
class LanguageEnum(str, Enum): |
|
|
PHP = "php" |
|
|
JS = "js" |
|
|
JAVASCRIPT = "javascript" |
|
|
|
|
|
|
|
|
class VulnerabilityDetail(BaseModel): |
|
|
type: str = "xss" |
|
|
severity: str |
|
|
line_number: Optional[int] = None |
|
|
description: str |
|
|
code_snippet: str |
|
|
suggestion: str |
|
|
|
|
|
|
|
|
class ScanRequest(BaseModel): |
|
|
code: str = Field(..., description="Source code to analyze") |
|
|
language: LanguageEnum = Field(..., description="Programming language (php or js)") |
|
|
file_path: Optional[str] = Field(None, description="File path for context") |
|
|
|
|
|
|
|
|
class ScanResult(BaseModel): |
|
|
is_vulnerable: bool |
|
|
confidence: float |
|
|
label: str |
|
|
vulnerabilities: List[VulnerabilityDetail] = [] |
|
|
processing_time_ms: Optional[int] = None |
|
|
cached: bool = False |
|
|
|
|
|
|
|
|
class BatchScanRequest(BaseModel): |
|
|
files: List[ScanRequest] |
|
|
|
|
|
|
|
|
class BatchScanResult(BaseModel): |
|
|
job_id: str |
|
|
total_files: int |
|
|
results: List[ScanResult] |
|
|
|
|
|
|
|
|
@router.post("/scan", response_model=ScanResult) |
|
|
async def scan_code( |
|
|
request: ScanRequest, |
|
|
model_service: ModelService = Depends(get_model_service) |
|
|
): |
|
|
""" |
|
|
Analyze a single code snippet for XSS vulnerabilities |
|
|
""" |
|
|
try: |
|
|
import time |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
result = model_service.predict_multi( |
|
|
request.code, |
|
|
request.language.value |
|
|
) |
|
|
|
|
|
|
|
|
vulnerabilities = [] |
|
|
for vuln_info in result['vulnerabilities']: |
|
|
confidence = vuln_info['confidence'] |
|
|
|
|
|
|
|
|
if confidence >= 0.95: |
|
|
severity = "critical" |
|
|
elif confidence >= 0.85: |
|
|
severity = "high" |
|
|
elif confidence >= 0.70: |
|
|
severity = "medium" |
|
|
else: |
|
|
severity = "low" |
|
|
|
|
|
|
|
|
lines = request.code.split('\n') |
|
|
start_line = vuln_info['start_line'] |
|
|
end_line = min(vuln_info['end_line'], len(lines)) |
|
|
code_snippet = '\n'.join(lines[start_line-1:min(start_line+5, end_line)]) |
|
|
|
|
|
vuln = VulnerabilityDetail( |
|
|
type="xss", |
|
|
severity=severity, |
|
|
line_number=start_line, |
|
|
description=f"Potential XSS vulnerability detected with {confidence:.1%} confidence (lines {start_line}-{end_line})", |
|
|
code_snippet=code_snippet[:500], |
|
|
suggestion=_get_suggestion(request.language.value) |
|
|
) |
|
|
vulnerabilities.append(vuln) |
|
|
|
|
|
processing_time = int((time.time() - start) * 1000) |
|
|
|
|
|
|
|
|
max_confidence = result['max_confidence'] |
|
|
is_vulnerable = result['is_vulnerable'] |
|
|
label = "VULNERABLE" if is_vulnerable else "SAFE" |
|
|
|
|
|
return ScanResult( |
|
|
is_vulnerable=is_vulnerable, |
|
|
confidence=max_confidence, |
|
|
label=label, |
|
|
vulnerabilities=vulnerabilities, |
|
|
processing_time_ms=processing_time, |
|
|
cached=False |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.post("/scan/batch", response_model=BatchScanResult) |
|
|
async def scan_batch( |
|
|
request: BatchScanRequest, |
|
|
model_service: ModelService = Depends(get_model_service) |
|
|
): |
|
|
""" |
|
|
Analyze multiple code files in batch |
|
|
""" |
|
|
import uuid |
|
|
|
|
|
job_id = str(uuid.uuid4()) |
|
|
results = [] |
|
|
|
|
|
for file_request in request.files: |
|
|
try: |
|
|
result = await scan_code(file_request, model_service) |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
|
|
|
results.append(ScanResult( |
|
|
is_vulnerable=False, |
|
|
confidence=0.0, |
|
|
label="ERROR", |
|
|
vulnerabilities=[], |
|
|
processing_time_ms=0, |
|
|
cached=False |
|
|
)) |
|
|
|
|
|
return BatchScanResult( |
|
|
job_id=job_id, |
|
|
total_files=len(request.files), |
|
|
results=results |
|
|
) |
|
|
|
|
|
|
|
|
def _extract_vulnerable_code(code: str, language: str) -> tuple: |
|
|
""" |
|
|
Extract the most likely vulnerable code snippet and line number. |
|
|
Returns (code_snippet, line_number) |
|
|
""" |
|
|
import re |
|
|
|
|
|
lines = code.split('\n') |
|
|
|
|
|
|
|
|
if language == "php": |
|
|
patterns = [ |
|
|
|
|
|
r'echo\s+\$_(GET|POST|REQUEST|COOKIE)', |
|
|
r'print\s+\$_(GET|POST|REQUEST|COOKIE)', |
|
|
|
|
|
r'echo\s+["\'].*\.\s*\$\w+\[', |
|
|
r'echo\s+["\'].*\$\w+\[.*\]', |
|
|
|
|
|
r'print\s+["\'].*\.\s*\$', |
|
|
|
|
|
r'echo\s+\$\w+\s*;', |
|
|
r'print\s+\$\w+\s*;', |
|
|
|
|
|
r'<\?=\s*\$\w+', |
|
|
|
|
|
r'eval\s*\(', |
|
|
r'innerHTML\s*=', |
|
|
|
|
|
r'query\s*\(.*\$_(GET|POST|REQUEST)', |
|
|
r'INSERT INTO.*\$\w+', |
|
|
r'mysql_query\s*\(.*\$', |
|
|
|
|
|
r'echo\s+["\']<[^>]+>\s*["\'].*\.\s*\$', |
|
|
] |
|
|
else: |
|
|
patterns = [ |
|
|
r'innerHTML\s*=', |
|
|
r'outerHTML\s*=', |
|
|
r'document\.write\s*\(', |
|
|
r'eval\s*\(', |
|
|
r'\.html\s*\(', |
|
|
r'insertAdjacentHTML\s*\(', |
|
|
r'location\s*=.*\+', |
|
|
r'window\.location\s*=', |
|
|
] |
|
|
|
|
|
|
|
|
for i, line in enumerate(lines, 1): |
|
|
for pattern in patterns: |
|
|
if re.search(pattern, line, re.IGNORECASE): |
|
|
|
|
|
start = max(0, i - 3) |
|
|
end = min(len(lines), i + 2) |
|
|
context_lines = lines[start:end] |
|
|
|
|
|
|
|
|
snippet = '\n'.join(context_lines) |
|
|
return snippet, i |
|
|
|
|
|
|
|
|
for i, line in enumerate(lines, 1): |
|
|
stripped = line.strip() |
|
|
|
|
|
if (stripped and |
|
|
not stripped.startswith('//') and |
|
|
not stripped.startswith('/*') and |
|
|
not stripped.startswith('*') and |
|
|
not stripped.startswith('#') and |
|
|
stripped != '<?php' and |
|
|
not stripped.startswith('/**')): |
|
|
|
|
|
start = max(0, i - 1) |
|
|
end = min(len(lines), i + 5) |
|
|
context_lines = lines[start:end] |
|
|
snippet = '\n'.join(context_lines) |
|
|
return snippet, i |
|
|
|
|
|
|
|
|
return code[:300] + "..." if len(code) > 300 else code, 1 |
|
|
|
|
|
|
|
|
def _get_suggestion(language: str) -> str: |
|
|
"""Get language-specific security suggestion""" |
|
|
if language == "php": |
|
|
return "Use htmlspecialchars($var, ENT_QUOTES, 'UTF-8') for output encoding" |
|
|
elif language in ["js", "javascript"]: |
|
|
return "Use textContent instead of innerHTML, or sanitize with DOMPurify" |
|
|
return "Sanitize user input before output" |
|
|
|