Spaces:
Paused
Paused
| """ | |
| BM25 keyword search with cross-encoder re-ranking and hybrid search support. | |
| """ | |
| import numpy as np | |
| import faiss | |
| from typing import Tuple, List, Optional | |
| from openai import OpenAI | |
| from sentence_transformers import CrossEncoder | |
| import config | |
| import utils | |
| # Initialize models | |
| client = OpenAI(api_key=config.OPENAI_API_KEY) | |
| cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL) | |
| # Global variables for lazy loading | |
| _bm25_index = None | |
| _texts = None | |
| _metadata = None | |
| _semantic_index = None | |
| def _load_bm25_index(): | |
| """Lazy load BM25 index and metadata.""" | |
| global _bm25_index, _texts, _metadata, _semantic_index | |
| if _bm25_index is None: | |
| # Initialize defaults | |
| _texts = [] | |
| _metadata = [] | |
| _semantic_index = None | |
| try: | |
| import pickle | |
| if config.BM25_INDEX.exists(): | |
| with open(config.BM25_INDEX, 'rb') as f: | |
| bm25_data = pickle.load(f) | |
| if isinstance(bm25_data, dict): | |
| _bm25_index = bm25_data.get('index') or bm25_data.get('bm25') | |
| chunks = bm25_data.get('texts', []) | |
| if chunks: | |
| _texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')] | |
| _metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')] | |
| else: | |
| _texts = [] | |
| _metadata = [] | |
| # Load semantic embeddings if available for hybrid search | |
| if 'embeddings' in bm25_data: | |
| semantic_embeddings = bm25_data['embeddings'] | |
| # Build FAISS index | |
| import faiss | |
| dimension = semantic_embeddings.shape[1] | |
| _semantic_index = faiss.IndexFlatIP(dimension) | |
| faiss.normalize_L2(semantic_embeddings) | |
| _semantic_index.add(semantic_embeddings) | |
| else: | |
| _bm25_index = bm25_data | |
| _texts = [] | |
| _metadata = [] | |
| print(f"Loaded BM25 index with {len(_texts)} documents") | |
| else: | |
| print("BM25 index not found. Run preprocess.py first.") | |
| except Exception as e: | |
| print(f"Error loading BM25 index: {e}") | |
| _bm25_index = None | |
| _texts = [] | |
| _metadata = [] | |
| def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]: | |
| """ | |
| Query using BM25 keyword search with re-ranking. | |
| Args: | |
| question: User's question | |
| image_path: Optional path to an image | |
| top_k: Number of relevant chunks to retrieve | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| if top_k is None: | |
| top_k = config.DEFAULT_TOP_K | |
| # Load index if not already loaded | |
| _load_bm25_index() | |
| if _bm25_index is None or len(_texts) == 0: | |
| return "BM25 index not loaded. Please run preprocess.py first.", [] | |
| # Tokenize query for BM25 | |
| tokenized_query = question.lower().split() | |
| # Get BM25 scores | |
| bm25_scores = _bm25_index.get_scores(tokenized_query) | |
| # Get top candidates (retrieve more for re-ranking) | |
| top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER] | |
| # Prepare candidates for re-ranking | |
| candidates = [] | |
| for idx in top_indices: | |
| if idx < len(_texts) and bm25_scores[idx] > 0: | |
| candidates.append({ | |
| 'text': _texts[idx], | |
| 'bm25_score': bm25_scores[idx], | |
| 'metadata': _metadata[idx], | |
| 'idx': idx | |
| }) | |
| # Re-rank with cross-encoder | |
| if candidates: | |
| pairs = [[question, cand['text']] for cand in candidates] | |
| cross_scores = cross_encoder.predict(pairs) | |
| # Add cross-encoder scores and sort | |
| for i, score in enumerate(cross_scores): | |
| candidates[i]['cross_score'] = score | |
| candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k] | |
| # Collect citations | |
| citations = [] | |
| sources_seen = set() | |
| for chunk in candidates: | |
| chunk_meta = chunk['metadata'] | |
| if chunk_meta['source'] not in sources_seen: | |
| citation = { | |
| 'source': chunk_meta['source'], | |
| 'type': chunk_meta['type'], | |
| 'bm25_score': round(chunk['bm25_score'], 3), | |
| 'rerank_score': round(chunk['cross_score'], 3) | |
| } | |
| if chunk_meta['type'] == 'pdf': | |
| citation['path'] = chunk_meta['path'] | |
| else: | |
| citation['url'] = chunk_meta.get('url', '') | |
| citations.append(citation) | |
| sources_seen.add(chunk_meta['source']) | |
| # Handle image if provided | |
| image_context = "" | |
| if image_path: | |
| try: | |
| classification = utils.classify_image(image_path) | |
| # classification is a string, not a dict | |
| image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]" | |
| except Exception as e: | |
| print(f"Error processing image: {e}") | |
| # Build context from retrieved chunks | |
| context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates]) | |
| if not context: | |
| return "No relevant documents found for your query.", [] | |
| # Generate answer | |
| prompt = f"""Answer the following question using the retrieved documents: | |
| Retrieved Documents: | |
| {context}{image_context} | |
| Question: {question} | |
| Instructions: | |
| 1. Provide a comprehensive answer based on the retrieved documents | |
| 2. Mention specific details from the sources | |
| 3. If the documents don't fully answer the question, indicate what information is missing""" | |
| # For GPT-5, temperature must be default (1.0) | |
| response = client.chat.completions.create( | |
| model=config.OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_completion_tokens=config.DEFAULT_MAX_TOKENS | |
| ) | |
| answer = response.choices[0].message.content | |
| return answer, citations | |
| def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]: | |
| """ | |
| Hybrid search combining BM25 and semantic search. | |
| Args: | |
| question: User's question | |
| top_k: Number of relevant chunks to retrieve | |
| alpha: Weight for BM25 scores (1-alpha for semantic) | |
| Returns: | |
| Tuple of (answer, citations) | |
| """ | |
| if top_k is None: | |
| top_k = config.DEFAULT_TOP_K | |
| if alpha is None: | |
| alpha = config.DEFAULT_HYBRID_ALPHA | |
| # Load index if not already loaded | |
| _load_bm25_index() | |
| if _bm25_index is None or _semantic_index is None: | |
| return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", [] | |
| # Get BM25 scores | |
| tokenized_query = question.lower().split() | |
| bm25_scores = _bm25_index.get_scores(tokenized_query) | |
| # Normalize BM25 scores | |
| if bm25_scores.max() > 0: | |
| bm25_scores = bm25_scores / bm25_scores.max() | |
| # Get semantic scores using FAISS | |
| embedding_generator = utils.EmbeddingGenerator() | |
| query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32) | |
| faiss.normalize_L2(query_embedding) | |
| # Search semantic index for all documents | |
| k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER) | |
| distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search) | |
| # Create semantic scores array | |
| semantic_scores = np.zeros(len(_texts)) | |
| for idx, dist in zip(indices[0], distances[0]): | |
| if idx < len(_texts): | |
| semantic_scores[idx] = dist | |
| # Combine scores | |
| hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores | |
| # Get top candidates | |
| top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER] | |
| # Prepare candidates | |
| candidates = [] | |
| for idx in top_indices: | |
| if idx < len(_texts) and hybrid_scores[idx] > 0: | |
| candidates.append({ | |
| 'text': _texts[idx], | |
| 'hybrid_score': hybrid_scores[idx], | |
| 'bm25_score': bm25_scores[idx], | |
| 'semantic_score': semantic_scores[idx], | |
| 'metadata': _metadata[idx], | |
| 'idx': idx | |
| }) | |
| # Re-rank with cross-encoder | |
| if candidates: | |
| pairs = [[question, cand['text']] for cand in candidates] | |
| cross_scores = cross_encoder.predict(pairs) | |
| for i, score in enumerate(cross_scores): | |
| candidates[i]['cross_score'] = score | |
| # Final ranking using cross-encoder scores | |
| candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k] | |
| # Collect citations | |
| citations = [] | |
| sources_seen = set() | |
| for chunk in candidates: | |
| chunk_meta = chunk['metadata'] | |
| if chunk_meta['source'] not in sources_seen: | |
| citation = { | |
| 'source': chunk_meta['source'], | |
| 'type': chunk_meta['type'], | |
| 'hybrid_score': round(chunk['hybrid_score'], 3), | |
| 'rerank_score': round(chunk.get('cross_score', 0), 3) | |
| } | |
| if chunk_meta['type'] == 'pdf': | |
| citation['path'] = chunk_meta['path'] | |
| else: | |
| citation['url'] = chunk_meta.get('url', '') | |
| citations.append(citation) | |
| sources_seen.add(chunk_meta['source']) | |
| # Build context | |
| context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates]) | |
| if not context: | |
| return "No relevant documents found for your query.", [] | |
| # Generate answer | |
| prompt = f"""Using the following retrieved passages, answer the question: | |
| {context} | |
| Question: {question} | |
| Provide a clear, detailed answer based on the information in the passages.""" | |
| # For GPT-5, temperature must be default (1.0) | |
| response = client.chat.completions.create( | |
| model=config.OPENAI_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_completion_tokens=config.DEFAULT_MAX_TOKENS | |
| ) | |
| answer = response.choices[0].message.content | |
| return answer, citations | |
| if __name__ == "__main__": | |
| # Test BM25 query | |
| test_questions = [ | |
| "lockout tagout procedures", | |
| "machine guard requirements OSHA", | |
| "robot safety collaborative workspace" | |
| ] | |
| for q in test_questions: | |
| print(f"\nQuestion: {q}") | |
| answer, citations = query(q) | |
| print(f"Answer: {answer[:200]}...") | |
| print(f"Citations: {citations}") | |
| print("-" * 50) |