Spaces:
Paused
Paused
| import os | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| import networkx as nx | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Initialize OpenAI client | |
| load_dotenv(override=True) | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # Load graph from GML | |
| G = nx.read_gml("graph.gml") | |
| enodes = list(G.nodes) | |
| embeddings = np.array([G.nodes[n]['embedding'] for n in enodes]) | |
| def query_graph(question, top_k=5): | |
| """ | |
| Embed the question, retrieve the top_k relevant chunks, | |
| and return: (answer, sources, chunks) | |
| - answer: generated response string | |
| - sources: list of unique source names | |
| - chunks: list of tuples (header, score, full_text, source_url_or_path) | |
| """ | |
| # Embed question | |
| emb_resp = client.embeddings.create( | |
| model="text-embedding-3-large", | |
| input=question | |
| ) | |
| q_vec = emb_resp.data[0].embedding | |
| # Compute cosine similarities | |
| sims = cosine_similarity([q_vec], embeddings)[0] | |
| idxs = sims.argsort()[::-1][:top_k] | |
| # Collect chunk-level info | |
| chunks = [] | |
| sources = [] | |
| for rank, i in enumerate(idxs, start=1): | |
| node = enodes[i] | |
| text = G.nodes[node]['text'] | |
| header = text.split('\n', 1)[0].lstrip('# ').strip() | |
| score = sims[i] | |
| # Determine citation (URL for HTML, path for PDF) | |
| citation = G.nodes[node].get('url') or G.nodes[node].get('path') or G.nodes[node]['source'] | |
| chunks.append((header, score, text, citation)) | |
| sources.append(G.nodes[node]['source']) | |
| # Deduplicate sources | |
| sources = list(dict.fromkeys(sources)) | |
| # Assemble prompt | |
| context = "\n\n---\n\n".join([c[2] for c in chunks]) | |
| prompt = ( | |
| "Use the following context to answer the question:\n\n" + | |
| context + | |
| f"\n\nQuestion: {question}\nAnswer:" | |
| ) | |
| # Query chat model | |
| chat_resp = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| answer = chat_resp.choices[0].message.content | |
| return answer, sources, chunks | |
| """ | |
| Embed the user question, retrieve the top_k relevant chunks from the graph, | |
| assemble a prompt with those chunks, call the chat model, and return: | |
| - answer: the generated response | |
| - sources: unique list of source documents | |
| - chunks: list of (header, score, full_text) for the top_k passages | |
| """ | |
| # Embed the question | |
| emb_resp = client.embeddings.create( | |
| model="text-embedding-3-large", | |
| input=question | |
| ) | |
| q_vec = emb_resp.data[0].embedding | |
| # Compute similarities against all stored embeddings | |
| sims = cosine_similarity([q_vec], embeddings)[0] | |
| idxs = sims.argsort()[::-1][:top_k] | |
| # Gather chunk‑level info and sources | |
| chunks = [] | |
| sources = [] | |
| for i in idxs: | |
| node = enodes[i] | |
| text = G.nodes[node]['text'] | |
| # Use the first line as the header | |
| header = text.split('\n', 1)[0].lstrip('# ').strip() | |
| score = sims[i] | |
| chunks.append((header, score, text)) | |
| sources.append(G.nodes[node]['source']) | |
| # Deduplicate sources while preserving order | |
| sources = list(dict.fromkeys(sources)) | |
| # Assemble the prompt from the chunk texts | |
| context_text = "\n\n---\n\n".join([chunk[2] for chunk in chunks]) | |
| prompt = ( | |
| "Use the following context to answer the question:\n\n" | |
| + context_text | |
| + f"\n\nQuestion: {question}\nAnswer:" | |
| ) | |
| # Call the chat model | |
| chat_resp = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| ) | |
| answer = chat_resp.choices[0].message.content | |
| return answer, sources, chunks | |
| # Test queries | |
| # test_questions = [ | |
| # "What are general machine guarding requirements?", | |
| # "Explain the key steps in lockout/tagout procedures." | |
| # ] | |
| # for q in test_questions: | |
| # answer, sources, chunks = query_graph(q) | |
| # print(f"Q: {q}") | |
| # print(f"Answer: {answer}\n") | |
| # print("Sources:") | |
| # for src in sources: | |
| # print(f"- {src}") | |
| # print("\nTop Chunks:") | |
| # for header, score, _, citation in chunks: | |
| # print(f" * {header} (score: {score:.2f}) from {citation}") | |
| # print("\n", "#"*40, "\n") | |