Spaces:
Runtime error
Runtime error
File size: 14,684 Bytes
9e5bc69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
"""Graph setup module for database and model initialization. Phase A (Steps 1-2)"""
import os
import logging
from typing import Dict, Optional, Any
import sys
sys.path.append('..') # Add parent directory to path for imports
from my_config import MY_CONFIG
from neo4j import GraphDatabase
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, VectorStoreIndex, StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.llms.litellm import LiteLLM
# Set up environment
os.environ['HF_ENDPOINT'] = MY_CONFIG.HF_ENDPOINT
# Configure logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Neo4jConnection:
"""
Neo4j database connection manager.
"""
def __init__(self):
self.uri = MY_CONFIG.NEO4J_URI
self.username = MY_CONFIG.NEO4J_USER
self.password = MY_CONFIG.NEO4J_PASSWORD
self.database = getattr(MY_CONFIG, "NEO4J_DATABASE", None)
# Validate required configuration
if not self.uri:
raise ValueError("NEO4J_URI config is required")
if not self.username:
raise ValueError("NEO4J_USERNAME config is required")
if not self.password:
raise ValueError("NEO4J_PASSWORD config is required")
if not self.database:
raise ValueError("NEO4J_DATABASE config is required")
self.driver: Optional[GraphDatabase.driver] = None
def connect(self):
"""STEP 1.2: Initialize Neo4j driver with verification"""
if self.driver is None:
try:
self.driver = GraphDatabase.driver(
self.uri,
auth=(self.username, self.password)
)
self.driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
except Exception as e:
logger.error(f"❌ STEP 1.2 FAILED: Neo4j connection error: {e}")
self.driver = None
def disconnect(self):
"""Clean up Neo4j connection"""
if self.driver:
self.driver.close()
self.driver = None
logger.info("Neo4j connection closed")
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None):
"""Execute Cypher query with error handling"""
if not self.driver:
raise ConnectionError("Not connected to Neo4j database")
with self.driver.session(database=self.database) as session:
result = session.run(query, parameters or {})
records = [record.data() for record in result]
return records
class GraphRAGSetup:
"""
Main setup class for graph-based retrieval system.
Handles core initialization and configuration:
- Database connections (Neo4j and vector database)
- Model initialization and configuration
- Graph statistics and validation
- Search configuration loading
"""
def __init__(self):
logger.info("Starting graph system initialization")
# Initialize core components
self.config = MY_CONFIG # Add config attribute for GraphQueryEngine
self.neo4j_conn = None
self.query_engine = None
self.graph_stats = {}
self.drift_config = {}
self.llm = None
self.embedding_model = None
# Execute Step 1 initialization sequence
self._execute_step1_sequence()
logger.info("Graph system initialization complete")
def _execute_step1_sequence(self):
"""Execute complete Step 1 initialization sequence"""
# STEP 1.1-1.6: Initialize all components
self._setup_neo4j() # STEP 1.2
self._setup_vector_search() # STEP 1.3-1.6
self._load_graph_statistics() # STEP 2.1-2.4
self._load_drift_configuration() # STEP 2.5
def _setup_neo4j(self):
"""STEP 1.2: Initialize Neo4j driver with verification"""
try:
logger.info("Initializing Neo4j connection...")
self.neo4j_conn = Neo4jConnection()
self.neo4j_conn.connect()
# Verify connection with test query
if self.neo4j_conn.driver:
test_result = self.neo4j_conn.execute_query("MATCH (n) RETURN count(n) as total_nodes LIMIT 1")
node_count = test_result[0]['total_nodes'] if test_result else 0
logger.info(f"Neo4j connected - {node_count} nodes found")
except Exception as e:
logger.error(f"Neo4j connection error: {e}")
self.neo4j_conn = None
def _setup_vector_search(self):
"""STEP 1.3-1.5: Initialize vector database and LLM components"""
try:
logger.info("Setting up vector search and LLM...")
# STEP 1.5: Load embedding model
self.embedding_model = HuggingFaceEmbedding(
model_name=MY_CONFIG.EMBEDDING_MODEL
)
Settings.embed_model = self.embedding_model
logger.info(f"Embedding model loaded: {MY_CONFIG.EMBEDDING_MODEL}")
# STEP 1.6: Connect to vector database based on configuration
if MY_CONFIG.VECTOR_DB_TYPE == "cloud_zilliz":
if not MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT or not MY_CONFIG.ZILLIZ_TOKEN:
raise ValueError("Cloud database configuration missing. Set ZILLIZ_CLUSTER_ENDPOINT and ZILLIZ_TOKEN in .env")
vector_store = MilvusVectorStore(
uri=MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT,
token=MY_CONFIG.ZILLIZ_TOKEN,
dim=MY_CONFIG.EMBEDDING_LENGTH,
collection_name=MY_CONFIG.COLLECTION_NAME,
overwrite=False
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("Connected to cloud vector database")
else:
vector_store = MilvusVectorStore(
uri=MY_CONFIG.MILVUS_URI_HYBRID_GRAPH,
dim=MY_CONFIG.EMBEDDING_LENGTH,
collection_name=MY_CONFIG.COLLECTION_NAME,
overwrite=False
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("Connected to local vector database")
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store, storage_context=storage_context)
logger.info("Vector index loaded successfully")
# STEP 1.4: Initialize LLM provider
llm_model = MY_CONFIG.LLM_MODEL
self.llm = LiteLLM(model=llm_model)
Settings.llm = self.llm
logger.info(f"LLM initialized: {llm_model}")
self.query_engine = index.as_query_engine()
except Exception as e:
logger.error(f"Vector setup error: {e}")
self.query_engine = None
def _load_graph_statistics(self):
"""STEP 2.1-2.4: Load and validate graph data structure"""
try:
logger.info("Loading graph statistics and validation...")
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.warning("No Neo4j connection for statistics")
return
# STEP 2.1: Get node and relationship counts
stats_query = """
MATCH (n)
OPTIONAL MATCH ()-[r]-()
RETURN count(DISTINCT n) as node_count,
count(DISTINCT r) as relationship_count,
count(DISTINCT n.community_id) as community_count
"""
result = self.neo4j_conn.execute_query(stats_query)
if result:
stats = result[0]
self.graph_stats = {
'node_count': stats.get('node_count', 0),
'relationship_count': stats.get('relationship_count', 0),
'community_count': stats.get('community_count', 0)
}
logger.info(f"Graph validated - {self.graph_stats['node_count']} nodes, "
f"{self.graph_stats['relationship_count']} relationships, "
f"{self.graph_stats['community_count']} communities")
except Exception as e:
logger.error(f"Graph statistics error: {e}")
self.graph_stats = {}
def _load_drift_configuration(self):
"""STEP 2.5: Load DRIFT search metadata and configuration"""
logger.info("Loading search configuration...")
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.warning("No Neo4j connection for search configuration")
self.drift_config = {}
return
# Query for all DRIFT-related nodes
drift_metadata_query = """
OPTIONAL MATCH (dm:DriftMetadata)
OPTIONAL MATCH (dc:DriftConfiguration)
OPTIONAL MATCH (csi:CommunitySearchIndex)
OPTIONAL MATCH (gm:GraphMetadata)
OPTIONAL MATCH (cm:CommunitiesMetadata)
RETURN dm, dc, csi, gm, cm
"""
result = self.neo4j_conn.execute_query(drift_metadata_query)
if result and result[0]:
record = result[0]
drift_config = {}
# Extract DriftMetadata properties
if record.get('dm'):
dm_props = dict(record['dm'])
drift_config.update(dm_props)
logger.info("DriftMetadata node found")
# Extract DriftConfiguration properties
if record.get('dc'):
dc_props = dict(record['dc'])
drift_config['configuration'] = dc_props
logger.info("DriftConfiguration node found")
# Extract CommunitySearchIndex properties
if record.get('csi'):
csi_props = dict(record['csi'])
drift_config['community_search_index'] = csi_props
logger.info("CommunitySearchIndex node found")
# Extract GraphMetadata properties
if record.get('gm'):
gm_props = dict(record['gm'])
drift_config['graph_metadata'] = gm_props
logger.info("GraphMetadata node found")
# Extract CommunitiesMetadata properties
if record.get('cm'):
cm_props = dict(record['cm'])
drift_config['communities_metadata'] = cm_props
logger.info("CommunitiesMetadata node found")
self.drift_config = drift_config
logger.info("Search configuration loaded from Neo4j nodes")
else:
logger.warning("No metadata nodes found in Neo4j")
self.drift_config = {}
def validate_system_readiness(self):
"""Validate all required components are initialized"""
ready = True
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.error("Neo4j connection not available")
ready = False
if not self.query_engine:
logger.error("Vector query engine not available")
ready = False
if not self.graph_stats:
logger.warning("Graph statistics not loaded")
if ready:
logger.info("System readiness validated")
return ready
def get_system_status(self):
"""Get detailed system status information"""
return {
"neo4j_connected": bool(self.neo4j_conn and self.neo4j_conn.driver),
"vector_engine_ready": bool(self.query_engine),
"graph_stats_loaded": bool(self.graph_stats),
"drift_config_loaded": bool(self.drift_config),
"llm_ready": bool(self.llm),
"graph_stats": self.graph_stats,
"drift_config": self.drift_config
}
async def cleanup_async_tasks(self, timeout: float = 2.0) -> None:
"""
Clean up async tasks and pending operations.
Handles proper cleanup of LiteLLM and other async tasks to prevent
'Task was destroyed but it is pending!' warnings.
"""
try:
import asyncio
# Import cleanup function if available
try:
from litellm_patch import cleanup_all_async_tasks
await cleanup_all_async_tasks(timeout=timeout)
logger.info(f"Cleaned up async tasks with timeout {timeout}s")
except ImportError:
# Fallback: Cancel pending tasks manually
pending_tasks = [task for task in asyncio.all_tasks() if not task.done()]
if pending_tasks:
logger.info(f"Cancelling {len(pending_tasks)} pending tasks")
for task in pending_tasks:
task.cancel()
# Wait for cancellation with timeout
try:
await asyncio.wait_for(
asyncio.gather(*pending_tasks, return_exceptions=True),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning("Some tasks did not complete within timeout")
except Exception as e:
logger.error(f"Error during async cleanup: {e}")
def close(self):
"""Clean up all connections"""
if self.neo4j_conn:
self.neo4j_conn.disconnect()
logger.info("Setup cleanup complete")
def create_graphrag_setup():
"""Factory function to create GraphRAG setup instance"""
return GraphRAGSetup()
# Exports
__all__ = ['GraphRAGSetup', 'create_graphrag_setup'] |