preethishsg commited on
Commit
1631730
·
verified ·
1 Parent(s): 36d88d5

Upload 4 files

Browse files
Files changed (4) hide show
  1. main.py +161 -0
  2. rag_system.py +214 -0
  3. requirements.txt +11 -0
  4. vector_db.py +118 -0
main.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List, Dict, Optional
5
+ import uvicorn
6
+ from pathlib import Path
7
+
8
+ from rag_system import RAGSystem, initialize_from_documents
9
+
10
+ app = FastAPI(title="RAG System API", version="1.0.0")
11
+
12
+ # CORS middleware for frontend
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # Global RAG system instance
22
+ rag_system: Optional[RAGSystem] = None
23
+ DB_PATH = "vector_db.json"
24
+
25
+ # Request/Response Models
26
+ class Document(BaseModel):
27
+ text: str
28
+ metadata: Optional[Dict] = None
29
+
30
+ class InsertRequest(BaseModel):
31
+ documents: List[Document]
32
+
33
+ class InsertResponse(BaseModel):
34
+ success: bool
35
+ document_ids: List[str]
36
+ message: str
37
+
38
+ class SearchRequest(BaseModel):
39
+ query: str
40
+ k: int = 5
41
+
42
+ class SearchResponse(BaseModel):
43
+ results: List[Dict]
44
+
45
+ class QueryRequest(BaseModel):
46
+ query: str
47
+ k: int = 3
48
+ max_length: int = 150
49
+
50
+ class QueryResponse(BaseModel):
51
+ query: str
52
+ answer: str
53
+ retrieved_documents: List[Dict]
54
+ context: str
55
+
56
+ class StatsResponse(BaseModel):
57
+ total_documents: int
58
+ dimension: int
59
+ next_id: int
60
+
61
+ @app.on_event("startup")
62
+ async def startup_event():
63
+ """Initialize RAG system on startup"""
64
+ global rag_system
65
+
66
+ print("Starting RAG System...")
67
+
68
+ # Check if we need to initialize from documents.json
69
+ documents_path = Path("documents.json")
70
+
71
+ if documents_path.exists() and not Path(DB_PATH).exists():
72
+ print("Initializing database from documents.json...")
73
+ rag_system = initialize_from_documents(str(documents_path), DB_PATH)
74
+ else:
75
+ print("Loading existing database...")
76
+ rag_system = RAGSystem(db_path=DB_PATH if Path(DB_PATH).exists() else None)
77
+
78
+ print("RAG System ready!")
79
+
80
+ @app.get("/")
81
+ async def root():
82
+ """Health check endpoint"""
83
+ return {
84
+ "status": "healthy",
85
+ "message": "RAG System API is running",
86
+ "version": "1.0.0"
87
+ }
88
+
89
+ @app.get("/stats", response_model=StatsResponse)
90
+ async def get_stats():
91
+ """Get database statistics"""
92
+ if rag_system is None:
93
+ raise HTTPException(status_code=500, detail="RAG system not initialized")
94
+
95
+ stats = rag_system.get_stats()
96
+ return StatsResponse(**stats)
97
+
98
+ @app.post("/insert", response_model=InsertResponse)
99
+ async def insert_documents(request: InsertRequest):
100
+ """Insert documents into the vector database"""
101
+ if rag_system is None:
102
+ raise HTTPException(status_code=500, detail="RAG system not initialized")
103
+
104
+ try:
105
+ # Convert Pydantic models to dicts
106
+ documents = []
107
+ for doc in request.documents:
108
+ doc_dict = {"text": doc.text}
109
+ if doc.metadata:
110
+ doc_dict.update(doc.metadata)
111
+ documents.append(doc_dict)
112
+
113
+ # Insert documents
114
+ doc_ids = rag_system.insert_documents(documents)
115
+
116
+ # Save database
117
+ rag_system.save_db(DB_PATH)
118
+
119
+ return InsertResponse(
120
+ success=True,
121
+ document_ids=doc_ids,
122
+ message=f"Successfully inserted {len(doc_ids)} documents"
123
+ )
124
+
125
+ except Exception as e:
126
+ raise HTTPException(status_code=500, detail=f"Error inserting documents: {str(e)}")
127
+
128
+ @app.post("/search", response_model=SearchResponse)
129
+ async def search_documents(request: SearchRequest):
130
+ """Search for similar documents"""
131
+ if rag_system is None:
132
+ raise HTTPException(status_code=500, detail="RAG system not initialized")
133
+
134
+ try:
135
+ results = rag_system.retrieve(request.query, k=request.k)
136
+ return SearchResponse(results=results)
137
+
138
+ except Exception as e:
139
+ raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}")
140
+
141
+ @app.post("/query", response_model=QueryResponse)
142
+ async def query_rag(request: QueryRequest):
143
+ """Complete RAG query: retrieve +ßgenerate"""
144
+ if rag_system is None:
145
+ raise HTTPException(status_code=500, detail="RAG system not initialized")
146
+
147
+ try:
148
+ result = rag_system.query(
149
+ request.query,
150
+ k=request.k,
151
+ max_length=request.max_length
152
+ )
153
+ return QueryResponse(**result)
154
+
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")
157
+
158
+ if __name__ == "__main__":
159
+ import os
160
+ port = int(os.environ.get("PORT", 8080))
161
+ uvicorn.run(app, host="0.0.0.0", port=port)
rag_system.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ import numpy as np
5
+ from typing import List, Dict
6
+ from pathlib import Path
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ from vector_db import VectorDatabase
10
+
11
+
12
+ class RAGSystem:
13
+ """
14
+ RAG System:
15
+ - Local embeddings using BGE-micro
16
+ - Custom vector database for retrieval
17
+ - Hosted lightweight LLM (Hugging Face Inference API) for generation
18
+ """
19
+
20
+ def __init__(self, db_path: str = None):
21
+ print("Initializing RAG System...")
22
+
23
+ # -----------------------------
24
+ # Embedding Model (Local)
25
+ # -----------------------------
26
+ print("Loading embedding model (BGE-micro)...")
27
+ self.embed_tokenizer = AutoTokenizer.from_pretrained("TaylorAI/bge-micro")
28
+ self.embed_model = AutoModel.from_pretrained("TaylorAI/bge-micro")
29
+ self.embed_model.eval()
30
+
31
+ # -----------------------------
32
+ # Vector Database
33
+ # -----------------------------
34
+ if db_path and Path(db_path).exists():
35
+ print(f"Loading vector DB from {db_path}")
36
+ self.db = VectorDatabase.load(db_path)
37
+ else:
38
+ print("Creating new vector DB")
39
+ self.db = VectorDatabase(dimension=384)
40
+
41
+ # -----------------------------
42
+ # Hosted LLM Config
43
+ # -----------------------------
44
+ self.hf_api_token = os.getenv("HF_API_TOKEN")
45
+ self.hf_model_url = (
46
+ "https://api-inference.huggingface.co/models/google/flan-t5-small"
47
+ )
48
+
49
+ if not self.hf_api_token:
50
+ print("WARNING: HF_API_TOKEN not set. Generation will fail.")
51
+
52
+ print("RAG System initialized successfully!")
53
+
54
+ # --------------------------------------------------
55
+ # Embedding
56
+ # --------------------------------------------------
57
+ def encode_text(self, text: str) -> np.ndarray:
58
+ with torch.no_grad():
59
+ inputs = self.embed_tokenizer(
60
+ text,
61
+ padding=True,
62
+ truncation=True,
63
+ max_length=512,
64
+ return_tensors="pt",
65
+ )
66
+ outputs = self.embed_model(**inputs)
67
+ embedding = outputs.last_hidden_state[:, 0, :].numpy()
68
+ return embedding[0]
69
+
70
+ def encode_batch(self, texts: List[str]) -> List[np.ndarray]:
71
+ return [self.encode_text(text) for text in texts]
72
+
73
+ # --------------------------------------------------
74
+ # Insert
75
+ # --------------------------------------------------
76
+ def insert_documents(self, documents: List[Dict]) -> List[str]:
77
+ texts = []
78
+ processed_docs = []
79
+
80
+ for doc in documents:
81
+ text = doc.get("data") or doc.get("text", "")
82
+ texts.append(text)
83
+
84
+ metadata = {"text": text}
85
+ for k, v in doc.items():
86
+ if k not in ["data", "text"]:
87
+ metadata[k] = v
88
+
89
+ processed_docs.append(metadata)
90
+
91
+ embeddings = self.encode_batch(texts)
92
+ return self.db.batch_insert(embeddings, processed_docs)
93
+
94
+ # --------------------------------------------------
95
+ # Retrieve
96
+ # --------------------------------------------------
97
+ def retrieve(self, query: str, k: int = 5) -> List[Dict]:
98
+ query_embedding = self.encode_text(query)
99
+ results = self.db.search(query_embedding, k=k)
100
+
101
+ return [
102
+ {"id": doc_id, "score": score, "metadata": metadata}
103
+ for doc_id, score, metadata in results
104
+ ]
105
+
106
+ # --------------------------------------------------
107
+ # Hosted LLM Generation (Optimized Prompt)
108
+ # --------------------------------------------------
109
+ def generate_response(self, query: str, context: str, max_length: int = 150) -> str:
110
+ if not self.hf_api_token:
111
+ return "HF_API_TOKEN not configured."
112
+
113
+ headers = {
114
+ "Authorization": f"Bearer {self.hf_api_token}",
115
+ "Content-Type": "application/json",
116
+ }
117
+
118
+ # 🔥 Optimized RAG Prompt
119
+ prompt = f"""
120
+ You are an intelligent assistant answering questions strictly using the provided context.
121
+
122
+ Rules:
123
+ - Use only the given context.
124
+ - If the answer is not present, say: "The information is not available in the provided documents."
125
+ - Answer clearly and concisely.
126
+
127
+ Context:
128
+ {context}
129
+
130
+ Question:
131
+ {query}
132
+
133
+ Answer:
134
+ """
135
+
136
+ payload = {
137
+ "inputs": prompt.strip(),
138
+ "parameters": {
139
+ "max_new_tokens": max_length,
140
+ "temperature": 0.2,
141
+ "top_p": 0.9,
142
+ "do_sample": False,
143
+ },
144
+ }
145
+
146
+ try:
147
+ response = requests.post(
148
+ self.hf_model_url,
149
+ headers=headers,
150
+ json=payload,
151
+ timeout=30,
152
+ )
153
+ response.raise_for_status()
154
+ result = response.json()
155
+
156
+ if isinstance(result, list) and "generated_text" in result[0]:
157
+ return result[0]["generated_text"].strip()
158
+
159
+ return str(result)
160
+
161
+ except Exception as e:
162
+ return f"LLM generation error: {str(e)}"
163
+
164
+ # --------------------------------------------------
165
+ # Full RAG Query
166
+ # --------------------------------------------------
167
+ def query(self, query: str, k: int = 3, max_length: int = 150) -> Dict:
168
+ retrieved_docs = self.retrieve(query, k=k)
169
+
170
+ if not retrieved_docs:
171
+ return {
172
+ "query": query,
173
+ "answer": "No relevant documents found.",
174
+ "retrieved_documents": [],
175
+ "context": "",
176
+ }
177
+
178
+ context = " ".join(
179
+ doc["metadata"].get("text", "") for doc in retrieved_docs
180
+ )
181
+
182
+ answer = self.generate_response(query, context, max_length)
183
+
184
+ return {
185
+ "query": query,
186
+ "answer": answer,
187
+ "retrieved_documents": retrieved_docs,
188
+ "context": context[:500],
189
+ }
190
+
191
+ # --------------------------------------------------
192
+ # Utilities
193
+ # --------------------------------------------------
194
+ def save_db(self, filepath: str):
195
+ self.db.save(filepath)
196
+
197
+ def get_stats(self) -> Dict:
198
+ return self.db.stats()
199
+
200
+
201
+ def initialize_from_documents(json_path: str, db_path: str = "vector_db.json"):
202
+ import json
203
+
204
+ rag = RAGSystem()
205
+
206
+ with open(json_path, "r") as f:
207
+ documents = json.load(f)
208
+
209
+ print(f"Loading {len(documents)} documents...")
210
+ rag.insert_documents(documents)
211
+ rag.save_db(db_path)
212
+
213
+ print("Database initialized successfully.")
214
+ return rag
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ torch==2.5.1
5
+ transformers==4.35.0
6
+ numpy==1.24.3
7
+ python-multipart==0.0.6
8
+ sentencepiece==0.1.99
9
+ accelerate==0.24.1
10
+ openai
11
+ requests==2.31.0
vector_db.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Dict, Tuple
3
+ import json
4
+ from pathlib import Path
5
+
6
+ class VectorDatabase:
7
+ """
8
+ Custom vector database with flat index supporting:
9
+ - Insert operations (single and batch)
10
+ - Top-k search using dot product similarity
11
+ """
12
+
13
+ def __init__(self, dimension: int = 384):
14
+ self.dimension = dimension
15
+ self.vectors = []
16
+ self.metadata = []
17
+ self.ids = []
18
+ self.next_id = 0
19
+
20
+ def insert(self, vector: np.ndarray, metadata: Dict = None) -> str:
21
+ """Insert a single vector with optional metadata"""
22
+ if vector.shape[0] != self.dimension:
23
+ raise ValueError(f"Vector dimension {vector.shape[0]} doesn't match database dimension {self.dimension}")
24
+
25
+ doc_id = f"doc_{self.next_id}"
26
+ self.next_id += 1
27
+
28
+ self.vectors.append(vector)
29
+ self.metadata.append(metadata or {})
30
+ self.ids.append(doc_id)
31
+
32
+ return doc_id
33
+
34
+ def batch_insert(self, vectors: List[np.ndarray], metadata_list: List[Dict] = None) -> List[str]:
35
+ """Insert multiple vectors at once"""
36
+ if metadata_list is None:
37
+ metadata_list = [{}] * len(vectors)
38
+
39
+ if len(vectors) != len(metadata_list):
40
+ raise ValueError("Number of vectors and metadata entries must match")
41
+
42
+ doc_ids = []
43
+ for vector, metadata in zip(vectors, metadata_list):
44
+ doc_id = self.insert(vector, metadata)
45
+ doc_ids.append(doc_id)
46
+
47
+ return doc_ids
48
+
49
+ def search(self, query_vector: np.ndarray, k: int = 5) -> List[Tuple[str, float, Dict]]:
50
+ """
51
+ Search for top-k most similar vectors using dot product similarity
52
+ Returns: List of (doc_id, similarity_score, metadata) tuples
53
+ """
54
+ if len(self.vectors) == 0:
55
+ return []
56
+
57
+ if query_vector.shape[0] != self.dimension:
58
+ raise ValueError(f"Query vector dimension {query_vector.shape[0]} doesn't match database dimension {self.dimension}")
59
+
60
+ # Normalize query vector for dot product similarity
61
+ query_norm = query_vector / (np.linalg.norm(query_vector) + 1e-8)
62
+
63
+ # Calculate dot product with all vectors
64
+ similarities = []
65
+ for i, vec in enumerate(self.vectors):
66
+ vec_norm = vec / (np.linalg.norm(vec) + 1e-8)
67
+ similarity = np.dot(query_norm, vec_norm)
68
+ similarities.append((i, similarity))
69
+
70
+ # Sort by similarity (descending)
71
+ similarities.sort(key=lambda x: x[1], reverse=True)
72
+
73
+ # Return top-k results
74
+ k = min(k, len(similarities))
75
+ results = []
76
+ for i, sim in similarities[:k]:
77
+ results.append((self.ids[i], float(sim), self.metadata[i]))
78
+
79
+ return results
80
+
81
+ def save(self, filepath: str):
82
+ """Save database to disk"""
83
+ data = {
84
+ 'dimension': self.dimension,
85
+ 'vectors': [v.tolist() for v in self.vectors],
86
+ 'metadata': self.metadata,
87
+ 'ids': self.ids,
88
+ 'next_id': self.next_id
89
+ }
90
+
91
+ Path(filepath).parent.mkdir(parents=True, exist_ok=True)
92
+ with open(filepath, 'w') as f:
93
+ json.dump(data, f)
94
+
95
+ @classmethod
96
+ def load(cls, filepath: str) -> 'VectorDatabase':
97
+ """Load database from disk"""
98
+ with open(filepath, 'r') as f:
99
+ data = json.load(f)
100
+
101
+ db = cls(dimension=data['dimension'])
102
+ db.vectors = [np.array(v) for v in data['vectors']]
103
+ db.metadata = data['metadata']
104
+ db.ids = data['ids']
105
+ db.next_id = data['next_id']
106
+
107
+ return db
108
+
109
+ def __len__(self):
110
+ return len(self.vectors)
111
+
112
+ def stats(self) -> Dict:
113
+ """Return database statistics"""
114
+ return {
115
+ 'total_documents': len(self.vectors),
116
+ 'dimension': self.dimension,
117
+ 'next_id': self.next_id
118
+ }