Spaces:
Paused
Paused
| import os | |
| import gc | |
| import time | |
| import asyncio | |
| import torch | |
| import uuid | |
| from contextlib import contextmanager | |
| from neo4j import GraphDatabase | |
| from pyvis.network import Network | |
| from src.query_processing.late_chunking.late_chunker import LateChunker | |
| from src.query_processing.query_processor import QueryProcessor | |
| from src.reasoning.reasoner import Reasoner | |
| from src.utils.api_key_manager import APIKeyManager | |
| from src.search.search_engine import SearchEngine | |
| from src.crawl.crawler import Crawler, CustomCrawler | |
| from sentence_transformers import SentenceTransformer | |
| from bert_score.scorer import BERTScorer | |
| import numpy as np | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List, Dict, Any | |
| class Neo4jGraphRAG: | |
| def __init__(self, num_workers: int = 1): | |
| """Initialize Neo4j connection and required components.""" | |
| # Neo4j connection setup | |
| self.neo4j_uri = os.getenv("NEO4J_URI") | |
| self.neo4j_user = os.getenv("NEO4J_USER") | |
| self.neo4j_password = os.getenv("NEO4J_PASSWORD") | |
| self.driver = GraphDatabase.driver( | |
| self.neo4j_uri, | |
| auth=(self.neo4j_user, self.neo4j_password) | |
| ) | |
| # Component initialization | |
| self.num_workers = num_workers | |
| self.search_engine = SearchEngine() | |
| self.query_processor = QueryProcessor() | |
| self.reasoner = Reasoner() | |
| # self.crawler = Crawler(verbose=True) | |
| self.custom_crawler = CustomCrawler(max_concurrent_requests=1000) | |
| self.chunking = LateChunker() | |
| self.llm = APIKeyManager().get_llm() | |
| # Model initialization | |
| self.model = SentenceTransformer( | |
| "dunzhang/stella_en_400M_v5", | |
| trust_remote_code=True, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| self.scorer = BERTScorer( | |
| model_type="roberta-base", | |
| lang="en", | |
| rescale_with_baseline=True, | |
| device= "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| # Counters and tracking | |
| self.root_node_id = "QR" | |
| self.node_counter = 0 | |
| self.sub_node_counter = 0 | |
| self.cross_connections = set() | |
| # Add graph tracking | |
| self.current_graph_id = None | |
| # Thread pool | |
| self.executor = ThreadPoolExecutor(max_workers=self.num_workers) | |
| # Create a callback to emit an event | |
| self.on_event_callback = None | |
| def set_on_event_callback(self, callback): | |
| """Register a single callback to be triggered for various event types.""" | |
| self.on_event_callback = callback | |
| async def emit_event(self, event_type: str, data: dict): | |
| """Helper method to safely emit an event if a callback is registered.""" | |
| if self.on_event_callback: | |
| # Check if the callback is asynchronous or synchronous | |
| if asyncio.iscoroutinefunction(self.on_event_callback): | |
| # The callback signature: callback(event_type, data) | |
| return await self.on_event_callback(event_type, data) | |
| else: | |
| return self.on_event_callback(event_type, data) | |
| def transaction(self, max_retries: int = 1): | |
| """Synchronous context manager for Neo4j transactions.""" | |
| session = self.driver.session() | |
| retry_count = 0 | |
| while True: | |
| try: | |
| tx = session.begin_transaction() | |
| try: | |
| yield tx | |
| tx.commit() | |
| break | |
| except Exception as e: | |
| tx.rollback() | |
| raise e | |
| except Exception as e: | |
| retry_count += 1 | |
| if retry_count >= max_retries: | |
| print(f"Transaction failed after {max_retries} attempts: {str(e)}") | |
| raise e | |
| print(f"Transaction failed, retrying ({retry_count}/{max_retries}): {str(e)}") | |
| time.sleep(1) # Use regular sleep for sync context | |
| finally: | |
| session.close() | |
| def initialize_schema(self): | |
| """Check and initialize database schema.""" | |
| constraint_node_id_per_graph = None | |
| index_node_query = None | |
| index_node_role = None | |
| constraint_graph_id = None | |
| index_graph_created = None | |
| constraint_node_graph = None | |
| try: | |
| with self.transaction() as tx: | |
| # Check if schema already exists by looking for our composite constraint | |
| constraint_node_id_per_graph = tx.run(""" | |
| SHOW CONSTRAINTS | |
| WHERE name = 'constraint_node_id_per_graph' | |
| """).data() | |
| index_node_role = tx.run(""" | |
| SHOW INDEXES | |
| WHERE name = 'index_node_role' | |
| """).data() | |
| index_node_graph_id = tx.run(""" | |
| SHOW INDEXES | |
| WHERE name = 'index_node_graph_id' | |
| """).data() | |
| constraint_graph_id = tx.run(""" | |
| SHOW CONSTRAINTS | |
| WHERE name = 'constraint_graph_id' | |
| """).data() | |
| index_graph_created = tx.run(""" | |
| SHOW INDEXES | |
| WHERE name = 'index_graph_created' | |
| """).data() | |
| constraint_node_graph = tx.run(""" | |
| SHOW CONSTRAINTS | |
| WHERE name = 'constraint_node_graph' | |
| """).data() | |
| if constraint_node_id_per_graph and index_node_role and \ | |
| index_node_graph_id and constraint_graph_id and index_graph_created and constraint_node_graph: | |
| print("Database schema already initialized") | |
| return | |
| print("Initializing database schema...") | |
| # Create composite constraint for node ID uniqueness within each graph | |
| if not constraint_node_id_per_graph: | |
| tx.run(""" | |
| CREATE CONSTRAINT constraint_node_id_per_graph IF NOT EXISTS | |
| FOR (n:Node) | |
| REQUIRE (n.id, n.graph_id) IS UNIQUE | |
| """) | |
| if not index_node_role: | |
| tx.run(""" | |
| CREATE INDEX index_node_role IF NOT EXISTS FOR (n:Node) | |
| ON (n.role) | |
| """) | |
| if not index_node_graph_id: | |
| tx.run(""" | |
| CREATE INDEX index_node_graph_id IF NOT EXISTS FOR (n:Node) | |
| ON (n.graph_id) | |
| """) | |
| # Graph management constraints | |
| if not constraint_graph_id: | |
| tx.run(""" | |
| CREATE CONSTRAINT constraint_graph_id IF NOT EXISTS | |
| FOR (g:Graph) | |
| REQUIRE g.id IS UNIQUE | |
| """) | |
| if not index_graph_created: | |
| tx.run(""" | |
| CREATE INDEX index_graph_created IF NOT EXISTS FOR (g:Graph) | |
| ON (g.created) | |
| """) | |
| if not constraint_node_graph: | |
| tx.run(""" | |
| CREATE CONSTRAINT constraint_node_graph IF NOT EXISTS | |
| FOR (n:Node) | |
| REQUIRE n.graph_id IS NOT NULL | |
| """) | |
| print("Database schema initialization complete") | |
| except Exception as e: | |
| print(f"Error ensuring schema exists: {str(e)}") | |
| raise | |
| def add_node(self, node_id: str, query: str, data: str = "", role: str = None): | |
| """Add a node to the current graph.""" | |
| if self.current_graph_id is None: | |
| raise Exception("Error: No current graph selected") | |
| try: | |
| with self.transaction() as tx: | |
| # Generate embedding | |
| embedding = self.model.encode(query).tolist() | |
| # Create node with properties including embedding and graph ID | |
| result = tx.run( | |
| """ | |
| MERGE (n:Node {id: $node_id, graph_id: $graph_id}) | |
| SET n.query = $node_query, | |
| n.embedding = $embedding, | |
| n.data = $data, | |
| n.role = $role | |
| """, | |
| node_id=node_id, | |
| graph_id=self.current_graph_id, | |
| node_query=query, | |
| embedding=embedding, | |
| data=data, | |
| role=role | |
| ) | |
| print(f"Added node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}'") | |
| except Exception as e: | |
| print(f"Error adding node '{node_id}' to graph '{self.current_graph_id}' with role '{role}' and query: '{query}': {str(e)}") | |
| raise | |
| def add_edge(self, node1: str, node2: str, weight: float = 1.0, relationship_type: str = None): | |
| """Add an edge between two nodes in a way that preserves a DAG structure in the graph""" | |
| if self.current_graph_id is None: | |
| raise Exception("Error: No current graph selected") | |
| # 1) Prevent self loops | |
| if node1 == node2: | |
| print(f"Cannot add edge to the same node {node1}!") | |
| return | |
| try: | |
| with self.transaction() as tx: | |
| # 2) Check if there is already a path from node2 back to node1 | |
| check_path = tx.run( | |
| """ | |
| MATCH (start:Node {id: $node2, graph_id: $graph_id}) | |
| MATCH (end:Node {id: $node1, graph_id: $graph_id}) | |
| // If there's any path of length >= 0 from 'start' to 'end', | |
| // then creating (end)->(start) would introduce a cycle. | |
| WHERE (start)-[:RELATION*0..]->(end) | |
| RETURN COUNT(start) AS pathExists | |
| """, | |
| node1=node1, | |
| node2=node2, | |
| graph_id=self.current_graph_id | |
| ) | |
| path_count = check_path.single()["pathExists"] | |
| if path_count > 0: | |
| print(f"An edge between {node1} -> {node2} already exists!") | |
| return | |
| # 3) Otherwise, safe to create a new directed edge | |
| tx.run( | |
| """ | |
| MATCH (a:Node {id: $node1, graph_id: $graph_id}) | |
| MATCH (b:Node {id: $node2, graph_id: $graph_id}) | |
| MERGE (a)-[r:RELATION {type: $rel_type}]->(b) | |
| SET r.weight = $weight | |
| """, | |
| node1=node1, | |
| node2=node2, | |
| graph_id=self.current_graph_id, | |
| rel_type=relationship_type, | |
| weight=weight | |
| ) | |
| print( | |
| f"Added edge between '{node1}' and '{node2}' in graph " | |
| f"'{self.current_graph_id}' (type='{relationship_type}', weight={weight})" | |
| ) | |
| except Exception as e: | |
| print(f"Error adding edge between '{node1}' and '{node2}': {str(e)}") | |
| raise | |
| def edge_exists(self, node1: str, node2: str) -> bool: | |
| """Check if an edge exists between two nodes.""" | |
| try: | |
| with self.transaction() as tx: | |
| result = tx.run( | |
| """ | |
| MATCH (a:Node {id: $node1})-[r:RELATION]-(b:Node {id: $node2}) | |
| RETURN COUNT(r) as count | |
| """, | |
| node1=node1, | |
| node2=node2 | |
| ) | |
| return result.single()["count"] > 0 | |
| except Exception as e: | |
| print(f"Error checking edge existence between {node1} and {node2}: {str(e)}") | |
| raise | |
| def graph_exists(self) -> bool: | |
| """Check if a graph exists in Neo4j.""" | |
| try: | |
| with self.transaction() as tx: | |
| result = tx.run(""" | |
| MATCH (n:Node) | |
| RETURN count(n) > 0 as has_nodes | |
| """) | |
| return result.single()["has_nodes"] | |
| except Exception as e: | |
| print(f"Error checking graph existence: {str(e)}") | |
| raise | |
| def get_graphs(self) -> list: | |
| """Get detailed information about all existing graphs and their nodes.""" | |
| try: | |
| with self.transaction() as tx: | |
| result = tx.run( | |
| """ | |
| MATCH (g:Graph) | |
| OPTIONAL MATCH (n:Node {graph_id: g.id})-[r:RELATION]->(:Node) | |
| WITH g, collect(DISTINCT n) AS nodes, collect(DISTINCT r) AS rels | |
| RETURN { | |
| graph_id: g.id, | |
| created: g.created, | |
| updated: g.updated, | |
| node_count: size(nodes), | |
| edge_count: size(rels), | |
| nodes: [node IN nodes | { | |
| id: node.id, | |
| query: node.query, | |
| data: node.data, | |
| role: node.role, | |
| pagerank: node.pagerank | |
| }] | |
| } as graph_info | |
| ORDER BY g.created DESC | |
| """ | |
| ) | |
| return list(result) | |
| except Exception as e: | |
| print(f"Error getting graphs: {str(e)}") | |
| raise | |
| def select_graph(self, graph_id: str) -> bool: | |
| """Select a specific graph as the current working graph.""" | |
| try: | |
| with self.transaction() as tx: | |
| result = tx.run(""" | |
| MATCH (g:Graph {id: $graph_id}) | |
| RETURN g | |
| """, graph_id=graph_id) | |
| if result.single(): | |
| self.current_graph_id = graph_id | |
| return True | |
| return False | |
| except Exception as e: | |
| print(f"Error selecting graph: {str(e)}") | |
| raise | |
| def create_new_graph(self) -> str: | |
| """Create a new graph instance and its ID.""" | |
| try: | |
| with self.transaction() as tx: | |
| graph_id = str(uuid.uuid4()) | |
| tx.run(""" | |
| CREATE (g:Graph { | |
| id: $graph_id, | |
| created: datetime(), | |
| updated: datetime() | |
| }) | |
| """, graph_id=graph_id) | |
| self.current_graph_id = graph_id | |
| except Exception as e: | |
| print(f"Error creating new graph: {str(e)}") | |
| raise | |
| def load_graph(self, node_id: str) -> bool: | |
| """Load an existing graph structure from Neo4j based on node ID.""" | |
| # Helper function to safely extract number from node ID | |
| def extract_number(node_id: str) -> int: | |
| try: | |
| # Extract all digits from the string | |
| num_str = ''.join(filter(str.isdigit, node_id)) | |
| return int(num_str) if num_str else 0 | |
| except ValueError: | |
| print(f"Warning: Could not extract number from node ID: {node_id}") | |
| return 0 | |
| try: | |
| with self.driver.session() as session: | |
| # Start transaction | |
| tx = session.begin_transaction() | |
| try: | |
| # Get all related nodes and relationships | |
| result = tx.run(""" | |
| MATCH path = (n:Node)-[r:RELATION*0..]->(m:Node) | |
| WHERE n.id = $node_id | |
| RETURN DISTINCT n, r, m, | |
| length(path) as depth, | |
| [rel in r | type(rel)] as rel_types, | |
| [rel in r | rel.weight] as weights | |
| """, node_id=node_id) | |
| # Reset internal state | |
| self.node_counter = 0 | |
| self.sub_node_counter = 0 | |
| self.cross_connections.clear() | |
| # Track processed nodes to avoid duplicates | |
| processed_nodes = set() | |
| # Process results | |
| for record in result: | |
| # Update counters based on node patterns | |
| if record["n"]["id"] not in processed_nodes: | |
| node_id = record["n"]["id"] | |
| if "SQ" in node_id: | |
| current_num = extract_number(node_id) | |
| self.node_counter = max(self.node_counter, current_num) | |
| elif "SSQ" in node_id: | |
| current_num = extract_number(node_id) | |
| self.sub_node_counter = max(self.sub_node_counter, current_num) | |
| processed_nodes.add(node_id) | |
| if record["m"]["id"] not in processed_nodes: | |
| node_id = record["m"]["id"] | |
| if "SQ" in node_id: | |
| current_num = extract_number(node_id) | |
| self.node_counter = max(self.node_counter, current_num) | |
| elif "SSQ" in node_id: | |
| current_num = extract_number(node_id) | |
| self.sub_node_counter = max(self.sub_node_counter, current_num) | |
| processed_nodes.add(node_id) | |
| # Increment counters for next use | |
| self.node_counter += 1 | |
| self.sub_node_counter += 1 | |
| # Track cross-connections | |
| result = tx.run(""" | |
| MATCH (n:Node)-[r:RELATION]->(m:Node) | |
| WHERE r.type = 'logical' | |
| RETURN n.id as source, m.id as target | |
| """) | |
| for record in result: | |
| connection = tuple(sorted([record["source"], record["target"]])) | |
| self.cross_connections.add(connection) | |
| tx.commit() | |
| print(f"Successfully loaded graph. Current counters - Node: {self.node_counter}, Sub: {self.sub_node_counter}") | |
| return True | |
| except Exception as e: | |
| tx.rollback() | |
| print(f"Transaction error while loading graph: {str(e)}") | |
| return False | |
| except Exception as e: | |
| print(f"Error loading graph: {str(e)}") | |
| return False | |
| async def modify_graph(self, new_query: str, similar_node_id: str, session_id: str = None): | |
| """Modify an existing graph structure by integrating a new query.""" | |
| # Inner function to add a new node as a sibling | |
| async def add_as_sibling(node_id: str, query: str): | |
| with self.transaction() as tx: | |
| result = tx.run(""" | |
| MATCH (n:Node)<-[r:RELATION]-(parent:Node) | |
| WHERE n.id = $node_id | |
| RETURN parent.id as parent_id, | |
| parent.query as parent_query, | |
| r.type as rel_type | |
| """, node_id=node_id) | |
| parent_data = result.single() | |
| if not parent_data: | |
| raise ValueError(f"No parent found for node {node_id}") | |
| if "SQ" in node_id: | |
| self.node_counter += 1 | |
| new_node_id = f"SQ{self.node_counter}" | |
| else: | |
| self.sub_node_counter += 1 | |
| new_node_id = f"SSQ{self.sub_node_counter}" | |
| self.add_node( | |
| node_id=new_node_id, | |
| query=query, | |
| role="independent" | |
| ) | |
| self.add_edge( | |
| parent_data["parent_id"], | |
| new_node_id, | |
| relationship_type=parent_data["rel_type"] | |
| ) | |
| return new_node_id | |
| # Inner function to add a new node as a child | |
| async def add_as_child(node_id: str, query: str): | |
| if "SQ" in node_id: | |
| self.sub_node_counter += 1 | |
| new_node_id = f"SSQ{self.sub_node_counter}" | |
| else: | |
| self.node_counter += 1 | |
| new_node_id = f"SQ{self.node_counter}" | |
| self.add_node( | |
| node_id=new_node_id, | |
| query=query, | |
| role="dependent" | |
| ) | |
| self.add_edge( | |
| node_id, | |
| new_node_id, | |
| relationship_type="logical" | |
| ) | |
| return new_node_id | |
| # Inner function to collect context from existing graph nodes | |
| def collect_graph_context() -> list: | |
| try: | |
| with self.transaction() as tx: | |
| # Get all nodes except root, ordered by depth and ID to maintain hierarchy | |
| result = tx.run(""" | |
| MATCH (n:Node) | |
| WHERE n.id <> $root_id AND n.graph_id = $graph_id | |
| WITH n | |
| ORDER BY | |
| CASE | |
| WHEN n.id STARTS WITH 'SQ' THEN 1 | |
| WHEN n.id STARTS WITH 'SSQ' THEN 2 | |
| ELSE 3 | |
| END, | |
| n.id | |
| RETURN COLLECT({ | |
| id: n.id, | |
| query: n.query, | |
| role: n.role | |
| }) as nodes | |
| """, root_id=self.root_node_id, graph_id=self.current_graph_id) | |
| nodes = result.single()["nodes"] | |
| if not nodes: | |
| return [] | |
| # Group nodes by hierarchy level | |
| level_queries = {} | |
| current_sq = None | |
| for node in nodes: | |
| node_id = node["id"] | |
| if node_id.startswith("SQ"): | |
| current_sq = node_id | |
| if current_sq not in level_queries: | |
| level_queries[current_sq] = { | |
| "originalquery": node["query"], | |
| "subqueries": [] | |
| } | |
| # Add the SQ node itself as a sub-query | |
| level_queries[current_sq]["subqueries"].append({ | |
| "subquery": node["query"], | |
| "role": node["role"], | |
| "dependson": [] # Dependencies will be added below | |
| }) | |
| elif node_id.startswith("SSQ") and current_sq: | |
| level_queries[current_sq]["subqueries"].append({ | |
| "subquery": node["query"], | |
| "role": node["role"], | |
| "dependson": [] # Dependencies will be added below | |
| }) | |
| # Add dependency information | |
| for sq_id, query_data in level_queries.items(): | |
| for i, sub_query in enumerate(query_data["subqueries"]): | |
| # Get dependencies for this sub_query | |
| deps = tx.run(""" | |
| MATCH (n:Node {query: $node_query})-[r:RELATION {type: 'logical'}]->(m:Node) | |
| WHERE n.graph_id = $graph_id | |
| RETURN COLLECT(m.query) as dependencies | |
| """, node_query=sub_query["subquery"], graph_id=self.current_graph_id) | |
| dep_queries = deps.single()["dependencies"] | |
| if dep_queries: | |
| # Find indices of dependent queries | |
| curr_deps = [] | |
| prev_deps = [] | |
| for dep_query in dep_queries: | |
| # Check current level dependencies | |
| curr_idx = next( | |
| (idx for idx, sq in enumerate(query_data["subqueries"]) | |
| if sq["subquery"] == dep_query), | |
| None | |
| ) | |
| if curr_idx is not None: | |
| curr_deps.append(curr_idx) | |
| else: | |
| # Check previous level dependencies | |
| for prev_idx, prev_data in enumerate(level_queries.values()): | |
| if dep_query in [sq["subquery"] for sq in prev_data["subqueries"]]: | |
| prev_deps.append(prev_idx) | |
| break | |
| query_data["subqueries"][i]["dependson"] = [prev_deps, curr_deps] | |
| # Convert to list maintaining order | |
| return list(level_queries.values()) | |
| except Exception as e: | |
| print(f"Error collecting graph context: {str(e)}") | |
| raise | |
| try: | |
| # Get the role and other metadata of the similar node | |
| with self.transaction() as tx: | |
| result = tx.run(""" | |
| MATCH (n:Node {id: $node_id}) | |
| RETURN n.role as role, | |
| n.query as query, | |
| EXISTS((n)<-[:RELATION]-()) as has_parent | |
| """, node_id=similar_node_id) | |
| node_data = result.single() | |
| if not node_data: | |
| raise Exception(f"Node {similar_node_id} not found") | |
| # Collect context from existing graph | |
| context = collect_graph_context() | |
| # Determine modification strategy | |
| if node_data["role"] == "independent": | |
| # Add as sibling if has parent, else as child | |
| if node_data["has_parent"]: | |
| new_node_id = await add_as_sibling(similar_node_id, new_query) | |
| else: | |
| new_node_id = await add_as_child(similar_node_id, new_query) | |
| else: | |
| # Add as child for dependent or pre-requisite nodes | |
| new_node_id = await add_as_child(similar_node_id, new_query) | |
| # Recursively build subgraph for new node if needed | |
| await self.build_graph( | |
| query=new_query, | |
| parent_node_id=new_node_id, | |
| depth=1 if "SQ" in new_node_id else 2, | |
| context=context, # Pass the collected context | |
| session_id=session_id | |
| ) | |
| except Exception as e: | |
| print(f"Error modifying graph: {str(e)}") | |
| raise | |
| async def build_graph(self, query: str, data: str = None, parent_node_id: str = None, | |
| depth: int = 0, threshold: float = 0.8, recurse: bool = True, | |
| context: list = None, session_id: str = None, max_tokens_allowed: int = 128000): | |
| """Build a new graph structure in Neo4j.""" | |
| async def process_node(self, node_id: str, sub_query: str, | |
| session_id: str, future: asyncio.Future, | |
| depth=depth, max_tokens_allowed=max_tokens_allowed): | |
| """Process a node asynchronously.""" | |
| try: | |
| # Generate an optimized search query | |
| optimized_query = await self.search_engine.generate_optimized_query(sub_query) | |
| # Search for the sub-query | |
| results = await self.search_engine.search( | |
| query=optimized_query, | |
| num_results=10, | |
| exclude_filetypes=["pdf"] | |
| ) | |
| # Emit event with the raw results | |
| await self.emit_event("search_results_fetched", { | |
| "node_id": node_id, | |
| "sub_query": sub_query, | |
| "optimized_query": optimized_query, | |
| "search_results": results | |
| }) | |
| # Filter the URLs based on the query | |
| filtered_urls = await self.search_engine.filter_urls( | |
| sub_query, | |
| "ultra", | |
| results | |
| ) | |
| # Emit an event with the filtered URLs | |
| await self.emit_event("search_results_filtered", { | |
| "node_id": node_id, | |
| "sub_query": sub_query, | |
| "filtered_urls": filtered_urls | |
| }) | |
| # Get the URLs | |
| urls = [result.get('link', 'No URL') for result in filtered_urls] | |
| # Fetch URL contents | |
| search_contents = await self.custom_crawler.fetch_page_contents( | |
| urls, | |
| sub_query, | |
| session_id=session_id, | |
| max_attempts=1, | |
| timeout=30 | |
| ) | |
| # Emit an event with the fetched contents | |
| await self.emit_event("search_contents_fetched", { | |
| "node_id": node_id, | |
| "sub_query": sub_query, | |
| "contents": search_contents | |
| }) | |
| # Format the contents | |
| contents = "" | |
| for k, content in enumerate(search_contents, 1): | |
| if isinstance(content, Exception): | |
| print(f"Error fetching content: {content}") | |
| elif content: | |
| contents += f"Document {k}:\n{content}\n\n" | |
| if len(contents.strip()) > 0: | |
| if depth == 0: | |
| # Emit an event to indicate the completion of sub-query processing | |
| await self.emit_event("sub_query_processed", { | |
| "node_id": node_id, | |
| "sub_query": sub_query, | |
| "contents": contents | |
| }) | |
| # Chunk the contents if it exceeds the token limit | |
| token_count = self.llm.get_num_tokens(contents) | |
| if token_count > max_tokens_allowed: | |
| contents = await self.chunking.chunker( | |
| text=contents, | |
| query=sub_query, | |
| max_tokens=max_tokens_allowed | |
| ) | |
| print(f"Number of tokens in the answer: {token_count}") | |
| print(f"Number of tokens in the content: {self.llm.get_num_tokens(contents)}") | |
| else: | |
| if depth == 0: | |
| # Emit an event to indicate the failure of sub-query processing | |
| await self.emit_event("sub_query_failed", { | |
| "node_id": node_id, | |
| "sub_query": sub_query, | |
| "contents": contents | |
| }) | |
| # Update node with data atomically | |
| with self.transaction() as tx: | |
| tx.run( | |
| """ | |
| MATCH (n:Node {id: $node_id}) | |
| SET n.data = $data | |
| """, | |
| node_id=node_id, | |
| data=contents | |
| ) | |
| # Set the result in the future | |
| future.set_result(contents) | |
| except Exception as e: | |
| print(f"Error processing node {node_id}: {str(e)}") | |
| if depth == 0: | |
| await self.emit_event("sub_query_failed", { | |
| "node_id": node_id, | |
| "sub_query": sub_query | |
| }) | |
| future.set_exception(e) | |
| raise | |
| async def process_dependent_node(self, node_id: str, sub_query: str, depth, dep_futures: list, future): | |
| """Process a dependent node asynchronously.""" | |
| try: | |
| loop = asyncio.get_running_loop() | |
| # Wait for dependencies | |
| dep_data = [await f for f in dep_futures] | |
| # Modify query based on dependencies | |
| modified_query = await self.query_processor.modify_query( | |
| sub_query, | |
| dep_data | |
| ) | |
| # Generate new embedding for modified query | |
| embedding = await loop.run_in_executor( | |
| self.executor, | |
| self.model.encode, | |
| modified_query | |
| ) | |
| # Update node query and embedding atomically | |
| with self.transaction() as tx: | |
| tx.run( | |
| """ | |
| MATCH (n:Node {id: $node_id}) | |
| SET n.query = $modified_query, | |
| n.embedding = $embedding | |
| """, | |
| node_id=node_id, | |
| modified_query=modified_query, | |
| embedding=embedding.tolist() | |
| ) | |
| # Process the modified node | |
| try: | |
| if not future.done(): | |
| await process_node( | |
| self, node_id, modified_query, session_id, future, depth, max_tokens_allowed | |
| ) | |
| except Exception as e: | |
| if depth == 0: | |
| await self.emit_event("sub_query_failed", { | |
| "node_id": node_id, | |
| "sub_query": sub_query | |
| }) | |
| if not future.done(): | |
| future.set_exception(e) | |
| raise | |
| except Exception as e: | |
| print(f"Error processing dependent node {node_id}: {str(e)}") | |
| if depth == 0: | |
| await self.emit_event("sub_query_failed", { | |
| "node_id": node_id, | |
| "sub_query": sub_query | |
| }) | |
| if not future.done(): | |
| future.set_exception(e) | |
| raise | |
| def create_cross_connections(self, node_id=None, depth=None, role=None): | |
| """Create cross connections based on dependencies.""" | |
| try: | |
| # Get all logical relationships | |
| relationships = self.get_node_relationships( | |
| node_id=node_id, | |
| depth=depth, | |
| role=role, | |
| relationship_type='logical' | |
| ) | |
| for current_node_id, edges in relationships.items(): | |
| # Get node role | |
| with self.transaction() as tx: | |
| result = tx.run( | |
| "MATCH (n:Node {id: $node_id}) RETURN n.role as role", | |
| node_id=current_node_id | |
| ) | |
| node_data = result.single() | |
| if not node_data or not node_data["role"]: | |
| continue | |
| node_role = node_data["role"].lower() | |
| # Only process dependent nodes | |
| if node_role == 'dependent': | |
| # Process incoming edges (dependencies) | |
| for source_id, target_id, edge_data in edges['in_edges']: | |
| if not source_id or source_id == self.root_node_id: | |
| continue | |
| # Create connection key | |
| connection = tuple(sorted([current_node_id, source_id])) | |
| # Add cross-connection if not exists | |
| if connection not in self.cross_connections: | |
| if not self.edge_exists(source_id, current_node_id): | |
| print(f"Adding cross-connection edge between {source_id} and {current_node_id}") | |
| self.add_edge( | |
| source_id, | |
| current_node_id, | |
| weight=edge_data.get('weight', 1.0), | |
| relationship_type='logical' | |
| ) | |
| self.cross_connections.add(connection) | |
| # Process outgoing edges (children) | |
| for source_id, target_id, edge_data in edges['out_edges']: | |
| if not target_id or target_id == self.root_node_id: | |
| continue | |
| # Create connection key | |
| connection = tuple(sorted([current_node_id, target_id])) | |
| # Add cross-connection if not exists | |
| if connection not in self.cross_connections: | |
| if not self.edge_exists(current_node_id, target_id): | |
| print(f"Adding cross-connection edge between {current_node_id} and {target_id}") | |
| self.add_edge( | |
| current_node_id, | |
| target_id, | |
| weight=edge_data.get('weight', 1.0), | |
| relationship_type='logical' | |
| ) | |
| self.cross_connections.add(connection) | |
| except Exception as e: | |
| print(f"Error creating cross connections: {str(e)}") | |
| raise | |
| # Main build_graph implementation | |
| # Limit recursion depth | |
| if depth > 1: | |
| return | |
| # Initialize context if not provided | |
| if context is None: | |
| context = [] | |
| # Dictionary to keep track of node data and their futures | |
| node_data_futures = {} | |
| if parent_node_id is None: | |
| # If no parent node, this is the root (original query) | |
| self.add_node(self.root_node_id, query, data) | |
| parent_node_id = self.root_node_id | |
| # Get the query intent | |
| intent = await self.query_processor.get_query_intent(query) | |
| if depth == 0: | |
| # Decompose the query into sub-queries | |
| response_data, sub_queries, roles, dependencies = \ | |
| await self.query_processor.decompose_query_with_dependencies(query, intent) | |
| else: | |
| # Decompose the sub-query into sub-sub-queries with past context | |
| response_data, sub_queries, roles, dependencies = \ | |
| await self.query_processor.decompose_query_with_dependencies( | |
| query, | |
| intent, | |
| context | |
| ) | |
| # Add current query data to context for next iteration | |
| if response_data: | |
| context.append(response_data) | |
| # If no further decomposition is possible, sub_queries will contain only the original query | |
| if len(sub_queries) > 1 and sub_queries[0] != query: | |
| sub_query_ids = [] | |
| pre_req_nodes = {} | |
| # Create the structure (nodes and edges) of the graph at the current level | |
| for idx, (sub_query, role, dependency) in enumerate(zip(sub_queries, roles, dependencies)): | |
| # If this is the sub-queries level, | |
| # fire the event, letting the callback know about the sub-query | |
| if depth == 0: | |
| await self.emit_event( | |
| "sub_query_created", | |
| { | |
| "depth": depth, | |
| "sub_query": sub_query, | |
| "role": role, | |
| "dependency": dependency, | |
| "parent_node_id": parent_node_id, | |
| } | |
| ) | |
| # Generate a unique ID for the sub-query | |
| if depth == 0: | |
| self.node_counter += 1 | |
| sub_node_id = f"SQ{self.node_counter}" | |
| else: | |
| self.sub_node_counter += 1 | |
| sub_node_id = f"SSQ{self.sub_node_counter}" | |
| # Add the node ID to the list of sub-query IDs | |
| sub_query_ids.append(sub_node_id) | |
| # Add the node to the graph but without a data | |
| self.add_node(node_id=sub_node_id, query=sub_query, role=role) | |
| # Create future for the node | |
| future = asyncio.Future() | |
| node_data_futures[sub_node_id] = future | |
| if role.lower() in ('pre-requisite', 'prerequisite'): | |
| pre_req_nodes[idx] = sub_node_id | |
| # Determine how to add edges based on the role | |
| if role.lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
| # Pre-requisite and Independent nodes connect directly to the parent | |
| self.add_edge(parent_node_id, sub_node_id, relationship_type='hierarchical') | |
| elif role.lower() == 'dependent': | |
| if isinstance(dependency, list) and ( | |
| (len(dependency) == 2 and all(isinstance(d, list) for d in dependency)) | |
| ): | |
| print(f"Dependency: {dependency}") | |
| # Handle previous query dependencies | |
| prev_deps, current_deps = dependency | |
| # Handle previous query dependencies | |
| if context and prev_deps not in [None, []]: | |
| for dep_idx in prev_deps: | |
| if dep_idx is not None: | |
| # Find the corresponding context data | |
| for context_data in context: | |
| if context_data and 'subqueries' in context_data: | |
| if dep_idx < len(context_data['subqueries']): | |
| # Get the query from context | |
| sub_query_data = context_data['subqueries'][dep_idx] | |
| if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
| dep_query = sub_query_data['subquery'] | |
| # Find matching nodes | |
| matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
| # Get the best matching node ID and score | |
| if matching_nodes not in [None, []]: | |
| dep_node_id = matching_nodes[0].get('node_id') | |
| score = matching_nodes[0].get('score', 0) | |
| if score >= 0.9: | |
| self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
| # Add edges from current query dependencies | |
| if current_deps not in [None, []]: | |
| for dep_idx in current_deps: | |
| if dep_idx < len(sub_queries): | |
| dep_node_id = sub_query_ids[dep_idx] | |
| self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
| else: | |
| # Dependency is incorrect | |
| raise ValueError(f"Invalid dependency index: {dep_idx}") | |
| elif len(dependency) > 0: | |
| for dep_idx in dependency: | |
| if dep_idx < len(sub_queries): | |
| # Get the node ID of the dependency | |
| dep_node_id = sub_query_ids[dep_idx] | |
| # Add an edge from the dependency to the current sub-query | |
| self.add_edge(dep_node_id, sub_node_id, relationship_type='logical') | |
| else: | |
| raise ValueError(f"Invalid dependency index: {dep_idx}") | |
| else: | |
| # Dependency is incorrect or empty | |
| raise ValueError(f"Invalid dependency: {dependency}") | |
| else: | |
| # Handle any unexpected roles | |
| raise ValueError(f"Unexpected role: {role}") | |
| # Proceed to process the nodes | |
| tasks = [] | |
| # Process pre-requisite and independent nodes concurrently | |
| for idx in range(len(sub_queries)): | |
| node_id = sub_query_ids[idx] | |
| future = node_data_futures[node_id] | |
| if roles[idx].lower() in ('pre-requisite', 'prerequisite', 'independent'): | |
| tasks.append(process_node( | |
| self, node_id, sub_queries[idx], session_id, future, depth, max_tokens_allowed | |
| )) | |
| # Process dependent nodes as soon as their dependencies are ready | |
| for idx in range(len(sub_queries)): | |
| node_id = sub_query_ids[idx] | |
| future = node_data_futures[node_id] | |
| if roles[idx].lower() == 'dependent': | |
| dep_futures = [] | |
| if isinstance(dependencies[idx], list) and len(dependencies[idx]) == 2: | |
| prev_deps, current_deps = dependencies[idx] | |
| # Get futures from previous context dependencies | |
| if context and prev_deps not in [None, []]: | |
| for context_idx, context_data in enumerate(context): | |
| # If prev_deps is a list, process the corresponding dependency | |
| if isinstance(prev_deps, list) and context_idx < len(prev_deps): | |
| context_dep = prev_deps[context_idx] | |
| if context_dep is not None: | |
| if context_data and 'subqueries' in context_data: | |
| if context_dep < len(context_data['subqueries']): | |
| sub_query_data = context_data['subqueries'][context_dep] | |
| if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
| dep_query = sub_query_data['subquery'] | |
| # Find matching nodes | |
| matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
| if matching_nodes not in [None, []]: | |
| # Get the exact matching node ID and score | |
| dep_node_id = matching_nodes[0].get('node_id', None) | |
| score = float(matching_nodes[0].get('score', 0)) | |
| if score == 1.0 and dep_node_id in node_data_futures: | |
| dep_futures.append(node_data_futures[dep_node_id]) | |
| # If prev_deps is an integer, process it for the current context | |
| elif isinstance(prev_deps, int): | |
| if prev_deps < len(context_data['subqueries']): | |
| sub_query_data = context_data['subqueries'][prev_deps] | |
| if isinstance(sub_query_data, dict) and 'subquery' in sub_query_data: | |
| dep_query = sub_query_data['subquery'] | |
| # Find matching nodes | |
| matching_nodes = self.find_nodes_by_properties(query=dep_query) | |
| if matching_nodes not in [None, []]: | |
| # Get the exact matching node ID and score | |
| dep_node_id = matching_nodes[0].get('node_id', None) | |
| score = matching_nodes[0].get('score', 0) | |
| if score == 1.0 and dep_node_id in node_data_futures: | |
| dep_futures.append(node_data_futures[dep_node_id]) | |
| # Get futures from current dependencies | |
| if current_deps not in [None, []]: | |
| current_deps_list = [current_deps] if isinstance(current_deps, int) else current_deps | |
| for dep_idx in current_deps_list: | |
| if dep_idx < len(sub_queries): | |
| dep_node_id = sub_query_ids[dep_idx] | |
| if dep_node_id in node_data_futures: | |
| dep_futures.append(node_data_futures[dep_node_id]) | |
| # Start coroutine to wait for dependencies and then process node | |
| tasks.append(process_dependent_node( | |
| self, node_id, sub_queries[idx], depth, dep_futures, future | |
| )) | |
| # Emit an event to indicate the start of the search process | |
| if depth == 0: | |
| await self.emit_event("search_process_started", { | |
| "depth": depth, | |
| "sub_queries": sub_queries, | |
| "roles": roles | |
| }) | |
| # Wait for all tasks to complete | |
| await asyncio.gather(*tasks) | |
| # Recurse into sub-queries if needed | |
| if recurse: | |
| recursion_tasks = [] | |
| for idx, sub_query in enumerate(sub_queries): | |
| try: | |
| sub_node_id = sub_query_ids[idx] | |
| recursion_tasks.append( | |
| self.build_graph( | |
| query=sub_query, | |
| parent_node_id=sub_node_id, | |
| depth=depth + 1, | |
| threshold=threshold, | |
| recurse=recurse, | |
| context=context, # Pass the context | |
| session_id=session_id | |
| )) | |
| except Exception as e: | |
| print(f"Failed to create recursion task for sub-query {sub_query}: {e}") | |
| continue | |
| # Only proceed if there are any recursion tasks | |
| if recursion_tasks: | |
| try: | |
| await asyncio.gather(*recursion_tasks) | |
| except Exception as e: | |
| raise Exception(f"Error during recursive processing: {e}") | |
| # Process completion tasks | |
| if depth == 0: | |
| print("Graph building complete, processing final tasks...") | |
| await self.emit_event("search_process_completed", { | |
| "depth": depth, | |
| "sub_queries": sub_queries, | |
| "roles": roles | |
| }) | |
| # Create cross-connections | |
| create_cross_connections(self) | |
| print("All cross-connections have been created!") | |
| # Add similarity-based edges | |
| print(f"Adding similarity edges with threshold {threshold}") | |
| all_nodes = [] | |
| with self.driver.session() as session: | |
| result = session.run( | |
| "MATCH (n:Node) WHERE n.id <> $root_id RETURN n.id as id", | |
| root_id=self.root_node_id | |
| ) | |
| all_nodes = [record["id"] for record in result] | |
| for i, node1 in enumerate(all_nodes): | |
| for node2 in all_nodes[i+1:]: | |
| if not self.edge_exists(node1, node2): | |
| self.add_edge_based_on_similarity_and_relevance( | |
| node1, node2, query, threshold | |
| ) | |
| print("All similarity-based edges have been added!") | |
| async def process_graph( | |
| self, | |
| query: str, | |
| data: str = None, | |
| similarity_threshold: float = 0.8, | |
| relevance_threshold: float = 0.7, | |
| sub_sub_queries: bool = True, | |
| session_id: str = None, | |
| max_tokens_allowed: int = 128000 | |
| ): | |
| """Process a query and manage graph creation/modification.""" | |
| # Inner function to check similarity between new query and existing queries in the graph | |
| def check_query_similarity(new_query: str, similarity_threshold: float = 0.8) -> Dict[str, Any]: | |
| if self.current_graph_id is None: | |
| raise Exception("Error: No current graph ID. Cannot check query similarity.") | |
| try: | |
| # Get all existing queries of the current graph and their metadata from Neo4j | |
| print(f"Retrieving existing queries and their metadata for graph {self.current_graph_id}") | |
| with self.transaction() as tx: | |
| result = tx.run(""" | |
| MATCH (n:Node) | |
| WHERE n.graph_id IS NOT NULL | |
| AND n.graph_id = $graph_id | |
| RETURN n.id as id, | |
| n.query as query, | |
| n.role as role | |
| """, | |
| graph_id=self.current_graph_id | |
| ) | |
| # Process results and calculate similarities | |
| similarities = [] | |
| records = list(result) # Materialize results to avoid session timeout | |
| if records == []: # No existing queries | |
| return {"should_create_new": True} | |
| for record in records: | |
| # Skip if missing required data | |
| if not all([record["query"]]): | |
| continue | |
| # Calculate query similarity | |
| similarity = self.calculate_query_similarity( | |
| new_query, | |
| record["query"] | |
| ) | |
| if similarity >= similarity_threshold: | |
| similarities.append({ | |
| "node_id": record["id"], | |
| "query": record["query"], | |
| "score": similarity, | |
| "role": record["role"] | |
| }) | |
| # If no similar queries found | |
| if similarities == []: | |
| print(f"No similar queries found above threshold {similarity_threshold}") | |
| return {"should_create_new": True} | |
| # Find best match | |
| best_match = max(similarities, key=lambda x: x["score"]) | |
| # Determine relationship type based on node ID pattern | |
| rel_type = "root" | |
| if "SSQ" in best_match["node_id"]: | |
| rel_type = "sub-sub" | |
| elif "SQ" in best_match["node_id"]: | |
| rel_type = "sub" | |
| return { | |
| "most_similar_query": best_match["query"], | |
| "similarity_score": best_match["score"], | |
| "relationship_type": rel_type, | |
| "node_id": best_match["node_id"], | |
| "should_create_new": best_match["score"] < similarity_threshold | |
| } | |
| except Exception as e: | |
| print(f"Error checking query similarity: {str(e)}") | |
| raise | |
| try: | |
| # Check if a graph already exists | |
| print("Checking for existing graphs...") | |
| result = self.get_graphs() | |
| graphs = list(result) | |
| if graphs == []: # No existing graphs | |
| print("No existing graphs found. Creating new graph.") | |
| self.create_new_graph() | |
| # Emit event for creating a new graph | |
| await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
| await self.build_graph( | |
| query=query, | |
| data=data, | |
| threshold=relevance_threshold, | |
| recurse=sub_sub_queries, | |
| session_id=session_id, | |
| max_tokens_allowed=max_tokens_allowed | |
| ) | |
| # Memory cleanup | |
| gc.collect() | |
| # Prune edges and update pagerank | |
| self.prune_edges() | |
| self.update_pagerank() | |
| # Verify graph integrity and consistency | |
| self.verify_graph_integrity() | |
| self.verify_graph_consistency() | |
| return | |
| # Check similarity with existing root queries | |
| max_similarity = 0 | |
| most_similar_graph = None | |
| # First, consolidate nodes from graphs with same ID | |
| consolidated_graphs = {} | |
| for graph in graphs: | |
| graph_info = graph.get("graph_info") | |
| if not graph_info: | |
| continue | |
| graph_id = graph_info.get("graph_id") | |
| if not graph_id: | |
| continue | |
| # Initialize or append nodes for this graph_id | |
| if graph_id not in consolidated_graphs: | |
| consolidated_graphs[graph_id] = { | |
| "graph_id": graph_id, | |
| "nodes": [] | |
| } | |
| # Add nodes if they exist | |
| if graph_info.get("nodes"): | |
| consolidated_graphs[graph_id]["nodes"].extend(graph_info["nodes"]) | |
| # Now process the consolidated graphs | |
| for graph_id, graph_data in consolidated_graphs.items(): | |
| nodes = graph_data["nodes"] | |
| # Calculate similarity with each node's query | |
| for node in nodes: | |
| if node.get("query"): # Skip nodes without queries | |
| similarity = self.calculate_query_similarity( | |
| query, | |
| node["query"] | |
| ) | |
| if node.get("id").startswith("SQ"): | |
| await self.emit_event("retrieved_sub_query", { | |
| "sub_query": node["query"] | |
| }) | |
| if similarity > max_similarity: | |
| max_similarity = similarity | |
| most_similar_graph = graph_id | |
| if max_similarity >= similarity_threshold: | |
| # Use existing graph | |
| print(f"Found similar query with score {round(max_similarity, 2)}") | |
| self.current_graph_id = most_similar_graph | |
| if round(max_similarity, 2) == 1.0: | |
| print("Loading and using existing graph") | |
| # Emit event for loading an existing graph | |
| await self.emit_event("graph_operation", {"operation_type": "loading_existing_graph"}) | |
| success = self.load_graph(self.root_node_id) | |
| if not success: | |
| raise Exception("Failed to load existing graph") | |
| else: | |
| # Check for node-level similarity | |
| print("Checking for node-level similarity...") | |
| similarity_info = check_query_similarity( | |
| query, | |
| similarity_threshold | |
| ) | |
| if similarity_info["relationship_type"] in ["sub", "sub-sub"]: | |
| print(f"Most Similar Query: {similarity_info['most_similar_query']}") | |
| print("Modifying existing graph structure") | |
| # Emit event for modifying the graph | |
| await self.emit_event("graph_operation", {"operation_type": "modifying_existing_graph"}) | |
| await self.modify_graph( | |
| query, | |
| similarity_info["node_id"], | |
| session_id=session_id | |
| ) | |
| # Memory cleanup | |
| gc.collect() | |
| # Prune edges and update pagerank | |
| self.prune_edges() | |
| self.update_pagerank() | |
| # Verify graph integrity and consistency | |
| self.verify_graph_integrity() | |
| self.verify_graph_consistency() | |
| else: | |
| # Create new graph | |
| print(f"Creating new graph for query: {query}") | |
| self.create_new_graph() | |
| # Emit event for creating a new graph | |
| await self.emit_event("graph_operation", {"operation_type": "creating_new_graph"}) | |
| await self.build_graph( | |
| query=query, | |
| data=data, | |
| threshold=relevance_threshold, | |
| recurse=sub_sub_queries, | |
| session_id=session_id, | |
| max_tokens_allowed=max_tokens_allowed | |
| ) | |
| # Memory cleanup | |
| gc.collect() | |
| # Prune edges and update pagerank | |
| self.prune_edges() | |
| self.update_pagerank() | |
| # Verify graph integrity and consistency | |
| self.verify_graph_integrity() | |
| self.verify_graph_consistency() | |
| except Exception as e: | |
| print(f"Error in process_graph: {str(e)}") | |
| raise | |
| def add_edge_based_on_similarity_and_relevance(self, node1_id: str, node2_id: str, query: str, threshold: float = 0.8): | |
| """Add edges based on node similarity and relevance.""" | |
| try: | |
| with self.transaction() as tx: | |
| # Get node data atomically | |
| result = tx.run( | |
| """ | |
| MATCH (n1:Node {id: $node1_id}) | |
| WITH n1 | |
| MATCH (n2:Node {id: $node2_id}) | |
| RETURN n1.embedding as emb1, n1.data as data1, | |
| n2.embedding as emb2, n2.data as data2 | |
| """, | |
| node1_id=node1_id, | |
| node2_id=node2_id | |
| ) | |
| data = result.single() | |
| if not data or not all([data["emb1"], data["emb2"], data["data1"], data["data2"]]): | |
| return | |
| # Calculate similarities and relevance | |
| similarity = self.cosine_similarity(data["emb1"], data["emb2"]) | |
| query_relevance1 = self.calculate_relevance(query, data["data1"]) | |
| query_relevance2 = self.calculate_relevance(query, data["data2"]) | |
| node_relevance = self.calculate_relevance(data["data1"], data["data2"]) | |
| # Calculate weight | |
| weight = (similarity + query_relevance1 + query_relevance2 + node_relevance) / 4 | |
| # Add edge if weight exceeds threshold | |
| if weight >= threshold: | |
| tx.run( | |
| """ | |
| MATCH (a:Node {id: $node1_id}), (b:Node {id: $node2_id}) | |
| MERGE (a)-[r:RELATION {type: 'similarity_and_relevance'}]->(b) | |
| ON CREATE SET r.weight = $weight | |
| ON MATCH SET r.weight = $weight | |
| """, | |
| node1_id=node1_id, | |
| node2_id=node2_id, | |
| weight=weight | |
| ) | |
| print(f"Added edge between {node1_id} and {node2_id} with type similarity_and_relevance and weight {weight}") | |
| except Exception as e: | |
| print(f"Error in similarity edge creation between {node1_id} and {node2_id}: {str(e)}") | |
| raise | |
| def calculate_relevance(self, data1: str, data2: str) -> float: | |
| """Calculate relevance between two data.""" | |
| try: | |
| if not data1 or not data2: | |
| return 0.0 | |
| P, R, F1 = self.scorer.score([data1], [data2]) | |
| return F1.mean().item() | |
| except Exception as e: | |
| print(f"Error calculating relevance: {str(e)}") | |
| return 0.0 | |
| def calculate_query_similarity(self, query1: str, query2: str) -> float: | |
| """Calculate similarity between two queries.""" | |
| try: | |
| # Generate embeddings | |
| embedding1 = self.model.encode(query1).tolist() | |
| embedding2 = self.model.encode(query2).tolist() | |
| # Calculate cosine similarity | |
| return self.cosine_similarity(embedding1, embedding2) | |
| except Exception as e: | |
| print(f"Error calculating query similarity: {str(e)}") | |
| return 0.0 | |
| def get_similarities_and_relevance(self, threshold: float = 0.8) -> list: | |
| """Get similarities and relevance between nodes.""" | |
| try: | |
| with self.transaction() as tx: | |
| # Get all nodes except root | |
| result = tx.run( | |
| """ | |
| MATCH (n:Node) | |
| WHERE n.id <> $root_id | |
| RETURN n.id as id, n.embedding as embedding, n.data as data | |
| """, | |
| root_id=self.root_node_id | |
| ) | |
| nodes = list(result) | |
| similarities = [] | |
| # Calculate similarities between each pair | |
| for i, node1 in enumerate(nodes): | |
| for node2 in nodes[i + 1:]: | |
| similarity = self.cosine_similarity(node1["embedding"], node2["embedding"]) | |
| relevance = self.calculate_relevance(node1["data"], node2["data"]) | |
| # Calculate weight | |
| weight = (similarity + relevance) / 2 | |
| # Add to results if meets threshold | |
| if weight >= threshold: | |
| similarities.append({ | |
| 'node1': node1["id"], | |
| 'node2': node2["id"], | |
| 'similarity': similarity, | |
| 'relevance': relevance, | |
| 'weight': weight | |
| }) | |
| return similarities | |
| except Exception as e: | |
| print(f"Error getting similarities and relevance: {str(e)}") | |
| return [] | |
| def get_node_relationships(self, node_id=None, depth=None, role=None, relationship_type=None): | |
| """Get relationships between nodes with filtering options.""" | |
| try: | |
| with self.transaction() as tx: | |
| # Build base query | |
| cypher_query = """ | |
| MATCH (n:Node) | |
| WHERE n.id <> $root_id | |
| AND n.graph_id = $current_graph_id | |
| """ | |
| params = { | |
| "root_id": self.root_node_id, | |
| "current_graph_id": self.current_graph_id | |
| } | |
| # Add filters | |
| if node_id: | |
| cypher_query += " AND n.id = $node_id" | |
| params["node_id"] = node_id | |
| if role: | |
| cypher_query += " AND n.role = $role" | |
| params["role"] = role | |
| if depth is not None: | |
| cypher_query += " AND n.depth = $depth" | |
| params["depth"] = depth | |
| # First get outgoing relationships | |
| cypher_query += """ | |
| WITH n | |
| OPTIONAL MATCH (n)-[r1:RELATION]->(m1:Node) | |
| WHERE m1.id <> $root_id | |
| AND m1.graph_id = $current_graph_id | |
| """ | |
| # Add relationship type filter if specified | |
| if relationship_type: | |
| cypher_query += " AND r1.type = $rel_type" | |
| params["rel_type"] = relationship_type | |
| # Then get incoming relationships in a separate match | |
| cypher_query += """ | |
| WITH n, collect({source: n.id, target: m1.id, weight: r1.weight, type: r1.type}) as out_edges | |
| OPTIONAL MATCH (n)<-[r2:RELATION]-(m2:Node) | |
| WHERE m2.id <> $root_id | |
| AND m2.graph_id = $current_graph_id | |
| """ | |
| # Add same relationship type filter for incoming edges | |
| if relationship_type: | |
| cypher_query += " AND r2.type = $rel_type" | |
| # Return both collections | |
| cypher_query += """ | |
| RETURN n.id as node_id, | |
| collect({source: m2.id, target: n.id, weight: r2.weight, type: r2.type}) as in_edges, | |
| out_edges | |
| """ | |
| result = tx.run(cypher_query, params) | |
| relationships = {} | |
| for record in result: | |
| node_id = record["node_id"] | |
| relationships[node_id] = { | |
| 'in_edges': [(edge['source'], edge['target'], { | |
| 'weight': edge['weight'], | |
| 'type': edge['type'] | |
| }) for edge in record["in_edges"] if edge['source'] is not None], | |
| 'out_edges': [(edge['source'], edge['target'], { | |
| 'weight': edge['weight'], | |
| 'type': edge['type'] | |
| }) for edge in record["out_edges"] if edge['target'] is not None] | |
| } | |
| return relationships | |
| except Exception as e: | |
| print(f"Error getting node relationships: {str(e)}") | |
| raise | |
| def find_nodes_by_properties(self, query: str = None, embedding: list = None, | |
| node_data: dict = None, similarity_threshold: float = 0.8) -> list: | |
| """Find nodes based on properties.""" | |
| try: | |
| with self.transaction() as tx: | |
| match_conditions = [] | |
| where_conditions = [] | |
| params = {} | |
| # Build query conditions | |
| if query: | |
| where_conditions.append("n.query CONTAINS $node_query") | |
| params["node_query"] = query | |
| if node_data: | |
| for key, value in node_data.items(): | |
| where_conditions.append(f"n.{key} = ${key}") | |
| params[key] = value | |
| # Construct the base query | |
| cypher_query = "MATCH (n:Node)" | |
| if where_conditions: | |
| cypher_query += " WHERE " + " AND ".join(where_conditions) | |
| cypher_query += " RETURN n" | |
| result = tx.run(cypher_query, params) | |
| matching_nodes = [] | |
| # Process results and calculate similarities | |
| for record in result: | |
| node = record["n"] | |
| match_score = 0 | |
| matches = 0 | |
| # Score based on property matches | |
| if query and query.lower() in node["query"].lower(): | |
| match_score += 1 | |
| matches += 1 | |
| # Score based on embedding similarity | |
| if embedding and "embedding" in node: | |
| similarity = self.cosine_similarity(embedding, node["embedding"]) | |
| if similarity >= similarity_threshold: | |
| match_score += similarity | |
| matches += 1 | |
| # Score based on node_data matches | |
| if node_data: | |
| data_matches = sum(1 for k, v in node_data.items() | |
| if k in node and node[k] == v) | |
| if data_matches > 0: | |
| match_score += data_matches / len(node_data) | |
| matches += 1 | |
| # Add to results if any match found | |
| if matches > 0: | |
| matching_nodes.append({ | |
| "node_id": node["id"], | |
| "score": match_score / matches, | |
| "data": dict(node) | |
| }) | |
| # Sort by score | |
| matching_nodes.sort(key=lambda x: x["score"], reverse=True) | |
| return matching_nodes | |
| except Exception as e: | |
| print(f"Error finding nodes by properties: {str(e)}") | |
| raise | |
| def query_graph(self, query: str) -> str: | |
| """Query the graph in Neo4j for a specific query, collecting data from the entire relevant subgraph.""" | |
| try: | |
| with self.transaction() as tx: | |
| # Find the query node | |
| query_node = tx.run(""" | |
| MATCH (n:Node {query: $node_query}) | |
| WHERE n.graph_id = $graph_id | |
| RETURN n | |
| """, node_query=query, graph_id=self.current_graph_id).single() | |
| if not query_node: | |
| raise ValueError(f"Query node not found for: {query}") | |
| query_node_id = query_node['n']['id'] | |
| datas = [] | |
| # Get entire subgraph including all relationship types and independent nodes | |
| subgraph_paths = tx.run(""" | |
| // First get the query node and all its connected paths | |
| MATCH path = (n:Node {id: $node_id})-[r:RELATION*0..]->(m:Node) | |
| WHERE n.graph_id = $graph_id | |
| // Collect all nodes and relationships in these paths | |
| WITH COLLECT(path) as paths | |
| UNWIND paths as path | |
| WITH DISTINCT path | |
| // Get all nodes and relationships from the paths | |
| WITH nodes(path) as nodes, relationships(path) as rels | |
| // Calculate path weight considering all relationship types | |
| WITH nodes, rels, | |
| reduce(weight = 1.0, rel in rels | | |
| CASE rel.type | |
| WHEN 'logical' THEN weight * rel.weight * 1.2 | |
| WHEN 'hierarchical' THEN weight * rel.weight * 1.1 | |
| WHEN 'similarity_and_relevance' THEN weight * rel.weight * 0.9 | |
| ELSE weight * rel.weight | |
| END | |
| ) as path_weight | |
| // Unwind nodes to get individual records | |
| UNWIND nodes as node | |
| WITH DISTINCT node, path_weight | |
| WHERE node.data IS NOT NULL | |
| AND node.data <> '' // Ensure data is not empty | |
| // Return ordered by weight and pagerank for better context flow | |
| RETURN node.data as data, | |
| path_weight, | |
| node.role as role, | |
| node.pagerank as pagerank | |
| ORDER BY | |
| CASE node.role | |
| WHEN 'pre-requisite' THEN 3 | |
| WHEN 'independent' THEN 2 | |
| ELSE 1 | |
| END DESC, | |
| path_weight DESC, | |
| pagerank DESC | |
| """, node_id=query_node_id, graph_id=self.current_graph_id) | |
| # Collect data in the order they were returned (already optimally sorted) | |
| for record in subgraph_paths: | |
| data = record["data"] | |
| if data and isinstance(data, str): | |
| datas.append(data.strip()) | |
| # If no data are found, return an empty string | |
| if datas == []: | |
| print(f"No data found for: {query}") | |
| return "" | |
| # Return combined data | |
| return "\n\n".join([f"Data {i+1}:\n{data}" for i, data in enumerate(datas)]) | |
| except Exception as e: | |
| print(f"Error querying graph for specific query: {str(e)}") | |
| raise | |
| def prune_edges(self, max_edges: int = 1000): | |
| """Prune excess edges while preserving node data.""" | |
| try: | |
| print(f"Pruning edges to keep top {max_edges} edges by weight...") | |
| with self.transaction() as tx: | |
| try: | |
| # Count current edges | |
| result = tx.run( | |
| """ | |
| MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
| RETURN count(r) AS count | |
| """, | |
| graphID=self.current_graph_id | |
| ) | |
| current_edges = result.single()["count"] | |
| if current_edges > max_edges: | |
| # Mark edges to keep | |
| tx.run( | |
| """ | |
| MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
| WITH r | |
| ORDER BY r.weight DESC | |
| LIMIT $max_edges | |
| SET r:KEEP | |
| """, | |
| graphID=self.current_graph_id, | |
| max_edges=max_edges | |
| ) | |
| # Remove excess edges | |
| tx.run( | |
| """ | |
| MATCH (a:Node {graph_id: $graphID})-[r:RELATION]->(b:Node {graph_id: $graphID}) | |
| WHERE NOT r:KEEP | |
| DELETE r | |
| """, | |
| graphID=self.current_graph_id | |
| ) | |
| # Remove temporary label | |
| tx.run( | |
| """ | |
| MATCH (a:Node {graph_id: $graphID})-[r:KEEP]->(b:Node {graph_id: $graphID}) | |
| REMOVE r:KEEP | |
| """, | |
| graphID=self.current_graph_id | |
| ) | |
| tx.commit() | |
| print(f"Pruned edges. Kept top {max_edges} edges by weight.") | |
| print("No pruning needed. Current edge count is within limits.") | |
| except Exception as e: | |
| tx.rollback() | |
| raise e | |
| except Exception as e: | |
| print(f"Error pruning edges: {str(e)}") | |
| raise | |
| def update_pagerank(self): | |
| """Update PageRank values using Neo4j's graph algorithms.""" | |
| if not self.current_graph_id: | |
| print("No current graph selected. Cannot compute PageRank.") | |
| return | |
| try: | |
| with self.transaction() as tx: | |
| # Create graph projection with weighted relationships | |
| tx.run( | |
| """ | |
| CALL gds.graph.project.cypher( | |
| 'graphProjection', | |
| 'MATCH (n:Node) WHERE n.graph_id = $myParam RETURN id(n) AS id', | |
| 'MATCH (n:Node)-[r:RELATION]->(m:Node) | |
| WHERE n.graph_id = $myParam AND m.graph_id = $myParam | |
| RETURN id(n) AS source, | |
| id(m) AS target, | |
| CASE r.type | |
| WHEN "logical" THEN r.weight * 2 | |
| ELSE r.weight | |
| END AS weight', | |
| { parameters: { myParam: $graphId } } | |
| ) | |
| """, | |
| graphId=self.current_graph_id | |
| ) | |
| # Run PageRank with relationship weights | |
| tx.run( | |
| """ | |
| CALL gds.pageRank.write( | |
| 'graphProjection', | |
| { | |
| relationshipWeightProperty: 'weight', | |
| writeProperty: 'pagerank', | |
| maxIterations: 20, | |
| dampingFactor: 0.85, | |
| concurrency: 4 | |
| } | |
| ) | |
| """ | |
| ) | |
| # Clean up projection | |
| tx.run( | |
| """ | |
| CALL gds.graph.drop('graphProjection') | |
| """ | |
| ) | |
| print("PageRank updated successfully") | |
| except Exception as e: | |
| print(f"Error updating PageRank: {str(e)}") | |
| raise | |
| def display_graph(self, query: str): | |
| """Display the graph""" | |
| try: | |
| with self.transaction() as tx: | |
| # 1. Find the graph_id(s) of the node using the provided query | |
| cypher_query = """ | |
| MATCH (n:Node) | |
| WHERE n.query = $node_query | |
| RETURN COLLECT(DISTINCT n.graph_id) AS graph_ids | |
| """ | |
| result = tx.run(cypher_query, node_query=query) | |
| graph_ids = result.single().get("graph_ids", []) | |
| if not graph_ids: | |
| print("No graph found for the given query.") | |
| return | |
| # Create the PyVis network once, so we can add all data to it: | |
| net = Network( | |
| height="600px", | |
| width="100%", | |
| directed=True, | |
| bgcolor="#222222", | |
| font_color="white" | |
| ) | |
| # Disable physics initially | |
| net.options = {"physics": {"enabled": False}} | |
| all_nodes = set() | |
| all_edges = [] | |
| for graph_id in graph_ids: | |
| # 2. Fetch Graph Data for this graph_id | |
| result = tx.run(f"MATCH (n)-[r]->(m) WHERE n.graph_id = '{graph_id}' RETURN n, r, m") | |
| for record in result: | |
| source_node = record["n"] | |
| target_node = record["m"] | |
| relationship = record["r"] | |
| source_id = source_node.get("id") | |
| target_id = target_node.get("id") | |
| # Build a descriptive tooltip for each node | |
| source_tooltip = ( | |
| f"Query: {source_node.get('query', 'N/A')}" | |
| ) | |
| target_tooltip = ( | |
| f"Query: {target_node.get('query', 'N/A')}" | |
| ) | |
| # Add source node if not already in the set | |
| if source_id not in all_nodes: | |
| net.add_node( | |
| source_id, | |
| label=source_id, | |
| title=source_tooltip, | |
| size=20, | |
| color="#00cc66" | |
| ) | |
| all_nodes.add(source_id) | |
| # Add target node if not already in the set | |
| if target_id not in all_nodes: | |
| net.add_node( | |
| target_id, | |
| label=target_id, | |
| title=target_tooltip, | |
| size=20, | |
| color="#00cc66" | |
| ) | |
| all_nodes.add(target_id) | |
| # Add edge | |
| all_edges.append({ | |
| "from": source_id, | |
| "to": target_id, | |
| "label": relationship.type, | |
| }) | |
| # Add all edges | |
| for edge in all_edges: | |
| net.add_edge( | |
| edge["from"], | |
| edge["to"], | |
| title=edge["label"], | |
| color="#cccccc" | |
| ) | |
| # 4. Enable improved layout and dragNodes | |
| net.options["layout"] = {"improvedLayout": True} | |
| net.options["interaction"] = {"dragNodes": True} | |
| # 5. Save to a temporary file, read it, then remove that file | |
| net.save_graph("temp_graph.html") | |
| with open("temp_graph.html", "r", encoding="utf-8") as f: | |
| html_str = f.read() | |
| os.remove("temp_graph.html") # Clean up the temp file | |
| return html_str | |
| except Exception as e: | |
| print(f"Error displaying graph: {str(e)}") | |
| raise | |
| def verify_graph_integrity(self): | |
| """Verify and fix graph integrity issues.""" | |
| try: | |
| with self.transaction() as tx: | |
| # Check for orphaned nodes | |
| orphaned = tx.run( | |
| """ | |
| MATCH (n:Node {graph_id: $graph_id}) | |
| WHERE NOT (n)-[:RELATION]-() | |
| RETURN n.id as node_id | |
| """, | |
| graph_id=self.current_graph_id | |
| ).values() | |
| if orphaned: | |
| print(f"Found orphaned nodes: {orphaned}") | |
| # Check for invalid edges | |
| invalid_edges = tx.run( | |
| """ | |
| MATCH (a:Node)-[r:RELATION]->(b:Node) | |
| WHERE a.graph_id = $graph_id | |
| AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) | |
| RETURN a.id as from_id, b.id as to_id | |
| """, | |
| graph_id=self.current_graph_id | |
| ).values() | |
| if invalid_edges: | |
| print(f"Found invalid edges: {invalid_edges}") | |
| # Optionally fix issues | |
| tx.run( | |
| """ | |
| MATCH (a:Node)-[r:RELATION]->(b:Node) | |
| WHERE a.graph_id = $graph_id | |
| AND (b.graph_id <> $graph_id OR b.graph_id IS NULL) | |
| DELETE r | |
| """, | |
| graph_id=self.current_graph_id | |
| ) | |
| print("Graph integrity verified successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error verifying graph integrity: {str(e)}") | |
| raise | |
| def verify_graph_consistency(self): | |
| """Verify consistency of the Neo4j graph.""" | |
| try: | |
| with self.driver.session() as session: | |
| # Check for nodes without required properties | |
| missing_props = session.run(""" | |
| MATCH (n:Node) | |
| WHERE n.id IS NULL OR n.query IS NULL | |
| RETURN count(n) as count | |
| """) | |
| if missing_props.single()["count"] > 0: | |
| raise ValueError("Found nodes with missing required properties") | |
| # Check for relationship consistency | |
| invalid_rels = session.run(""" | |
| MATCH ()-[r:RELATION]->() | |
| WHERE r.type IS NULL OR r.weight IS NULL | |
| RETURN count(r) as count | |
| """) | |
| if invalid_rels.single()["count"] > 0: | |
| raise ValueError("Found relationships with missing required properties") | |
| print("Graph consistency verified successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error verifying graph consistency: {str(e)}") | |
| raise | |
| async def close(self): | |
| """Properly cleanup all resources.""" | |
| try: | |
| # Shutdown executor | |
| if hasattr(self, 'executor'): | |
| self.executor.shutdown(wait=True) | |
| # Close Neo4j driver | |
| if hasattr(self, 'driver'): | |
| self.driver.close() | |
| # Cleanup crawler resources and browser contexts | |
| if hasattr(self, 'crawler'): | |
| await asyncio.shield(self.crawler.cleanup_expired_sessions()) | |
| await asyncio.shield(self.crawler.cleanup_browser_context(self.session_id)) | |
| except Exception as e: | |
| print(f"Error during cleanup: {e}") | |
| def cosine_similarity(v1: List[float], v2: List[float]) -> float: | |
| """Calculate cosine similarity between two vectors.""" | |
| try: | |
| v1_array = np.array(v1) | |
| v2_array = np.array(v2) | |
| return np.dot(v1_array, v2_array) / (np.linalg.norm(v1_array) * np.linalg.norm(v2_array)) | |
| except Exception as e: | |
| print(f"Error calculating cosine similarity: {str(e)}") | |
| return 0.0 | |
| if __name__ == "__main__": | |
| import os | |
| from dotenv import load_dotenv | |
| from src.reasoning.reasoner import Reasoner | |
| from src.evaluation.evaluator import Evaluator | |
| load_dotenv() | |
| graph_search = Neo4jGraphRAG(num_workers=24) | |
| evaluator = Evaluator() | |
| reasoner = Reasoner() | |
| async def test_graph_search(): | |
| # Sample data for testing | |
| queries = [ | |
| """In the context of global economic recovery and energy security concerns, provide an in-depth comparative assessment of the renewable energy policies among G20 countries. | |
| Specifically, examine how short-term economic stimulus measures intersect with long-term decarbonization commitments, including: | |
| 1. Carbon pricing mechanisms | |
| 2. Subsidies for emerging technologies (such as green hydrogen and battery storage) | |
| 3. Cross-border climate finance initiatives | |
| Highlight the unique challenges faced by both advanced and emerging economies in addressing: | |
| 1. Energy poverty | |
| 2. Supply chain disruptions | |
| 3. Geopolitical tensions (e.g., the Russia-Ukraine conflict) | |
| Discuss how these factors influence policy effectiveness, and evaluate the degree to which each country is on track to meet—or exceed—its Paris Agreement targets. | |
| Note any significant policy gaps, regional collaborations, or innovative best practices. | |
| Lastly, provide a forward-looking perspective on how these renewable energy strategies may evolve over the next decade, considering: | |
| 1. Technological breakthroughs | |
| 2. Global market trends | |
| 3. Potential climate-related disasters | |
| Present your analysis as a detailed, well-formatted report.""", | |
| """Analyse the impact of 'hot-money' on the value of Indian Rupee and answer the following questions:- | |
| 1. How does it affect the exchange rate? | |
| 2. How can it be mitigated/eliminated? | |
| 3. Why is it a problem? | |
| 4. What are the consequences? | |
| 5. What are the alternatives? | |
| - Evaluate the alternatives for pros and cons. | |
| - Evaluate the impact of alternatives on the exchange rate. | |
| - How can they be implemented? | |
| - What are the consequences of each alternative? | |
| - Evaluate the feasibility of the alternatives. | |
| - Pick top 5 alternatives and justify your choices in detail. | |
| 6. What are the implications for the Indian economy? Furthermore:- | |
| - Evaluate the impact of the chosen alternatives on the Indian economy.""", | |
| """Inflation has been an intrinsic past of human civilization since the very beginning. Answer the following questions:- | |
| 1. How true is the above statement? | |
| 2. What are the causes of inflation? | |
| 3. What are the consequences of inflation? | |
| 4. Can we completely eliminate inflation?""", | |
| """Perform a detailed comparison between the ancient Greece and Roman civilizations. | |
| 1. What were the key differences between the two civilizations? | |
| - Evaluate the differences in governance, society, and culture | |
| - Evaluate the differences in economy, trade, and military | |
| - Evaluate the differences in technology and infrastructure | |
| 2. What were the similarities between the two civilizations? | |
| - Evaluate the similarities in governance, society, and culture | |
| - Evaluate the similarities in economy, trade, and military | |
| - Evaluate the similarities in technology and infrastructure | |
| 3. How did these two civilizations influence each other? | |
| - Evaluate the influence of one civilization on the other | |
| 4. How did these two civilizations influence the modern world? | |
| 5. Was there another civilization that influenced these two? If yes, how?""", | |
| """Evaluate the long-term effects of colonialism on economic development in Asia:- | |
| 1. Include case studies of at least five different countries | |
| 2. Analyze how these effects differ based on colonial power, time of independence, and resource distribution | |
| - Evaluate the impact of colonialism on the economy of the country | |
| - Evaluate the impact of colonialism on the economy of the region | |
| - Evaluate the impact of colonialism on the economy of the world | |
| 3. How do these effects compare to Africa?""" | |
| ] | |
| follow_on_queries = [ | |
| "How is 'hot-money' related to the current economic situation in India?", | |
| "What is inflation?", | |
| "Did ancient Greece and Rome have any impact on modern democracy? If yes, how?", | |
| "Did colonialism have any impact on the trade between Africa and Asia, both in colonial and post-colonial times? If yes, how?" | |
| ] | |
| query = queries[2] | |
| # Initialize the database schema | |
| graph_search.initialize_schema() | |
| # Build the graph in Neo4j | |
| await graph_search.process_graph(query, similarity_threshold=0.8, relevance_threshold=0.8) | |
| # Query the graph and generate a response | |
| answer = graph_search.query_graph(query) | |
| response = "" | |
| async for chunk in reasoner.answer(query, answer): | |
| response += chunk | |
| print(response, end="", flush=True) | |
| # Display the graph | |
| graph_search.display_graph(query) | |
| # Evaluate the response | |
| evaluation = await evaluator.evaluate_response(query, response, [answer]) | |
| print(f"Faithfulness: {evaluation['faithfulness']}") | |
| print(f"Answer Relevancy: {evaluation['answer relevancy']}") | |
| print(f"Context Utilization: {evaluation['contextual recall']}") | |
| # Shutdown the executor after all tasks are complete | |
| await graph_search.close() | |
| # Run the test function | |
| asyncio.run(test_graph_search()) |