Spaces:
Paused
Paused
version 2.0.0
Browse files- Dockerfile +67 -26
- analytics_db.py +509 -0
- app.py +1298 -47
- config.py +217 -0
- preprocess.py +187 -77
- query_bm25.py +336 -0
- query_context.py +334 -0
- query_dpr.py +279 -0
- query_graph.py +372 -99
- query_vanilla.py +197 -0
- query_vision.py +393 -0
- realtime_server.py +402 -0
- requirements.txt +63 -10
- utils.py +679 -0
Dockerfile
CHANGED
|
@@ -1,26 +1,67 @@
|
|
| 1 |
-
FROM python:3.12.11-slim
|
| 2 |
-
|
| 3 |
-
#
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12.11-slim
|
| 2 |
+
|
| 3 |
+
# Set environment variables for Python
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 6 |
+
|
| 7 |
+
# Create app directory and switch to it
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# Install system dependencies required for your packages
|
| 11 |
+
RUN apt-get update && \
|
| 12 |
+
apt-get install -y \
|
| 13 |
+
build-essential \
|
| 14 |
+
curl \
|
| 15 |
+
git \
|
| 16 |
+
libgomp1 \
|
| 17 |
+
supervisor \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# Copy requirements first for better Docker layer caching
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
|
| 23 |
+
# Install Python dependencies
|
| 24 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 25 |
+
pip install --no-cache-dir -r requirements.txt
|
| 26 |
+
|
| 27 |
+
# Copy the entire application
|
| 28 |
+
COPY . .
|
| 29 |
+
|
| 30 |
+
# Create necessary directories if they don't exist
|
| 31 |
+
RUN mkdir -p /app/data /app/embeddings /app/graph /app/metadata /var/log/supervisor
|
| 32 |
+
|
| 33 |
+
# Create supervisor configuration
|
| 34 |
+
COPY <<EOF /etc/supervisor/conf.d/supervisord.conf
|
| 35 |
+
[supervisord]
|
| 36 |
+
nodaemon=true
|
| 37 |
+
logfile=/var/log/supervisor/supervisord.log
|
| 38 |
+
pidfile=/var/run/supervisord.pid
|
| 39 |
+
|
| 40 |
+
[program:realtime_server]
|
| 41 |
+
command=python realtime_server.py --port=7861 --host=0.0.0.0
|
| 42 |
+
directory=/app
|
| 43 |
+
autostart=true
|
| 44 |
+
autorestart=true
|
| 45 |
+
stderr_logfile=/var/log/supervisor/realtime_server.err.log
|
| 46 |
+
stdout_logfile=/var/log/supervisor/realtime_server.out.log
|
| 47 |
+
priority=100
|
| 48 |
+
|
| 49 |
+
[program:streamlit]
|
| 50 |
+
command=streamlit run app.py --server.port=7860 --server.address=0.0.0.0 --server.enableXsrfProtection=false --server.enableCORS=false
|
| 51 |
+
directory=/app
|
| 52 |
+
autostart=true
|
| 53 |
+
autorestart=true
|
| 54 |
+
stderr_logfile=/var/log/supervisor/streamlit.err.log
|
| 55 |
+
stdout_logfile=/var/log/supervisor/streamlit.out.log
|
| 56 |
+
priority=200
|
| 57 |
+
EOF
|
| 58 |
+
|
| 59 |
+
# Expose both ports (Streamlit on 7860, Realtime API on 7861)
|
| 60 |
+
EXPOSE 7860 7861
|
| 61 |
+
|
| 62 |
+
# Health check for both services
|
| 63 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 64 |
+
CMD curl --fail http://localhost:7860/_stcore/health && curl --fail http://localhost:7861/health || exit 1
|
| 65 |
+
|
| 66 |
+
# Use supervisor to run both services
|
| 67 |
+
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
analytics_db.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analytics Database Module for Query Logging and Performance Tracking.
|
| 3 |
+
Tracks every query, method, answer, and citation for comprehensive analytics.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sqlite3
|
| 7 |
+
import json
|
| 8 |
+
import time
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from config import DATA_DIR
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# Database path
|
| 19 |
+
ANALYTICS_DB = DATA_DIR / "analytics.db"
|
| 20 |
+
|
| 21 |
+
class AnalyticsDB:
|
| 22 |
+
"""Database manager for query analytics and logging."""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.db_path = ANALYTICS_DB
|
| 26 |
+
self._init_database()
|
| 27 |
+
|
| 28 |
+
def _init_database(self):
|
| 29 |
+
"""Initialize analytics database with required tables."""
|
| 30 |
+
conn = sqlite3.connect(self.db_path)
|
| 31 |
+
cursor = conn.cursor()
|
| 32 |
+
|
| 33 |
+
# Main queries table
|
| 34 |
+
cursor.execute('''
|
| 35 |
+
CREATE TABLE IF NOT EXISTS queries (
|
| 36 |
+
query_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 37 |
+
timestamp TEXT NOT NULL,
|
| 38 |
+
user_query TEXT NOT NULL,
|
| 39 |
+
retrieval_method TEXT NOT NULL,
|
| 40 |
+
answer TEXT NOT NULL,
|
| 41 |
+
response_time_ms REAL,
|
| 42 |
+
num_citations INTEGER DEFAULT 0,
|
| 43 |
+
image_path TEXT,
|
| 44 |
+
error_message TEXT,
|
| 45 |
+
top_k_used INTEGER DEFAULT 5,
|
| 46 |
+
additional_settings TEXT,
|
| 47 |
+
answer_length INTEGER,
|
| 48 |
+
session_id TEXT
|
| 49 |
+
)
|
| 50 |
+
''')
|
| 51 |
+
|
| 52 |
+
# Citations table
|
| 53 |
+
cursor.execute('''
|
| 54 |
+
CREATE TABLE IF NOT EXISTS citations (
|
| 55 |
+
citation_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 56 |
+
query_id INTEGER NOT NULL,
|
| 57 |
+
source TEXT NOT NULL,
|
| 58 |
+
citation_type TEXT,
|
| 59 |
+
relevance_score REAL,
|
| 60 |
+
bm25_score REAL,
|
| 61 |
+
rerank_score REAL,
|
| 62 |
+
similarity_score REAL,
|
| 63 |
+
url TEXT,
|
| 64 |
+
path TEXT,
|
| 65 |
+
rank INTEGER,
|
| 66 |
+
FOREIGN KEY (query_id) REFERENCES queries (query_id)
|
| 67 |
+
)
|
| 68 |
+
''')
|
| 69 |
+
|
| 70 |
+
# Performance metrics table
|
| 71 |
+
cursor.execute('''
|
| 72 |
+
CREATE TABLE IF NOT EXISTS performance_metrics (
|
| 73 |
+
metric_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 74 |
+
query_id INTEGER NOT NULL,
|
| 75 |
+
retrieval_time_ms REAL,
|
| 76 |
+
generation_time_ms REAL,
|
| 77 |
+
total_time_ms REAL,
|
| 78 |
+
chunks_retrieved INTEGER,
|
| 79 |
+
tokens_estimated INTEGER,
|
| 80 |
+
FOREIGN KEY (query_id) REFERENCES queries (query_id)
|
| 81 |
+
)
|
| 82 |
+
''')
|
| 83 |
+
|
| 84 |
+
conn.commit()
|
| 85 |
+
conn.close()
|
| 86 |
+
logger.info("Analytics database initialized")
|
| 87 |
+
|
| 88 |
+
def log_query(self, user_query: str, method: str, answer: str,
|
| 89 |
+
citations: List[Dict], response_time: float = None,
|
| 90 |
+
image_path: str = None, error_message: str = None,
|
| 91 |
+
top_k: int = 5, additional_settings: Dict = None,
|
| 92 |
+
session_id: str = None) -> int:
|
| 93 |
+
"""
|
| 94 |
+
Log a complete query interaction.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
user_query: The user's question
|
| 98 |
+
method: Retrieval method used
|
| 99 |
+
answer: Generated answer
|
| 100 |
+
citations: List of citation dictionaries
|
| 101 |
+
response_time: Time taken in milliseconds
|
| 102 |
+
image_path: Path to uploaded image (if any)
|
| 103 |
+
error_message: Error message (if any)
|
| 104 |
+
top_k: Number of chunks retrieved
|
| 105 |
+
additional_settings: Method-specific settings
|
| 106 |
+
session_id: Session identifier
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
query_id: The ID of the logged query
|
| 110 |
+
"""
|
| 111 |
+
conn = sqlite3.connect(self.db_path)
|
| 112 |
+
cursor = conn.cursor()
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
# Insert main query record
|
| 116 |
+
cursor.execute('''
|
| 117 |
+
INSERT INTO queries (
|
| 118 |
+
timestamp, user_query, retrieval_method, answer,
|
| 119 |
+
response_time_ms, num_citations, image_path, error_message,
|
| 120 |
+
top_k_used, additional_settings, answer_length, session_id
|
| 121 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 122 |
+
''', (
|
| 123 |
+
datetime.now().isoformat(),
|
| 124 |
+
user_query,
|
| 125 |
+
method,
|
| 126 |
+
answer,
|
| 127 |
+
response_time,
|
| 128 |
+
len(citations),
|
| 129 |
+
image_path,
|
| 130 |
+
error_message,
|
| 131 |
+
top_k,
|
| 132 |
+
json.dumps(additional_settings) if additional_settings else None,
|
| 133 |
+
len(answer),
|
| 134 |
+
session_id
|
| 135 |
+
))
|
| 136 |
+
|
| 137 |
+
query_id = cursor.lastrowid
|
| 138 |
+
|
| 139 |
+
# Insert citations
|
| 140 |
+
for rank, citation in enumerate(citations, 1):
|
| 141 |
+
cursor.execute('''
|
| 142 |
+
INSERT INTO citations (
|
| 143 |
+
query_id, source, citation_type, relevance_score,
|
| 144 |
+
bm25_score, rerank_score, similarity_score, url, path, rank
|
| 145 |
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 146 |
+
''', (
|
| 147 |
+
query_id,
|
| 148 |
+
citation.get('source', ''),
|
| 149 |
+
citation.get('type', ''),
|
| 150 |
+
citation.get('relevance_score'),
|
| 151 |
+
citation.get('bm25_score'),
|
| 152 |
+
citation.get('rerank_score'),
|
| 153 |
+
citation.get('similarity_score'),
|
| 154 |
+
citation.get('url'),
|
| 155 |
+
citation.get('path'),
|
| 156 |
+
rank
|
| 157 |
+
))
|
| 158 |
+
|
| 159 |
+
conn.commit()
|
| 160 |
+
logger.info(f"Logged query {query_id} with {len(citations)} citations")
|
| 161 |
+
return query_id
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Error logging query: {e}")
|
| 165 |
+
conn.rollback()
|
| 166 |
+
return None
|
| 167 |
+
finally:
|
| 168 |
+
conn.close()
|
| 169 |
+
|
| 170 |
+
def get_query_stats(self, days: int = 30) -> Dict[str, Any]:
|
| 171 |
+
"""Get comprehensive query statistics."""
|
| 172 |
+
conn = sqlite3.connect(self.db_path)
|
| 173 |
+
cursor = conn.cursor()
|
| 174 |
+
|
| 175 |
+
since_date = (datetime.now() - timedelta(days=days)).isoformat()
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
stats = {}
|
| 179 |
+
|
| 180 |
+
# Total queries
|
| 181 |
+
cursor.execute('''
|
| 182 |
+
SELECT COUNT(*) FROM queries
|
| 183 |
+
WHERE timestamp >= ?
|
| 184 |
+
''', (since_date,))
|
| 185 |
+
stats['total_queries'] = cursor.fetchone()[0]
|
| 186 |
+
|
| 187 |
+
# Method usage
|
| 188 |
+
cursor.execute('''
|
| 189 |
+
SELECT retrieval_method, COUNT(*) as count
|
| 190 |
+
FROM queries
|
| 191 |
+
WHERE timestamp >= ?
|
| 192 |
+
GROUP BY retrieval_method
|
| 193 |
+
ORDER BY count DESC
|
| 194 |
+
''', (since_date,))
|
| 195 |
+
stats['method_usage'] = dict(cursor.fetchall())
|
| 196 |
+
|
| 197 |
+
# Average response times by method
|
| 198 |
+
cursor.execute('''
|
| 199 |
+
SELECT retrieval_method, AVG(response_time_ms) as avg_time
|
| 200 |
+
FROM queries
|
| 201 |
+
WHERE timestamp >= ? AND response_time_ms IS NOT NULL
|
| 202 |
+
GROUP BY retrieval_method
|
| 203 |
+
''', (since_date,))
|
| 204 |
+
stats['avg_response_times'] = dict(cursor.fetchall())
|
| 205 |
+
|
| 206 |
+
# Citation statistics
|
| 207 |
+
cursor.execute('''
|
| 208 |
+
SELECT AVG(num_citations) as avg_citations,
|
| 209 |
+
SUM(num_citations) as total_citations
|
| 210 |
+
FROM queries
|
| 211 |
+
WHERE timestamp >= ?
|
| 212 |
+
''', (since_date,))
|
| 213 |
+
result = cursor.fetchone()
|
| 214 |
+
stats['avg_citations'] = result[0] or 0
|
| 215 |
+
stats['total_citations'] = result[1] or 0
|
| 216 |
+
|
| 217 |
+
# Citation types
|
| 218 |
+
cursor.execute('''
|
| 219 |
+
SELECT c.citation_type, COUNT(*) as count
|
| 220 |
+
FROM citations c
|
| 221 |
+
JOIN queries q ON c.query_id = q.query_id
|
| 222 |
+
WHERE q.timestamp >= ?
|
| 223 |
+
GROUP BY c.citation_type
|
| 224 |
+
ORDER BY count DESC
|
| 225 |
+
''', (since_date,))
|
| 226 |
+
stats['citation_types'] = dict(cursor.fetchall())
|
| 227 |
+
|
| 228 |
+
# Error rate
|
| 229 |
+
cursor.execute('''
|
| 230 |
+
SELECT
|
| 231 |
+
COUNT(CASE WHEN error_message IS NOT NULL THEN 1 END) as errors,
|
| 232 |
+
COUNT(*) as total
|
| 233 |
+
FROM queries
|
| 234 |
+
WHERE timestamp >= ?
|
| 235 |
+
''', (since_date,))
|
| 236 |
+
result = cursor.fetchone()
|
| 237 |
+
stats['error_rate'] = (result[0] / result[1]) * 100 if result[1] > 0 else 0
|
| 238 |
+
|
| 239 |
+
# Most common query topics (simple word analysis)
|
| 240 |
+
cursor.execute('''
|
| 241 |
+
SELECT user_query FROM queries
|
| 242 |
+
WHERE timestamp >= ?
|
| 243 |
+
''', (since_date,))
|
| 244 |
+
queries = [row[0].lower() for row in cursor.fetchall()]
|
| 245 |
+
|
| 246 |
+
# Simple keyword extraction
|
| 247 |
+
keywords = {}
|
| 248 |
+
for query in queries:
|
| 249 |
+
words = [word for word in query.split() if len(word) > 3]
|
| 250 |
+
for word in words:
|
| 251 |
+
keywords[word] = keywords.get(word, 0) + 1
|
| 252 |
+
|
| 253 |
+
# Top 10 keywords
|
| 254 |
+
stats['top_keywords'] = dict(sorted(keywords.items(),
|
| 255 |
+
key=lambda x: x[1],
|
| 256 |
+
reverse=True)[:10])
|
| 257 |
+
|
| 258 |
+
return stats
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.error(f"Error getting query stats: {e}")
|
| 262 |
+
return {}
|
| 263 |
+
finally:
|
| 264 |
+
conn.close()
|
| 265 |
+
|
| 266 |
+
def get_method_performance(self) -> Dict[str, Dict[str, float]]:
|
| 267 |
+
"""Get detailed performance metrics by method."""
|
| 268 |
+
conn = sqlite3.connect(self.db_path)
|
| 269 |
+
cursor = conn.cursor()
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
cursor.execute('''
|
| 273 |
+
SELECT
|
| 274 |
+
retrieval_method,
|
| 275 |
+
AVG(response_time_ms) as avg_response_time,
|
| 276 |
+
AVG(num_citations) as avg_citations,
|
| 277 |
+
AVG(answer_length) as avg_answer_length,
|
| 278 |
+
COUNT(*) as query_count
|
| 279 |
+
FROM queries
|
| 280 |
+
WHERE response_time_ms IS NOT NULL
|
| 281 |
+
GROUP BY retrieval_method
|
| 282 |
+
''')
|
| 283 |
+
|
| 284 |
+
results = {}
|
| 285 |
+
for row in cursor.fetchall():
|
| 286 |
+
method, avg_time, avg_cites, avg_length, count = row
|
| 287 |
+
results[method] = {
|
| 288 |
+
'avg_response_time': avg_time,
|
| 289 |
+
'avg_citations': avg_cites,
|
| 290 |
+
'avg_answer_length': avg_length,
|
| 291 |
+
'query_count': count
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
return results
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.error(f"Error getting method performance: {e}")
|
| 298 |
+
return {}
|
| 299 |
+
finally:
|
| 300 |
+
conn.close()
|
| 301 |
+
|
| 302 |
+
def get_recent_queries(self, limit: int = 20, include_answers: bool = True) -> List[Dict[str, Any]]:
|
| 303 |
+
"""Get recent queries with basic information and optionally full answers."""
|
| 304 |
+
conn = sqlite3.connect(self.db_path)
|
| 305 |
+
cursor = conn.cursor()
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
if include_answers:
|
| 309 |
+
cursor.execute('''
|
| 310 |
+
SELECT query_id, timestamp, user_query, retrieval_method,
|
| 311 |
+
answer, answer_length, num_citations, response_time_ms, error_message
|
| 312 |
+
FROM queries
|
| 313 |
+
ORDER BY timestamp DESC
|
| 314 |
+
LIMIT ?
|
| 315 |
+
''', (limit,))
|
| 316 |
+
|
| 317 |
+
columns = ['query_id', 'timestamp', 'query', 'method',
|
| 318 |
+
'answer', 'answer_length', 'citations', 'response_time', 'error_message']
|
| 319 |
+
else:
|
| 320 |
+
cursor.execute('''
|
| 321 |
+
SELECT query_id, timestamp, user_query, retrieval_method,
|
| 322 |
+
answer_length, num_citations, response_time_ms
|
| 323 |
+
FROM queries
|
| 324 |
+
ORDER BY timestamp DESC
|
| 325 |
+
LIMIT ?
|
| 326 |
+
''', (limit,))
|
| 327 |
+
|
| 328 |
+
columns = ['query_id', 'timestamp', 'query', 'method',
|
| 329 |
+
'answer_length', 'citations', 'response_time']
|
| 330 |
+
|
| 331 |
+
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.error(f"Error getting recent queries: {e}")
|
| 335 |
+
return []
|
| 336 |
+
finally:
|
| 337 |
+
conn.close()
|
| 338 |
+
|
| 339 |
+
def get_query_with_citations(self, query_id: int) -> Dict[str, Any]:
|
| 340 |
+
"""Get full query details including citations."""
|
| 341 |
+
conn = sqlite3.connect(self.db_path)
|
| 342 |
+
cursor = conn.cursor()
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
# Get query details
|
| 346 |
+
cursor.execute('''
|
| 347 |
+
SELECT query_id, timestamp, user_query, retrieval_method, answer,
|
| 348 |
+
response_time_ms, num_citations, error_message, top_k_used
|
| 349 |
+
FROM queries WHERE query_id = ?
|
| 350 |
+
''', (query_id,))
|
| 351 |
+
|
| 352 |
+
query_row = cursor.fetchone()
|
| 353 |
+
if not query_row:
|
| 354 |
+
return {}
|
| 355 |
+
|
| 356 |
+
query_data = {
|
| 357 |
+
'query_id': query_row[0],
|
| 358 |
+
'timestamp': query_row[1],
|
| 359 |
+
'user_query': query_row[2],
|
| 360 |
+
'method': query_row[3],
|
| 361 |
+
'answer': query_row[4],
|
| 362 |
+
'response_time': query_row[5],
|
| 363 |
+
'num_citations': query_row[6],
|
| 364 |
+
'error_message': query_row[7],
|
| 365 |
+
'top_k_used': query_row[8]
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
# Get citations
|
| 369 |
+
cursor.execute('''
|
| 370 |
+
SELECT source, citation_type, relevance_score, bm25_score,
|
| 371 |
+
rerank_score, similarity_score, url, path, rank
|
| 372 |
+
FROM citations WHERE query_id = ?
|
| 373 |
+
ORDER BY rank
|
| 374 |
+
''', (query_id,))
|
| 375 |
+
|
| 376 |
+
citations = []
|
| 377 |
+
for row in cursor.fetchall():
|
| 378 |
+
citation = {
|
| 379 |
+
'source': row[0],
|
| 380 |
+
'type': row[1],
|
| 381 |
+
'relevance_score': row[2],
|
| 382 |
+
'bm25_score': row[3],
|
| 383 |
+
'rerank_score': row[4],
|
| 384 |
+
'similarity_score': row[5],
|
| 385 |
+
'url': row[6],
|
| 386 |
+
'path': row[7],
|
| 387 |
+
'rank': row[8]
|
| 388 |
+
}
|
| 389 |
+
citations.append(citation)
|
| 390 |
+
|
| 391 |
+
query_data['citations'] = citations
|
| 392 |
+
return query_data
|
| 393 |
+
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(f"Error getting query with citations: {e}")
|
| 396 |
+
return {}
|
| 397 |
+
finally:
|
| 398 |
+
conn.close()
|
| 399 |
+
|
| 400 |
+
def get_query_trends(self, days: int = 30) -> Dict[str, List[Tuple[str, int]]]:
|
| 401 |
+
"""Get query trends over time."""
|
| 402 |
+
conn = sqlite3.connect(self.db_path)
|
| 403 |
+
cursor = conn.cursor()
|
| 404 |
+
|
| 405 |
+
since_date = (datetime.now() - timedelta(days=days)).isoformat()
|
| 406 |
+
|
| 407 |
+
try:
|
| 408 |
+
# Queries per day
|
| 409 |
+
cursor.execute('''
|
| 410 |
+
SELECT DATE(timestamp) as date, COUNT(*) as count
|
| 411 |
+
FROM queries
|
| 412 |
+
WHERE timestamp >= ?
|
| 413 |
+
GROUP BY DATE(timestamp)
|
| 414 |
+
ORDER BY date
|
| 415 |
+
''', (since_date,))
|
| 416 |
+
|
| 417 |
+
daily_queries = cursor.fetchall()
|
| 418 |
+
|
| 419 |
+
# Method usage trends
|
| 420 |
+
cursor.execute('''
|
| 421 |
+
SELECT DATE(timestamp) as date, retrieval_method, COUNT(*) as count
|
| 422 |
+
FROM queries
|
| 423 |
+
WHERE timestamp >= ?
|
| 424 |
+
GROUP BY DATE(timestamp), retrieval_method
|
| 425 |
+
ORDER BY date, retrieval_method
|
| 426 |
+
''', (since_date,))
|
| 427 |
+
|
| 428 |
+
method_trends = {}
|
| 429 |
+
for date, method, count in cursor.fetchall():
|
| 430 |
+
if method not in method_trends:
|
| 431 |
+
method_trends[method] = []
|
| 432 |
+
method_trends[method].append((date, count))
|
| 433 |
+
|
| 434 |
+
return {
|
| 435 |
+
'daily_queries': daily_queries,
|
| 436 |
+
'method_trends': method_trends
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
logger.error(f"Error getting query trends: {e}")
|
| 441 |
+
return {}
|
| 442 |
+
finally:
|
| 443 |
+
conn.close()
|
| 444 |
+
|
| 445 |
+
def get_voice_interaction_stats(self) -> Dict[str, Any]:
|
| 446 |
+
"""Get statistics about voice interactions."""
|
| 447 |
+
try:
|
| 448 |
+
conn = sqlite3.connect(self.db_path)
|
| 449 |
+
cursor = conn.cursor()
|
| 450 |
+
|
| 451 |
+
# Count voice interactions (those with voice_interaction=true in additional_settings)
|
| 452 |
+
cursor.execute('''
|
| 453 |
+
SELECT COUNT(*) as total_voice_queries
|
| 454 |
+
FROM queries
|
| 455 |
+
WHERE additional_settings LIKE '%voice_interaction%'
|
| 456 |
+
OR session_id LIKE 'voice_%'
|
| 457 |
+
''')
|
| 458 |
+
result = cursor.fetchone()
|
| 459 |
+
total_voice = result[0] if result else 0
|
| 460 |
+
|
| 461 |
+
# Get voice queries by method
|
| 462 |
+
cursor.execute('''
|
| 463 |
+
SELECT retrieval_method, COUNT(*) as count
|
| 464 |
+
FROM queries
|
| 465 |
+
WHERE additional_settings LIKE '%voice_interaction%'
|
| 466 |
+
OR session_id LIKE 'voice_%'
|
| 467 |
+
GROUP BY retrieval_method
|
| 468 |
+
''')
|
| 469 |
+
voice_by_method = dict(cursor.fetchall())
|
| 470 |
+
|
| 471 |
+
# Average response time for voice queries
|
| 472 |
+
cursor.execute('''
|
| 473 |
+
SELECT AVG(response_time_ms) as avg_response_time
|
| 474 |
+
FROM queries
|
| 475 |
+
WHERE (additional_settings LIKE '%voice_interaction%'
|
| 476 |
+
OR session_id LIKE 'voice_%')
|
| 477 |
+
AND response_time_ms IS NOT NULL
|
| 478 |
+
''')
|
| 479 |
+
result = cursor.fetchone()
|
| 480 |
+
avg_response_time = result[0] if result and result[0] else 0
|
| 481 |
+
|
| 482 |
+
return {
|
| 483 |
+
'total_voice_queries': total_voice,
|
| 484 |
+
'voice_by_method': voice_by_method,
|
| 485 |
+
'avg_voice_response_time': avg_response_time
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
except Exception as e:
|
| 489 |
+
logger.error(f"Error getting voice interaction stats: {e}")
|
| 490 |
+
return {}
|
| 491 |
+
finally:
|
| 492 |
+
conn.close()
|
| 493 |
+
|
| 494 |
+
# Global instance
|
| 495 |
+
analytics_db = AnalyticsDB()
|
| 496 |
+
|
| 497 |
+
# Convenience functions
|
| 498 |
+
def log_query(user_query: str, method: str, answer: str, citations: List[Dict],
|
| 499 |
+
**kwargs) -> int:
|
| 500 |
+
"""Log a query to the analytics database."""
|
| 501 |
+
return analytics_db.log_query(user_query, method, answer, citations, **kwargs)
|
| 502 |
+
|
| 503 |
+
def get_analytics_stats(days: int = 30) -> Dict[str, Any]:
|
| 504 |
+
"""Get analytics statistics."""
|
| 505 |
+
return analytics_db.get_query_stats(days)
|
| 506 |
+
|
| 507 |
+
def get_method_performance() -> Dict[str, Dict[str, float]]:
|
| 508 |
+
"""Get method performance metrics."""
|
| 509 |
+
return analytics_db.get_method_performance()
|
app.py
CHANGED
|
@@ -1,82 +1,1333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
# Helper for <details>
|
| 5 |
def format_citations_html(chunks):
|
|
|
|
| 6 |
html = []
|
| 7 |
for idx, (hdr, sc, txt, citation) in enumerate(chunks, start=1):
|
| 8 |
-
preamble = (
|
| 9 |
-
f"<p style='font-size:0.9em;'><strong>Preamble:</strong> "
|
| 10 |
-
f"The text in the following detail is reproduced from [{citation}]. "
|
| 11 |
-
f"It had a cosine similarity of {sc:.2f} with the user question, "
|
| 12 |
-
f"and it ranked {idx} among the text chunks in our graph database.</p>"
|
| 13 |
-
)
|
| 14 |
body = txt.replace("\n", "<br>")
|
| 15 |
html.append(
|
| 16 |
f"<details>"
|
| 17 |
-
f"<summary>{hdr} (
|
| 18 |
f"<div style='font-size:0.9em; margin-top:0.5em;'>"
|
| 19 |
-
f"<strong>
|
| 20 |
f"</div>"
|
| 21 |
-
f"<div style='font-size:0.
|
| 22 |
f"</details><br><br>"
|
| 23 |
)
|
| 24 |
return "<br>".join(html)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Sidebar configuration
|
| 28 |
-
st.sidebar.title("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
st.sidebar.markdown("**Authors:** [The SIGHT Project Team](https://sites.miamioh.edu/sight/)")
|
| 30 |
-
st.sidebar.markdown("**Version:** V.
|
| 31 |
-
st.sidebar.markdown("**Date:**
|
| 32 |
-
st.sidebar.markdown("**Model:**
|
| 33 |
|
| 34 |
st.sidebar.markdown("---")
|
| 35 |
st.sidebar.markdown(
|
| 36 |
"**Funding:** SIGHT is funded by [OHBWC WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)"
|
| 37 |
)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
st.
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
'
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
})
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Footer
|
| 76 |
st.markdown("---")
|
| 77 |
st.markdown(
|
| 78 |
-
"
|
| 79 |
-
|
|
|
|
| 80 |
st.markdown(
|
| 81 |
-
"
|
| 82 |
-
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Method RAG System - SIGHT
|
| 3 |
+
Enhanced Streamlit application with method comparison and analytics.
|
| 4 |
+
|
| 5 |
+
Directory structure:
|
| 6 |
+
/data/ # Original PDFs, HTML
|
| 7 |
+
/embeddings/ # FAISS, Chroma, DPR vector stores
|
| 8 |
+
/graph/ # Graph database files
|
| 9 |
+
/metadata/ # Image metadata (SQLite or MongoDB)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
import streamlit as st
|
| 13 |
+
import os
|
| 14 |
+
import logging
|
| 15 |
+
import tempfile
|
| 16 |
+
import time
|
| 17 |
+
import uuid
|
| 18 |
+
from typing import Tuple, List, Dict, Any, Optional
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Import all query modules
|
| 22 |
+
from query_graph import query as graph_query, query_graph
|
| 23 |
+
from query_vanilla import query as vanilla_query
|
| 24 |
+
from query_dpr import query as dpr_query
|
| 25 |
+
from query_bm25 import query as bm25_query
|
| 26 |
+
from query_context import query as context_query
|
| 27 |
+
from query_vision import query as vision_query, query_image_only
|
| 28 |
+
|
| 29 |
+
from config import *
|
| 30 |
+
from analytics_db import log_query, get_analytics_stats, get_method_performance, analytics_db
|
| 31 |
+
import streamlit.components.v1 as components
|
| 32 |
+
import requests
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Check realtime server health
|
| 37 |
+
@st.cache_data(ttl=30) # Cache for 30 seconds
|
| 38 |
+
def check_realtime_server_health():
|
| 39 |
+
"""Check if the realtime server is running."""
|
| 40 |
+
try:
|
| 41 |
+
response = requests.get("http://localhost:7861/health", timeout=2)
|
| 42 |
+
return response.status_code == 200
|
| 43 |
+
except:
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
# Query method dispatch
|
| 47 |
+
QUERY_DISPATCH = {
|
| 48 |
+
'graph': graph_query,
|
| 49 |
+
'vanilla': vanilla_query,
|
| 50 |
+
'dpr': dpr_query,
|
| 51 |
+
'bm25': bm25_query,
|
| 52 |
+
'context': context_query,
|
| 53 |
+
'vision': vision_query
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# Method options for speech interface
|
| 57 |
+
METHOD_OPTIONS = ['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision']
|
| 58 |
|
|
|
|
| 59 |
def format_citations_html(chunks):
|
| 60 |
+
"""Format citations for display (backward compatibility)."""
|
| 61 |
html = []
|
| 62 |
for idx, (hdr, sc, txt, citation) in enumerate(chunks, start=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
body = txt.replace("\n", "<br>")
|
| 64 |
html.append(
|
| 65 |
f"<details>"
|
| 66 |
+
f"<summary>{hdr} (relevance score: {sc:.3f})</summary>"
|
| 67 |
f"<div style='font-size:0.9em; margin-top:0.5em;'>"
|
| 68 |
+
f"<strong>Source:</strong> {citation} "
|
| 69 |
f"</div>"
|
| 70 |
+
f"<div style='font-size:0.8em; margin-left:1em; margin-top:0.5em;'>{body}</div>"
|
| 71 |
f"</details><br><br>"
|
| 72 |
)
|
| 73 |
return "<br>".join(html)
|
| 74 |
|
| 75 |
+
def format_citations_html(citations: List[dict], method: str) -> str:
|
| 76 |
+
"""Format citations as HTML based on method and citation type."""
|
| 77 |
+
if not citations:
|
| 78 |
+
return "<p><em>No citations available</em></p>"
|
| 79 |
+
|
| 80 |
+
html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"]
|
| 81 |
+
|
| 82 |
+
for citation in citations:
|
| 83 |
+
# Skip citations without source
|
| 84 |
+
if 'source' not in citation:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
source = citation['source']
|
| 88 |
+
cite_type = citation.get('type', 'unknown')
|
| 89 |
+
|
| 90 |
+
# Build citation text based on type
|
| 91 |
+
if cite_type == 'pdf':
|
| 92 |
+
cite_text = f"📄 {source} (PDF)"
|
| 93 |
+
elif cite_type == 'html':
|
| 94 |
+
url = citation.get('url', '')
|
| 95 |
+
if url:
|
| 96 |
+
cite_text = f"🌐 <a href='{url}' target='_blank'>{source}</a> (Web)"
|
| 97 |
+
else:
|
| 98 |
+
cite_text = f"🌐 {source} (Web)"
|
| 99 |
+
elif cite_type == 'image':
|
| 100 |
+
page = citation.get('page', 'N/A')
|
| 101 |
+
cite_text = f"🖼️ {source} (Image, page {page})"
|
| 102 |
+
elif cite_type == 'image_analysis':
|
| 103 |
+
classification = citation.get('classification', 'N/A')
|
| 104 |
+
cite_text = f"🔍 {source} - {classification}"
|
| 105 |
+
else:
|
| 106 |
+
cite_text = f"📚 {source}"
|
| 107 |
+
|
| 108 |
+
# Add scores if available
|
| 109 |
+
scores = []
|
| 110 |
+
if 'relevance_score' in citation:
|
| 111 |
+
scores.append(f"relevance: {citation['relevance_score']}")
|
| 112 |
+
if 'bm25_score' in citation:
|
| 113 |
+
scores.append(f"BM25: {citation['bm25_score']}")
|
| 114 |
+
if 'rerank_score' in citation:
|
| 115 |
+
scores.append(f"rerank: {citation['rerank_score']}")
|
| 116 |
+
if 'similarity' in citation:
|
| 117 |
+
scores.append(f"similarity: {citation['similarity']}")
|
| 118 |
+
if 'score' in citation:
|
| 119 |
+
scores.append(f"score: {citation['score']:.3f}")
|
| 120 |
+
|
| 121 |
+
if scores:
|
| 122 |
+
cite_text += f" <small>({', '.join(scores)})</small>"
|
| 123 |
+
|
| 124 |
+
html_parts.append(f"<li>{cite_text}</li>")
|
| 125 |
+
|
| 126 |
+
html_parts.append("</ul></div>")
|
| 127 |
+
return "".join(html_parts)
|
| 128 |
+
|
| 129 |
+
def save_uploaded_file(uploaded_file) -> str:
|
| 130 |
+
"""Save uploaded file to temporary location."""
|
| 131 |
+
try:
|
| 132 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp_file:
|
| 133 |
+
tmp_file.write(uploaded_file.getvalue())
|
| 134 |
+
return tmp_file.name
|
| 135 |
+
except Exception as e:
|
| 136 |
+
st.error(f"Error saving file: {e}")
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# Page configuration
|
| 141 |
+
st.set_page_config(
|
| 142 |
+
page_title="Multi-Method RAG System - SIGHT",
|
| 143 |
+
page_icon="🔍",
|
| 144 |
+
layout="wide"
|
| 145 |
+
)
|
| 146 |
|
| 147 |
# Sidebar configuration
|
| 148 |
+
st.sidebar.title("Configuration")
|
| 149 |
+
|
| 150 |
+
# Method selector
|
| 151 |
+
st.sidebar.markdown("### Retrieval Method")
|
| 152 |
+
selected_method = st.sidebar.radio(
|
| 153 |
+
"Choose retrieval method:",
|
| 154 |
+
options=['graph', 'vanilla', 'dpr', 'bm25', 'context', 'vision'],
|
| 155 |
+
format_func=lambda x: x.capitalize(),
|
| 156 |
+
help="Select different RAG methods to compare results"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Display method description
|
| 160 |
+
st.sidebar.info(METHOD_DESCRIPTIONS[selected_method])
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Advanced settings
|
| 164 |
+
with st.sidebar.expander("Advanced Settings"):
|
| 165 |
+
top_k = st.slider("Number of chunks to retrieve", min_value=1, max_value=10, value=DEFAULT_TOP_K)
|
| 166 |
+
|
| 167 |
+
if selected_method == 'bm25':
|
| 168 |
+
use_hybrid = st.checkbox("Use hybrid search (BM25 + semantic)", value=False)
|
| 169 |
+
if use_hybrid:
|
| 170 |
+
alpha = st.slider("BM25 weight (alpha)", min_value=0.0, max_value=1.0, value=0.5)
|
| 171 |
+
|
| 172 |
+
# Sidebar info
|
| 173 |
+
|
| 174 |
+
st.sidebar.markdown("---")
|
| 175 |
+
st.sidebar.markdown("### About")
|
| 176 |
st.sidebar.markdown("**Authors:** [The SIGHT Project Team](https://sites.miamioh.edu/sight/)")
|
| 177 |
+
st.sidebar.markdown(f"**Version:** V. {VERSION}")
|
| 178 |
+
st.sidebar.markdown(f"**Date:** {DATE}")
|
| 179 |
+
st.sidebar.markdown(f"**Model:** {OPENAI_CHAT_MODEL}")
|
| 180 |
|
| 181 |
st.sidebar.markdown("---")
|
| 182 |
st.sidebar.markdown(
|
| 183 |
"**Funding:** SIGHT is funded by [OHBWC WSIC](https://info.bwc.ohio.gov/for-employers/safety-services/workplace-safety-innovation-center/wsic-overview)"
|
| 184 |
)
|
| 185 |
|
| 186 |
+
# Main interface with dynamic status
|
| 187 |
+
col1, col2 = st.columns([3, 1])
|
| 188 |
+
with col1:
|
| 189 |
+
st.title("🔍 Multi-Method RAG System - SIGHT")
|
| 190 |
+
st.markdown("### Compare different retrieval methods for machine safety Q&A")
|
| 191 |
|
| 192 |
+
with col2:
|
| 193 |
+
# Quick stats in the header
|
| 194 |
+
if 'chat_history' in st.session_state:
|
| 195 |
+
total_queries = len(st.session_state.chat_history)
|
| 196 |
+
st.metric("Session Queries", total_queries, delta=None if total_queries == 0 else "+1" if total_queries == 1 else f"+{total_queries}")
|
| 197 |
+
|
| 198 |
+
# Voice chat status indicator
|
| 199 |
+
if st.session_state.get('voice_session_active', False):
|
| 200 |
+
st.success("🔴 Voice LIVE")
|
| 201 |
|
| 202 |
+
# Create tabs for different interfaces
|
| 203 |
+
tab1, tab2, tab3, tab4 = st.tabs(["💬 Chat", "📊 Method Comparison", "🔊 Voice Chat", "📈 Analytics"])
|
| 204 |
+
|
| 205 |
+
with tab1:
|
| 206 |
+
# Example questions
|
| 207 |
+
with st.expander("📝 Example Questions", expanded=False):
|
| 208 |
+
example_cols = st.columns(2)
|
| 209 |
+
with example_cols[0]:
|
| 210 |
+
st.markdown(
|
| 211 |
+
"**General Safety:**\n"
|
| 212 |
+
"- What are general machine guarding requirements?\n"
|
| 213 |
+
"- How do I perform lockout/tagout?\n"
|
| 214 |
+
"- What is required for emergency stops?"
|
| 215 |
+
)
|
| 216 |
+
with example_cols[1]:
|
| 217 |
+
st.markdown(
|
| 218 |
+
"**Specific Topics:**\n"
|
| 219 |
+
"- Summarize robot safety requirements from OSHA\n"
|
| 220 |
+
"- Compare guard types: fixed vs interlocked\n"
|
| 221 |
+
"- What are the ANSI standards for machine safety?"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# File uploader for vision method
|
| 225 |
+
uploaded_file = None
|
| 226 |
+
if selected_method == 'vision':
|
| 227 |
+
st.markdown("#### 🖼️ Upload an image for analysis")
|
| 228 |
+
uploaded_file = st.file_uploader(
|
| 229 |
+
"Choose an image file",
|
| 230 |
+
type=['png', 'jpg', 'jpeg', 'bmp', 'gif'],
|
| 231 |
+
help="Upload an image of safety equipment, signs, or machinery"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if uploaded_file:
|
| 235 |
+
col1, col2 = st.columns([1, 2])
|
| 236 |
+
with col1:
|
| 237 |
+
st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
|
| 238 |
+
|
| 239 |
+
# Initialize session state
|
| 240 |
+
if 'chat_history' not in st.session_state:
|
| 241 |
+
st.session_state.chat_history = []
|
| 242 |
+
|
| 243 |
+
if 'session_id' not in st.session_state:
|
| 244 |
+
st.session_state.session_id = str(uuid.uuid4())[:8]
|
| 245 |
+
|
| 246 |
+
# Chat input
|
| 247 |
+
query = st.text_input(
|
| 248 |
+
"Ask a question:",
|
| 249 |
+
placeholder="E.g., What are the safety requirements for collaborative robots?",
|
| 250 |
+
key="chat_input"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
col1, col2, col3 = st.columns([1, 1, 8])
|
| 254 |
+
with col1:
|
| 255 |
+
send_button = st.button("🚀 Send", type="primary", use_container_width=True)
|
| 256 |
+
with col2:
|
| 257 |
+
clear_button = st.button("🗑️ Clear", use_container_width=True)
|
| 258 |
+
|
| 259 |
+
if clear_button:
|
| 260 |
+
st.session_state.chat_history = []
|
| 261 |
+
st.rerun()
|
| 262 |
+
|
| 263 |
+
if send_button and query:
|
| 264 |
+
# Save uploaded file if present
|
| 265 |
+
image_path = None
|
| 266 |
+
if uploaded_file and selected_method == 'vision':
|
| 267 |
+
image_path = save_uploaded_file(uploaded_file)
|
| 268 |
+
|
| 269 |
+
# Show spinner while processing
|
| 270 |
+
with st.spinner(f"Searching using {selected_method.upper()} method..."):
|
| 271 |
+
start_time = time.time()
|
| 272 |
+
error_message = None
|
| 273 |
+
answer = ""
|
| 274 |
+
citations = []
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
# Get the appropriate query function
|
| 278 |
+
query_func = QUERY_DISPATCH[selected_method]
|
| 279 |
+
|
| 280 |
+
# Call the query function
|
| 281 |
+
if selected_method == 'vision' and not image_path:
|
| 282 |
+
error_message = "Please upload an image for vision-based search"
|
| 283 |
+
st.error(error_message)
|
| 284 |
+
else:
|
| 285 |
+
answer, citations = query_func(query, image_path=image_path, top_k=top_k)
|
| 286 |
+
|
| 287 |
+
# Add to history
|
| 288 |
+
st.session_state.chat_history.append({
|
| 289 |
+
'query': query,
|
| 290 |
+
'answer': answer,
|
| 291 |
+
'citations': citations,
|
| 292 |
+
'method': selected_method,
|
| 293 |
+
'image_path': image_path
|
| 294 |
+
})
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
error_message = str(e)
|
| 298 |
+
answer = f"Error: {error_message}"
|
| 299 |
+
st.error(f"Error processing query: {error_message}")
|
| 300 |
+
st.info("Make sure you've run preprocess.py to generate the required indices.")
|
| 301 |
+
|
| 302 |
+
finally:
|
| 303 |
+
# Log query to analytics database (always, even on error)
|
| 304 |
+
response_time = (time.time() - start_time) * 1000 # Convert to ms
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
log_query(
|
| 308 |
+
user_query=query,
|
| 309 |
+
method=selected_method,
|
| 310 |
+
answer=answer,
|
| 311 |
+
citations=citations,
|
| 312 |
+
response_time=response_time,
|
| 313 |
+
image_path=image_path,
|
| 314 |
+
error_message=error_message,
|
| 315 |
+
top_k=top_k,
|
| 316 |
+
session_id=st.session_state.session_id
|
| 317 |
+
)
|
| 318 |
+
except Exception as log_error:
|
| 319 |
+
logger.error(f"Failed to log query: {log_error}")
|
| 320 |
+
|
| 321 |
+
# Clean up temp file
|
| 322 |
+
if image_path and os.path.exists(image_path):
|
| 323 |
+
os.unlink(image_path)
|
| 324 |
+
|
| 325 |
+
# Display chat history
|
| 326 |
+
if st.session_state.chat_history:
|
| 327 |
+
st.markdown("---")
|
| 328 |
+
st.markdown("### Chat History")
|
| 329 |
+
|
| 330 |
+
for i, entry in enumerate(reversed(st.session_state.chat_history)):
|
| 331 |
+
with st.container():
|
| 332 |
+
# User message
|
| 333 |
+
st.markdown(f"**🧑 You** ({entry['method'].upper()}):")
|
| 334 |
+
st.markdown(entry['query'])
|
| 335 |
+
|
| 336 |
+
# Assistant response
|
| 337 |
+
st.markdown("**🤖 Assistant:**")
|
| 338 |
+
st.markdown(entry['answer'])
|
| 339 |
+
|
| 340 |
+
# Citations
|
| 341 |
+
st.markdown(format_citations_html(entry['citations'], entry['method']), unsafe_allow_html=True)
|
| 342 |
+
|
| 343 |
+
if i < len(st.session_state.chat_history) - 1:
|
| 344 |
+
st.markdown("---")
|
| 345 |
+
|
| 346 |
+
with tab2:
|
| 347 |
+
st.markdown("### Method Comparison")
|
| 348 |
+
st.markdown("Compare results from different retrieval methods for the same query.")
|
| 349 |
+
|
| 350 |
+
comparison_query = st.text_input(
|
| 351 |
+
"Enter a query to compare across methods:",
|
| 352 |
+
placeholder="E.g., What are the requirements for machine guards?",
|
| 353 |
+
key="comparison_input"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
methods_to_compare = st.multiselect(
|
| 357 |
+
"Select methods to compare:",
|
| 358 |
+
options=['graph', 'vanilla', 'dpr', 'bm25', 'context'],
|
| 359 |
+
default=['vanilla', 'bm25'],
|
| 360 |
+
help="Vision method requires an image and is not included in comparison"
|
| 361 |
)
|
| 362 |
+
|
| 363 |
+
col1, col2 = st.columns([3, 1])
|
| 364 |
+
with col1:
|
| 365 |
+
compare_button = st.button("🔍 Compare Methods", type="primary")
|
| 366 |
+
with col2:
|
| 367 |
+
if 'comparison_results' in st.session_state and st.session_state.comparison_results:
|
| 368 |
+
if st.button("🪟 Full Screen View", help="View results in a dedicated comparison window"):
|
| 369 |
+
st.session_state.show_comparison_window = True
|
| 370 |
+
st.rerun()
|
| 371 |
+
|
| 372 |
+
if compare_button:
|
| 373 |
+
if comparison_query and methods_to_compare:
|
| 374 |
+
results = {}
|
| 375 |
+
|
| 376 |
+
progress_bar = st.progress(0)
|
| 377 |
+
for idx, method in enumerate(methods_to_compare):
|
| 378 |
+
with st.spinner(f"Running {method.upper()}..."):
|
| 379 |
+
start_time = time.time()
|
| 380 |
+
error_message = None
|
| 381 |
+
|
| 382 |
+
try:
|
| 383 |
+
query_func = QUERY_DISPATCH[method]
|
| 384 |
+
answer, citations = query_func(comparison_query, top_k=top_k)
|
| 385 |
+
results[method] = {
|
| 386 |
+
'answer': answer,
|
| 387 |
+
'citations': citations
|
| 388 |
+
}
|
| 389 |
+
except Exception as e:
|
| 390 |
+
error_message = str(e)
|
| 391 |
+
answer = f"Error: {error_message}"
|
| 392 |
+
citations = []
|
| 393 |
+
results[method] = {
|
| 394 |
+
'answer': answer,
|
| 395 |
+
'citations': citations
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
finally:
|
| 399 |
+
# Log comparison queries too
|
| 400 |
+
response_time = (time.time() - start_time) * 1000
|
| 401 |
+
try:
|
| 402 |
+
log_query(
|
| 403 |
+
user_query=comparison_query,
|
| 404 |
+
method=method,
|
| 405 |
+
answer=results[method]['answer'],
|
| 406 |
+
citations=results[method]['citations'],
|
| 407 |
+
response_time=response_time,
|
| 408 |
+
error_message=error_message,
|
| 409 |
+
top_k=top_k,
|
| 410 |
+
session_id=st.session_state.session_id,
|
| 411 |
+
additional_settings={'comparison_mode': True}
|
| 412 |
+
)
|
| 413 |
+
except Exception as log_error:
|
| 414 |
+
logger.error(f"Failed to log comparison query: {log_error}")
|
| 415 |
+
|
| 416 |
+
progress_bar.progress((idx + 1) / len(methods_to_compare))
|
| 417 |
+
|
| 418 |
+
# Store results in session state for full screen view
|
| 419 |
+
st.session_state.comparison_results = {
|
| 420 |
+
'query': comparison_query,
|
| 421 |
+
'methods': methods_to_compare,
|
| 422 |
+
'results': results,
|
| 423 |
+
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
# Display results in compact columns
|
| 427 |
+
cols = st.columns(len(methods_to_compare))
|
| 428 |
+
for idx, (method, col) in enumerate(zip(methods_to_compare, cols)):
|
| 429 |
+
with col:
|
| 430 |
+
st.markdown(f"#### {method.upper()}")
|
| 431 |
+
|
| 432 |
+
# Use expandable container for full text without truncation
|
| 433 |
+
answer = results[method]['answer']
|
| 434 |
+
if len(answer) > 800:
|
| 435 |
+
# Show first 300 chars, then expandable for full text
|
| 436 |
+
st.markdown(answer[:300] + "...")
|
| 437 |
+
with st.expander("📖 Show full answer"):
|
| 438 |
+
st.markdown(answer)
|
| 439 |
+
else:
|
| 440 |
+
# Short answers display fully
|
| 441 |
+
st.markdown(answer)
|
| 442 |
+
|
| 443 |
+
st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True)
|
| 444 |
+
else:
|
| 445 |
+
st.warning("Please enter a query and select at least one method to compare.")
|
| 446 |
+
|
| 447 |
+
with tab3:
|
| 448 |
+
st.markdown("### 🔊 Voice Chat - Hands-free AI Assistant")
|
| 449 |
+
|
| 450 |
+
# Server status check
|
| 451 |
+
server_healthy = check_realtime_server_health()
|
| 452 |
+
if server_healthy:
|
| 453 |
+
st.success("✅ **Voice Server Online** - Ready for voice interactions")
|
| 454 |
+
else:
|
| 455 |
+
st.error("❌ **Voice Server Offline** - Please start the realtime server: `python realtime_server.py`")
|
| 456 |
+
st.code("python realtime_server.py", language="bash")
|
| 457 |
+
st.stop()
|
| 458 |
+
|
| 459 |
+
st.info(
|
| 460 |
+
"🎤 **Real-time Voice Interaction**: Speak naturally and get instant responses from your chosen RAG method. "
|
| 461 |
+
"The AI will automatically transcribe your speech, search the knowledge base, and respond with synthesized voice."
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# Voice Chat Status and Configuration
|
| 465 |
+
col1, col2 = st.columns([2, 1])
|
| 466 |
+
|
| 467 |
+
with col1:
|
| 468 |
+
# Use the same method from sidebar
|
| 469 |
+
st.info(f"🔍 **Voice using {selected_method.upper()} method** (change in sidebar)")
|
| 470 |
+
|
| 471 |
+
with col2:
|
| 472 |
+
# Voice settings (simplified)
|
| 473 |
+
voice_choice = st.selectbox(
|
| 474 |
+
"🎙️ AI Voice:",
|
| 475 |
+
["alloy", "echo", "fable", "onyx", "nova", "shimmer"],
|
| 476 |
+
index=0,
|
| 477 |
+
help="Select the AI voice for responses"
|
| 478 |
+
)
|
| 479 |
+
response_speed = st.slider(
|
| 480 |
+
"⏱️ Response Speed (seconds):",
|
| 481 |
+
min_value=1, max_value=5, value=2,
|
| 482 |
+
help="How quickly the AI should respond after you stop speaking"
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Auto-detect server URL (hide the manual configuration)
|
| 486 |
+
server_url = "http://localhost:5050" # Default for local development
|
| 487 |
+
|
| 488 |
+
# Voice Chat Interface
|
| 489 |
+
st.markdown("---")
|
| 490 |
+
|
| 491 |
+
# Initialize voice chat session state
|
| 492 |
+
if 'voice_chat_history' not in st.session_state:
|
| 493 |
+
st.session_state.voice_chat_history = []
|
| 494 |
+
if 'voice_session_active' not in st.session_state:
|
| 495 |
+
st.session_state.voice_session_active = False
|
| 496 |
+
|
| 497 |
+
# Simple Status Display
|
| 498 |
+
if st.session_state.voice_session_active:
|
| 499 |
+
st.success("🔴 **LIVE** - Voice chat active using " + voice_method.upper())
|
| 500 |
+
|
| 501 |
+
# Enhanced Voice Interface with better UX
|
| 502 |
+
components.html(f"""
|
| 503 |
+
<!DOCTYPE html>
|
| 504 |
+
<html>
|
| 505 |
+
<head>
|
| 506 |
+
<meta charset="utf-8" />
|
| 507 |
+
<style>
|
| 508 |
+
body {{
|
| 509 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 510 |
+
padding: 20px;
|
| 511 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 512 |
+
color: white;
|
| 513 |
+
border-radius: 10px;
|
| 514 |
+
}}
|
| 515 |
+
.container {{
|
| 516 |
+
max-width: 800px;
|
| 517 |
+
margin: 0 auto;
|
| 518 |
+
background: rgba(255,255,255,0.1);
|
| 519 |
+
padding: 30px;
|
| 520 |
+
border-radius: 15px;
|
| 521 |
+
backdrop-filter: blur(10px);
|
| 522 |
+
}}
|
| 523 |
+
.controls {{
|
| 524 |
+
display: flex;
|
| 525 |
+
gap: 20px;
|
| 526 |
+
align-items: center;
|
| 527 |
+
justify-content: center;
|
| 528 |
+
margin-bottom: 30px;
|
| 529 |
+
}}
|
| 530 |
+
.status-display {{
|
| 531 |
+
text-align: center;
|
| 532 |
+
margin: 20px 0;
|
| 533 |
+
padding: 15px;
|
| 534 |
+
border-radius: 10px;
|
| 535 |
+
background: rgba(255,255,255,0.2);
|
| 536 |
+
}}
|
| 537 |
+
.status-idle {{ background: rgba(108, 117, 125, 0.3); }}
|
| 538 |
+
.status-connecting {{ background: rgba(255, 193, 7, 0.3); }}
|
| 539 |
+
.status-active {{ background: rgba(40, 167, 69, 0.3); }}
|
| 540 |
+
.status-error {{ background: rgba(220, 53, 69, 0.3); }}
|
| 541 |
+
|
| 542 |
+
button {{
|
| 543 |
+
padding: 12px 24px;
|
| 544 |
+
font-size: 16px;
|
| 545 |
+
border: none;
|
| 546 |
+
border-radius: 25px;
|
| 547 |
+
cursor: pointer;
|
| 548 |
+
transition: all 0.3s ease;
|
| 549 |
+
font-weight: bold;
|
| 550 |
+
}}
|
| 551 |
+
.start-btn {{
|
| 552 |
+
background: linear-gradient(45deg, #28a745, #20c997);
|
| 553 |
+
color: white;
|
| 554 |
+
}}
|
| 555 |
+
.start-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(40,167,69,0.4); }}
|
| 556 |
+
.start-btn:disabled {{
|
| 557 |
+
background: #6c757d;
|
| 558 |
+
cursor: not-allowed;
|
| 559 |
+
transform: none;
|
| 560 |
+
box-shadow: none;
|
| 561 |
+
}}
|
| 562 |
+
.stop-btn {{
|
| 563 |
+
background: linear-gradient(45deg, #dc3545, #fd7e14);
|
| 564 |
+
color: white;
|
| 565 |
+
}}
|
| 566 |
+
.stop-btn:hover {{ transform: translateY(-2px); box-shadow: 0 4px 12px rgba(220,53,69,0.4); }}
|
| 567 |
+
.stop-btn:disabled {{
|
| 568 |
+
background: #6c757d;
|
| 569 |
+
cursor: not-allowed;
|
| 570 |
+
transform: none;
|
| 571 |
+
box-shadow: none;
|
| 572 |
+
}}
|
| 573 |
+
|
| 574 |
+
.log {{
|
| 575 |
+
height: 200px;
|
| 576 |
+
overflow-y: auto;
|
| 577 |
+
border: 1px solid rgba(255,255,255,0.3);
|
| 578 |
+
padding: 15px;
|
| 579 |
+
background: rgba(0,0,0,0.2);
|
| 580 |
+
border-radius: 10px;
|
| 581 |
+
font-family: 'Monaco', 'Menlo', monospace;
|
| 582 |
+
font-size: 13px;
|
| 583 |
+
line-height: 1.4;
|
| 584 |
+
}}
|
| 585 |
+
.audio-controls {{
|
| 586 |
+
text-align: center;
|
| 587 |
+
margin: 20px 0;
|
| 588 |
+
}}
|
| 589 |
+
.pulse {{
|
| 590 |
+
animation: pulse 2s infinite;
|
| 591 |
+
}}
|
| 592 |
+
@keyframes pulse {{
|
| 593 |
+
0% {{ transform: scale(1); }}
|
| 594 |
+
50% {{ transform: scale(1.05); }}
|
| 595 |
+
100% {{ transform: scale(1); }}
|
| 596 |
+
}}
|
| 597 |
+
.visualizer {{
|
| 598 |
+
width: 100%;
|
| 599 |
+
height: 60px;
|
| 600 |
+
background: rgba(0,0,0,0.2);
|
| 601 |
+
border-radius: 10px;
|
| 602 |
+
margin: 10px 0;
|
| 603 |
+
display: flex;
|
| 604 |
+
align-items: center;
|
| 605 |
+
justify-content: center;
|
| 606 |
+
font-size: 14px;
|
| 607 |
+
}}
|
| 608 |
+
</style>
|
| 609 |
+
</head>
|
| 610 |
+
<body>
|
| 611 |
+
<div class="container">
|
| 612 |
+
<div class="status-display status-idle" id="statusDisplay">
|
| 613 |
+
<h3 id="statusTitle">🎤 Voice Chat</h3>
|
| 614 |
+
<p id="statusText">Click "Start Listening" to begin</p>
|
| 615 |
+
</div>
|
| 616 |
+
|
| 617 |
+
<div class="controls">
|
| 618 |
+
<button id="startBtn" class="start-btn">🎤 Start Listening</button>
|
| 619 |
+
<button id="stopBtn" class="stop-btn" disabled>⏹️ Stop</button>
|
| 620 |
+
</div>
|
| 621 |
+
|
| 622 |
+
<div class="audio-controls">
|
| 623 |
+
<audio id="remoteAudio" autoplay style="width: 100%; max-width: 400px;"></audio>
|
| 624 |
+
</div>
|
| 625 |
+
|
| 626 |
+
<div class="visualizer" id="visualizer">
|
| 627 |
+
🔇 Audio will appear here when active
|
| 628 |
+
</div>
|
| 629 |
+
|
| 630 |
+
<div class="log" id="log"></div>
|
| 631 |
+
</div>
|
| 632 |
+
|
| 633 |
+
<script>
|
| 634 |
+
(async () => {{
|
| 635 |
+
const serverBase = {server_url!r};
|
| 636 |
+
const chosenMethod = {selected_method!r};
|
| 637 |
+
const voiceChoice = {voice_choice!r};
|
| 638 |
+
const responseSpeed = {response_speed!r};
|
| 639 |
+
|
| 640 |
+
const logEl = document.getElementById('log');
|
| 641 |
+
const statusDisplay = document.getElementById('statusDisplay');
|
| 642 |
+
const statusTitle = document.getElementById('statusTitle');
|
| 643 |
+
const statusText = document.getElementById('statusText');
|
| 644 |
+
const startBtn = document.getElementById('startBtn');
|
| 645 |
+
const stopBtn = document.getElementById('stopBtn');
|
| 646 |
+
const visualizer = document.getElementById('visualizer');
|
| 647 |
+
|
| 648 |
+
let pc, dc, micStream;
|
| 649 |
+
let isConnected = false;
|
| 650 |
+
let questionStartTime = null;
|
| 651 |
+
|
| 652 |
+
function updateStatus(status, title, text, className) {{
|
| 653 |
+
statusDisplay.className = `status-display ${{className}}`;
|
| 654 |
+
statusTitle.textContent = title;
|
| 655 |
+
statusText.textContent = text;
|
| 656 |
+
}}
|
| 657 |
+
|
| 658 |
+
function log(msg, type = 'info') {{
|
| 659 |
+
const timestamp = new Date().toLocaleTimeString();
|
| 660 |
+
const icon = type === 'error' ? '❌' : type === 'success' ? '✅' : type === 'warning' ? '⚠️' : 'ℹ️';
|
| 661 |
+
logEl.innerHTML += `<div>${{timestamp}} ${{icon}} ${{msg}}</div>`;
|
| 662 |
+
logEl.scrollTop = logEl.scrollHeight;
|
| 663 |
+
}}
|
| 664 |
|
| 665 |
+
async function start() {{
|
| 666 |
+
startBtn.disabled = true;
|
| 667 |
+
stopBtn.disabled = false;
|
| 668 |
+
updateStatus('connecting', '🔄 Connecting...', 'Establishing secure connection to voice services', 'status-connecting');
|
| 669 |
+
|
| 670 |
+
try {{
|
| 671 |
+
log('Initializing voice session...', 'info');
|
| 672 |
+
|
| 673 |
+
// 1) Fetch ephemeral session token
|
| 674 |
+
const sessResp = await fetch(serverBase + "/session", {{
|
| 675 |
+
method: "POST",
|
| 676 |
+
headers: {{ "Content-Type": "application/json" }},
|
| 677 |
+
body: JSON.stringify({{ voice: voiceChoice }})
|
| 678 |
+
}});
|
| 679 |
+
|
| 680 |
+
if (!sessResp.ok) {{
|
| 681 |
+
throw new Error(`Server error: ${{sessResp.status}} ${{sessResp.statusText}}`);
|
| 682 |
+
}}
|
| 683 |
+
|
| 684 |
+
const sess = await sessResp.json();
|
| 685 |
+
if (sess.error) throw new Error(sess.error);
|
| 686 |
+
|
| 687 |
+
const EPHEMERAL_KEY = sess.client_secret;
|
| 688 |
+
if (!EPHEMERAL_KEY) throw new Error("No ephemeral token from server");
|
| 689 |
+
|
| 690 |
+
log('✅ Session token obtained', 'success');
|
| 691 |
+
|
| 692 |
+
// 2) Setup WebRTC
|
| 693 |
+
pc = new RTCPeerConnection();
|
| 694 |
+
const remoteAudio = document.getElementById('remoteAudio');
|
| 695 |
+
pc.ontrack = (event) => {{
|
| 696 |
+
log('🔊 Audio track received from OpenAI', 'success');
|
| 697 |
+
console.log('WebRTC track event:', event);
|
| 698 |
+
const stream = event.streams[0];
|
| 699 |
+
if (stream && stream.getAudioTracks().length > 0) {{
|
| 700 |
+
remoteAudio.srcObject = stream;
|
| 701 |
+
visualizer.textContent = '🔊 Audio stream connected - AI can speak';
|
| 702 |
+
log(`🎵 Audio tracks: ${{stream.getAudioTracks().length}}`, 'success');
|
| 703 |
+
}} else {{
|
| 704 |
+
log('⚠️ No audio tracks in stream', 'warning');
|
| 705 |
+
visualizer.textContent = '⚠️ No audio stream received';
|
| 706 |
+
}}
|
| 707 |
+
}};
|
| 708 |
|
| 709 |
+
// 3) Create data channel
|
| 710 |
+
dc = pc.createDataChannel("oai-data");
|
| 711 |
+
dc.onopen = () => {{
|
| 712 |
+
log('🔗 Data channel established', 'success');
|
| 713 |
+
}};
|
| 714 |
+
dc.onerror = (error) => {{
|
| 715 |
+
log('❌ Data channel error: ' + error, 'error');
|
| 716 |
+
}};
|
| 717 |
+
dc.onmessage = (e) => handleDataMessage(e);
|
|
|
|
| 718 |
|
| 719 |
+
// 4) Get microphone
|
| 720 |
+
log('🎤 Requesting microphone access...', 'info');
|
| 721 |
+
micStream = await navigator.mediaDevices.getUserMedia({{ audio: true }});
|
| 722 |
+
log('✅ Microphone access granted', 'success');
|
| 723 |
+
visualizer.textContent = '🎤 Microphone active - speak naturally';
|
| 724 |
+
|
| 725 |
+
for (const track of micStream.getTracks()) {{
|
| 726 |
+
pc.addTrack(track, micStream);
|
| 727 |
+
}}
|
| 728 |
|
| 729 |
+
// 5) Setup audio receiving
|
| 730 |
+
pc.addTransceiver("audio", {{ direction: "recvonly" }});
|
| 731 |
+
log('🔊 Audio receiver configured', 'success');
|
| 732 |
+
|
| 733 |
+
// 6) Create and set local description
|
| 734 |
+
const offer = await pc.createOffer();
|
| 735 |
+
await pc.setLocalDescription(offer);
|
| 736 |
+
log('📡 WebRTC offer created', 'success');
|
| 737 |
+
|
| 738 |
+
// 7) Exchange SDP with OpenAI Realtime
|
| 739 |
+
const baseUrl = "https://api.openai.com/v1/realtime";
|
| 740 |
+
const model = sess.model || "gpt-4o-realtime-preview";
|
| 741 |
+
const sdpResp = await fetch(`${{baseUrl}}?model=${{encodeURIComponent(model)}}`, {{
|
| 742 |
+
method: "POST",
|
| 743 |
+
body: offer.sdp,
|
| 744 |
+
headers: {{
|
| 745 |
+
Authorization: `Bearer ${{EPHEMERAL_KEY}}`,
|
| 746 |
+
"Content-Type": "application/sdp"
|
| 747 |
+
}}
|
| 748 |
+
}});
|
| 749 |
+
|
| 750 |
+
if (!sdpResp.ok) throw new Error(`WebRTC setup failed: ${{sdpResp.status}}`);
|
| 751 |
+
|
| 752 |
+
const answer = {{ type: "answer", sdp: await sdpResp.text() }};
|
| 753 |
+
await pc.setRemoteDescription(answer);
|
| 754 |
+
|
| 755 |
+
// 8) Configure the session with tools and faster response
|
| 756 |
+
// Wait a bit to ensure data channel is fully ready
|
| 757 |
+
setTimeout(() => {{
|
| 758 |
+
if (dc.readyState === 'open') {{
|
| 759 |
+
const toolDecl = {{
|
| 760 |
+
type: "session.update",
|
| 761 |
+
session: {{
|
| 762 |
+
tools: [{{
|
| 763 |
+
"type": "function",
|
| 764 |
+
"name": "ask_rag",
|
| 765 |
+
"description": "Search the safety knowledge base for accurate, authoritative information. Call this immediately when users ask safety questions to get current, reliable information with proper citations.",
|
| 766 |
+
"parameters": {{
|
| 767 |
+
"type": "object",
|
| 768 |
+
"properties": {{
|
| 769 |
+
"query": {{ "type": "string", "description": "User's safety question" }},
|
| 770 |
+
"top_k": {{ "type": "integer", "minimum": 1, "maximum": 20, "default": 5 }}
|
| 771 |
+
}},
|
| 772 |
+
"required": ["query"]
|
| 773 |
+
}}
|
| 774 |
+
}}],
|
| 775 |
+
turn_detection: {{
|
| 776 |
+
type: "server_vad",
|
| 777 |
+
threshold: 0.5,
|
| 778 |
+
prefix_padding_ms: 300,
|
| 779 |
+
silence_duration_ms: {response_speed * 1000}
|
| 780 |
+
}},
|
| 781 |
+
input_audio_transcription: {{
|
| 782 |
+
model: "whisper-1"
|
| 783 |
+
}},
|
| 784 |
+
voice: voiceChoice,
|
| 785 |
+
temperature: 0.7, // Higher temperature for more natural speech
|
| 786 |
+
max_response_output_tokens: 1000, // Allow full responses
|
| 787 |
+
modalities: ["audio", "text"],
|
| 788 |
+
response_format: "audio"
|
| 789 |
+
}}
|
| 790 |
+
}};
|
| 791 |
+
dc.send(JSON.stringify(toolDecl));
|
| 792 |
+
log('🛠️ RAG tools configured', 'success');
|
| 793 |
+
|
| 794 |
+
// Send initial conversation starter to prime the model for natural interaction
|
| 795 |
+
const initialMessage = {{
|
| 796 |
+
type: "conversation.item.create",
|
| 797 |
+
item: {{
|
| 798 |
+
type: "message",
|
| 799 |
+
role: "user",
|
| 800 |
+
content: [{{
|
| 801 |
+
type: "input_text",
|
| 802 |
+
text: "Hello! I'm ready to ask you questions about machine safety. Please speak naturally like a safety expert - no need to mention specific documents or sources, just give me the information as your expertise."
|
| 803 |
+
}}]
|
| 804 |
+
}}
|
| 805 |
+
}};
|
| 806 |
+
dc.send(JSON.stringify(initialMessage));
|
| 807 |
+
|
| 808 |
+
const responseRequest = {{
|
| 809 |
+
type: "response.create",
|
| 810 |
+
response: {{
|
| 811 |
+
modalities: ["audio"],
|
| 812 |
+
instructions: "Acknowledge briefly that you're ready to help with safety questions. Speak naturally and confidently as a safety expert - no citations or document references needed."
|
| 813 |
+
}}
|
| 814 |
+
}};
|
| 815 |
+
dc.send(JSON.stringify(responseRequest));
|
| 816 |
+
|
| 817 |
+
}} else {{
|
| 818 |
+
log('⚠️ Data channel not ready, retrying...', 'warning');
|
| 819 |
+
// Retry after another second
|
| 820 |
+
setTimeout(() => {{
|
| 821 |
+
if (dc.readyState === 'open') {{
|
| 822 |
+
dc.send(JSON.stringify(toolDecl));
|
| 823 |
+
log('🛠️ RAG tools configured (retry)', 'success');
|
| 824 |
+
}}
|
| 825 |
+
}}, 1000);
|
| 826 |
+
}}
|
| 827 |
+
}}, 500);
|
| 828 |
+
|
| 829 |
+
isConnected = true;
|
| 830 |
+
updateStatus('active', '🎤 Live - Speak Now!', `Using ${{chosenMethod.toUpperCase()}} method • Voice: ${{voiceChoice}} • Response: ${{responseSpeed}}s`, 'status-active');
|
| 831 |
+
startBtn.classList.add('pulse');
|
| 832 |
+
|
| 833 |
+
}} catch (error) {{
|
| 834 |
+
log(`❌ Connection failed: ${{error.message}}`, 'error');
|
| 835 |
+
updateStatus('error', '❌ Connection Failed', error.message, 'status-error');
|
| 836 |
+
startBtn.disabled = false;
|
| 837 |
+
stopBtn.disabled = true;
|
| 838 |
+
cleanup();
|
| 839 |
+
}}
|
| 840 |
+
}}
|
| 841 |
+
|
| 842 |
+
function cleanup() {{
|
| 843 |
+
try {{
|
| 844 |
+
if (dc && dc.readyState === 'open') dc.close();
|
| 845 |
+
if (pc) pc.close();
|
| 846 |
+
if (micStream) micStream.getTracks().forEach(t => t.stop());
|
| 847 |
+
}} catch (e) {{ /* ignore cleanup errors */ }}
|
| 848 |
+
startBtn.classList.remove('pulse');
|
| 849 |
+
visualizer.textContent = '🔇 Audio inactive';
|
| 850 |
+
}}
|
| 851 |
+
|
| 852 |
+
async function stop() {{
|
| 853 |
+
startBtn.disabled = false;
|
| 854 |
+
stopBtn.disabled = true;
|
| 855 |
+
isConnected = false;
|
| 856 |
+
updateStatus('idle', '⚪ Session Ended', 'Click "Start Listening" to begin a new voice session', 'status-idle');
|
| 857 |
+
log('🛑 Voice session terminated', 'info');
|
| 858 |
+
cleanup();
|
| 859 |
+
}}
|
| 860 |
+
|
| 861 |
+
// Handle realtime events
|
| 862 |
+
async function handleDataMessage(e) {{
|
| 863 |
+
if (!isConnected) return;
|
| 864 |
+
|
| 865 |
+
try {{
|
| 866 |
+
const msg = JSON.parse(e.data);
|
| 867 |
+
|
| 868 |
+
if (msg.type === "response.function_call") {{
|
| 869 |
+
const {{ name, call_id, arguments: args }} = msg;
|
| 870 |
+
if (name === "ask_rag") {{
|
| 871 |
+
visualizer.textContent = '✅ Question received - searching...';
|
| 872 |
+
const query = JSON.parse(args || "{{}}").query;
|
| 873 |
+
log(`✅ AI heard: "${{query}}"`, 'success');
|
| 874 |
+
log('🔍 Searching knowledge base...', 'info');
|
| 875 |
+
|
| 876 |
+
// Store the transcribed query for analytics
|
| 877 |
+
window.lastVoiceQuery = query;
|
| 878 |
+
|
| 879 |
+
const payload = JSON.parse(args || "{{}}");
|
| 880 |
+
const ragResp = await fetch("${{serverBase}}/rag", {{
|
| 881 |
+
method: "POST",
|
| 882 |
+
headers: {{ "Content-Type": "application/json" }},
|
| 883 |
+
body: JSON.stringify({{
|
| 884 |
+
query: payload.query,
|
| 885 |
+
top_k: payload.top_k ?? 5,
|
| 886 |
+
method: chosenMethod
|
| 887 |
+
}})
|
| 888 |
+
}});
|
| 889 |
+
|
| 890 |
+
const rag = await ragResp.json();
|
| 891 |
+
|
| 892 |
+
// Send result back to model (check if data channel is still open)
|
| 893 |
+
if (dc && dc.readyState === 'open') {{
|
| 894 |
+
dc.send(JSON.stringify({{
|
| 895 |
+
type: "response.function_call_result",
|
| 896 |
+
call_id,
|
| 897 |
+
output: JSON.stringify({{
|
| 898 |
+
answer: rag.answer,
|
| 899 |
+
instruction: "Speak this information naturally as your expertise. Do not mention sources or documents."
|
| 900 |
+
}})
|
| 901 |
+
}}));
|
| 902 |
+
}} else {{
|
| 903 |
+
log('⚠️ Data channel closed, cannot send result', 'warning');
|
| 904 |
+
}}
|
| 905 |
+
|
| 906 |
+
const searchTime = ((Date.now() - questionStartTime) / 1000).toFixed(1);
|
| 907 |
+
log(`✅ Found ${{rag.citations?.length || 0}} citations in ${{searchTime}}s`, 'success');
|
| 908 |
+
visualizer.textContent = '🎙️ AI is speaking your answer...';
|
| 909 |
+
}}
|
| 910 |
+
}}
|
| 911 |
+
|
| 912 |
+
if (msg.type === "input_audio_buffer.speech_started") {{
|
| 913 |
+
questionStartTime = Date.now();
|
| 914 |
+
visualizer.textContent = '🎙️ Listening to you...';
|
| 915 |
+
log('🎤 Speech detected', 'info');
|
| 916 |
+
}}
|
| 917 |
+
|
| 918 |
+
if (msg.type === "input_audio_buffer.speech_stopped") {{
|
| 919 |
+
visualizer.textContent = '🤔 Processing your question...';
|
| 920 |
+
log('⏸️ Processing speech...', 'info');
|
| 921 |
+
}}
|
| 922 |
+
|
| 923 |
+
if (msg.type === "response.audio.delta") {{
|
| 924 |
+
visualizer.textContent = '🔊 AI speaking...';
|
| 925 |
+
}}
|
| 926 |
+
|
| 927 |
+
if (msg.type === "response.done") {{
|
| 928 |
+
if (questionStartTime) {{
|
| 929 |
+
const totalTime = ((Date.now() - questionStartTime) / 1000).toFixed(1);
|
| 930 |
+
visualizer.textContent = '🎤 Your turn - speak now';
|
| 931 |
+
log(`✅ Response complete in ${{totalTime}}s`, 'success');
|
| 932 |
+
questionStartTime = null;
|
| 933 |
+
}} else {{
|
| 934 |
+
visualizer.textContent = '🎤 Your turn - speak now';
|
| 935 |
+
log('✅ Response complete', 'success');
|
| 936 |
+
}}
|
| 937 |
+
}}
|
| 938 |
+
|
| 939 |
+
}} catch (err) {{
|
| 940 |
+
// Ignore non-JSON messages
|
| 941 |
+
}}
|
| 942 |
+
}}
|
| 943 |
+
|
| 944 |
+
startBtn.onclick = start;
|
| 945 |
+
stopBtn.onclick = stop;
|
| 946 |
+
|
| 947 |
+
// Initialize
|
| 948 |
+
log('🚀 Voice chat interface loaded', 'success');
|
| 949 |
+
}})();
|
| 950 |
+
</script>
|
| 951 |
+
</body>
|
| 952 |
+
</html>
|
| 953 |
+
""", height=600, scrolling=True)
|
| 954 |
+
|
| 955 |
+
# Voice Chat History
|
| 956 |
+
if st.session_state.voice_chat_history:
|
| 957 |
+
st.markdown("### 🗣️ Recent Voice Conversations")
|
| 958 |
+
for i, entry in enumerate(reversed(st.session_state.voice_chat_history[-5:])):
|
| 959 |
+
with st.expander(f"🎤 Conversation {len(st.session_state.voice_chat_history)-i} - {entry.get('method', 'unknown').upper()}"):
|
| 960 |
+
st.write(f"**Query**: {entry.get('query', 'N/A')}")
|
| 961 |
+
st.write(f"**Response**: {entry.get('answer', 'N/A')[:200]}...")
|
| 962 |
+
st.write(f"**Citations**: {len(entry.get('citations', []))}")
|
| 963 |
+
st.write(f"**Timestamp**: {entry.get('timestamp', 'N/A')}")
|
| 964 |
+
|
| 965 |
+
with tab4:
|
| 966 |
+
st.markdown("### 📊 Analytics Dashboard")
|
| 967 |
+
st.markdown("*Persistent analytics from all user interactions*")
|
| 968 |
+
|
| 969 |
+
# Time period selector
|
| 970 |
+
col1, col2 = st.columns([3, 1])
|
| 971 |
+
with col1:
|
| 972 |
+
st.markdown("")
|
| 973 |
+
with col2:
|
| 974 |
+
days_filter = st.selectbox("Time Period", [7, 30, 90, 365], index=1, format_func=lambda x: f"Last {x} days")
|
| 975 |
+
|
| 976 |
+
# Get analytics data
|
| 977 |
+
try:
|
| 978 |
+
stats = get_analytics_stats(days=days_filter)
|
| 979 |
+
performance = get_method_performance()
|
| 980 |
+
recent_queries = analytics_db.get_recent_queries(limit=10)
|
| 981 |
+
|
| 982 |
+
# Overview Metrics
|
| 983 |
+
st.markdown("#### 📈 Overview")
|
| 984 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 985 |
+
|
| 986 |
+
with col1:
|
| 987 |
+
st.metric(
|
| 988 |
+
"Total Queries",
|
| 989 |
+
stats.get('total_queries', 0),
|
| 990 |
+
help="All queries processed in the selected time period"
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
with col2:
|
| 994 |
+
avg_citations = stats.get('avg_citations', 0)
|
| 995 |
+
st.metric(
|
| 996 |
+
"Avg Citations",
|
| 997 |
+
f"{avg_citations:.1f}",
|
| 998 |
+
help="Average number of citations per query"
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
with col3:
|
| 1002 |
+
error_rate = stats.get('error_rate', 0)
|
| 1003 |
+
st.metric(
|
| 1004 |
+
"Success Rate",
|
| 1005 |
+
f"{100 - error_rate:.1f}%",
|
| 1006 |
+
delta=f"-{error_rate:.1f}% errors" if error_rate > 0 else None,
|
| 1007 |
+
help="Percentage of successful queries"
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
with col4:
|
| 1011 |
+
total_citations = stats.get('total_citations', 0)
|
| 1012 |
+
st.metric(
|
| 1013 |
+
"Total Citations",
|
| 1014 |
+
total_citations,
|
| 1015 |
+
help="Total citations generated across all queries"
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# Method Performance Comparison
|
| 1019 |
+
if performance:
|
| 1020 |
+
st.markdown("#### ⚡ Method Performance")
|
| 1021 |
+
|
| 1022 |
+
perf_data = []
|
| 1023 |
+
for method, metrics in performance.items():
|
| 1024 |
+
perf_data.append({
|
| 1025 |
+
'Method': method.upper(),
|
| 1026 |
+
'Avg Response Time (ms)': f"{metrics['avg_response_time']:.0f}",
|
| 1027 |
+
'Avg Citations': f"{metrics['avg_citations']:.1f}",
|
| 1028 |
+
'Avg Answer Length': f"{metrics['avg_answer_length']:.0f}",
|
| 1029 |
+
'Query Count': int(metrics['query_count'])
|
| 1030 |
+
})
|
| 1031 |
+
|
| 1032 |
+
if perf_data:
|
| 1033 |
+
st.dataframe(perf_data, use_container_width=True, hide_index=True)
|
| 1034 |
+
|
| 1035 |
+
# Method Usage with Voice Interaction Indicator
|
| 1036 |
+
method_usage = stats.get('method_usage', {})
|
| 1037 |
+
if method_usage:
|
| 1038 |
+
st.markdown("#### 🎯 Method Usage Distribution")
|
| 1039 |
+
col1, col2 = st.columns([2, 1])
|
| 1040 |
+
|
| 1041 |
+
with col1:
|
| 1042 |
+
st.bar_chart(method_usage)
|
| 1043 |
+
|
| 1044 |
+
with col2:
|
| 1045 |
+
st.markdown("**Most Popular Methods:**")
|
| 1046 |
+
sorted_methods = sorted(method_usage.items(), key=lambda x: x[1], reverse=True)
|
| 1047 |
+
for i, (method, count) in enumerate(sorted_methods[:3], 1):
|
| 1048 |
+
percentage = (count / sum(method_usage.values())) * 100
|
| 1049 |
+
st.markdown(f"{i}. **{method.upper()}** - {count} queries ({percentage:.1f}%)")
|
| 1050 |
+
|
| 1051 |
+
# Voice interaction stats
|
| 1052 |
+
try:
|
| 1053 |
+
voice_queries = analytics_db.get_voice_interaction_stats()
|
| 1054 |
+
if voice_queries and voice_queries.get('total_voice_queries', 0) > 0:
|
| 1055 |
+
st.markdown("---")
|
| 1056 |
+
st.markdown("**🎤 Voice Interactions:**")
|
| 1057 |
+
st.markdown(f"🔊 Voice queries: {voice_queries['total_voice_queries']}")
|
| 1058 |
+
if voice_queries.get('avg_voice_response_time', 0) > 0:
|
| 1059 |
+
st.markdown(f"⏱️ Avg response time: {voice_queries['avg_voice_response_time']:.1f}ms")
|
| 1060 |
+
if sum(method_usage.values()) > 0:
|
| 1061 |
+
voice_percentage = (voice_queries['total_voice_queries'] / sum(method_usage.values())) * 100
|
| 1062 |
+
st.markdown(f"📊 Voice usage: {voice_percentage:.1f}%")
|
| 1063 |
+
except Exception as e:
|
| 1064 |
+
logger.error(f"Voice stats error: {e}")
|
| 1065 |
+
pass
|
| 1066 |
+
|
| 1067 |
+
# Voice Analytics Section (if voice interactions exist)
|
| 1068 |
+
try:
|
| 1069 |
+
voice_queries = analytics_db.get_voice_interaction_stats()
|
| 1070 |
+
if voice_queries and voice_queries.get('total_voice_queries', 0) > 0:
|
| 1071 |
+
st.markdown("#### 🎤 Voice Interaction Analytics")
|
| 1072 |
+
col1, col2 = st.columns([2, 1])
|
| 1073 |
+
|
| 1074 |
+
with col1:
|
| 1075 |
+
voice_by_method = voice_queries.get('voice_by_method', {})
|
| 1076 |
+
if voice_by_method:
|
| 1077 |
+
st.bar_chart(voice_by_method)
|
| 1078 |
+
else:
|
| 1079 |
+
st.info("No voice method breakdown available yet")
|
| 1080 |
+
|
| 1081 |
+
with col2:
|
| 1082 |
+
st.markdown("**Voice Stats:**")
|
| 1083 |
+
total_voice = voice_queries['total_voice_queries']
|
| 1084 |
+
st.markdown(f"🔊 Total voice queries: {total_voice}")
|
| 1085 |
+
|
| 1086 |
+
avg_response = voice_queries.get('avg_voice_response_time', 0)
|
| 1087 |
+
if avg_response > 0:
|
| 1088 |
+
st.markdown(f"⏱️ Avg response: {avg_response:.1f}ms")
|
| 1089 |
+
|
| 1090 |
+
# Most used voice method
|
| 1091 |
+
if voice_by_method:
|
| 1092 |
+
most_used_voice = max(voice_by_method.items(), key=lambda x: x[1])
|
| 1093 |
+
st.markdown(f"🎯 Top voice method: {most_used_voice[0].upper()}")
|
| 1094 |
+
except Exception as e:
|
| 1095 |
+
logger.error(f"Voice analytics error: {e}")
|
| 1096 |
+
|
| 1097 |
+
# Citation Analysis
|
| 1098 |
+
citation_types = stats.get('citation_types', {})
|
| 1099 |
+
if citation_types:
|
| 1100 |
+
st.markdown("#### 📚 Citation Sources")
|
| 1101 |
+
col1, col2 = st.columns([2, 1])
|
| 1102 |
+
|
| 1103 |
+
with col1:
|
| 1104 |
+
# Filter out empty/null citation types
|
| 1105 |
+
filtered_citations = {k: v for k, v in citation_types.items() if k and k.strip()}
|
| 1106 |
+
if filtered_citations:
|
| 1107 |
+
st.bar_chart(filtered_citations)
|
| 1108 |
+
|
| 1109 |
+
with col2:
|
| 1110 |
+
st.markdown("**Source Breakdown:**")
|
| 1111 |
+
total_citations = sum(citation_types.values())
|
| 1112 |
+
for cite_type, count in sorted(citation_types.items(), key=lambda x: x[1], reverse=True):
|
| 1113 |
+
if cite_type and cite_type.strip():
|
| 1114 |
+
percentage = (count / total_citations) * 100
|
| 1115 |
+
icon = "📄" if cite_type == "pdf" else "🌐" if cite_type == "html" else "🖼️" if cite_type == "image" else "📚"
|
| 1116 |
+
st.markdown(f"{icon} **{cite_type.title()}**: {count} ({percentage:.1f}%)")
|
| 1117 |
+
|
| 1118 |
+
# Popular Keywords
|
| 1119 |
+
keywords = stats.get('top_keywords', {})
|
| 1120 |
+
if keywords:
|
| 1121 |
+
st.markdown("#### 🔍 Popular Query Topics")
|
| 1122 |
+
col1, col2, col3 = st.columns(3)
|
| 1123 |
+
|
| 1124 |
+
keyword_items = list(keywords.items())
|
| 1125 |
+
for i, (word, count) in enumerate(keyword_items[:9]): # Top 9 keywords
|
| 1126 |
+
col = [col1, col2, col3][i % 3]
|
| 1127 |
+
with col:
|
| 1128 |
+
st.metric(word.title(), count)
|
| 1129 |
+
|
| 1130 |
+
# Recent Queries with Responses
|
| 1131 |
+
if recent_queries:
|
| 1132 |
+
st.markdown("#### 🕒 Recent Queries & Responses")
|
| 1133 |
+
|
| 1134 |
+
for query in recent_queries[:5]: # Show last 5
|
| 1135 |
+
# Create expander title with query preview
|
| 1136 |
+
query_preview = query['query'][:60] + "..." if len(query['query']) > 60 else query['query']
|
| 1137 |
+
expander_title = f"🧑 **{query['method'].upper()}**: {query_preview}"
|
| 1138 |
+
|
| 1139 |
+
with st.expander(expander_title):
|
| 1140 |
+
# Query details
|
| 1141 |
+
st.markdown(f"**📝 Full Query:** {query['query']}")
|
| 1142 |
+
|
| 1143 |
+
# Metrics row
|
| 1144 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 1145 |
+
with col1:
|
| 1146 |
+
st.metric("Answer Length", f"{query['answer_length']} chars")
|
| 1147 |
+
with col2:
|
| 1148 |
+
st.metric("Citations", query['citations'])
|
| 1149 |
+
with col3:
|
| 1150 |
+
if query['response_time']:
|
| 1151 |
+
st.metric("Response Time", f"{query['response_time']:.0f}ms")
|
| 1152 |
+
else:
|
| 1153 |
+
st.metric("Response Time", "N/A")
|
| 1154 |
+
with col4:
|
| 1155 |
+
status = "❌ Error" if query.get('error_message') else "✅ Success"
|
| 1156 |
+
st.markdown(f"**Status:** {status}")
|
| 1157 |
+
|
| 1158 |
+
# Show error message if exists
|
| 1159 |
+
if query.get('error_message'):
|
| 1160 |
+
st.error(f"**Error:** {query['error_message']}")
|
| 1161 |
+
else:
|
| 1162 |
+
# Show answer in a styled container
|
| 1163 |
+
st.markdown("**🤖 Response:**")
|
| 1164 |
+
answer = query.get('answer', 'No answer available')
|
| 1165 |
+
|
| 1166 |
+
# Truncate very long answers for better UX
|
| 1167 |
+
if len(answer) > 1000:
|
| 1168 |
+
st.markdown(
|
| 1169 |
+
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">'
|
| 1170 |
+
f'{answer[:500].replace(chr(10), "<br>")}<br><br>'
|
| 1171 |
+
f'<i>... (truncated, showing first 500 chars of {len(answer)} total)</i>'
|
| 1172 |
+
f'</div>',
|
| 1173 |
+
unsafe_allow_html=True
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
# Option to view full answer
|
| 1177 |
+
if st.button(f"📖 View Full Answer", key=f"full_answer_{query['query_id']}"):
|
| 1178 |
+
st.markdown("**Full Answer:**")
|
| 1179 |
+
st.markdown(
|
| 1180 |
+
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; max-height: 400px; overflow-y: auto;">'
|
| 1181 |
+
f'{answer.replace(chr(10), "<br>")}'
|
| 1182 |
+
f'</div>',
|
| 1183 |
+
unsafe_allow_html=True
|
| 1184 |
+
)
|
| 1185 |
+
else:
|
| 1186 |
+
# Short answers display fully
|
| 1187 |
+
st.markdown(
|
| 1188 |
+
f'<div style="background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #28a745;">'
|
| 1189 |
+
f'{answer.replace(chr(10), "<br>")}'
|
| 1190 |
+
f'</div>',
|
| 1191 |
+
unsafe_allow_html=True
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
# Show detailed citation info
|
| 1195 |
+
if query['citations'] > 0:
|
| 1196 |
+
if st.button(f"📚 View Citations", key=f"citations_{query['query_id']}"):
|
| 1197 |
+
detailed_query = analytics_db.get_query_with_citations(query['query_id'])
|
| 1198 |
+
if detailed_query and 'citations' in detailed_query:
|
| 1199 |
+
st.markdown("**Citations:**")
|
| 1200 |
+
for i, citation in enumerate(detailed_query['citations'], 1):
|
| 1201 |
+
scores = []
|
| 1202 |
+
if citation.get('relevance_score'):
|
| 1203 |
+
scores.append(f"relevance: {citation['relevance_score']:.3f}")
|
| 1204 |
+
if citation.get('bm25_score'):
|
| 1205 |
+
scores.append(f"BM25: {citation['bm25_score']:.3f}")
|
| 1206 |
+
if citation.get('rerank_score'):
|
| 1207 |
+
scores.append(f"rerank: {citation['rerank_score']:.3f}")
|
| 1208 |
+
|
| 1209 |
+
score_text = f" ({', '.join(scores)})" if scores else ""
|
| 1210 |
+
st.markdown(f"{i}. **{citation['source']}** {score_text}")
|
| 1211 |
+
|
| 1212 |
+
st.markdown(f"**🕐 Timestamp:** {query['timestamp']}")
|
| 1213 |
+
st.markdown("---")
|
| 1214 |
+
|
| 1215 |
+
# Session Info
|
| 1216 |
+
st.markdown("---")
|
| 1217 |
+
col1, col2 = st.columns([3, 1])
|
| 1218 |
+
with col1:
|
| 1219 |
+
st.markdown("*Analytics are updated in real-time and persist across sessions*")
|
| 1220 |
+
with col2:
|
| 1221 |
+
st.markdown(f"**Session ID:** `{st.session_state.session_id}`")
|
| 1222 |
+
|
| 1223 |
+
except Exception as e:
|
| 1224 |
+
st.error(f"Error loading analytics: {e}")
|
| 1225 |
+
st.info("Analytics data will appear after your first query. The database is created automatically.")
|
| 1226 |
+
|
| 1227 |
+
# Fallback to session analytics
|
| 1228 |
+
if st.session_state.chat_history:
|
| 1229 |
+
st.markdown("#### 📊 Current Session")
|
| 1230 |
+
col1, col2 = st.columns(2)
|
| 1231 |
+
with col1:
|
| 1232 |
+
st.metric("Session Queries", len(st.session_state.chat_history))
|
| 1233 |
+
with col2:
|
| 1234 |
+
methods_used = [entry['method'] for entry in st.session_state.chat_history]
|
| 1235 |
+
most_used = max(set(methods_used), key=methods_used.count) if methods_used else "N/A"
|
| 1236 |
+
st.metric("Most Used Method", most_used.upper() if most_used != "N/A" else most_used)
|
| 1237 |
+
|
| 1238 |
+
# Full Screen Comparison Window (Modal-like)
|
| 1239 |
+
if st.session_state.get('show_comparison_window', False):
|
| 1240 |
+
st.markdown("---")
|
| 1241 |
+
|
| 1242 |
+
# Header with close button
|
| 1243 |
+
col1, col2 = st.columns([4, 1])
|
| 1244 |
+
with col1:
|
| 1245 |
+
comparison_data = st.session_state.comparison_results
|
| 1246 |
+
st.markdown(f"## 🪟 Full Screen Comparison")
|
| 1247 |
+
st.markdown(f"**Query:** {comparison_data['query']}")
|
| 1248 |
+
st.markdown(f"**Generated:** {comparison_data['timestamp']} | **Methods:** {', '.join([m.upper() for m in comparison_data['methods']])}")
|
| 1249 |
+
|
| 1250 |
+
with col2:
|
| 1251 |
+
if st.button("✖️ Close", help="Close full screen view"):
|
| 1252 |
+
st.session_state.show_comparison_window = False
|
| 1253 |
+
st.rerun()
|
| 1254 |
+
|
| 1255 |
+
st.markdown("---")
|
| 1256 |
+
|
| 1257 |
+
# Full-width comparison display
|
| 1258 |
+
results = comparison_data['results']
|
| 1259 |
+
methods = comparison_data['methods']
|
| 1260 |
+
|
| 1261 |
+
for method in methods:
|
| 1262 |
+
st.markdown(f"### 🔸 {method.upper()} Method")
|
| 1263 |
+
|
| 1264 |
+
# Answer
|
| 1265 |
+
answer = results[method]['answer']
|
| 1266 |
+
st.markdown("**Answer:**")
|
| 1267 |
+
|
| 1268 |
+
# Use a container with custom styling for better readability
|
| 1269 |
+
with st.container():
|
| 1270 |
+
st.markdown(
|
| 1271 |
+
f'<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin: 10px 0; border-left: 5px solid #1f77b4;">'
|
| 1272 |
+
f'{answer.replace(chr(10), "<br>")}'
|
| 1273 |
+
f'</div>',
|
| 1274 |
+
unsafe_allow_html=True
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
# Citations
|
| 1278 |
+
st.markdown("**Citations:**")
|
| 1279 |
+
st.markdown(format_citations_html(results[method]['citations'], method), unsafe_allow_html=True)
|
| 1280 |
+
|
| 1281 |
+
# Statistics
|
| 1282 |
+
col1, col2, col3 = st.columns(3)
|
| 1283 |
+
with col1:
|
| 1284 |
+
st.metric("Answer Length", f"{len(answer)} chars")
|
| 1285 |
+
with col2:
|
| 1286 |
+
st.metric("Citations", len(results[method]['citations']))
|
| 1287 |
+
with col3:
|
| 1288 |
+
word_count = len(answer.split())
|
| 1289 |
+
st.metric("Word Count", word_count)
|
| 1290 |
+
|
| 1291 |
+
if method != methods[-1]: # Not the last method
|
| 1292 |
+
st.markdown("---")
|
| 1293 |
+
|
| 1294 |
+
# Summary comparison table
|
| 1295 |
+
st.markdown("### 📊 Method Comparison Summary")
|
| 1296 |
+
|
| 1297 |
+
summary_data = []
|
| 1298 |
+
for method in methods:
|
| 1299 |
+
summary_data.append({
|
| 1300 |
+
'Method': method.upper(),
|
| 1301 |
+
'Answer Length (chars)': len(results[method]['answer']),
|
| 1302 |
+
'Word Count': len(results[method]['answer'].split()),
|
| 1303 |
+
'Citations': len(results[method]['citations']),
|
| 1304 |
+
'Avg Citation Score': round(
|
| 1305 |
+
sum(float(c.get('relevance_score', 0) or c.get('score', 0) or 0)
|
| 1306 |
+
for c in results[method]['citations']) / len(results[method]['citations'])
|
| 1307 |
+
if results[method]['citations'] else 0, 3
|
| 1308 |
+
)
|
| 1309 |
+
})
|
| 1310 |
+
|
| 1311 |
+
st.dataframe(summary_data, use_container_width=True, hide_index=True)
|
| 1312 |
+
|
| 1313 |
+
st.markdown("---")
|
| 1314 |
+
|
| 1315 |
+
# Return to normal view button
|
| 1316 |
+
col1, col2, col3 = st.columns([2, 1, 2])
|
| 1317 |
+
with col2:
|
| 1318 |
+
if st.button("⬅️ Back to Comparison Tab", type="primary", use_container_width=True):
|
| 1319 |
+
st.session_state.show_comparison_window = False
|
| 1320 |
+
st.rerun()
|
| 1321 |
+
|
| 1322 |
+
st.stop() # Stop rendering the rest of the app when in full screen mode
|
| 1323 |
|
| 1324 |
# Footer
|
| 1325 |
st.markdown("---")
|
| 1326 |
st.markdown(
|
| 1327 |
+
"**⚠️ Disclaimer:** *This system uses AI to retrieve and generate responses. "
|
| 1328 |
+
"While we strive for accuracy, please verify critical safety information with official sources.*"
|
| 1329 |
+
)
|
| 1330 |
st.markdown(
|
| 1331 |
+
"**🙏 Acknowledgment:** *We thank [Ohio BWC/WSIC](https://info.bwc.ohio.gov/) "
|
| 1332 |
+
"for funding that made this multi-method RAG system possible.*"
|
| 1333 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Central configuration file for the Multi-Method RAG System.
|
| 3 |
+
All shared parameters and settings are defined here.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# Load environment variables
|
| 11 |
+
load_dotenv(override=True)
|
| 12 |
+
|
| 13 |
+
# ==================== Versioning and Date ====================
|
| 14 |
+
DATE = "August 13, 2025"
|
| 15 |
+
VERSION = "2.0.1"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ==================== API Configuration ====================
|
| 19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
OPENAI_CHAT_MODEL = "gpt-5-chat-latest" # This is the non-reasoning model for gpt-5 so it has no latency
|
| 21 |
+
OPENAI_EMBEDDING_MODEL = "text-embedding-3-large" # Options: text-embedding-3-large, text-embedding-3-small, text-embedding-ada-002
|
| 22 |
+
|
| 23 |
+
# ==================== Realtime API Configuration ====================
|
| 24 |
+
# OpenAI Realtime API settings for speech-to-speech functionality
|
| 25 |
+
OPENAI_REALTIME_MODEL = "gpt-4o-realtime-preview" # Realtime model for speech-to-speech
|
| 26 |
+
REALTIME_VOICE = "alloy" # Available voices: alloy, echo, fable, onyx, nova, shimmer
|
| 27 |
+
REALTIME_INSTRUCTIONS = (
|
| 28 |
+
"You are a knowledgeable safety expert speaking naturally in conversation. "
|
| 29 |
+
|
| 30 |
+
"VOICE BEHAVIOR: "
|
| 31 |
+
"- Speak like a confident safety professional talking to a colleague "
|
| 32 |
+
"- Acknowledge what you heard: 'You're asking about [topic]...' "
|
| 33 |
+
"- Use natural speech with appropriate pauses and emphasis "
|
| 34 |
+
"- Sound authoritative and knowledgeable - you ARE the expert "
|
| 35 |
+
"- Never mention document names, page numbers, or citation details when speaking "
|
| 36 |
+
"- Just state the facts naturally as if you know them from your expertise "
|
| 37 |
+
|
| 38 |
+
"RESPONSE PROCESS: "
|
| 39 |
+
"1. Briefly acknowledge the question: 'You're asking about [topic]...' "
|
| 40 |
+
"2. Call ask_rag to get the accurate information "
|
| 41 |
+
"3. Speak the information naturally as YOUR expertise, not as 'according to document X' "
|
| 42 |
+
"4. Organize complex topics: 'There are three key requirements here...' "
|
| 43 |
+
"5. Be thorough but conversational - like explaining to a colleague "
|
| 44 |
+
|
| 45 |
+
"CITATION RULE: "
|
| 46 |
+
"NEVER mention specific documents, sources, or page numbers in speech. "
|
| 47 |
+
"Just state the information confidently as if it's your professional knowledge. "
|
| 48 |
+
"For example, don't say 'According to OSHA 1910.147...' - just say 'The lockout tagout requirements are...' "
|
| 49 |
+
|
| 50 |
+
"IMPORTANT: Always use ask_rag for safety questions to get accurate information, "
|
| 51 |
+
"but speak the results as your own expertise, not as citations."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# ==================== Model Parameters ====================
|
| 55 |
+
# Generation parameters
|
| 56 |
+
DEFAULT_TEMPERATURE = 0 # Range: 0.0-1.0 (0=deterministic, 1=creative)
|
| 57 |
+
DEFAULT_MAX_TOKENS = 5000 # Maximum tokens in response
|
| 58 |
+
DEFAULT_TOP_K = 5 # Number of chunks to retrieve by default
|
| 59 |
+
DEFAULT_TOP_P = 1.0 # Nucleus sampling parameter
|
| 60 |
+
|
| 61 |
+
# Context window management
|
| 62 |
+
MAX_CONTEXT_TOKENS = 7500 # Maximum context for models with 8k window
|
| 63 |
+
CHUNK_SIZE = 2000 # Tokens per chunk (used by TextPreprocessor.chunk_text_by_tokens)
|
| 64 |
+
CHUNK_OVERLAP = 200 # Token overlap between chunks
|
| 65 |
+
|
| 66 |
+
# ==================== Embedding Models ====================
|
| 67 |
+
# Sentence Transformers models
|
| 68 |
+
SENTENCE_TRANSFORMER_MODEL = 'all-MiniLM-L6-v2' # For DPR
|
| 69 |
+
CROSS_ENCODER_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2' # For re-ranking
|
| 70 |
+
|
| 71 |
+
# CLIP model
|
| 72 |
+
CLIP_MODEL = "ViT-L/14" # Options: ViT-B/32, ViT-L/14, RN50
|
| 73 |
+
|
| 74 |
+
# ==================== Search Parameters ====================
|
| 75 |
+
# BM25 parameters
|
| 76 |
+
BM25_K1 = 1.5 # Term frequency saturation parameter
|
| 77 |
+
BM25_B = 0.75 # Length normalization parameter
|
| 78 |
+
|
| 79 |
+
# Hybrid search
|
| 80 |
+
DEFAULT_HYBRID_ALPHA = 0.5 # Weight for BM25 (1-alpha for semantic)
|
| 81 |
+
|
| 82 |
+
# Re-ranking
|
| 83 |
+
RERANK_MULTIPLIER = 2 # Retrieve this many times top_k for re-ranking
|
| 84 |
+
MIN_RELEVANCE_SCORE = 0.3 # Minimum score threshold
|
| 85 |
+
|
| 86 |
+
# ==================== Directory Structure ====================
|
| 87 |
+
# Project directories
|
| 88 |
+
PROJECT_ROOT = Path(__file__).parent
|
| 89 |
+
DATA_DIR = PROJECT_ROOT / "data"
|
| 90 |
+
EMBEDDINGS_DIR = PROJECT_ROOT / "embeddings"
|
| 91 |
+
GRAPH_DIR = PROJECT_ROOT / "graph"
|
| 92 |
+
METADATA_DIR = PROJECT_ROOT / "metadata"
|
| 93 |
+
IMAGES_DIR = DATA_DIR / "images"
|
| 94 |
+
|
| 95 |
+
# File paths
|
| 96 |
+
VANILLA_FAISS_INDEX = EMBEDDINGS_DIR / "vanilla_faiss.index"
|
| 97 |
+
VANILLA_METADATA = EMBEDDINGS_DIR / "vanilla_metadata.pkl"
|
| 98 |
+
DPR_FAISS_INDEX = EMBEDDINGS_DIR / "dpr_faiss.index"
|
| 99 |
+
DPR_METADATA = EMBEDDINGS_DIR / "dpr_metadata.pkl"
|
| 100 |
+
BM25_INDEX = EMBEDDINGS_DIR / "bm25_index.pkl"
|
| 101 |
+
CONTEXT_DOCS = EMBEDDINGS_DIR / "context_stuffing_docs.pkl"
|
| 102 |
+
GRAPH_FILE = GRAPH_DIR / "graph.gml"
|
| 103 |
+
IMAGES_DB = METADATA_DIR / "images.db"
|
| 104 |
+
CHROMA_PATH = EMBEDDINGS_DIR / "chroma"
|
| 105 |
+
|
| 106 |
+
# ==================== Batch Processing ====================
|
| 107 |
+
EMBEDDING_BATCH_SIZE = 100 # Batch size for OpenAI embeddings
|
| 108 |
+
PROCESSING_BATCH_SIZE = 50 # Documents to process at once
|
| 109 |
+
|
| 110 |
+
# ==================== UI Configuration ====================
|
| 111 |
+
# Streamlit settings
|
| 112 |
+
MAX_CHAT_HISTORY = 5 # Maximum chat messages to keep
|
| 113 |
+
EXAMPLE_QUESTIONS = [
|
| 114 |
+
"What are general machine guarding requirements?",
|
| 115 |
+
"How do I perform lockout/tagout?",
|
| 116 |
+
"What safety measures are needed for robotic systems?",
|
| 117 |
+
"Explain the difference between guards and devices in machine safety.",
|
| 118 |
+
"What are the OSHA requirements for emergency stops?",
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
# Default method
|
| 122 |
+
DEFAULT_METHOD = "graph"
|
| 123 |
+
|
| 124 |
+
# Method descriptions for UI
|
| 125 |
+
METHOD_DESCRIPTIONS = {
|
| 126 |
+
'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval",
|
| 127 |
+
'vanilla': "Standard vector search with FAISS and OpenAI embeddings",
|
| 128 |
+
'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking",
|
| 129 |
+
'bm25': "BM25 keyword search with neural re-ranking for exact term matching",
|
| 130 |
+
'context': "Context stuffing with full document loading and heuristic selection",
|
| 131 |
+
'vision': "Vision-based search using GPT-5 Vision for image analysis and classification"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# ==================== Document Processing ====================
|
| 135 |
+
# Document types
|
| 136 |
+
SUPPORTED_EXTENSIONS = ['.pdf', '.txt', '.md', '.html']
|
| 137 |
+
IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
|
| 138 |
+
|
| 139 |
+
# Text splitting
|
| 140 |
+
MARKDOWN_HEADER_LEVEL = 3 # Split by this header level (###)
|
| 141 |
+
MAX_SECTIONS_PER_DOC = 500 # Maximum sections to extract from a document
|
| 142 |
+
|
| 143 |
+
# ==================== Logging ====================
|
| 144 |
+
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") # DEBUG, INFO, WARNING, ERROR
|
| 145 |
+
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 146 |
+
|
| 147 |
+
# ==================== Performance ====================
|
| 148 |
+
# Device configuration
|
| 149 |
+
import torch
|
| 150 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 151 |
+
NUM_WORKERS = 4 # Parallel processing workers
|
| 152 |
+
|
| 153 |
+
# Cache settings
|
| 154 |
+
ENABLE_CACHE = True
|
| 155 |
+
CACHE_TTL = 3600 # Cache time-to-live in seconds
|
| 156 |
+
|
| 157 |
+
# ==================== Safety & Validation ====================
|
| 158 |
+
# Input validation
|
| 159 |
+
MAX_QUESTION_LENGTH = 1000 # Maximum characters in a question
|
| 160 |
+
MAX_IMAGE_SIZE_MB = 10 # Maximum image file size
|
| 161 |
+
|
| 162 |
+
# Rate limiting (if needed)
|
| 163 |
+
RATE_LIMIT_ENABLED = False
|
| 164 |
+
MAX_QUERIES_PER_MINUTE = 60
|
| 165 |
+
|
| 166 |
+
# ==================== Default HTML Sources ====================
|
| 167 |
+
DEFAULT_HTML_SOURCES = [
|
| 168 |
+
{
|
| 169 |
+
"title": "NIOSH Robotics in the Workplace – Safety Overview",
|
| 170 |
+
"url": "https://www.cdc.gov/niosh/robotics/about/",
|
| 171 |
+
"source": "NIOSH",
|
| 172 |
+
"year": 2024,
|
| 173 |
+
"category": "Technical Guide",
|
| 174 |
+
"format": "HTML"
|
| 175 |
+
}
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
# ==================== Helper Functions ====================
|
| 179 |
+
def ensure_directories():
|
| 180 |
+
"""Create all required directories if they don't exist."""
|
| 181 |
+
for directory in [DATA_DIR, EMBEDDINGS_DIR, GRAPH_DIR, METADATA_DIR, IMAGES_DIR]:
|
| 182 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 183 |
+
|
| 184 |
+
def get_model_context_length(model_name: str = OPENAI_CHAT_MODEL) -> int:
|
| 185 |
+
"""Get the context length for a given model."""
|
| 186 |
+
context_lengths = {
|
| 187 |
+
"gpt-5": 128000,
|
| 188 |
+
"gpt-4o-mini": 8192,
|
| 189 |
+
"gpt-4o": 128000,
|
| 190 |
+
}
|
| 191 |
+
return context_lengths.get(model_name, 4096)
|
| 192 |
+
|
| 193 |
+
def validate_api_key():
|
| 194 |
+
"""Check if OpenAI API key is set."""
|
| 195 |
+
if not OPENAI_API_KEY:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
"OpenAI API key not found. Please set OPENAI_API_KEY in .env file."
|
| 198 |
+
)
|
| 199 |
+
return True
|
| 200 |
+
|
| 201 |
+
# ==================== System Info ====================
|
| 202 |
+
def print_config():
|
| 203 |
+
"""Print current configuration for debugging."""
|
| 204 |
+
print("="*50)
|
| 205 |
+
print("RAG System Configuration")
|
| 206 |
+
print("="*50)
|
| 207 |
+
print(f"OpenAI Model: {OPENAI_CHAT_MODEL}")
|
| 208 |
+
print(f"Embedding Model: {OPENAI_EMBEDDING_MODEL}")
|
| 209 |
+
print(f"Device: {DEVICE}")
|
| 210 |
+
print(f"Default Temperature: {DEFAULT_TEMPERATURE}")
|
| 211 |
+
print(f"Default Top-K: {DEFAULT_TOP_K}")
|
| 212 |
+
print(f"Chunk Size: {CHUNK_SIZE}")
|
| 213 |
+
print(f"Project Root: {PROJECT_ROOT}")
|
| 214 |
+
print("="*50)
|
| 215 |
+
|
| 216 |
+
# Ensure directories exist on import
|
| 217 |
+
ensure_directories()
|
preprocess.py
CHANGED
|
@@ -1,85 +1,195 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from dotenv import load_dotenv
|
| 5 |
-
import requests
|
| 6 |
-
from bs4 import BeautifulSoup
|
| 7 |
-
import pandas as pd
|
| 8 |
-
import pymupdf4llm
|
| 9 |
-
import networkx as nx
|
| 10 |
-
from openai import OpenAI
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
parts = re.split(r'(?m)^### ', md_text)
|
| 19 |
-
return [('### ' + p) if not p.startswith('### ') else p for p in parts if p.strip()]
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
for pdf_path in glob.glob("scrapped_data/*.pdf"):
|
| 26 |
-
filename = os.path.basename(pdf_path)
|
| 27 |
-
title = os.path.splitext(filename)[0]
|
| 28 |
-
# Convert PDF to Markdown
|
| 29 |
-
md_text = pymupdf4llm.to_markdown(pdf_path)
|
| 30 |
-
# Split into sections
|
| 31 |
-
sections = split_by_header(md_text)
|
| 32 |
-
for idx, sec in enumerate(sections):
|
| 33 |
-
resp = client.embeddings.create(model="text-embedding-3-large", input=sec)
|
| 34 |
-
vector = resp.data[0].embedding
|
| 35 |
-
node_id = f"PDF::{title}::section{idx}"
|
| 36 |
-
# Store the local file path for citation
|
| 37 |
-
G.add_node(node_id,
|
| 38 |
-
text=sec,
|
| 39 |
-
embedding=vector,
|
| 40 |
-
source=title,
|
| 41 |
-
path=pdf_path)
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
{
|
| 46 |
-
"title": "NIOSH Robotics in the Workplace – Safety Overview (Human-Robot Collaboration)",
|
| 47 |
-
"url": "https://www.cdc.gov/niosh/robotics/about/",
|
| 48 |
-
"source": "NIOSH",
|
| 49 |
-
"year": 2024,
|
| 50 |
-
"category": "Technical Guide",
|
| 51 |
-
"summary": "A NIOSH overview of emerging safety challenges as robots increasingly collaborate with human workers. Updated in 2024, this page discusses how robots can improve safety by taking over dangerous tasks but also introduces new struck-by and caught-between hazards. It emphasizes the need for updated safety standards, risk assessments, and research on human-robot interaction, and it outlines NIOSH’s efforts (through its Center for Occupational Robotics Research) to develop best practices and guidance for safe integration of robotics in industry.",
|
| 52 |
-
"format": "HTML"
|
| 53 |
-
}
|
| 54 |
-
]
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
# Run HTML processing
|
| 80 |
-
for item in html_data:
|
| 81 |
-
process_html(item)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Refactored preprocessing pipeline for all RAG methods.
|
| 3 |
+
Uses utils.py functions and supports multiple retrieval methods.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
Directory Layout:
|
| 6 |
+
/data/ # Original PDFs, HTML
|
| 7 |
+
/embeddings/ # FAISS, Chroma, DPR vector stores
|
| 8 |
+
/graph/ # Graph database files
|
| 9 |
+
/metadata/ # Image metadata (SQLite or MongoDB)
|
| 10 |
+
"""
|
| 11 |
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
from config import *
|
| 16 |
+
from utils import (
|
| 17 |
+
DocumentLoader, TextPreprocessor, VectorStoreManager,
|
| 18 |
+
ImageProcessor, ImageData
|
| 19 |
+
)
|
| 20 |
|
| 21 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# Ensure all directories exist
|
| 24 |
+
ensure_directories()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
def preprocess_for_method(method: str, documents: list):
|
| 27 |
+
"""Preprocess documents for a specific retrieval method."""
|
| 28 |
+
|
| 29 |
+
print(f"\n{'='*50}")
|
| 30 |
+
print(f"Preprocessing for method: {method}")
|
| 31 |
+
print(f"{'='*50}")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# Initialize processors
|
| 35 |
+
text_processor = TextPreprocessor()
|
| 36 |
+
vector_manager = VectorStoreManager()
|
| 37 |
+
|
| 38 |
+
# Preprocess text chunks for this method
|
| 39 |
+
chunks = text_processor.preprocess_for_method(documents, method)
|
| 40 |
+
|
| 41 |
+
if method == 'vanilla':
|
| 42 |
+
# Build FAISS index with OpenAI embeddings
|
| 43 |
+
index, metadata = vector_manager.build_faiss_index(chunks, method="vanilla")
|
| 44 |
+
vector_manager.save_index(index, metadata, method)
|
| 45 |
+
|
| 46 |
+
elif method == 'dpr':
|
| 47 |
+
# Build FAISS index with sentence transformer embeddings
|
| 48 |
+
index, metadata = vector_manager.build_faiss_index(chunks, method="dpr")
|
| 49 |
+
vector_manager.save_index(index, metadata, method)
|
| 50 |
+
|
| 51 |
+
elif method == 'bm25':
|
| 52 |
+
# Build BM25 index
|
| 53 |
+
bm25_index = vector_manager.build_bm25_index(chunks)
|
| 54 |
+
vector_manager.save_index(bm25_index, chunks, method)
|
| 55 |
+
|
| 56 |
+
elif method == 'graph':
|
| 57 |
+
# Build NetworkX graph
|
| 58 |
+
graph = vector_manager.build_graph_index(chunks)
|
| 59 |
+
vector_manager.save_index(graph, None, method)
|
| 60 |
+
|
| 61 |
+
elif method == 'context_stuffing':
|
| 62 |
+
# Save full documents for context stuffing
|
| 63 |
+
vector_manager.save_index(None, chunks, method)
|
| 64 |
+
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Unknown method: {method}")
|
| 67 |
+
|
| 68 |
+
print(f"Successfully preprocessed for method '{method}'")
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Error preprocessing for {method}: {e}")
|
| 72 |
+
raise
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
def extract_and_process_images(documents: list):
|
| 76 |
+
"""Extract images from documents and process them."""
|
| 77 |
+
print("\n" + "="*50)
|
| 78 |
+
print("Extracting and processing images...")
|
| 79 |
+
print("="*50)
|
| 80 |
+
|
| 81 |
+
image_processor = ImageProcessor()
|
| 82 |
+
processed_count = 0
|
| 83 |
+
filtered_count = 0
|
| 84 |
+
filter_reasons = {}
|
| 85 |
+
|
| 86 |
+
for doc in documents:
|
| 87 |
+
if 'images' in doc and doc['images']:
|
| 88 |
+
for image_info in doc['images']:
|
| 89 |
+
try:
|
| 90 |
+
# Check if image should be filtered out
|
| 91 |
+
should_filter, reason = image_processor.should_filter_image(image_info['image_path'])
|
| 92 |
+
|
| 93 |
+
if should_filter:
|
| 94 |
+
filtered_count += 1
|
| 95 |
+
filter_reasons[reason] = filter_reasons.get(reason, 0) + 1
|
| 96 |
+
print(f" Filtered: {image_info['image_id']} - {reason}")
|
| 97 |
+
|
| 98 |
+
# Optionally delete the filtered image file
|
| 99 |
+
try:
|
| 100 |
+
import os
|
| 101 |
+
os.remove(image_info['image_path'])
|
| 102 |
+
print(f" Deleted: {image_info['image_path']}")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.warning(f"Could not delete filtered image {image_info['image_path']}: {e}")
|
| 105 |
+
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# Classify image
|
| 109 |
+
classification = image_processor.classify_image(image_info['image_path'])
|
| 110 |
+
|
| 111 |
+
# Generate embedding (placeholder for now)
|
| 112 |
+
# embedding = embed_image_clip([image_info['image_path']])[0]
|
| 113 |
+
|
| 114 |
+
# Create ImageData object
|
| 115 |
+
image_data = ImageData(
|
| 116 |
+
image_path=image_info['image_path'],
|
| 117 |
+
image_id=image_info['image_id'],
|
| 118 |
+
classification=classification,
|
| 119 |
+
metadata={
|
| 120 |
+
'source': doc['source'],
|
| 121 |
+
'page': image_info.get('page'),
|
| 122 |
+
'extracted_from': doc['path']
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Store in database
|
| 127 |
+
image_processor.store_image_metadata(image_data)
|
| 128 |
+
processed_count += 1
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"Error processing image {image_info['image_id']}: {e}")
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
# Print filtering summary
|
| 135 |
+
if filtered_count > 0:
|
| 136 |
+
print(f"\nImage Filtering Summary:")
|
| 137 |
+
print(f" Total filtered: {filtered_count}")
|
| 138 |
+
for reason, count in filter_reasons.items():
|
| 139 |
+
print(f" {reason}: {count}")
|
| 140 |
+
print()
|
| 141 |
+
|
| 142 |
+
if processed_count > 0:
|
| 143 |
+
print(f"Processed and stored metadata for {processed_count} images")
|
| 144 |
+
else:
|
| 145 |
+
print("No images found in documents")
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def main():
|
| 149 |
+
"""Main preprocessing pipeline."""
|
| 150 |
+
# Validate configuration
|
| 151 |
+
try:
|
| 152 |
+
validate_api_key()
|
| 153 |
+
except ValueError as e:
|
| 154 |
+
print(f"Error: {e}")
|
| 155 |
+
return
|
| 156 |
+
|
| 157 |
+
# Print configuration
|
| 158 |
+
print_config()
|
| 159 |
+
|
| 160 |
+
print("\nStarting preprocessing pipeline...")
|
| 161 |
+
|
| 162 |
+
# Load documents using utils
|
| 163 |
+
print("\nLoading documents...")
|
| 164 |
+
loader = DocumentLoader()
|
| 165 |
+
documents = loader.load_text_documents()
|
| 166 |
+
|
| 167 |
+
print(f"Loaded {len(documents)} documents")
|
| 168 |
+
|
| 169 |
+
# Define methods to preprocess
|
| 170 |
+
methods = ['vanilla', 'dpr', 'bm25', 'graph', 'context_stuffing']
|
| 171 |
+
|
| 172 |
+
# Preprocess for each method
|
| 173 |
+
for method in methods:
|
| 174 |
+
try:
|
| 175 |
+
preprocess_for_method(method, documents)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(f"Error preprocessing for {method}: {e}")
|
| 178 |
+
import traceback
|
| 179 |
+
traceback.print_exc()
|
| 180 |
+
|
| 181 |
+
# Extract and process images
|
| 182 |
+
try:
|
| 183 |
+
extract_and_process_images(documents)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f"Error processing images: {e}")
|
| 186 |
+
import traceback
|
| 187 |
+
traceback.print_exc()
|
| 188 |
+
|
| 189 |
+
print("\n" + "="*50)
|
| 190 |
+
print("Preprocessing complete!")
|
| 191 |
+
print("="*50)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
main()
|
query_bm25.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BM25 keyword search with cross-encoder re-ranking and hybrid search support.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss
|
| 7 |
+
from typing import Tuple, List, Optional
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
from sentence_transformers import CrossEncoder
|
| 10 |
+
|
| 11 |
+
import config
|
| 12 |
+
import utils
|
| 13 |
+
|
| 14 |
+
# Initialize models
|
| 15 |
+
client = OpenAI(api_key=config.OPENAI_API_KEY)
|
| 16 |
+
cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL)
|
| 17 |
+
|
| 18 |
+
# Global variables for lazy loading
|
| 19 |
+
_bm25_index = None
|
| 20 |
+
_texts = None
|
| 21 |
+
_metadata = None
|
| 22 |
+
_semantic_index = None
|
| 23 |
+
|
| 24 |
+
def _load_bm25_index():
|
| 25 |
+
"""Lazy load BM25 index and metadata."""
|
| 26 |
+
global _bm25_index, _texts, _metadata, _semantic_index
|
| 27 |
+
|
| 28 |
+
if _bm25_index is None:
|
| 29 |
+
# Initialize defaults
|
| 30 |
+
_texts = []
|
| 31 |
+
_metadata = []
|
| 32 |
+
_semantic_index = None
|
| 33 |
+
try:
|
| 34 |
+
import pickle
|
| 35 |
+
|
| 36 |
+
if config.BM25_INDEX.exists():
|
| 37 |
+
with open(config.BM25_INDEX, 'rb') as f:
|
| 38 |
+
bm25_data = pickle.load(f)
|
| 39 |
+
|
| 40 |
+
if isinstance(bm25_data, dict):
|
| 41 |
+
_bm25_index = bm25_data.get('index') or bm25_data.get('bm25')
|
| 42 |
+
chunks = bm25_data.get('texts', [])
|
| 43 |
+
if chunks:
|
| 44 |
+
_texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')]
|
| 45 |
+
_metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')]
|
| 46 |
+
else:
|
| 47 |
+
_texts = []
|
| 48 |
+
_metadata = []
|
| 49 |
+
|
| 50 |
+
# Load semantic embeddings if available for hybrid search
|
| 51 |
+
if 'embeddings' in bm25_data:
|
| 52 |
+
semantic_embeddings = bm25_data['embeddings']
|
| 53 |
+
# Build FAISS index
|
| 54 |
+
import faiss
|
| 55 |
+
dimension = semantic_embeddings.shape[1]
|
| 56 |
+
_semantic_index = faiss.IndexFlatIP(dimension)
|
| 57 |
+
faiss.normalize_L2(semantic_embeddings)
|
| 58 |
+
_semantic_index.add(semantic_embeddings)
|
| 59 |
+
else:
|
| 60 |
+
_bm25_index = bm25_data
|
| 61 |
+
_texts = []
|
| 62 |
+
_metadata = []
|
| 63 |
+
|
| 64 |
+
print(f"Loaded BM25 index with {len(_texts)} documents")
|
| 65 |
+
else:
|
| 66 |
+
print("BM25 index not found. Run preprocess.py first.")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Error loading BM25 index: {e}")
|
| 70 |
+
_bm25_index = None
|
| 71 |
+
_texts = []
|
| 72 |
+
_metadata = []
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
|
| 76 |
+
"""
|
| 77 |
+
Query using BM25 keyword search with re-ranking.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
question: User's question
|
| 81 |
+
image_path: Optional path to an image
|
| 82 |
+
top_k: Number of relevant chunks to retrieve
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tuple of (answer, citations)
|
| 86 |
+
"""
|
| 87 |
+
if top_k is None:
|
| 88 |
+
top_k = config.DEFAULT_TOP_K
|
| 89 |
+
|
| 90 |
+
# Load index if not already loaded
|
| 91 |
+
_load_bm25_index()
|
| 92 |
+
|
| 93 |
+
if _bm25_index is None or len(_texts) == 0:
|
| 94 |
+
return "BM25 index not loaded. Please run preprocess.py first.", []
|
| 95 |
+
|
| 96 |
+
# Tokenize query for BM25
|
| 97 |
+
tokenized_query = question.lower().split()
|
| 98 |
+
|
| 99 |
+
# Get BM25 scores
|
| 100 |
+
bm25_scores = _bm25_index.get_scores(tokenized_query)
|
| 101 |
+
|
| 102 |
+
# Get top candidates (retrieve more for re-ranking)
|
| 103 |
+
top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
|
| 104 |
+
|
| 105 |
+
# Prepare candidates for re-ranking
|
| 106 |
+
candidates = []
|
| 107 |
+
for idx in top_indices:
|
| 108 |
+
if idx < len(_texts) and bm25_scores[idx] > 0:
|
| 109 |
+
candidates.append({
|
| 110 |
+
'text': _texts[idx],
|
| 111 |
+
'bm25_score': bm25_scores[idx],
|
| 112 |
+
'metadata': _metadata[idx],
|
| 113 |
+
'idx': idx
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
# Re-rank with cross-encoder
|
| 117 |
+
if candidates:
|
| 118 |
+
pairs = [[question, cand['text']] for cand in candidates]
|
| 119 |
+
cross_scores = cross_encoder.predict(pairs)
|
| 120 |
+
|
| 121 |
+
# Add cross-encoder scores and sort
|
| 122 |
+
for i, score in enumerate(cross_scores):
|
| 123 |
+
candidates[i]['cross_score'] = score
|
| 124 |
+
|
| 125 |
+
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
|
| 126 |
+
|
| 127 |
+
# Collect citations
|
| 128 |
+
citations = []
|
| 129 |
+
sources_seen = set()
|
| 130 |
+
|
| 131 |
+
for chunk in candidates:
|
| 132 |
+
chunk_meta = chunk['metadata']
|
| 133 |
+
|
| 134 |
+
if chunk_meta['source'] not in sources_seen:
|
| 135 |
+
citation = {
|
| 136 |
+
'source': chunk_meta['source'],
|
| 137 |
+
'type': chunk_meta['type'],
|
| 138 |
+
'bm25_score': round(chunk['bm25_score'], 3),
|
| 139 |
+
'rerank_score': round(chunk['cross_score'], 3)
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if chunk_meta['type'] == 'pdf':
|
| 143 |
+
citation['path'] = chunk_meta['path']
|
| 144 |
+
else:
|
| 145 |
+
citation['url'] = chunk_meta.get('url', '')
|
| 146 |
+
|
| 147 |
+
citations.append(citation)
|
| 148 |
+
sources_seen.add(chunk_meta['source'])
|
| 149 |
+
|
| 150 |
+
# Handle image if provided
|
| 151 |
+
image_context = ""
|
| 152 |
+
if image_path:
|
| 153 |
+
try:
|
| 154 |
+
classification = utils.classify_image(image_path)
|
| 155 |
+
# classification is a string, not a dict
|
| 156 |
+
image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]"
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error processing image: {e}")
|
| 159 |
+
|
| 160 |
+
# Build context from retrieved chunks
|
| 161 |
+
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
|
| 162 |
+
|
| 163 |
+
if not context:
|
| 164 |
+
return "No relevant documents found for your query.", []
|
| 165 |
+
|
| 166 |
+
# Generate answer
|
| 167 |
+
prompt = f"""Answer the following question using the retrieved documents:
|
| 168 |
+
|
| 169 |
+
Retrieved Documents:
|
| 170 |
+
{context}{image_context}
|
| 171 |
+
|
| 172 |
+
Question: {question}
|
| 173 |
+
|
| 174 |
+
Instructions:
|
| 175 |
+
1. Provide a comprehensive answer based on the retrieved documents
|
| 176 |
+
2. Mention specific details from the sources
|
| 177 |
+
3. If the documents don't fully answer the question, indicate what information is missing"""
|
| 178 |
+
|
| 179 |
+
# For GPT-5, temperature must be default (1.0)
|
| 180 |
+
response = client.chat.completions.create(
|
| 181 |
+
model=config.OPENAI_CHAT_MODEL,
|
| 182 |
+
messages=[
|
| 183 |
+
{"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."},
|
| 184 |
+
{"role": "user", "content": prompt}
|
| 185 |
+
],
|
| 186 |
+
max_completion_tokens=config.DEFAULT_MAX_TOKENS
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
answer = response.choices[0].message.content
|
| 190 |
+
|
| 191 |
+
return answer, citations
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]:
|
| 195 |
+
"""
|
| 196 |
+
Hybrid search combining BM25 and semantic search.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
question: User's question
|
| 200 |
+
top_k: Number of relevant chunks to retrieve
|
| 201 |
+
alpha: Weight for BM25 scores (1-alpha for semantic)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Tuple of (answer, citations)
|
| 205 |
+
"""
|
| 206 |
+
if top_k is None:
|
| 207 |
+
top_k = config.DEFAULT_TOP_K
|
| 208 |
+
if alpha is None:
|
| 209 |
+
alpha = config.DEFAULT_HYBRID_ALPHA
|
| 210 |
+
|
| 211 |
+
# Load index if not already loaded
|
| 212 |
+
_load_bm25_index()
|
| 213 |
+
|
| 214 |
+
if _bm25_index is None or _semantic_index is None:
|
| 215 |
+
return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", []
|
| 216 |
+
|
| 217 |
+
# Get BM25 scores
|
| 218 |
+
tokenized_query = question.lower().split()
|
| 219 |
+
bm25_scores = _bm25_index.get_scores(tokenized_query)
|
| 220 |
+
|
| 221 |
+
# Normalize BM25 scores
|
| 222 |
+
if bm25_scores.max() > 0:
|
| 223 |
+
bm25_scores = bm25_scores / bm25_scores.max()
|
| 224 |
+
|
| 225 |
+
# Get semantic scores using FAISS
|
| 226 |
+
embedding_generator = utils.EmbeddingGenerator()
|
| 227 |
+
query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32)
|
| 228 |
+
faiss.normalize_L2(query_embedding)
|
| 229 |
+
|
| 230 |
+
# Search semantic index for all documents
|
| 231 |
+
k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER)
|
| 232 |
+
distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search)
|
| 233 |
+
|
| 234 |
+
# Create semantic scores array
|
| 235 |
+
semantic_scores = np.zeros(len(_texts))
|
| 236 |
+
for idx, dist in zip(indices[0], distances[0]):
|
| 237 |
+
if idx < len(_texts):
|
| 238 |
+
semantic_scores[idx] = dist
|
| 239 |
+
|
| 240 |
+
# Combine scores
|
| 241 |
+
hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores
|
| 242 |
+
|
| 243 |
+
# Get top candidates
|
| 244 |
+
top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
|
| 245 |
+
|
| 246 |
+
# Prepare candidates
|
| 247 |
+
candidates = []
|
| 248 |
+
for idx in top_indices:
|
| 249 |
+
if idx < len(_texts) and hybrid_scores[idx] > 0:
|
| 250 |
+
candidates.append({
|
| 251 |
+
'text': _texts[idx],
|
| 252 |
+
'hybrid_score': hybrid_scores[idx],
|
| 253 |
+
'bm25_score': bm25_scores[idx],
|
| 254 |
+
'semantic_score': semantic_scores[idx],
|
| 255 |
+
'metadata': _metadata[idx],
|
| 256 |
+
'idx': idx
|
| 257 |
+
})
|
| 258 |
+
|
| 259 |
+
# Re-rank with cross-encoder
|
| 260 |
+
if candidates:
|
| 261 |
+
pairs = [[question, cand['text']] for cand in candidates]
|
| 262 |
+
cross_scores = cross_encoder.predict(pairs)
|
| 263 |
+
|
| 264 |
+
for i, score in enumerate(cross_scores):
|
| 265 |
+
candidates[i]['cross_score'] = score
|
| 266 |
+
|
| 267 |
+
# Final ranking using cross-encoder scores
|
| 268 |
+
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
|
| 269 |
+
|
| 270 |
+
# Collect citations
|
| 271 |
+
citations = []
|
| 272 |
+
sources_seen = set()
|
| 273 |
+
|
| 274 |
+
for chunk in candidates:
|
| 275 |
+
chunk_meta = chunk['metadata']
|
| 276 |
+
|
| 277 |
+
if chunk_meta['source'] not in sources_seen:
|
| 278 |
+
citation = {
|
| 279 |
+
'source': chunk_meta['source'],
|
| 280 |
+
'type': chunk_meta['type'],
|
| 281 |
+
'hybrid_score': round(chunk['hybrid_score'], 3),
|
| 282 |
+
'rerank_score': round(chunk.get('cross_score', 0), 3)
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
if chunk_meta['type'] == 'pdf':
|
| 286 |
+
citation['path'] = chunk_meta['path']
|
| 287 |
+
else:
|
| 288 |
+
citation['url'] = chunk_meta.get('url', '')
|
| 289 |
+
|
| 290 |
+
citations.append(citation)
|
| 291 |
+
sources_seen.add(chunk_meta['source'])
|
| 292 |
+
|
| 293 |
+
# Build context
|
| 294 |
+
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
|
| 295 |
+
|
| 296 |
+
if not context:
|
| 297 |
+
return "No relevant documents found for your query.", []
|
| 298 |
+
|
| 299 |
+
# Generate answer
|
| 300 |
+
prompt = f"""Using the following retrieved passages, answer the question:
|
| 301 |
+
|
| 302 |
+
{context}
|
| 303 |
+
|
| 304 |
+
Question: {question}
|
| 305 |
+
|
| 306 |
+
Provide a clear, detailed answer based on the information in the passages."""
|
| 307 |
+
|
| 308 |
+
# For GPT-5, temperature must be default (1.0)
|
| 309 |
+
response = client.chat.completions.create(
|
| 310 |
+
model=config.OPENAI_CHAT_MODEL,
|
| 311 |
+
messages=[
|
| 312 |
+
{"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."},
|
| 313 |
+
{"role": "user", "content": prompt}
|
| 314 |
+
],
|
| 315 |
+
max_completion_tokens=config.DEFAULT_MAX_TOKENS
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
answer = response.choices[0].message.content
|
| 319 |
+
|
| 320 |
+
return answer, citations
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
if __name__ == "__main__":
|
| 324 |
+
# Test BM25 query
|
| 325 |
+
test_questions = [
|
| 326 |
+
"lockout tagout procedures",
|
| 327 |
+
"machine guard requirements OSHA",
|
| 328 |
+
"robot safety collaborative workspace"
|
| 329 |
+
]
|
| 330 |
+
|
| 331 |
+
for q in test_questions:
|
| 332 |
+
print(f"\nQuestion: {q}")
|
| 333 |
+
answer, citations = query(q)
|
| 334 |
+
print(f"Answer: {answer[:200]}...")
|
| 335 |
+
print(f"Citations: {citations}")
|
| 336 |
+
print("-" * 50)
|
query_context.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Context stuffing query module.
|
| 3 |
+
Loads full documents and uses heuristics to select relevant content.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pickle
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 10 |
+
from openai import OpenAI
|
| 11 |
+
import tiktoken
|
| 12 |
+
|
| 13 |
+
from config import *
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class ContextStuffingRetriever:
|
| 18 |
+
"""Context stuffing with heuristic document selection."""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.client = OpenAI(api_key=OPENAI_API_KEY)
|
| 22 |
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
| 23 |
+
self.documents = None
|
| 24 |
+
self._load_documents()
|
| 25 |
+
|
| 26 |
+
def _load_documents(self):
|
| 27 |
+
"""Load full documents for context stuffing."""
|
| 28 |
+
try:
|
| 29 |
+
if CONTEXT_DOCS.exists():
|
| 30 |
+
logger.info("Loading documents for context stuffing...")
|
| 31 |
+
|
| 32 |
+
with open(CONTEXT_DOCS, 'rb') as f:
|
| 33 |
+
data = pickle.load(f)
|
| 34 |
+
|
| 35 |
+
if isinstance(data, list) and len(data) > 0:
|
| 36 |
+
# Handle both old format (list of chunks) and new format (list of DocumentChunk objects)
|
| 37 |
+
if hasattr(data[0], 'text'): # New format with DocumentChunk objects
|
| 38 |
+
self.documents = []
|
| 39 |
+
for chunk in data:
|
| 40 |
+
self.documents.append({
|
| 41 |
+
'text': chunk.text,
|
| 42 |
+
'metadata': chunk.metadata,
|
| 43 |
+
'chunk_id': chunk.chunk_id
|
| 44 |
+
})
|
| 45 |
+
else: # Old format with dict objects
|
| 46 |
+
self.documents = data
|
| 47 |
+
|
| 48 |
+
logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing")
|
| 49 |
+
else:
|
| 50 |
+
logger.warning("No documents found in context stuffing file")
|
| 51 |
+
self.documents = []
|
| 52 |
+
else:
|
| 53 |
+
logger.warning("Context stuffing documents not found. Run preprocess.py first.")
|
| 54 |
+
self.documents = []
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Error loading context stuffing documents: {e}")
|
| 58 |
+
self.documents = []
|
| 59 |
+
|
| 60 |
+
def _calculate_keyword_score(self, text: str, question: str) -> float:
|
| 61 |
+
"""Calculate keyword overlap score between text and question."""
|
| 62 |
+
# Simple keyword matching heuristic
|
| 63 |
+
question_words = set(re.findall(r'\w+', question.lower()))
|
| 64 |
+
text_words = set(re.findall(r'\w+', text.lower()))
|
| 65 |
+
|
| 66 |
+
if not question_words:
|
| 67 |
+
return 0.0
|
| 68 |
+
|
| 69 |
+
overlap = len(question_words & text_words)
|
| 70 |
+
return overlap / len(question_words)
|
| 71 |
+
|
| 72 |
+
def _calculate_section_relevance(self, text: str, question: str) -> float:
|
| 73 |
+
"""Calculate section relevance using multiple heuristics."""
|
| 74 |
+
score = 0.0
|
| 75 |
+
|
| 76 |
+
# Keyword overlap score (weight: 0.5)
|
| 77 |
+
keyword_score = self._calculate_keyword_score(text, question)
|
| 78 |
+
score += 0.5 * keyword_score
|
| 79 |
+
|
| 80 |
+
# Length penalty (prefer medium-length sections)
|
| 81 |
+
text_length = len(text.split())
|
| 82 |
+
optimal_length = 200 # words
|
| 83 |
+
length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length)
|
| 84 |
+
score += 0.2 * length_score
|
| 85 |
+
|
| 86 |
+
# Header/title bonus (if text starts with common header patterns)
|
| 87 |
+
if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()):
|
| 88 |
+
score += 0.1
|
| 89 |
+
|
| 90 |
+
# Question type specific bonuses
|
| 91 |
+
question_lower = question.lower()
|
| 92 |
+
text_lower = text.lower()
|
| 93 |
+
|
| 94 |
+
if any(word in question_lower for word in ['what', 'define', 'definition']):
|
| 95 |
+
if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']):
|
| 96 |
+
score += 0.2
|
| 97 |
+
|
| 98 |
+
if any(word in question_lower for word in ['how', 'procedure', 'steps']):
|
| 99 |
+
if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']):
|
| 100 |
+
score += 0.2
|
| 101 |
+
|
| 102 |
+
if any(word in question_lower for word in ['requirement', 'shall', 'must']):
|
| 103 |
+
if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']):
|
| 104 |
+
score += 0.2
|
| 105 |
+
|
| 106 |
+
return min(1.0, score) # Cap at 1.0
|
| 107 |
+
|
| 108 |
+
def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]:
|
| 109 |
+
"""Select most relevant documents using heuristics."""
|
| 110 |
+
if not self.documents:
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
if max_tokens is None:
|
| 114 |
+
max_tokens = MAX_CONTEXT_TOKENS
|
| 115 |
+
|
| 116 |
+
# Score all documents
|
| 117 |
+
scored_docs = []
|
| 118 |
+
for doc in self.documents:
|
| 119 |
+
text = doc.get('text', '')
|
| 120 |
+
if text.strip():
|
| 121 |
+
relevance_score = self._calculate_section_relevance(text, question)
|
| 122 |
+
|
| 123 |
+
doc_info = {
|
| 124 |
+
'text': text,
|
| 125 |
+
'metadata': doc.get('metadata', {}),
|
| 126 |
+
'score': relevance_score,
|
| 127 |
+
'token_count': len(self.encoding.encode(text))
|
| 128 |
+
}
|
| 129 |
+
scored_docs.append(doc_info)
|
| 130 |
+
|
| 131 |
+
# Sort by relevance score
|
| 132 |
+
scored_docs.sort(key=lambda x: x['score'], reverse=True)
|
| 133 |
+
|
| 134 |
+
# Select documents within token limit
|
| 135 |
+
selected_docs = []
|
| 136 |
+
total_tokens = 0
|
| 137 |
+
|
| 138 |
+
for doc in scored_docs:
|
| 139 |
+
if doc['score'] > 0.1: # Minimum relevance threshold
|
| 140 |
+
if total_tokens + doc['token_count'] <= max_tokens:
|
| 141 |
+
selected_docs.append(doc)
|
| 142 |
+
total_tokens += doc['token_count']
|
| 143 |
+
else:
|
| 144 |
+
# Try to include a truncated version
|
| 145 |
+
remaining_tokens = max_tokens - total_tokens
|
| 146 |
+
if remaining_tokens > 100: # Only if meaningful content can fit
|
| 147 |
+
truncated_text = self._truncate_text(doc['text'], remaining_tokens)
|
| 148 |
+
if truncated_text:
|
| 149 |
+
doc['text'] = truncated_text
|
| 150 |
+
doc['token_count'] = len(self.encoding.encode(truncated_text))
|
| 151 |
+
selected_docs.append(doc)
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens")
|
| 155 |
+
return selected_docs
|
| 156 |
+
|
| 157 |
+
def _truncate_text(self, text: str, max_tokens: int) -> str:
|
| 158 |
+
"""Truncate text to fit within token limit while preserving meaning."""
|
| 159 |
+
tokens = self.encoding.encode(text)
|
| 160 |
+
if len(tokens) <= max_tokens:
|
| 161 |
+
return text
|
| 162 |
+
|
| 163 |
+
# Truncate and try to end at a sentence boundary
|
| 164 |
+
truncated_tokens = tokens[:max_tokens]
|
| 165 |
+
truncated_text = self.encoding.decode(truncated_tokens)
|
| 166 |
+
|
| 167 |
+
# Try to end at a sentence boundary
|
| 168 |
+
sentences = re.split(r'[.!?]+', truncated_text)
|
| 169 |
+
if len(sentences) > 1:
|
| 170 |
+
# Remove the last incomplete sentence
|
| 171 |
+
truncated_text = '.'.join(sentences[:-1]) + '.'
|
| 172 |
+
|
| 173 |
+
return truncated_text
|
| 174 |
+
|
| 175 |
+
def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str:
|
| 176 |
+
"""Generate answer using full context stuffing approach."""
|
| 177 |
+
if not context_docs:
|
| 178 |
+
return "I couldn't find any relevant documents to answer your question."
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
# Assemble context from selected documents
|
| 182 |
+
context_parts = []
|
| 183 |
+
sources = []
|
| 184 |
+
|
| 185 |
+
for i, doc in enumerate(context_docs, 1):
|
| 186 |
+
text = doc['text']
|
| 187 |
+
metadata = doc['metadata']
|
| 188 |
+
source = metadata.get('source', f'Document {i}')
|
| 189 |
+
|
| 190 |
+
context_parts.append(f"=== {source} ===\n{text}")
|
| 191 |
+
if source not in sources:
|
| 192 |
+
sources.append(source)
|
| 193 |
+
|
| 194 |
+
full_context = "\n\n".join(context_parts)
|
| 195 |
+
|
| 196 |
+
# Create system message for context stuffing
|
| 197 |
+
system_message = (
|
| 198 |
+
"You are an expert in occupational safety and health regulations. "
|
| 199 |
+
"Answer the user's question using the provided regulatory documents and technical materials. "
|
| 200 |
+
"Provide comprehensive, accurate answers that directly address the question. "
|
| 201 |
+
"Reference specific sections or requirements when applicable. "
|
| 202 |
+
"If the provided context doesn't fully answer the question, clearly state what information is missing."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Create user message
|
| 206 |
+
user_message = f"""Based on the following regulatory and technical documents, please answer this question:
|
| 207 |
+
|
| 208 |
+
QUESTION: {question}
|
| 209 |
+
|
| 210 |
+
DOCUMENTS:
|
| 211 |
+
{full_context}
|
| 212 |
+
|
| 213 |
+
Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well."""
|
| 214 |
+
|
| 215 |
+
# For GPT-5, temperature must be default (1.0)
|
| 216 |
+
response = self.client.chat.completions.create(
|
| 217 |
+
model=OPENAI_CHAT_MODEL,
|
| 218 |
+
messages=[
|
| 219 |
+
{"role": "system", "content": system_message},
|
| 220 |
+
{"role": "user", "content": user_message}
|
| 221 |
+
],
|
| 222 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
answer = response.choices[0].message.content.strip()
|
| 226 |
+
|
| 227 |
+
# Add source information
|
| 228 |
+
if len(sources) > 1:
|
| 229 |
+
answer += f"\n\n*Sources consulted: {', '.join(sources)}*"
|
| 230 |
+
elif sources:
|
| 231 |
+
answer += f"\n\n*Source: {sources[0]}*"
|
| 232 |
+
|
| 233 |
+
return answer
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
logger.error(f"Error generating context stuffing answer: {e}")
|
| 237 |
+
return "I apologize, but I encountered an error while generating the answer using context stuffing."
|
| 238 |
+
|
| 239 |
+
# Global retriever instance
|
| 240 |
+
_retriever = None
|
| 241 |
+
|
| 242 |
+
def get_retriever() -> ContextStuffingRetriever:
|
| 243 |
+
"""Get or create global context stuffing retriever instance."""
|
| 244 |
+
global _retriever
|
| 245 |
+
if _retriever is None:
|
| 246 |
+
_retriever = ContextStuffingRetriever()
|
| 247 |
+
return _retriever
|
| 248 |
+
|
| 249 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
|
| 250 |
+
"""
|
| 251 |
+
Main context stuffing query function with unified signature.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
question: User question
|
| 255 |
+
image_path: Optional image path (not used in context stuffing but kept for consistency)
|
| 256 |
+
top_k: Not used in context stuffing (uses heuristic selection instead)
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Tuple of (answer, citations)
|
| 260 |
+
"""
|
| 261 |
+
try:
|
| 262 |
+
retriever = get_retriever()
|
| 263 |
+
|
| 264 |
+
# Select relevant documents using heuristics
|
| 265 |
+
relevant_docs = retriever.select_relevant_documents(question)
|
| 266 |
+
|
| 267 |
+
if not relevant_docs:
|
| 268 |
+
return "I couldn't find any relevant documents to answer your question.", []
|
| 269 |
+
|
| 270 |
+
# Generate comprehensive answer
|
| 271 |
+
answer = retriever.generate_answer(question, relevant_docs)
|
| 272 |
+
|
| 273 |
+
# Prepare citations
|
| 274 |
+
citations = []
|
| 275 |
+
for i, doc in enumerate(relevant_docs, 1):
|
| 276 |
+
metadata = doc['metadata']
|
| 277 |
+
citations.append({
|
| 278 |
+
'rank': i,
|
| 279 |
+
'score': float(doc['score']),
|
| 280 |
+
'source': metadata.get('source', 'Unknown'),
|
| 281 |
+
'type': metadata.get('type', 'unknown'),
|
| 282 |
+
'method': 'context_stuffing',
|
| 283 |
+
'tokens_used': doc['token_count']
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
logger.info(f"Context stuffing query completed. Used {len(citations)} documents.")
|
| 287 |
+
return answer, citations
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.error(f"Error in context stuffing query: {e}")
|
| 291 |
+
error_message = "I apologize, but I encountered an error while processing your question with context stuffing."
|
| 292 |
+
return error_message, []
|
| 293 |
+
|
| 294 |
+
def query_with_details(question: str, image_path: Optional[str] = None,
|
| 295 |
+
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
|
| 296 |
+
"""
|
| 297 |
+
Context stuffing query function that returns detailed chunk information (for compatibility).
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Tuple of (answer, citations, chunks)
|
| 301 |
+
"""
|
| 302 |
+
answer, citations = query(question, image_path, top_k)
|
| 303 |
+
|
| 304 |
+
# Convert citations to chunk format for backward compatibility
|
| 305 |
+
chunks = []
|
| 306 |
+
for citation in citations:
|
| 307 |
+
chunks.append((
|
| 308 |
+
f"Document {citation['rank']} (Score: {citation['score']:.3f})",
|
| 309 |
+
citation['score'],
|
| 310 |
+
f"Context from {citation['source']} ({citation['tokens_used']} tokens)",
|
| 311 |
+
citation['source']
|
| 312 |
+
))
|
| 313 |
+
|
| 314 |
+
return answer, citations, chunks
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
# Test the context stuffing system
|
| 318 |
+
test_question = "What are the general requirements for machine guarding?"
|
| 319 |
+
|
| 320 |
+
print("Testing context stuffing retrieval system...")
|
| 321 |
+
print(f"Question: {test_question}")
|
| 322 |
+
print("-" * 50)
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
answer, citations = query(test_question)
|
| 326 |
+
|
| 327 |
+
print("Answer:")
|
| 328 |
+
print(answer)
|
| 329 |
+
print(f"\nCitations ({len(citations)} documents used):")
|
| 330 |
+
for citation in citations:
|
| 331 |
+
print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})")
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"Error during testing: {e}")
|
query_dpr.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dense Passage Retrieval (DPR) query module.
|
| 3 |
+
Uses bi-encoder for retrieval and cross-encoder for re-ranking.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pickle
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Tuple, Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
import faiss
|
| 11 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 12 |
+
from openai import OpenAI
|
| 13 |
+
|
| 14 |
+
from config import *
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class DPRRetriever:
|
| 19 |
+
"""Dense Passage Retrieval with cross-encoder re-ranking."""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.client = OpenAI(api_key=OPENAI_API_KEY)
|
| 23 |
+
self.bi_encoder = None
|
| 24 |
+
self.cross_encoder = None
|
| 25 |
+
self.index = None
|
| 26 |
+
self.metadata = None
|
| 27 |
+
self._load_models()
|
| 28 |
+
self._load_index()
|
| 29 |
+
|
| 30 |
+
def _load_models(self):
|
| 31 |
+
"""Load bi-encoder and cross-encoder models."""
|
| 32 |
+
try:
|
| 33 |
+
logger.info("Loading DPR models...")
|
| 34 |
+
self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
|
| 35 |
+
self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
|
| 36 |
+
|
| 37 |
+
if DEVICE == "cuda":
|
| 38 |
+
self.bi_encoder = self.bi_encoder.to(DEVICE)
|
| 39 |
+
self.cross_encoder = self.cross_encoder.to(DEVICE)
|
| 40 |
+
|
| 41 |
+
logger.info("✓ DPR models loaded successfully")
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Error loading DPR models: {e}")
|
| 45 |
+
raise
|
| 46 |
+
|
| 47 |
+
def _load_index(self):
|
| 48 |
+
"""Load FAISS index and metadata."""
|
| 49 |
+
try:
|
| 50 |
+
if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists():
|
| 51 |
+
logger.info("Loading DPR index and metadata...")
|
| 52 |
+
|
| 53 |
+
# Load FAISS index
|
| 54 |
+
self.index = faiss.read_index(str(DPR_FAISS_INDEX))
|
| 55 |
+
|
| 56 |
+
# Load metadata
|
| 57 |
+
with open(DPR_METADATA, 'rb') as f:
|
| 58 |
+
data = pickle.load(f)
|
| 59 |
+
self.metadata = data
|
| 60 |
+
|
| 61 |
+
logger.info(f"✓ Loaded DPR index with {len(self.metadata)} chunks")
|
| 62 |
+
else:
|
| 63 |
+
logger.warning("DPR index not found. Run preprocess.py first.")
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Error loading DPR index: {e}")
|
| 67 |
+
raise
|
| 68 |
+
|
| 69 |
+
def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
|
| 70 |
+
"""Retrieve candidate passages using bi-encoder."""
|
| 71 |
+
if self.index is None or self.metadata is None:
|
| 72 |
+
raise ValueError("DPR index not loaded. Run preprocess.py first.")
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Encode question with bi-encoder
|
| 76 |
+
question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True)
|
| 77 |
+
|
| 78 |
+
# Normalize for cosine similarity
|
| 79 |
+
faiss.normalize_L2(question_embedding)
|
| 80 |
+
|
| 81 |
+
# Search FAISS index
|
| 82 |
+
# Retrieve more candidates for re-ranking
|
| 83 |
+
retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata))
|
| 84 |
+
scores, indices = self.index.search(question_embedding, retrieve_k)
|
| 85 |
+
|
| 86 |
+
# Prepare candidates
|
| 87 |
+
candidates = []
|
| 88 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 89 |
+
if idx < len(self.metadata):
|
| 90 |
+
chunk_data = self.metadata[idx]
|
| 91 |
+
candidates.append((
|
| 92 |
+
chunk_data['text'],
|
| 93 |
+
float(score),
|
| 94 |
+
chunk_data['metadata']
|
| 95 |
+
))
|
| 96 |
+
|
| 97 |
+
logger.info(f"Retrieved {len(candidates)} candidates for re-ranking")
|
| 98 |
+
return candidates
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Error in candidate retrieval: {e}")
|
| 102 |
+
raise
|
| 103 |
+
|
| 104 |
+
def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]],
|
| 105 |
+
top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
|
| 106 |
+
"""Re-rank candidates using cross-encoder."""
|
| 107 |
+
if not candidates:
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
# Prepare pairs for cross-encoder
|
| 112 |
+
pairs = [(question, candidate[0]) for candidate in candidates]
|
| 113 |
+
|
| 114 |
+
# Get cross-encoder scores
|
| 115 |
+
cross_scores = self.cross_encoder.predict(pairs)
|
| 116 |
+
|
| 117 |
+
# Combine with candidate data and re-sort
|
| 118 |
+
reranked = []
|
| 119 |
+
for i, (text, bi_score, metadata) in enumerate(candidates):
|
| 120 |
+
cross_score = float(cross_scores[i])
|
| 121 |
+
|
| 122 |
+
# Filter by minimum relevance score
|
| 123 |
+
if cross_score >= MIN_RELEVANCE_SCORE:
|
| 124 |
+
reranked.append((text, cross_score, metadata))
|
| 125 |
+
|
| 126 |
+
# Sort by cross-encoder score (descending)
|
| 127 |
+
reranked.sort(key=lambda x: x[1], reverse=True)
|
| 128 |
+
|
| 129 |
+
# Return top-k
|
| 130 |
+
final_results = reranked[:top_k]
|
| 131 |
+
logger.info(f"Re-ranked to {len(final_results)} final results")
|
| 132 |
+
|
| 133 |
+
return final_results
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error in re-ranking: {e}")
|
| 137 |
+
# Fall back to bi-encoder results
|
| 138 |
+
return candidates[:top_k]
|
| 139 |
+
|
| 140 |
+
def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str:
|
| 141 |
+
"""Generate answer using GPT with retrieved context."""
|
| 142 |
+
if not context_chunks:
|
| 143 |
+
return "I couldn't find relevant information to answer your question."
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
# Prepare context
|
| 147 |
+
context_parts = []
|
| 148 |
+
for i, (text, score, metadata) in enumerate(context_chunks, 1):
|
| 149 |
+
source = metadata.get('source', 'Unknown')
|
| 150 |
+
context_parts.append(f"[Context {i}] Source: {source}\n{text}")
|
| 151 |
+
|
| 152 |
+
context = "\n\n".join(context_parts)
|
| 153 |
+
|
| 154 |
+
# Create system message
|
| 155 |
+
system_message = (
|
| 156 |
+
"You are a helpful assistant specialized in occupational safety and health. "
|
| 157 |
+
"Answer questions based only on the provided context. "
|
| 158 |
+
"If the context doesn't contain enough information, say so clearly. "
|
| 159 |
+
"Always cite the source when referencing information."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Create user message
|
| 163 |
+
user_message = f"Context:\n{context}\n\nQuestion: {question}"
|
| 164 |
+
|
| 165 |
+
# Generate response
|
| 166 |
+
# For GPT-5, temperature must be default (1.0)
|
| 167 |
+
response = self.client.chat.completions.create(
|
| 168 |
+
model=OPENAI_CHAT_MODEL,
|
| 169 |
+
messages=[
|
| 170 |
+
{"role": "system", "content": system_message},
|
| 171 |
+
{"role": "user", "content": user_message}
|
| 172 |
+
],
|
| 173 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return response.choices[0].message.content.strip()
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"Error generating answer: {e}")
|
| 180 |
+
return "I apologize, but I encountered an error while generating the answer."
|
| 181 |
+
|
| 182 |
+
# Global retriever instance
|
| 183 |
+
_retriever = None
|
| 184 |
+
|
| 185 |
+
def get_retriever() -> DPRRetriever:
|
| 186 |
+
"""Get or create global DPR retriever instance."""
|
| 187 |
+
global _retriever
|
| 188 |
+
if _retriever is None:
|
| 189 |
+
_retriever = DPRRetriever()
|
| 190 |
+
return _retriever
|
| 191 |
+
|
| 192 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
|
| 193 |
+
"""
|
| 194 |
+
Main DPR query function with unified signature.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
question: User question
|
| 198 |
+
image_path: Optional image path (not used in DPR but kept for consistency)
|
| 199 |
+
top_k: Number of top results to retrieve
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Tuple of (answer, citations)
|
| 203 |
+
"""
|
| 204 |
+
try:
|
| 205 |
+
retriever = get_retriever()
|
| 206 |
+
|
| 207 |
+
# Step 1: Retrieve candidates with bi-encoder
|
| 208 |
+
candidates = retriever.retrieve_candidates(question, top_k)
|
| 209 |
+
|
| 210 |
+
if not candidates:
|
| 211 |
+
return "I couldn't find any relevant information for your question.", []
|
| 212 |
+
|
| 213 |
+
# Step 2: Re-rank with cross-encoder
|
| 214 |
+
reranked_candidates = retriever.rerank_candidates(question, candidates, top_k)
|
| 215 |
+
|
| 216 |
+
# Step 3: Generate answer
|
| 217 |
+
answer = retriever.generate_answer(question, reranked_candidates)
|
| 218 |
+
|
| 219 |
+
# Step 4: Prepare citations
|
| 220 |
+
citations = []
|
| 221 |
+
for i, (text, score, metadata) in enumerate(reranked_candidates, 1):
|
| 222 |
+
citations.append({
|
| 223 |
+
'rank': i,
|
| 224 |
+
'text': text,
|
| 225 |
+
'score': float(score),
|
| 226 |
+
'source': metadata.get('source', 'Unknown'),
|
| 227 |
+
'type': metadata.get('type', 'unknown'),
|
| 228 |
+
'method': 'dpr'
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
logger.info(f"DPR query completed. Retrieved {len(citations)} citations.")
|
| 232 |
+
return answer, citations
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.error(f"Error in DPR query: {e}")
|
| 236 |
+
error_message = "I apologize, but I encountered an error while processing your question with DPR."
|
| 237 |
+
return error_message, []
|
| 238 |
+
|
| 239 |
+
def query_with_details(question: str, image_path: Optional[str] = None,
|
| 240 |
+
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]:
|
| 241 |
+
"""
|
| 242 |
+
DPR query function that returns detailed chunk information (for compatibility).
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Tuple of (answer, citations, chunks)
|
| 246 |
+
"""
|
| 247 |
+
answer, citations = query(question, image_path, top_k)
|
| 248 |
+
|
| 249 |
+
# Convert citations to chunk format for backward compatibility
|
| 250 |
+
chunks = []
|
| 251 |
+
for citation in citations:
|
| 252 |
+
chunks.append((
|
| 253 |
+
f"Rank {citation['rank']} (Score: {citation['score']:.3f})",
|
| 254 |
+
citation['score'],
|
| 255 |
+
citation['text'],
|
| 256 |
+
citation['source']
|
| 257 |
+
))
|
| 258 |
+
|
| 259 |
+
return answer, citations, chunks
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
# Test the DPR system
|
| 263 |
+
test_question = "What are the general requirements for machine guarding?"
|
| 264 |
+
|
| 265 |
+
print("Testing DPR retrieval system...")
|
| 266 |
+
print(f"Question: {test_question}")
|
| 267 |
+
print("-" * 50)
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
answer, citations = query(test_question)
|
| 271 |
+
|
| 272 |
+
print("Answer:")
|
| 273 |
+
print(answer)
|
| 274 |
+
print("\nCitations:")
|
| 275 |
+
for citation in citations:
|
| 276 |
+
print(f"- {citation['source']} (Score: {citation['score']:.3f})")
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"Error during testing: {e}")
|
query_graph.py
CHANGED
|
@@ -1,140 +1,413 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
-
|
|
|
|
| 4 |
from openai import OpenAI
|
| 5 |
import networkx as nx
|
| 6 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Initialize OpenAI client
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
# Load graph from GML
|
| 13 |
-
G = nx.read_gml("graph.gml")
|
| 14 |
-
enodes = list(G.nodes)
|
| 15 |
-
embeddings = np.array([G.nodes[n]['embedding'] for n in enodes])
|
| 16 |
|
| 17 |
-
def
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
emb_resp = client.embeddings.create(
|
| 27 |
-
model=
|
| 28 |
input=question
|
| 29 |
)
|
| 30 |
-
q_vec = emb_resp.data[0].embedding
|
| 31 |
-
|
| 32 |
# Compute cosine similarities
|
| 33 |
-
sims = cosine_similarity([q_vec],
|
| 34 |
idxs = sims.argsort()[::-1][:top_k]
|
| 35 |
-
|
| 36 |
# Collect chunk-level info
|
| 37 |
chunks = []
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
for rank, i in enumerate(idxs, start=1):
|
| 40 |
-
node =
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
score = sims[i]
|
| 44 |
-
|
| 45 |
-
citation
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
chat_resp = client.chat.completions.create(
|
| 61 |
-
model=
|
| 62 |
messages=[
|
| 63 |
-
{"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety."},
|
| 64 |
-
{"role": "user",
|
| 65 |
-
]
|
|
|
|
| 66 |
)
|
|
|
|
| 67 |
answer = chat_resp.choices[0].message.content
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
return answer, sources, chunks
|
| 70 |
|
|
|
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
"""
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
emb_resp = client.embeddings.create(
|
| 80 |
-
model=
|
| 81 |
input=question
|
| 82 |
)
|
| 83 |
-
q_vec = emb_resp.data[0].embedding
|
| 84 |
-
|
| 85 |
-
# Compute similarities
|
| 86 |
-
sims = cosine_similarity([q_vec],
|
| 87 |
idxs = sims.argsort()[::-1][:top_k]
|
| 88 |
-
|
| 89 |
-
#
|
| 90 |
chunks = []
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
for i in idxs:
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
messages=[
|
| 115 |
-
{"role": "system", "content": "You are
|
| 116 |
-
{"role": "user",
|
| 117 |
-
]
|
|
|
|
| 118 |
)
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return answer, sources, chunks
|
| 122 |
|
| 123 |
|
| 124 |
-
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# for header, score, _, citation in chunks:
|
| 139 |
-
# print(f" * {header} (score: {score:.2f}) from {citation}")
|
| 140 |
-
# print("\n", "#"*40, "\n")
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph-based RAG using NetworkX.
|
| 3 |
+
Updated to match the common query signature used by other methods.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import numpy as np
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Tuple, List, Optional
|
| 9 |
from openai import OpenAI
|
| 10 |
import networkx as nx
|
| 11 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 12 |
|
| 13 |
+
from config import *
|
| 14 |
+
from utils import classify_image
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
# Initialize OpenAI client
|
| 19 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 20 |
+
|
| 21 |
+
# Global variables for lazy loading
|
| 22 |
+
_graph = None
|
| 23 |
+
_enodes = None
|
| 24 |
+
_embeddings = None
|
| 25 |
+
|
| 26 |
+
def _load_graph():
|
| 27 |
+
"""Lazy load graph database."""
|
| 28 |
+
global _graph, _enodes, _embeddings
|
| 29 |
+
|
| 30 |
+
if _graph is None:
|
| 31 |
+
try:
|
| 32 |
+
if GRAPH_FILE.exists():
|
| 33 |
+
logger.info("Loading graph database...")
|
| 34 |
+
_graph = nx.read_gml(str(GRAPH_FILE))
|
| 35 |
+
_enodes = list(_graph.nodes)
|
| 36 |
+
# Convert embeddings from lists back to numpy arrays
|
| 37 |
+
embeddings_list = []
|
| 38 |
+
for n in _enodes:
|
| 39 |
+
embedding = _graph.nodes[n]['embedding']
|
| 40 |
+
if isinstance(embedding, list):
|
| 41 |
+
embeddings_list.append(np.array(embedding))
|
| 42 |
+
else:
|
| 43 |
+
embeddings_list.append(embedding)
|
| 44 |
+
_embeddings = np.array(embeddings_list)
|
| 45 |
+
logger.info(f"✓ Loaded graph with {len(_enodes)} nodes")
|
| 46 |
+
else:
|
| 47 |
+
logger.warning("Graph database not found. Run preprocess.py first.")
|
| 48 |
+
_graph = nx.Graph()
|
| 49 |
+
_enodes = []
|
| 50 |
+
_embeddings = np.array([])
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Error loading graph: {e}")
|
| 53 |
+
_graph = nx.Graph()
|
| 54 |
+
_enodes = []
|
| 55 |
+
_embeddings = np.array([])
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
|
| 59 |
"""
|
| 60 |
+
Query using graph-based retrieval.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
question: User's question
|
| 64 |
+
image_path: Optional path to an image (for multimodal queries)
|
| 65 |
+
top_k: Number of relevant chunks to retrieve
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (answer, citations)
|
| 69 |
"""
|
| 70 |
+
|
| 71 |
+
# Load graph if not already loaded
|
| 72 |
+
_load_graph()
|
| 73 |
+
|
| 74 |
+
if len(_enodes) == 0:
|
| 75 |
+
return "Graph database is empty. Please run preprocess.py first.", []
|
| 76 |
+
|
| 77 |
+
# Embed question using OpenAI
|
| 78 |
emb_resp = client.embeddings.create(
|
| 79 |
+
model=OPENAI_EMBEDDING_MODEL,
|
| 80 |
input=question
|
| 81 |
)
|
| 82 |
+
q_vec = np.array(emb_resp.data[0].embedding)
|
| 83 |
+
|
| 84 |
# Compute cosine similarities
|
| 85 |
+
sims = cosine_similarity([q_vec], _embeddings)[0]
|
| 86 |
idxs = sims.argsort()[::-1][:top_k]
|
| 87 |
+
|
| 88 |
# Collect chunk-level info
|
| 89 |
chunks = []
|
| 90 |
+
citations = []
|
| 91 |
+
sources_seen = set()
|
| 92 |
+
|
| 93 |
for rank, i in enumerate(idxs, start=1):
|
| 94 |
+
node = _enodes[i]
|
| 95 |
+
node_data = _graph.nodes[node]
|
| 96 |
+
text = node_data['text']
|
| 97 |
+
|
| 98 |
+
# Extract header from text
|
| 99 |
+
header = text.split('\n', 1)[0].lstrip('#').strip()
|
| 100 |
score = sims[i]
|
| 101 |
+
|
| 102 |
+
# Extract citation format - get source from metadata or node_data
|
| 103 |
+
metadata = node_data.get('metadata', {})
|
| 104 |
+
source = metadata.get('source') or node_data.get('source')
|
| 105 |
+
|
| 106 |
+
if not source:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
if 'url' in metadata: # HTML source
|
| 110 |
+
citation_ref = metadata['url']
|
| 111 |
+
cite_type = 'html'
|
| 112 |
+
elif 'path' in metadata: # PDF source
|
| 113 |
+
citation_ref = metadata['path']
|
| 114 |
+
cite_type = 'pdf'
|
| 115 |
+
elif 'url' in node_data: # Legacy format
|
| 116 |
+
citation_ref = node_data['url']
|
| 117 |
+
cite_type = 'html'
|
| 118 |
+
elif 'path' in node_data: # Legacy format
|
| 119 |
+
citation_ref = node_data['path']
|
| 120 |
+
cite_type = 'pdf'
|
| 121 |
+
else:
|
| 122 |
+
citation_ref = source
|
| 123 |
+
cite_type = 'unknown'
|
| 124 |
+
|
| 125 |
+
chunks.append({
|
| 126 |
+
'header': header,
|
| 127 |
+
'score': score,
|
| 128 |
+
'text': text,
|
| 129 |
+
'citation': citation_ref
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
+
# Add unique citation
|
| 133 |
+
if source not in sources_seen:
|
| 134 |
+
citation_entry = {
|
| 135 |
+
'source': source,
|
| 136 |
+
'type': cite_type,
|
| 137 |
+
'relevance_score': round(float(score), 3)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
if cite_type == 'html':
|
| 141 |
+
citation_entry['url'] = citation_ref
|
| 142 |
+
elif cite_type == 'pdf':
|
| 143 |
+
citation_entry['path'] = citation_ref
|
| 144 |
+
|
| 145 |
+
citations.append(citation_entry)
|
| 146 |
+
sources_seen.add(source)
|
| 147 |
+
|
| 148 |
+
# Handle image if provided
|
| 149 |
+
image_context = ""
|
| 150 |
+
if image_path:
|
| 151 |
+
try:
|
| 152 |
+
# Classify the image
|
| 153 |
+
classification = classify_image(image_path)
|
| 154 |
+
image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]"
|
| 155 |
+
|
| 156 |
+
# Optionally, find related nodes in graph based on image classification
|
| 157 |
+
# This would require storing image-related metadata in the graph
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"Error processing image: {e}")
|
| 161 |
+
|
| 162 |
+
# Assemble context for prompt
|
| 163 |
+
context = "\n\n---\n\n".join([c['text'] for c in chunks])
|
| 164 |
+
|
| 165 |
+
prompt = f"""Use the following context to answer the question:
|
| 166 |
|
| 167 |
+
{context}{image_context}
|
| 168 |
+
|
| 169 |
+
Question: {question}
|
| 170 |
+
|
| 171 |
+
Please provide a comprehensive answer based on the context provided. Cite specific sources when providing information."""
|
| 172 |
+
|
| 173 |
+
# For GPT-5, temperature must be default (1.0)
|
| 174 |
chat_resp = client.chat.completions.create(
|
| 175 |
+
model=OPENAI_CHAT_MODEL,
|
| 176 |
messages=[
|
| 177 |
+
{"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always provide accurate information based on the given context."},
|
| 178 |
+
{"role": "user", "content": prompt}
|
| 179 |
+
],
|
| 180 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 181 |
)
|
| 182 |
+
|
| 183 |
answer = chat_resp.choices[0].message.content
|
| 184 |
+
|
| 185 |
+
return answer, citations
|
| 186 |
|
|
|
|
| 187 |
|
| 188 |
+
def query_with_graph_traversal(question: str, top_k: int = 5, max_hops: int = 2) -> Tuple[str, List[dict]]:
|
| 189 |
"""
|
| 190 |
+
Enhanced graph query that can traverse edges to find related information.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
question: User's question
|
| 194 |
+
top_k: Number of initial nodes to retrieve
|
| 195 |
+
max_hops: Maximum graph traversal depth
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tuple of (answer, citations)
|
| 199 |
"""
|
| 200 |
+
|
| 201 |
+
# Load graph if not already loaded
|
| 202 |
+
_load_graph()
|
| 203 |
+
|
| 204 |
+
if len(_enodes) == 0:
|
| 205 |
+
return "Graph database is empty. Please run preprocess.py first.", []
|
| 206 |
+
|
| 207 |
+
# Get initial nodes using standard query
|
| 208 |
+
initial_answer, initial_citations = query(question, top_k=top_k)
|
| 209 |
+
|
| 210 |
+
# For a more sophisticated implementation, you would:
|
| 211 |
+
# 1. Add edges between related nodes during preprocessing
|
| 212 |
+
# 2. Traverse from initial nodes to find related content
|
| 213 |
+
# 3. Score the related nodes based on path distance and relevance
|
| 214 |
+
|
| 215 |
+
# For now, return the standard query results
|
| 216 |
+
return initial_answer, initial_citations
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def query_subgraph(question: str, source_filter: str = None, top_k: int = 5) -> Tuple[str, List[dict]]:
|
| 220 |
+
"""
|
| 221 |
+
Query a specific subgraph filtered by source.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
question: User's question
|
| 225 |
+
source_filter: Filter nodes by source (e.g., specific PDF name)
|
| 226 |
+
top_k: Number of relevant chunks to retrieve
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Tuple of (answer, citations)
|
| 230 |
+
"""
|
| 231 |
+
|
| 232 |
+
# Load graph if not already loaded
|
| 233 |
+
_load_graph()
|
| 234 |
+
|
| 235 |
+
# Filter nodes if source specified
|
| 236 |
+
if source_filter:
|
| 237 |
+
filtered_nodes = []
|
| 238 |
+
for n in _enodes:
|
| 239 |
+
node_data = _graph.nodes[n]
|
| 240 |
+
metadata = node_data.get('metadata', {})
|
| 241 |
+
source = metadata.get('source') or node_data.get('source', '')
|
| 242 |
+
source_from_meta = metadata.get('source', '')
|
| 243 |
+
|
| 244 |
+
# Check both direct source and metadata source
|
| 245 |
+
if (source_filter.lower() in source.lower() or
|
| 246 |
+
source_filter.lower() in source_from_meta.lower()):
|
| 247 |
+
filtered_nodes.append(n)
|
| 248 |
+
|
| 249 |
+
if not filtered_nodes:
|
| 250 |
+
return f"No nodes found for source: {source_filter}", []
|
| 251 |
+
else:
|
| 252 |
+
filtered_nodes = _enodes
|
| 253 |
+
|
| 254 |
+
# Get embeddings for filtered nodes
|
| 255 |
+
filtered_embeddings = np.array([_graph.nodes[n]['embedding'] for n in filtered_nodes])
|
| 256 |
+
|
| 257 |
+
# Embed question
|
| 258 |
emb_resp = client.embeddings.create(
|
| 259 |
+
model=OPENAI_EMBEDDING_MODEL,
|
| 260 |
input=question
|
| 261 |
)
|
| 262 |
+
q_vec = np.array(emb_resp.data[0].embedding)
|
| 263 |
+
|
| 264 |
+
# Compute similarities
|
| 265 |
+
sims = cosine_similarity([q_vec], filtered_embeddings)[0]
|
| 266 |
idxs = sims.argsort()[::-1][:top_k]
|
| 267 |
+
|
| 268 |
+
# Collect results
|
| 269 |
chunks = []
|
| 270 |
+
citations = []
|
| 271 |
+
sources_seen = set()
|
| 272 |
+
|
| 273 |
for i in idxs:
|
| 274 |
+
if i < len(filtered_nodes):
|
| 275 |
+
node = filtered_nodes[i]
|
| 276 |
+
node_data = _graph.nodes[node]
|
| 277 |
+
|
| 278 |
+
chunks.append(node_data['text'])
|
| 279 |
+
|
| 280 |
+
# Skip if source information missing
|
| 281 |
+
metadata = node_data.get('metadata', {})
|
| 282 |
+
source = metadata.get('source') or node_data.get('source')
|
| 283 |
+
|
| 284 |
+
if not source:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
if source not in sources_seen:
|
| 288 |
+
citation = {
|
| 289 |
+
'source': source,
|
| 290 |
+
'type': 'pdf' if ('path' in metadata or 'path' in node_data) else 'html',
|
| 291 |
+
'relevance_score': round(float(sims[i]), 3)
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
# Check metadata first, then node_data for legacy support
|
| 295 |
+
if 'url' in metadata:
|
| 296 |
+
citation['url'] = metadata['url']
|
| 297 |
+
elif 'path' in metadata:
|
| 298 |
+
citation['path'] = metadata['path']
|
| 299 |
+
elif 'url' in node_data:
|
| 300 |
+
citation['url'] = node_data['url']
|
| 301 |
+
elif 'path' in node_data:
|
| 302 |
+
citation['path'] = node_data['path']
|
| 303 |
+
|
| 304 |
+
citations.append(citation)
|
| 305 |
+
sources_seen.add(source)
|
| 306 |
+
|
| 307 |
+
# Build context and generate answer
|
| 308 |
+
context = "\n\n---\n\n".join(chunks)
|
| 309 |
+
|
| 310 |
+
prompt = f"""Answer the following question using the provided context:
|
| 311 |
|
| 312 |
+
Context from {source_filter if source_filter else 'all sources'}:
|
| 313 |
+
{context}
|
| 314 |
+
|
| 315 |
+
Question: {question}
|
| 316 |
+
|
| 317 |
+
Provide a detailed answer based on the context."""
|
| 318 |
+
|
| 319 |
+
# For GPT-5, temperature must be default (1.0)
|
| 320 |
+
response = client.chat.completions.create(
|
| 321 |
+
model=OPENAI_CHAT_MODEL,
|
| 322 |
messages=[
|
| 323 |
+
{"role": "system", "content": "You are an expert on manufacturing safety. Answer based on the provided context."},
|
| 324 |
+
{"role": "user", "content": prompt}
|
| 325 |
+
],
|
| 326 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 327 |
)
|
| 328 |
+
|
| 329 |
+
answer = response.choices[0].message.content
|
| 330 |
+
|
| 331 |
+
return answer, citations
|
| 332 |
|
| 333 |
+
|
| 334 |
+
# Maintain backward compatibility with original function signature
|
| 335 |
+
def query_graph(question: str, top_k: int = 5) -> Tuple[str, List[str], List[tuple]]:
|
| 336 |
+
"""
|
| 337 |
+
Original query_graph function signature for backward compatibility.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
question: User's question
|
| 341 |
+
top_k: Number of relevant chunks to retrieve
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
Tuple of (answer, sources, chunks)
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
# Call the new query function
|
| 348 |
+
answer, citations = query(question, top_k=top_k)
|
| 349 |
+
|
| 350 |
+
# Convert citations to old format
|
| 351 |
+
sources = [c['source'] for c in citations]
|
| 352 |
+
|
| 353 |
+
# Get chunks in old format (header, score, text, citation)
|
| 354 |
+
_load_graph()
|
| 355 |
+
|
| 356 |
+
if len(_enodes) == 0:
|
| 357 |
+
return answer, sources, []
|
| 358 |
+
|
| 359 |
+
# Regenerate chunks for backward compatibility
|
| 360 |
+
emb_resp = client.embeddings.create(
|
| 361 |
+
model=OPENAI_EMBEDDING_MODEL,
|
| 362 |
+
input=question
|
| 363 |
+
)
|
| 364 |
+
q_vec = np.array(emb_resp.data[0].embedding)
|
| 365 |
+
|
| 366 |
+
sims = cosine_similarity([q_vec], _embeddings)[0]
|
| 367 |
+
idxs = sims.argsort()[::-1][:top_k]
|
| 368 |
+
|
| 369 |
+
chunks = []
|
| 370 |
+
for i in idxs:
|
| 371 |
+
node = _enodes[i]
|
| 372 |
+
node_data = _graph.nodes[node]
|
| 373 |
+
text = node_data['text']
|
| 374 |
+
header = text.split('\n', 1)[0].lstrip('#').strip()
|
| 375 |
+
score = sims[i]
|
| 376 |
+
|
| 377 |
+
# Skip if source information missing
|
| 378 |
+
metadata = node_data.get('metadata', {})
|
| 379 |
+
source = metadata.get('source') or node_data.get('source')
|
| 380 |
+
|
| 381 |
+
if not source:
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
if 'url' in metadata:
|
| 385 |
+
citation = metadata['url']
|
| 386 |
+
elif 'path' in metadata:
|
| 387 |
+
citation = metadata['path']
|
| 388 |
+
elif 'url' in node_data:
|
| 389 |
+
citation = node_data['url']
|
| 390 |
+
elif 'path' in node_data:
|
| 391 |
+
citation = node_data['path']
|
| 392 |
+
else:
|
| 393 |
+
citation = source
|
| 394 |
+
|
| 395 |
+
chunks.append((header, score, text, citation))
|
| 396 |
+
|
| 397 |
return answer, sources, chunks
|
| 398 |
|
| 399 |
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
# Test the updated graph query
|
| 402 |
+
test_questions = [
|
| 403 |
+
"What are general machine guarding requirements?",
|
| 404 |
+
"How do I perform lockout/tagout procedures?",
|
| 405 |
+
"What safety measures are needed for robotic systems?"
|
| 406 |
+
]
|
| 407 |
+
|
| 408 |
+
for q in test_questions:
|
| 409 |
+
print(f"\nQuestion: {q}")
|
| 410 |
+
answer, citations = query(q)
|
| 411 |
+
print(f"Answer: {answer[:200]}...")
|
| 412 |
+
print(f"Citations: {[c['source'] for c in citations]}")
|
| 413 |
+
print("-" * 50)
|
|
|
|
|
|
|
|
|
query_vanilla.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vanilla vector search using FAISS index and OpenAI embeddings.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss
|
| 7 |
+
from typing import Tuple, List, Optional
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
|
| 10 |
+
import pickle
|
| 11 |
+
import logging
|
| 12 |
+
from config import *
|
| 13 |
+
from utils import EmbeddingGenerator, classify_image
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Initialize OpenAI client
|
| 18 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 19 |
+
|
| 20 |
+
# Global variables for lazy loading
|
| 21 |
+
_index = None
|
| 22 |
+
_texts = None
|
| 23 |
+
_metadata = None
|
| 24 |
+
|
| 25 |
+
def _load_vanilla_index():
|
| 26 |
+
"""Lazy load vanilla FAISS index and metadata."""
|
| 27 |
+
global _index, _texts, _metadata
|
| 28 |
+
|
| 29 |
+
if _index is None:
|
| 30 |
+
try:
|
| 31 |
+
if VANILLA_FAISS_INDEX.exists() and VANILLA_METADATA.exists():
|
| 32 |
+
logger.info("Loading vanilla FAISS index...")
|
| 33 |
+
|
| 34 |
+
# Load FAISS index
|
| 35 |
+
_index = faiss.read_index(str(VANILLA_FAISS_INDEX))
|
| 36 |
+
|
| 37 |
+
# Load metadata
|
| 38 |
+
with open(VANILLA_METADATA, 'rb') as f:
|
| 39 |
+
data = pickle.load(f)
|
| 40 |
+
|
| 41 |
+
if isinstance(data, list):
|
| 42 |
+
# New format with metadata list
|
| 43 |
+
_texts = [item['text'] for item in data]
|
| 44 |
+
_metadata = [item['metadata'] for item in data]
|
| 45 |
+
else:
|
| 46 |
+
# Old format with dict
|
| 47 |
+
_texts = data.get('texts', [])
|
| 48 |
+
_metadata = data.get('metadata', [])
|
| 49 |
+
|
| 50 |
+
logger.info(f"✓ Loaded vanilla index with {len(_texts)} documents")
|
| 51 |
+
else:
|
| 52 |
+
logger.warning("Vanilla index not found. Run preprocess.py first.")
|
| 53 |
+
_index = None
|
| 54 |
+
_texts = []
|
| 55 |
+
_metadata = []
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Error loading vanilla index: {e}")
|
| 59 |
+
_index = None
|
| 60 |
+
_texts = []
|
| 61 |
+
_metadata = []
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
|
| 65 |
+
"""
|
| 66 |
+
Query using vanilla vector search.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
question: User's question
|
| 70 |
+
image_path: Optional path to an image (for multimodal queries)
|
| 71 |
+
top_k: Number of relevant chunks to retrieve
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (answer, citations)
|
| 75 |
+
"""
|
| 76 |
+
if top_k is None:
|
| 77 |
+
top_k = DEFAULT_TOP_K
|
| 78 |
+
|
| 79 |
+
# Load index if not already loaded
|
| 80 |
+
_load_vanilla_index()
|
| 81 |
+
|
| 82 |
+
if _index is None or len(_texts) == 0:
|
| 83 |
+
return "Index not loaded. Please run preprocess.py first.", []
|
| 84 |
+
|
| 85 |
+
# Generate query embedding using embedding generator
|
| 86 |
+
embedding_gen = EmbeddingGenerator()
|
| 87 |
+
query_embedding = embedding_gen.embed_text_openai([question])
|
| 88 |
+
|
| 89 |
+
# Normalize for cosine similarity
|
| 90 |
+
query_embedding = query_embedding.astype(np.float32)
|
| 91 |
+
faiss.normalize_L2(query_embedding)
|
| 92 |
+
|
| 93 |
+
# Search the index
|
| 94 |
+
distances, indices = _index.search(query_embedding, top_k)
|
| 95 |
+
|
| 96 |
+
# Collect retrieved chunks and citations
|
| 97 |
+
retrieved_chunks = []
|
| 98 |
+
citations = []
|
| 99 |
+
sources_seen = set()
|
| 100 |
+
|
| 101 |
+
for idx, distance in zip(indices[0], distances[0]):
|
| 102 |
+
if idx < len(_texts) and distance > MIN_RELEVANCE_SCORE:
|
| 103 |
+
chunk_text = _texts[idx]
|
| 104 |
+
chunk_meta = _metadata[idx]
|
| 105 |
+
|
| 106 |
+
retrieved_chunks.append({
|
| 107 |
+
'text': chunk_text,
|
| 108 |
+
'score': float(distance),
|
| 109 |
+
'metadata': chunk_meta
|
| 110 |
+
})
|
| 111 |
+
|
| 112 |
+
# Build citation
|
| 113 |
+
if chunk_meta['source'] not in sources_seen:
|
| 114 |
+
citation = {
|
| 115 |
+
'source': chunk_meta['source'],
|
| 116 |
+
'type': chunk_meta['type'],
|
| 117 |
+
'relevance_score': round(float(distance), 3)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if chunk_meta['type'] == 'pdf':
|
| 121 |
+
citation['path'] = chunk_meta['path']
|
| 122 |
+
else: # HTML
|
| 123 |
+
citation['url'] = chunk_meta.get('url', '')
|
| 124 |
+
|
| 125 |
+
citations.append(citation)
|
| 126 |
+
sources_seen.add(chunk_meta['source'])
|
| 127 |
+
|
| 128 |
+
# Handle image if provided
|
| 129 |
+
image_context = ""
|
| 130 |
+
if image_path:
|
| 131 |
+
try:
|
| 132 |
+
classification = classify_image(image_path)
|
| 133 |
+
image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]"
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Error processing image: {e}")
|
| 136 |
+
|
| 137 |
+
# Build context for the prompt
|
| 138 |
+
context = "\n\n---\n\n".join([chunk['text'] for chunk in retrieved_chunks])
|
| 139 |
+
|
| 140 |
+
if not context:
|
| 141 |
+
return "No relevant documents found for your query.", []
|
| 142 |
+
|
| 143 |
+
# Generate answer using OpenAI
|
| 144 |
+
prompt = f"""Use the following context to answer the question:
|
| 145 |
+
|
| 146 |
+
{context}{image_context}
|
| 147 |
+
|
| 148 |
+
Question: {question}
|
| 149 |
+
|
| 150 |
+
Please provide a comprehensive answer based on the context provided. If the context doesn't contain enough information, say so."""
|
| 151 |
+
|
| 152 |
+
# For GPT-5, temperature must be default (1.0)
|
| 153 |
+
response = client.chat.completions.create(
|
| 154 |
+
model=OPENAI_CHAT_MODEL,
|
| 155 |
+
messages=[
|
| 156 |
+
{"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always cite your sources when providing information."},
|
| 157 |
+
{"role": "user", "content": prompt}
|
| 158 |
+
],
|
| 159 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
answer = response.choices[0].message.content
|
| 163 |
+
|
| 164 |
+
return answer, citations
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def query_with_feedback(question: str, feedback_scores: List[float] = None, top_k: int = 5) -> Tuple[str, List[dict]]:
|
| 168 |
+
"""
|
| 169 |
+
Query with relevance feedback to refine results.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
question: User's question
|
| 173 |
+
feedback_scores: Optional relevance scores for previous results
|
| 174 |
+
top_k: Number of relevant chunks to retrieve
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Tuple of (answer, citations)
|
| 178 |
+
"""
|
| 179 |
+
# For now, just use regular query
|
| 180 |
+
# TODO: Implement Rocchio algorithm or similar for relevance feedback
|
| 181 |
+
return query(question, top_k=top_k)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
# Test the vanilla query
|
| 186 |
+
test_questions = [
|
| 187 |
+
"What are general machine guarding requirements?",
|
| 188 |
+
"How do I perform lockout/tagout procedures?",
|
| 189 |
+
"What safety measures are needed for robotic systems?"
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
for q in test_questions:
|
| 193 |
+
print(f"\nQuestion: {q}")
|
| 194 |
+
answer, citations = query(q)
|
| 195 |
+
print(f"Answer: {answer[:200]}...")
|
| 196 |
+
print(f"Citations: {[c['source'] for c in citations]}")
|
| 197 |
+
print("-" * 50)
|
query_vision.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vision-based query module using GPT-5 Vision.
|
| 3 |
+
Supports multimodal queries combining text and images.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import base64
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import sqlite3
|
| 10 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
|
| 15 |
+
from config import *
|
| 16 |
+
from utils import ImageProcessor, classify_image
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class VisionRetriever:
|
| 21 |
+
"""Vision-based retrieval using GPT-5 Vision for image analysis and classification."""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.client = OpenAI(api_key=OPENAI_API_KEY)
|
| 25 |
+
self.image_processor = ImageProcessor()
|
| 26 |
+
|
| 27 |
+
def get_similar_images(self, query_image_path: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
| 28 |
+
"""Find similar images in the database based on classification similarity."""
|
| 29 |
+
try:
|
| 30 |
+
# Uses GPT-5 Vision for classification-based similarity search
|
| 31 |
+
# Note: This implementation uses classification similarity rather than embeddings
|
| 32 |
+
|
| 33 |
+
# Classify the query image
|
| 34 |
+
query_classification = classify_image(query_image_path)
|
| 35 |
+
|
| 36 |
+
# Query database for similar images
|
| 37 |
+
conn = sqlite3.connect(IMAGES_DB)
|
| 38 |
+
cursor = conn.cursor()
|
| 39 |
+
|
| 40 |
+
# Search for images with similar classification
|
| 41 |
+
cursor.execute("""
|
| 42 |
+
SELECT image_id, image_path, classification, metadata
|
| 43 |
+
FROM images
|
| 44 |
+
WHERE classification LIKE ?
|
| 45 |
+
ORDER BY created_at DESC
|
| 46 |
+
LIMIT ?
|
| 47 |
+
""", (f"%{query_classification}%", top_k))
|
| 48 |
+
|
| 49 |
+
results = cursor.fetchall()
|
| 50 |
+
conn.close()
|
| 51 |
+
|
| 52 |
+
similar_images = []
|
| 53 |
+
for row in results:
|
| 54 |
+
image_id, image_path, classification, metadata_json = row
|
| 55 |
+
metadata = json.loads(metadata_json) if metadata_json else {}
|
| 56 |
+
|
| 57 |
+
similar_images.append({
|
| 58 |
+
'image_id': image_id,
|
| 59 |
+
'image_path': image_path,
|
| 60 |
+
'classification': classification,
|
| 61 |
+
'metadata': metadata,
|
| 62 |
+
'similarity_score': 0.8 # Classification-based similarity score
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
logger.info(f"Found {len(similar_images)} similar images for query")
|
| 66 |
+
return similar_images
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Error finding similar images: {e}")
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
def analyze_image_safety(self, image_path: str, question: str = None) -> str:
|
| 73 |
+
"""Analyze image for safety concerns using GPT-5 Vision."""
|
| 74 |
+
try:
|
| 75 |
+
# Convert image to base64
|
| 76 |
+
with open(image_path, "rb") as image_file:
|
| 77 |
+
image_b64 = base64.b64encode(image_file.read()).decode()
|
| 78 |
+
|
| 79 |
+
# Create analysis prompt
|
| 80 |
+
if question:
|
| 81 |
+
analysis_prompt = (
|
| 82 |
+
f"Analyze this image in the context of the following question: {question}\n\n"
|
| 83 |
+
"Please provide a detailed safety analysis covering:\n"
|
| 84 |
+
"1. What equipment, machinery, or workplace elements are visible\n"
|
| 85 |
+
"2. Any potential safety hazards or compliance issues\n"
|
| 86 |
+
"3. Relevant OSHA standards or regulations that may apply\n"
|
| 87 |
+
"4. Recommendations for safety improvements\n"
|
| 88 |
+
"5. How this relates to the specific question asked"
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
analysis_prompt = (
|
| 92 |
+
"Analyze this image for occupational safety and health concerns. Provide:\n"
|
| 93 |
+
"1. Description of what's shown in the image\n"
|
| 94 |
+
"2. Identification of potential safety hazards\n"
|
| 95 |
+
"3. Relevant OSHA standards or safety regulations\n"
|
| 96 |
+
"4. Recommendations for improving safety"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
messages = [{
|
| 100 |
+
"role": "user",
|
| 101 |
+
"content": [
|
| 102 |
+
{"type": "text", "text": analysis_prompt},
|
| 103 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}", "detail": "high"}}
|
| 104 |
+
]
|
| 105 |
+
}]
|
| 106 |
+
|
| 107 |
+
# For GPT-5 vision, temperature must be default (1.0) and reasoning is not supported
|
| 108 |
+
response = self.client.chat.completions.create(
|
| 109 |
+
model=OPENAI_CHAT_MODEL,
|
| 110 |
+
messages=messages,
|
| 111 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return response.choices[0].message.content.strip()
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"Error analyzing image: {e}")
|
| 118 |
+
return f"I encountered an error while analyzing the image: {e}"
|
| 119 |
+
|
| 120 |
+
def retrieve_relevant_text(self, image_classification: str, question: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 121 |
+
"""Retrieve text documents relevant to the image classification and question."""
|
| 122 |
+
# This would integrate with other retrieval methods to find relevant text
|
| 123 |
+
# For now, we'll create a simple keyword-based search
|
| 124 |
+
|
| 125 |
+
try:
|
| 126 |
+
# Import other query modules for text retrieval
|
| 127 |
+
from query_vanilla import query as vanilla_query
|
| 128 |
+
|
| 129 |
+
# Create an enhanced query combining image classification and original question
|
| 130 |
+
enhanced_question = f"safety requirements for {image_classification} {question}"
|
| 131 |
+
|
| 132 |
+
# Use vanilla retrieval to find relevant text
|
| 133 |
+
_, text_citations = vanilla_query(enhanced_question, top_k=top_k)
|
| 134 |
+
|
| 135 |
+
return text_citations
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error retrieving relevant text: {e}")
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
def generate_multimodal_answer(self, question: str, image_analysis: str,
|
| 142 |
+
text_citations: List[Dict], similar_images: List[Dict]) -> str:
|
| 143 |
+
"""Generate answer combining image analysis and text retrieval."""
|
| 144 |
+
try:
|
| 145 |
+
# Prepare context from text citations
|
| 146 |
+
text_context = ""
|
| 147 |
+
if text_citations:
|
| 148 |
+
text_parts = []
|
| 149 |
+
for i, citation in enumerate(text_citations, 1):
|
| 150 |
+
if 'text' in citation:
|
| 151 |
+
text_parts.append(f"[Text Source {i}] {citation['source']}: {citation['text'][:500]}...")
|
| 152 |
+
else:
|
| 153 |
+
text_parts.append(f"[Text Source {i}] {citation['source']}")
|
| 154 |
+
text_context = "\n\n".join(text_parts)
|
| 155 |
+
|
| 156 |
+
# Prepare context from similar images
|
| 157 |
+
image_context = ""
|
| 158 |
+
if similar_images:
|
| 159 |
+
image_parts = []
|
| 160 |
+
for img in similar_images[:3]: # Limit to top 3
|
| 161 |
+
source = img['metadata'].get('source', 'Unknown')
|
| 162 |
+
classification = img.get('classification', 'unknown')
|
| 163 |
+
image_parts.append(f"Similar image from {source}: classified as {classification}")
|
| 164 |
+
image_context = "\n".join(image_parts)
|
| 165 |
+
|
| 166 |
+
# Create comprehensive prompt
|
| 167 |
+
system_message = (
|
| 168 |
+
"You are an expert in occupational safety and health. "
|
| 169 |
+
"You have been provided with an image analysis, relevant text documents, "
|
| 170 |
+
"and information about similar images in the database. "
|
| 171 |
+
"Provide a comprehensive answer that integrates all this information."
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
user_message = f"""Question: {question}
|
| 175 |
+
|
| 176 |
+
Image Analysis:
|
| 177 |
+
{image_analysis}
|
| 178 |
+
|
| 179 |
+
Relevant Text Documentation:
|
| 180 |
+
{text_context}
|
| 181 |
+
|
| 182 |
+
Similar Images Context:
|
| 183 |
+
{image_context}
|
| 184 |
+
|
| 185 |
+
Please provide a comprehensive answer that:
|
| 186 |
+
1. Addresses the specific question asked
|
| 187 |
+
2. Incorporates insights from the image analysis
|
| 188 |
+
3. References relevant regulatory information from the text sources
|
| 189 |
+
4. Notes any connections to similar cases or images
|
| 190 |
+
5. Provides actionable recommendations based on safety standards"""
|
| 191 |
+
|
| 192 |
+
# For GPT-5, temperature must be default (1.0) and reasoning is not supported
|
| 193 |
+
response = self.client.chat.completions.create(
|
| 194 |
+
model=OPENAI_CHAT_MODEL,
|
| 195 |
+
messages=[
|
| 196 |
+
{"role": "system", "content": system_message},
|
| 197 |
+
{"role": "user", "content": user_message}
|
| 198 |
+
],
|
| 199 |
+
max_completion_tokens=DEFAULT_MAX_TOKENS * 2 # Allow longer response for comprehensive analysis
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return response.choices[0].message.content.strip()
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error generating multimodal answer: {e}")
|
| 206 |
+
return "I apologize, but I encountered an error while generating the comprehensive answer."
|
| 207 |
+
|
| 208 |
+
# Global retriever instance
|
| 209 |
+
_retriever = None
|
| 210 |
+
|
| 211 |
+
def get_retriever() -> VisionRetriever:
|
| 212 |
+
"""Get or create global vision retriever instance."""
|
| 213 |
+
global _retriever
|
| 214 |
+
if _retriever is None:
|
| 215 |
+
_retriever = VisionRetriever()
|
| 216 |
+
return _retriever
|
| 217 |
+
|
| 218 |
+
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]:
|
| 219 |
+
"""
|
| 220 |
+
Main vision-based query function with unified signature.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
question: User question
|
| 224 |
+
image_path: Path to image file (required for vision queries)
|
| 225 |
+
top_k: Number of relevant results to retrieve
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Tuple of (answer, citations)
|
| 229 |
+
"""
|
| 230 |
+
if not image_path:
|
| 231 |
+
return "Vision queries require an image. Please provide an image file.", []
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
retriever = get_retriever()
|
| 235 |
+
|
| 236 |
+
# Step 1: Analyze the provided image
|
| 237 |
+
logger.info(f"Analyzing image: {image_path}")
|
| 238 |
+
image_analysis = retriever.analyze_image_safety(image_path, question)
|
| 239 |
+
|
| 240 |
+
# Step 2: Classify the image
|
| 241 |
+
image_classification = classify_image(image_path)
|
| 242 |
+
|
| 243 |
+
# Step 3: Find similar images
|
| 244 |
+
similar_images = retriever.get_similar_images(image_path, top_k=3)
|
| 245 |
+
|
| 246 |
+
# Step 4: Retrieve relevant text documents
|
| 247 |
+
text_citations = retriever.retrieve_relevant_text(image_classification, question, top_k)
|
| 248 |
+
|
| 249 |
+
# Step 5: Generate comprehensive multimodal answer
|
| 250 |
+
answer = retriever.generate_multimodal_answer(
|
| 251 |
+
question, image_analysis, text_citations, similar_images
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Step 6: Prepare citations
|
| 255 |
+
citations = []
|
| 256 |
+
|
| 257 |
+
# Add image analysis as primary citation
|
| 258 |
+
citations.append({
|
| 259 |
+
'rank': 1,
|
| 260 |
+
'type': 'image_analysis',
|
| 261 |
+
'source': f"Analysis of {image_path.split('/')[-1] if '/' in image_path else image_path.split('\\')[-1]}",
|
| 262 |
+
'method': 'vision',
|
| 263 |
+
'classification': image_classification,
|
| 264 |
+
'score': 1.0
|
| 265 |
+
})
|
| 266 |
+
|
| 267 |
+
# Add text citations
|
| 268 |
+
for i, citation in enumerate(text_citations, 2):
|
| 269 |
+
citation_copy = citation.copy()
|
| 270 |
+
citation_copy['rank'] = i
|
| 271 |
+
citation_copy['method'] = 'vision_text'
|
| 272 |
+
citations.append(citation_copy)
|
| 273 |
+
|
| 274 |
+
# Add similar images
|
| 275 |
+
for i, img in enumerate(similar_images):
|
| 276 |
+
citations.append({
|
| 277 |
+
'rank': len(citations) + 1,
|
| 278 |
+
'type': 'similar_image',
|
| 279 |
+
'source': img['metadata'].get('source', 'Image Database'),
|
| 280 |
+
'method': 'vision',
|
| 281 |
+
'classification': img.get('classification', 'unknown'),
|
| 282 |
+
'similarity_score': img.get('similarity_score', 0.0),
|
| 283 |
+
'image_id': img.get('image_id')
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
logger.info(f"Vision query completed. Generated {len(citations)} citations.")
|
| 287 |
+
return answer, citations
|
| 288 |
+
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.error(f"Error in vision query: {e}")
|
| 291 |
+
error_message = "I apologize, but I encountered an error while processing your vision-based question."
|
| 292 |
+
return error_message, []
|
| 293 |
+
|
| 294 |
+
def query_image_only(image_path: str, question: str = None) -> Tuple[str, List[Dict]]:
|
| 295 |
+
"""
|
| 296 |
+
Analyze image without text retrieval (faster for simple image analysis).
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
image_path: Path to image file
|
| 300 |
+
question: Optional specific question about the image
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
Tuple of (analysis, citations)
|
| 304 |
+
"""
|
| 305 |
+
try:
|
| 306 |
+
retriever = get_retriever()
|
| 307 |
+
|
| 308 |
+
# Analyze image
|
| 309 |
+
analysis = retriever.analyze_image_safety(image_path, question)
|
| 310 |
+
|
| 311 |
+
# Classify image
|
| 312 |
+
classification = classify_image(image_path)
|
| 313 |
+
|
| 314 |
+
# Create citation for image analysis
|
| 315 |
+
citations = [{
|
| 316 |
+
'rank': 1,
|
| 317 |
+
'type': 'image_analysis',
|
| 318 |
+
'source': f"Analysis of {image_path.split('/')[-1] if '/' in image_path else image_path.split('\\')[-1]}",
|
| 319 |
+
'method': 'vision_only',
|
| 320 |
+
'classification': classification,
|
| 321 |
+
'score': 1.0
|
| 322 |
+
}]
|
| 323 |
+
|
| 324 |
+
return analysis, citations
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
logger.error(f"Error in image-only analysis: {e}")
|
| 328 |
+
return "Error analyzing image.", []
|
| 329 |
+
|
| 330 |
+
def query_with_details(question: str, image_path: Optional[str] = None,
|
| 331 |
+
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]:
|
| 332 |
+
"""
|
| 333 |
+
Vision query function that returns detailed chunk information (for compatibility).
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Tuple of (answer, citations, chunks)
|
| 337 |
+
"""
|
| 338 |
+
answer, citations = query(question, image_path, top_k)
|
| 339 |
+
|
| 340 |
+
# Convert citations to chunk format for backward compatibility
|
| 341 |
+
chunks = []
|
| 342 |
+
for citation in citations:
|
| 343 |
+
if citation['type'] == 'image_analysis':
|
| 344 |
+
chunks.append((
|
| 345 |
+
f"Image Analysis ({citation['classification']})",
|
| 346 |
+
citation['score'],
|
| 347 |
+
"Analysis of uploaded image for safety compliance",
|
| 348 |
+
citation['source']
|
| 349 |
+
))
|
| 350 |
+
elif citation['type'] == 'similar_image':
|
| 351 |
+
chunks.append((
|
| 352 |
+
f"Similar Image (Score: {citation.get('similarity_score', 0):.3f})",
|
| 353 |
+
citation.get('similarity_score', 0),
|
| 354 |
+
f"Similar image classified as {citation['classification']}",
|
| 355 |
+
citation['source']
|
| 356 |
+
))
|
| 357 |
+
else:
|
| 358 |
+
chunks.append((
|
| 359 |
+
f"Text Reference {citation['rank']}",
|
| 360 |
+
citation.get('score', 0.5),
|
| 361 |
+
citation.get('text', 'Referenced document'),
|
| 362 |
+
citation['source']
|
| 363 |
+
))
|
| 364 |
+
|
| 365 |
+
return answer, citations, chunks
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
# Test the vision system (requires an actual image file)
|
| 369 |
+
import sys
|
| 370 |
+
|
| 371 |
+
if len(sys.argv) > 1:
|
| 372 |
+
test_image_path = sys.argv[1]
|
| 373 |
+
test_question = "What safety issues can you identify in this image?"
|
| 374 |
+
|
| 375 |
+
print("Testing vision retrieval system...")
|
| 376 |
+
print(f"Image: {test_image_path}")
|
| 377 |
+
print(f"Question: {test_question}")
|
| 378 |
+
print("-" * 50)
|
| 379 |
+
|
| 380 |
+
try:
|
| 381 |
+
answer, citations = query(test_question, test_image_path)
|
| 382 |
+
|
| 383 |
+
print("Answer:")
|
| 384 |
+
print(answer)
|
| 385 |
+
print(f"\nCitations ({len(citations)}):")
|
| 386 |
+
for citation in citations:
|
| 387 |
+
print(f"- {citation['source']} (Type: {citation.get('type', 'unknown')})")
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
print(f"Error during testing: {e}")
|
| 391 |
+
else:
|
| 392 |
+
print("To test vision system, provide an image path as argument:")
|
| 393 |
+
print("python query_vision.py /path/to/image.jpg")
|
realtime_server.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for OpenAI Realtime API integration with RAG system.
|
| 3 |
+
Provides endpoints for session management and RAG tool calls.
|
| 4 |
+
|
| 5 |
+
Directory structure:
|
| 6 |
+
/data/ # Original PDFs, HTML
|
| 7 |
+
/embeddings/ # FAISS, Chroma, DPR vector stores
|
| 8 |
+
/graph/ # Graph database files
|
| 9 |
+
/metadata/ # Image metadata (SQLite or MongoDB)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from typing import Dict, Any, Optional
|
| 17 |
+
from fastapi import FastAPI, HTTPException, Request, Response, status
|
| 18 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
+
from fastapi.responses import JSONResponse
|
| 20 |
+
from fastapi.exceptions import RequestValidationError
|
| 21 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
import uvicorn
|
| 24 |
+
from openai import OpenAI
|
| 25 |
+
|
| 26 |
+
# Import all query modules
|
| 27 |
+
from query_graph import query as graph_query
|
| 28 |
+
from query_vanilla import query as vanilla_query
|
| 29 |
+
from query_dpr import query as dpr_query
|
| 30 |
+
from query_bm25 import query as bm25_query
|
| 31 |
+
from query_context import query as context_query
|
| 32 |
+
from query_vision import query as vision_query
|
| 33 |
+
|
| 34 |
+
from config import OPENAI_API_KEY, OPENAI_CHAT_MODEL, OPENAI_REALTIME_MODEL, REALTIME_VOICE, REALTIME_INSTRUCTIONS, DEFAULT_METHOD
|
| 35 |
+
from analytics_db import log_query
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
# Initialize FastAPI app
|
| 40 |
+
app = FastAPI(title="SIGHT Realtime API Server", version="1.0.0")
|
| 41 |
+
|
| 42 |
+
# CORS middleware for frontend integration
|
| 43 |
+
app.add_middleware(
|
| 44 |
+
CORSMiddleware,
|
| 45 |
+
allow_origins=["*"], # In production, restrict to your domain
|
| 46 |
+
allow_credentials=True,
|
| 47 |
+
allow_methods=["*"],
|
| 48 |
+
allow_headers=["*"],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
@app.middleware("http")
|
| 52 |
+
async def log_requests(request: Request, call_next):
|
| 53 |
+
"""Log all incoming requests for debugging."""
|
| 54 |
+
logger.info(f"Incoming request: {request.method} {request.url}")
|
| 55 |
+
try:
|
| 56 |
+
response = await call_next(request)
|
| 57 |
+
logger.info(f"Response status: {response.status_code}")
|
| 58 |
+
return response
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Request processing error: {e}")
|
| 61 |
+
return JSONResponse(
|
| 62 |
+
content={"error": "Internal server error"},
|
| 63 |
+
status_code=500
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Exception handlers
|
| 67 |
+
@app.exception_handler(RequestValidationError)
|
| 68 |
+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
| 69 |
+
logger.warning(f"Validation error for {request.url}: {exc}")
|
| 70 |
+
return JSONResponse(
|
| 71 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
| 72 |
+
content={"error": "Invalid request format", "details": str(exc)}
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@app.exception_handler(StarletteHTTPException)
|
| 76 |
+
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
| 77 |
+
logger.warning(f"HTTP error for {request.url}: {exc.status_code} - {exc.detail}")
|
| 78 |
+
return JSONResponse(
|
| 79 |
+
status_code=exc.status_code,
|
| 80 |
+
content={"error": exc.detail}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
@app.exception_handler(Exception)
|
| 84 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 85 |
+
logger.error(f"Unhandled error for {request.url}: {exc}")
|
| 86 |
+
return JSONResponse(
|
| 87 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 88 |
+
content={"error": "Internal server error"}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Initialize OpenAI client
|
| 92 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
| 93 |
+
|
| 94 |
+
# Query method dispatch
|
| 95 |
+
QUERY_DISPATCH = {
|
| 96 |
+
'graph': graph_query,
|
| 97 |
+
'vanilla': vanilla_query,
|
| 98 |
+
'dpr': dpr_query,
|
| 99 |
+
'bm25': bm25_query,
|
| 100 |
+
'context': context_query,
|
| 101 |
+
'vision': vision_query
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# Use configuration from config.py with environment variable overrides
|
| 105 |
+
REALTIME_MODEL = os.getenv("REALTIME_MODEL", OPENAI_REALTIME_MODEL)
|
| 106 |
+
VOICE = os.getenv("REALTIME_VOICE", REALTIME_VOICE)
|
| 107 |
+
INSTRUCTIONS = os.getenv("REALTIME_INSTRUCTIONS", REALTIME_INSTRUCTIONS)
|
| 108 |
+
|
| 109 |
+
# Pydantic models for request/response
|
| 110 |
+
class SessionRequest(BaseModel):
|
| 111 |
+
"""Request model for creating ephemeral sessions."""
|
| 112 |
+
model: Optional[str] = "gpt-4o-realtime-preview"
|
| 113 |
+
instructions: Optional[str] = None
|
| 114 |
+
voice: Optional[str] = None
|
| 115 |
+
|
| 116 |
+
class RAGRequest(BaseModel):
|
| 117 |
+
"""Request model for RAG queries."""
|
| 118 |
+
query: str
|
| 119 |
+
method: str = "graph"
|
| 120 |
+
top_k: int = 5
|
| 121 |
+
image_path: Optional[str] = None
|
| 122 |
+
|
| 123 |
+
class RAGResponse(BaseModel):
|
| 124 |
+
"""Response model for RAG queries."""
|
| 125 |
+
answer: str
|
| 126 |
+
citations: list
|
| 127 |
+
method: str
|
| 128 |
+
citations_html: Optional[str] = None
|
| 129 |
+
|
| 130 |
+
@app.post("/session")
|
| 131 |
+
async def create_ephemeral_session(request: SessionRequest) -> JSONResponse:
|
| 132 |
+
"""
|
| 133 |
+
Create an ephemeral session token for OpenAI Realtime API.
|
| 134 |
+
This token will be used by the frontend WebRTC client.
|
| 135 |
+
"""
|
| 136 |
+
try:
|
| 137 |
+
logger.info(f"Creating ephemeral session with model: {request.model or REALTIME_MODEL}")
|
| 138 |
+
|
| 139 |
+
# Create ephemeral token using direct HTTP call to OpenAI API
|
| 140 |
+
# Since the Python SDK doesn't support realtime sessions yet
|
| 141 |
+
import requests
|
| 142 |
+
|
| 143 |
+
session_data = {
|
| 144 |
+
"model": request.model or REALTIME_MODEL,
|
| 145 |
+
"voice": request.voice or VOICE,
|
| 146 |
+
"modalities": ["audio", "text"],
|
| 147 |
+
"instructions": request.instructions or INSTRUCTIONS,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
headers = {
|
| 151 |
+
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
| 152 |
+
"Content-Type": "application/json"
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
# Make direct HTTP request to OpenAI's realtime sessions endpoint
|
| 156 |
+
response = requests.post(
|
| 157 |
+
"https://api.openai.com/v1/realtime/sessions",
|
| 158 |
+
json=session_data,
|
| 159 |
+
headers=headers,
|
| 160 |
+
timeout=30
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if response.status_code == 200:
|
| 164 |
+
session_result = response.json()
|
| 165 |
+
|
| 166 |
+
response_data = {
|
| 167 |
+
"client_secret": session_result.get("client_secret", {}).get("value") or session_result.get("client_secret"),
|
| 168 |
+
"model": request.model or REALTIME_MODEL,
|
| 169 |
+
"session_id": session_result.get("id")
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
logger.info("Ephemeral session created successfully")
|
| 173 |
+
return JSONResponse(content=response_data, status_code=200)
|
| 174 |
+
else:
|
| 175 |
+
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
|
| 176 |
+
return JSONResponse(
|
| 177 |
+
content={"error": f"OpenAI API error: {response.status_code} - {response.text}"},
|
| 178 |
+
status_code=response.status_code
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
except requests.exceptions.RequestException as e:
|
| 182 |
+
logger.error(f"Network error creating ephemeral session: {e}")
|
| 183 |
+
return JSONResponse(
|
| 184 |
+
content={"error": f"Network error: {str(e)}"},
|
| 185 |
+
status_code=500
|
| 186 |
+
)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.error(f"Error creating ephemeral session: {e}")
|
| 189 |
+
return JSONResponse(
|
| 190 |
+
content={"error": f"Session creation failed: {str(e)}"},
|
| 191 |
+
status_code=500
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
@app.post("/rag", response_model=RAGResponse)
|
| 195 |
+
async def rag_query(request: RAGRequest) -> RAGResponse:
|
| 196 |
+
"""
|
| 197 |
+
Handle RAG queries from the realtime interface.
|
| 198 |
+
This endpoint is called by the JavaScript frontend when the model
|
| 199 |
+
requests the ask_rag function.
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
logger.info(f"RAG query: {request.query} using method: {request.method}")
|
| 203 |
+
|
| 204 |
+
# Validate and default method if needed
|
| 205 |
+
method = request.method
|
| 206 |
+
if method not in QUERY_DISPATCH:
|
| 207 |
+
logger.warning(f"Invalid method '{method}', using default '{DEFAULT_METHOD}'")
|
| 208 |
+
method = DEFAULT_METHOD
|
| 209 |
+
|
| 210 |
+
# Get the appropriate query function
|
| 211 |
+
query_func = QUERY_DISPATCH[method]
|
| 212 |
+
|
| 213 |
+
# Execute the query
|
| 214 |
+
start_time = time.time()
|
| 215 |
+
|
| 216 |
+
answer, citations = query_func(
|
| 217 |
+
question=request.query,
|
| 218 |
+
image_path=request.image_path,
|
| 219 |
+
top_k=request.top_k
|
| 220 |
+
)
|
| 221 |
+
response_time = (time.time() - start_time) * 1000 # Convert to ms
|
| 222 |
+
|
| 223 |
+
# Format citations for HTML display (optional)
|
| 224 |
+
citations_html = format_citations_html(citations, method)
|
| 225 |
+
|
| 226 |
+
# Log to analytics database (mark as voice interaction)
|
| 227 |
+
try:
|
| 228 |
+
# Generate unique session ID for each voice interaction
|
| 229 |
+
import uuid
|
| 230 |
+
voice_session_id = f"voice_{uuid.uuid4().hex[:8]}"
|
| 231 |
+
|
| 232 |
+
log_query(
|
| 233 |
+
user_query=request.query,
|
| 234 |
+
method=method,
|
| 235 |
+
answer=answer,
|
| 236 |
+
citations=citations,
|
| 237 |
+
response_time=response_time,
|
| 238 |
+
image_path=request.image_path,
|
| 239 |
+
top_k=request.top_k,
|
| 240 |
+
session_id=voice_session_id,
|
| 241 |
+
additional_settings={'voice_interaction': True, 'interaction_type': 'speech_to_speech'}
|
| 242 |
+
)
|
| 243 |
+
logger.info(f"Voice interaction logged: {request.query[:50]}...")
|
| 244 |
+
except Exception as log_error:
|
| 245 |
+
logger.error(f"Failed to log voice query: {log_error}")
|
| 246 |
+
|
| 247 |
+
logger.info(f"RAG query completed: {len(answer)} chars, {len(citations)} citations")
|
| 248 |
+
|
| 249 |
+
return RAGResponse(
|
| 250 |
+
answer=answer,
|
| 251 |
+
citations=citations,
|
| 252 |
+
method=method,
|
| 253 |
+
citations_html=citations_html
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(f"Error processing RAG query: {e}")
|
| 258 |
+
raise HTTPException(status_code=500, detail=f"RAG query failed: {str(e)}")
|
| 259 |
+
|
| 260 |
+
def format_citations_html(citations: list, method: str) -> str:
|
| 261 |
+
"""Format citations as HTML for display."""
|
| 262 |
+
if not citations:
|
| 263 |
+
return "<p><em>No citations available</em></p>"
|
| 264 |
+
|
| 265 |
+
html_parts = ["<div style='margin-top: 1em;'><strong>Sources:</strong><ul>"]
|
| 266 |
+
|
| 267 |
+
for citation in citations:
|
| 268 |
+
if isinstance(citation, dict) and 'source' in citation:
|
| 269 |
+
source = citation['source']
|
| 270 |
+
cite_type = citation.get('type', 'unknown')
|
| 271 |
+
|
| 272 |
+
# Build citation text based on type
|
| 273 |
+
if cite_type == 'pdf':
|
| 274 |
+
cite_text = f"📄 {source} (PDF)"
|
| 275 |
+
elif cite_type == 'html':
|
| 276 |
+
url = citation.get('url', '')
|
| 277 |
+
if url:
|
| 278 |
+
cite_text = f"🌐 <a href='{url}' target='_blank'>{source}</a> (Web)"
|
| 279 |
+
else:
|
| 280 |
+
cite_text = f"🌐 {source} (Web)"
|
| 281 |
+
elif cite_type == 'image':
|
| 282 |
+
page = citation.get('page', 'N/A')
|
| 283 |
+
cite_text = f"🖼️ {source} (Image, page {page})"
|
| 284 |
+
else:
|
| 285 |
+
cite_text = f"📚 {source}"
|
| 286 |
+
|
| 287 |
+
# Add scores if available
|
| 288 |
+
scores = []
|
| 289 |
+
if 'relevance_score' in citation:
|
| 290 |
+
scores.append(f"relevance: {citation['relevance_score']:.3f}")
|
| 291 |
+
if 'score' in citation:
|
| 292 |
+
scores.append(f"score: {citation['score']:.3f}")
|
| 293 |
+
|
| 294 |
+
if scores:
|
| 295 |
+
cite_text += f" <small>({', '.join(scores)})</small>"
|
| 296 |
+
|
| 297 |
+
html_parts.append(f"<li>{cite_text}</li>")
|
| 298 |
+
elif isinstance(citation, (list, tuple)) and len(citation) >= 4:
|
| 299 |
+
# Handle legacy citation format (header, score, text, source)
|
| 300 |
+
header, score, text, source = citation[:4]
|
| 301 |
+
cite_text = f"📚 {source} <small>(score: {score:.3f})</small>"
|
| 302 |
+
html_parts.append(f"<li>{cite_text}</li>")
|
| 303 |
+
|
| 304 |
+
html_parts.append("</ul></div>")
|
| 305 |
+
return "".join(html_parts)
|
| 306 |
+
|
| 307 |
+
@app.get("/")
|
| 308 |
+
async def root():
|
| 309 |
+
"""Root endpoint to prevent invalid HTTP request warnings."""
|
| 310 |
+
return {
|
| 311 |
+
"service": "SIGHT Realtime API Server",
|
| 312 |
+
"version": "1.0.0",
|
| 313 |
+
"status": "running",
|
| 314 |
+
"endpoints": {
|
| 315 |
+
"session": "POST /session - Create realtime session",
|
| 316 |
+
"rag": "POST /rag - Query RAG system",
|
| 317 |
+
"health": "GET /health - Health check",
|
| 318 |
+
"methods": "GET /methods - List available RAG methods"
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
@app.get("/health")
|
| 323 |
+
async def health_check():
|
| 324 |
+
"""Health check endpoint."""
|
| 325 |
+
return {"status": "healthy", "service": "SIGHT Realtime API Server"}
|
| 326 |
+
|
| 327 |
+
@app.get("/methods")
|
| 328 |
+
async def list_methods():
|
| 329 |
+
"""List available RAG methods."""
|
| 330 |
+
return {
|
| 331 |
+
"methods": list(QUERY_DISPATCH.keys()),
|
| 332 |
+
"descriptions": {
|
| 333 |
+
'graph': "Graph-based RAG using NetworkX with relationship-aware retrieval",
|
| 334 |
+
'vanilla': "Standard vector search with FAISS and OpenAI embeddings",
|
| 335 |
+
'dpr': "Dense Passage Retrieval with bi-encoder and cross-encoder re-ranking",
|
| 336 |
+
'bm25': "BM25 keyword search with neural re-ranking for exact term matching",
|
| 337 |
+
'context': "Context stuffing with full document loading and heuristic selection",
|
| 338 |
+
'vision': "Vision-based search using GPT-5 Vision for image analysis"
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
@app.options("/{full_path:path}")
|
| 343 |
+
async def options_handler(request: Request, response: Response):
|
| 344 |
+
"""Handle CORS preflight requests."""
|
| 345 |
+
response.headers["Access-Control-Allow-Origin"] = "*"
|
| 346 |
+
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
| 347 |
+
response.headers["Access-Control-Allow-Headers"] = "*"
|
| 348 |
+
return response
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
import argparse
|
| 352 |
+
|
| 353 |
+
# Parse command line arguments
|
| 354 |
+
parser = argparse.ArgumentParser(description="SIGHT Realtime API Server")
|
| 355 |
+
parser.add_argument("--https", action="store_true", help="Enable HTTPS with self-signed certificate")
|
| 356 |
+
parser.add_argument("--port", type=int, default=5050, help="Port to run the server on")
|
| 357 |
+
parser.add_argument("--host", default="0.0.0.0", help="Host to bind the server to")
|
| 358 |
+
args = parser.parse_args()
|
| 359 |
+
|
| 360 |
+
# Configure logging
|
| 361 |
+
logging.basicConfig(
|
| 362 |
+
level=logging.INFO,
|
| 363 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Suppress uvicorn access logs for cleaner output
|
| 367 |
+
uvicorn_logger = logging.getLogger("uvicorn.access")
|
| 368 |
+
uvicorn_logger.setLevel(logging.WARNING)
|
| 369 |
+
|
| 370 |
+
# Prepare uvicorn configuration
|
| 371 |
+
uvicorn_config = {
|
| 372 |
+
"app": "realtime_server:app",
|
| 373 |
+
"host": args.host,
|
| 374 |
+
"port": args.port,
|
| 375 |
+
"reload": True,
|
| 376 |
+
"log_level": "warning",
|
| 377 |
+
"access_log": False
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
# Add SSL configuration if HTTPS is requested
|
| 381 |
+
if args.https:
|
| 382 |
+
logger.info("Starting server with HTTPS (self-signed certificate)")
|
| 383 |
+
logger.warning("⚠️ Self-signed certificate will show security warnings in browser")
|
| 384 |
+
logger.info("For production, use a proper SSL certificate from a CA")
|
| 385 |
+
|
| 386 |
+
# Note: You would need to generate SSL certificates
|
| 387 |
+
# For development, you can create self-signed certificates:
|
| 388 |
+
# openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes
|
| 389 |
+
uvicorn_config.update({
|
| 390 |
+
"ssl_keyfile": "key.pem",
|
| 391 |
+
"ssl_certfile": "cert.pem"
|
| 392 |
+
})
|
| 393 |
+
|
| 394 |
+
print(f"🔒 Starting HTTPS server on https://{args.host}:{args.port}")
|
| 395 |
+
print("📝 To generate self-signed certificates, run:")
|
| 396 |
+
print(" openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes")
|
| 397 |
+
else:
|
| 398 |
+
print(f"🌐 Starting HTTP server on http://{args.host}:{args.port}")
|
| 399 |
+
print("⚠️ HTTP only works for localhost. Use --https for production deployment.")
|
| 400 |
+
|
| 401 |
+
# Run the server
|
| 402 |
+
uvicorn.run(**uvicorn_config)
|
requirements.txt
CHANGED
|
@@ -1,10 +1,63 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies (updated versions)
|
| 2 |
+
python-dotenv>=1.1.1
|
| 3 |
+
pymupdf4llm>=0.0.27
|
| 4 |
+
beautifulsoup4>=4.13.4
|
| 5 |
+
requests>=2.32.4
|
| 6 |
+
pandas>=2.2.3
|
| 7 |
+
openai>=1.99.9
|
| 8 |
+
networkx>=3.5
|
| 9 |
+
numpy>=2.3.1
|
| 10 |
+
scikit-learn>=1.7.1
|
| 11 |
+
streamlit>=1.47.0
|
| 12 |
+
|
| 13 |
+
# FastAPI and realtime API dependencies
|
| 14 |
+
fastapi>=0.104.0 # For realtime API server
|
| 15 |
+
uvicorn[standard]>=0.24.0 # ASGI server for FastAPI
|
| 16 |
+
pydantic>=2.4.0 # Data validation and settings management
|
| 17 |
+
|
| 18 |
+
# Document processing
|
| 19 |
+
pymupdf>=1.24.0 # For PDF processing and image extraction
|
| 20 |
+
Pillow>=10.0.0 # For image processing
|
| 21 |
+
lxml>=5.0.0 # For HTML parsing
|
| 22 |
+
html5lib>=1.1 # Alternative HTML parser
|
| 23 |
+
|
| 24 |
+
# Vector stores and search
|
| 25 |
+
faiss-cpu>=1.8.0 # For vector similarity search (use faiss-gpu if CUDA available)
|
| 26 |
+
chromadb>=0.5.0 # Alternative vector database
|
| 27 |
+
rank-bm25>=0.2.2 # For BM25 keyword search
|
| 28 |
+
|
| 29 |
+
# Language models and embeddings
|
| 30 |
+
sentence-transformers>=3.0.0 # For DPR and cross-encoder
|
| 31 |
+
transformers>=4.40.0 # Required by sentence-transformers
|
| 32 |
+
torch>=2.0.0 # For neural models (CPU version)
|
| 33 |
+
# For GPU support, install separately:
|
| 34 |
+
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
|
| 35 |
+
ftfy>=6.1.1 # Text preprocessing for CLIP
|
| 36 |
+
regex>=2023.0.0 # Text processing
|
| 37 |
+
# For CLIP (optional - enable if needed):
|
| 38 |
+
# git+https://github.com/openai/CLIP.git
|
| 39 |
+
|
| 40 |
+
# Token counting and management
|
| 41 |
+
tiktoken>=0.7.0 # For OpenAI token counting
|
| 42 |
+
|
| 43 |
+
# Database (optional)
|
| 44 |
+
# pymongo>=4.0.0 # Uncomment if using MongoDB for metadata
|
| 45 |
+
|
| 46 |
+
# Development and debugging
|
| 47 |
+
tqdm>=4.65.0 # Progress bars
|
| 48 |
+
ipython>=8.0.0 # For interactive debugging
|
| 49 |
+
jupyter>=1.0.0 # For notebook development
|
| 50 |
+
|
| 51 |
+
# Data visualization (optional)
|
| 52 |
+
matplotlib>=3.7.0 # For plotting
|
| 53 |
+
seaborn>=0.12.0 # Statistical visualization
|
| 54 |
+
plotly>=5.15.0 # Interactive plots
|
| 55 |
+
|
| 56 |
+
# Optional advanced features (uncomment if needed)
|
| 57 |
+
# langchain>=0.2.11 # For advanced RAG patterns
|
| 58 |
+
# langchain-openai>=0.1.20 # OpenAI integration for LangChain
|
| 59 |
+
# llama-index>=0.10.51 # Alternative RAG framework
|
| 60 |
+
|
| 61 |
+
# Additional utility packages
|
| 62 |
+
colorama>=0.4.6 # Colored console output
|
| 63 |
+
rich>=13.0.0 # Rich text and beautiful formatting in terminal
|
utils.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the Multi-Method RAG System.
|
| 3 |
+
|
| 4 |
+
Directory Layout:
|
| 5 |
+
/data/ # Original PDFs, HTML
|
| 6 |
+
/embeddings/ # FAISS, Chroma, DPR vector stores
|
| 7 |
+
/graph/ # Graph database files
|
| 8 |
+
/metadata/ # Image metadata (SQLite or MongoDB)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import pickle
|
| 14 |
+
import sqlite3
|
| 15 |
+
import base64
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import List, Dict, Tuple, Optional, Any, Union
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
import pymupdf4llm
|
| 22 |
+
import pymupdf
|
| 23 |
+
import numpy as np
|
| 24 |
+
import pandas as pd
|
| 25 |
+
from PIL import Image
|
| 26 |
+
import requests
|
| 27 |
+
from bs4 import BeautifulSoup
|
| 28 |
+
|
| 29 |
+
# Vector stores and search
|
| 30 |
+
import faiss
|
| 31 |
+
import chromadb
|
| 32 |
+
from rank_bm25 import BM25Okapi
|
| 33 |
+
import networkx as nx
|
| 34 |
+
|
| 35 |
+
# ML models
|
| 36 |
+
from openai import OpenAI
|
| 37 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 38 |
+
import torch
|
| 39 |
+
# import clip
|
| 40 |
+
|
| 41 |
+
# Text processing
|
| 42 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 43 |
+
import tiktoken
|
| 44 |
+
|
| 45 |
+
from config import *
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class DocumentChunk:
|
| 51 |
+
"""Data structure for document chunks."""
|
| 52 |
+
text: str
|
| 53 |
+
metadata: Dict[str, Any]
|
| 54 |
+
chunk_id: str
|
| 55 |
+
embedding: Optional[np.ndarray] = None
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class ImageData:
|
| 59 |
+
"""Data structure for image metadata."""
|
| 60 |
+
image_path: str
|
| 61 |
+
image_id: str
|
| 62 |
+
classification: Optional[str] = None
|
| 63 |
+
embedding: Optional[np.ndarray] = None
|
| 64 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 65 |
+
|
| 66 |
+
class DocumentLoader:
|
| 67 |
+
"""Load and extract text from various document formats."""
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
self.client = OpenAI(api_key=OPENAI_API_KEY)
|
| 71 |
+
validate_api_key()
|
| 72 |
+
|
| 73 |
+
def load_pdf_documents(self, pdf_paths: List[Union[str, Path]]) -> List[Dict[str, Any]]:
|
| 74 |
+
"""Load text from PDF files using pymupdf4llm."""
|
| 75 |
+
documents = []
|
| 76 |
+
|
| 77 |
+
for pdf_path in pdf_paths:
|
| 78 |
+
try:
|
| 79 |
+
pdf_path = Path(pdf_path)
|
| 80 |
+
logger.info(f"Loading PDF: {pdf_path}")
|
| 81 |
+
|
| 82 |
+
# Extract text using pymupdf4llm
|
| 83 |
+
text = pymupdf4llm.to_markdown(str(pdf_path))
|
| 84 |
+
|
| 85 |
+
# Extract images if present
|
| 86 |
+
images = self._extract_pdf_images(pdf_path)
|
| 87 |
+
|
| 88 |
+
doc = {
|
| 89 |
+
'text': text,
|
| 90 |
+
'source': str(pdf_path.name),
|
| 91 |
+
'path': str(pdf_path),
|
| 92 |
+
'type': 'pdf',
|
| 93 |
+
'images': images,
|
| 94 |
+
'metadata': {
|
| 95 |
+
'file_size': pdf_path.stat().st_size,
|
| 96 |
+
'modified': pdf_path.stat().st_mtime
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
documents.append(doc)
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"Error loading PDF {pdf_path}: {e}")
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
return documents
|
| 106 |
+
|
| 107 |
+
def _extract_pdf_images(self, pdf_path: Path) -> List[Dict[str, Any]]:
|
| 108 |
+
"""Extract images from PDF using pymupdf."""
|
| 109 |
+
images = []
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
doc = pymupdf.open(str(pdf_path))
|
| 113 |
+
|
| 114 |
+
for page_num in range(len(doc)):
|
| 115 |
+
page = doc[page_num]
|
| 116 |
+
image_list = page.get_images(full=True)
|
| 117 |
+
|
| 118 |
+
for img_index, img in enumerate(image_list):
|
| 119 |
+
try:
|
| 120 |
+
# Extract image
|
| 121 |
+
xref = img[0]
|
| 122 |
+
pix = pymupdf.Pixmap(doc, xref)
|
| 123 |
+
|
| 124 |
+
# Skip if pixmap is invalid or has no colorspace
|
| 125 |
+
if not pix or pix.colorspace is None:
|
| 126 |
+
if pix:
|
| 127 |
+
pix = None
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
# Only process images with valid color channels
|
| 131 |
+
if pix.n - pix.alpha < 4: # GRAY or RGB
|
| 132 |
+
image_id = f"{pdf_path.stem}_p{page_num}_img{img_index}"
|
| 133 |
+
image_path = IMAGES_DIR / f"{image_id}.png"
|
| 134 |
+
|
| 135 |
+
# Convert to RGB if grayscale or other formats
|
| 136 |
+
if pix.n == 1: # Grayscale
|
| 137 |
+
rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
| 138 |
+
pix = None # Clean up original
|
| 139 |
+
pix = rgb_pix
|
| 140 |
+
elif pix.n == 4 and pix.alpha == 0: # CMYK
|
| 141 |
+
rgb_pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
|
| 142 |
+
pix = None # Clean up original
|
| 143 |
+
pix = rgb_pix
|
| 144 |
+
|
| 145 |
+
# Save image
|
| 146 |
+
pix.save(str(image_path))
|
| 147 |
+
|
| 148 |
+
images.append({
|
| 149 |
+
'image_id': image_id,
|
| 150 |
+
'image_path': str(image_path),
|
| 151 |
+
'page': page_num,
|
| 152 |
+
'source': str(pdf_path.name)
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
pix = None
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.warning(f"Error extracting image {img_index} from page {page_num}: {e}")
|
| 159 |
+
if 'pix' in locals() and pix:
|
| 160 |
+
pix = None
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
doc.close()
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Error extracting images from {pdf_path}: {e}")
|
| 167 |
+
|
| 168 |
+
return images
|
| 169 |
+
|
| 170 |
+
def load_html_documents(self, html_sources: List[Dict[str, str]]) -> List[Dict[str, Any]]:
|
| 171 |
+
"""Load text from HTML sources."""
|
| 172 |
+
documents = []
|
| 173 |
+
|
| 174 |
+
for source in html_sources:
|
| 175 |
+
try:
|
| 176 |
+
logger.info(f"Loading HTML: {source.get('title', source['url'])}")
|
| 177 |
+
|
| 178 |
+
# Fetch HTML content
|
| 179 |
+
response = requests.get(source['url'], timeout=30)
|
| 180 |
+
response.raise_for_status()
|
| 181 |
+
|
| 182 |
+
# Parse with BeautifulSoup
|
| 183 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
| 184 |
+
|
| 185 |
+
# Extract text
|
| 186 |
+
text = soup.get_text(separator=' ', strip=True)
|
| 187 |
+
|
| 188 |
+
doc = {
|
| 189 |
+
'text': text,
|
| 190 |
+
'source': source.get('title', source['url']),
|
| 191 |
+
'path': source['url'],
|
| 192 |
+
'type': 'html',
|
| 193 |
+
'images': [],
|
| 194 |
+
'metadata': {
|
| 195 |
+
'url': source['url'],
|
| 196 |
+
'title': source.get('title', ''),
|
| 197 |
+
'year': source.get('year', ''),
|
| 198 |
+
'category': source.get('category', ''),
|
| 199 |
+
'format': source.get('format', 'HTML')
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
documents.append(doc)
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error loading HTML {source['url']}: {e}")
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
return documents
|
| 209 |
+
|
| 210 |
+
def load_text_documents(self, data_dir: Path = DATA_DIR) -> List[Dict[str, Any]]:
|
| 211 |
+
"""Load all supported document types from data directory."""
|
| 212 |
+
documents = []
|
| 213 |
+
|
| 214 |
+
# Load PDFs
|
| 215 |
+
pdf_files = list(data_dir.glob("*.pdf"))
|
| 216 |
+
if pdf_files:
|
| 217 |
+
documents.extend(self.load_pdf_documents(pdf_files))
|
| 218 |
+
|
| 219 |
+
# Load HTML sources (from config)
|
| 220 |
+
if DEFAULT_HTML_SOURCES:
|
| 221 |
+
documents.extend(self.load_html_documents(DEFAULT_HTML_SOURCES))
|
| 222 |
+
|
| 223 |
+
logger.info(f"Loaded {len(documents)} documents total")
|
| 224 |
+
return documents
|
| 225 |
+
|
| 226 |
+
class TextPreprocessor:
|
| 227 |
+
"""Preprocess text for different retrieval methods."""
|
| 228 |
+
|
| 229 |
+
def __init__(self):
|
| 230 |
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
| 231 |
+
|
| 232 |
+
def chunk_text_by_tokens(self, text: str, chunk_size: int = CHUNK_SIZE,
|
| 233 |
+
overlap: int = CHUNK_OVERLAP) -> List[str]:
|
| 234 |
+
"""Split text into chunks by token count."""
|
| 235 |
+
tokens = self.encoding.encode(text)
|
| 236 |
+
chunks = []
|
| 237 |
+
|
| 238 |
+
start = 0
|
| 239 |
+
while start < len(tokens):
|
| 240 |
+
end = start + chunk_size
|
| 241 |
+
chunk_tokens = tokens[start:end]
|
| 242 |
+
chunk_text = self.encoding.decode(chunk_tokens)
|
| 243 |
+
chunks.append(chunk_text)
|
| 244 |
+
start = end - overlap
|
| 245 |
+
|
| 246 |
+
return chunks
|
| 247 |
+
|
| 248 |
+
def chunk_text_by_sections(self, text: str, method: str = "vanilla") -> List[str]:
|
| 249 |
+
"""Split text by sections based on method requirements."""
|
| 250 |
+
if method in ["vanilla", "dpr"]:
|
| 251 |
+
return self.chunk_text_by_tokens(text)
|
| 252 |
+
elif method == "bm25":
|
| 253 |
+
# BM25 works better with paragraph-level chunks
|
| 254 |
+
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
|
| 255 |
+
return paragraphs
|
| 256 |
+
elif method == "graph":
|
| 257 |
+
# Graph method uses larger sections
|
| 258 |
+
return self.chunk_text_by_tokens(text, chunk_size=CHUNK_SIZE*2)
|
| 259 |
+
elif method == "context_stuffing":
|
| 260 |
+
# Context stuffing uses full documents
|
| 261 |
+
return [text]
|
| 262 |
+
else:
|
| 263 |
+
return self.chunk_text_by_tokens(text)
|
| 264 |
+
|
| 265 |
+
def preprocess_for_method(self, documents: List[Dict[str, Any]],
|
| 266 |
+
method: str) -> List[DocumentChunk]:
|
| 267 |
+
"""Preprocess documents for specific retrieval method."""
|
| 268 |
+
chunks = []
|
| 269 |
+
|
| 270 |
+
for doc in documents:
|
| 271 |
+
text_chunks = self.chunk_text_by_sections(doc['text'], method)
|
| 272 |
+
|
| 273 |
+
for i, chunk_text in enumerate(text_chunks):
|
| 274 |
+
chunk_id = f"{doc['source']}_{method}_chunk_{i}"
|
| 275 |
+
|
| 276 |
+
chunk = DocumentChunk(
|
| 277 |
+
text=chunk_text,
|
| 278 |
+
metadata={
|
| 279 |
+
'source': doc['source'],
|
| 280 |
+
'path': doc['path'],
|
| 281 |
+
'type': doc['type'],
|
| 282 |
+
'chunk_index': i,
|
| 283 |
+
'method': method,
|
| 284 |
+
**doc.get('metadata', {})
|
| 285 |
+
},
|
| 286 |
+
chunk_id=chunk_id
|
| 287 |
+
)
|
| 288 |
+
chunks.append(chunk)
|
| 289 |
+
|
| 290 |
+
logger.info(f"Created {len(chunks)} chunks for method '{method}'")
|
| 291 |
+
return chunks
|
| 292 |
+
|
| 293 |
+
class EmbeddingGenerator:
|
| 294 |
+
"""Generate embeddings using various models."""
|
| 295 |
+
|
| 296 |
+
def __init__(self):
|
| 297 |
+
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 298 |
+
self.sentence_transformer = None
|
| 299 |
+
# self.clip_model = None
|
| 300 |
+
# self.clip_preprocess = None
|
| 301 |
+
|
| 302 |
+
def _get_sentence_transformer(self):
|
| 303 |
+
"""Lazy loading of sentence transformer."""
|
| 304 |
+
if self.sentence_transformer is None:
|
| 305 |
+
self.sentence_transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
|
| 306 |
+
if DEVICE == "cuda":
|
| 307 |
+
self.sentence_transformer = self.sentence_transformer.to(DEVICE)
|
| 308 |
+
return self.sentence_transformer
|
| 309 |
+
|
| 310 |
+
# def _get_clip_model(self):
|
| 311 |
+
# """Lazy loading of CLIP model."""
|
| 312 |
+
# if self.clip_model is None:
|
| 313 |
+
# self.clip_model, self.clip_preprocess = clip.load(CLIP_MODEL, device=DEVICE)
|
| 314 |
+
# return self.clip_model, self.clip_preprocess
|
| 315 |
+
|
| 316 |
+
def embed_text_openai(self, texts: List[str]) -> np.ndarray:
|
| 317 |
+
"""Generate embeddings using OpenAI API."""
|
| 318 |
+
embeddings = []
|
| 319 |
+
|
| 320 |
+
# Process in batches
|
| 321 |
+
for i in range(0, len(texts), EMBEDDING_BATCH_SIZE):
|
| 322 |
+
batch = texts[i:i + EMBEDDING_BATCH_SIZE]
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
response = self.openai_client.embeddings.create(
|
| 326 |
+
model=OPENAI_EMBEDDING_MODEL,
|
| 327 |
+
input=batch
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
batch_embeddings = [data.embedding for data in response.data]
|
| 331 |
+
embeddings.extend(batch_embeddings)
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.error(f"Error generating OpenAI embeddings: {e}")
|
| 335 |
+
raise
|
| 336 |
+
|
| 337 |
+
return np.array(embeddings)
|
| 338 |
+
|
| 339 |
+
def embed_text_sentence_transformer(self, texts: List[str]) -> np.ndarray:
|
| 340 |
+
"""Generate embeddings using sentence transformers."""
|
| 341 |
+
model = self._get_sentence_transformer()
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
embeddings = model.encode(texts, convert_to_numpy=True,
|
| 345 |
+
show_progress_bar=True, batch_size=32)
|
| 346 |
+
return embeddings
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"Error generating sentence transformer embeddings: {e}")
|
| 350 |
+
raise
|
| 351 |
+
|
| 352 |
+
def embed_image_clip(self, image_paths: List[str]) -> np.ndarray:
|
| 353 |
+
"""Generate image embeddings using CLIP."""
|
| 354 |
+
# model, preprocess = self._get_clip_model()
|
| 355 |
+
# embeddings = []
|
| 356 |
+
|
| 357 |
+
# for image_path in image_paths:
|
| 358 |
+
# try:
|
| 359 |
+
# image = preprocess(Image.open(image_path)).unsqueeze(0).to(DEVICE)
|
| 360 |
+
#
|
| 361 |
+
# with torch.no_grad():
|
| 362 |
+
# image_features = model.encode_image(image)
|
| 363 |
+
# image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 364 |
+
#
|
| 365 |
+
# embeddings.append(image_features.cpu().numpy().flatten())
|
| 366 |
+
#
|
| 367 |
+
# except Exception as e:
|
| 368 |
+
# logger.error(f"Error embedding image {image_path}: {e}")
|
| 369 |
+
# continue
|
| 370 |
+
|
| 371 |
+
# return np.array(embeddings) if embeddings else np.array([])
|
| 372 |
+
|
| 373 |
+
# Placeholder for CLIP embeddings
|
| 374 |
+
logger.warning("CLIP embeddings not implemented - returning dummy embeddings")
|
| 375 |
+
return np.random.rand(len(image_paths), 512)
|
| 376 |
+
|
| 377 |
+
class VectorStoreManager:
|
| 378 |
+
"""Manage vector stores for different methods."""
|
| 379 |
+
|
| 380 |
+
def __init__(self):
|
| 381 |
+
self.embedding_generator = EmbeddingGenerator()
|
| 382 |
+
|
| 383 |
+
def build_faiss_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Tuple[Any, List[Dict]]:
|
| 384 |
+
"""Build FAISS index for vanilla or DPR method."""
|
| 385 |
+
|
| 386 |
+
# Generate embeddings
|
| 387 |
+
texts = [chunk.text for chunk in chunks]
|
| 388 |
+
|
| 389 |
+
if method == "vanilla":
|
| 390 |
+
embeddings = self.embedding_generator.embed_text_openai(texts)
|
| 391 |
+
elif method == "dpr":
|
| 392 |
+
embeddings = self.embedding_generator.embed_text_sentence_transformer(texts)
|
| 393 |
+
else:
|
| 394 |
+
raise ValueError(f"Unsupported method for FAISS: {method}")
|
| 395 |
+
|
| 396 |
+
# Build FAISS index
|
| 397 |
+
dimension = embeddings.shape[1]
|
| 398 |
+
index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity
|
| 399 |
+
|
| 400 |
+
# Ensure embeddings are float32 and normalize for cosine similarity
|
| 401 |
+
embeddings = embeddings.astype(np.float32)
|
| 402 |
+
faiss.normalize_L2(embeddings)
|
| 403 |
+
index.add(embeddings)
|
| 404 |
+
|
| 405 |
+
# Store chunk metadata
|
| 406 |
+
metadata = []
|
| 407 |
+
for i, chunk in enumerate(chunks):
|
| 408 |
+
metadata.append({
|
| 409 |
+
'chunk_id': chunk.chunk_id,
|
| 410 |
+
'text': chunk.text,
|
| 411 |
+
'metadata': chunk.metadata,
|
| 412 |
+
'embedding': embeddings[i].tolist()
|
| 413 |
+
})
|
| 414 |
+
|
| 415 |
+
logger.info(f"Built FAISS index with {index.ntotal} vectors for method '{method}'")
|
| 416 |
+
return index, metadata
|
| 417 |
+
|
| 418 |
+
def build_chroma_index(self, chunks: List[DocumentChunk], method: str = "vanilla") -> Any:
|
| 419 |
+
"""Build Chroma vector database."""
|
| 420 |
+
|
| 421 |
+
# Initialize Chroma client
|
| 422 |
+
chroma_client = chromadb.PersistentClient(path=str(CHROMA_PATH / method))
|
| 423 |
+
collection = chroma_client.get_or_create_collection(
|
| 424 |
+
name=f"{method}_collection",
|
| 425 |
+
metadata={"method": method}
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Prepare data for Chroma
|
| 429 |
+
texts = [chunk.text for chunk in chunks]
|
| 430 |
+
ids = [chunk.chunk_id for chunk in chunks]
|
| 431 |
+
metadatas = [chunk.metadata for chunk in chunks]
|
| 432 |
+
|
| 433 |
+
# Add to collection (Chroma handles embeddings internally)
|
| 434 |
+
collection.add(
|
| 435 |
+
documents=texts,
|
| 436 |
+
ids=ids,
|
| 437 |
+
metadatas=metadatas
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
logger.info(f"Built Chroma collection with {collection.count()} documents for method '{method}'")
|
| 441 |
+
return collection
|
| 442 |
+
|
| 443 |
+
def build_bm25_index(self, chunks: List[DocumentChunk]) -> BM25Okapi:
|
| 444 |
+
"""Build BM25 index for keyword search."""
|
| 445 |
+
|
| 446 |
+
# Tokenize texts
|
| 447 |
+
tokenized_corpus = []
|
| 448 |
+
for chunk in chunks:
|
| 449 |
+
tokens = chunk.text.lower().split()
|
| 450 |
+
tokenized_corpus.append(tokens)
|
| 451 |
+
|
| 452 |
+
# Build BM25 index
|
| 453 |
+
bm25 = BM25Okapi(tokenized_corpus, k1=BM25_K1, b=BM25_B)
|
| 454 |
+
|
| 455 |
+
logger.info(f"Built BM25 index with {len(tokenized_corpus)} documents")
|
| 456 |
+
return bm25
|
| 457 |
+
|
| 458 |
+
def build_graph_index(self, chunks: List[DocumentChunk]) -> nx.Graph:
|
| 459 |
+
"""Build NetworkX graph for graph-based retrieval."""
|
| 460 |
+
|
| 461 |
+
# Create graph
|
| 462 |
+
G = nx.Graph()
|
| 463 |
+
|
| 464 |
+
# Generate embeddings for similarity calculation
|
| 465 |
+
texts = [chunk.text for chunk in chunks]
|
| 466 |
+
embeddings = self.embedding_generator.embed_text_openai(texts)
|
| 467 |
+
|
| 468 |
+
# Add nodes (convert embeddings to lists for GML serialization)
|
| 469 |
+
for i, chunk in enumerate(chunks):
|
| 470 |
+
G.add_node(chunk.chunk_id,
|
| 471 |
+
text=chunk.text,
|
| 472 |
+
metadata=chunk.metadata,
|
| 473 |
+
embedding=embeddings[i].tolist()) # Convert to list for serialization
|
| 474 |
+
|
| 475 |
+
# Add edges based on similarity
|
| 476 |
+
threshold = 0.7 # Similarity threshold
|
| 477 |
+
for i in range(len(chunks)):
|
| 478 |
+
for j in range(i + 1, len(chunks)):
|
| 479 |
+
# Calculate cosine similarity
|
| 480 |
+
sim = np.dot(embeddings[i], embeddings[j]) / (
|
| 481 |
+
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if sim > threshold:
|
| 485 |
+
G.add_edge(chunks[i].chunk_id, chunks[j].chunk_id,
|
| 486 |
+
weight=float(sim))
|
| 487 |
+
|
| 488 |
+
logger.info(f"Built graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
|
| 489 |
+
return G
|
| 490 |
+
|
| 491 |
+
def save_index(self, index: Any, metadata: Any, method: str):
|
| 492 |
+
"""Save index and metadata to disk."""
|
| 493 |
+
|
| 494 |
+
if method == "vanilla":
|
| 495 |
+
faiss.write_index(index, str(VANILLA_FAISS_INDEX))
|
| 496 |
+
with open(VANILLA_METADATA, 'wb') as f:
|
| 497 |
+
pickle.dump(metadata, f)
|
| 498 |
+
|
| 499 |
+
elif method == "dpr":
|
| 500 |
+
faiss.write_index(index, str(DPR_FAISS_INDEX))
|
| 501 |
+
with open(DPR_METADATA, 'wb') as f:
|
| 502 |
+
pickle.dump(metadata, f)
|
| 503 |
+
|
| 504 |
+
elif method == "bm25":
|
| 505 |
+
with open(BM25_INDEX, 'wb') as f:
|
| 506 |
+
pickle.dump({'index': index, 'texts': metadata}, f)
|
| 507 |
+
|
| 508 |
+
elif method == "context_stuffing":
|
| 509 |
+
with open(CONTEXT_DOCS, 'wb') as f:
|
| 510 |
+
pickle.dump(metadata, f)
|
| 511 |
+
|
| 512 |
+
elif method == "graph":
|
| 513 |
+
nx.write_gml(index, str(GRAPH_FILE))
|
| 514 |
+
|
| 515 |
+
logger.info(f"Saved {method} index to disk")
|
| 516 |
+
|
| 517 |
+
class ImageProcessor:
|
| 518 |
+
"""Process and classify images."""
|
| 519 |
+
|
| 520 |
+
def __init__(self):
|
| 521 |
+
self.embedding_generator = EmbeddingGenerator()
|
| 522 |
+
self.openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
| 523 |
+
self._init_database()
|
| 524 |
+
|
| 525 |
+
def _init_database(self):
|
| 526 |
+
"""Initialize SQLite database for image metadata."""
|
| 527 |
+
conn = sqlite3.connect(IMAGES_DB)
|
| 528 |
+
cursor = conn.cursor()
|
| 529 |
+
|
| 530 |
+
cursor.execute('''
|
| 531 |
+
CREATE TABLE IF NOT EXISTS images (
|
| 532 |
+
image_id TEXT PRIMARY KEY,
|
| 533 |
+
image_path TEXT NOT NULL,
|
| 534 |
+
classification TEXT,
|
| 535 |
+
metadata TEXT,
|
| 536 |
+
embedding BLOB,
|
| 537 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 538 |
+
)
|
| 539 |
+
''')
|
| 540 |
+
|
| 541 |
+
conn.commit()
|
| 542 |
+
conn.close()
|
| 543 |
+
|
| 544 |
+
def classify_image(self, image_path: str) -> str:
|
| 545 |
+
"""Classify image using GPT-5 Vision."""
|
| 546 |
+
try:
|
| 547 |
+
# Convert image to base64
|
| 548 |
+
with open(image_path, "rb") as image_file:
|
| 549 |
+
image_b64 = base64.b64encode(image_file.read()).decode()
|
| 550 |
+
|
| 551 |
+
messages = [{
|
| 552 |
+
"role": "user",
|
| 553 |
+
"content": [
|
| 554 |
+
{"type": "text", "text": "Classify this image in 1-2 words (e.g., 'machine guard', 'press brake', 'conveyor belt', 'safety sign')."},
|
| 555 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}", "detail": "low"}}
|
| 556 |
+
]
|
| 557 |
+
}]
|
| 558 |
+
|
| 559 |
+
# For GPT-5 vision, temperature must be default (1.0)
|
| 560 |
+
response = self.openai_client.chat.completions.create(
|
| 561 |
+
model=OPENAI_CHAT_MODEL,
|
| 562 |
+
messages=messages,
|
| 563 |
+
max_completion_tokens=50
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
return response.choices[0].message.content.strip()
|
| 567 |
+
|
| 568 |
+
except Exception as e:
|
| 569 |
+
logger.error(f"Error classifying image {image_path}: {e}")
|
| 570 |
+
return "unknown"
|
| 571 |
+
|
| 572 |
+
def should_filter_image(self, image_path: str) -> tuple[bool, str]:
|
| 573 |
+
"""
|
| 574 |
+
Check if image should be filtered out based on height and black image criteria.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
image_path: Path to the image file
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
Tuple of (should_filter: bool, reason: str)
|
| 581 |
+
"""
|
| 582 |
+
try:
|
| 583 |
+
from PIL import Image
|
| 584 |
+
import numpy as np
|
| 585 |
+
|
| 586 |
+
# Open and analyze the image
|
| 587 |
+
with Image.open(image_path) as img:
|
| 588 |
+
# Convert to RGB if needed
|
| 589 |
+
if img.mode != 'RGB':
|
| 590 |
+
img = img.convert('RGB')
|
| 591 |
+
|
| 592 |
+
width, height = img.size
|
| 593 |
+
|
| 594 |
+
# Filter 1: Height less than 40 pixels
|
| 595 |
+
if height < 40:
|
| 596 |
+
return True, f"height too small ({height}px)"
|
| 597 |
+
|
| 598 |
+
# Filter 2: Check if image is mostly black
|
| 599 |
+
img_array = np.array(img)
|
| 600 |
+
mean_brightness = np.mean(img_array)
|
| 601 |
+
|
| 602 |
+
# If mean brightness is very low (mostly black)
|
| 603 |
+
if mean_brightness < 10: # Adjust threshold as needed
|
| 604 |
+
return True, "mostly black image"
|
| 605 |
+
|
| 606 |
+
except Exception as e:
|
| 607 |
+
logger.warning(f"Error analyzing image {image_path}: {e}")
|
| 608 |
+
# If we can't analyze it, don't filter it out
|
| 609 |
+
return False, "analysis failed"
|
| 610 |
+
|
| 611 |
+
return False, "passed all filters"
|
| 612 |
+
|
| 613 |
+
def store_image_metadata(self, image_data: ImageData):
|
| 614 |
+
"""Store image metadata in database."""
|
| 615 |
+
conn = sqlite3.connect(IMAGES_DB)
|
| 616 |
+
cursor = conn.cursor()
|
| 617 |
+
|
| 618 |
+
# Serialize metadata and embedding
|
| 619 |
+
metadata_json = json.dumps(image_data.metadata) if image_data.metadata else None
|
| 620 |
+
embedding_blob = image_data.embedding.tobytes() if image_data.embedding is not None else None
|
| 621 |
+
|
| 622 |
+
cursor.execute('''
|
| 623 |
+
INSERT OR REPLACE INTO images
|
| 624 |
+
(image_id, image_path, classification, metadata, embedding)
|
| 625 |
+
VALUES (?, ?, ?, ?, ?)
|
| 626 |
+
''', (image_data.image_id, image_data.image_path,
|
| 627 |
+
image_data.classification, metadata_json, embedding_blob))
|
| 628 |
+
|
| 629 |
+
conn.commit()
|
| 630 |
+
conn.close()
|
| 631 |
+
|
| 632 |
+
def get_image_metadata(self, image_id: str) -> Optional[ImageData]:
|
| 633 |
+
"""Retrieve image metadata from database."""
|
| 634 |
+
conn = sqlite3.connect(IMAGES_DB)
|
| 635 |
+
cursor = conn.cursor()
|
| 636 |
+
|
| 637 |
+
cursor.execute('''
|
| 638 |
+
SELECT image_id, image_path, classification, metadata, embedding
|
| 639 |
+
FROM images WHERE image_id = ?
|
| 640 |
+
''', (image_id,))
|
| 641 |
+
|
| 642 |
+
row = cursor.fetchone()
|
| 643 |
+
conn.close()
|
| 644 |
+
|
| 645 |
+
if row:
|
| 646 |
+
image_id, image_path, classification, metadata_json, embedding_blob = row
|
| 647 |
+
|
| 648 |
+
metadata = json.loads(metadata_json) if metadata_json else None
|
| 649 |
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32) if embedding_blob else None
|
| 650 |
+
|
| 651 |
+
return ImageData(
|
| 652 |
+
image_path=image_path,
|
| 653 |
+
image_id=image_id,
|
| 654 |
+
classification=classification,
|
| 655 |
+
embedding=embedding,
|
| 656 |
+
metadata=metadata
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
return None
|
| 660 |
+
|
| 661 |
+
def load_text_documents() -> List[Dict[str, Any]]:
|
| 662 |
+
"""Convenience function to load all text documents."""
|
| 663 |
+
loader = DocumentLoader()
|
| 664 |
+
return loader.load_text_documents()
|
| 665 |
+
|
| 666 |
+
def embed_image_clip(image_paths: List[str]) -> np.ndarray:
|
| 667 |
+
"""Convenience function to embed images with CLIP."""
|
| 668 |
+
generator = EmbeddingGenerator()
|
| 669 |
+
return generator.embed_image_clip(image_paths)
|
| 670 |
+
|
| 671 |
+
def store_image_metadata(image_data: ImageData):
|
| 672 |
+
"""Convenience function to store image metadata."""
|
| 673 |
+
processor = ImageProcessor()
|
| 674 |
+
processor.store_image_metadata(image_data)
|
| 675 |
+
|
| 676 |
+
def classify_image(image_path: str) -> str:
|
| 677 |
+
"""Convenience function to classify an image."""
|
| 678 |
+
processor = ImageProcessor()
|
| 679 |
+
return processor.classify_image(image_path)
|