Spaces:
Paused
Paused
| """ | |
| Graph-based RAG using NetworkX. | |
| Updated to match the common query signature used by other methods. | |
| """ | |
| import numpy as np | |
| import logging | |
| from typing import Tuple, List, Optional | |
| from openai import OpenAI | |
| import networkx as nx | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from config import * | |
| from utils import classify_image | |
| logger = logging.getLogger(__name__) | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| # Global variables for lazy loading | |
| _graph = None | |
| _enodes = None | |
| _embeddings = None | |
| def _load_graph(): | |
| """Lazy load graph database.""" | |
| global _graph, _enodes, _embeddings | |
| if _graph is None: | |
| try: | |
| if GRAPH_FILE.exists(): | |
| logger.info("Loading graph database...") | |
| _graph = nx.read_gml(str(GRAPH_FILE)) | |
| _enodes = list(_graph.nodes) | |
| # Convert embeddings from lists back to numpy arrays | |
| embeddings_list = [] | |
| for n in _enodes: | |
| embedding = _graph.nodes[n]['embedding'] | |
| if isinstance(embedding, list): | |
| embeddings_list.append(np.array(embedding)) | |
| else: | |
| embeddings_list.append(embedding) | |
| _embeddings = np.array(embeddings_list) | |
| logger.info(f"✓ Loaded graph with {len(_enodes)} nodes") | |
| else: | |
| logger.warning("Graph database not found. Run preprocess.py first.") | |
| _graph = nx.Graph() | |
| _enodes = [] | |
| _embeddings = np.array([]) | |
| except Exception as e: | |
| logger.error(f"Error loading graph: {e}") | |
| _graph = nx.Graph() | |
| _enodes = [] | |
| _embeddings = np.array([]) | |
| def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]: | |
| """ | |
| Query using graph-based retrieval. | |
| Args: | |
| question: User's question | |
| image_path: Optional path to an image (for multimodal queries) | |
| top_k: Number of relevant chunks to retrieve | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| # Load graph if not already loaded | |
| _load_graph() | |
| if len(_enodes) == 0: | |
| return "Graph database is empty. Please run preprocess.py first.", [] | |
| # Embed question using OpenAI | |
| emb_resp = client.embeddings.create( | |
| model=OPENAI_EMBEDDING_MODEL, | |
| input=question | |
| ) | |
| q_vec = np.array(emb_resp.data[0].embedding) | |
| # Compute cosine similarities | |
| sims = cosine_similarity([q_vec], _embeddings)[0] | |
| idxs = sims.argsort()[::-1][:top_k] | |
| # Collect chunk-level info | |
| chunks = [] | |
| citations = [] | |
| sources_seen = set() | |
| for rank, i in enumerate(idxs, start=1): | |
| node = _enodes[i] | |
| node_data = _graph.nodes[node] | |
| text = node_data['text'] | |
| # Extract header from text | |
| header = text.split('\n', 1)[0].lstrip('#').strip() | |
| score = sims[i] | |
| # Extract citation format - get source from metadata or node_data | |
| metadata = node_data.get('metadata', {}) | |
| source = metadata.get('source') or node_data.get('source') | |
| if not source: | |
| continue | |
| if 'url' in metadata: # HTML source | |
| citation_ref = metadata['url'] | |
| cite_type = 'html' | |
| elif 'path' in metadata: # PDF source | |
| citation_ref = metadata['path'] | |
| cite_type = 'pdf' | |
| elif 'url' in node_data: # Legacy format | |
| citation_ref = node_data['url'] | |
| cite_type = 'html' | |
| elif 'path' in node_data: # Legacy format | |
| citation_ref = node_data['path'] | |
| cite_type = 'pdf' | |
| else: | |
| citation_ref = source | |
| cite_type = 'unknown' | |
| chunks.append({ | |
| 'header': header, | |
| 'score': score, | |
| 'text': text, | |
| 'citation': citation_ref | |
| }) | |
| # Add unique citation | |
| if source not in sources_seen: | |
| citation_entry = { | |
| 'source': source, | |
| 'type': cite_type, | |
| 'relevance_score': round(float(score), 3) | |
| } | |
| if cite_type == 'html': | |
| citation_entry['url'] = citation_ref | |
| elif cite_type == 'pdf': | |
| citation_entry['path'] = citation_ref | |
| citations.append(citation_entry) | |
| sources_seen.add(source) | |
| # Handle image if provided | |
| image_context = "" | |
| if image_path: | |
| try: | |
| # Classify the image | |
| classification = classify_image(image_path) | |
| image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]" | |
| # Optionally, find related nodes in graph based on image classification | |
| # This would require storing image-related metadata in the graph | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| # Assemble context for prompt | |
| context = "\n\n---\n\n".join([c['text'] for c in chunks]) | |
| prompt = f"""Use the following context to answer the question: | |
| {context}{image_context} | |
| Question: {question} | |
| Please provide a comprehensive answer based on the context provided. Cite specific sources when providing information.""" | |
| # For GPT-5, temperature must be default (1.0) | |
| chat_resp = client.chat.completions.create( | |
| model=OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always provide accurate information based on the given context."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_completion_tokens=DEFAULT_MAX_TOKENS | |
| ) | |
| answer = chat_resp.choices[0].message.content | |
| return answer, citations | |
| def query_with_graph_traversal(question: str, top_k: int = 5, max_hops: int = 2) -> Tuple[str, List[dict]]: | |
| """ | |
| Enhanced graph query that can traverse edges to find related information. | |
| Args: | |
| question: User's question | |
| top_k: Number of initial nodes to retrieve | |
| max_hops: Maximum graph traversal depth | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| # Load graph if not already loaded | |
| _load_graph() | |
| if len(_enodes) == 0: | |
| return "Graph database is empty. Please run preprocess.py first.", [] | |
| # Get initial nodes using standard query | |
| initial_answer, initial_citations = query(question, top_k=top_k) | |
| # For a more sophisticated implementation, you would: | |
| # 1. Add edges between related nodes during preprocessing | |
| # 2. Traverse from initial nodes to find related content | |
| # 3. Score the related nodes based on path distance and relevance | |
| # For now, return the standard query results | |
| return initial_answer, initial_citations | |
| def query_subgraph(question: str, source_filter: str = None, top_k: int = 5) -> Tuple[str, List[dict]]: | |
| """ | |
| Query a specific subgraph filtered by source. | |
| Args: | |
| question: User's question | |
| source_filter: Filter nodes by source (e.g., specific PDF name) | |
| top_k: Number of relevant chunks to retrieve | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| # Load graph if not already loaded | |
| _load_graph() | |
| # Filter nodes if source specified | |
| if source_filter: | |
| filtered_nodes = [] | |
| for n in _enodes: | |
| node_data = _graph.nodes[n] | |
| metadata = node_data.get('metadata', {}) | |
| source = metadata.get('source') or node_data.get('source', '') | |
| source_from_meta = metadata.get('source', '') | |
| # Check both direct source and metadata source | |
| if (source_filter.lower() in source.lower() or | |
| source_filter.lower() in source_from_meta.lower()): | |
| filtered_nodes.append(n) | |
| if not filtered_nodes: | |
| return f"No nodes found for source: {source_filter}", [] | |
| else: | |
| filtered_nodes = _enodes | |
| # Get embeddings for filtered nodes | |
| filtered_embeddings = np.array([_graph.nodes[n]['embedding'] for n in filtered_nodes]) | |
| # Embed question | |
| emb_resp = client.embeddings.create( | |
| model=OPENAI_EMBEDDING_MODEL, | |
| input=question | |
| ) | |
| q_vec = np.array(emb_resp.data[0].embedding) | |
| # Compute similarities | |
| sims = cosine_similarity([q_vec], filtered_embeddings)[0] | |
| idxs = sims.argsort()[::-1][:top_k] | |
| # Collect results | |
| chunks = [] | |
| citations = [] | |
| sources_seen = set() | |
| for i in idxs: | |
| if i < len(filtered_nodes): | |
| node = filtered_nodes[i] | |
| node_data = _graph.nodes[node] | |
| chunks.append(node_data['text']) | |
| # Skip if source information missing | |
| metadata = node_data.get('metadata', {}) | |
| source = metadata.get('source') or node_data.get('source') | |
| if not source: | |
| continue | |
| if source not in sources_seen: | |
| citation = { | |
| 'source': source, | |
| 'type': 'pdf' if ('path' in metadata or 'path' in node_data) else 'html', | |
| 'relevance_score': round(float(sims[i]), 3) | |
| } | |
| # Check metadata first, then node_data for legacy support | |
| if 'url' in metadata: | |
| citation['url'] = metadata['url'] | |
| elif 'path' in metadata: | |
| citation['path'] = metadata['path'] | |
| elif 'url' in node_data: | |
| citation['url'] = node_data['url'] | |
| elif 'path' in node_data: | |
| citation['path'] = node_data['path'] | |
| citations.append(citation) | |
| sources_seen.add(source) | |
| # Build context and generate answer | |
| context = "\n\n---\n\n".join(chunks) | |
| prompt = f"""Answer the following question using the provided context: | |
| Context from {source_filter if source_filter else 'all sources'}: | |
| {context} | |
| Question: {question} | |
| Provide a detailed answer based on the context.""" | |
| # For GPT-5, temperature must be default (1.0) | |
| response = client.chat.completions.create( | |
| model=OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are an expert on manufacturing safety. Answer based on the provided context."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_completion_tokens=DEFAULT_MAX_TOKENS | |
| ) | |
| answer = response.choices[0].message.content | |
| return answer, citations | |
| # Maintain backward compatibility with original function signature | |
| def query_graph(question: str, top_k: int = 5) -> Tuple[str, List[str], List[tuple]]: | |
| """ | |
| Original query_graph function signature for backward compatibility. | |
| Args: | |
| question: User's question | |
| top_k: Number of relevant chunks to retrieve | |
| Returns: | |
| Tuple of (answer, sources, chunks) | |
| """ | |
| # Call the new query function | |
| answer, citations = query(question, top_k=top_k) | |
| # Convert citations to old format | |
| sources = [c['source'] for c in citations] | |
| # Get chunks in old format (header, score, text, citation) | |
| _load_graph() | |
| if len(_enodes) == 0: | |
| return answer, sources, [] | |
| # Regenerate chunks for backward compatibility | |
| emb_resp = client.embeddings.create( | |
| model=OPENAI_EMBEDDING_MODEL, | |
| input=question | |
| ) | |
| q_vec = np.array(emb_resp.data[0].embedding) | |
| sims = cosine_similarity([q_vec], _embeddings)[0] | |
| idxs = sims.argsort()[::-1][:top_k] | |
| chunks = [] | |
| for i in idxs: | |
| node = _enodes[i] | |
| node_data = _graph.nodes[node] | |
| text = node_data['text'] | |
| header = text.split('\n', 1)[0].lstrip('#').strip() | |
| score = sims[i] | |
| # Skip if source information missing | |
| metadata = node_data.get('metadata', {}) | |
| source = metadata.get('source') or node_data.get('source') | |
| if not source: | |
| continue | |
| if 'url' in metadata: | |
| citation = metadata['url'] | |
| elif 'path' in metadata: | |
| citation = metadata['path'] | |
| elif 'url' in node_data: | |
| citation = node_data['url'] | |
| elif 'path' in node_data: | |
| citation = node_data['path'] | |
| else: | |
| citation = source | |
| chunks.append((header, score, text, citation)) | |
| return answer, sources, chunks | |
| if __name__ == "__main__": | |
| # Test the updated graph query | |
| test_questions = [ | |
| "What are general machine guarding requirements?", | |
| "How do I perform lockout/tagout procedures?", | |
| "What safety measures are needed for robotic systems?" | |
| ] | |
| for q in test_questions: | |
| print(f"\nQuestion: {q}") | |
| answer, citations = query(q) | |
| print(f"Answer: {answer[:200]}...") | |
| print(f"Citations: {[c['source'] for c in citations]}") | |
| print("-" * 50) |