Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List, Dict, Union, Optional, Any | |
| import numpy as np | |
| from .embedding_provider import EmbeddingProvider | |
| from .database.annoydb import AnnoyDB | |
| from .keyword_search_provider import KeywordSearchProvider | |
| class HybridSearch: | |
| def __init__( | |
| self, | |
| embedding_provider: EmbeddingProvider, | |
| documents: List[str] = None, | |
| ann_filepath: Optional[str] = None, | |
| semantic_weight: float = 0.7, | |
| keyword_weight: float = 0.3 | |
| ) -> None: | |
| self.embedding_provider = embedding_provider | |
| self.documents = documents | |
| if ann_filepath and os.path.exists(ann_filepath): | |
| self.index = AnnoyDB | |
| self.embeddings = self.embedding_provider.embed_documents(documents) | |
| self.vector_db = AnnoyDB( | |
| embedding_dim=self.embeddings.shape[1] | |
| ) | |
| for emb, doc in zip(self.embeddings, documents): | |
| self.vector_db.add_data(emb, doc) | |
| self.vector_db.build() | |
| # Keyword Search Setup | |
| self.keyword_search = KeywordSearchProvider(documents) | |
| # Weights for hybrid search | |
| self.semantic_weight = semantic_weight | |
| self.keyword_weight = keyword_weight | |
| self.documents = documents | |
| def hybrid_search(self, query: str, top_k: int = 5) -> List[Dict[str, Union[str, float]]]: | |
| # Embed query | |
| query_embedding = self.embedding_provider.embed_query(query) | |
| # Perform semantic search | |
| semantic_results = self.vector_db.search(query_embedding, top_k) | |
| # Perform keyword search | |
| keyword_results = self.keyword_search.search(query, top_k) | |
| # Combine results with weighted scoring | |
| combined_results = {} | |
| for result in semantic_results: | |
| doc = result['document'] | |
| combined_results[doc] = { | |
| 'semantic_score': result['score'] * self.semantic_weight, | |
| 'keyword_score': 0, | |
| 'hybrid_score': result['score'] * self.semantic_weight | |
| } | |
| for result in keyword_results: | |
| doc = result['document'] | |
| if doc in combined_results: | |
| combined_results[doc]['keyword_score'] = result['score'] * self.keyword_weight | |
| combined_results[doc]['hybrid_score'] += result['score'] * self.keyword_weight | |
| else: | |
| combined_results[doc] = { | |
| 'semantic_score': 0, | |
| 'keyword_score': result['score'] * self.keyword_weight, | |
| 'hybrid_score': result['score'] * self.keyword_weight | |
| } | |
| # Sort and return top results | |
| sorted_results = sorted( | |
| [ | |
| {**{'document': doc}, **scores} | |
| for doc, scores in combined_results.items() | |
| ], | |
| key=lambda x: x['hybrid_score'], | |
| reverse=True | |
| ) | |
| return sorted_results[:top_k] | |
| def set_weights(self, semantic_weight: float, keyword_weight: float): | |
| """ | |
| Dynamically update search weights. | |
| Args: | |
| semantic_weight: New weight for semantic search | |
| keyword_weight: New weight for keyword search | |
| """ | |
| if not (0 <= semantic_weight <= 1 and 0 <= keyword_weight <= 1): | |
| raise ValueError("Weights must be between 0 and 1") | |
| if not np.isclose(semantic_weight + keyword_weight, 1.0): | |
| raise ValueError("Semantic and keyword weights must sum to 1.0") | |
| self.semantic_weight = semantic_weight | |
| self.keyword_weight = keyword_weight |