sight_chat / query_graph.py
fmegahed's picture
sight chat app v0.0.2
01c0ebb verified
raw
history blame
4.71 kB
import os
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
# Initialize OpenAI client
load_dotenv(override=True)
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Load graph from GML
G = nx.read_gml("graph.gml")
enodes = list(G.nodes)
embeddings = np.array([G.nodes[n]['embedding'] for n in enodes])
def query_graph(question, top_k=5):
"""
Embed the question, retrieve the top_k relevant chunks,
and return: (answer, sources, chunks)
- answer: generated response string
- sources: list of unique source names
- chunks: list of tuples (header, score, full_text, source_url_or_path)
"""
# Embed question
emb_resp = client.embeddings.create(
model="text-embedding-3-large",
input=question
)
q_vec = emb_resp.data[0].embedding
# Compute cosine similarities
sims = cosine_similarity([q_vec], embeddings)[0]
idxs = sims.argsort()[::-1][:top_k]
# Collect chunk-level info
chunks = []
sources = []
for rank, i in enumerate(idxs, start=1):
node = enodes[i]
text = G.nodes[node]['text']
header = text.split('\n', 1)[0].lstrip('# ').strip()
score = sims[i]
# Determine citation (URL for HTML, path for PDF)
citation = G.nodes[node].get('url') or G.nodes[node].get('path') or G.nodes[node]['source']
chunks.append((header, score, text, citation))
sources.append(G.nodes[node]['source'])
# Deduplicate sources
sources = list(dict.fromkeys(sources))
# Assemble prompt
context = "\n\n---\n\n".join([c[2] for c in chunks])
prompt = (
"Use the following context to answer the question:\n\n" +
context +
f"\n\nQuestion: {question}\nAnswer:"
)
# Query chat model
chat_resp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."},
{"role": "user", "content": prompt}
]
)
answer = chat_resp.choices[0].message.content
return answer, sources, chunks
"""
Embed the user question, retrieve the top_k relevant chunks from the graph,
assemble a prompt with those chunks, call the chat model, and return:
- answer: the generated response
- sources: unique list of source documents
- chunks: list of (header, score, full_text) for the top_k passages
"""
# Embed the question
emb_resp = client.embeddings.create(
model="text-embedding-3-large",
input=question
)
q_vec = emb_resp.data[0].embedding
# Compute similarities against all stored embeddings
sims = cosine_similarity([q_vec], embeddings)[0]
idxs = sims.argsort()[::-1][:top_k]
# Gather chunk‑level info and sources
chunks = []
sources = []
for i in idxs:
node = enodes[i]
text = G.nodes[node]['text']
# Use the first line as the header
header = text.split('\n', 1)[0].lstrip('# ').strip()
score = sims[i]
chunks.append((header, score, text))
sources.append(G.nodes[node]['source'])
# Deduplicate sources while preserving order
sources = list(dict.fromkeys(sources))
# Assemble the prompt from the chunk texts
context_text = "\n\n---\n\n".join([chunk[2] for chunk in chunks])
prompt = (
"Use the following context to answer the question:\n\n"
+ context_text
+ f"\n\nQuestion: {question}\nAnswer:"
)
# Call the chat model
chat_resp = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."},
{"role": "user", "content": prompt}
]
)
answer = chat_resp.choices[0].message.content
return answer, sources, chunks
# Test queries
# test_questions = [
# "What are general machine guarding requirements?",
# "Explain the key steps in lockout/tagout procedures."
# ]
# for q in test_questions:
# answer, sources, chunks = query_graph(q)
# print(f"Q: {q}")
# print(f"Answer: {answer}\n")
# print("Sources:")
# for src in sources:
# print(f"- {src}")
# print("\nTop Chunks:")
# for header, score, _, citation in chunks:
# print(f" * {header} (score: {score:.2f}) from {citation}")
# print("\n", "#"*40, "\n")