preethishsg's picture
Upload 4 files
1631730 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import uvicorn
from pathlib import Path
from rag_system import RAGSystem, initialize_from_documents
app = FastAPI(title="RAG System API", version="1.0.0")
# CORS middleware for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global RAG system instance
rag_system: Optional[RAGSystem] = None
DB_PATH = "vector_db.json"
# Request/Response Models
class Document(BaseModel):
text: str
metadata: Optional[Dict] = None
class InsertRequest(BaseModel):
documents: List[Document]
class InsertResponse(BaseModel):
success: bool
document_ids: List[str]
message: str
class SearchRequest(BaseModel):
query: str
k: int = 5
class SearchResponse(BaseModel):
results: List[Dict]
class QueryRequest(BaseModel):
query: str
k: int = 3
max_length: int = 150
class QueryResponse(BaseModel):
query: str
answer: str
retrieved_documents: List[Dict]
context: str
class StatsResponse(BaseModel):
total_documents: int
dimension: int
next_id: int
@app.on_event("startup")
async def startup_event():
"""Initialize RAG system on startup"""
global rag_system
print("Starting RAG System...")
# Check if we need to initialize from documents.json
documents_path = Path("documents.json")
if documents_path.exists() and not Path(DB_PATH).exists():
print("Initializing database from documents.json...")
rag_system = initialize_from_documents(str(documents_path), DB_PATH)
else:
print("Loading existing database...")
rag_system = RAGSystem(db_path=DB_PATH if Path(DB_PATH).exists() else None)
print("RAG System ready!")
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "healthy",
"message": "RAG System API is running",
"version": "1.0.0"
}
@app.get("/stats", response_model=StatsResponse)
async def get_stats():
"""Get database statistics"""
if rag_system is None:
raise HTTPException(status_code=500, detail="RAG system not initialized")
stats = rag_system.get_stats()
return StatsResponse(**stats)
@app.post("/insert", response_model=InsertResponse)
async def insert_documents(request: InsertRequest):
"""Insert documents into the vector database"""
if rag_system is None:
raise HTTPException(status_code=500, detail="RAG system not initialized")
try:
# Convert Pydantic models to dicts
documents = []
for doc in request.documents:
doc_dict = {"text": doc.text}
if doc.metadata:
doc_dict.update(doc.metadata)
documents.append(doc_dict)
# Insert documents
doc_ids = rag_system.insert_documents(documents)
# Save database
rag_system.save_db(DB_PATH)
return InsertResponse(
success=True,
document_ids=doc_ids,
message=f"Successfully inserted {len(doc_ids)} documents"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error inserting documents: {str(e)}")
@app.post("/search", response_model=SearchResponse)
async def search_documents(request: SearchRequest):
"""Search for similar documents"""
if rag_system is None:
raise HTTPException(status_code=500, detail="RAG system not initialized")
try:
results = rag_system.retrieve(request.query, k=request.k)
return SearchResponse(results=results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
@app.post("/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
"""Complete RAG query: retrieve +ßgenerate"""
if rag_system is None:
raise HTTPException(status_code=500, detail="RAG system not initialized")
try:
result = rag_system.query(
request.query,
k=request.k,
max_length=request.max_length
)
return QueryResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
if __name__ == "__main__":
import os
port = int(os.environ.get("PORT", 8080))
uvicorn.run(app, host="0.0.0.0", port=port)