Spaces:
Paused
Paused
| """ | |
| Dense Passage Retrieval (DPR) query module. | |
| Uses bi-encoder for retrieval and cross-encoder for re-ranking. | |
| """ | |
| import pickle | |
| import logging | |
| from typing import List, Tuple, Optional | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from openai import OpenAI | |
| from config import * | |
| logger = logging.getLogger(__name__) | |
| class DPRRetriever: | |
| """Dense Passage Retrieval with cross-encoder re-ranking.""" | |
| def __init__(self): | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| self.bi_encoder = None | |
| self.cross_encoder = None | |
| self.index = None | |
| self.metadata = None | |
| self._load_models() | |
| self._load_index() | |
| def _load_models(self): | |
| """Load bi-encoder and cross-encoder models.""" | |
| try: | |
| logger.info("Loading DPR models...") | |
| self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
| self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL) | |
| if DEVICE == "cuda": | |
| self.bi_encoder = self.bi_encoder.to(DEVICE) | |
| self.cross_encoder = self.cross_encoder.to(DEVICE) | |
| logger.info("β DPR models loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading DPR models: {e}") | |
| raise | |
| def _load_index(self): | |
| """Load FAISS index and metadata.""" | |
| try: | |
| if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists(): | |
| logger.info("Loading DPR index and metadata...") | |
| # Load FAISS index | |
| self.index = faiss.read_index(str(DPR_FAISS_INDEX)) | |
| # Load metadata | |
| with open(DPR_METADATA, 'rb') as f: | |
| data = pickle.load(f) | |
| self.metadata = data | |
| logger.info(f"β Loaded DPR index with {len(self.metadata)} chunks") | |
| else: | |
| logger.warning("DPR index not found. Run preprocess.py first.") | |
| except Exception as e: | |
| logger.error(f"Error loading DPR index: {e}") | |
| raise | |
| def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: | |
| """Retrieve candidate passages using bi-encoder.""" | |
| if self.index is None or self.metadata is None: | |
| raise ValueError("DPR index not loaded. Run preprocess.py first.") | |
| try: | |
| # Encode question with bi-encoder | |
| question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(question_embedding) | |
| # Search FAISS index | |
| # Retrieve more candidates for re-ranking | |
| retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata)) | |
| scores, indices = self.index.search(question_embedding, retrieve_k) | |
| # Prepare candidates | |
| candidates = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx < len(self.metadata): | |
| chunk_data = self.metadata[idx] | |
| candidates.append(( | |
| chunk_data['text'], | |
| float(score), | |
| chunk_data['metadata'] | |
| )) | |
| logger.info(f"Retrieved {len(candidates)} candidates for re-ranking") | |
| return candidates | |
| except Exception as e: | |
| logger.error(f"Error in candidate retrieval: {e}") | |
| raise | |
| def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]], | |
| top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: | |
| """Re-rank candidates using cross-encoder.""" | |
| if not candidates: | |
| return [] | |
| try: | |
| # Prepare pairs for cross-encoder | |
| pairs = [(question, candidate[0]) for candidate in candidates] | |
| # Get cross-encoder scores | |
| cross_scores = self.cross_encoder.predict(pairs) | |
| # Combine with candidate data and re-sort | |
| reranked = [] | |
| for i, (text, bi_score, metadata) in enumerate(candidates): | |
| cross_score = float(cross_scores[i]) | |
| # Filter by minimum relevance score | |
| if cross_score >= MIN_RELEVANCE_SCORE: | |
| reranked.append((text, cross_score, metadata)) | |
| # Sort by cross-encoder score (descending) | |
| reranked.sort(key=lambda x: x[1], reverse=True) | |
| # Return top-k | |
| final_results = reranked[:top_k] | |
| logger.info(f"Re-ranked to {len(final_results)} final results") | |
| return final_results | |
| except Exception as e: | |
| logger.error(f"Error in re-ranking: {e}") | |
| # Fall back to bi-encoder results | |
| return candidates[:top_k] | |
| def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str: | |
| """Generate answer using GPT with retrieved context.""" | |
| if not context_chunks: | |
| return "I couldn't find relevant information to answer your question." | |
| try: | |
| # Prepare context | |
| context_parts = [] | |
| for i, (text, score, metadata) in enumerate(context_chunks, 1): | |
| source = metadata.get('source', 'Unknown') | |
| context_parts.append(f"[Context {i}] Source: {source}\n{text}") | |
| context = "\n\n".join(context_parts) | |
| # Create system message | |
| system_message = ( | |
| "You are a helpful assistant specialized in occupational safety and health. " | |
| "Answer questions based only on the provided context. " | |
| "If the context doesn't contain enough information, say so clearly. " | |
| "Always cite the source when referencing information." | |
| ) | |
| # Create user message | |
| user_message = f"Context:\n{context}\n\nQuestion: {question}" | |
| # Generate response | |
| # For GPT-5, temperature must be default (1.0) | |
| response = self.client.chat.completions.create( | |
| model=OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ], | |
| max_completion_tokens=DEFAULT_MAX_TOKENS | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| logger.error(f"Error generating answer: {e}") | |
| return "I apologize, but I encountered an error while generating the answer." | |
| # Global retriever instance | |
| _retriever = None | |
| def get_retriever() -> DPRRetriever: | |
| """Get or create global DPR retriever instance.""" | |
| global _retriever | |
| if _retriever is None: | |
| _retriever = DPRRetriever() | |
| return _retriever | |
| def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]: | |
| """ | |
| Main DPR query function with unified signature. | |
| Args: | |
| question: User question | |
| image_path: Optional image path (not used in DPR but kept for consistency) | |
| top_k: Number of top results to retrieve | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| try: | |
| retriever = get_retriever() | |
| # Step 1: Retrieve candidates with bi-encoder | |
| candidates = retriever.retrieve_candidates(question, top_k) | |
| if not candidates: | |
| return "I couldn't find any relevant information for your question.", [] | |
| # Step 2: Re-rank with cross-encoder | |
| reranked_candidates = retriever.rerank_candidates(question, candidates, top_k) | |
| # Step 3: Generate answer | |
| answer = retriever.generate_answer(question, reranked_candidates) | |
| # Step 4: Prepare citations | |
| citations = [] | |
| for i, (text, score, metadata) in enumerate(reranked_candidates, 1): | |
| citations.append({ | |
| 'rank': i, | |
| 'text': text, | |
| 'score': float(score), | |
| 'source': metadata.get('source', 'Unknown'), | |
| 'type': metadata.get('type', 'unknown'), | |
| 'method': 'dpr' | |
| }) | |
| logger.info(f"DPR query completed. Retrieved {len(citations)} citations.") | |
| return answer, citations | |
| except Exception as e: | |
| logger.error(f"Error in DPR query: {e}") | |
| error_message = "I apologize, but I encountered an error while processing your question with DPR." | |
| return error_message, [] | |
| def query_with_details(question: str, image_path: Optional[str] = None, | |
| top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]: | |
| """ | |
| DPR query function that returns detailed chunk information (for compatibility). | |
| Returns: | |
| Tuple of (answer, citations, chunks) | |
| """ | |
| answer, citations = query(question, image_path, top_k) | |
| # Convert citations to chunk format for backward compatibility | |
| chunks = [] | |
| for citation in citations: | |
| chunks.append(( | |
| f"Rank {citation['rank']} (Score: {citation['score']:.3f})", | |
| citation['score'], | |
| citation['text'], | |
| citation['source'] | |
| )) | |
| return answer, citations, chunks | |
| if __name__ == "__main__": | |
| # Test the DPR system | |
| test_question = "What are the general requirements for machine guarding?" | |
| print("Testing DPR retrieval system...") | |
| print(f"Question: {test_question}") | |
| print("-" * 50) | |
| try: | |
| answer, citations = query(test_question) | |
| print("Answer:") | |
| print(answer) | |
| print("\nCitations:") | |
| for citation in citations: | |
| print(f"- {citation['source']} (Score: {citation['score']:.3f})") | |
| except Exception as e: | |
| print(f"Error during testing: {e}") |