niloydebbarma's picture
Upload 8 files
9e5bc69 verified
"""Graph setup module for database and model initialization. Phase A (Steps 1-2)"""
import os
import logging
from typing import Dict, Optional, Any
import sys
sys.path.append('..') # Add parent directory to path for imports
from my_config import MY_CONFIG
from neo4j import GraphDatabase
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings, VectorStoreIndex, StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from llama_index.llms.litellm import LiteLLM
# Set up environment
os.environ['HF_ENDPOINT'] = MY_CONFIG.HF_ENDPOINT
# Configure logging
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Neo4jConnection:
"""
Neo4j database connection manager.
"""
def __init__(self):
self.uri = MY_CONFIG.NEO4J_URI
self.username = MY_CONFIG.NEO4J_USER
self.password = MY_CONFIG.NEO4J_PASSWORD
self.database = getattr(MY_CONFIG, "NEO4J_DATABASE", None)
# Validate required configuration
if not self.uri:
raise ValueError("NEO4J_URI config is required")
if not self.username:
raise ValueError("NEO4J_USERNAME config is required")
if not self.password:
raise ValueError("NEO4J_PASSWORD config is required")
if not self.database:
raise ValueError("NEO4J_DATABASE config is required")
self.driver: Optional[GraphDatabase.driver] = None
def connect(self):
"""STEP 1.2: Initialize Neo4j driver with verification"""
if self.driver is None:
try:
self.driver = GraphDatabase.driver(
self.uri,
auth=(self.username, self.password)
)
self.driver.verify_connectivity()
logger.info(f"Connected to Neo4j at {self.uri}")
except Exception as e:
logger.error(f"❌ STEP 1.2 FAILED: Neo4j connection error: {e}")
self.driver = None
def disconnect(self):
"""Clean up Neo4j connection"""
if self.driver:
self.driver.close()
self.driver = None
logger.info("Neo4j connection closed")
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None):
"""Execute Cypher query with error handling"""
if not self.driver:
raise ConnectionError("Not connected to Neo4j database")
with self.driver.session(database=self.database) as session:
result = session.run(query, parameters or {})
records = [record.data() for record in result]
return records
class GraphRAGSetup:
"""
Main setup class for graph-based retrieval system.
Handles core initialization and configuration:
- Database connections (Neo4j and vector database)
- Model initialization and configuration
- Graph statistics and validation
- Search configuration loading
"""
def __init__(self):
logger.info("Starting graph system initialization")
# Initialize core components
self.config = MY_CONFIG # Add config attribute for GraphQueryEngine
self.neo4j_conn = None
self.query_engine = None
self.graph_stats = {}
self.drift_config = {}
self.llm = None
self.embedding_model = None
# Execute Step 1 initialization sequence
self._execute_step1_sequence()
logger.info("Graph system initialization complete")
def _execute_step1_sequence(self):
"""Execute complete Step 1 initialization sequence"""
# STEP 1.1-1.6: Initialize all components
self._setup_neo4j() # STEP 1.2
self._setup_vector_search() # STEP 1.3-1.6
self._load_graph_statistics() # STEP 2.1-2.4
self._load_drift_configuration() # STEP 2.5
def _setup_neo4j(self):
"""STEP 1.2: Initialize Neo4j driver with verification"""
try:
logger.info("Initializing Neo4j connection...")
self.neo4j_conn = Neo4jConnection()
self.neo4j_conn.connect()
# Verify connection with test query
if self.neo4j_conn.driver:
test_result = self.neo4j_conn.execute_query("MATCH (n) RETURN count(n) as total_nodes LIMIT 1")
node_count = test_result[0]['total_nodes'] if test_result else 0
logger.info(f"Neo4j connected - {node_count} nodes found")
except Exception as e:
logger.error(f"Neo4j connection error: {e}")
self.neo4j_conn = None
def _setup_vector_search(self):
"""STEP 1.3-1.5: Initialize vector database and LLM components"""
try:
logger.info("Setting up vector search and LLM...")
# STEP 1.5: Load embedding model
self.embedding_model = HuggingFaceEmbedding(
model_name=MY_CONFIG.EMBEDDING_MODEL
)
Settings.embed_model = self.embedding_model
logger.info(f"Embedding model loaded: {MY_CONFIG.EMBEDDING_MODEL}")
# STEP 1.6: Connect to vector database based on configuration
if MY_CONFIG.VECTOR_DB_TYPE == "cloud_zilliz":
if not MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT or not MY_CONFIG.ZILLIZ_TOKEN:
raise ValueError("Cloud database configuration missing. Set ZILLIZ_CLUSTER_ENDPOINT and ZILLIZ_TOKEN in .env")
vector_store = MilvusVectorStore(
uri=MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT,
token=MY_CONFIG.ZILLIZ_TOKEN,
dim=MY_CONFIG.EMBEDDING_LENGTH,
collection_name=MY_CONFIG.COLLECTION_NAME,
overwrite=False
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("Connected to cloud vector database")
else:
vector_store = MilvusVectorStore(
uri=MY_CONFIG.MILVUS_URI_HYBRID_GRAPH,
dim=MY_CONFIG.EMBEDDING_LENGTH,
collection_name=MY_CONFIG.COLLECTION_NAME,
overwrite=False
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
logger.info("Connected to local vector database")
index = VectorStoreIndex.from_vector_store(
vector_store=vector_store, storage_context=storage_context)
logger.info("Vector index loaded successfully")
# STEP 1.4: Initialize LLM provider
llm_model = MY_CONFIG.LLM_MODEL
self.llm = LiteLLM(model=llm_model)
Settings.llm = self.llm
logger.info(f"LLM initialized: {llm_model}")
self.query_engine = index.as_query_engine()
except Exception as e:
logger.error(f"Vector setup error: {e}")
self.query_engine = None
def _load_graph_statistics(self):
"""STEP 2.1-2.4: Load and validate graph data structure"""
try:
logger.info("Loading graph statistics and validation...")
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.warning("No Neo4j connection for statistics")
return
# STEP 2.1: Get node and relationship counts
stats_query = """
MATCH (n)
OPTIONAL MATCH ()-[r]-()
RETURN count(DISTINCT n) as node_count,
count(DISTINCT r) as relationship_count,
count(DISTINCT n.community_id) as community_count
"""
result = self.neo4j_conn.execute_query(stats_query)
if result:
stats = result[0]
self.graph_stats = {
'node_count': stats.get('node_count', 0),
'relationship_count': stats.get('relationship_count', 0),
'community_count': stats.get('community_count', 0)
}
logger.info(f"Graph validated - {self.graph_stats['node_count']} nodes, "
f"{self.graph_stats['relationship_count']} relationships, "
f"{self.graph_stats['community_count']} communities")
except Exception as e:
logger.error(f"Graph statistics error: {e}")
self.graph_stats = {}
def _load_drift_configuration(self):
"""STEP 2.5: Load DRIFT search metadata and configuration"""
logger.info("Loading search configuration...")
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.warning("No Neo4j connection for search configuration")
self.drift_config = {}
return
# Query for all DRIFT-related nodes
drift_metadata_query = """
OPTIONAL MATCH (dm:DriftMetadata)
OPTIONAL MATCH (dc:DriftConfiguration)
OPTIONAL MATCH (csi:CommunitySearchIndex)
OPTIONAL MATCH (gm:GraphMetadata)
OPTIONAL MATCH (cm:CommunitiesMetadata)
RETURN dm, dc, csi, gm, cm
"""
result = self.neo4j_conn.execute_query(drift_metadata_query)
if result and result[0]:
record = result[0]
drift_config = {}
# Extract DriftMetadata properties
if record.get('dm'):
dm_props = dict(record['dm'])
drift_config.update(dm_props)
logger.info("DriftMetadata node found")
# Extract DriftConfiguration properties
if record.get('dc'):
dc_props = dict(record['dc'])
drift_config['configuration'] = dc_props
logger.info("DriftConfiguration node found")
# Extract CommunitySearchIndex properties
if record.get('csi'):
csi_props = dict(record['csi'])
drift_config['community_search_index'] = csi_props
logger.info("CommunitySearchIndex node found")
# Extract GraphMetadata properties
if record.get('gm'):
gm_props = dict(record['gm'])
drift_config['graph_metadata'] = gm_props
logger.info("GraphMetadata node found")
# Extract CommunitiesMetadata properties
if record.get('cm'):
cm_props = dict(record['cm'])
drift_config['communities_metadata'] = cm_props
logger.info("CommunitiesMetadata node found")
self.drift_config = drift_config
logger.info("Search configuration loaded from Neo4j nodes")
else:
logger.warning("No metadata nodes found in Neo4j")
self.drift_config = {}
def validate_system_readiness(self):
"""Validate all required components are initialized"""
ready = True
if not self.neo4j_conn or not self.neo4j_conn.driver:
logger.error("Neo4j connection not available")
ready = False
if not self.query_engine:
logger.error("Vector query engine not available")
ready = False
if not self.graph_stats:
logger.warning("Graph statistics not loaded")
if ready:
logger.info("System readiness validated")
return ready
def get_system_status(self):
"""Get detailed system status information"""
return {
"neo4j_connected": bool(self.neo4j_conn and self.neo4j_conn.driver),
"vector_engine_ready": bool(self.query_engine),
"graph_stats_loaded": bool(self.graph_stats),
"drift_config_loaded": bool(self.drift_config),
"llm_ready": bool(self.llm),
"graph_stats": self.graph_stats,
"drift_config": self.drift_config
}
async def cleanup_async_tasks(self, timeout: float = 2.0) -> None:
"""
Clean up async tasks and pending operations.
Handles proper cleanup of LiteLLM and other async tasks to prevent
'Task was destroyed but it is pending!' warnings.
"""
try:
import asyncio
# Import cleanup function if available
try:
from litellm_patch import cleanup_all_async_tasks
await cleanup_all_async_tasks(timeout=timeout)
logger.info(f"Cleaned up async tasks with timeout {timeout}s")
except ImportError:
# Fallback: Cancel pending tasks manually
pending_tasks = [task for task in asyncio.all_tasks() if not task.done()]
if pending_tasks:
logger.info(f"Cancelling {len(pending_tasks)} pending tasks")
for task in pending_tasks:
task.cancel()
# Wait for cancellation with timeout
try:
await asyncio.wait_for(
asyncio.gather(*pending_tasks, return_exceptions=True),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning("Some tasks did not complete within timeout")
except Exception as e:
logger.error(f"Error during async cleanup: {e}")
def close(self):
"""Clean up all connections"""
if self.neo4j_conn:
self.neo4j_conn.disconnect()
logger.info("Setup cleanup complete")
def create_graphrag_setup():
"""Factory function to create GraphRAG setup instance"""
return GraphRAGSetup()
# Exports
__all__ = ['GraphRAGSetup', 'create_graphrag_setup']