sight_chat / query_context.py
fmegahed's picture
version 2.0.0
ef821d9 verified
raw
history blame
14 kB
"""
Context stuffing query module.
Loads full documents and uses heuristics to select relevant content.
"""
import pickle
import logging
import re
from typing import List, Tuple, Optional, Dict, Any
from openai import OpenAI
import tiktoken
from config import *
logger = logging.getLogger(__name__)
class ContextStuffingRetriever:
"""Context stuffing with heuristic document selection."""
def __init__(self):
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.encoding = tiktoken.get_encoding("cl100k_base")
self.documents = None
self._load_documents()
def _load_documents(self):
"""Load full documents for context stuffing."""
try:
if CONTEXT_DOCS.exists():
logger.info("Loading documents for context stuffing...")
with open(CONTEXT_DOCS, 'rb') as f:
data = pickle.load(f)
if isinstance(data, list) and len(data) > 0:
# Handle both old format (list of chunks) and new format (list of DocumentChunk objects)
if hasattr(data[0], 'text'): # New format with DocumentChunk objects
self.documents = []
for chunk in data:
self.documents.append({
'text': chunk.text,
'metadata': chunk.metadata,
'chunk_id': chunk.chunk_id
})
else: # Old format with dict objects
self.documents = data
logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing")
else:
logger.warning("No documents found in context stuffing file")
self.documents = []
else:
logger.warning("Context stuffing documents not found. Run preprocess.py first.")
self.documents = []
except Exception as e:
logger.error(f"Error loading context stuffing documents: {e}")
self.documents = []
def _calculate_keyword_score(self, text: str, question: str) -> float:
"""Calculate keyword overlap score between text and question."""
# Simple keyword matching heuristic
question_words = set(re.findall(r'\w+', question.lower()))
text_words = set(re.findall(r'\w+', text.lower()))
if not question_words:
return 0.0
overlap = len(question_words & text_words)
return overlap / len(question_words)
def _calculate_section_relevance(self, text: str, question: str) -> float:
"""Calculate section relevance using multiple heuristics."""
score = 0.0
# Keyword overlap score (weight: 0.5)
keyword_score = self._calculate_keyword_score(text, question)
score += 0.5 * keyword_score
# Length penalty (prefer medium-length sections)
text_length = len(text.split())
optimal_length = 200 # words
length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length)
score += 0.2 * length_score
# Header/title bonus (if text starts with common header patterns)
if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()):
score += 0.1
# Question type specific bonuses
question_lower = question.lower()
text_lower = text.lower()
if any(word in question_lower for word in ['what', 'define', 'definition']):
if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']):
score += 0.2
if any(word in question_lower for word in ['how', 'procedure', 'steps']):
if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']):
score += 0.2
if any(word in question_lower for word in ['requirement', 'shall', 'must']):
if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']):
score += 0.2
return min(1.0, score) # Cap at 1.0
def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]:
"""Select most relevant documents using heuristics."""
if not self.documents:
return []
if max_tokens is None:
max_tokens = MAX_CONTEXT_TOKENS
# Score all documents
scored_docs = []
for doc in self.documents:
text = doc.get('text', '')
if text.strip():
relevance_score = self._calculate_section_relevance(text, question)
doc_info = {
'text': text,
'metadata': doc.get('metadata', {}),
'score': relevance_score,
'token_count': len(self.encoding.encode(text))
}
scored_docs.append(doc_info)
# Sort by relevance score
scored_docs.sort(key=lambda x: x['score'], reverse=True)
# Select documents within token limit
selected_docs = []
total_tokens = 0
for doc in scored_docs:
if doc['score'] > 0.1: # Minimum relevance threshold
if total_tokens + doc['token_count'] <= max_tokens:
selected_docs.append(doc)
total_tokens += doc['token_count']
else:
# Try to include a truncated version
remaining_tokens = max_tokens - total_tokens
if remaining_tokens > 100: # Only if meaningful content can fit
truncated_text = self._truncate_text(doc['text'], remaining_tokens)
if truncated_text:
doc['text'] = truncated_text
doc['token_count'] = len(self.encoding.encode(truncated_text))
selected_docs.append(doc)
break
logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens")
return selected_docs
def _truncate_text(self, text: str, max_tokens: int) -> str:
"""Truncate text to fit within token limit while preserving meaning."""
tokens = self.encoding.encode(text)
if len(tokens) <= max_tokens:
return text
# Truncate and try to end at a sentence boundary
truncated_tokens = tokens[:max_tokens]
truncated_text = self.encoding.decode(truncated_tokens)
# Try to end at a sentence boundary
sentences = re.split(r'[.!?]+', truncated_text)
if len(sentences) > 1:
# Remove the last incomplete sentence
truncated_text = '.'.join(sentences[:-1]) + '.'
return truncated_text
def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str:
"""Generate answer using full context stuffing approach."""
if not context_docs:
return "I couldn't find any relevant documents to answer your question."
try:
# Assemble context from selected documents
context_parts = []
sources = []
for i, doc in enumerate(context_docs, 1):
text = doc['text']
metadata = doc['metadata']
source = metadata.get('source', f'Document {i}')
context_parts.append(f"=== {source} ===\n{text}")
if source not in sources:
sources.append(source)
full_context = "\n\n".join(context_parts)
# Create system message for context stuffing
system_message = (
"You are an expert in occupational safety and health regulations. "
"Answer the user's question using the provided regulatory documents and technical materials. "
"Provide comprehensive, accurate answers that directly address the question. "
"Reference specific sections or requirements when applicable. "
"If the provided context doesn't fully answer the question, clearly state what information is missing."
)
# Create user message
user_message = f"""Based on the following regulatory and technical documents, please answer this question:
QUESTION: {question}
DOCUMENTS:
{full_context}
Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well."""
# For GPT-5, temperature must be default (1.0)
response = self.client.chat.completions.create(
model=OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
max_completion_tokens=DEFAULT_MAX_TOKENS
)
answer = response.choices[0].message.content.strip()
# Add source information
if len(sources) > 1:
answer += f"\n\n*Sources consulted: {', '.join(sources)}*"
elif sources:
answer += f"\n\n*Source: {sources[0]}*"
return answer
except Exception as e:
logger.error(f"Error generating context stuffing answer: {e}")
return "I apologize, but I encountered an error while generating the answer using context stuffing."
# Global retriever instance
_retriever = None
def get_retriever() -> ContextStuffingRetriever:
"""Get or create global context stuffing retriever instance."""
global _retriever
if _retriever is None:
_retriever = ContextStuffingRetriever()
return _retriever
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
"""
Main context stuffing query function with unified signature.
Args:
question: User question
image_path: Optional image path (not used in context stuffing but kept for consistency)
top_k: Not used in context stuffing (uses heuristic selection instead)
Returns:
Tuple of (answer, citations)
"""
try:
retriever = get_retriever()
# Select relevant documents using heuristics
relevant_docs = retriever.select_relevant_documents(question)
if not relevant_docs:
return "I couldn't find any relevant documents to answer your question.", []
# Generate comprehensive answer
answer = retriever.generate_answer(question, relevant_docs)
# Prepare citations
citations = []
for i, doc in enumerate(relevant_docs, 1):
metadata = doc['metadata']
citations.append({
'rank': i,
'score': float(doc['score']),
'source': metadata.get('source', 'Unknown'),
'type': metadata.get('type', 'unknown'),
'method': 'context_stuffing',
'tokens_used': doc['token_count']
})
logger.info(f"Context stuffing query completed. Used {len(citations)} documents.")
return answer, citations
except Exception as e:
logger.error(f"Error in context stuffing query: {e}")
error_message = "I apologize, but I encountered an error while processing your question with context stuffing."
return error_message, []
def query_with_details(question: str, image_path: Optional[str] = None,
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
"""
Context stuffing query function that returns detailed chunk information (for compatibility).
Returns:
Tuple of (answer, citations, chunks)
"""
answer, citations = query(question, image_path, top_k)
# Convert citations to chunk format for backward compatibility
chunks = []
for citation in citations:
chunks.append((
f"Document {citation['rank']} (Score: {citation['score']:.3f})",
citation['score'],
f"Context from {citation['source']} ({citation['tokens_used']} tokens)",
citation['source']
))
return answer, citations, chunks
if __name__ == "__main__":
# Test the context stuffing system
test_question = "What are the general requirements for machine guarding?"
print("Testing context stuffing retrieval system...")
print(f"Question: {test_question}")
print("-" * 50)
try:
answer, citations = query(test_question)
print("Answer:")
print(answer)
print(f"\nCitations ({len(citations)} documents used):")
for citation in citations:
print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})")
except Exception as e:
print(f"Error during testing: {e}")