Spaces:
Paused
Paused
| """ | |
| Context stuffing query module. | |
| Loads full documents and uses heuristics to select relevant content. | |
| """ | |
| import pickle | |
| import logging | |
| import re | |
| from typing import List, Tuple, Optional, Dict, Any | |
| from openai import OpenAI | |
| import tiktoken | |
| from config import * | |
| logger = logging.getLogger(__name__) | |
| class ContextStuffingRetriever: | |
| """Context stuffing with heuristic document selection.""" | |
| def __init__(self): | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| self.documents = None | |
| self._load_documents() | |
| def _load_documents(self): | |
| """Load full documents for context stuffing.""" | |
| try: | |
| if CONTEXT_DOCS.exists(): | |
| logger.info("Loading documents for context stuffing...") | |
| with open(CONTEXT_DOCS, 'rb') as f: | |
| data = pickle.load(f) | |
| if isinstance(data, list) and len(data) > 0: | |
| # Handle both old format (list of chunks) and new format (list of DocumentChunk objects) | |
| if hasattr(data[0], 'text'): # New format with DocumentChunk objects | |
| self.documents = [] | |
| for chunk in data: | |
| self.documents.append({ | |
| 'text': chunk.text, | |
| 'metadata': chunk.metadata, | |
| 'chunk_id': chunk.chunk_id | |
| }) | |
| else: # Old format with dict objects | |
| self.documents = data | |
| logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing") | |
| else: | |
| logger.warning("No documents found in context stuffing file") | |
| self.documents = [] | |
| else: | |
| logger.warning("Context stuffing documents not found. Run preprocess.py first.") | |
| self.documents = [] | |
| except Exception as e: | |
| logger.error(f"Error loading context stuffing documents: {e}") | |
| self.documents = [] | |
| def _calculate_keyword_score(self, text: str, question: str) -> float: | |
| """Calculate keyword overlap score between text and question.""" | |
| # Simple keyword matching heuristic | |
| question_words = set(re.findall(r'\w+', question.lower())) | |
| text_words = set(re.findall(r'\w+', text.lower())) | |
| if not question_words: | |
| return 0.0 | |
| overlap = len(question_words & text_words) | |
| return overlap / len(question_words) | |
| def _calculate_section_relevance(self, text: str, question: str) -> float: | |
| """Calculate section relevance using multiple heuristics.""" | |
| score = 0.0 | |
| # Keyword overlap score (weight: 0.5) | |
| keyword_score = self._calculate_keyword_score(text, question) | |
| score += 0.5 * keyword_score | |
| # Length penalty (prefer medium-length sections) | |
| text_length = len(text.split()) | |
| optimal_length = 200 # words | |
| length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length) | |
| score += 0.2 * length_score | |
| # Header/title bonus (if text starts with common header patterns) | |
| if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()): | |
| score += 0.1 | |
| # Question type specific bonuses | |
| question_lower = question.lower() | |
| text_lower = text.lower() | |
| if any(word in question_lower for word in ['what', 'define', 'definition']): | |
| if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']): | |
| score += 0.2 | |
| if any(word in question_lower for word in ['how', 'procedure', 'steps']): | |
| if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']): | |
| score += 0.2 | |
| if any(word in question_lower for word in ['requirement', 'shall', 'must']): | |
| if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']): | |
| score += 0.2 | |
| return min(1.0, score) # Cap at 1.0 | |
| def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]: | |
| """Select most relevant documents using heuristics.""" | |
| if not self.documents: | |
| return [] | |
| if max_tokens is None: | |
| max_tokens = MAX_CONTEXT_TOKENS | |
| # Score all documents | |
| scored_docs = [] | |
| for doc in self.documents: | |
| text = doc.get('text', '') | |
| if text.strip(): | |
| relevance_score = self._calculate_section_relevance(text, question) | |
| doc_info = { | |
| 'text': text, | |
| 'metadata': doc.get('metadata', {}), | |
| 'score': relevance_score, | |
| 'token_count': len(self.encoding.encode(text)) | |
| } | |
| scored_docs.append(doc_info) | |
| # Sort by relevance score | |
| scored_docs.sort(key=lambda x: x['score'], reverse=True) | |
| # Select documents within token limit | |
| selected_docs = [] | |
| total_tokens = 0 | |
| for doc in scored_docs: | |
| if doc['score'] > 0.1: # Minimum relevance threshold | |
| if total_tokens + doc['token_count'] <= max_tokens: | |
| selected_docs.append(doc) | |
| total_tokens += doc['token_count'] | |
| else: | |
| # Try to include a truncated version | |
| remaining_tokens = max_tokens - total_tokens | |
| if remaining_tokens > 100: # Only if meaningful content can fit | |
| truncated_text = self._truncate_text(doc['text'], remaining_tokens) | |
| if truncated_text: | |
| doc['text'] = truncated_text | |
| doc['token_count'] = len(self.encoding.encode(truncated_text)) | |
| selected_docs.append(doc) | |
| break | |
| logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens") | |
| return selected_docs | |
| def _truncate_text(self, text: str, max_tokens: int) -> str: | |
| """Truncate text to fit within token limit while preserving meaning.""" | |
| tokens = self.encoding.encode(text) | |
| if len(tokens) <= max_tokens: | |
| return text | |
| # Truncate and try to end at a sentence boundary | |
| truncated_tokens = tokens[:max_tokens] | |
| truncated_text = self.encoding.decode(truncated_tokens) | |
| # Try to end at a sentence boundary | |
| sentences = re.split(r'[.!?]+', truncated_text) | |
| if len(sentences) > 1: | |
| # Remove the last incomplete sentence | |
| truncated_text = '.'.join(sentences[:-1]) + '.' | |
| return truncated_text | |
| def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str: | |
| """Generate answer using full context stuffing approach.""" | |
| if not context_docs: | |
| return "I couldn't find any relevant documents to answer your question." | |
| try: | |
| # Assemble context from selected documents | |
| context_parts = [] | |
| sources = [] | |
| for i, doc in enumerate(context_docs, 1): | |
| text = doc['text'] | |
| metadata = doc['metadata'] | |
| source = metadata.get('source', f'Document {i}') | |
| context_parts.append(f"=== {source} ===\n{text}") | |
| if source not in sources: | |
| sources.append(source) | |
| full_context = "\n\n".join(context_parts) | |
| # Create system message for context stuffing | |
| system_message = ( | |
| "You are an expert in occupational safety and health regulations. " | |
| "Answer the user's question using the provided regulatory documents and technical materials. " | |
| "Provide comprehensive, accurate answers that directly address the question. " | |
| "Reference specific sections or requirements when applicable. " | |
| "If the provided context doesn't fully answer the question, clearly state what information is missing." | |
| ) | |
| # Create user message | |
| user_message = f"""Based on the following regulatory and technical documents, please answer this question: | |
| QUESTION: {question} | |
| DOCUMENTS: | |
| {full_context} | |
| Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well.""" | |
| # For GPT-5, temperature must be default (1.0) | |
| response = self.client.chat.completions.create( | |
| model=OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ], | |
| max_completion_tokens=DEFAULT_MAX_TOKENS | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| # Add source information | |
| if len(sources) > 1: | |
| answer += f"\n\n*Sources consulted: {', '.join(sources)}*" | |
| elif sources: | |
| answer += f"\n\n*Source: {sources[0]}*" | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Error generating context stuffing answer: {e}") | |
| return "I apologize, but I encountered an error while generating the answer using context stuffing." | |
| # Global retriever instance | |
| _retriever = None | |
| def get_retriever() -> ContextStuffingRetriever: | |
| """Get or create global context stuffing retriever instance.""" | |
| global _retriever | |
| if _retriever is None: | |
| _retriever = ContextStuffingRetriever() | |
| return _retriever | |
| def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]: | |
| """ | |
| Main context stuffing query function with unified signature. | |
| Args: | |
| question: User question | |
| image_path: Optional image path (not used in context stuffing but kept for consistency) | |
| top_k: Not used in context stuffing (uses heuristic selection instead) | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| try: | |
| retriever = get_retriever() | |
| # Select relevant documents using heuristics | |
| relevant_docs = retriever.select_relevant_documents(question) | |
| if not relevant_docs: | |
| return "I couldn't find any relevant documents to answer your question.", [] | |
| # Generate comprehensive answer | |
| answer = retriever.generate_answer(question, relevant_docs) | |
| # Prepare citations | |
| citations = [] | |
| for i, doc in enumerate(relevant_docs, 1): | |
| metadata = doc['metadata'] | |
| citations.append({ | |
| 'rank': i, | |
| 'score': float(doc['score']), | |
| 'source': metadata.get('source', 'Unknown'), | |
| 'type': metadata.get('type', 'unknown'), | |
| 'method': 'context_stuffing', | |
| 'tokens_used': doc['token_count'] | |
| }) | |
| logger.info(f"Context stuffing query completed. Used {len(citations)} documents.") | |
| return answer, citations | |
| except Exception as e: | |
| logger.error(f"Error in context stuffing query: {e}") | |
| error_message = "I apologize, but I encountered an error while processing your question with context stuffing." | |
| return error_message, [] | |
| def query_with_details(question: str, image_path: Optional[str] = None, | |
| top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]: | |
| """ | |
| Context stuffing query function that returns detailed chunk information (for compatibility). | |
| Returns: | |
| Tuple of (answer, citations, chunks) | |
| """ | |
| answer, citations = query(question, image_path, top_k) | |
| # Convert citations to chunk format for backward compatibility | |
| chunks = [] | |
| for citation in citations: | |
| chunks.append(( | |
| f"Document {citation['rank']} (Score: {citation['score']:.3f})", | |
| citation['score'], | |
| f"Context from {citation['source']} ({citation['tokens_used']} tokens)", | |
| citation['source'] | |
| )) | |
| return answer, citations, chunks | |
| if __name__ == "__main__": | |
| # Test the context stuffing system | |
| test_question = "What are the general requirements for machine guarding?" | |
| print("Testing context stuffing retrieval system...") | |
| print(f"Question: {test_question}") | |
| print("-" * 50) | |
| try: | |
| answer, citations = query(test_question) | |
| print("Answer:") | |
| print(answer) | |
| print(f"\nCitations ({len(citations)} documents used):") | |
| for citation in citations: | |
| print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})") | |
| except Exception as e: | |
| print(f"Error during testing: {e}") |