File size: 13,964 Bytes
ef821d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""

Context stuffing query module.

Loads full documents and uses heuristics to select relevant content.

"""

import pickle
import logging
import re
from typing import List, Tuple, Optional, Dict, Any
from openai import OpenAI
import tiktoken

from config import *

logger = logging.getLogger(__name__)

class ContextStuffingRetriever:
    """Context stuffing with heuristic document selection."""
    
    def __init__(self):
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.encoding = tiktoken.get_encoding("cl100k_base")
        self.documents = None
        self._load_documents()
    
    def _load_documents(self):
        """Load full documents for context stuffing."""
        try:
            if CONTEXT_DOCS.exists():
                logger.info("Loading documents for context stuffing...")
                
                with open(CONTEXT_DOCS, 'rb') as f:
                    data = pickle.load(f)
                    
                if isinstance(data, list) and len(data) > 0:
                    # Handle both old format (list of chunks) and new format (list of DocumentChunk objects)
                    if hasattr(data[0], 'text'):  # New format with DocumentChunk objects
                        self.documents = []
                        for chunk in data:
                            self.documents.append({
                                'text': chunk.text,
                                'metadata': chunk.metadata,
                                'chunk_id': chunk.chunk_id
                            })
                    else:  # Old format with dict objects
                        self.documents = data
                        
                    logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing")
                else:
                    logger.warning("No documents found in context stuffing file")
                    self.documents = []
            else:
                logger.warning("Context stuffing documents not found. Run preprocess.py first.")
                self.documents = []
                
        except Exception as e:
            logger.error(f"Error loading context stuffing documents: {e}")
            self.documents = []
    
    def _calculate_keyword_score(self, text: str, question: str) -> float:
        """Calculate keyword overlap score between text and question."""
        # Simple keyword matching heuristic
        question_words = set(re.findall(r'\w+', question.lower()))
        text_words = set(re.findall(r'\w+', text.lower()))
        
        if not question_words:
            return 0.0
        
        overlap = len(question_words & text_words)
        return overlap / len(question_words)
    
    def _calculate_section_relevance(self, text: str, question: str) -> float:
        """Calculate section relevance using multiple heuristics."""
        score = 0.0
        
        # Keyword overlap score (weight: 0.5)
        keyword_score = self._calculate_keyword_score(text, question)
        score += 0.5 * keyword_score
        
        # Length penalty (prefer medium-length sections)
        text_length = len(text.split())
        optimal_length = 200  # words
        length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length)
        score += 0.2 * length_score
        
        # Header/title bonus (if text starts with common header patterns)
        if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()):
            score += 0.1
        
        # Question type specific bonuses
        question_lower = question.lower()
        text_lower = text.lower()
        
        if any(word in question_lower for word in ['what', 'define', 'definition']):
            if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']):
                score += 0.2
        
        if any(word in question_lower for word in ['how', 'procedure', 'steps']):
            if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']):
                score += 0.2
        
        if any(word in question_lower for word in ['requirement', 'shall', 'must']):
            if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']):
                score += 0.2
        
        return min(1.0, score)  # Cap at 1.0
    
    def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]:
        """Select most relevant documents using heuristics."""
        if not self.documents:
            return []
        
        if max_tokens is None:
            max_tokens = MAX_CONTEXT_TOKENS
        
        # Score all documents
        scored_docs = []
        for doc in self.documents:
            text = doc.get('text', '')
            if text.strip():
                relevance_score = self._calculate_section_relevance(text, question)
                
                doc_info = {
                    'text': text,
                    'metadata': doc.get('metadata', {}),
                    'score': relevance_score,
                    'token_count': len(self.encoding.encode(text))
                }
                scored_docs.append(doc_info)
        
        # Sort by relevance score
        scored_docs.sort(key=lambda x: x['score'], reverse=True)
        
        # Select documents within token limit
        selected_docs = []
        total_tokens = 0
        
        for doc in scored_docs:
            if doc['score'] > 0.1:  # Minimum relevance threshold
                if total_tokens + doc['token_count'] <= max_tokens:
                    selected_docs.append(doc)
                    total_tokens += doc['token_count']
                else:
                    # Try to include a truncated version
                    remaining_tokens = max_tokens - total_tokens
                    if remaining_tokens > 100:  # Only if meaningful content can fit
                        truncated_text = self._truncate_text(doc['text'], remaining_tokens)
                        if truncated_text:
                            doc['text'] = truncated_text
                            doc['token_count'] = len(self.encoding.encode(truncated_text))
                            selected_docs.append(doc)
                    break
        
        logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens")
        return selected_docs
    
    def _truncate_text(self, text: str, max_tokens: int) -> str:
        """Truncate text to fit within token limit while preserving meaning."""
        tokens = self.encoding.encode(text)
        if len(tokens) <= max_tokens:
            return text
        
        # Truncate and try to end at a sentence boundary
        truncated_tokens = tokens[:max_tokens]
        truncated_text = self.encoding.decode(truncated_tokens)
        
        # Try to end at a sentence boundary
        sentences = re.split(r'[.!?]+', truncated_text)
        if len(sentences) > 1:
            # Remove the last incomplete sentence
            truncated_text = '.'.join(sentences[:-1]) + '.'
        
        return truncated_text
    
    def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str:
        """Generate answer using full context stuffing approach."""
        if not context_docs:
            return "I couldn't find any relevant documents to answer your question."
        
        try:
            # Assemble context from selected documents
            context_parts = []
            sources = []
            
            for i, doc in enumerate(context_docs, 1):
                text = doc['text']
                metadata = doc['metadata']
                source = metadata.get('source', f'Document {i}')
                
                context_parts.append(f"=== {source} ===\n{text}")
                if source not in sources:
                    sources.append(source)
            
            full_context = "\n\n".join(context_parts)
            
            # Create system message for context stuffing
            system_message = (
                "You are an expert in occupational safety and health regulations. "
                "Answer the user's question using the provided regulatory documents and technical materials. "
                "Provide comprehensive, accurate answers that directly address the question. "
                "Reference specific sections or requirements when applicable. "
                "If the provided context doesn't fully answer the question, clearly state what information is missing."
            )
            
            # Create user message
            user_message = f"""Based on the following regulatory and technical documents, please answer this question:



QUESTION: {question}



DOCUMENTS:

{full_context}



Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well."""
            
            # For GPT-5, temperature must be default (1.0)
            response = self.client.chat.completions.create(
                model=OPENAI_CHAT_MODEL,
                messages=[
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_message}
                ],
                max_completion_tokens=DEFAULT_MAX_TOKENS
            )
            
            answer = response.choices[0].message.content.strip()
            
            # Add source information
            if len(sources) > 1:
                answer += f"\n\n*Sources consulted: {', '.join(sources)}*"
            elif sources:
                answer += f"\n\n*Source: {sources[0]}*"
            
            return answer
            
        except Exception as e:
            logger.error(f"Error generating context stuffing answer: {e}")
            return "I apologize, but I encountered an error while generating the answer using context stuffing."

# Global retriever instance
_retriever = None

def get_retriever() -> ContextStuffingRetriever:
    """Get or create global context stuffing retriever instance."""
    global _retriever
    if _retriever is None:
        _retriever = ContextStuffingRetriever()
    return _retriever

def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
    """

    Main context stuffing query function with unified signature.

    

    Args:

        question: User question

        image_path: Optional image path (not used in context stuffing but kept for consistency)

        top_k: Not used in context stuffing (uses heuristic selection instead)

        

    Returns:

        Tuple of (answer, citations)

    """
    try:
        retriever = get_retriever()
        
        # Select relevant documents using heuristics
        relevant_docs = retriever.select_relevant_documents(question)
        
        if not relevant_docs:
            return "I couldn't find any relevant documents to answer your question.", []
        
        # Generate comprehensive answer
        answer = retriever.generate_answer(question, relevant_docs)
        
        # Prepare citations
        citations = []
        for i, doc in enumerate(relevant_docs, 1):
            metadata = doc['metadata']
            citations.append({
                'rank': i,
                'score': float(doc['score']),
                'source': metadata.get('source', 'Unknown'),
                'type': metadata.get('type', 'unknown'),
                'method': 'context_stuffing',
                'tokens_used': doc['token_count']
            })
        
        logger.info(f"Context stuffing query completed. Used {len(citations)} documents.")
        return answer, citations
        
    except Exception as e:
        logger.error(f"Error in context stuffing query: {e}")
        error_message = "I apologize, but I encountered an error while processing your question with context stuffing."
        return error_message, []

def query_with_details(question: str, image_path: Optional[str] = None, 

                      top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
    """

    Context stuffing query function that returns detailed chunk information (for compatibility).

    

    Returns:

        Tuple of (answer, citations, chunks)

    """
    answer, citations = query(question, image_path, top_k)
    
    # Convert citations to chunk format for backward compatibility
    chunks = []
    for citation in citations:
        chunks.append((
            f"Document {citation['rank']} (Score: {citation['score']:.3f})",
            citation['score'],
            f"Context from {citation['source']} ({citation['tokens_used']} tokens)",
            citation['source']
        ))
    
    return answer, citations, chunks

if __name__ == "__main__":
    # Test the context stuffing system
    test_question = "What are the general requirements for machine guarding?"
    
    print("Testing context stuffing retrieval system...")
    print(f"Question: {test_question}")
    print("-" * 50)
    
    try:
        answer, citations = query(test_question)
        
        print("Answer:")
        print(answer)
        print(f"\nCitations ({len(citations)} documents used):")
        for citation in citations:
            print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})")
            
    except Exception as e:
        print(f"Error during testing: {e}")