from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel, Field from typing import List, Optional from enum import Enum from api.dependencies import get_model_service from api.services.model_service import ModelService router = APIRouter() class LanguageEnum(str, Enum): PHP = "php" JS = "js" JAVASCRIPT = "javascript" class VulnerabilityDetail(BaseModel): type: str = "xss" severity: str line_number: Optional[int] = None description: str code_snippet: str suggestion: str class ScanRequest(BaseModel): code: str = Field(..., description="Source code to analyze") language: LanguageEnum = Field(..., description="Programming language (php or js)") file_path: Optional[str] = Field(None, description="File path for context") class ScanResult(BaseModel): is_vulnerable: bool confidence: float label: str vulnerabilities: List[VulnerabilityDetail] = [] processing_time_ms: Optional[int] = None cached: bool = False class BatchScanRequest(BaseModel): files: List[ScanRequest] class BatchScanResult(BaseModel): job_id: str total_files: int results: List[ScanResult] @router.post("/scan", response_model=ScanResult) async def scan_code( request: ScanRequest, model_service: ModelService = Depends(get_model_service) ): """ Analyze a single code snippet for XSS vulnerabilities """ try: import time start = time.time() # Run prediction with multi-vulnerability support result = model_service.predict_multi( request.code, request.language.value ) # Build vulnerability list from all detected vulnerabilities vulnerabilities = [] for vuln_info in result['vulnerabilities']: confidence = vuln_info['confidence'] # Determine severity based on confidence if confidence >= 0.95: severity = "critical" elif confidence >= 0.85: severity = "high" elif confidence >= 0.70: severity = "medium" else: severity = "low" # Get code snippet for this line range lines = request.code.split('\n') start_line = vuln_info['start_line'] end_line = min(vuln_info['end_line'], len(lines)) code_snippet = '\n'.join(lines[start_line-1:min(start_line+5, end_line)]) vuln = VulnerabilityDetail( type="xss", severity=severity, line_number=start_line, description=f"Potential XSS vulnerability detected with {confidence:.1%} confidence (lines {start_line}-{end_line})", code_snippet=code_snippet[:500], # Limit snippet length suggestion=_get_suggestion(request.language.value) ) vulnerabilities.append(vuln) processing_time = int((time.time() - start) * 1000) # Use max confidence for overall result max_confidence = result['max_confidence'] is_vulnerable = result['is_vulnerable'] label = "VULNERABLE" if is_vulnerable else "SAFE" return ScanResult( is_vulnerable=is_vulnerable, confidence=max_confidence, label=label, vulnerabilities=vulnerabilities, processing_time_ms=processing_time, cached=False ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/scan/batch", response_model=BatchScanResult) async def scan_batch( request: BatchScanRequest, model_service: ModelService = Depends(get_model_service) ): """ Analyze multiple code files in batch """ import uuid job_id = str(uuid.uuid4()) results = [] for file_request in request.files: try: result = await scan_code(file_request, model_service) results.append(result) except Exception as e: # Add error result results.append(ScanResult( is_vulnerable=False, confidence=0.0, label="ERROR", vulnerabilities=[], processing_time_ms=0, cached=False )) return BatchScanResult( job_id=job_id, total_files=len(request.files), results=results ) def _extract_vulnerable_code(code: str, language: str) -> tuple: """ Extract the most likely vulnerable code snippet and line number. Returns (code_snippet, line_number) """ import re lines = code.split('\n') # Define vulnerable patterns by language if language == "php": patterns = [ # Direct output of user input superglobals r'echo\s+\$_(GET|POST|REQUEST|COOKIE)', r'print\s+\$_(GET|POST|REQUEST|COOKIE)', # Echo with array access (database output) - common stored XSS r'echo\s+["\'].*\.\s*\$\w+\[', r'echo\s+["\'].*\$\w+\[.*\]', # Print with concatenation r'print\s+["\'].*\.\s*\$', # Unescaped variable in echo r'echo\s+\$\w+\s*;', r'print\s+\$\w+\s*;', # Short echo tag with variable r'<\?=\s*\$\w+', # Dangerous functions r'eval\s*\(', r'innerHTML\s*=', # SQL with user input (can lead to stored XSS) r'query\s*\(.*\$_(GET|POST|REQUEST)', r'INSERT INTO.*\$\w+', r'mysql_query\s*\(.*\$', # Direct concatenation in HTML r'echo\s+["\']<[^>]+>\s*["\'].*\.\s*\$', ] else: # JavaScript patterns = [ r'innerHTML\s*=', r'outerHTML\s*=', r'document\.write\s*\(', r'eval\s*\(', r'\.html\s*\(', # jQuery r'insertAdjacentHTML\s*\(', r'location\s*=.*\+', # URL manipulation r'window\.location\s*=', ] # Search for patterns and find matching lines for i, line in enumerate(lines, 1): for pattern in patterns: if re.search(pattern, line, re.IGNORECASE): # Get context: 2 lines before and after start = max(0, i - 3) end = min(len(lines), i + 2) context_lines = lines[start:end] # Mark the vulnerable line snippet = '\n'.join(context_lines) return snippet, i # If no specific pattern found, skip comments and find real code for i, line in enumerate(lines, 1): stripped = line.strip() # Skip empty lines, comments, and PHP opening tag if (stripped and not stripped.startswith('//') and not stripped.startswith('/*') and not stripped.startswith('*') and not stripped.startswith('#') and stripped != ' 300 else code, 1 def _get_suggestion(language: str) -> str: """Get language-specific security suggestion""" if language == "php": return "Use htmlspecialchars($var, ENT_QUOTES, 'UTF-8') for output encoding" elif language in ["js", "javascript"]: return "Use textContent instead of innerHTML, or sanitize with DOMPurify" return "Sanitize user input before output"