|
|
import os |
|
|
import faiss |
|
|
import numpy as np |
|
|
from src.customlogger import log_time, logger |
|
|
|
|
|
|
|
|
Index = faiss.IndexFlat |
|
|
|
|
|
INDEX_FILE = "proverbs.index" |
|
|
|
|
|
|
|
|
INDEX_TYPES = [faiss.IndexFlatL2, faiss.IndexFlatIP] |
|
|
|
|
|
DEFAULT_INDEX_TYPE = faiss.IndexFlatL2 |
|
|
|
|
|
|
|
|
def index_exists(index_file: str = INDEX_FILE) -> bool: |
|
|
"""Check if the index file exists.""" |
|
|
return os.path.exists(index_file) |
|
|
|
|
|
|
|
|
@log_time |
|
|
def create_index(embeddings: np.ndarray, index_type: type = None, index_file: str = INDEX_FILE) -> Index: |
|
|
"""Create a FAISS index and store the given embeddings.""" |
|
|
if not index_type: |
|
|
index_type = DEFAULT_INDEX_TYPE |
|
|
|
|
|
dimension = embeddings.shape[1] |
|
|
logger.debug( |
|
|
f"Creating FAISS index with {len(embeddings)} {embeddings.shape[1]}-dimensional embeddings...") |
|
|
index = index_type(dimension) |
|
|
index.add(embeddings) |
|
|
logger.debug(f"Saving FAISS index to '{index_file}'...") |
|
|
faiss.write_index(index, index_file) |
|
|
return index |
|
|
|
|
|
|
|
|
@log_time |
|
|
def load_index(index_file: str = INDEX_FILE) -> Index: |
|
|
"""Load the FAISS index from the specified file.""" |
|
|
logger.debug(f"Loading FAISS index from '{index_file}'...") |
|
|
index = faiss.read_index(index_file) |
|
|
return index |
|
|
|
|
|
|
|
|
@log_time |
|
|
def find_closest(embeddings, index: Index, k=5) -> tuple[np.ndarray, np.ndarray]: |
|
|
"""Find the closest k vectors in the index for the given embeddings.""" |
|
|
logger.debug( |
|
|
f"Performing search for the top {k} matches of {len(embeddings)} embedding...") |
|
|
distances, indices = index.search(embeddings, k) |
|
|
return distances, indices |
|
|
|