fmegahed commited on
Commit
ef821d9
·
verified ·
1 Parent(s): 94d4926

version 2.0.0

Browse files
Files changed (14) hide show
  1. Dockerfile +67 -26
  2. analytics_db.py +509 -0
  3. app.py +1298 -47
  4. config.py +217 -0
  5. preprocess.py +187 -77
  6. query_bm25.py +336 -0
  7. query_context.py +334 -0
  8. query_dpr.py +279 -0
  9. query_graph.py +372 -99
  10. query_vanilla.py +197 -0
  11. query_vision.py +393 -0
  12. realtime_server.py +402 -0
  13. requirements.txt +63 -10
  14. utils.py +679 -0
Dockerfile CHANGED
@@ -1,26 +1,67 @@
1
- FROM python:3.12.11-slim
2
-
3
- # 1) Create and switch to the app directory
4
- WORKDIR /app
5
-
6
- # 2) Install system dependencies
7
- RUN apt-get update && \
8
- apt-get install -y build-essential curl git && \
9
- rm -rf /var/lib/apt/lists/*
10
-
11
- # 3) Copy & install Python dependencies
12
- COPY requirements.txt .
13
- RUN pip install --no-cache-dir -r requirements.txt
14
-
15
- # 4) Copy all your code and data into the container
16
- COPY . .
17
-
18
- # 5) Expose Streamlit’s default port
19
- EXPOSE 8501
20
-
21
- # 6) Healthcheck for Streamlit
22
- HEALTHCHECK --interval=30s --timeout=5s \
23
- CMD curl --fail http://localhost:8501/_stcore/health || exit 1
24
-
25
- # 7) Launch your app.py at root
26
- ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.11-slim
2
+
3
+ # Set environment variables for Python
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PYTHONDONTWRITEBYTECODE=1
6
+
7
+ # Create app directory and switch to it
8
+ WORKDIR /app
9
+
10
+ # Install system dependencies required for your packages
11
+ RUN apt-get update && \
12
+ apt-get install -y \
13
+ build-essential \
14
+ curl \
15
+ git \
16
+ libgomp1 \
17
+ supervisor \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Copy requirements first for better Docker layer caching
21
+ COPY requirements.txt .
22
+
23
+ # Install Python dependencies
24
+ RUN pip install --no-cache-dir --upgrade pip && \
25
+ pip install --no-cache-dir -r requirements.txt
26
+
27
+ # Copy the entire application
28
+ COPY . .
29
+
30
+ # Create necessary directories if they don't exist
31
+ RUN mkdir -p /app/data /app/embeddings /app/graph /app/metadata /var/log/supervisor
32
+
33
+ # Create supervisor configuration
34
+ COPY <<EOF /etc/supervisor/conf.d/supervisord.conf
35
+ [supervisord]
36
+ nodaemon=true
37
+ logfile=/var/log/supervisor/supervisord.log
38
+ pidfile=/var/run/supervisord.pid
39
+
40
+ [program:realtime_server]
41
+ command=python realtime_server.py --port=7861 --host=0.0.0.0
42
+ directory=/app
43
+ autostart=true
44
+ autorestart=true
45
+ stderr_logfile=/var/log/supervisor/realtime_server.err.log
46
+ stdout_logfile=/var/log/supervisor/realtime_server.out.log
47
+ priority=100
48
+
49
+ [program:streamlit]
50
+ command=streamlit run app.py --server.port=7860 --server.address=0.0.0.0 --server.enableXsrfProtection=false --server.enableCORS=false
51
+ directory=/app
52
+ autostart=true
53
+ autorestart=true
54
+ stderr_logfile=/var/log/supervisor/streamlit.err.log
55
+ stdout_logfile=/var/log/supervisor/streamlit.out.log
56
+ priority=200
57
+ EOF
58
+
59
+ # Expose both ports (Streamlit on 7860, Realtime API on 7861)
60
+ EXPOSE 7860 7861
61
+
62
+ # Health check for both services
63
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
64
+ CMD curl --fail http://localhost:7860/_stcore/health && curl --fail http://localhost:7861/health || exit 1
65
+
66
+ # Use supervisor to run both services
67
+ CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
analytics_db.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analytics Database Module for Query Logging and Performance Tracking.
3
+ Tracks every query, method, answer, and citation for comprehensive analytics.
4
+ """
5
+
6
+ import sqlite3
7
+ import json
8
+ import time
9
+ from datetime import datetime, timedelta
10
+ from typing import List, Dict, Any, Optional, Tuple
11
+ from pathlib import Path
12
+ import logging
13
+
14
+ from config import DATA_DIR
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Database path
19
+ ANALYTICS_DB = DATA_DIR / "analytics.db"
20
+
21
+ class AnalyticsDB:
22
+ """Database manager for query analytics and logging."""
23
+
24
+ def __init__(self):
25
+ self.db_path = ANALYTICS_DB
26
+ self._init_database()
27
+
28
+ def _init_database(self):
29
+ """Initialize analytics database with required tables."""
30
+ conn = sqlite3.connect(self.db_path)
31
+ cursor = conn.cursor()
32
+
33
+ # Main queries table
34
+ cursor.execute('''
35
+ CREATE TABLE IF NOT EXISTS queries (
36
+ query_id INTEGER PRIMARY KEY AUTOINCREMENT,
37
+ timestamp TEXT NOT NULL,
38
+ user_query TEXT NOT NULL,
39
+ retrieval_method TEXT NOT NULL,
40
+ answer TEXT NOT NULL,
41
+ response_time_ms REAL,
42
+ num_citations INTEGER DEFAULT 0,
43
+ image_path TEXT,
44
+ error_message TEXT,
45
+ top_k_used INTEGER DEFAULT 5,
46
+ additional_settings TEXT,
47
+ answer_length INTEGER,
48
+ session_id TEXT
49
+ )
50
+ ''')
51
+
52
+ # Citations table
53
+ cursor.execute('''
54
+ CREATE TABLE IF NOT EXISTS citations (
55
+ citation_id INTEGER PRIMARY KEY AUTOINCREMENT,
56
+ query_id INTEGER NOT NULL,
57
+ source TEXT NOT NULL,
58
+ citation_type TEXT,
59
+ relevance_score REAL,
60
+ bm25_score REAL,
61
+ rerank_score REAL,
62
+ similarity_score REAL,
63
+ url TEXT,
64
+ path TEXT,
65
+ rank INTEGER,
66
+ FOREIGN KEY (query_id) REFERENCES queries (query_id)
67
+ )
68
+ ''')
69
+
70
+ # Performance metrics table
71
+ cursor.execute('''
72
+ CREATE TABLE IF NOT EXISTS performance_metrics (
73
+ metric_id INTEGER PRIMARY KEY AUTOINCREMENT,
74
+ query_id INTEGER NOT NULL,
75
+ retrieval_time_ms REAL,
76
+ generation_time_ms REAL,
77
+ total_time_ms REAL,
78
+ chunks_retrieved INTEGER,
79
+ tokens_estimated INTEGER,
80
+ FOREIGN KEY (query_id) REFERENCES queries (query_id)
81
+ )
82
+ ''')
83
+
84
+ conn.commit()
85
+ conn.close()
86
+ logger.info("Analytics database initialized")
87
+
88
+ def log_query(self, user_query: str, method: str, answer: str,
89
+ citations: List[Dict], response_time: float = None,
90
+ image_path: str = None, error_message: str = None,
91
+ top_k: int = 5, additional_settings: Dict = None,
92
+ session_id: str = None) -> int:
93
+ """
94
+ Log a complete query interaction.
95
+
96
+ Args:
97
+ user_query: The user's question
98
+ method: Retrieval method used
99
+ answer: Generated answer
100
+ citations: List of citation dictionaries
101
+ response_time: Time taken in milliseconds
102
+ image_path: Path to uploaded image (if any)
103
+ error_message: Error message (if any)
104
+ top_k: Number of chunks retrieved
105
+ additional_settings: Method-specific settings
106
+ session_id: Session identifier
107
+
108
+ Returns:
109
+ query_id: The ID of the logged query
110
+ """
111
+ conn = sqlite3.connect(self.db_path)
112
+ cursor = conn.cursor()
113
+
114
+ try:
115
+ # Insert main query record
116
+ cursor.execute('''
117
+ INSERT INTO queries (
118
+ timestamp, user_query, retrieval_method, answer,
119
+ response_time_ms, num_citations, image_path, error_message,
120
+ top_k_used, additional_settings, answer_length, session_id
121
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
122
+ ''', (
123
+ datetime.now().isoformat(),
124
+ user_query,
125
+ method,
126
+ answer,
127
+ response_time,
128
+ len(citations),
129
+ image_path,
130
+ error_message,
131
+ top_k,
132
+ json.dumps(additional_settings) if additional_settings else None,
133
+ len(answer),
134
+ session_id
135
+ ))
136
+
137
+ query_id = cursor.lastrowid
138
+
139
+ # Insert citations
140
+ for rank, citation in enumerate(citations, 1):
141
+ cursor.execute('''
142
+ INSERT INTO citations (
143
+ query_id, source, citation_type, relevance_score,
144
+ bm25_score, rerank_score, similarity_score, url, path, rank
145
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
146
+ ''', (
147
+ query_id,
148
+ citation.get('source', ''),
149
+ citation.get('type', ''),
150
+ citation.get('relevance_score'),
151
+ citation.get('bm25_score'),
152
+ citation.get('rerank_score'),
153
+ citation.get('similarity_score'),
154
+ citation.get('url'),
155
+ citation.get('path'),
156
+ rank
157
+ ))
158
+
159
+ conn.commit()
160
+ logger.info(f"Logged query {query_id} with {len(citations)} citations")
161
+ return query_id
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error logging query: {e}")
165
+ conn.rollback()
166
+ return None
167
+ finally:
168
+ conn.close()
169
+
170
+ def get_query_stats(self, days: int = 30) -> Dict[str, Any]:
171
+ """Get comprehensive query statistics."""
172
+ conn = sqlite3.connect(self.db_path)
173
+ cursor = conn.cursor()
174
+
175
+ since_date = (datetime.now() - timedelta(days=days)).isoformat()
176
+
177
+ try:
178
+ stats = {}
179
+
180
+ # Total queries
181
+ cursor.execute('''
182
+ SELECT COUNT(*) FROM queries
183
+ WHERE timestamp >= ?
184
+ ''', (since_date,))
185
+ stats['total_queries'] = cursor.fetchone()[0]
186
+
187
+ # Method usage
188
+ cursor.execute('''
189
+ SELECT retrieval_method, COUNT(*) as count
190
+ FROM queries
191
+ WHERE timestamp >= ?
192
+ GROUP BY retrieval_method
193
+ ORDER BY count DESC
194
+ ''', (since_date,))
195
+ stats['method_usage'] = dict(cursor.fetchall())
196
+
197
+ # Average response times by method
198
+ cursor.execute('''
199
+ SELECT retrieval_method, AVG(response_time_ms) as avg_time
200
+ FROM queries
201
+ WHERE timestamp >= ? AND response_time_ms IS NOT NULL
202
+ GROUP BY retrieval_method
203
+ ''', (since_date,))
204
+ stats['avg_response_times'] = dict(cursor.fetchall())
205
+
206
+ # Citation statistics
207
+ cursor.execute('''
208
+ SELECT AVG(num_citations) as avg_citations,
209
+ SUM(num_citations) as total_citations
210
+ FROM queries
211
+ WHERE timestamp >= ?
212
+ ''', (since_date,))
213
+ result = cursor.fetchone()
214
+ stats['avg_citations'] = result[0] or 0
215
+ stats['total_citations'] = result[1] or 0
216
+
217
+ # Citation types
218
+ cursor.execute('''
219
+ SELECT c.citation_type, COUNT(*) as count
220
+ FROM citations c
221
+ JOIN queries q ON c.query_id = q.query_id
222
+ WHERE q.timestamp >= ?
223
+ GROUP BY c.citation_type
224
+ ORDER BY count DESC
225
+ ''', (since_date,))
226
+ stats['citation_types'] = dict(cursor.fetchall())
227
+
228
+ # Error rate
229
+ cursor.execute('''
230
+ SELECT
231
+ COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as errors,
232
+ COUNT(*) as total
233
+ FROM queries
234
+ WHERE timestamp >= ?
235
+ ''', (since_date,))
236
+ result = cursor.fetchone()
237
+ stats['error_rate'] = (result[0] / result[1]) * 100 if result[1] > 0 else 0
238
+
239
+ # Most common query topics (simple word analysis)
240
+ cursor.execute('''
241
+ SELECT user_query FROM queries
242
+ WHERE timestamp >= ?
243
+ ''', (since_date,))
244
+ queries = [row[0].lower() for row in cursor.fetchall()]
245
+
246
+ # Simple keyword extraction
247
+ keywords = {}
248
+ for query in queries:
249
+ words = [word for word in query.split() if len(word) > 3]
250
+ for word in words:
251
+ keywords[word] = keywords.get(word, 0) + 1
252
+
253
+ # Top 10 keywords
254
+ stats['top_keywords'] = dict(sorted(keywords.items(),
255
+ key=lambda x: x[1],
256
+ reverse=True)[:10])
257
+
258
+ return stats
259
+
260
+ except Exception as e:
261
+ logger.error(f"Error getting query stats: {e}")
262
+ return {}
263
+ finally:
264
+ conn.close()
265
+
266
+ def get_method_performance(self) -> Dict[str, Dict[str, float]]:
267
+ """Get detailed performance metrics by method."""
268
+ conn = sqlite3.connect(self.db_path)
269
+ cursor = conn.cursor()
270
+
271
+ try:
272
+ cursor.execute('''
273
+ SELECT
274
+ retrieval_method,
275
+ AVG(response_time_ms) as avg_response_time,
276
+ AVG(num_citations) as avg_citations,
277
+ AVG(answer_length) as avg_answer_length,
278
+ COUNT(*) as query_count
279
+ FROM queries
280
+ WHERE response_time_ms IS NOT NULL
281
+ GROUP BY retrieval_method
282
+ ''')
283
+
284
+ results = {}
285
+ for row in cursor.fetchall():
286
+ method, avg_time, avg_cites, avg_length, count = row
287
+ results[method] = {
288
+ 'avg_response_time': avg_time,
289
+ 'avg_citations': avg_cites,
290
+ 'avg_answer_length': avg_length,
291
+ 'query_count': count
292
+ }
293
+
294
+ return results
295
+
296
+ except Exception as e:
297
+ logger.error(f"Error getting method performance: {e}")
298
+ return {}
299
+ finally:
300
+ conn.close()
301
+
302
+ def get_recent_queries(self, limit: int = 20, include_answers: bool = True) -> List[Dict[str, Any]]:
303
+ """Get recent queries with basic information and optionally full answers."""
304
+ conn = sqlite3.connect(self.db_path)
305
+ cursor = conn.cursor()
306
+
307
+ try:
308
+ if include_answers:
309
+ cursor.execute('''
310
+ SELECT query_id, timestamp, user_query, retrieval_method,
311
+ answer, answer_length, num_citations, response_time_ms, error_message
312
+ FROM queries
313
+ ORDER BY timestamp DESC
314
+ LIMIT ?
315
+ ''', (limit,))
316
+
317
+ columns = ['query_id', 'timestamp', 'query', 'method',
318
+ 'answer', 'answer_length', 'citations', 'response_time', 'error_message']
319
+ else:
320
+ cursor.execute('''
321
+ SELECT query_id, timestamp, user_query, retrieval_method,
322
+ answer_length, num_citations, response_time_ms
323
+ FROM queries
324
+ ORDER BY timestamp DESC
325
+ LIMIT ?
326
+ ''', (limit,))
327
+
328
+ columns = ['query_id', 'timestamp', 'query', 'method',
329
+ 'answer_length', 'citations', 'response_time']
330
+
331
+ return [dict(zip(columns, row)) for row in cursor.fetchall()]
332
+
333
+ except Exception as e:
334
+ logger.error(f"Error getting recent queries: {e}")
335
+ return []
336
+ finally:
337
+ conn.close()
338
+
339
+ def get_query_with_citations(self, query_id: int) -> Dict[str, Any]:
340
+ """Get full query details including citations."""
341
+ conn = sqlite3.connect(self.db_path)
342
+ cursor = conn.cursor()
343
+
344
+ try:
345
+ # Get query details
346
+ cursor.execute('''
347
+ SELECT query_id, timestamp, user_query, retrieval_method, answer,
348
+ response_time_ms, num_citations, error_message, top_k_used
349
+ FROM queries WHERE query_id = ?
350
+ ''', (query_id,))
351
+
352
+ query_row = cursor.fetchone()
353
+ if not query_row:
354
+ return {}
355
+
356
+ query_data = {
357
+ 'query_id': query_row[0],
358
+ 'timestamp': query_row[1],
359
+ 'user_query': query_row[2],
360
+ 'method': query_row[3],
361
+ 'answer': query_row[4],
362
+ 'response_time': query_row[5],
363
+ 'num_citations': query_row[6],
364
+ 'error_message': query_row[7],
365
+ 'top_k_used': query_row[8]
366
+ }
367
+
368
+ # Get citations
369
+ cursor.execute('''
370
+ SELECT source, citation_type, relevance_score, bm25_score,
371
+ rerank_score, similarity_score, url, path, rank
372
+ FROM citations WHERE query_id = ?
373
+ ORDER BY rank
374
+ ''', (query_id,))
375
+
376
+ citations = []
377
+ for row in cursor.fetchall():
378
+ citation = {
379
+ 'source': row[0],
380
+ 'type': row[1],
381
+ 'relevance_score': row[2],
382
+ 'bm25_score': row[3],
383
+ 'rerank_score': row[4],
384
+ 'similarity_score': row[5],
385
+ 'url': row[6],
386
+ 'path': row[7],
387
+ 'rank': row[8]
388
+ }
389
+ citations.append(citation)
390
+
391
+ query_data['citations'] = citations
392
+ return query_data
393
+
394
+ except Exception as e:
395
+ logger.error(f"Error getting query with citations: {e}")
396
+ return {}
397
+ finally:
398
+ conn.close()
399
+
400
+ def get_query_trends(self, days: int = 30) -> Dict[str, List[Tuple[str, int]]]:
401
+ """Get query trends over time."""
402
+ conn = sqlite3.connect(self.db_path)
403
+ cursor = conn.cursor()
404
+
405
+ since_date = (datetime.now() - timedelta(days=days)).isoformat()
406
+
407
+ try:
408
+ # Queries per day
409
+ cursor.execute('''
410
+ SELECT DATE(timestamp) as date, COUNT(*) as count
411
+ FROM queries
412
+ WHERE timestamp >= ?
413
+ GROUP BY DATE(timestamp)
414
+ ORDER BY date
415
+ ''', (since_date,))
416
+
417
+ daily_queries = cursor.fetchall()
418
+
419
+ # Method usage trends
420
+ cursor.execute('''
421
+ SELECT DATE(timestamp) as date, retrieval_method, COUNT(*) as count
422
+ FROM queries
423
+ WHERE timestamp >= ?
424
+ GROUP BY DATE(timestamp), retrieval_method
425
+ ORDER BY date, retrieval_method
426
+ ''', (since_date,))
427
+
428
+ method_trends = {}
429
+ for date, method, count in cursor.fetchall():
430
+ if method not in method_trends:
431
+ method_trends[method] = []
432
+ method_trends[method].append((date, count))
433
+
434
+ return {
435
+ 'daily_queries': daily_queries,
436
+ 'method_trends': method_trends
437
+ }
438
+
439
+ except Exception as e:
440
+ logger.error(f"Error getting query trends: {e}")
441
+ return {}
442
+ finally:
443
+ conn.close()
444
+
445
+ def get_voice_interaction_stats(self) -> Dict[str, Any]:
446
+ """Get statistics about voice interactions."""
447
+ try:
448
+ conn = sqlite3.connect(self.db_path)
449
+ cursor = conn.cursor()
450
+
451
+ # Count voice interactions (those with voice_interaction=true in additional_settings)
452
+ cursor.execute('''
453
+ SELECT COUNT(*) as total_voice_queries
454
+ FROM queries
455
+ WHERE additional_settings LIKE '%voice_interaction%'
456
+ OR session_id LIKE 'voice_%'
457
+ ''')
458
+ result = cursor.fetchone()
459
+ total_voice = result[0] if result else 0
460
+
461
+ # Get voice queries by method
462
+ cursor.execute('''
463
+ SELECT retrieval_method, COUNT(*) as count
464
+ FROM queries
465
+ WHERE additional_settings LIKE '%voice_interaction%'
466
+ OR session_id LIKE 'voice_%'
467
+ GROUP BY retrieval_method
468
+ ''')
469
+ voice_by_method = dict(cursor.fetchall())
470
+
471
+ # Average response time for voice queries
472
+ cursor.execute('''
473
+ SELECT AVG(response_time_ms) as avg_response_time
474
+ FROM queries
475
+ WHERE (additional_settings LIKE '%voice_interaction%'
476
+ OR session_id LIKE 'voice_%')
477
+ AND response_time_ms IS NOT NULL
478
+ ''')
479
+ result = cursor.fetchone()
480
+ avg_response_time = result[0] if result and result[0] else 0
481
+
482
+ return {
483
+ 'total_voice_queries': total_voice,
484
+ 'voice_by_method': voice_by_method,
485
+ 'avg_voice_response_time': avg_response_time
486
+ }
487
+
488
+ except Exception as e:
489
+ logger.error(f"Error getting voice interaction stats: {e}")
490
+ return {}
491
+ finally:
492
+ conn.close()
493
+
494
+ # Global instance
495
+ analytics_db = AnalyticsDB()
496
+
497
+ # Convenience functions
498
+ def log_query(user_query: str, method: str, answer: str, citations: List[Dict],
499
+ **kwargs) -> int:
500
+ """Log a query to the analytics database."""
501
+ return analytics_db.log_query(user_query, method, answer, citations, **kwargs)
502
+
503
+ def get_analytics_stats(days: int = 30) -> Dict[str, Any]:
504
+ """Get analytics statistics."""
505
+ return analytics_db.get_query_stats(days)
506
+
507
+ def get_method_performance() -> Dict[str, Dict[str, float]]:
508
+ """Get method performance metrics."""
509
+ return analytics_db.get_method_performance()
app.py CHANGED
@@ -1,82 +1,1333 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- from query_graph import query_graph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Helper for <details>
5
  def format_citations_html(chunks):
 
6
  html = []
7
  for idx, (hdr, sc, txt, citation) in enumerate(chunks, start=1):
8
- preamble = (
9
- f"<p style='font-size:0.9em;'><strong>Preamble:</strong> "
10
- f"The text in the following detail is reproduced from [{citation}]. "
11
- f"It had a cosine similarity of {sc:.2f} with the user question, "
12
- f"and it ranked {idx} among the text chunks in our graph database.</p>"
13
- )
14
  body = txt.replace("\n", "<br>")
15
  html.append(
16
  f"<details>"
17
- f"<summary>{hdr} (cosine similarity: {sc:.2f})</summary>"
18
  f"<div style='font-size:0.9em; margin-top:0.5em;'>"
19
- f"<strong>Preamble:</strong> The text below is reproduced from {citation}. "
20
  f"</div>"
21
- f"<div style='font-size:0.7em; margin-left:1em; margin-top:0.5em;'>{body}</div>"
22
  f"</details><br><br>"
23
  )
24
  return "<br>".join(html)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Sidebar configuration
28
- st.sidebar.title("About")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  st.sidebar.markdown("**Authors:** [The SIGHT Project Team](https://sites.miamioh.edu/sight/)")
30
- st.sidebar.markdown("**Version:** V. 0.0.2")
31
- st.sidebar.markdown("**Date:** July 24, 2025")
32
- st.sidebar.markdown("**Model:** gpt4o")
33
 
34
  st.sidebar.markdown("---")
35
  st.sidebar.markdown(
36
  "**Funding:** SIGHT is funded by [OHBWC WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)"
37
  )
38
 
 
 
 
 
 
39
 
40
- # Main interface
41
- st.set_page_config(page_title="Miami University's SIGHT Chatbot")
42
- st.title("Chat with SIGHT")
43
- st.write("Ask questions about machine safeguarding, LOTO, and hazard prevention based on OSHA/CFR's corpus.")
 
 
 
 
 
44
 
45
- # Example questions toggled in main window
46
- with st.expander("Example Questions", expanded=False):
47
- st.markdown(
48
- "- What are general machine guarding requirements? \n"
49
- "- How do I perform lockout/tagout? \n"
50
- "- Summarize the definition of machine guarding from 29 CFR 1910.211"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Initialize chat history
54
- if 'history' not in st.session_state:
55
- st.session_state.history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # User input
58
- query = st.text_input("Your question:")
59
- if st.button("Send") and query:
60
- answer, sources, chunks = query_graph(query)
61
- st.session_state.history.append({
62
- 'query': query,
63
- 'answer': answer,
64
- 'sources': sources,
65
- 'chunks': chunks
66
- })
67
 
68
- # Display chat history
69
- for entry in st.session_state.history[::-1]:
70
- st.markdown(f"**You:** {entry['query']}")
71
- st.markdown(f"**Assistant:** {entry['answer']}")
72
- st.markdown(format_citations_html(entry['chunks']), unsafe_allow_html=True)
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Footer
76
  st.markdown("---")
77
  st.markdown(
78
- "**Disclaimer:** *Powered by a Graph RAG to reduce hallucinations; please verify as it can still make mistakes.*"
79
- )
 
80
  st.markdown(
81
- "**Funding:** *We are thankful for [Ohio BWC/WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)'s funding that made this chat bot possible.*"
82
- )
 
 
1
+ """
2
+ Multi-Method RAG System - SIGHT
3
+ Enhanced Streamlit application with method comparison and analytics.
4
+
5
+ Directory structure:
6
+ /data/ # Original PDFs, HTML
7
+ /embeddings/ # FAISS, Chroma, DPR vector stores
8
+ /graph/ # Graph database files
9
+ /metadata/ # Image metadata (SQLite or MongoDB)
10
+ """
11
+
12
  import streamlit as st
13
+ import os
14
+ import logging
15
+ import tempfile
16
+ import time
17
+ import uuid
18
+ from typing import Tuple, List, Dict, Any, Optional
19
+ from pathlib import Path
20
+
21
+ # Import all query modules
22
+ from query_graph import query as graph_query, query_graph
23
+ from query_vanilla import query as vanilla_query
24
+ from query_dpr import query as dpr_query
25
+ from query_bm25 import query as bm25_query
26
+ from query_context import query as context_query
27
+ from query_vision import query as vision_query, query_image_only
28
+
29
+ from config import *
30
+ from analytics_db import log_query, get_analytics_stats, get_method_performance, analytics_db
31
+ import streamlit.components.v1 as components
32
+ import requests
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Check realtime server health
37
+ @st.cache_data(ttl=30) # Cache for 30 seconds
38
+ def check_realtime_server_health():
39
+ """Check if the realtime server is running."""
40
+ try:
41
+ response = requests.get("http://localhost:7861/health", timeout=2)
42
+ return response.status_code == 200
43
+ except:
44
+ return False
45
+
46
+ # Query method dispatch
47
+ QUERY_DISPATCH = {
48
+ 'graph': graph_query,
49
+ 'vanilla': vanilla_query,
50
+ 'dpr': dpr_query,
51
+ 'bm25': bm25_query,
52
+ 'context': context_query,
53
+ 'vision': vision_query
54
+ }
55
+
56
+ # Method options for speech interface
57
+ METHOD_OPTIONS = ['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision']
58
 
 
59
  def format_citations_html(chunks):
60
+ """Format citations for display (backward compatibility)."""
61
  html = []
62
  for idx, (hdr, sc, txt, citation) in enumerate(chunks, start=1):
 
 
 
 
 
 
63
  body = txt.replace("\n", "<br>")
64
  html.append(
65
  f"<details>"
66
+ f"<summary>{hdr} (relevance score: {sc:.3f})</summary>"
67
  f"<div style='font-size:0.9em; margin-top:0.5em;'>"
68
+ f"<strong>Source:</strong> {citation} "
69
  f"</div>"
70
+ f"<div style='font-size:0.8em; margin-left:1em; margin-top:0.5em;'>{body}</div>"
71
  f"</details><br><br>"
72
  )
73
  return "<br>".join(html)
74
 
75
+ def format_citations_html(citations: List[dict], method: str) -> str:
76
+ """Format citations as HTML based on method and citation type."""
77
+ if not citations:
78
+ return "<p><em>No citations available</em></p>"
79
+
80
+ html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"]
81
+
82
+ for citation in citations:
83
+ # Skip citations without source
84
+ if 'source' not in citation:
85
+ continue
86
+
87
+ source = citation['source']
88
+ cite_type = citation.get('type', 'unknown')
89
+
90
+ # Build citation text based on type
91
+ if cite_type == 'pdf':
92
+ cite_text = f"📄 {source} (PDF)"
93
+ elif cite_type == 'html':
94
+ url = citation.get('url', '')
95
+ if url:
96
+ cite_text = f"🌐 <a href='{url}' target='_blank'>{source}</a> (Web)"
97
+ else:
98
+ cite_text = f"🌐 {source} (Web)"
99
+ elif cite_type == 'image':
100
+ page = citation.get('page', 'N/A')
101
+ cite_text = f"🖼️ {source} (Image, page {page})"
102
+ elif cite_type == 'image_analysis':
103
+ classification = citation.get('classification', 'N/A')
104
+ cite_text = f"🔍 {source} - {classification}"
105
+ else:
106
+ cite_text = f"📚 {source}"
107
+
108
+ # Add scores if available
109
+ scores = []
110
+ if 'relevance_score' in citation:
111
+ scores.append(f"relevance: {citation['relevance_score']}")
112
+ if 'bm25_score' in citation:
113
+ scores.append(f"BM25: {citation['bm25_score']}")
114
+ if 'rerank_score' in citation:
115
+ scores.append(f"rerank: {citation['rerank_score']}")
116
+ if 'similarity' in citation:
117
+ scores.append(f"similarity: {citation['similarity']}")
118
+ if 'score' in citation:
119
+ scores.append(f"score: {citation['score']:.3f}")
120
+
121
+ if scores:
122
+ cite_text += f" <small>({', '.join(scores)})</small>"
123
+
124
+ html_parts.append(f"<li>{cite_text}</li>")
125
+
126
+ html_parts.append("</ul></div>")
127
+ return "".join(html_parts)
128
+
129
+ def save_uploaded_file(uploaded_file) -> str:
130
+ """Save uploaded file to temporary location."""
131
+ try:
132
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file:
133
+ tmp_file.write(uploaded_file.getvalue())
134
+ return tmp_file.name
135
+ except Exception as e:
136
+ st.error(f"Error saving file: {e}")
137
+ return None
138
+
139
+
140
+ # Page configuration
141
+ st.set_page_config(
142
+ page_title="Multi-Method RAG System - SIGHT",
143
+ page_icon="🔍",
144
+ layout="wide"
145
+ )
146
 
147
  # Sidebar configuration
148
+ st.sidebar.title("Configuration")
149
+
150
+ # Method selector
151
+ st.sidebar.markdown("### Retrieval Method")
152
+ selected_method = st.sidebar.radio(
153
+ "Choose retrieval method:",
154
+ options=['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision'],
155
+ format_func=lambda x: x.capitalize(),
156
+ help="Select different RAG methods to compare results"
157
+ )
158
+
159
+ # Display method description
160
+ st.sidebar.info(METHOD_DESCRIPTIONS[selected_method])
161
+
162
+
163
+ # Advanced settings
164
+ with st.sidebar.expander("Advanced Settings"):
165
+ top_k = st.slider("Number of chunks to retrieve", min_value=1, max_value=10, value=DEFAULT_TOP_K)
166
+
167
+ if selected_method == 'bm25':
168
+ use_hybrid = st.checkbox("Use hybrid search (BM25 + semantic)", value=False)
169
+ if use_hybrid:
170
+ alpha = st.slider("BM25 weight (alpha)", min_value=0.0, max_value=1.0, value=0.5)
171
+
172
+ # Sidebar info
173
+
174
+ st.sidebar.markdown("---")
175
+ st.sidebar.markdown("### About")
176
  st.sidebar.markdown("**Authors:** [The SIGHT Project Team](https://sites.miamioh.edu/sight/)")
177
+ st.sidebar.markdown(f"**Version:** V. {VERSION}")
178
+ st.sidebar.markdown(f"**Date:** {DATE}")
179
+ st.sidebar.markdown(f"**Model:** {OPENAI_CHAT_MODEL}")
180
 
181
  st.sidebar.markdown("---")
182
  st.sidebar.markdown(
183
  "**Funding:** SIGHT is funded by [OHBWC WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)"
184
  )
185
 
186
+ # Main interface with dynamic status
187
+ col1, col2 = st.columns([3, 1])
188
+ with col1:
189
+ st.title("🔍 Multi-Method RAG System - SIGHT")
190
+ st.markdown("### Compare different retrieval methods for machine safety Q&A")
191
 
192
+ with col2:
193
+ # Quick stats in the header
194
+ if 'chat_history' in st.session_state:
195
+ total_queries = len(st.session_state.chat_history)
196
+ st.metric("Session Queries", total_queries, delta=None if total_queries == 0 else "+1" if total_queries == 1 else f"+{total_queries}")
197
+
198
+ # Voice chat status indicator
199
+ if st.session_state.get('voice_session_active', False):
200
+ st.success("🔴 Voice LIVE")
201
 
202
+ # Create tabs for different interfaces
203
+ tab1, tab2, tab3, tab4 = st.tabs(["💬 Chat", "📊 Method Comparison", "🔊 Voice Chat", "📈 Analytics"])
204
+
205
+ with tab1:
206
+ # Example questions
207
+ with st.expander("📝 Example Questions", expanded=False):
208
+ example_cols = st.columns(2)
209
+ with example_cols[0]:
210
+ st.markdown(
211
+ "**General Safety:**\n"
212
+ "- What are general machine guarding requirements?\n"
213
+ "- How do I perform lockout/tagout?\n"
214
+ "- What is required for emergency stops?"
215
+ )
216
+ with example_cols[1]:
217
+ st.markdown(
218
+ "**Specific Topics:**\n"
219
+ "- Summarize robot safety requirements from OSHA\n"
220
+ "- Compare guard types: fixed vs interlocked\n"
221
+ "- What are the ANSI standards for machine safety?"
222
+ )
223
+
224
+ # File uploader for vision method
225
+ uploaded_file = None
226
+ if selected_method == 'vision':
227
+ st.markdown("#### 🖼️ Upload an image for analysis")
228
+ uploaded_file = st.file_uploader(
229
+ "Choose an image file",
230
+ type=['png', 'jpg', 'jpeg', 'bmp', 'gif'],
231
+ help="Upload an image of safety equipment, signs, or machinery"
232
+ )
233
+
234
+ if uploaded_file:
235
+ col1, col2 = st.columns([1, 2])
236
+ with col1:
237
+ st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
238
+
239
+ # Initialize session state
240
+ if 'chat_history' not in st.session_state:
241
+ st.session_state.chat_history = []
242
+
243
+ if 'session_id' not in st.session_state:
244
+ st.session_state.session_id = str(uuid.uuid4())[:8]
245
+
246
+ # Chat input
247
+ query = st.text_input(
248
+ "Ask a question:",
249
+ placeholder="E.g., What are the safety requirements for collaborative robots?",
250
+ key="chat_input"
251
+ )
252
+
253
+ col1, col2, col3 = st.columns([1, 1, 8])
254
+ with col1:
255
+ send_button = st.button("🚀 Send", type="primary", use_container_width=True)
256
+ with col2:
257
+ clear_button = st.button("🗑️ Clear", use_container_width=True)
258
+
259
+ if clear_button:
260
+ st.session_state.chat_history = []
261
+ st.rerun()
262
+
263
+ if send_button and query:
264
+ # Save uploaded file if present
265
+ image_path = None
266
+ if uploaded_file and selected_method == 'vision':
267
+ image_path = save_uploaded_file(uploaded_file)
268
+
269
+ # Show spinner while processing
270
+ with st.spinner(f"Searching using {selected_method.upper()} method..."):
271
+ start_time = time.time()
272
+ error_message = None
273
+ answer = ""
274
+ citations = []
275
+
276
+ try:
277
+ # Get the appropriate query function
278
+ query_func = QUERY_DISPATCH[selected_method]
279
+
280
+ # Call the query function
281
+ if selected_method == 'vision' and not image_path:
282
+ error_message = "Please upload an image for vision-based search"
283
+ st.error(error_message)
284
+ else:
285
+ answer, citations = query_func(query, image_path=image_path, top_k=top_k)
286
+
287
+ # Add to history
288
+ st.session_state.chat_history.append({
289
+ 'query': query,
290
+ 'answer': answer,
291
+ 'citations': citations,
292
+ 'method': selected_method,
293
+ 'image_path': image_path
294
+ })
295
+
296
+ except Exception as e:
297
+ error_message = str(e)
298
+ answer = f"Error: {error_message}"
299
+ st.error(f"Error processing query: {error_message}")
300
+ st.info("Make sure you've run preprocess.py to generate the required indices.")
301
+
302
+ finally:
303
+ # Log query to analytics database (always, even on error)
304
+ response_time = (time.time() - start_time) * 1000 # Convert to ms
305
+
306
+ try:
307
+ log_query(
308
+ user_query=query,
309
+ method=selected_method,
310
+ answer=answer,
311
+ citations=citations,
312
+ response_time=response_time,
313
+ image_path=image_path,
314
+ error_message=error_message,
315
+ top_k=top_k,
316
+ session_id=st.session_state.session_id
317
+ )
318
+ except Exception as log_error:
319
+ logger.error(f"Failed to log query: {log_error}")
320
+
321
+ # Clean up temp file
322
+ if image_path and os.path.exists(image_path):
323
+ os.unlink(image_path)
324
+
325
+ # Display chat history
326
+ if st.session_state.chat_history:
327
+ st.markdown("---")
328
+ st.markdown("### Chat History")
329
+
330
+ for i, entry in enumerate(reversed(st.session_state.chat_history)):
331
+ with st.container():
332
+ # User message
333
+ st.markdown(f"**🧑 You** ({entry['method'].upper()}):")
334
+ st.markdown(entry['query'])
335
+
336
+ # Assistant response
337
+ st.markdown("**🤖 Assistant:**")
338
+ st.markdown(entry['answer'])
339
+
340
+ # Citations
341
+ st.markdown(format_citations_html(entry['citations'], entry['method']), unsafe_allow_html=True)
342
+
343
+ if i < len(st.session_state.chat_history) - 1:
344
+ st.markdown("---")
345
+
346
+ with tab2:
347
+ st.markdown("### Method Comparison")
348
+ st.markdown("Compare results from different retrieval methods for the same query.")
349
+
350
+ comparison_query = st.text_input(
351
+ "Enter a query to compare across methods:",
352
+ placeholder="E.g., What are the requirements for machine guards?",
353
+ key="comparison_input"
354
+ )
355
+
356
+ methods_to_compare = st.multiselect(
357
+ "Select methods to compare:",
358
+ options=['graph', 'vanilla', 'dpr', 'bm25', 'context'],
359
+ default=['vanilla', 'bm25'],
360
+ help="Vision method requires an image and is not included in comparison"
361
  )
362
+
363
+ col1, col2 = st.columns([3, 1])
364
+ with col1:
365
+ compare_button = st.button("🔍 Compare Methods", type="primary")
366
+ with col2:
367
+ if 'comparison_results' in st.session_state and st.session_state.comparison_results:
368
+ if st.button("🪟 Full Screen View", help="View results in a dedicated comparison window"):
369
+ st.session_state.show_comparison_window = True
370
+ st.rerun()
371
+
372
+ if compare_button:
373
+ if comparison_query and methods_to_compare:
374
+ results = {}
375
+
376
+ progress_bar = st.progress(0)
377
+ for idx, method in enumerate(methods_to_compare):
378
+ with st.spinner(f"Running {method.upper()}..."):
379
+ start_time = time.time()
380
+ error_message = None
381
+
382
+ try:
383
+ query_func = QUERY_DISPATCH[method]
384
+ answer, citations = query_func(comparison_query, top_k=top_k)
385
+ results[method] = {
386
+ 'answer': answer,
387
+ 'citations': citations
388
+ }
389
+ except Exception as e:
390
+ error_message = str(e)
391
+ answer = f"Error: {error_message}"
392
+ citations = []
393
+ results[method] = {
394
+ 'answer': answer,
395
+ 'citations': citations
396
+ }
397
+
398
+ finally:
399
+ # Log comparison queries too
400
+ response_time = (time.time() - start_time) * 1000
401
+ try:
402
+ log_query(
403
+ user_query=comparison_query,
404
+ method=method,
405
+ answer=results[method]['answer'],
406
+ citations=results[method]['citations'],
407
+ response_time=response_time,
408
+ error_message=error_message,
409
+ top_k=top_k,
410
+ session_id=st.session_state.session_id,
411
+ additional_settings={'comparison_mode': True}
412
+ )
413
+ except Exception as log_error:
414
+ logger.error(f"Failed to log comparison query: {log_error}")
415
+
416
+ progress_bar.progress((idx + 1) / len(methods_to_compare))
417
+
418
+ # Store results in session state for full screen view
419
+ st.session_state.comparison_results = {
420
+ 'query': comparison_query,
421
+ 'methods': methods_to_compare,
422
+ 'results': results,
423
+ 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
424
+ }
425
+
426
+ # Display results in compact columns
427
+ cols = st.columns(len(methods_to_compare))
428
+ for idx, (method, col) in enumerate(zip(methods_to_compare, cols)):
429
+ with col:
430
+ st.markdown(f"#### {method.upper()}")
431
+
432
+ # Use expandable container for full text without truncation
433
+ answer = results[method]['answer']
434
+ if len(answer) > 800:
435
+ # Show first 300 chars, then expandable for full text
436
+ st.markdown(answer[:300] + "...")
437
+ with st.expander("📖 Show full answer"):
438
+ st.markdown(answer)
439
+ else:
440
+ # Short answers display fully
441
+ st.markdown(answer)
442
+
443
+ st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True)
444
+ else:
445
+ st.warning("Please enter a query and select at least one method to compare.")
446
+
447
+ with tab3:
448
+ st.markdown("### 🔊 Voice Chat - Hands-free AI Assistant")
449
+
450
+ # Server status check
451
+ server_healthy = check_realtime_server_health()
452
+ if server_healthy:
453
+ st.success("✅ **Voice Server Online** - Ready for voice interactions")
454
+ else:
455
+ st.error("❌ **Voice Server Offline** - Please start the realtime server: `python realtime_server.py`")
456
+ st.code("python realtime_server.py", language="bash")
457
+ st.stop()
458
+
459
+ st.info(
460
+ "🎤 **Real-time Voice Interaction**: Speak naturally and get instant responses from your chosen RAG method. "
461
+ "The AI will automatically transcribe your speech, search the knowledge base, and respond with synthesized voice."
462
+ )
463
+
464
+ # Voice Chat Status and Configuration
465
+ col1, col2 = st.columns([2, 1])
466
+
467
+ with col1:
468
+ # Use the same method from sidebar
469
+ st.info(f"🔍 **Voice using {selected_method.upper()} method** (change in sidebar)")
470
+
471
+ with col2:
472
+ # Voice settings (simplified)
473
+ voice_choice = st.selectbox(
474
+ "🎙️ AI Voice:",
475
+ ["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
476
+ index=0,
477
+ help="Select the AI voice for responses"
478
+ )
479
+ response_speed = st.slider(
480
+ "⏱️ Response Speed (seconds):",
481
+ min_value=1, max_value=5, value=2,
482
+ help="How quickly the AI should respond after you stop speaking"
483
+ )
484
+
485
+ # Auto-detect server URL (hide the manual configuration)
486
+ server_url = "http://localhost:5050" # Default for local development
487
+
488
+ # Voice Chat Interface
489
+ st.markdown("---")
490
+
491
+ # Initialize voice chat session state
492
+ if 'voice_chat_history' not in st.session_state:
493
+ st.session_state.voice_chat_history = []
494
+ if 'voice_session_active' not in st.session_state:
495
+ st.session_state.voice_session_active = False
496
+
497
+ # Simple Status Display
498
+ if st.session_state.voice_session_active:
499
+ st.success("🔴 **LIVE** - Voice chat active using " + voice_method.upper())
500
+
501
+ # Enhanced Voice Interface with better UX
502
+ components.html(f"""
503
+ <!DOCTYPE html>
504
+ <html>
505
+ <head>
506
+ <meta charset="utf-8" />
507
+ <style>
508
+ body {{
509
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
510
+ padding: 20px;
511
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
512
+ color: white;
513
+ border-radius: 10px;
514
+ }}
515
+ .container {{
516
+ max-width: 800px;
517
+ margin: 0 auto;
518
+ background: rgba(255,255,255,0.1);
519
+ padding: 30px;
520
+ border-radius: 15px;
521
+ backdrop-filter: blur(10px);
522
+ }}
523
+ .controls {{
524
+ display: flex;
525
+ gap: 20px;
526
+ align-items: center;
527
+ justify-content: center;
528
+ margin-bottom: 30px;
529
+ }}
530
+ .status-display {{
531
+ text-align: center;
532
+ margin: 20px 0;
533
+ padding: 15px;
534
+ border-radius: 10px;
535
+ background: rgba(255,255,255,0.2);
536
+ }}
537
+ .status-idle {{ background: rgba(108, 117, 125, 0.3); }}
538
+ .status-connecting {{ background: rgba(255, 193, 7, 0.3); }}
539
+ .status-active {{ background: rgba(40, 167, 69, 0.3); }}
540
+ .status-error {{ background: rgba(220, 53, 69, 0.3); }}
541
+
542
+ button {{
543
+ padding: 12px 24px;
544
+ font-size: 16px;
545
+ border: none;
546
+ border-radius: 25px;
547
+ cursor: pointer;
548
+ transition: all 0.3s ease;
549
+ font-weight: bold;
550
+ }}
551
+ .start-btn {{
552
+ background: linear-gradient(45deg, #28a745, #20c997);
553
+ color: white;
554
+ }}
555
+ .start-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(40,167,69,0.4); }}
556
+ .start-btn:disabled {{
557
+ background: #6c757d;
558
+ cursor: not-allowed;
559
+ transform: none;
560
+ box-shadow: none;
561
+ }}
562
+ .stop-btn {{
563
+ background: linear-gradient(45deg, #dc3545, #fd7e14);
564
+ color: white;
565
+ }}
566
+ .stop-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(220,53,69,0.4); }}
567
+ .stop-btn:disabled {{
568
+ background: #6c757d;
569
+ cursor: not-allowed;
570
+ transform: none;
571
+ box-shadow: none;
572
+ }}
573
+
574
+ .log {{
575
+ height: 200px;
576
+ overflow-y: auto;
577
+ border: 1px solid rgba(255,255,255,0.3);
578
+ padding: 15px;
579
+ background: rgba(0,0,0,0.2);
580
+ border-radius: 10px;
581
+ font-family: 'Monaco', 'Menlo', monospace;
582
+ font-size: 13px;
583
+ line-height: 1.4;
584
+ }}
585
+ .audio-controls {{
586
+ text-align: center;
587
+ margin: 20px 0;
588
+ }}
589
+ .pulse {{
590
+ animation: pulse 2s infinite;
591
+ }}
592
+ @keyframes pulse {{
593
+ 0% {{ transform: scale(1); }}
594
+ 50% {{ transform: scale(1.05); }}
595
+ 100% {{ transform: scale(1); }}
596
+ }}
597
+ .visualizer {{
598
+ width: 100%;
599
+ height: 60px;
600
+ background: rgba(0,0,0,0.2);
601
+ border-radius: 10px;
602
+ margin: 10px 0;
603
+ display: flex;
604
+ align-items: center;
605
+ justify-content: center;
606
+ font-size: 14px;
607
+ }}
608
+ </style>
609
+ </head>
610
+ <body>
611
+ <div class="container">
612
+ <div class="status-display status-idle" id="statusDisplay">
613
+ <h3 id="statusTitle">🎤 Voice Chat</h3>
614
+ <p id="statusText">Click "Start Listening" to begin</p>
615
+ </div>
616
+
617
+ <div class="controls">
618
+ <button id="startBtn" class="start-btn">🎤 Start Listening</button>
619
+ <button id="stopBtn" class="stop-btn" disabled>⏹️ Stop</button>
620
+ </div>
621
+
622
+ <div class="audio-controls">
623
+ <audio id="remoteAudio" autoplay style="width: 100%; max-width: 400px;"></audio>
624
+ </div>
625
+
626
+ <div class="visualizer" id="visualizer">
627
+ 🔇 Audio will appear here when active
628
+ </div>
629
+
630
+ <div class="log" id="log"></div>
631
+ </div>
632
+
633
+ <script>
634
+ (async () => {{
635
+ const serverBase = {server_url!r};
636
+ const chosenMethod = {selected_method!r};
637
+ const voiceChoice = {voice_choice!r};
638
+ const responseSpeed = {response_speed!r};
639
+
640
+ const logEl = document.getElementById('log');
641
+ const statusDisplay = document.getElementById('statusDisplay');
642
+ const statusTitle = document.getElementById('statusTitle');
643
+ const statusText = document.getElementById('statusText');
644
+ const startBtn = document.getElementById('startBtn');
645
+ const stopBtn = document.getElementById('stopBtn');
646
+ const visualizer = document.getElementById('visualizer');
647
+
648
+ let pc, dc, micStream;
649
+ let isConnected = false;
650
+ let questionStartTime = null;
651
+
652
+ function updateStatus(status, title, text, className) {{
653
+ statusDisplay.className = `status-display ${{className}}`;
654
+ statusTitle.textContent = title;
655
+ statusText.textContent = text;
656
+ }}
657
+
658
+ function log(msg, type = 'info') {{
659
+ const timestamp = new Date().toLocaleTimeString();
660
+ const icon = type === 'error' ? '❌' : type === 'success' ? '✅' : type === 'warning' ? '⚠️' : 'ℹ️';
661
+ logEl.innerHTML += `<div>${{timestamp}} ${{icon}} ${{msg}}</div>`;
662
+ logEl.scrollTop = logEl.scrollHeight;
663
+ }}
664
 
665
+ async function start() {{
666
+ startBtn.disabled = true;
667
+ stopBtn.disabled = false;
668
+ updateStatus('connecting', '🔄 Connecting...', 'Establishing secure connection to voice services', 'status-connecting');
669
+
670
+ try {{
671
+ log('Initializing voice session...', 'info');
672
+
673
+ // 1) Fetch ephemeral session token
674
+ const sessResp = await fetch(serverBase + "/session", {{
675
+ method: "POST",
676
+ headers: {{ "Content-Type": "application/json" }},
677
+ body: JSON.stringify({{ voice: voiceChoice }})
678
+ }});
679
+
680
+ if (!sessResp.ok) {{
681
+ throw new Error(`Server error: ${{sessResp.status}} ${{sessResp.statusText}}`);
682
+ }}
683
+
684
+ const sess = await sessResp.json();
685
+ if (sess.error) throw new Error(sess.error);
686
+
687
+ const EPHEMERAL_KEY = sess.client_secret;
688
+ if (!EPHEMERAL_KEY) throw new Error("No ephemeral token from server");
689
+
690
+ log('✅ Session token obtained', 'success');
691
+
692
+ // 2) Setup WebRTC
693
+ pc = new RTCPeerConnection();
694
+ const remoteAudio = document.getElementById('remoteAudio');
695
+ pc.ontrack = (event) => {{
696
+ log('🔊 Audio track received from OpenAI', 'success');
697
+ console.log('WebRTC track event:', event);
698
+ const stream = event.streams[0];
699
+ if (stream && stream.getAudioTracks().length > 0) {{
700
+ remoteAudio.srcObject = stream;
701
+ visualizer.textContent = '🔊 Audio stream connected - AI can speak';
702
+ log(`🎵 Audio tracks: ${{stream.getAudioTracks().length}}`, 'success');
703
+ }} else {{
704
+ log('⚠️ No audio tracks in stream', 'warning');
705
+ visualizer.textContent = '⚠️ No audio stream received';
706
+ }}
707
+ }};
708
 
709
+ // 3) Create data channel
710
+ dc = pc.createDataChannel("oai-data");
711
+ dc.onopen = () => {{
712
+ log('🔗 Data channel established', 'success');
713
+ }};
714
+ dc.onerror = (error) => {{
715
+ log('❌ Data channel error: ' + error, 'error');
716
+ }};
717
+ dc.onmessage = (e) => handleDataMessage(e);
 
718
 
719
+ // 4) Get microphone
720
+ log('🎤 Requesting microphone access...', 'info');
721
+ micStream = await navigator.mediaDevices.getUserMedia({{ audio: true }});
722
+ log('✅ Microphone access granted', 'success');
723
+ visualizer.textContent = '🎤 Microphone active - speak naturally';
724
+
725
+ for (const track of micStream.getTracks()) {{
726
+ pc.addTrack(track, micStream);
727
+ }}
728
 
729
+ // 5) Setup audio receiving
730
+ pc.addTransceiver("audio", {{ direction: "recvonly" }});
731
+ log('🔊 Audio receiver configured', 'success');
732
+
733
+ // 6) Create and set local description
734
+ const offer = await pc.createOffer();
735
+ await pc.setLocalDescription(offer);
736
+ log('📡 WebRTC offer created', 'success');
737
+
738
+ // 7) Exchange SDP with OpenAI Realtime
739
+ const baseUrl = "https://api.openai.com/v1/realtime";
740
+ const model = sess.model || "gpt-4o-realtime-preview";
741
+ const sdpResp = await fetch(`${{baseUrl}}?model=${{encodeURIComponent(model)}}`, {{
742
+ method: "POST",
743
+ body: offer.sdp,
744
+ headers: {{
745
+ Authorization: `Bearer ${{EPHEMERAL_KEY}}`,
746
+ "Content-Type": "application/sdp"
747
+ }}
748
+ }});
749
+
750
+ if (!sdpResp.ok) throw new Error(`WebRTC setup failed: ${{sdpResp.status}}`);
751
+
752
+ const answer = {{ type: "answer", sdp: await sdpResp.text() }};
753
+ await pc.setRemoteDescription(answer);
754
+
755
+ // 8) Configure the session with tools and faster response
756
+ // Wait a bit to ensure data channel is fully ready
757
+ setTimeout(() => {{
758
+ if (dc.readyState === 'open') {{
759
+ const toolDecl = {{
760
+ type: "session.update",
761
+ session: {{
762
+ tools: [{{
763
+ "type": "function",
764
+ "name": "ask_rag",
765
+ "description": "Search the safety knowledge base for accurate, authoritative information. Call this immediately when users ask safety questions to get current, reliable information with proper citations.",
766
+ "parameters": {{
767
+ "type": "object",
768
+ "properties": {{
769
+ "query": {{ "type": "string", "description": "User's safety question" }},
770
+ "top_k": {{ "type": "integer", "minimum": 1, "maximum": 20, "default": 5 }}
771
+ }},
772
+ "required": ["query"]
773
+ }}
774
+ }}],
775
+ turn_detection: {{
776
+ type: "server_vad",
777
+ threshold: 0.5,
778
+ prefix_padding_ms: 300,
779
+ silence_duration_ms: {response_speed * 1000}
780
+ }},
781
+ input_audio_transcription: {{
782
+ model: "whisper-1"
783
+ }},
784
+ voice: voiceChoice,
785
+ temperature: 0.7, // Higher temperature for more natural speech
786
+ max_response_output_tokens: 1000, // Allow full responses
787
+ modalities: ["audio", "text"],
788
+ response_format: "audio"
789
+ }}
790
+ }};
791
+ dc.send(JSON.stringify(toolDecl));
792
+ log('🛠️ RAG tools configured', 'success');
793
+
794
+ // Send initial conversation starter to prime the model for natural interaction
795
+ const initialMessage = {{
796
+ type: "conversation.item.create",
797
+ item: {{
798
+ type: "message",
799
+ role: "user",
800
+ content: [{{
801
+ type: "input_text",
802
+ text: "Hello! I'm ready to ask you questions about machine safety. Please speak naturally like a safety expert - no need to mention specific documents or sources, just give me the information as your expertise."
803
+ }}]
804
+ }}
805
+ }};
806
+ dc.send(JSON.stringify(initialMessage));
807
+
808
+ const responseRequest = {{
809
+ type: "response.create",
810
+ response: {{
811
+ modalities: ["audio"],
812
+ instructions: "Acknowledge briefly that you're ready to help with safety questions. Speak naturally and confidently as a safety expert - no citations or document references needed."
813
+ }}
814
+ }};
815
+ dc.send(JSON.stringify(responseRequest));
816
+
817
+ }} else {{
818
+ log('⚠️ Data channel not ready, retrying...', 'warning');
819
+ // Retry after another second
820
+ setTimeout(() => {{
821
+ if (dc.readyState === 'open') {{
822
+ dc.send(JSON.stringify(toolDecl));
823
+ log('🛠️ RAG tools configured (retry)', 'success');
824
+ }}
825
+ }}, 1000);
826
+ }}
827
+ }}, 500);
828
+
829
+ isConnected = true;
830
+ updateStatus('active', '🎤 Live - Speak Now!', `Using ${{chosenMethod.toUpperCase()}} method • Voice: ${{voiceChoice}} • Response: ${{responseSpeed}}s`, 'status-active');
831
+ startBtn.classList.add('pulse');
832
+
833
+ }} catch (error) {{
834
+ log(`❌ Connection failed: ${{error.message}}`, 'error');
835
+ updateStatus('error', '❌ Connection Failed', error.message, 'status-error');
836
+ startBtn.disabled = false;
837
+ stopBtn.disabled = true;
838
+ cleanup();
839
+ }}
840
+ }}
841
+
842
+ function cleanup() {{
843
+ try {{
844
+ if (dc && dc.readyState === 'open') dc.close();
845
+ if (pc) pc.close();
846
+ if (micStream) micStream.getTracks().forEach(t => t.stop());
847
+ }} catch (e) {{ /* ignore cleanup errors */ }}
848
+ startBtn.classList.remove('pulse');
849
+ visualizer.textContent = '🔇 Audio inactive';
850
+ }}
851
+
852
+ async function stop() {{
853
+ startBtn.disabled = false;
854
+ stopBtn.disabled = true;
855
+ isConnected = false;
856
+ updateStatus('idle', '⚪ Session Ended', 'Click "Start Listening" to begin a new voice session', 'status-idle');
857
+ log('🛑 Voice session terminated', 'info');
858
+ cleanup();
859
+ }}
860
+
861
+ // Handle realtime events
862
+ async function handleDataMessage(e) {{
863
+ if (!isConnected) return;
864
+
865
+ try {{
866
+ const msg = JSON.parse(e.data);
867
+
868
+ if (msg.type === "response.function_call") {{
869
+ const {{ name, call_id, arguments: args }} = msg;
870
+ if (name === "ask_rag") {{
871
+ visualizer.textContent = '✅ Question received - searching...';
872
+ const query = JSON.parse(args || "{{}}").query;
873
+ log(`✅ AI heard: "${{query}}"`, 'success');
874
+ log('🔍 Searching knowledge base...', 'info');
875
+
876
+ // Store the transcribed query for analytics
877
+ window.lastVoiceQuery = query;
878
+
879
+ const payload = JSON.parse(args || "{{}}");
880
+ const ragResp = await fetch("${{serverBase}}/rag", {{
881
+ method: "POST",
882
+ headers: {{ "Content-Type": "application/json" }},
883
+ body: JSON.stringify({{
884
+ query: payload.query,
885
+ top_k: payload.top_k ?? 5,
886
+ method: chosenMethod
887
+ }})
888
+ }});
889
+
890
+ const rag = await ragResp.json();
891
+
892
+ // Send result back to model (check if data channel is still open)
893
+ if (dc && dc.readyState === 'open') {{
894
+ dc.send(JSON.stringify({{
895
+ type: "response.function_call_result",
896
+ call_id,
897
+ output: JSON.stringify({{
898
+ answer: rag.answer,
899
+ instruction: "Speak this information naturally as your expertise. Do not mention sources or documents."
900
+ }})
901
+ }}));
902
+ }} else {{
903
+ log('⚠️ Data channel closed, cannot send result', 'warning');
904
+ }}
905
+
906
+ const searchTime = ((Date.now() - questionStartTime) / 1000).toFixed(1);
907
+ log(`✅ Found ${{rag.citations?.length || 0}} citations in ${{searchTime}}s`, 'success');
908
+ visualizer.textContent = '🎙️ AI is speaking your answer...';
909
+ }}
910
+ }}
911
+
912
+ if (msg.type === "input_audio_buffer.speech_started") {{
913
+ questionStartTime = Date.now();
914
+ visualizer.textContent = '🎙️ Listening to you...';
915
+ log('🎤 Speech detected', 'info');
916
+ }}
917
+
918
+ if (msg.type === "input_audio_buffer.speech_stopped") {{
919
+ visualizer.textContent = '🤔 Processing your question...';
920
+ log('⏸️ Processing speech...', 'info');
921
+ }}
922
+
923
+ if (msg.type === "response.audio.delta") {{
924
+ visualizer.textContent = '🔊 AI speaking...';
925
+ }}
926
+
927
+ if (msg.type === "response.done") {{
928
+ if (questionStartTime) {{
929
+ const totalTime = ((Date.now() - questionStartTime) / 1000).toFixed(1);
930
+ visualizer.textContent = '🎤 Your turn - speak now';
931
+ log(`✅ Response complete in ${{totalTime}}s`, 'success');
932
+ questionStartTime = null;
933
+ }} else {{
934
+ visualizer.textContent = '🎤 Your turn - speak now';
935
+ log('✅ Response complete', 'success');
936
+ }}
937
+ }}
938
+
939
+ }} catch (err) {{
940
+ // Ignore non-JSON messages
941
+ }}
942
+ }}
943
+
944
+ startBtn.onclick = start;
945
+ stopBtn.onclick = stop;
946
+
947
+ // Initialize
948
+ log('🚀 Voice chat interface loaded', 'success');
949
+ }})();
950
+ </script>
951
+ </body>
952
+ </html>
953
+ """, height=600, scrolling=True)
954
+
955
+ # Voice Chat History
956
+ if st.session_state.voice_chat_history:
957
+ st.markdown("### 🗣️ Recent Voice Conversations")
958
+ for i, entry in enumerate(reversed(st.session_state.voice_chat_history[-5:])):
959
+ with st.expander(f"🎤 Conversation {len(st.session_state.voice_chat_history)-i} - {entry.get('method', 'unknown').upper()}"):
960
+ st.write(f"**Query**: {entry.get('query', 'N/A')}")
961
+ st.write(f"**Response**: {entry.get('answer', 'N/A')[:200]}...")
962
+ st.write(f"**Citations**: {len(entry.get('citations', []))}")
963
+ st.write(f"**Timestamp**: {entry.get('timestamp', 'N/A')}")
964
+
965
+ with tab4:
966
+ st.markdown("### 📊 Analytics Dashboard")
967
+ st.markdown("*Persistent analytics from all user interactions*")
968
+
969
+ # Time period selector
970
+ col1, col2 = st.columns([3, 1])
971
+ with col1:
972
+ st.markdown("")
973
+ with col2:
974
+ days_filter = st.selectbox("Time Period", [7, 30, 90, 365], index=1, format_func=lambda x: f"Last {x} days")
975
+
976
+ # Get analytics data
977
+ try:
978
+ stats = get_analytics_stats(days=days_filter)
979
+ performance = get_method_performance()
980
+ recent_queries = analytics_db.get_recent_queries(limit=10)
981
+
982
+ # Overview Metrics
983
+ st.markdown("#### 📈 Overview")
984
+ col1, col2, col3, col4 = st.columns(4)
985
+
986
+ with col1:
987
+ st.metric(
988
+ "Total Queries",
989
+ stats.get('total_queries', 0),
990
+ help="All queries processed in the selected time period"
991
+ )
992
+
993
+ with col2:
994
+ avg_citations = stats.get('avg_citations', 0)
995
+ st.metric(
996
+ "Avg Citations",
997
+ f"{avg_citations:.1f}",
998
+ help="Average number of citations per query"
999
+ )
1000
+
1001
+ with col3:
1002
+ error_rate = stats.get('error_rate', 0)
1003
+ st.metric(
1004
+ "Success Rate",
1005
+ f"{100 - error_rate:.1f}%",
1006
+ delta=f"-{error_rate:.1f}% errors" if error_rate > 0 else None,
1007
+ help="Percentage of successful queries"
1008
+ )
1009
+
1010
+ with col4:
1011
+ total_citations = stats.get('total_citations', 0)
1012
+ st.metric(
1013
+ "Total Citations",
1014
+ total_citations,
1015
+ help="Total citations generated across all queries"
1016
+ )
1017
+
1018
+ # Method Performance Comparison
1019
+ if performance:
1020
+ st.markdown("#### ⚡ Method Performance")
1021
+
1022
+ perf_data = []
1023
+ for method, metrics in performance.items():
1024
+ perf_data.append({
1025
+ 'Method': method.upper(),
1026
+ 'Avg Response Time (ms)': f"{metrics['avg_response_time']:.0f}",
1027
+ 'Avg Citations': f"{metrics['avg_citations']:.1f}",
1028
+ 'Avg Answer Length': f"{metrics['avg_answer_length']:.0f}",
1029
+ 'Query Count': int(metrics['query_count'])
1030
+ })
1031
+
1032
+ if perf_data:
1033
+ st.dataframe(perf_data, use_container_width=True, hide_index=True)
1034
+
1035
+ # Method Usage with Voice Interaction Indicator
1036
+ method_usage = stats.get('method_usage', {})
1037
+ if method_usage:
1038
+ st.markdown("#### 🎯 Method Usage Distribution")
1039
+ col1, col2 = st.columns([2, 1])
1040
+
1041
+ with col1:
1042
+ st.bar_chart(method_usage)
1043
+
1044
+ with col2:
1045
+ st.markdown("**Most Popular Methods:**")
1046
+ sorted_methods = sorted(method_usage.items(), key=lambda x: x[1], reverse=True)
1047
+ for i, (method, count) in enumerate(sorted_methods[:3], 1):
1048
+ percentage = (count / sum(method_usage.values())) * 100
1049
+ st.markdown(f"{i}. **{method.upper()}** - {count} queries ({percentage:.1f}%)")
1050
+
1051
+ # Voice interaction stats
1052
+ try:
1053
+ voice_queries = analytics_db.get_voice_interaction_stats()
1054
+ if voice_queries and voice_queries.get('total_voice_queries', 0) > 0:
1055
+ st.markdown("---")
1056
+ st.markdown("**🎤 Voice Interactions:**")
1057
+ st.markdown(f"🔊 Voice queries: {voice_queries['total_voice_queries']}")
1058
+ if voice_queries.get('avg_voice_response_time', 0) > 0:
1059
+ st.markdown(f"⏱️ Avg response time: {voice_queries['avg_voice_response_time']:.1f}ms")
1060
+ if sum(method_usage.values()) > 0:
1061
+ voice_percentage = (voice_queries['total_voice_queries'] / sum(method_usage.values())) * 100
1062
+ st.markdown(f"📊 Voice usage: {voice_percentage:.1f}%")
1063
+ except Exception as e:
1064
+ logger.error(f"Voice stats error: {e}")
1065
+ pass
1066
+
1067
+ # Voice Analytics Section (if voice interactions exist)
1068
+ try:
1069
+ voice_queries = analytics_db.get_voice_interaction_stats()
1070
+ if voice_queries and voice_queries.get('total_voice_queries', 0) > 0:
1071
+ st.markdown("#### 🎤 Voice Interaction Analytics")
1072
+ col1, col2 = st.columns([2, 1])
1073
+
1074
+ with col1:
1075
+ voice_by_method = voice_queries.get('voice_by_method', {})
1076
+ if voice_by_method:
1077
+ st.bar_chart(voice_by_method)
1078
+ else:
1079
+ st.info("No voice method breakdown available yet")
1080
+
1081
+ with col2:
1082
+ st.markdown("**Voice Stats:**")
1083
+ total_voice = voice_queries['total_voice_queries']
1084
+ st.markdown(f"🔊 Total voice queries: {total_voice}")
1085
+
1086
+ avg_response = voice_queries.get('avg_voice_response_time', 0)
1087
+ if avg_response > 0:
1088
+ st.markdown(f"⏱️ Avg response: {avg_response:.1f}ms")
1089
+
1090
+ # Most used voice method
1091
+ if voice_by_method:
1092
+ most_used_voice = max(voice_by_method.items(), key=lambda x: x[1])
1093
+ st.markdown(f"🎯 Top voice method: {most_used_voice[0].upper()}")
1094
+ except Exception as e:
1095
+ logger.error(f"Voice analytics error: {e}")
1096
+
1097
+ # Citation Analysis
1098
+ citation_types = stats.get('citation_types', {})
1099
+ if citation_types:
1100
+ st.markdown("#### 📚 Citation Sources")
1101
+ col1, col2 = st.columns([2, 1])
1102
+
1103
+ with col1:
1104
+ # Filter out empty/null citation types
1105
+ filtered_citations = {k: v for k, v in citation_types.items() if k and k.strip()}
1106
+ if filtered_citations:
1107
+ st.bar_chart(filtered_citations)
1108
+
1109
+ with col2:
1110
+ st.markdown("**Source Breakdown:**")
1111
+ total_citations = sum(citation_types.values())
1112
+ for cite_type, count in sorted(citation_types.items(), key=lambda x: x[1], reverse=True):
1113
+ if cite_type and cite_type.strip():
1114
+ percentage = (count / total_citations) * 100
1115
+ icon = "📄" if cite_type == "pdf" else "🌐" if cite_type == "html" else "🖼️" if cite_type == "image" else "📚"
1116
+ st.markdown(f"{icon} **{cite_type.title()}**: {count} ({percentage:.1f}%)")
1117
+
1118
+ # Popular Keywords
1119
+ keywords = stats.get('top_keywords', {})
1120
+ if keywords:
1121
+ st.markdown("#### 🔍 Popular Query Topics")
1122
+ col1, col2, col3 = st.columns(3)
1123
+
1124
+ keyword_items = list(keywords.items())
1125
+ for i, (word, count) in enumerate(keyword_items[:9]): # Top 9 keywords
1126
+ col = [col1, col2, col3][i % 3]
1127
+ with col:
1128
+ st.metric(word.title(), count)
1129
+
1130
+ # Recent Queries with Responses
1131
+ if recent_queries:
1132
+ st.markdown("#### 🕒 Recent Queries & Responses")
1133
+
1134
+ for query in recent_queries[:5]: # Show last 5
1135
+ # Create expander title with query preview
1136
+ query_preview = query['query'][:60] + "..." if len(query['query']) > 60 else query['query']
1137
+ expander_title = f"🧑 **{query['method'].upper()}**: {query_preview}"
1138
+
1139
+ with st.expander(expander_title):
1140
+ # Query details
1141
+ st.markdown(f"**📝 Full Query:** {query['query']}")
1142
+
1143
+ # Metrics row
1144
+ col1, col2, col3, col4 = st.columns(4)
1145
+ with col1:
1146
+ st.metric("Answer Length", f"{query['answer_length']} chars")
1147
+ with col2:
1148
+ st.metric("Citations", query['citations'])
1149
+ with col3:
1150
+ if query['response_time']:
1151
+ st.metric("Response Time", f"{query['response_time']:.0f}ms")
1152
+ else:
1153
+ st.metric("Response Time", "N/A")
1154
+ with col4:
1155
+ status = "❌ Error" if query.get('error_message') else "✅ Success"
1156
+ st.markdown(f"**Status:** {status}")
1157
+
1158
+ # Show error message if exists
1159
+ if query.get('error_message'):
1160
+ st.error(f"**Error:** {query['error_message']}")
1161
+ else:
1162
+ # Show answer in a styled container
1163
+ st.markdown("**🤖 Response:**")
1164
+ answer = query.get('answer', 'No answer available')
1165
+
1166
+ # Truncate very long answers for better UX
1167
+ if len(answer) > 1000:
1168
+ st.markdown(
1169
+ f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">'
1170
+ f'{answer[:500].replace(chr(10), "<br>")}<br><br>'
1171
+ f'<i>... (truncated, showing first 500 chars of {len(answer)} total)</i>'
1172
+ f'</div>',
1173
+ unsafe_allow_html=True
1174
+ )
1175
+
1176
+ # Option to view full answer
1177
+ if st.button(f"📖 View Full Answer", key=f"full_answer_{query['query_id']}"):
1178
+ st.markdown("**Full Answer:**")
1179
+ st.markdown(
1180
+ f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; max-height: 400px; overflow-y: auto;">'
1181
+ f'{answer.replace(chr(10), "<br>")}'
1182
+ f'</div>',
1183
+ unsafe_allow_html=True
1184
+ )
1185
+ else:
1186
+ # Short answers display fully
1187
+ st.markdown(
1188
+ f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">'
1189
+ f'{answer.replace(chr(10), "<br>")}'
1190
+ f'</div>',
1191
+ unsafe_allow_html=True
1192
+ )
1193
+
1194
+ # Show detailed citation info
1195
+ if query['citations'] > 0:
1196
+ if st.button(f"📚 View Citations", key=f"citations_{query['query_id']}"):
1197
+ detailed_query = analytics_db.get_query_with_citations(query['query_id'])
1198
+ if detailed_query and 'citations' in detailed_query:
1199
+ st.markdown("**Citations:**")
1200
+ for i, citation in enumerate(detailed_query['citations'], 1):
1201
+ scores = []
1202
+ if citation.get('relevance_score'):
1203
+ scores.append(f"relevance: {citation['relevance_score']:.3f}")
1204
+ if citation.get('bm25_score'):
1205
+ scores.append(f"BM25: {citation['bm25_score']:.3f}")
1206
+ if citation.get('rerank_score'):
1207
+ scores.append(f"rerank: {citation['rerank_score']:.3f}")
1208
+
1209
+ score_text = f" ({', '.join(scores)})" if scores else ""
1210
+ st.markdown(f"{i}. **{citation['source']}** {score_text}")
1211
+
1212
+ st.markdown(f"**🕐 Timestamp:** {query['timestamp']}")
1213
+ st.markdown("---")
1214
+
1215
+ # Session Info
1216
+ st.markdown("---")
1217
+ col1, col2 = st.columns([3, 1])
1218
+ with col1:
1219
+ st.markdown("*Analytics are updated in real-time and persist across sessions*")
1220
+ with col2:
1221
+ st.markdown(f"**Session ID:** `{st.session_state.session_id}`")
1222
+
1223
+ except Exception as e:
1224
+ st.error(f"Error loading analytics: {e}")
1225
+ st.info("Analytics data will appear after your first query. The database is created automatically.")
1226
+
1227
+ # Fallback to session analytics
1228
+ if st.session_state.chat_history:
1229
+ st.markdown("#### 📊 Current Session")
1230
+ col1, col2 = st.columns(2)
1231
+ with col1:
1232
+ st.metric("Session Queries", len(st.session_state.chat_history))
1233
+ with col2:
1234
+ methods_used = [entry['method'] for entry in st.session_state.chat_history]
1235
+ most_used = max(set(methods_used), key=methods_used.count) if methods_used else "N/A"
1236
+ st.metric("Most Used Method", most_used.upper() if most_used != "N/A" else most_used)
1237
+
1238
+ # Full Screen Comparison Window (Modal-like)
1239
+ if st.session_state.get('show_comparison_window', False):
1240
+ st.markdown("---")
1241
+
1242
+ # Header with close button
1243
+ col1, col2 = st.columns([4, 1])
1244
+ with col1:
1245
+ comparison_data = st.session_state.comparison_results
1246
+ st.markdown(f"## 🪟 Full Screen Comparison")
1247
+ st.markdown(f"**Query:** {comparison_data['query']}")
1248
+ st.markdown(f"**Generated:** {comparison_data['timestamp']} | **Methods:** {', '.join([m.upper() for m in comparison_data['methods']])}")
1249
+
1250
+ with col2:
1251
+ if st.button("✖️ Close", help="Close full screen view"):
1252
+ st.session_state.show_comparison_window = False
1253
+ st.rerun()
1254
+
1255
+ st.markdown("---")
1256
+
1257
+ # Full-width comparison display
1258
+ results = comparison_data['results']
1259
+ methods = comparison_data['methods']
1260
+
1261
+ for method in methods:
1262
+ st.markdown(f"### 🔸 {method.upper()} Method")
1263
+
1264
+ # Answer
1265
+ answer = results[method]['answer']
1266
+ st.markdown("**Answer:**")
1267
+
1268
+ # Use a container with custom styling for better readability
1269
+ with st.container():
1270
+ st.markdown(
1271
+ f'<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin: 10px 0; border-left: 5px solid #1f77b4;">'
1272
+ f'{answer.replace(chr(10), "<br>")}'
1273
+ f'</div>',
1274
+ unsafe_allow_html=True
1275
+ )
1276
+
1277
+ # Citations
1278
+ st.markdown("**Citations:**")
1279
+ st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True)
1280
+
1281
+ # Statistics
1282
+ col1, col2, col3 = st.columns(3)
1283
+ with col1:
1284
+ st.metric("Answer Length", f"{len(answer)} chars")
1285
+ with col2:
1286
+ st.metric("Citations", len(results[method]['citations']))
1287
+ with col3:
1288
+ word_count = len(answer.split())
1289
+ st.metric("Word Count", word_count)
1290
+
1291
+ if method != methods[-1]: # Not the last method
1292
+ st.markdown("---")
1293
+
1294
+ # Summary comparison table
1295
+ st.markdown("### 📊 Method Comparison Summary")
1296
+
1297
+ summary_data = []
1298
+ for method in methods:
1299
+ summary_data.append({
1300
+ 'Method': method.upper(),
1301
+ 'Answer Length (chars)': len(results[method]['answer']),
1302
+ 'Word Count': len(results[method]['answer'].split()),
1303
+ 'Citations': len(results[method]['citations']),
1304
+ 'Avg Citation Score': round(
1305
+ sum(float(c.get('relevance_score', 0) or c.get('score', 0) or 0)
1306
+ for c in results[method]['citations']) / len(results[method]['citations'])
1307
+ if results[method]['citations'] else 0, 3
1308
+ )
1309
+ })
1310
+
1311
+ st.dataframe(summary_data, use_container_width=True, hide_index=True)
1312
+
1313
+ st.markdown("---")
1314
+
1315
+ # Return to normal view button
1316
+ col1, col2, col3 = st.columns([2, 1, 2])
1317
+ with col2:
1318
+ if st.button("⬅️ Back to Comparison Tab", type="primary", use_container_width=True):
1319
+ st.session_state.show_comparison_window = False
1320
+ st.rerun()
1321
+
1322
+ st.stop() # Stop rendering the rest of the app when in full screen mode
1323
 
1324
  # Footer
1325
  st.markdown("---")
1326
  st.markdown(
1327
+ "**⚠️ Disclaimer:** *This system uses AI to retrieve and generate responses. "
1328
+ "While we strive for accuracy, please verify critical safety information with official sources.*"
1329
+ )
1330
  st.markdown(
1331
+ "**🙏 Acknowledgment:** *We thank [Ohio BWC/WSIC](https://info.bwc.ohio.gov/) "
1332
+ "for funding that made this multi-method RAG system possible.*"
1333
+ )
config.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Central configuration file for the Multi-Method RAG System.
3
+ All shared parameters and settings are defined here.
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ load_dotenv(override=True)
12
+
13
+ # ==================== Versioning and Date ====================
14
+ DATE = "August 13, 2025"
15
+ VERSION = "2.0.1"
16
+
17
+
18
+ # ==================== API Configuration ====================
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
+ OPENAI_CHAT_MODEL = "gpt-5-chat-latest" # This is the non-reasoning model for gpt-5 so it has no latency
21
+ OPENAI_EMBEDDING_MODEL = "text-embedding-3-large" # Options: text-embedding-3-large, text-embedding-3-small, text-embedding-ada-002
22
+
23
+ # ==================== Realtime API Configuration ====================
24
+ # OpenAI Realtime API settings for speech-to-speech functionality
25
+ OPENAI_REALTIME_MODEL = "gpt-4o-realtime-preview" # Realtime model for speech-to-speech
26
+ REALTIME_VOICE = "alloy" # Available voices: alloy, echo, fable, onyx, nova, shimmer
27
+ REALTIME_INSTRUCTIONS = (
28
+ "You are a knowledgeable safety expert speaking naturally in conversation. "
29
+
30
+ "VOICE BEHAVIOR: "
31
+ "- Speak like a confident safety professional talking to a colleague "
32
+ "- Acknowledge what you heard: 'You're asking about [topic]...' "
33
+ "- Use natural speech with appropriate pauses and emphasis "
34
+ "- Sound authoritative and knowledgeable - you ARE the expert "
35
+ "- Never mention document names, page numbers, or citation details when speaking "
36
+ "- Just state the facts naturally as if you know them from your expertise "
37
+
38
+ "RESPONSE PROCESS: "
39
+ "1. Briefly acknowledge the question: 'You're asking about [topic]...' "
40
+ "2. Call ask_rag to get the accurate information "
41
+ "3. Speak the information naturally as YOUR expertise, not as 'according to document X' "
42
+ "4. Organize complex topics: 'There are three key requirements here...' "
43
+ "5. Be thorough but conversational - like explaining to a colleague "
44
+
45
+ "CITATION RULE: "
46
+ "NEVER mention specific documents, sources, or page numbers in speech. "
47
+ "Just state the information confidently as if it's your professional knowledge. "
48
+ "For example, don't say 'According to OSHA 1910.147...' - just say 'The lockout tagout requirements are...' "
49
+
50
+ "IMPORTANT: Always use ask_rag for safety questions to get accurate information, "
51
+ "but speak the results as your own expertise, not as citations."
52
+ )
53
+
54
+ # ==================== Model Parameters ====================
55
+ # Generation parameters
56
+ DEFAULT_TEMPERATURE = 0 # Range: 0.0-1.0 (0=deterministic, 1=creative)
57
+ DEFAULT_MAX_TOKENS = 5000 # Maximum tokens in response
58
+ DEFAULT_TOP_K = 5 # Number of chunks to retrieve by default
59
+ DEFAULT_TOP_P = 1.0 # Nucleus sampling parameter
60
+
61
+ # Context window management
62
+ MAX_CONTEXT_TOKENS = 7500 # Maximum context for models with 8k window
63
+ CHUNK_SIZE = 2000 # Tokens per chunk (used by TextPreprocessor.chunk_text_by_tokens)
64
+ CHUNK_OVERLAP = 200 # Token overlap between chunks
65
+
66
+ # ==================== Embedding Models ====================
67
+ # Sentence Transformers models
68
+ SENTENCE_TRANSFORMER_MODEL = 'all-MiniLM-L6-v2' # For DPR
69
+ CROSS_ENCODER_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2' # For re-ranking
70
+
71
+ # CLIP model
72
+ CLIP_MODEL = "ViT-L/14" # Options: ViT-B/32, ViT-L/14, RN50
73
+
74
+ # ==================== Search Parameters ====================
75
+ # BM25 parameters
76
+ BM25_K1 = 1.5 # Term frequency saturation parameter
77
+ BM25_B = 0.75 # Length normalization parameter
78
+
79
+ # Hybrid search
80
+ DEFAULT_HYBRID_ALPHA = 0.5 # Weight for BM25 (1-alpha for semantic)
81
+
82
+ # Re-ranking
83
+ RERANK_MULTIPLIER = 2 # Retrieve this many times top_k for re-ranking
84
+ MIN_RELEVANCE_SCORE = 0.3 # Minimum score threshold
85
+
86
+ # ==================== Directory Structure ====================
87
+ # Project directories
88
+ PROJECT_ROOT = Path(__file__).parent
89
+ DATA_DIR = PROJECT_ROOT / "data"
90
+ EMBEDDINGS_DIR = PROJECT_ROOT / "embeddings"
91
+ GRAPH_DIR = PROJECT_ROOT / "graph"
92
+ METADATA_DIR = PROJECT_ROOT / "metadata"
93
+ IMAGES_DIR = DATA_DIR / "images"
94
+
95
+ # File paths
96
+ VANILLA_FAISS_INDEX = EMBEDDINGS_DIR / "vanilla_faiss.index"
97
+ VANILLA_METADATA = EMBEDDINGS_DIR / "vanilla_metadata.pkl"
98
+ DPR_FAISS_INDEX = EMBEDDINGS_DIR / "dpr_faiss.index"
99
+ DPR_METADATA = EMBEDDINGS_DIR / "dpr_metadata.pkl"
100
+ BM25_INDEX = EMBEDDINGS_DIR / "bm25_index.pkl"
101
+ CONTEXT_DOCS = EMBEDDINGS_DIR / "context_stuffing_docs.pkl"
102
+ GRAPH_FILE = GRAPH_DIR / "graph.gml"
103
+ IMAGES_DB = METADATA_DIR / "images.db"
104
+ CHROMA_PATH = EMBEDDINGS_DIR / "chroma"
105
+
106
+ # ==================== Batch Processing ====================
107
+ EMBEDDING_BATCH_SIZE = 100 # Batch size for OpenAI embeddings
108
+ PROCESSING_BATCH_SIZE = 50 # Documents to process at once
109
+
110
+ # ==================== UI Configuration ====================
111
+ # Streamlit settings
112
+ MAX_CHAT_HISTORY = 5 # Maximum chat messages to keep
113
+ EXAMPLE_QUESTIONS = [
114
+ "What are general machine guarding requirements?",
115
+ "How do I perform lockout/tagout?",
116
+ "What safety measures are needed for robotic systems?",
117
+ "Explain the difference between guards and devices in machine safety.",
118
+ "What are the OSHA requirements for emergency stops?",
119
+ ]
120
+
121
+ # Default method
122
+ DEFAULT_METHOD = "graph"
123
+
124
+ # Method descriptions for UI
125
+ METHOD_DESCRIPTIONS = {
126
+ 'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval",
127
+ 'vanilla': "Standard vector search with FAISS and OpenAI embeddings",
128
+ 'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking",
129
+ 'bm25': "BM25 keyword search with neural re-ranking for exact term matching",
130
+ 'context': "Context stuffing with full document loading and heuristic selection",
131
+ 'vision': "Vision-based search using GPT-5 Vision for image analysis and classification"
132
+ }
133
+
134
+ # ==================== Document Processing ====================
135
+ # Document types
136
+ SUPPORTED_EXTENSIONS = ['.pdf', '.txt', '.md', '.html']
137
+ IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
138
+
139
+ # Text splitting
140
+ MARKDOWN_HEADER_LEVEL = 3 # Split by this header level (###)
141
+ MAX_SECTIONS_PER_DOC = 500 # Maximum sections to extract from a document
142
+
143
+ # ==================== Logging ====================
144
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") # DEBUG, INFO, WARNING, ERROR
145
+ LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
146
+
147
+ # ==================== Performance ====================
148
+ # Device configuration
149
+ import torch
150
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
151
+ NUM_WORKERS = 4 # Parallel processing workers
152
+
153
+ # Cache settings
154
+ ENABLE_CACHE = True
155
+ CACHE_TTL = 3600 # Cache time-to-live in seconds
156
+
157
+ # ==================== Safety & Validation ====================
158
+ # Input validation
159
+ MAX_QUESTION_LENGTH = 1000 # Maximum characters in a question
160
+ MAX_IMAGE_SIZE_MB = 10 # Maximum image file size
161
+
162
+ # Rate limiting (if needed)
163
+ RATE_LIMIT_ENABLED = False
164
+ MAX_QUERIES_PER_MINUTE = 60
165
+
166
+ # ==================== Default HTML Sources ====================
167
+ DEFAULT_HTML_SOURCES = [
168
+ {
169
+ "title": "NIOSH Robotics in the Workplace – Safety Overview",
170
+ "url": "https://www.cdc.gov/niosh/robotics/about/",
171
+ "source": "NIOSH",
172
+ "year": 2024,
173
+ "category": "Technical Guide",
174
+ "format": "HTML"
175
+ }
176
+ ]
177
+
178
+ # ==================== Helper Functions ====================
179
+ def ensure_directories():
180
+ """Create all required directories if they don't exist."""
181
+ for directory in [DATA_DIR, EMBEDDINGS_DIR, GRAPH_DIR, METADATA_DIR, IMAGES_DIR]:
182
+ directory.mkdir(parents=True, exist_ok=True)
183
+
184
+ def get_model_context_length(model_name: str = OPENAI_CHAT_MODEL) -> int:
185
+ """Get the context length for a given model."""
186
+ context_lengths = {
187
+ "gpt-5": 128000,
188
+ "gpt-4o-mini": 8192,
189
+ "gpt-4o": 128000,
190
+ }
191
+ return context_lengths.get(model_name, 4096)
192
+
193
+ def validate_api_key():
194
+ """Check if OpenAI API key is set."""
195
+ if not OPENAI_API_KEY:
196
+ raise ValueError(
197
+ "OpenAI API key not found. Please set OPENAI_API_KEY in .env file."
198
+ )
199
+ return True
200
+
201
+ # ==================== System Info ====================
202
+ def print_config():
203
+ """Print current configuration for debugging."""
204
+ print("="*50)
205
+ print("RAG System Configuration")
206
+ print("="*50)
207
+ print(f"OpenAI Model: {OPENAI_CHAT_MODEL}")
208
+ print(f"Embedding Model: {OPENAI_EMBEDDING_MODEL}")
209
+ print(f"Device: {DEVICE}")
210
+ print(f"Default Temperature: {DEFAULT_TEMPERATURE}")
211
+ print(f"Default Top-K: {DEFAULT_TOP_K}")
212
+ print(f"Chunk Size: {CHUNK_SIZE}")
213
+ print(f"Project Root: {PROJECT_ROOT}")
214
+ print("="*50)
215
+
216
+ # Ensure directories exist on import
217
+ ensure_directories()
preprocess.py CHANGED
@@ -1,85 +1,195 @@
1
- import os
2
- import re
3
- import glob
4
- from dotenv import load_dotenv
5
- import requests
6
- from bs4 import BeautifulSoup
7
- import pandas as pd
8
- import pymupdf4llm
9
- import networkx as nx
10
- from openai import OpenAI
11
 
12
- # Load environment and initialize OpenAI client
13
- load_dotenv(override=True)
14
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
15
 
16
- # Helper: split Markdown text by third-level headers
17
- def split_by_header(md_text):
18
- parts = re.split(r'(?m)^### ', md_text)
19
- return [('### ' + p) if not p.startswith('### ') else p for p in parts if p.strip()]
20
 
21
- # Initialize graph database
22
- G = nx.Graph()
 
 
 
23
 
24
- # Process local PDFs
25
- for pdf_path in glob.glob("scrapped_data/*.pdf"):
26
- filename = os.path.basename(pdf_path)
27
- title = os.path.splitext(filename)[0]
28
- # Convert PDF to Markdown
29
- md_text = pymupdf4llm.to_markdown(pdf_path)
30
- # Split into sections
31
- sections = split_by_header(md_text)
32
- for idx, sec in enumerate(sections):
33
- resp = client.embeddings.create(model="text-embedding-3-large", input=sec)
34
- vector = resp.data[0].embedding
35
- node_id = f"PDF::{title}::section{idx}"
36
- # Store the local file path for citation
37
- G.add_node(node_id,
38
- text=sec,
39
- embedding=vector,
40
- source=title,
41
- path=pdf_path)
42
 
43
- # HTML Document List
44
- html_data = [
45
- {
46
- "title": "NIOSH Robotics in the Workplace – Safety Overview (Human-Robot Collaboration)",
47
- "url": "https://www.cdc.gov/niosh/robotics/about/",
48
- "source": "NIOSH",
49
- "year": 2024,
50
- "category": "Technical Guide",
51
- "summary": "A NIOSH overview of emerging safety challenges as robots increasingly collaborate with human workers. Updated in 2024, this page discusses how robots can improve safety by taking over dangerous tasks but also introduces new struck-by and caught-between hazards. It emphasizes the need for updated safety standards, risk assessments, and research on human-robot interaction, and it outlines NIOSH’s efforts (through its Center for Occupational Robotics Research) to develop best practices and guidance for safe integration of robotics in industry.",
52
- "format": "HTML"
53
- }
54
- ]
55
 
56
- # Process HTML sources
57
- def process_html(item):
58
- resp = requests.get(item['url'])
59
- resp.raise_for_status()
60
- soup = BeautifulSoup(resp.text, 'html.parser')
61
- # Extract paragraph texts
62
- texts = [p.get_text() for p in soup.find_all('p')]
63
- # Extract tables as markdown
64
- tables = []
65
- for t in soup.find_all('table'):
66
- df = pd.read_html(str(t))[0]
67
- tables.append(df.to_markdown())
68
- # Join paragraphs and tables with double newlines
69
- full = "\n\n".join(texts + tables)
70
- # Embed the combined text
71
- resp_emb = client.embeddings.create(model="text-embedding-3-large", input=full)
72
- vec = resp_emb.data[0].embedding
73
- node_id = f"HTML::{item['title']}"
74
- # Add node with URL citation
75
- G.add_node(
76
- node_id, text=full, embedding=vec, source=item['title'], url=item['url']
77
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Run HTML processing
80
- for item in html_data:
81
- process_html(item)
82
 
83
- # Save graph
84
- nx.write_gml(G, "graph.gml")
85
- print("Graph RAG database created: graph.gml")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refactored preprocessing pipeline for all RAG methods.
3
+ Uses utils.py functions and supports multiple retrieval methods.
 
 
 
 
 
 
 
4
 
5
+ Directory Layout:
6
+ /data/ # Original PDFs, HTML
7
+ /embeddings/ # FAISS, Chroma, DPR vector stores
8
+ /graph/ # Graph database files
9
+ /metadata/ # Image metadata (SQLite or MongoDB)
10
+ """
11
 
12
+ import logging
13
+ from pathlib import Path
 
 
14
 
15
+ from config import *
16
+ from utils import (
17
+ DocumentLoader, TextPreprocessor, VectorStoreManager,
18
+ ImageProcessor, ImageData
19
+ )
20
 
21
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Ensure all directories exist
24
+ ensure_directories()
 
 
 
 
 
 
 
 
 
 
25
 
26
+ def preprocess_for_method(method: str, documents: list):
27
+ """Preprocess documents for a specific retrieval method."""
28
+
29
+ print(f"\n{'='*50}")
30
+ print(f"Preprocessing for method: {method}")
31
+ print(f"{'='*50}")
32
+
33
+ try:
34
+ # Initialize processors
35
+ text_processor = TextPreprocessor()
36
+ vector_manager = VectorStoreManager()
37
+
38
+ # Preprocess text chunks for this method
39
+ chunks = text_processor.preprocess_for_method(documents, method)
40
+
41
+ if method == 'vanilla':
42
+ # Build FAISS index with OpenAI embeddings
43
+ index, metadata = vector_manager.build_faiss_index(chunks, method="vanilla")
44
+ vector_manager.save_index(index, metadata, method)
45
+
46
+ elif method == 'dpr':
47
+ # Build FAISS index with sentence transformer embeddings
48
+ index, metadata = vector_manager.build_faiss_index(chunks, method="dpr")
49
+ vector_manager.save_index(index, metadata, method)
50
+
51
+ elif method == 'bm25':
52
+ # Build BM25 index
53
+ bm25_index = vector_manager.build_bm25_index(chunks)
54
+ vector_manager.save_index(bm25_index, chunks, method)
55
+
56
+ elif method == 'graph':
57
+ # Build NetworkX graph
58
+ graph = vector_manager.build_graph_index(chunks)
59
+ vector_manager.save_index(graph, None, method)
60
+
61
+ elif method == 'context_stuffing':
62
+ # Save full documents for context stuffing
63
+ vector_manager.save_index(None, chunks, method)
64
+
65
+ else:
66
+ raise ValueError(f"Unknown method: {method}")
67
+
68
+ print(f"Successfully preprocessed for method '{method}'")
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error preprocessing for {method}: {e}")
72
+ raise
73
 
 
 
 
74
 
75
+ def extract_and_process_images(documents: list):
76
+ """Extract images from documents and process them."""
77
+ print("\n" + "="*50)
78
+ print("Extracting and processing images...")
79
+ print("="*50)
80
+
81
+ image_processor = ImageProcessor()
82
+ processed_count = 0
83
+ filtered_count = 0
84
+ filter_reasons = {}
85
+
86
+ for doc in documents:
87
+ if 'images' in doc and doc['images']:
88
+ for image_info in doc['images']:
89
+ try:
90
+ # Check if image should be filtered out
91
+ should_filter, reason = image_processor.should_filter_image(image_info['image_path'])
92
+
93
+ if should_filter:
94
+ filtered_count += 1
95
+ filter_reasons[reason] = filter_reasons.get(reason, 0) + 1
96
+ print(f" Filtered: {image_info['image_id']} - {reason}")
97
+
98
+ # Optionally delete the filtered image file
99
+ try:
100
+ import os
101
+ os.remove(image_info['image_path'])
102
+ print(f" Deleted: {image_info['image_path']}")
103
+ except Exception as e:
104
+ logger.warning(f"Could not delete filtered image {image_info['image_path']}: {e}")
105
+
106
+ continue
107
+
108
+ # Classify image
109
+ classification = image_processor.classify_image(image_info['image_path'])
110
+
111
+ # Generate embedding (placeholder for now)
112
+ # embedding = embed_image_clip([image_info['image_path']])[0]
113
+
114
+ # Create ImageData object
115
+ image_data = ImageData(
116
+ image_path=image_info['image_path'],
117
+ image_id=image_info['image_id'],
118
+ classification=classification,
119
+ metadata={
120
+ 'source': doc['source'],
121
+ 'page': image_info.get('page'),
122
+ 'extracted_from': doc['path']
123
+ }
124
+ )
125
+
126
+ # Store in database
127
+ image_processor.store_image_metadata(image_data)
128
+ processed_count += 1
129
+
130
+ except Exception as e:
131
+ logger.error(f"Error processing image {image_info['image_id']}: {e}")
132
+ continue
133
+
134
+ # Print filtering summary
135
+ if filtered_count > 0:
136
+ print(f"\nImage Filtering Summary:")
137
+ print(f" Total filtered: {filtered_count}")
138
+ for reason, count in filter_reasons.items():
139
+ print(f" {reason}: {count}")
140
+ print()
141
+
142
+ if processed_count > 0:
143
+ print(f"Processed and stored metadata for {processed_count} images")
144
+ else:
145
+ print("No images found in documents")
146
+
147
+
148
+ def main():
149
+ """Main preprocessing pipeline."""
150
+ # Validate configuration
151
+ try:
152
+ validate_api_key()
153
+ except ValueError as e:
154
+ print(f"Error: {e}")
155
+ return
156
+
157
+ # Print configuration
158
+ print_config()
159
+
160
+ print("\nStarting preprocessing pipeline...")
161
+
162
+ # Load documents using utils
163
+ print("\nLoading documents...")
164
+ loader = DocumentLoader()
165
+ documents = loader.load_text_documents()
166
+
167
+ print(f"Loaded {len(documents)} documents")
168
+
169
+ # Define methods to preprocess
170
+ methods = ['vanilla', 'dpr', 'bm25', 'graph', 'context_stuffing']
171
+
172
+ # Preprocess for each method
173
+ for method in methods:
174
+ try:
175
+ preprocess_for_method(method, documents)
176
+ except Exception as e:
177
+ print(f"Error preprocessing for {method}: {e}")
178
+ import traceback
179
+ traceback.print_exc()
180
+
181
+ # Extract and process images
182
+ try:
183
+ extract_and_process_images(documents)
184
+ except Exception as e:
185
+ print(f"Error processing images: {e}")
186
+ import traceback
187
+ traceback.print_exc()
188
+
189
+ print("\n" + "="*50)
190
+ print("Preprocessing complete!")
191
+ print("="*50)
192
+
193
+
194
+ if __name__ == "__main__":
195
+ main()
query_bm25.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 keyword search with cross-encoder re-ranking and hybrid search support.
3
+ """
4
+
5
+ import numpy as np
6
+ import faiss
7
+ from typing import Tuple, List, Optional
8
+ from openai import OpenAI
9
+ from sentence_transformers import CrossEncoder
10
+
11
+ import config
12
+ import utils
13
+
14
+ # Initialize models
15
+ client = OpenAI(api_key=config.OPENAI_API_KEY)
16
+ cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL)
17
+
18
+ # Global variables for lazy loading
19
+ _bm25_index = None
20
+ _texts = None
21
+ _metadata = None
22
+ _semantic_index = None
23
+
24
+ def _load_bm25_index():
25
+ """Lazy load BM25 index and metadata."""
26
+ global _bm25_index, _texts, _metadata, _semantic_index
27
+
28
+ if _bm25_index is None:
29
+ # Initialize defaults
30
+ _texts = []
31
+ _metadata = []
32
+ _semantic_index = None
33
+ try:
34
+ import pickle
35
+
36
+ if config.BM25_INDEX.exists():
37
+ with open(config.BM25_INDEX, 'rb') as f:
38
+ bm25_data = pickle.load(f)
39
+
40
+ if isinstance(bm25_data, dict):
41
+ _bm25_index = bm25_data.get('index') or bm25_data.get('bm25')
42
+ chunks = bm25_data.get('texts', [])
43
+ if chunks:
44
+ _texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')]
45
+ _metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')]
46
+ else:
47
+ _texts = []
48
+ _metadata = []
49
+
50
+ # Load semantic embeddings if available for hybrid search
51
+ if 'embeddings' in bm25_data:
52
+ semantic_embeddings = bm25_data['embeddings']
53
+ # Build FAISS index
54
+ import faiss
55
+ dimension = semantic_embeddings.shape[1]
56
+ _semantic_index = faiss.IndexFlatIP(dimension)
57
+ faiss.normalize_L2(semantic_embeddings)
58
+ _semantic_index.add(semantic_embeddings)
59
+ else:
60
+ _bm25_index = bm25_data
61
+ _texts = []
62
+ _metadata = []
63
+
64
+ print(f"Loaded BM25 index with {len(_texts)} documents")
65
+ else:
66
+ print("BM25 index not found. Run preprocess.py first.")
67
+
68
+ except Exception as e:
69
+ print(f"Error loading BM25 index: {e}")
70
+ _bm25_index = None
71
+ _texts = []
72
+ _metadata = []
73
+
74
+
75
+ def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
76
+ """
77
+ Query using BM25 keyword search with re-ranking.
78
+
79
+ Args:
80
+ question: User's question
81
+ image_path: Optional path to an image
82
+ top_k: Number of relevant chunks to retrieve
83
+
84
+ Returns:
85
+ Tuple of (answer, citations)
86
+ """
87
+ if top_k is None:
88
+ top_k = config.DEFAULT_TOP_K
89
+
90
+ # Load index if not already loaded
91
+ _load_bm25_index()
92
+
93
+ if _bm25_index is None or len(_texts) == 0:
94
+ return "BM25 index not loaded. Please run preprocess.py first.", []
95
+
96
+ # Tokenize query for BM25
97
+ tokenized_query = question.lower().split()
98
+
99
+ # Get BM25 scores
100
+ bm25_scores = _bm25_index.get_scores(tokenized_query)
101
+
102
+ # Get top candidates (retrieve more for re-ranking)
103
+ top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
104
+
105
+ # Prepare candidates for re-ranking
106
+ candidates = []
107
+ for idx in top_indices:
108
+ if idx < len(_texts) and bm25_scores[idx] > 0:
109
+ candidates.append({
110
+ 'text': _texts[idx],
111
+ 'bm25_score': bm25_scores[idx],
112
+ 'metadata': _metadata[idx],
113
+ 'idx': idx
114
+ })
115
+
116
+ # Re-rank with cross-encoder
117
+ if candidates:
118
+ pairs = [[question, cand['text']] for cand in candidates]
119
+ cross_scores = cross_encoder.predict(pairs)
120
+
121
+ # Add cross-encoder scores and sort
122
+ for i, score in enumerate(cross_scores):
123
+ candidates[i]['cross_score'] = score
124
+
125
+ candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
126
+
127
+ # Collect citations
128
+ citations = []
129
+ sources_seen = set()
130
+
131
+ for chunk in candidates:
132
+ chunk_meta = chunk['metadata']
133
+
134
+ if chunk_meta['source'] not in sources_seen:
135
+ citation = {
136
+ 'source': chunk_meta['source'],
137
+ 'type': chunk_meta['type'],
138
+ 'bm25_score': round(chunk['bm25_score'], 3),
139
+ 'rerank_score': round(chunk['cross_score'], 3)
140
+ }
141
+
142
+ if chunk_meta['type'] == 'pdf':
143
+ citation['path'] = chunk_meta['path']
144
+ else:
145
+ citation['url'] = chunk_meta.get('url', '')
146
+
147
+ citations.append(citation)
148
+ sources_seen.add(chunk_meta['source'])
149
+
150
+ # Handle image if provided
151
+ image_context = ""
152
+ if image_path:
153
+ try:
154
+ classification = utils.classify_image(image_path)
155
+ # classification is a string, not a dict
156
+ image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]"
157
+ except Exception as e:
158
+ print(f"Error processing image: {e}")
159
+
160
+ # Build context from retrieved chunks
161
+ context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
162
+
163
+ if not context:
164
+ return "No relevant documents found for your query.", []
165
+
166
+ # Generate answer
167
+ prompt = f"""Answer the following question using the retrieved documents:
168
+
169
+ Retrieved Documents:
170
+ {context}{image_context}
171
+
172
+ Question: {question}
173
+
174
+ Instructions:
175
+ 1. Provide a comprehensive answer based on the retrieved documents
176
+ 2. Mention specific details from the sources
177
+ 3. If the documents don't fully answer the question, indicate what information is missing"""
178
+
179
+ # For GPT-5, temperature must be default (1.0)
180
+ response = client.chat.completions.create(
181
+ model=config.OPENAI_CHAT_MODEL,
182
+ messages=[
183
+ {"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."},
184
+ {"role": "user", "content": prompt}
185
+ ],
186
+ max_completion_tokens=config.DEFAULT_MAX_TOKENS
187
+ )
188
+
189
+ answer = response.choices[0].message.content
190
+
191
+ return answer, citations
192
+
193
+
194
+ def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]:
195
+ """
196
+ Hybrid search combining BM25 and semantic search.
197
+
198
+ Args:
199
+ question: User's question
200
+ top_k: Number of relevant chunks to retrieve
201
+ alpha: Weight for BM25 scores (1-alpha for semantic)
202
+
203
+ Returns:
204
+ Tuple of (answer, citations)
205
+ """
206
+ if top_k is None:
207
+ top_k = config.DEFAULT_TOP_K
208
+ if alpha is None:
209
+ alpha = config.DEFAULT_HYBRID_ALPHA
210
+
211
+ # Load index if not already loaded
212
+ _load_bm25_index()
213
+
214
+ if _bm25_index is None or _semantic_index is None:
215
+ return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", []
216
+
217
+ # Get BM25 scores
218
+ tokenized_query = question.lower().split()
219
+ bm25_scores = _bm25_index.get_scores(tokenized_query)
220
+
221
+ # Normalize BM25 scores
222
+ if bm25_scores.max() > 0:
223
+ bm25_scores = bm25_scores / bm25_scores.max()
224
+
225
+ # Get semantic scores using FAISS
226
+ embedding_generator = utils.EmbeddingGenerator()
227
+ query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32)
228
+ faiss.normalize_L2(query_embedding)
229
+
230
+ # Search semantic index for all documents
231
+ k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER)
232
+ distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search)
233
+
234
+ # Create semantic scores array
235
+ semantic_scores = np.zeros(len(_texts))
236
+ for idx, dist in zip(indices[0], distances[0]):
237
+ if idx < len(_texts):
238
+ semantic_scores[idx] = dist
239
+
240
+ # Combine scores
241
+ hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores
242
+
243
+ # Get top candidates
244
+ top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
245
+
246
+ # Prepare candidates
247
+ candidates = []
248
+ for idx in top_indices:
249
+ if idx < len(_texts) and hybrid_scores[idx] > 0:
250
+ candidates.append({
251
+ 'text': _texts[idx],
252
+ 'hybrid_score': hybrid_scores[idx],
253
+ 'bm25_score': bm25_scores[idx],
254
+ 'semantic_score': semantic_scores[idx],
255
+ 'metadata': _metadata[idx],
256
+ 'idx': idx
257
+ })
258
+
259
+ # Re-rank with cross-encoder
260
+ if candidates:
261
+ pairs = [[question, cand['text']] for cand in candidates]
262
+ cross_scores = cross_encoder.predict(pairs)
263
+
264
+ for i, score in enumerate(cross_scores):
265
+ candidates[i]['cross_score'] = score
266
+
267
+ # Final ranking using cross-encoder scores
268
+ candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
269
+
270
+ # Collect citations
271
+ citations = []
272
+ sources_seen = set()
273
+
274
+ for chunk in candidates:
275
+ chunk_meta = chunk['metadata']
276
+
277
+ if chunk_meta['source'] not in sources_seen:
278
+ citation = {
279
+ 'source': chunk_meta['source'],
280
+ 'type': chunk_meta['type'],
281
+ 'hybrid_score': round(chunk['hybrid_score'], 3),
282
+ 'rerank_score': round(chunk.get('cross_score', 0), 3)
283
+ }
284
+
285
+ if chunk_meta['type'] == 'pdf':
286
+ citation['path'] = chunk_meta['path']
287
+ else:
288
+ citation['url'] = chunk_meta.get('url', '')
289
+
290
+ citations.append(citation)
291
+ sources_seen.add(chunk_meta['source'])
292
+
293
+ # Build context
294
+ context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
295
+
296
+ if not context:
297
+ return "No relevant documents found for your query.", []
298
+
299
+ # Generate answer
300
+ prompt = f"""Using the following retrieved passages, answer the question:
301
+
302
+ {context}
303
+
304
+ Question: {question}
305
+
306
+ Provide a clear, detailed answer based on the information in the passages."""
307
+
308
+ # For GPT-5, temperature must be default (1.0)
309
+ response = client.chat.completions.create(
310
+ model=config.OPENAI_CHAT_MODEL,
311
+ messages=[
312
+ {"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."},
313
+ {"role": "user", "content": prompt}
314
+ ],
315
+ max_completion_tokens=config.DEFAULT_MAX_TOKENS
316
+ )
317
+
318
+ answer = response.choices[0].message.content
319
+
320
+ return answer, citations
321
+
322
+
323
+ if __name__ == "__main__":
324
+ # Test BM25 query
325
+ test_questions = [
326
+ "lockout tagout procedures",
327
+ "machine guard requirements OSHA",
328
+ "robot safety collaborative workspace"
329
+ ]
330
+
331
+ for q in test_questions:
332
+ print(f"\nQuestion: {q}")
333
+ answer, citations = query(q)
334
+ print(f"Answer: {answer[:200]}...")
335
+ print(f"Citations: {citations}")
336
+ print("-" * 50)
query_context.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Context stuffing query module.
3
+ Loads full documents and uses heuristics to select relevant content.
4
+ """
5
+
6
+ import pickle
7
+ import logging
8
+ import re
9
+ from typing import List, Tuple, Optional, Dict, Any
10
+ from openai import OpenAI
11
+ import tiktoken
12
+
13
+ from config import *
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ContextStuffingRetriever:
18
+ """Context stuffing with heuristic document selection."""
19
+
20
+ def __init__(self):
21
+ self.client = OpenAI(api_key=OPENAI_API_KEY)
22
+ self.encoding = tiktoken.get_encoding("cl100k_base")
23
+ self.documents = None
24
+ self._load_documents()
25
+
26
+ def _load_documents(self):
27
+ """Load full documents for context stuffing."""
28
+ try:
29
+ if CONTEXT_DOCS.exists():
30
+ logger.info("Loading documents for context stuffing...")
31
+
32
+ with open(CONTEXT_DOCS, 'rb') as f:
33
+ data = pickle.load(f)
34
+
35
+ if isinstance(data, list) and len(data) > 0:
36
+ # Handle both old format (list of chunks) and new format (list of DocumentChunk objects)
37
+ if hasattr(data[0], 'text'): # New format with DocumentChunk objects
38
+ self.documents = []
39
+ for chunk in data:
40
+ self.documents.append({
41
+ 'text': chunk.text,
42
+ 'metadata': chunk.metadata,
43
+ 'chunk_id': chunk.chunk_id
44
+ })
45
+ else: # Old format with dict objects
46
+ self.documents = data
47
+
48
+ logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing")
49
+ else:
50
+ logger.warning("No documents found in context stuffing file")
51
+ self.documents = []
52
+ else:
53
+ logger.warning("Context stuffing documents not found. Run preprocess.py first.")
54
+ self.documents = []
55
+
56
+ except Exception as e:
57
+ logger.error(f"Error loading context stuffing documents: {e}")
58
+ self.documents = []
59
+
60
+ def _calculate_keyword_score(self, text: str, question: str) -> float:
61
+ """Calculate keyword overlap score between text and question."""
62
+ # Simple keyword matching heuristic
63
+ question_words = set(re.findall(r'\w+', question.lower()))
64
+ text_words = set(re.findall(r'\w+', text.lower()))
65
+
66
+ if not question_words:
67
+ return 0.0
68
+
69
+ overlap = len(question_words & text_words)
70
+ return overlap / len(question_words)
71
+
72
+ def _calculate_section_relevance(self, text: str, question: str) -> float:
73
+ """Calculate section relevance using multiple heuristics."""
74
+ score = 0.0
75
+
76
+ # Keyword overlap score (weight: 0.5)
77
+ keyword_score = self._calculate_keyword_score(text, question)
78
+ score += 0.5 * keyword_score
79
+
80
+ # Length penalty (prefer medium-length sections)
81
+ text_length = len(text.split())
82
+ optimal_length = 200 # words
83
+ length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length)
84
+ score += 0.2 * length_score
85
+
86
+ # Header/title bonus (if text starts with common header patterns)
87
+ if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()):
88
+ score += 0.1
89
+
90
+ # Question type specific bonuses
91
+ question_lower = question.lower()
92
+ text_lower = text.lower()
93
+
94
+ if any(word in question_lower for word in ['what', 'define', 'definition']):
95
+ if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']):
96
+ score += 0.2
97
+
98
+ if any(word in question_lower for word in ['how', 'procedure', 'steps']):
99
+ if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']):
100
+ score += 0.2
101
+
102
+ if any(word in question_lower for word in ['requirement', 'shall', 'must']):
103
+ if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']):
104
+ score += 0.2
105
+
106
+ return min(1.0, score) # Cap at 1.0
107
+
108
+ def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]:
109
+ """Select most relevant documents using heuristics."""
110
+ if not self.documents:
111
+ return []
112
+
113
+ if max_tokens is None:
114
+ max_tokens = MAX_CONTEXT_TOKENS
115
+
116
+ # Score all documents
117
+ scored_docs = []
118
+ for doc in self.documents:
119
+ text = doc.get('text', '')
120
+ if text.strip():
121
+ relevance_score = self._calculate_section_relevance(text, question)
122
+
123
+ doc_info = {
124
+ 'text': text,
125
+ 'metadata': doc.get('metadata', {}),
126
+ 'score': relevance_score,
127
+ 'token_count': len(self.encoding.encode(text))
128
+ }
129
+ scored_docs.append(doc_info)
130
+
131
+ # Sort by relevance score
132
+ scored_docs.sort(key=lambda x: x['score'], reverse=True)
133
+
134
+ # Select documents within token limit
135
+ selected_docs = []
136
+ total_tokens = 0
137
+
138
+ for doc in scored_docs:
139
+ if doc['score'] > 0.1: # Minimum relevance threshold
140
+ if total_tokens + doc['token_count'] <= max_tokens:
141
+ selected_docs.append(doc)
142
+ total_tokens += doc['token_count']
143
+ else:
144
+ # Try to include a truncated version
145
+ remaining_tokens = max_tokens - total_tokens
146
+ if remaining_tokens > 100: # Only if meaningful content can fit
147
+ truncated_text = self._truncate_text(doc['text'], remaining_tokens)
148
+ if truncated_text:
149
+ doc['text'] = truncated_text
150
+ doc['token_count'] = len(self.encoding.encode(truncated_text))
151
+ selected_docs.append(doc)
152
+ break
153
+
154
+ logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens")
155
+ return selected_docs
156
+
157
+ def _truncate_text(self, text: str, max_tokens: int) -> str:
158
+ """Truncate text to fit within token limit while preserving meaning."""
159
+ tokens = self.encoding.encode(text)
160
+ if len(tokens) <= max_tokens:
161
+ return text
162
+
163
+ # Truncate and try to end at a sentence boundary
164
+ truncated_tokens = tokens[:max_tokens]
165
+ truncated_text = self.encoding.decode(truncated_tokens)
166
+
167
+ # Try to end at a sentence boundary
168
+ sentences = re.split(r'[.!?]+', truncated_text)
169
+ if len(sentences) > 1:
170
+ # Remove the last incomplete sentence
171
+ truncated_text = '.'.join(sentences[:-1]) + '.'
172
+
173
+ return truncated_text
174
+
175
+ def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str:
176
+ """Generate answer using full context stuffing approach."""
177
+ if not context_docs:
178
+ return "I couldn't find any relevant documents to answer your question."
179
+
180
+ try:
181
+ # Assemble context from selected documents
182
+ context_parts = []
183
+ sources = []
184
+
185
+ for i, doc in enumerate(context_docs, 1):
186
+ text = doc['text']
187
+ metadata = doc['metadata']
188
+ source = metadata.get('source', f'Document {i}')
189
+
190
+ context_parts.append(f"=== {source} ===\n{text}")
191
+ if source not in sources:
192
+ sources.append(source)
193
+
194
+ full_context = "\n\n".join(context_parts)
195
+
196
+ # Create system message for context stuffing
197
+ system_message = (
198
+ "You are an expert in occupational safety and health regulations. "
199
+ "Answer the user's question using the provided regulatory documents and technical materials. "
200
+ "Provide comprehensive, accurate answers that directly address the question. "
201
+ "Reference specific sections or requirements when applicable. "
202
+ "If the provided context doesn't fully answer the question, clearly state what information is missing."
203
+ )
204
+
205
+ # Create user message
206
+ user_message = f"""Based on the following regulatory and technical documents, please answer this question:
207
+
208
+ QUESTION: {question}
209
+
210
+ DOCUMENTS:
211
+ {full_context}
212
+
213
+ Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well."""
214
+
215
+ # For GPT-5, temperature must be default (1.0)
216
+ response = self.client.chat.completions.create(
217
+ model=OPENAI_CHAT_MODEL,
218
+ messages=[
219
+ {"role": "system", "content": system_message},
220
+ {"role": "user", "content": user_message}
221
+ ],
222
+ max_completion_tokens=DEFAULT_MAX_TOKENS
223
+ )
224
+
225
+ answer = response.choices[0].message.content.strip()
226
+
227
+ # Add source information
228
+ if len(sources) > 1:
229
+ answer += f"\n\n*Sources consulted: {', '.join(sources)}*"
230
+ elif sources:
231
+ answer += f"\n\n*Source: {sources[0]}*"
232
+
233
+ return answer
234
+
235
+ except Exception as e:
236
+ logger.error(f"Error generating context stuffing answer: {e}")
237
+ return "I apologize, but I encountered an error while generating the answer using context stuffing."
238
+
239
+ # Global retriever instance
240
+ _retriever = None
241
+
242
+ def get_retriever() -> ContextStuffingRetriever:
243
+ """Get or create global context stuffing retriever instance."""
244
+ global _retriever
245
+ if _retriever is None:
246
+ _retriever = ContextStuffingRetriever()
247
+ return _retriever
248
+
249
+ def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
250
+ """
251
+ Main context stuffing query function with unified signature.
252
+
253
+ Args:
254
+ question: User question
255
+ image_path: Optional image path (not used in context stuffing but kept for consistency)
256
+ top_k: Not used in context stuffing (uses heuristic selection instead)
257
+
258
+ Returns:
259
+ Tuple of (answer, citations)
260
+ """
261
+ try:
262
+ retriever = get_retriever()
263
+
264
+ # Select relevant documents using heuristics
265
+ relevant_docs = retriever.select_relevant_documents(question)
266
+
267
+ if not relevant_docs:
268
+ return "I couldn't find any relevant documents to answer your question.", []
269
+
270
+ # Generate comprehensive answer
271
+ answer = retriever.generate_answer(question, relevant_docs)
272
+
273
+ # Prepare citations
274
+ citations = []
275
+ for i, doc in enumerate(relevant_docs, 1):
276
+ metadata = doc['metadata']
277
+ citations.append({
278
+ 'rank': i,
279
+ 'score': float(doc['score']),
280
+ 'source': metadata.get('source', 'Unknown'),
281
+ 'type': metadata.get('type', 'unknown'),
282
+ 'method': 'context_stuffing',
283
+ 'tokens_used': doc['token_count']
284
+ })
285
+
286
+ logger.info(f"Context stuffing query completed. Used {len(citations)} documents.")
287
+ return answer, citations
288
+
289
+ except Exception as e:
290
+ logger.error(f"Error in context stuffing query: {e}")
291
+ error_message = "I apologize, but I encountered an error while processing your question with context stuffing."
292
+ return error_message, []
293
+
294
+ def query_with_details(question: str, image_path: Optional[str] = None,
295
+ top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
296
+ """
297
+ Context stuffing query function that returns detailed chunk information (for compatibility).
298
+
299
+ Returns:
300
+ Tuple of (answer, citations, chunks)
301
+ """
302
+ answer, citations = query(question, image_path, top_k)
303
+
304
+ # Convert citations to chunk format for backward compatibility
305
+ chunks = []
306
+ for citation in citations:
307
+ chunks.append((
308
+ f"Document {citation['rank']} (Score: {citation['score']:.3f})",
309
+ citation['score'],
310
+ f"Context from {citation['source']} ({citation['tokens_used']} tokens)",
311
+ citation['source']
312
+ ))
313
+
314
+ return answer, citations, chunks
315
+
316
+ if __name__ == "__main__":
317
+ # Test the context stuffing system
318
+ test_question = "What are the general requirements for machine guarding?"
319
+
320
+ print("Testing context stuffing retrieval system...")
321
+ print(f"Question: {test_question}")
322
+ print("-" * 50)
323
+
324
+ try:
325
+ answer, citations = query(test_question)
326
+
327
+ print("Answer:")
328
+ print(answer)
329
+ print(f"\nCitations ({len(citations)} documents used):")
330
+ for citation in citations:
331
+ print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})")
332
+
333
+ except Exception as e:
334
+ print(f"Error during testing: {e}")
query_dpr.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dense Passage Retrieval (DPR) query module.
3
+ Uses bi-encoder for retrieval and cross-encoder for re-ranking.
4
+ """
5
+
6
+ import pickle
7
+ import logging
8
+ from typing import List, Tuple, Optional
9
+ import numpy as np
10
+ import faiss
11
+ from sentence_transformers import SentenceTransformer, CrossEncoder
12
+ from openai import OpenAI
13
+
14
+ from config import *
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class DPRRetriever:
19
+ """Dense Passage Retrieval with cross-encoder re-ranking."""
20
+
21
+ def __init__(self):
22
+ self.client = OpenAI(api_key=OPENAI_API_KEY)
23
+ self.bi_encoder = None
24
+ self.cross_encoder = None
25
+ self.index = None
26
+ self.metadata = None
27
+ self._load_models()
28
+ self._load_index()
29
+
30
+ def _load_models(self):
31
+ """Load bi-encoder and cross-encoder models."""
32
+ try:
33
+ logger.info("Loading DPR models...")
34
+ self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
35
+ self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
36
+
37
+ if DEVICE == "cuda":
38
+ self.bi_encoder = self.bi_encoder.to(DEVICE)
39
+ self.cross_encoder = self.cross_encoder.to(DEVICE)
40
+
41
+ logger.info("✓ DPR models loaded successfully")
42
+
43
+ except Exception as e:
44
+ logger.error(f"Error loading DPR models: {e}")
45
+ raise
46
+
47
+ def _load_index(self):
48
+ """Load FAISS index and metadata."""
49
+ try:
50
+ if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists():
51
+ logger.info("Loading DPR index and metadata...")
52
+
53
+ # Load FAISS index
54
+ self.index = faiss.read_index(str(DPR_FAISS_INDEX))
55
+
56
+ # Load metadata
57
+ with open(DPR_METADATA, 'rb') as f:
58
+ data = pickle.load(f)
59
+ self.metadata = data
60
+
61
+ logger.info(f"✓ Loaded DPR index with {len(self.metadata)} chunks")
62
+ else:
63
+ logger.warning("DPR index not found. Run preprocess.py first.")
64
+
65
+ except Exception as e:
66
+ logger.error(f"Error loading DPR index: {e}")
67
+ raise
68
+
69
+ def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
70
+ """Retrieve candidate passages using bi-encoder."""
71
+ if self.index is None or self.metadata is None:
72
+ raise ValueError("DPR index not loaded. Run preprocess.py first.")
73
+
74
+ try:
75
+ # Encode question with bi-encoder
76
+ question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True)
77
+
78
+ # Normalize for cosine similarity
79
+ faiss.normalize_L2(question_embedding)
80
+
81
+ # Search FAISS index
82
+ # Retrieve more candidates for re-ranking
83
+ retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata))
84
+ scores, indices = self.index.search(question_embedding, retrieve_k)
85
+
86
+ # Prepare candidates
87
+ candidates = []
88
+ for score, idx in zip(scores[0], indices[0]):
89
+ if idx < len(self.metadata):
90
+ chunk_data = self.metadata[idx]
91
+ candidates.append((
92
+ chunk_data['text'],
93
+ float(score),
94
+ chunk_data['metadata']
95
+ ))
96
+
97
+ logger.info(f"Retrieved {len(candidates)} candidates for re-ranking")
98
+ return candidates
99
+
100
+ except Exception as e:
101
+ logger.error(f"Error in candidate retrieval: {e}")
102
+ raise
103
+
104
+ def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]],
105
+ top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
106
+ """Re-rank candidates using cross-encoder."""
107
+ if not candidates:
108
+ return []
109
+
110
+ try:
111
+ # Prepare pairs for cross-encoder
112
+ pairs = [(question, candidate[0]) for candidate in candidates]
113
+
114
+ # Get cross-encoder scores
115
+ cross_scores = self.cross_encoder.predict(pairs)
116
+
117
+ # Combine with candidate data and re-sort
118
+ reranked = []
119
+ for i, (text, bi_score, metadata) in enumerate(candidates):
120
+ cross_score = float(cross_scores[i])
121
+
122
+ # Filter by minimum relevance score
123
+ if cross_score >= MIN_RELEVANCE_SCORE:
124
+ reranked.append((text, cross_score, metadata))
125
+
126
+ # Sort by cross-encoder score (descending)
127
+ reranked.sort(key=lambda x: x[1], reverse=True)
128
+
129
+ # Return top-k
130
+ final_results = reranked[:top_k]
131
+ logger.info(f"Re-ranked to {len(final_results)} final results")
132
+
133
+ return final_results
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error in re-ranking: {e}")
137
+ # Fall back to bi-encoder results
138
+ return candidates[:top_k]
139
+
140
+ def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str:
141
+ """Generate answer using GPT with retrieved context."""
142
+ if not context_chunks:
143
+ return "I couldn't find relevant information to answer your question."
144
+
145
+ try:
146
+ # Prepare context
147
+ context_parts = []
148
+ for i, (text, score, metadata) in enumerate(context_chunks, 1):
149
+ source = metadata.get('source', 'Unknown')
150
+ context_parts.append(f"[Context {i}] Source: {source}\n{text}")
151
+
152
+ context = "\n\n".join(context_parts)
153
+
154
+ # Create system message
155
+ system_message = (
156
+ "You are a helpful assistant specialized in occupational safety and health. "
157
+ "Answer questions based only on the provided context. "
158
+ "If the context doesn't contain enough information, say so clearly. "
159
+ "Always cite the source when referencing information."
160
+ )
161
+
162
+ # Create user message
163
+ user_message = f"Context:\n{context}\n\nQuestion: {question}"
164
+
165
+ # Generate response
166
+ # For GPT-5, temperature must be default (1.0)
167
+ response = self.client.chat.completions.create(
168
+ model=OPENAI_CHAT_MODEL,
169
+ messages=[
170
+ {"role": "system", "content": system_message},
171
+ {"role": "user", "content": user_message}
172
+ ],
173
+ max_completion_tokens=DEFAULT_MAX_TOKENS
174
+ )
175
+
176
+ return response.choices[0].message.content.strip()
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error generating answer: {e}")
180
+ return "I apologize, but I encountered an error while generating the answer."
181
+
182
+ # Global retriever instance
183
+ _retriever = None
184
+
185
+ def get_retriever() -> DPRRetriever:
186
+ """Get or create global DPR retriever instance."""
187
+ global _retriever
188
+ if _retriever is None:
189
+ _retriever = DPRRetriever()
190
+ return _retriever
191
+
192
+ def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
193
+ """
194
+ Main DPR query function with unified signature.
195
+
196
+ Args:
197
+ question: User question
198
+ image_path: Optional image path (not used in DPR but kept for consistency)
199
+ top_k: Number of top results to retrieve
200
+
201
+ Returns:
202
+ Tuple of (answer, citations)
203
+ """
204
+ try:
205
+ retriever = get_retriever()
206
+
207
+ # Step 1: Retrieve candidates with bi-encoder
208
+ candidates = retriever.retrieve_candidates(question, top_k)
209
+
210
+ if not candidates:
211
+ return "I couldn't find any relevant information for your question.", []
212
+
213
+ # Step 2: Re-rank with cross-encoder
214
+ reranked_candidates = retriever.rerank_candidates(question, candidates, top_k)
215
+
216
+ # Step 3: Generate answer
217
+ answer = retriever.generate_answer(question, reranked_candidates)
218
+
219
+ # Step 4: Prepare citations
220
+ citations = []
221
+ for i, (text, score, metadata) in enumerate(reranked_candidates, 1):
222
+ citations.append({
223
+ 'rank': i,
224
+ 'text': text,
225
+ 'score': float(score),
226
+ 'source': metadata.get('source', 'Unknown'),
227
+ 'type': metadata.get('type', 'unknown'),
228
+ 'method': 'dpr'
229
+ })
230
+
231
+ logger.info(f"DPR query completed. Retrieved {len(citations)} citations.")
232
+ return answer, citations
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error in DPR query: {e}")
236
+ error_message = "I apologize, but I encountered an error while processing your question with DPR."
237
+ return error_message, []
238
+
239
+ def query_with_details(question: str, image_path: Optional[str] = None,
240
+ top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]:
241
+ """
242
+ DPR query function that returns detailed chunk information (for compatibility).
243
+
244
+ Returns:
245
+ Tuple of (answer, citations, chunks)
246
+ """
247
+ answer, citations = query(question, image_path, top_k)
248
+
249
+ # Convert citations to chunk format for backward compatibility
250
+ chunks = []
251
+ for citation in citations:
252
+ chunks.append((
253
+ f"Rank {citation['rank']} (Score: {citation['score']:.3f})",
254
+ citation['score'],
255
+ citation['text'],
256
+ citation['source']
257
+ ))
258
+
259
+ return answer, citations, chunks
260
+
261
+ if __name__ == "__main__":
262
+ # Test the DPR system
263
+ test_question = "What are the general requirements for machine guarding?"
264
+
265
+ print("Testing DPR retrieval system...")
266
+ print(f"Question: {test_question}")
267
+ print("-" * 50)
268
+
269
+ try:
270
+ answer, citations = query(test_question)
271
+
272
+ print("Answer:")
273
+ print(answer)
274
+ print("\nCitations:")
275
+ for citation in citations:
276
+ print(f"- {citation['source']} (Score: {citation['score']:.3f})")
277
+
278
+ except Exception as e:
279
+ print(f"Error during testing: {e}")
query_graph.py CHANGED
@@ -1,140 +1,413 @@
1
- import os
 
 
 
 
2
  import numpy as np
3
- from dotenv import load_dotenv
 
4
  from openai import OpenAI
5
  import networkx as nx
6
  from sklearn.metrics.pairwise import cosine_similarity
7
 
 
 
 
 
 
8
  # Initialize OpenAI client
9
- load_dotenv(override=True)
10
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Load graph from GML
13
- G = nx.read_gml("graph.gml")
14
- enodes = list(G.nodes)
15
- embeddings = np.array([G.nodes[n]['embedding'] for n in enodes])
16
 
17
- def query_graph(question, top_k=5):
18
  """
19
- Embed the question, retrieve the top_k relevant chunks,
20
- and return: (answer, sources, chunks)
21
- - answer: generated response string
22
- - sources: list of unique source names
23
- - chunks: list of tuples (header, score, full_text, source_url_or_path)
 
 
 
 
24
  """
25
- # Embed question
 
 
 
 
 
 
 
26
  emb_resp = client.embeddings.create(
27
- model="text-embedding-3-large",
28
  input=question
29
  )
30
- q_vec = emb_resp.data[0].embedding
31
-
32
  # Compute cosine similarities
33
- sims = cosine_similarity([q_vec], embeddings)[0]
34
  idxs = sims.argsort()[::-1][:top_k]
35
-
36
  # Collect chunk-level info
37
  chunks = []
38
- sources = []
 
 
39
  for rank, i in enumerate(idxs, start=1):
40
- node = enodes[i]
41
- text = G.nodes[node]['text']
42
- header = text.split('\n', 1)[0].lstrip('# ').strip()
 
 
 
43
  score = sims[i]
44
- # Determine citation (URL for HTML, path for PDF)
45
- citation = G.nodes[node].get('url') or G.nodes[node].get('path') or G.nodes[node]['source']
46
- chunks.append((header, score, text, citation))
47
- sources.append(G.nodes[node]['source'])
48
- # Deduplicate sources
49
- sources = list(dict.fromkeys(sources))
50
-
51
- # Assemble prompt
52
- context = "\n\n---\n\n".join([c[2] for c in chunks])
53
- prompt = (
54
- "Use the following context to answer the question:\n\n" +
55
- context +
56
- f"\n\nQuestion: {question}\nAnswer:"
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Query chat model
 
 
 
 
 
 
60
  chat_resp = client.chat.completions.create(
61
- model="gpt-4o-mini",
62
  messages=[
63
- {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."},
64
- {"role": "user", "content": prompt}
65
- ]
 
66
  )
 
67
  answer = chat_resp.choices[0].message.content
 
 
68
 
69
- return answer, sources, chunks
70
 
 
71
  """
72
- Embed the user question, retrieve the top_k relevant chunks from the graph,
73
- assemble a prompt with those chunks, call the chat model, and return:
74
- - answer: the generated response
75
- - sources: unique list of source documents
76
- - chunks: list of (header, score, full_text) for the top_k passages
 
 
 
 
77
  """
78
- # Embed the question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  emb_resp = client.embeddings.create(
80
- model="text-embedding-3-large",
81
  input=question
82
  )
83
- q_vec = emb_resp.data[0].embedding
84
-
85
- # Compute similarities against all stored embeddings
86
- sims = cosine_similarity([q_vec], embeddings)[0]
87
  idxs = sims.argsort()[::-1][:top_k]
88
-
89
- # Gather chunk‑level info and sources
90
  chunks = []
91
- sources = []
 
 
92
  for i in idxs:
93
- node = enodes[i]
94
- text = G.nodes[node]['text']
95
- # Use the first line as the header
96
- header = text.split('\n', 1)[0].lstrip('# ').strip()
97
- score = sims[i]
98
- chunks.append((header, score, text))
99
- sources.append(G.nodes[node]['source'])
100
- # Deduplicate sources while preserving order
101
- sources = list(dict.fromkeys(sources))
102
-
103
- # Assemble the prompt from the chunk texts
104
- context_text = "\n\n---\n\n".join([chunk[2] for chunk in chunks])
105
- prompt = (
106
- "Use the following context to answer the question:\n\n"
107
- + context_text
108
- + f"\n\nQuestion: {question}\nAnswer:"
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # Call the chat model
112
- chat_resp = client.chat.completions.create(
113
- model="gpt-4o-mini",
 
 
 
 
 
 
 
114
  messages=[
115
- {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."},
116
- {"role": "user", "content": prompt}
117
- ]
 
118
  )
119
- answer = chat_resp.choices[0].message.content
 
 
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return answer, sources, chunks
122
 
123
 
124
- # Test queries
125
- # test_questions = [
126
- # "What are general machine guarding requirements?",
127
- # "Explain the key steps in lockout/tagout procedures."
128
- # ]
129
-
130
- # for q in test_questions:
131
- # answer, sources, chunks = query_graph(q)
132
- # print(f"Q: {q}")
133
- # print(f"Answer: {answer}\n")
134
- # print("Sources:")
135
- # for src in sources:
136
- # print(f"- {src}")
137
- # print("\nTop Chunks:")
138
- # for header, score, _, citation in chunks:
139
- # print(f" * {header} (score: {score:.2f}) from {citation}")
140
- # print("\n", "#"*40, "\n")
 
1
+ """
2
+ Graph-based RAG using NetworkX.
3
+ Updated to match the common query signature used by other methods.
4
+ """
5
+
6
  import numpy as np
7
+ import logging
8
+ from typing import Tuple, List, Optional
9
  from openai import OpenAI
10
  import networkx as nx
11
  from sklearn.metrics.pairwise import cosine_similarity
12
 
13
+ from config import *
14
+ from utils import classify_image
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
  # Initialize OpenAI client
19
+ client = OpenAI(api_key=OPENAI_API_KEY)
20
+
21
+ # Global variables for lazy loading
22
+ _graph = None
23
+ _enodes = None
24
+ _embeddings = None
25
+
26
+ def _load_graph():
27
+ """Lazy load graph database."""
28
+ global _graph, _enodes, _embeddings
29
+
30
+ if _graph is None:
31
+ try:
32
+ if GRAPH_FILE.exists():
33
+ logger.info("Loading graph database...")
34
+ _graph = nx.read_gml(str(GRAPH_FILE))
35
+ _enodes = list(_graph.nodes)
36
+ # Convert embeddings from lists back to numpy arrays
37
+ embeddings_list = []
38
+ for n in _enodes:
39
+ embedding = _graph.nodes[n]['embedding']
40
+ if isinstance(embedding, list):
41
+ embeddings_list.append(np.array(embedding))
42
+ else:
43
+ embeddings_list.append(embedding)
44
+ _embeddings = np.array(embeddings_list)
45
+ logger.info(f"✓ Loaded graph with {len(_enodes)} nodes")
46
+ else:
47
+ logger.warning("Graph database not found. Run preprocess.py first.")
48
+ _graph = nx.Graph()
49
+ _enodes = []
50
+ _embeddings = np.array([])
51
+ except Exception as e:
52
+ logger.error(f"Error loading graph: {e}")
53
+ _graph = nx.Graph()
54
+ _enodes = []
55
+ _embeddings = np.array([])
56
 
 
 
 
 
57
 
58
+ def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
59
  """
60
+ Query using graph-based retrieval.
61
+
62
+ Args:
63
+ question: User's question
64
+ image_path: Optional path to an image (for multimodal queries)
65
+ top_k: Number of relevant chunks to retrieve
66
+
67
+ Returns:
68
+ Tuple of (answer, citations)
69
  """
70
+
71
+ # Load graph if not already loaded
72
+ _load_graph()
73
+
74
+ if len(_enodes) == 0:
75
+ return "Graph database is empty. Please run preprocess.py first.", []
76
+
77
+ # Embed question using OpenAI
78
  emb_resp = client.embeddings.create(
79
+ model=OPENAI_EMBEDDING_MODEL,
80
  input=question
81
  )
82
+ q_vec = np.array(emb_resp.data[0].embedding)
83
+
84
  # Compute cosine similarities
85
+ sims = cosine_similarity([q_vec], _embeddings)[0]
86
  idxs = sims.argsort()[::-1][:top_k]
87
+
88
  # Collect chunk-level info
89
  chunks = []
90
+ citations = []
91
+ sources_seen = set()
92
+
93
  for rank, i in enumerate(idxs, start=1):
94
+ node = _enodes[i]
95
+ node_data = _graph.nodes[node]
96
+ text = node_data['text']
97
+
98
+ # Extract header from text
99
+ header = text.split('\n', 1)[0].lstrip('#').strip()
100
  score = sims[i]
101
+
102
+ # Extract citation format - get source from metadata or node_data
103
+ metadata = node_data.get('metadata', {})
104
+ source = metadata.get('source') or node_data.get('source')
105
+
106
+ if not source:
107
+ continue
108
+
109
+ if 'url' in metadata: # HTML source
110
+ citation_ref = metadata['url']
111
+ cite_type = 'html'
112
+ elif 'path' in metadata: # PDF source
113
+ citation_ref = metadata['path']
114
+ cite_type = 'pdf'
115
+ elif 'url' in node_data: # Legacy format
116
+ citation_ref = node_data['url']
117
+ cite_type = 'html'
118
+ elif 'path' in node_data: # Legacy format
119
+ citation_ref = node_data['path']
120
+ cite_type = 'pdf'
121
+ else:
122
+ citation_ref = source
123
+ cite_type = 'unknown'
124
+
125
+ chunks.append({
126
+ 'header': header,
127
+ 'score': score,
128
+ 'text': text,
129
+ 'citation': citation_ref
130
+ })
131
+
132
+ # Add unique citation
133
+ if source not in sources_seen:
134
+ citation_entry = {
135
+ 'source': source,
136
+ 'type': cite_type,
137
+ 'relevance_score': round(float(score), 3)
138
+ }
139
+
140
+ if cite_type == 'html':
141
+ citation_entry['url'] = citation_ref
142
+ elif cite_type == 'pdf':
143
+ citation_entry['path'] = citation_ref
144
+
145
+ citations.append(citation_entry)
146
+ sources_seen.add(source)
147
+
148
+ # Handle image if provided
149
+ image_context = ""
150
+ if image_path:
151
+ try:
152
+ # Classify the image
153
+ classification = classify_image(image_path)
154
+ image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]"
155
+
156
+ # Optionally, find related nodes in graph based on image classification
157
+ # This would require storing image-related metadata in the graph
158
+
159
+ except Exception as e:
160
+ print(f"Error processing image: {e}")
161
+
162
+ # Assemble context for prompt
163
+ context = "\n\n---\n\n".join([c['text'] for c in chunks])
164
+
165
+ prompt = f"""Use the following context to answer the question:
166
 
167
+ {context}{image_context}
168
+
169
+ Question: {question}
170
+
171
+ Please provide a comprehensive answer based on the context provided. Cite specific sources when providing information."""
172
+
173
+ # For GPT-5, temperature must be default (1.0)
174
  chat_resp = client.chat.completions.create(
175
+ model=OPENAI_CHAT_MODEL,
176
  messages=[
177
+ {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always provide accurate information based on the given context."},
178
+ {"role": "user", "content": prompt}
179
+ ],
180
+ max_completion_tokens=DEFAULT_MAX_TOKENS
181
  )
182
+
183
  answer = chat_resp.choices[0].message.content
184
+
185
+ return answer, citations
186
 
 
187
 
188
+ def query_with_graph_traversal(question: str, top_k: int = 5, max_hops: int = 2) -> Tuple[str, List[dict]]:
189
  """
190
+ Enhanced graph query that can traverse edges to find related information.
191
+
192
+ Args:
193
+ question: User's question
194
+ top_k: Number of initial nodes to retrieve
195
+ max_hops: Maximum graph traversal depth
196
+
197
+ Returns:
198
+ Tuple of (answer, citations)
199
  """
200
+
201
+ # Load graph if not already loaded
202
+ _load_graph()
203
+
204
+ if len(_enodes) == 0:
205
+ return "Graph database is empty. Please run preprocess.py first.", []
206
+
207
+ # Get initial nodes using standard query
208
+ initial_answer, initial_citations = query(question, top_k=top_k)
209
+
210
+ # For a more sophisticated implementation, you would:
211
+ # 1. Add edges between related nodes during preprocessing
212
+ # 2. Traverse from initial nodes to find related content
213
+ # 3. Score the related nodes based on path distance and relevance
214
+
215
+ # For now, return the standard query results
216
+ return initial_answer, initial_citations
217
+
218
+
219
+ def query_subgraph(question: str, source_filter: str = None, top_k: int = 5) -> Tuple[str, List[dict]]:
220
+ """
221
+ Query a specific subgraph filtered by source.
222
+
223
+ Args:
224
+ question: User's question
225
+ source_filter: Filter nodes by source (e.g., specific PDF name)
226
+ top_k: Number of relevant chunks to retrieve
227
+
228
+ Returns:
229
+ Tuple of (answer, citations)
230
+ """
231
+
232
+ # Load graph if not already loaded
233
+ _load_graph()
234
+
235
+ # Filter nodes if source specified
236
+ if source_filter:
237
+ filtered_nodes = []
238
+ for n in _enodes:
239
+ node_data = _graph.nodes[n]
240
+ metadata = node_data.get('metadata', {})
241
+ source = metadata.get('source') or node_data.get('source', '')
242
+ source_from_meta = metadata.get('source', '')
243
+
244
+ # Check both direct source and metadata source
245
+ if (source_filter.lower() in source.lower() or
246
+ source_filter.lower() in source_from_meta.lower()):
247
+ filtered_nodes.append(n)
248
+
249
+ if not filtered_nodes:
250
+ return f"No nodes found for source: {source_filter}", []
251
+ else:
252
+ filtered_nodes = _enodes
253
+
254
+ # Get embeddings for filtered nodes
255
+ filtered_embeddings = np.array([_graph.nodes[n]['embedding'] for n in filtered_nodes])
256
+
257
+ # Embed question
258
  emb_resp = client.embeddings.create(
259
+ model=OPENAI_EMBEDDING_MODEL,
260
  input=question
261
  )
262
+ q_vec = np.array(emb_resp.data[0].embedding)
263
+
264
+ # Compute similarities
265
+ sims = cosine_similarity([q_vec], filtered_embeddings)[0]
266
  idxs = sims.argsort()[::-1][:top_k]
267
+
268
+ # Collect results
269
  chunks = []
270
+ citations = []
271
+ sources_seen = set()
272
+
273
  for i in idxs:
274
+ if i < len(filtered_nodes):
275
+ node = filtered_nodes[i]
276
+ node_data = _graph.nodes[node]
277
+
278
+ chunks.append(node_data['text'])
279
+
280
+ # Skip if source information missing
281
+ metadata = node_data.get('metadata', {})
282
+ source = metadata.get('source') or node_data.get('source')
283
+
284
+ if not source:
285
+ continue
286
+
287
+ if source not in sources_seen:
288
+ citation = {
289
+ 'source': source,
290
+ 'type': 'pdf' if ('path' in metadata or 'path' in node_data) else 'html',
291
+ 'relevance_score': round(float(sims[i]), 3)
292
+ }
293
+
294
+ # Check metadata first, then node_data for legacy support
295
+ if 'url' in metadata:
296
+ citation['url'] = metadata['url']
297
+ elif 'path' in metadata:
298
+ citation['path'] = metadata['path']
299
+ elif 'url' in node_data:
300
+ citation['url'] = node_data['url']
301
+ elif 'path' in node_data:
302
+ citation['path'] = node_data['path']
303
+
304
+ citations.append(citation)
305
+ sources_seen.add(source)
306
+
307
+ # Build context and generate answer
308
+ context = "\n\n---\n\n".join(chunks)
309
+
310
+ prompt = f"""Answer the following question using the provided context:
311
 
312
+ Context from {source_filter if source_filter else 'all sources'}:
313
+ {context}
314
+
315
+ Question: {question}
316
+
317
+ Provide a detailed answer based on the context."""
318
+
319
+ # For GPT-5, temperature must be default (1.0)
320
+ response = client.chat.completions.create(
321
+ model=OPENAI_CHAT_MODEL,
322
  messages=[
323
+ {"role": "system", "content": "You are an expert on manufacturing safety. Answer based on the provided context."},
324
+ {"role": "user", "content": prompt}
325
+ ],
326
+ max_completion_tokens=DEFAULT_MAX_TOKENS
327
  )
328
+
329
+ answer = response.choices[0].message.content
330
+
331
+ return answer, citations
332
 
333
+
334
+ # Maintain backward compatibility with original function signature
335
+ def query_graph(question: str, top_k: int = 5) -> Tuple[str, List[str], List[tuple]]:
336
+ """
337
+ Original query_graph function signature for backward compatibility.
338
+
339
+ Args:
340
+ question: User's question
341
+ top_k: Number of relevant chunks to retrieve
342
+
343
+ Returns:
344
+ Tuple of (answer, sources, chunks)
345
+ """
346
+
347
+ # Call the new query function
348
+ answer, citations = query(question, top_k=top_k)
349
+
350
+ # Convert citations to old format
351
+ sources = [c['source'] for c in citations]
352
+
353
+ # Get chunks in old format (header, score, text, citation)
354
+ _load_graph()
355
+
356
+ if len(_enodes) == 0:
357
+ return answer, sources, []
358
+
359
+ # Regenerate chunks for backward compatibility
360
+ emb_resp = client.embeddings.create(
361
+ model=OPENAI_EMBEDDING_MODEL,
362
+ input=question
363
+ )
364
+ q_vec = np.array(emb_resp.data[0].embedding)
365
+
366
+ sims = cosine_similarity([q_vec], _embeddings)[0]
367
+ idxs = sims.argsort()[::-1][:top_k]
368
+
369
+ chunks = []
370
+ for i in idxs:
371
+ node = _enodes[i]
372
+ node_data = _graph.nodes[node]
373
+ text = node_data['text']
374
+ header = text.split('\n', 1)[0].lstrip('#').strip()
375
+ score = sims[i]
376
+
377
+ # Skip if source information missing
378
+ metadata = node_data.get('metadata', {})
379
+ source = metadata.get('source') or node_data.get('source')
380
+
381
+ if not source:
382
+ continue
383
+
384
+ if 'url' in metadata:
385
+ citation = metadata['url']
386
+ elif 'path' in metadata:
387
+ citation = metadata['path']
388
+ elif 'url' in node_data:
389
+ citation = node_data['url']
390
+ elif 'path' in node_data:
391
+ citation = node_data['path']
392
+ else:
393
+ citation = source
394
+
395
+ chunks.append((header, score, text, citation))
396
+
397
  return answer, sources, chunks
398
 
399
 
400
+ if __name__ == "__main__":
401
+ # Test the updated graph query
402
+ test_questions = [
403
+ "What are general machine guarding requirements?",
404
+ "How do I perform lockout/tagout procedures?",
405
+ "What safety measures are needed for robotic systems?"
406
+ ]
407
+
408
+ for q in test_questions:
409
+ print(f"\nQuestion: {q}")
410
+ answer, citations = query(q)
411
+ print(f"Answer: {answer[:200]}...")
412
+ print(f"Citations: {[c['source'] for c in citations]}")
413
+ print("-" * 50)
 
 
 
query_vanilla.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vanilla vector search using FAISS index and OpenAI embeddings.
3
+ """
4
+
5
+ import numpy as np
6
+ import faiss
7
+ from typing import Tuple, List, Optional
8
+ from openai import OpenAI
9
+
10
+ import pickle
11
+ import logging
12
+ from config import *
13
+ from utils import EmbeddingGenerator, classify_image
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize OpenAI client
18
+ client = OpenAI(api_key=OPENAI_API_KEY)
19
+
20
+ # Global variables for lazy loading
21
+ _index = None
22
+ _texts = None
23
+ _metadata = None
24
+
25
+ def _load_vanilla_index():
26
+ """Lazy load vanilla FAISS index and metadata."""
27
+ global _index, _texts, _metadata
28
+
29
+ if _index is None:
30
+ try:
31
+ if VANILLA_FAISS_INDEX.exists() and VANILLA_METADATA.exists():
32
+ logger.info("Loading vanilla FAISS index...")
33
+
34
+ # Load FAISS index
35
+ _index = faiss.read_index(str(VANILLA_FAISS_INDEX))
36
+
37
+ # Load metadata
38
+ with open(VANILLA_METADATA, 'rb') as f:
39
+ data = pickle.load(f)
40
+
41
+ if isinstance(data, list):
42
+ # New format with metadata list
43
+ _texts = [item['text'] for item in data]
44
+ _metadata = [item['metadata'] for item in data]
45
+ else:
46
+ # Old format with dict
47
+ _texts = data.get('texts', [])
48
+ _metadata = data.get('metadata', [])
49
+
50
+ logger.info(f"✓ Loaded vanilla index with {len(_texts)} documents")
51
+ else:
52
+ logger.warning("Vanilla index not found. Run preprocess.py first.")
53
+ _index = None
54
+ _texts = []
55
+ _metadata = []
56
+
57
+ except Exception as e:
58
+ logger.error(f"Error loading vanilla index: {e}")
59
+ _index = None
60
+ _texts = []
61
+ _metadata = []
62
+
63
+
64
+ def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
65
+ """
66
+ Query using vanilla vector search.
67
+
68
+ Args:
69
+ question: User's question
70
+ image_path: Optional path to an image (for multimodal queries)
71
+ top_k: Number of relevant chunks to retrieve
72
+
73
+ Returns:
74
+ Tuple of (answer, citations)
75
+ """
76
+ if top_k is None:
77
+ top_k = DEFAULT_TOP_K
78
+
79
+ # Load index if not already loaded
80
+ _load_vanilla_index()
81
+
82
+ if _index is None or len(_texts) == 0:
83
+ return "Index not loaded. Please run preprocess.py first.", []
84
+
85
+ # Generate query embedding using embedding generator
86
+ embedding_gen = EmbeddingGenerator()
87
+ query_embedding = embedding_gen.embed_text_openai([question])
88
+
89
+ # Normalize for cosine similarity
90
+ query_embedding = query_embedding.astype(np.float32)
91
+ faiss.normalize_L2(query_embedding)
92
+
93
+ # Search the index
94
+ distances, indices = _index.search(query_embedding, top_k)
95
+
96
+ # Collect retrieved chunks and citations
97
+ retrieved_chunks = []
98
+ citations = []
99
+ sources_seen = set()
100
+
101
+ for idx, distance in zip(indices[0], distances[0]):
102
+ if idx < len(_texts) and distance > MIN_RELEVANCE_SCORE:
103
+ chunk_text = _texts[idx]
104
+ chunk_meta = _metadata[idx]
105
+
106
+ retrieved_chunks.append({
107
+ 'text': chunk_text,
108
+ 'score': float(distance),
109
+ 'metadata': chunk_meta
110
+ })
111
+
112
+ # Build citation
113
+ if chunk_meta['source'] not in sources_seen:
114
+ citation = {
115
+ 'source': chunk_meta['source'],
116
+ 'type': chunk_meta['type'],
117
+ 'relevance_score': round(float(distance), 3)
118
+ }
119
+
120
+ if chunk_meta['type'] == 'pdf':
121
+ citation['path'] = chunk_meta['path']
122
+ else: # HTML
123
+ citation['url'] = chunk_meta.get('url', '')
124
+
125
+ citations.append(citation)
126
+ sources_seen.add(chunk_meta['source'])
127
+
128
+ # Handle image if provided
129
+ image_context = ""
130
+ if image_path:
131
+ try:
132
+ classification = classify_image(image_path)
133
+ image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]"
134
+ except Exception as e:
135
+ logger.error(f"Error processing image: {e}")
136
+
137
+ # Build context for the prompt
138
+ context = "\n\n---\n\n".join([chunk['text'] for chunk in retrieved_chunks])
139
+
140
+ if not context:
141
+ return "No relevant documents found for your query.", []
142
+
143
+ # Generate answer using OpenAI
144
+ prompt = f"""Use the following context to answer the question:
145
+
146
+ {context}{image_context}
147
+
148
+ Question: {question}
149
+
150
+ Please provide a comprehensive answer based on the context provided. If the context doesn't contain enough information, say so."""
151
+
152
+ # For GPT-5, temperature must be default (1.0)
153
+ response = client.chat.completions.create(
154
+ model=OPENAI_CHAT_MODEL,
155
+ messages=[
156
+ {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always cite your sources when providing information."},
157
+ {"role": "user", "content": prompt}
158
+ ],
159
+ max_completion_tokens=DEFAULT_MAX_TOKENS
160
+ )
161
+
162
+ answer = response.choices[0].message.content
163
+
164
+ return answer, citations
165
+
166
+
167
+ def query_with_feedback(question: str, feedback_scores: List[float] = None, top_k: int = 5) -> Tuple[str, List[dict]]:
168
+ """
169
+ Query with relevance feedback to refine results.
170
+
171
+ Args:
172
+ question: User's question
173
+ feedback_scores: Optional relevance scores for previous results
174
+ top_k: Number of relevant chunks to retrieve
175
+
176
+ Returns:
177
+ Tuple of (answer, citations)
178
+ """
179
+ # For now, just use regular query
180
+ # TODO: Implement Rocchio algorithm or similar for relevance feedback
181
+ return query(question, top_k=top_k)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ # Test the vanilla query
186
+ test_questions = [
187
+ "What are general machine guarding requirements?",
188
+ "How do I perform lockout/tagout procedures?",
189
+ "What safety measures are needed for robotic systems?"
190
+ ]
191
+
192
+ for q in test_questions:
193
+ print(f"\nQuestion: {q}")
194
+ answer, citations = query(q)
195
+ print(f"Answer: {answer[:200]}...")
196
+ print(f"Citations: {[c['source'] for c in citations]}")
197
+ print("-" * 50)
query_vision.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vision-based query module using GPT-5 Vision.
3
+ Supports multimodal queries combining text and images.
4
+ """
5
+
6
+ import base64
7
+ import json
8
+ import logging
9
+ import sqlite3
10
+ from typing import List, Tuple, Optional, Dict, Any
11
+ import numpy as np
12
+ from PIL import Image
13
+ from openai import OpenAI
14
+
15
+ from config import *
16
+ from utils import ImageProcessor, classify_image
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class VisionRetriever:
21
+ """Vision-based retrieval using GPT-5 Vision for image analysis and classification."""
22
+
23
+ def __init__(self):
24
+ self.client = OpenAI(api_key=OPENAI_API_KEY)
25
+ self.image_processor = ImageProcessor()
26
+
27
+ def get_similar_images(self, query_image_path: str, top_k: int = 5) -> List[Dict[str, Any]]:
28
+ """Find similar images in the database based on classification similarity."""
29
+ try:
30
+ # Uses GPT-5 Vision for classification-based similarity search
31
+ # Note: This implementation uses classification similarity rather than embeddings
32
+
33
+ # Classify the query image
34
+ query_classification = classify_image(query_image_path)
35
+
36
+ # Query database for similar images
37
+ conn = sqlite3.connect(IMAGES_DB)
38
+ cursor = conn.cursor()
39
+
40
+ # Search for images with similar classification
41
+ cursor.execute("""
42
+ SELECT image_id, image_path, classification, metadata
43
+ FROM images
44
+ WHERE classification LIKE ?
45
+ ORDER BY created_at DESC
46
+ LIMIT ?
47
+ """, (f"%{query_classification}%", top_k))
48
+
49
+ results = cursor.fetchall()
50
+ conn.close()
51
+
52
+ similar_images = []
53
+ for row in results:
54
+ image_id, image_path, classification, metadata_json = row
55
+ metadata = json.loads(metadata_json) if metadata_json else {}
56
+
57
+ similar_images.append({
58
+ 'image_id': image_id,
59
+ 'image_path': image_path,
60
+ 'classification': classification,
61
+ 'metadata': metadata,
62
+ 'similarity_score': 0.8 # Classification-based similarity score
63
+ })
64
+
65
+ logger.info(f"Found {len(similar_images)} similar images for query")
66
+ return similar_images
67
+
68
+ except Exception as e:
69
+ logger.error(f"Error finding similar images: {e}")
70
+ return []
71
+
72
+ def analyze_image_safety(self, image_path: str, question: str = None) -> str:
73
+ """Analyze image for safety concerns using GPT-5 Vision."""
74
+ try:
75
+ # Convert image to base64
76
+ with open(image_path, "rb") as image_file:
77
+ image_b64 = base64.b64encode(image_file.read()).decode()
78
+
79
+ # Create analysis prompt
80
+ if question:
81
+ analysis_prompt = (
82
+ f"Analyze this image in the context of the following question: {question}\n\n"
83
+ "Please provide a detailed safety analysis covering:\n"
84
+ "1. What equipment, machinery, or workplace elements are visible\n"
85
+ "2. Any potential safety hazards or compliance issues\n"
86
+ "3. Relevant OSHA standards or regulations that may apply\n"
87
+ "4. Recommendations for safety improvements\n"
88
+ "5. How this relates to the specific question asked"
89
+ )
90
+ else:
91
+ analysis_prompt = (
92
+ "Analyze this image for occupational safety and health concerns. Provide:\n"
93
+ "1. Description of what's shown in the image\n"
94
+ "2. Identification of potential safety hazards\n"
95
+ "3. Relevant OSHA standards or safety regulations\n"
96
+ "4. Recommendations for improving safety"
97
+ )
98
+
99
+ messages = [{
100
+ "role": "user",
101
+ "content": [
102
+ {"type": "text", "text": analysis_prompt},
103
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}", "detail": "high"}}
104
+ ]
105
+ }]
106
+
107
+ # For GPT-5 vision, temperature must be default (1.0) and reasoning is not supported
108
+ response = self.client.chat.completions.create(
109
+ model=OPENAI_CHAT_MODEL,
110
+ messages=messages,
111
+ max_completion_tokens=DEFAULT_MAX_TOKENS
112
+ )
113
+
114
+ return response.choices[0].message.content.strip()
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error analyzing image: {e}")
118
+ return f"I encountered an error while analyzing the image: {e}"
119
+
120
+ def retrieve_relevant_text(self, image_classification: str, question: str, top_k: int = 3) -> List[Dict[str, Any]]:
121
+ """Retrieve text documents relevant to the image classification and question."""
122
+ # This would integrate with other retrieval methods to find relevant text
123
+ # For now, we'll create a simple keyword-based search
124
+
125
+ try:
126
+ # Import other query modules for text retrieval
127
+ from query_vanilla import query as vanilla_query
128
+
129
+ # Create an enhanced query combining image classification and original question
130
+ enhanced_question = f"safety requirements for {image_classification} {question}"
131
+
132
+ # Use vanilla retrieval to find relevant text
133
+ _, text_citations = vanilla_query(enhanced_question, top_k=top_k)
134
+
135
+ return text_citations
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error retrieving relevant text: {e}")
139
+ return []
140
+
141
+ def generate_multimodal_answer(self, question: str, image_analysis: str,
142
+ text_citations: List[Dict], similar_images: List[Dict]) -> str:
143
+ """Generate answer combining image analysis and text retrieval."""
144
+ try:
145
+ # Prepare context from text citations
146
+ text_context = ""
147
+ if text_citations:
148
+ text_parts = []
149
+ for i, citation in enumerate(text_citations, 1):
150
+ if 'text' in citation:
151
+ text_parts.append(f"[Text Source {i}] {citation['source']}: {citation['text'][:500]}...")
152
+ else:
153
+ text_parts.append(f"[Text Source {i}] {citation['source']}")
154
+ text_context = "\n\n".join(text_parts)
155
+
156
+ # Prepare context from similar images
157
+ image_context = ""
158
+ if similar_images:
159
+ image_parts = []
160
+ for img in similar_images[:3]: # Limit to top 3
161
+ source = img['metadata'].get('source', 'Unknown')
162
+ classification = img.get('classification', 'unknown')
163
+ image_parts.append(f"Similar image from {source}: classified as {classification}")
164
+ image_context = "\n".join(image_parts)
165
+
166
+ # Create comprehensive prompt
167
+ system_message = (
168
+ "You are an expert in occupational safety and health. "
169
+ "You have been provided with an image analysis, relevant text documents, "
170
+ "and information about similar images in the database. "
171
+ "Provide a comprehensive answer that integrates all this information."
172
+ )
173
+
174
+ user_message = f"""Question: {question}
175
+
176
+ Image Analysis:
177
+ {image_analysis}
178
+
179
+ Relevant Text Documentation:
180
+ {text_context}
181
+
182
+ Similar Images Context:
183
+ {image_context}
184
+
185
+ Please provide a comprehensive answer that:
186
+ 1. Addresses the specific question asked
187
+ 2. Incorporates insights from the image analysis
188
+ 3. References relevant regulatory information from the text sources
189
+ 4. Notes any connections to similar cases or images
190
+ 5. Provides actionable recommendations based on safety standards"""
191
+
192
+ # For GPT-5, temperature must be default (1.0) and reasoning is not supported
193
+ response = self.client.chat.completions.create(
194
+ model=OPENAI_CHAT_MODEL,
195
+ messages=[
196
+ {"role": "system", "content": system_message},
197
+ {"role": "user", "content": user_message}
198
+ ],
199
+ max_completion_tokens=DEFAULT_MAX_TOKENS * 2 # Allow longer response for comprehensive analysis
200
+ )
201
+
202
+ return response.choices[0].message.content.strip()
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error generating multimodal answer: {e}")
206
+ return "I apologize, but I encountered an error while generating the comprehensive answer."
207
+
208
+ # Global retriever instance
209
+ _retriever = None
210
+
211
+ def get_retriever() -> VisionRetriever:
212
+ """Get or create global vision retriever instance."""
213
+ global _retriever
214
+ if _retriever is None:
215
+ _retriever = VisionRetriever()
216
+ return _retriever
217
+
218
+ def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
219
+ """
220
+ Main vision-based query function with unified signature.
221
+
222
+ Args:
223
+ question: User question
224
+ image_path: Path to image file (required for vision queries)
225
+ top_k: Number of relevant results to retrieve
226
+
227
+ Returns:
228
+ Tuple of (answer, citations)
229
+ """
230
+ if not image_path:
231
+ return "Vision queries require an image. Please provide an image file.", []
232
+
233
+ try:
234
+ retriever = get_retriever()
235
+
236
+ # Step 1: Analyze the provided image
237
+ logger.info(f"Analyzing image: {image_path}")
238
+ image_analysis = retriever.analyze_image_safety(image_path, question)
239
+
240
+ # Step 2: Classify the image
241
+ image_classification = classify_image(image_path)
242
+
243
+ # Step 3: Find similar images
244
+ similar_images = retriever.get_similar_images(image_path, top_k=3)
245
+
246
+ # Step 4: Retrieve relevant text documents
247
+ text_citations = retriever.retrieve_relevant_text(image_classification, question, top_k)
248
+
249
+ # Step 5: Generate comprehensive multimodal answer
250
+ answer = retriever.generate_multimodal_answer(
251
+ question, image_analysis, text_citations, similar_images
252
+ )
253
+
254
+ # Step 6: Prepare citations
255
+ citations = []
256
+
257
+ # Add image analysis as primary citation
258
+ citations.append({
259
+ 'rank': 1,
260
+ 'type': 'image_analysis',
261
+ 'source': f"Analysis of {image_path.split('/')[-1] if '/' in image_path else image_path.split('\\')[-1]}",
262
+ 'method': 'vision',
263
+ 'classification': image_classification,
264
+ 'score': 1.0
265
+ })
266
+
267
+ # Add text citations
268
+ for i, citation in enumerate(text_citations, 2):
269
+ citation_copy = citation.copy()
270
+ citation_copy['rank'] = i
271
+ citation_copy['method'] = 'vision_text'
272
+ citations.append(citation_copy)
273
+
274
+ # Add similar images
275
+ for i, img in enumerate(similar_images):
276
+ citations.append({
277
+ 'rank': len(citations) + 1,
278
+ 'type': 'similar_image',
279
+ 'source': img['metadata'].get('source', 'Image Database'),
280
+ 'method': 'vision',
281
+ 'classification': img.get('classification', 'unknown'),
282
+ 'similarity_score': img.get('similarity_score', 0.0),
283
+ 'image_id': img.get('image_id')
284
+ })
285
+
286
+ logger.info(f"Vision query completed. Generated {len(citations)} citations.")
287
+ return answer, citations
288
+
289
+ except Exception as e:
290
+ logger.error(f"Error in vision query: {e}")
291
+ error_message = "I apologize, but I encountered an error while processing your vision-based question."
292
+ return error_message, []
293
+
294
+ def query_image_only(image_path: str, question: str = None) -> Tuple[str, List[Dict]]:
295
+ """
296
+ Analyze image without text retrieval (faster for simple image analysis).
297
+
298
+ Args:
299
+ image_path: Path to image file
300
+ question: Optional specific question about the image
301
+
302
+ Returns:
303
+ Tuple of (analysis, citations)
304
+ """
305
+ try:
306
+ retriever = get_retriever()
307
+
308
+ # Analyze image
309
+ analysis = retriever.analyze_image_safety(image_path, question)
310
+
311
+ # Classify image
312
+ classification = classify_image(image_path)
313
+
314
+ # Create citation for image analysis
315
+ citations = [{
316
+ 'rank': 1,
317
+ 'type': 'image_analysis',
318
+ 'source': f"Analysis of {image_path.split('/')[-1] if '/' in image_path else image_path.split('\\')[-1]}",
319
+ 'method': 'vision_only',
320
+ 'classification': classification,
321
+ 'score': 1.0
322
+ }]
323
+
324
+ return analysis, citations
325
+
326
+ except Exception as e:
327
+ logger.error(f"Error in image-only analysis: {e}")
328
+ return "Error analyzing image.", []
329
+
330
+ def query_with_details(question: str, image_path: Optional[str] = None,
331
+ top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
332
+ """
333
+ Vision query function that returns detailed chunk information (for compatibility).
334
+
335
+ Returns:
336
+ Tuple of (answer, citations, chunks)
337
+ """
338
+ answer, citations = query(question, image_path, top_k)
339
+
340
+ # Convert citations to chunk format for backward compatibility
341
+ chunks = []
342
+ for citation in citations:
343
+ if citation['type'] == 'image_analysis':
344
+ chunks.append((
345
+ f"Image Analysis ({citation['classification']})",
346
+ citation['score'],
347
+ "Analysis of uploaded image for safety compliance",
348
+ citation['source']
349
+ ))
350
+ elif citation['type'] == 'similar_image':
351
+ chunks.append((
352
+ f"Similar Image (Score: {citation.get('similarity_score', 0):.3f})",
353
+ citation.get('similarity_score', 0),
354
+ f"Similar image classified as {citation['classification']}",
355
+ citation['source']
356
+ ))
357
+ else:
358
+ chunks.append((
359
+ f"Text Reference {citation['rank']}",
360
+ citation.get('score', 0.5),
361
+ citation.get('text', 'Referenced document'),
362
+ citation['source']
363
+ ))
364
+
365
+ return answer, citations, chunks
366
+
367
+ if __name__ == "__main__":
368
+ # Test the vision system (requires an actual image file)
369
+ import sys
370
+
371
+ if len(sys.argv) > 1:
372
+ test_image_path = sys.argv[1]
373
+ test_question = "What safety issues can you identify in this image?"
374
+
375
+ print("Testing vision retrieval system...")
376
+ print(f"Image: {test_image_path}")
377
+ print(f"Question: {test_question}")
378
+ print("-" * 50)
379
+
380
+ try:
381
+ answer, citations = query(test_question, test_image_path)
382
+
383
+ print("Answer:")
384
+ print(answer)
385
+ print(f"\nCitations ({len(citations)}):")
386
+ for citation in citations:
387
+ print(f"- {citation['source']} (Type: {citation.get('type', 'unknown')})")
388
+
389
+ except Exception as e:
390
+ print(f"Error during testing: {e}")
391
+ else:
392
+ print("To test vision system, provide an image path as argument:")
393
+ print("python query_vision.py /path/to/image.jpg")
realtime_server.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for OpenAI Realtime API integration with RAG system.
3
+ Provides endpoints for session management and RAG tool calls.
4
+
5
+ Directory structure:
6
+ /data/ # Original PDFs, HTML
7
+ /embeddings/ # FAISS, Chroma, DPR vector stores
8
+ /graph/ # Graph database files
9
+ /metadata/ # Image metadata (SQLite or MongoDB)
10
+ """
11
+
12
+ import json
13
+ import logging
14
+ import os
15
+ import time
16
+ from typing import Dict, Any, Optional
17
+ from fastapi import FastAPI, HTTPException, Request, Response, status
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from fastapi.responses import JSONResponse
20
+ from fastapi.exceptions import RequestValidationError
21
+ from starlette.exceptions import HTTPException as StarletteHTTPException
22
+ from pydantic import BaseModel
23
+ import uvicorn
24
+ from openai import OpenAI
25
+
26
+ # Import all query modules
27
+ from query_graph import query as graph_query
28
+ from query_vanilla import query as vanilla_query
29
+ from query_dpr import query as dpr_query
30
+ from query_bm25 import query as bm25_query
31
+ from query_context import query as context_query
32
+ from query_vision import query as vision_query
33
+
34
+ from config import OPENAI_API_KEY, OPENAI_CHAT_MODEL, OPENAI_REALTIME_MODEL, REALTIME_VOICE, REALTIME_INSTRUCTIONS, DEFAULT_METHOD
35
+ from analytics_db import log_query
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # Initialize FastAPI app
40
+ app = FastAPI(title="SIGHT Realtime API Server", version="1.0.0")
41
+
42
+ # CORS middleware for frontend integration
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=["*"], # In production, restrict to your domain
46
+ allow_credentials=True,
47
+ allow_methods=["*"],
48
+ allow_headers=["*"],
49
+ )
50
+
51
+ @app.middleware("http")
52
+ async def log_requests(request: Request, call_next):
53
+ """Log all incoming requests for debugging."""
54
+ logger.info(f"Incoming request: {request.method} {request.url}")
55
+ try:
56
+ response = await call_next(request)
57
+ logger.info(f"Response status: {response.status_code}")
58
+ return response
59
+ except Exception as e:
60
+ logger.error(f"Request processing error: {e}")
61
+ return JSONResponse(
62
+ content={"error": "Internal server error"},
63
+ status_code=500
64
+ )
65
+
66
+ # Exception handlers
67
+ @app.exception_handler(RequestValidationError)
68
+ async def validation_exception_handler(request: Request, exc: RequestValidationError):
69
+ logger.warning(f"Validation error for {request.url}: {exc}")
70
+ return JSONResponse(
71
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
72
+ content={"error": "Invalid request format", "details": str(exc)}
73
+ )
74
+
75
+ @app.exception_handler(StarletteHTTPException)
76
+ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
77
+ logger.warning(f"HTTP error for {request.url}: {exc.status_code} - {exc.detail}")
78
+ return JSONResponse(
79
+ status_code=exc.status_code,
80
+ content={"error": exc.detail}
81
+ )
82
+
83
+ @app.exception_handler(Exception)
84
+ async def general_exception_handler(request: Request, exc: Exception):
85
+ logger.error(f"Unhandled error for {request.url}: {exc}")
86
+ return JSONResponse(
87
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
88
+ content={"error": "Internal server error"}
89
+ )
90
+
91
+ # Initialize OpenAI client
92
+ client = OpenAI(api_key=OPENAI_API_KEY)
93
+
94
+ # Query method dispatch
95
+ QUERY_DISPATCH = {
96
+ 'graph': graph_query,
97
+ 'vanilla': vanilla_query,
98
+ 'dpr': dpr_query,
99
+ 'bm25': bm25_query,
100
+ 'context': context_query,
101
+ 'vision': vision_query
102
+ }
103
+
104
+ # Use configuration from config.py with environment variable overrides
105
+ REALTIME_MODEL = os.getenv("REALTIME_MODEL", OPENAI_REALTIME_MODEL)
106
+ VOICE = os.getenv("REALTIME_VOICE", REALTIME_VOICE)
107
+ INSTRUCTIONS = os.getenv("REALTIME_INSTRUCTIONS", REALTIME_INSTRUCTIONS)
108
+
109
+ # Pydantic models for request/response
110
+ class SessionRequest(BaseModel):
111
+ """Request model for creating ephemeral sessions."""
112
+ model: Optional[str] = "gpt-4o-realtime-preview"
113
+ instructions: Optional[str] = None
114
+ voice: Optional[str] = None
115
+
116
+ class RAGRequest(BaseModel):
117
+ """Request model for RAG queries."""
118
+ query: str
119
+ method: str = "graph"
120
+ top_k: int = 5
121
+ image_path: Optional[str] = None
122
+
123
+ class RAGResponse(BaseModel):
124
+ """Response model for RAG queries."""
125
+ answer: str
126
+ citations: list
127
+ method: str
128
+ citations_html: Optional[str] = None
129
+
130
+ @app.post("/session")
131
+ async def create_ephemeral_session(request: SessionRequest) -> JSONResponse:
132
+ """
133
+ Create an ephemeral session token for OpenAI Realtime API.
134
+ This token will be used by the frontend WebRTC client.
135
+ """
136
+ try:
137
+ logger.info(f"Creating ephemeral session with model: {request.model or REALTIME_MODEL}")
138
+
139
+ # Create ephemeral token using direct HTTP call to OpenAI API
140
+ # Since the Python SDK doesn't support realtime sessions yet
141
+ import requests
142
+
143
+ session_data = {
144
+ "model": request.model or REALTIME_MODEL,
145
+ "voice": request.voice or VOICE,
146
+ "modalities": ["audio", "text"],
147
+ "instructions": request.instructions or INSTRUCTIONS,
148
+ }
149
+
150
+ headers = {
151
+ "Authorization": f"Bearer {OPENAI_API_KEY}",
152
+ "Content-Type": "application/json"
153
+ }
154
+
155
+ # Make direct HTTP request to OpenAI's realtime sessions endpoint
156
+ response = requests.post(
157
+ "https://api.openai.com/v1/realtime/sessions",
158
+ json=session_data,
159
+ headers=headers,
160
+ timeout=30
161
+ )
162
+
163
+ if response.status_code == 200:
164
+ session_result = response.json()
165
+
166
+ response_data = {
167
+ "client_secret": session_result.get("client_secret", {}).get("value") or session_result.get("client_secret"),
168
+ "model": request.model or REALTIME_MODEL,
169
+ "session_id": session_result.get("id")
170
+ }
171
+
172
+ logger.info("Ephemeral session created successfully")
173
+ return JSONResponse(content=response_data, status_code=200)
174
+ else:
175
+ logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
176
+ return JSONResponse(
177
+ content={"error": f"OpenAI API error: {response.status_code} - {response.text}"},
178
+ status_code=response.status_code
179
+ )
180
+
181
+ except requests.exceptions.RequestException as e:
182
+ logger.error(f"Network error creating ephemeral session: {e}")
183
+ return JSONResponse(
184
+ content={"error": f"Network error: {str(e)}"},
185
+ status_code=500
186
+ )
187
+ except Exception as e:
188
+ logger.error(f"Error creating ephemeral session: {e}")
189
+ return JSONResponse(
190
+ content={"error": f"Session creation failed: {str(e)}"},
191
+ status_code=500
192
+ )
193
+
194
+ @app.post("/rag", response_model=RAGResponse)
195
+ async def rag_query(request: RAGRequest) -> RAGResponse:
196
+ """
197
+ Handle RAG queries from the realtime interface.
198
+ This endpoint is called by the JavaScript frontend when the model
199
+ requests the ask_rag function.
200
+ """
201
+ try:
202
+ logger.info(f"RAG query: {request.query} using method: {request.method}")
203
+
204
+ # Validate and default method if needed
205
+ method = request.method
206
+ if method not in QUERY_DISPATCH:
207
+ logger.warning(f"Invalid method '{method}', using default '{DEFAULT_METHOD}'")
208
+ method = DEFAULT_METHOD
209
+
210
+ # Get the appropriate query function
211
+ query_func = QUERY_DISPATCH[method]
212
+
213
+ # Execute the query
214
+ start_time = time.time()
215
+
216
+ answer, citations = query_func(
217
+ question=request.query,
218
+ image_path=request.image_path,
219
+ top_k=request.top_k
220
+ )
221
+ response_time = (time.time() - start_time) * 1000 # Convert to ms
222
+
223
+ # Format citations for HTML display (optional)
224
+ citations_html = format_citations_html(citations, method)
225
+
226
+ # Log to analytics database (mark as voice interaction)
227
+ try:
228
+ # Generate unique session ID for each voice interaction
229
+ import uuid
230
+ voice_session_id = f"voice_{uuid.uuid4().hex[:8]}"
231
+
232
+ log_query(
233
+ user_query=request.query,
234
+ method=method,
235
+ answer=answer,
236
+ citations=citations,
237
+ response_time=response_time,
238
+ image_path=request.image_path,
239
+ top_k=request.top_k,
240
+ session_id=voice_session_id,
241
+ additional_settings={'voice_interaction': True, 'interaction_type': 'speech_to_speech'}
242
+ )
243
+ logger.info(f"Voice interaction logged: {request.query[:50]}...")
244
+ except Exception as log_error:
245
+ logger.error(f"Failed to log voice query: {log_error}")
246
+
247
+ logger.info(f"RAG query completed: {len(answer)} chars, {len(citations)} citations")
248
+
249
+ return RAGResponse(
250
+ answer=answer,
251
+ citations=citations,
252
+ method=method,
253
+ citations_html=citations_html
254
+ )
255
+
256
+ except Exception as e:
257
+ logger.error(f"Error processing RAG query: {e}")
258
+ raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}")
259
+
260
+ def format_citations_html(citations: list, method: str) -> str:
261
+ """Format citations as HTML for display."""
262
+ if not citations:
263
+ return "<p><em>No citations available</em></p>"
264
+
265
+ html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"]
266
+
267
+ for citation in citations:
268
+ if isinstance(citation, dict) and 'source' in citation:
269
+ source = citation['source']
270
+ cite_type = citation.get('type', 'unknown')
271
+
272
+ # Build citation text based on type
273
+ if cite_type == 'pdf':
274
+ cite_text = f"📄 {source} (PDF)"
275
+ elif cite_type == 'html':
276
+ url = citation.get('url', '')
277
+ if url:
278
+ cite_text = f"🌐 <a href='{url}' target='_blank'>{source}</a> (Web)"
279
+ else:
280
+ cite_text = f"🌐 {source} (Web)"
281
+ elif cite_type == 'image':
282
+ page = citation.get('page', 'N/A')
283
+ cite_text = f"🖼️ {source} (Image, page {page})"
284
+ else:
285
+ cite_text = f"📚 {source}"
286
+
287
+ # Add scores if available
288
+ scores = []
289
+ if 'relevance_score' in citation:
290
+ scores.append(f"relevance: {citation['relevance_score']:.3f}")
291
+ if 'score' in citation:
292
+ scores.append(f"score: {citation['score']:.3f}")
293
+
294
+ if scores:
295
+ cite_text += f" <small>({', '.join(scores)})</small>"
296
+
297
+ html_parts.append(f"<li>{cite_text}</li>")
298
+ elif isinstance(citation, (list, tuple)) and len(citation) >= 4:
299
+ # Handle legacy citation format (header, score, text, source)
300
+ header, score, text, source = citation[:4]
301
+ cite_text = f"📚 {source} <small>(score: {score:.3f})</small>"
302
+ html_parts.append(f"<li>{cite_text}</li>")
303
+
304
+ html_parts.append("</ul></div>")
305
+ return "".join(html_parts)
306
+
307
+ @app.get("/")
308
+ async def root():
309
+ """Root endpoint to prevent invalid HTTP request warnings."""
310
+ return {
311
+ "service": "SIGHT Realtime API Server",
312
+ "version": "1.0.0",
313
+ "status": "running",
314
+ "endpoints": {
315
+ "session": "POST /session - Create realtime session",
316
+ "rag": "POST /rag - Query RAG system",
317
+ "health": "GET /health - Health check",
318
+ "methods": "GET /methods - List available RAG methods"
319
+ }
320
+ }
321
+
322
+ @app.get("/health")
323
+ async def health_check():
324
+ """Health check endpoint."""
325
+ return {"status": "healthy", "service": "SIGHT Realtime API Server"}
326
+
327
+ @app.get("/methods")
328
+ async def list_methods():
329
+ """List available RAG methods."""
330
+ return {
331
+ "methods": list(QUERY_DISPATCH.keys()),
332
+ "descriptions": {
333
+ 'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval",
334
+ 'vanilla': "Standard vector search with FAISS and OpenAI embeddings",
335
+ 'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking",
336
+ 'bm25': "BM25 keyword search with neural re-ranking for exact term matching",
337
+ 'context': "Context stuffing with full document loading and heuristic selection",
338
+ 'vision': "Vision-based search using GPT-5 Vision for image analysis"
339
+ }
340
+ }
341
+
342
+ @app.options("/{full_path:path}")
343
+ async def options_handler(request: Request, response: Response):
344
+ """Handle CORS preflight requests."""
345
+ response.headers["Access-Control-Allow-Origin"] = "*"
346
+ response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
347
+ response.headers["Access-Control-Allow-Headers"] = "*"
348
+ return response
349
+
350
+ if __name__ == "__main__":
351
+ import argparse
352
+
353
+ # Parse command line arguments
354
+ parser = argparse.ArgumentParser(description="SIGHT Realtime API Server")
355
+ parser.add_argument("--https", action="store_true", help="Enable HTTPS with self-signed certificate")
356
+ parser.add_argument("--port", type=int, default=5050, help="Port to run the server on")
357
+ parser.add_argument("--host", default="0.0.0.0", help="Host to bind the server to")
358
+ args = parser.parse_args()
359
+
360
+ # Configure logging
361
+ logging.basicConfig(
362
+ level=logging.INFO,
363
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
364
+ )
365
+
366
+ # Suppress uvicorn access logs for cleaner output
367
+ uvicorn_logger = logging.getLogger("uvicorn.access")
368
+ uvicorn_logger.setLevel(logging.WARNING)
369
+
370
+ # Prepare uvicorn configuration
371
+ uvicorn_config = {
372
+ "app": "realtime_server:app",
373
+ "host": args.host,
374
+ "port": args.port,
375
+ "reload": True,
376
+ "log_level": "warning",
377
+ "access_log": False
378
+ }
379
+
380
+ # Add SSL configuration if HTTPS is requested
381
+ if args.https:
382
+ logger.info("Starting server with HTTPS (self-signed certificate)")
383
+ logger.warning("⚠️ Self-signed certificate will show security warnings in browser")
384
+ logger.info("For production, use a proper SSL certificate from a CA")
385
+
386
+ # Note: You would need to generate SSL certificates
387
+ # For development, you can create self-signed certificates:
388
+ # openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes
389
+ uvicorn_config.update({
390
+ "ssl_keyfile": "key.pem",
391
+ "ssl_certfile": "cert.pem"
392
+ })
393
+
394
+ print(f"🔒 Starting HTTPS server on https://{args.host}:{args.port}")
395
+ print("📝 To generate self-signed certificates, run:")
396
+ print(" openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes")
397
+ else:
398
+ print(f"🌐 Starting HTTP server on http://{args.host}:{args.port}")
399
+ print("⚠️ HTTP only works for localhost. Use --https for production deployment.")
400
+
401
+ # Run the server
402
+ uvicorn.run(**uvicorn_config)
requirements.txt CHANGED
@@ -1,10 +1,63 @@
1
- python-dotenv==1.1.1
2
- pymupdf4llm==0.0.27
3
- beautifulsoup4==4.13.4
4
- requests==2.32.4
5
- pandas==2.2.3
6
- openai==1.97.1
7
- networkx==3.5
8
- numpy==2.3.1
9
- scikit-learn==1.7.1
10
- streamlit==1.47.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies (updated versions)
2
+ python-dotenv>=1.1.1
3
+ pymupdf4llm>=0.0.27
4
+ beautifulsoup4>=4.13.4
5
+ requests>=2.32.4
6
+ pandas>=2.2.3
7
+ openai>=1.99.9
8
+ networkx>=3.5
9
+ numpy>=2.3.1
10
+ scikit-learn>=1.7.1
11
+ streamlit>=1.47.0
12
+
13
+ # FastAPI and realtime API dependencies
14
+ fastapi>=0.104.0 # For realtime API server
15
+ uvicorn[standard]>=0.24.0 # ASGI server for FastAPI
16
+ pydantic>=2.4.0 # Data validation and settings management
17
+
18
+ # Document processing
19
+ pymupdf>=1.24.0 # For PDF processing and image extraction
20
+ Pillow>=10.0.0 # For image processing
21
+ lxml>=5.0.0 # For HTML parsing
22
+ html5lib>=1.1 # Alternative HTML parser
23
+
24
+ # Vector stores and search
25
+ faiss-cpu>=1.8.0 # For vector similarity search (use faiss-gpu if CUDA available)
26
+ chromadb>=0.5.0 # Alternative vector database
27
+ rank-bm25>=0.2.2 # For BM25 keyword search
28
+
29
+ # Language models and embeddings
30
+ sentence-transformers>=3.0.0 # For DPR and cross-encoder
31
+ transformers>=4.40.0 # Required by sentence-transformers
32
+ torch>=2.0.0 # For neural models (CPU version)
33
+ # For GPU support, install separately:
34
+ # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
35
+ ftfy>=6.1.1 # Text preprocessing for CLIP
36
+ regex>=2023.0.0 # Text processing
37
+ # For CLIP (optional - enable if needed):
38
+ # git+https://github.com/openai/CLIP.git
39
+
40
+ # Token counting and management
41
+ tiktoken>=0.7.0 # For OpenAI token counting
42
+
43
+ # Database (optional)
44
+ # pymongo>=4.0.0 # Uncomment if using MongoDB for metadata
45
+
46
+ # Development and debugging
47
+ tqdm>=4.65.0 # Progress bars
48
+ ipython>=8.0.0 # For interactive debugging
49
+ jupyter>=1.0.0 # For notebook development
50
+
51
+ # Data visualization (optional)
52
+ matplotlib>=3.7.0 # For plotting
53
+ seaborn>=0.12.0 # Statistical visualization
54
+ plotly>=5.15.0 # Interactive plots
55
+
56
+ # Optional advanced features (uncomment if needed)
57
+ # langchain>=0.2.11 # For advanced RAG patterns
58
+ # langchain-openai>=0.1.20 # OpenAI integration for LangChain
59
+ # llama-index>=0.10.51 # Alternative RAG framework
60
+
61
+ # Additional utility packages
62
+ colorama>=0.4.6 # Colored console output
63
+ rich>=13.0.0 # Rich text and beautiful formatting in terminal
utils.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Multi-Method RAG System.
3
+
4
+ Directory Layout:
5
+ /data/ # Original PDFs, HTML
6
+ /embeddings/ # FAISS, Chroma, DPR vector stores
7
+ /graph/ # Graph database files
8
+ /metadata/ # Image metadata (SQLite or MongoDB)
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import pickle
14
+ import sqlite3
15
+ import base64
16
+ from pathlib import Path
17
+ from typing import List, Dict, Tuple, Optional, Any, Union
18
+ from dataclasses import dataclass
19
+ import logging
20
+
21
+ import pymupdf4llm
22
+ import pymupdf
23
+ import numpy as np
24
+ import pandas as pd
25
+ from PIL import Image
26
+ import requests
27
+ from bs4 import BeautifulSoup
28
+
29
+ # Vector stores and search
30
+ import faiss
31
+ import chromadb
32
+ from rank_bm25 import BM25Okapi
33
+ import networkx as nx
34
+
35
+ # ML models
36
+ from openai import OpenAI
37
+ from sentence_transformers import SentenceTransformer, CrossEncoder
38
+ import torch
39
+ # import clip
40
+
41
+ # Text processing
42
+ from sklearn.feature_extraction.text import TfidfVectorizer
43
+ import tiktoken
44
+
45
+ from config import *
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ @dataclass
50
+ class DocumentChunk:
51
+ """Data structure for document chunks."""
52
+ text: str
53
+ metadata: Dict[str, Any]
54
+ chunk_id: str
55
+ embedding: Optional[np.ndarray] = None
56
+
57
+ @dataclass
58
+ class ImageData:
59
+ """Data structure for image metadata."""
60
+ image_path: str
61
+ image_id: str
62
+ classification: Optional[str] = None
63
+ embedding: Optional[np.ndarray] = None
64
+ metadata: Optional[Dict[str, Any]] = None
65
+
66
+ class DocumentLoader:
67
+ """Load and extract text from various document formats."""
68
+
69
+ def __init__(self):
70
+ self.client = OpenAI(api_key=OPENAI_API_KEY)
71
+ validate_api_key()
72
+
73
+ def load_pdf_documents(self, pdf_paths: List[Union[str, Path]]) -> List[Dict[str, Any]]:
74
+ """Load text from PDF files using pymupdf4llm."""
75
+ documents = []
76
+
77
+ for pdf_path in pdf_paths:
78
+ try:
79
+ pdf_path = Path(pdf_path)
80
+ logger.info(f"Loading PDF: {pdf_path}")
81
+
82
+ # Extract text using pymupdf4llm
83
+ text = pymupdf4llm.to_markdown(str(pdf_path))
84
+
85
+ # Extract images if present
86
+ images = self._extract_pdf_images(pdf_path)
87
+
88
+ doc = {
89
+ 'text': text,
90
+ 'source': str(pdf_path.name),
91
+ 'path': str(pdf_path),
92
+ 'type': 'pdf',
93
+ 'images': images,
94
+ 'metadata': {
95
+ 'file_size': pdf_path.stat().st_size,
96
+ 'modified': pdf_path.stat().st_mtime
97
+ }
98
+ }
99
+ documents.append(doc)
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error loading PDF {pdf_path}: {e}")
103
+ continue
104
+
105
+ return documents
106
+
107
+ def _extract_pdf_images(self, pdf_path: Path) -> List[Dict[str, Any]]:
108
+ """Extract images from PDF using pymupdf."""
109
+ images = []
110
+
111
+ try:
112
+ doc = pymupdf.open(str(pdf_path))
113
+
114
+ for page_num in range(len(doc)):
115
+ page = doc[page_num]
116
+ image_list = page.get_images(full=True)
117
+
118
+ for img_index, img in enumerate(image_list):
119
+ try:
120
+ # Extract image
121
+ xref = img[0]
122
+ pix = pymupdf.Pixmap(doc, xref)
123
+
124
+ # Skip if pixmap is invalid or has no colorspace
125
+ if not pix or pix.colorspace is None:
126
+ if pix:
127
+ pix = None
128
+ continue
129
+
130
+ # Only process images with valid color channels
131
+ if pix.n - pix.alpha < 4: # GRAY or RGB
132
+ image_id = f"{pdf_path.stem}_p{page_num}_img{img_index}"
133
+ image_path = IMAGES_DIR / f"{image_id}.png"
134
+
135
+ # Convert to RGB if grayscale or other formats
136
+ if pix.n == 1: # Grayscale
137
+ rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
138
+ pix = None # Clean up original
139
+ pix = rgb_pix
140
+ elif pix.n == 4 and pix.alpha == 0: # CMYK
141
+ rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
142
+ pix = None # Clean up original
143
+ pix = rgb_pix
144
+
145
+ # Save image
146
+ pix.save(str(image_path))
147
+
148
+ images.append({
149
+ 'image_id': image_id,
150
+ 'image_path': str(image_path),
151
+ 'page': page_num,
152
+ 'source': str(pdf_path.name)
153
+ })
154
+
155
+ pix = None
156
+
157
+ except Exception as e:
158
+ logger.warning(f"Error extracting image {img_index} from page {page_num}: {e}")
159
+ if 'pix' in locals() and pix:
160
+ pix = None
161
+ continue
162
+
163
+ doc.close()
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error extracting images from {pdf_path}: {e}")
167
+
168
+ return images
169
+
170
+ def load_html_documents(self, html_sources: List[Dict[str, str]]) -> List[Dict[str, Any]]:
171
+ """Load text from HTML sources."""
172
+ documents = []
173
+
174
+ for source in html_sources:
175
+ try:
176
+ logger.info(f"Loading HTML: {source.get('title', source['url'])}")
177
+
178
+ # Fetch HTML content
179
+ response = requests.get(source['url'], timeout=30)
180
+ response.raise_for_status()
181
+
182
+ # Parse with BeautifulSoup
183
+ soup = BeautifulSoup(response.text, 'html.parser')
184
+
185
+ # Extract text
186
+ text = soup.get_text(separator=' ', strip=True)
187
+
188
+ doc = {
189
+ 'text': text,
190
+ 'source': source.get('title', source['url']),
191
+ 'path': source['url'],
192
+ 'type': 'html',
193
+ 'images': [],
194
+ 'metadata': {
195
+ 'url': source['url'],
196
+ 'title': source.get('title', ''),
197
+ 'year': source.get('year', ''),
198
+ 'category': source.get('category', ''),
199
+ 'format': source.get('format', 'HTML')
200
+ }
201
+ }
202
+ documents.append(doc)
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error loading HTML {source['url']}: {e}")
206
+ continue
207
+
208
+ return documents
209
+
210
+ def load_text_documents(self, data_dir: Path = DATA_DIR) -> List[Dict[str, Any]]:
211
+ """Load all supported document types from data directory."""
212
+ documents = []
213
+
214
+ # Load PDFs
215
+ pdf_files = list(data_dir.glob("*.pdf"))
216
+ if pdf_files:
217
+ documents.extend(self.load_pdf_documents(pdf_files))
218
+
219
+ # Load HTML sources (from config)
220
+ if DEFAULT_HTML_SOURCES:
221
+ documents.extend(self.load_html_documents(DEFAULT_HTML_SOURCES))
222
+
223
+ logger.info(f"Loaded {len(documents)} documents total")
224
+ return documents
225
+
226
+ class TextPreprocessor:
227
+ """Preprocess text for different retrieval methods."""
228
+
229
+ def __init__(self):
230
+ self.encoding = tiktoken.get_encoding("cl100k_base")
231
+
232
+ def chunk_text_by_tokens(self, text: str, chunk_size: int = CHUNK_SIZE,
233
+ overlap: int = CHUNK_OVERLAP) -> List[str]:
234
+ """Split text into chunks by token count."""
235
+ tokens = self.encoding.encode(text)
236
+ chunks = []
237
+
238
+ start = 0
239
+ while start < len(tokens):
240
+ end = start + chunk_size
241
+ chunk_tokens = tokens[start:end]
242
+ chunk_text = self.encoding.decode(chunk_tokens)
243
+ chunks.append(chunk_text)
244
+ start = end - overlap
245
+
246
+ return chunks
247
+
248
+ def chunk_text_by_sections(self, text: str, method: str = "vanilla") -> List[str]:
249
+ """Split text by sections based on method requirements."""
250
+ if method in ["vanilla", "dpr"]:
251
+ return self.chunk_text_by_tokens(text)
252
+ elif method == "bm25":
253
+ # BM25 works better with paragraph-level chunks
254
+ paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
255
+ return paragraphs
256
+ elif method == "graph":
257
+ # Graph method uses larger sections
258
+ return self.chunk_text_by_tokens(text, chunk_size=CHUNK_SIZE*2)
259
+ elif method == "context_stuffing":
260
+ # Context stuffing uses full documents
261
+ return [text]
262
+ else:
263
+ return self.chunk_text_by_tokens(text)
264
+
265
+ def preprocess_for_method(self, documents: List[Dict[str, Any]],
266
+ method: str) -> List[DocumentChunk]:
267
+ """Preprocess documents for specific retrieval method."""
268
+ chunks = []
269
+
270
+ for doc in documents:
271
+ text_chunks = self.chunk_text_by_sections(doc['text'], method)
272
+
273
+ for i, chunk_text in enumerate(text_chunks):
274
+ chunk_id = f"{doc['source']}_{method}_chunk_{i}"
275
+
276
+ chunk = DocumentChunk(
277
+ text=chunk_text,
278
+ metadata={
279
+ 'source': doc['source'],
280
+ 'path': doc['path'],
281
+ 'type': doc['type'],
282
+ 'chunk_index': i,
283
+ 'method': method,
284
+ **doc.get('metadata', {})
285
+ },
286
+ chunk_id=chunk_id
287
+ )
288
+ chunks.append(chunk)
289
+
290
+ logger.info(f"Created {len(chunks)} chunks for method '{method}'")
291
+ return chunks
292
+
293
+ class EmbeddingGenerator:
294
+ """Generate embeddings using various models."""
295
+
296
+ def __init__(self):
297
+ self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
298
+ self.sentence_transformer = None
299
+ # self.clip_model = None
300
+ # self.clip_preprocess = None
301
+
302
+ def _get_sentence_transformer(self):
303
+ """Lazy loading of sentence transformer."""
304
+ if self.sentence_transformer is None:
305
+ self.sentence_transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
306
+ if DEVICE == "cuda":
307
+ self.sentence_transformer = self.sentence_transformer.to(DEVICE)
308
+ return self.sentence_transformer
309
+
310
+ # def _get_clip_model(self):
311
+ # """Lazy loading of CLIP model."""
312
+ # if self.clip_model is None:
313
+ # self.clip_model, self.clip_preprocess = clip.load(CLIP_MODEL, device=DEVICE)
314
+ # return self.clip_model, self.clip_preprocess
315
+
316
+ def embed_text_openai(self, texts: List[str]) -> np.ndarray:
317
+ """Generate embeddings using OpenAI API."""
318
+ embeddings = []
319
+
320
+ # Process in batches
321
+ for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
322
+ batch = texts[i:i + EMBEDDING_BATCH_SIZE]
323
+
324
+ try:
325
+ response = self.openai_client.embeddings.create(
326
+ model=OPENAI_EMBEDDING_MODEL,
327
+ input=batch
328
+ )
329
+
330
+ batch_embeddings = [data.embedding for data in response.data]
331
+ embeddings.extend(batch_embeddings)
332
+
333
+ except Exception as e:
334
+ logger.error(f"Error generating OpenAI embeddings: {e}")
335
+ raise
336
+
337
+ return np.array(embeddings)
338
+
339
+ def embed_text_sentence_transformer(self, texts: List[str]) -> np.ndarray:
340
+ """Generate embeddings using sentence transformers."""
341
+ model = self._get_sentence_transformer()
342
+
343
+ try:
344
+ embeddings = model.encode(texts, convert_to_numpy=True,
345
+ show_progress_bar=True, batch_size=32)
346
+ return embeddings
347
+
348
+ except Exception as e:
349
+ logger.error(f"Error generating sentence transformer embeddings: {e}")
350
+ raise
351
+
352
+ def embed_image_clip(self, image_paths: List[str]) -> np.ndarray:
353
+ """Generate image embeddings using CLIP."""
354
+ # model, preprocess = self._get_clip_model()
355
+ # embeddings = []
356
+
357
+ # for image_path in image_paths:
358
+ # try:
359
+ # image = preprocess(Image.open(image_path)).unsqueeze(0).to(DEVICE)
360
+ #
361
+ # with torch.no_grad():
362
+ # image_features = model.encode_image(image)
363
+ # image_features /= image_features.norm(dim=-1, keepdim=True)
364
+ #
365
+ # embeddings.append(image_features.cpu().numpy().flatten())
366
+ #
367
+ # except Exception as e:
368
+ # logger.error(f"Error embedding image {image_path}: {e}")
369
+ # continue
370
+
371
+ # return np.array(embeddings) if embeddings else np.array([])
372
+
373
+ # Placeholder for CLIP embeddings
374
+ logger.warning("CLIP embeddings not implemented - returning dummy embeddings")
375
+ return np.random.rand(len(image_paths), 512)
376
+
377
+ class VectorStoreManager:
378
+ """Manage vector stores for different methods."""
379
+
380
+ def __init__(self):
381
+ self.embedding_generator = EmbeddingGenerator()
382
+
383
+ def build_faiss_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Tuple[Any, List[Dict]]:
384
+ """Build FAISS index for vanilla or DPR method."""
385
+
386
+ # Generate embeddings
387
+ texts = [chunk.text for chunk in chunks]
388
+
389
+ if method == "vanilla":
390
+ embeddings = self.embedding_generator.embed_text_openai(texts)
391
+ elif method == "dpr":
392
+ embeddings = self.embedding_generator.embed_text_sentence_transformer(texts)
393
+ else:
394
+ raise ValueError(f"Unsupported method for FAISS: {method}")
395
+
396
+ # Build FAISS index
397
+ dimension = embeddings.shape[1]
398
+ index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity
399
+
400
+ # Ensure embeddings are float32 and normalize for cosine similarity
401
+ embeddings = embeddings.astype(np.float32)
402
+ faiss.normalize_L2(embeddings)
403
+ index.add(embeddings)
404
+
405
+ # Store chunk metadata
406
+ metadata = []
407
+ for i, chunk in enumerate(chunks):
408
+ metadata.append({
409
+ 'chunk_id': chunk.chunk_id,
410
+ 'text': chunk.text,
411
+ 'metadata': chunk.metadata,
412
+ 'embedding': embeddings[i].tolist()
413
+ })
414
+
415
+ logger.info(f"Built FAISS index with {index.ntotal} vectors for method '{method}'")
416
+ return index, metadata
417
+
418
+ def build_chroma_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Any:
419
+ """Build Chroma vector database."""
420
+
421
+ # Initialize Chroma client
422
+ chroma_client = chromadb.PersistentClient(path=str(CHROMA_PATH / method))
423
+ collection = chroma_client.get_or_create_collection(
424
+ name=f"{method}_collection",
425
+ metadata={"method": method}
426
+ )
427
+
428
+ # Prepare data for Chroma
429
+ texts = [chunk.text for chunk in chunks]
430
+ ids = [chunk.chunk_id for chunk in chunks]
431
+ metadatas = [chunk.metadata for chunk in chunks]
432
+
433
+ # Add to collection (Chroma handles embeddings internally)
434
+ collection.add(
435
+ documents=texts,
436
+ ids=ids,
437
+ metadatas=metadatas
438
+ )
439
+
440
+ logger.info(f"Built Chroma collection with {collection.count()} documents for method '{method}'")
441
+ return collection
442
+
443
+ def build_bm25_index(self, chunks: List[DocumentChunk]) -> BM25Okapi:
444
+ """Build BM25 index for keyword search."""
445
+
446
+ # Tokenize texts
447
+ tokenized_corpus = []
448
+ for chunk in chunks:
449
+ tokens = chunk.text.lower().split()
450
+ tokenized_corpus.append(tokens)
451
+
452
+ # Build BM25 index
453
+ bm25 = BM25Okapi(tokenized_corpus, k1=BM25_K1, b=BM25_B)
454
+
455
+ logger.info(f"Built BM25 index with {len(tokenized_corpus)} documents")
456
+ return bm25
457
+
458
+ def build_graph_index(self, chunks: List[DocumentChunk]) -> nx.Graph:
459
+ """Build NetworkX graph for graph-based retrieval."""
460
+
461
+ # Create graph
462
+ G = nx.Graph()
463
+
464
+ # Generate embeddings for similarity calculation
465
+ texts = [chunk.text for chunk in chunks]
466
+ embeddings = self.embedding_generator.embed_text_openai(texts)
467
+
468
+ # Add nodes (convert embeddings to lists for GML serialization)
469
+ for i, chunk in enumerate(chunks):
470
+ G.add_node(chunk.chunk_id,
471
+ text=chunk.text,
472
+ metadata=chunk.metadata,
473
+ embedding=embeddings[i].tolist()) # Convert to list for serialization
474
+
475
+ # Add edges based on similarity
476
+ threshold = 0.7 # Similarity threshold
477
+ for i in range(len(chunks)):
478
+ for j in range(i + 1, len(chunks)):
479
+ # Calculate cosine similarity
480
+ sim = np.dot(embeddings[i], embeddings[j]) / (
481
+ np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
482
+ )
483
+
484
+ if sim > threshold:
485
+ G.add_edge(chunks[i].chunk_id, chunks[j].chunk_id,
486
+ weight=float(sim))
487
+
488
+ logger.info(f"Built graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
489
+ return G
490
+
491
+ def save_index(self, index: Any, metadata: Any, method: str):
492
+ """Save index and metadata to disk."""
493
+
494
+ if method == "vanilla":
495
+ faiss.write_index(index, str(VANILLA_FAISS_INDEX))
496
+ with open(VANILLA_METADATA, 'wb') as f:
497
+ pickle.dump(metadata, f)
498
+
499
+ elif method == "dpr":
500
+ faiss.write_index(index, str(DPR_FAISS_INDEX))
501
+ with open(DPR_METADATA, 'wb') as f:
502
+ pickle.dump(metadata, f)
503
+
504
+ elif method == "bm25":
505
+ with open(BM25_INDEX, 'wb') as f:
506
+ pickle.dump({'index': index, 'texts': metadata}, f)
507
+
508
+ elif method == "context_stuffing":
509
+ with open(CONTEXT_DOCS, 'wb') as f:
510
+ pickle.dump(metadata, f)
511
+
512
+ elif method == "graph":
513
+ nx.write_gml(index, str(GRAPH_FILE))
514
+
515
+ logger.info(f"Saved {method} index to disk")
516
+
517
+ class ImageProcessor:
518
+ """Process and classify images."""
519
+
520
+ def __init__(self):
521
+ self.embedding_generator = EmbeddingGenerator()
522
+ self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
523
+ self._init_database()
524
+
525
+ def _init_database(self):
526
+ """Initialize SQLite database for image metadata."""
527
+ conn = sqlite3.connect(IMAGES_DB)
528
+ cursor = conn.cursor()
529
+
530
+ cursor.execute('''
531
+ CREATE TABLE IF NOT EXISTS images (
532
+ image_id TEXT PRIMARY KEY,
533
+ image_path TEXT NOT NULL,
534
+ classification TEXT,
535
+ metadata TEXT,
536
+ embedding BLOB,
537
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
538
+ )
539
+ ''')
540
+
541
+ conn.commit()
542
+ conn.close()
543
+
544
+ def classify_image(self, image_path: str) -> str:
545
+ """Classify image using GPT-5 Vision."""
546
+ try:
547
+ # Convert image to base64
548
+ with open(image_path, "rb") as image_file:
549
+ image_b64 = base64.b64encode(image_file.read()).decode()
550
+
551
+ messages = [{
552
+ "role": "user",
553
+ "content": [
554
+ {"type": "text", "text": "Classify this image in 1-2 words (e.g., 'machine guard', 'press brake', 'conveyor belt', 'safety sign')."},
555
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}", "detail": "low"}}
556
+ ]
557
+ }]
558
+
559
+ # For GPT-5 vision, temperature must be default (1.0)
560
+ response = self.openai_client.chat.completions.create(
561
+ model=OPENAI_CHAT_MODEL,
562
+ messages=messages,
563
+ max_completion_tokens=50
564
+ )
565
+
566
+ return response.choices[0].message.content.strip()
567
+
568
+ except Exception as e:
569
+ logger.error(f"Error classifying image {image_path}: {e}")
570
+ return "unknown"
571
+
572
+ def should_filter_image(self, image_path: str) -> tuple[bool, str]:
573
+ """
574
+ Check if image should be filtered out based on height and black image criteria.
575
+
576
+ Args:
577
+ image_path: Path to the image file
578
+
579
+ Returns:
580
+ Tuple of (should_filter: bool, reason: str)
581
+ """
582
+ try:
583
+ from PIL import Image
584
+ import numpy as np
585
+
586
+ # Open and analyze the image
587
+ with Image.open(image_path) as img:
588
+ # Convert to RGB if needed
589
+ if img.mode != 'RGB':
590
+ img = img.convert('RGB')
591
+
592
+ width, height = img.size
593
+
594
+ # Filter 1: Height less than 40 pixels
595
+ if height < 40:
596
+ return True, f"height too small ({height}px)"
597
+
598
+ # Filter 2: Check if image is mostly black
599
+ img_array = np.array(img)
600
+ mean_brightness = np.mean(img_array)
601
+
602
+ # If mean brightness is very low (mostly black)
603
+ if mean_brightness < 10: # Adjust threshold as needed
604
+ return True, "mostly black image"
605
+
606
+ except Exception as e:
607
+ logger.warning(f"Error analyzing image {image_path}: {e}")
608
+ # If we can't analyze it, don't filter it out
609
+ return False, "analysis failed"
610
+
611
+ return False, "passed all filters"
612
+
613
+ def store_image_metadata(self, image_data: ImageData):
614
+ """Store image metadata in database."""
615
+ conn = sqlite3.connect(IMAGES_DB)
616
+ cursor = conn.cursor()
617
+
618
+ # Serialize metadata and embedding
619
+ metadata_json = json.dumps(image_data.metadata) if image_data.metadata else None
620
+ embedding_blob = image_data.embedding.tobytes() if image_data.embedding is not None else None
621
+
622
+ cursor.execute('''
623
+ INSERT OR REPLACE INTO images
624
+ (image_id, image_path, classification, metadata, embedding)
625
+ VALUES (?, ?, ?, ?, ?)
626
+ ''', (image_data.image_id, image_data.image_path,
627
+ image_data.classification, metadata_json, embedding_blob))
628
+
629
+ conn.commit()
630
+ conn.close()
631
+
632
+ def get_image_metadata(self, image_id: str) -> Optional[ImageData]:
633
+ """Retrieve image metadata from database."""
634
+ conn = sqlite3.connect(IMAGES_DB)
635
+ cursor = conn.cursor()
636
+
637
+ cursor.execute('''
638
+ SELECT image_id, image_path, classification, metadata, embedding
639
+ FROM images WHERE image_id = ?
640
+ ''', (image_id,))
641
+
642
+ row = cursor.fetchone()
643
+ conn.close()
644
+
645
+ if row:
646
+ image_id, image_path, classification, metadata_json, embedding_blob = row
647
+
648
+ metadata = json.loads(metadata_json) if metadata_json else None
649
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32) if embedding_blob else None
650
+
651
+ return ImageData(
652
+ image_path=image_path,
653
+ image_id=image_id,
654
+ classification=classification,
655
+ embedding=embedding,
656
+ metadata=metadata
657
+ )
658
+
659
+ return None
660
+
661
+ def load_text_documents() -> List[Dict[str, Any]]:
662
+ """Convenience function to load all text documents."""
663
+ loader = DocumentLoader()
664
+ return loader.load_text_documents()
665
+
666
+ def embed_image_clip(image_paths: List[str]) -> np.ndarray:
667
+ """Convenience function to embed images with CLIP."""
668
+ generator = EmbeddingGenerator()
669
+ return generator.embed_image_clip(image_paths)
670
+
671
+ def store_image_metadata(image_data: ImageData):
672
+ """Convenience function to store image metadata."""
673
+ processor = ImageProcessor()
674
+ processor.store_image_metadata(image_data)
675
+
676
+ def classify_image(image_path: str) -> str:
677
+ """Convenience function to classify an image."""
678
+ processor = ImageProcessor()
679
+ return processor.classify_image(image_path)