Spaces:
Runtime error
Runtime error
| """Graph setup module for database and model initialization. Phase A (Steps 1-2)""" | |
| import os | |
| import logging | |
| from typing import Dict, Optional, Any | |
| import sys | |
| sys.path.append('..') # Add parent directory to path for imports | |
| from my_config import MY_CONFIG | |
| from neo4j import GraphDatabase | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core import Settings, VectorStoreIndex, StorageContext | |
| from llama_index.vector_stores.milvus import MilvusVectorStore | |
| from llama_index.llms.litellm import LiteLLM | |
| # Set up environment | |
| os.environ['HF_ENDPOINT'] = MY_CONFIG.HF_ENDPOINT | |
| # Configure logging | |
| logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s', force=True) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class Neo4jConnection: | |
| """ | |
| Neo4j database connection manager. | |
| """ | |
| def __init__(self): | |
| self.uri = MY_CONFIG.NEO4J_URI | |
| self.username = MY_CONFIG.NEO4J_USER | |
| self.password = MY_CONFIG.NEO4J_PASSWORD | |
| self.database = getattr(MY_CONFIG, "NEO4J_DATABASE", None) | |
| # Validate required configuration | |
| if not self.uri: | |
| raise ValueError("NEO4J_URI config is required") | |
| if not self.username: | |
| raise ValueError("NEO4J_USERNAME config is required") | |
| if not self.password: | |
| raise ValueError("NEO4J_PASSWORD config is required") | |
| if not self.database: | |
| raise ValueError("NEO4J_DATABASE config is required") | |
| self.driver: Optional[GraphDatabase.driver] = None | |
| def connect(self): | |
| """STEP 1.2: Initialize Neo4j driver with verification""" | |
| if self.driver is None: | |
| try: | |
| self.driver = GraphDatabase.driver( | |
| self.uri, | |
| auth=(self.username, self.password) | |
| ) | |
| self.driver.verify_connectivity() | |
| logger.info(f"Connected to Neo4j at {self.uri}") | |
| except Exception as e: | |
| logger.error(f"❌ STEP 1.2 FAILED: Neo4j connection error: {e}") | |
| self.driver = None | |
| def disconnect(self): | |
| """Clean up Neo4j connection""" | |
| if self.driver: | |
| self.driver.close() | |
| self.driver = None | |
| logger.info("Neo4j connection closed") | |
| def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None): | |
| """Execute Cypher query with error handling""" | |
| if not self.driver: | |
| raise ConnectionError("Not connected to Neo4j database") | |
| with self.driver.session(database=self.database) as session: | |
| result = session.run(query, parameters or {}) | |
| records = [record.data() for record in result] | |
| return records | |
| class GraphRAGSetup: | |
| """ | |
| Main setup class for graph-based retrieval system. | |
| Handles core initialization and configuration: | |
| - Database connections (Neo4j and vector database) | |
| - Model initialization and configuration | |
| - Graph statistics and validation | |
| - Search configuration loading | |
| """ | |
| def __init__(self): | |
| logger.info("Starting graph system initialization") | |
| # Initialize core components | |
| self.config = MY_CONFIG # Add config attribute for GraphQueryEngine | |
| self.neo4j_conn = None | |
| self.query_engine = None | |
| self.graph_stats = {} | |
| self.drift_config = {} | |
| self.llm = None | |
| self.embedding_model = None | |
| # Execute Step 1 initialization sequence | |
| self._execute_step1_sequence() | |
| logger.info("Graph system initialization complete") | |
| def _execute_step1_sequence(self): | |
| """Execute complete Step 1 initialization sequence""" | |
| # STEP 1.1-1.6: Initialize all components | |
| self._setup_neo4j() # STEP 1.2 | |
| self._setup_vector_search() # STEP 1.3-1.6 | |
| self._load_graph_statistics() # STEP 2.1-2.4 | |
| self._load_drift_configuration() # STEP 2.5 | |
| def _setup_neo4j(self): | |
| """STEP 1.2: Initialize Neo4j driver with verification""" | |
| try: | |
| logger.info("Initializing Neo4j connection...") | |
| self.neo4j_conn = Neo4jConnection() | |
| self.neo4j_conn.connect() | |
| # Verify connection with test query | |
| if self.neo4j_conn.driver: | |
| test_result = self.neo4j_conn.execute_query("MATCH (n) RETURN count(n) as total_nodes LIMIT 1") | |
| node_count = test_result[0]['total_nodes'] if test_result else 0 | |
| logger.info(f"Neo4j connected - {node_count} nodes found") | |
| except Exception as e: | |
| logger.error(f"Neo4j connection error: {e}") | |
| self.neo4j_conn = None | |
| def _setup_vector_search(self): | |
| """STEP 1.3-1.5: Initialize vector database and LLM components""" | |
| try: | |
| logger.info("Setting up vector search and LLM...") | |
| # STEP 1.5: Load embedding model | |
| self.embedding_model = HuggingFaceEmbedding( | |
| model_name=MY_CONFIG.EMBEDDING_MODEL | |
| ) | |
| Settings.embed_model = self.embedding_model | |
| logger.info(f"Embedding model loaded: {MY_CONFIG.EMBEDDING_MODEL}") | |
| # STEP 1.6: Connect to vector database based on configuration | |
| if MY_CONFIG.VECTOR_DB_TYPE == "cloud_zilliz": | |
| if not MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT or not MY_CONFIG.ZILLIZ_TOKEN: | |
| raise ValueError("Cloud database configuration missing. Set ZILLIZ_CLUSTER_ENDPOINT and ZILLIZ_TOKEN in .env") | |
| vector_store = MilvusVectorStore( | |
| uri=MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT, | |
| token=MY_CONFIG.ZILLIZ_TOKEN, | |
| dim=MY_CONFIG.EMBEDDING_LENGTH, | |
| collection_name=MY_CONFIG.COLLECTION_NAME, | |
| overwrite=False | |
| ) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| logger.info("Connected to cloud vector database") | |
| else: | |
| vector_store = MilvusVectorStore( | |
| uri=MY_CONFIG.MILVUS_URI_HYBRID_GRAPH, | |
| dim=MY_CONFIG.EMBEDDING_LENGTH, | |
| collection_name=MY_CONFIG.COLLECTION_NAME, | |
| overwrite=False | |
| ) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| logger.info("Connected to local vector database") | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, storage_context=storage_context) | |
| logger.info("Vector index loaded successfully") | |
| # STEP 1.4: Initialize LLM provider | |
| llm_model = MY_CONFIG.LLM_MODEL | |
| self.llm = LiteLLM(model=llm_model) | |
| Settings.llm = self.llm | |
| logger.info(f"LLM initialized: {llm_model}") | |
| self.query_engine = index.as_query_engine() | |
| except Exception as e: | |
| logger.error(f"Vector setup error: {e}") | |
| self.query_engine = None | |
| def _load_graph_statistics(self): | |
| """STEP 2.1-2.4: Load and validate graph data structure""" | |
| try: | |
| logger.info("Loading graph statistics and validation...") | |
| if not self.neo4j_conn or not self.neo4j_conn.driver: | |
| logger.warning("No Neo4j connection for statistics") | |
| return | |
| # STEP 2.1: Get node and relationship counts | |
| stats_query = """ | |
| MATCH (n) | |
| OPTIONAL MATCH ()-[r]-() | |
| RETURN count(DISTINCT n) as node_count, | |
| count(DISTINCT r) as relationship_count, | |
| count(DISTINCT n.community_id) as community_count | |
| """ | |
| result = self.neo4j_conn.execute_query(stats_query) | |
| if result: | |
| stats = result[0] | |
| self.graph_stats = { | |
| 'node_count': stats.get('node_count', 0), | |
| 'relationship_count': stats.get('relationship_count', 0), | |
| 'community_count': stats.get('community_count', 0) | |
| } | |
| logger.info(f"Graph validated - {self.graph_stats['node_count']} nodes, " | |
| f"{self.graph_stats['relationship_count']} relationships, " | |
| f"{self.graph_stats['community_count']} communities") | |
| except Exception as e: | |
| logger.error(f"Graph statistics error: {e}") | |
| self.graph_stats = {} | |
| def _load_drift_configuration(self): | |
| """STEP 2.5: Load DRIFT search metadata and configuration""" | |
| logger.info("Loading search configuration...") | |
| if not self.neo4j_conn or not self.neo4j_conn.driver: | |
| logger.warning("No Neo4j connection for search configuration") | |
| self.drift_config = {} | |
| return | |
| # Query for all DRIFT-related nodes | |
| drift_metadata_query = """ | |
| OPTIONAL MATCH (dm:DriftMetadata) | |
| OPTIONAL MATCH (dc:DriftConfiguration) | |
| OPTIONAL MATCH (csi:CommunitySearchIndex) | |
| OPTIONAL MATCH (gm:GraphMetadata) | |
| OPTIONAL MATCH (cm:CommunitiesMetadata) | |
| RETURN dm, dc, csi, gm, cm | |
| """ | |
| result = self.neo4j_conn.execute_query(drift_metadata_query) | |
| if result and result[0]: | |
| record = result[0] | |
| drift_config = {} | |
| # Extract DriftMetadata properties | |
| if record.get('dm'): | |
| dm_props = dict(record['dm']) | |
| drift_config.update(dm_props) | |
| logger.info("DriftMetadata node found") | |
| # Extract DriftConfiguration properties | |
| if record.get('dc'): | |
| dc_props = dict(record['dc']) | |
| drift_config['configuration'] = dc_props | |
| logger.info("DriftConfiguration node found") | |
| # Extract CommunitySearchIndex properties | |
| if record.get('csi'): | |
| csi_props = dict(record['csi']) | |
| drift_config['community_search_index'] = csi_props | |
| logger.info("CommunitySearchIndex node found") | |
| # Extract GraphMetadata properties | |
| if record.get('gm'): | |
| gm_props = dict(record['gm']) | |
| drift_config['graph_metadata'] = gm_props | |
| logger.info("GraphMetadata node found") | |
| # Extract CommunitiesMetadata properties | |
| if record.get('cm'): | |
| cm_props = dict(record['cm']) | |
| drift_config['communities_metadata'] = cm_props | |
| logger.info("CommunitiesMetadata node found") | |
| self.drift_config = drift_config | |
| logger.info("Search configuration loaded from Neo4j nodes") | |
| else: | |
| logger.warning("No metadata nodes found in Neo4j") | |
| self.drift_config = {} | |
| def validate_system_readiness(self): | |
| """Validate all required components are initialized""" | |
| ready = True | |
| if not self.neo4j_conn or not self.neo4j_conn.driver: | |
| logger.error("Neo4j connection not available") | |
| ready = False | |
| if not self.query_engine: | |
| logger.error("Vector query engine not available") | |
| ready = False | |
| if not self.graph_stats: | |
| logger.warning("Graph statistics not loaded") | |
| if ready: | |
| logger.info("System readiness validated") | |
| return ready | |
| def get_system_status(self): | |
| """Get detailed system status information""" | |
| return { | |
| "neo4j_connected": bool(self.neo4j_conn and self.neo4j_conn.driver), | |
| "vector_engine_ready": bool(self.query_engine), | |
| "graph_stats_loaded": bool(self.graph_stats), | |
| "drift_config_loaded": bool(self.drift_config), | |
| "llm_ready": bool(self.llm), | |
| "graph_stats": self.graph_stats, | |
| "drift_config": self.drift_config | |
| } | |
| async def cleanup_async_tasks(self, timeout: float = 2.0) -> None: | |
| """ | |
| Clean up async tasks and pending operations. | |
| Handles proper cleanup of LiteLLM and other async tasks to prevent | |
| 'Task was destroyed but it is pending!' warnings. | |
| """ | |
| try: | |
| import asyncio | |
| # Import cleanup function if available | |
| try: | |
| from litellm_patch import cleanup_all_async_tasks | |
| await cleanup_all_async_tasks(timeout=timeout) | |
| logger.info(f"Cleaned up async tasks with timeout {timeout}s") | |
| except ImportError: | |
| # Fallback: Cancel pending tasks manually | |
| pending_tasks = [task for task in asyncio.all_tasks() if not task.done()] | |
| if pending_tasks: | |
| logger.info(f"Cancelling {len(pending_tasks)} pending tasks") | |
| for task in pending_tasks: | |
| task.cancel() | |
| # Wait for cancellation with timeout | |
| try: | |
| await asyncio.wait_for( | |
| asyncio.gather(*pending_tasks, return_exceptions=True), | |
| timeout=timeout | |
| ) | |
| except asyncio.TimeoutError: | |
| logger.warning("Some tasks did not complete within timeout") | |
| except Exception as e: | |
| logger.error(f"Error during async cleanup: {e}") | |
| def close(self): | |
| """Clean up all connections""" | |
| if self.neo4j_conn: | |
| self.neo4j_conn.disconnect() | |
| logger.info("Setup cleanup complete") | |
| def create_graphrag_setup(): | |
| """Factory function to create GraphRAG setup instance""" | |
| return GraphRAGSetup() | |
| # Exports | |
| __all__ = ['GraphRAGSetup', 'create_graphrag_setup'] |