import src.datasets as datasets import src.models as models import src.indexes as indexes import numpy as np def load_or_create_index(tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict]) -> indexes.Index: """Load an existing index or create a new one with default values if it doesn't exist.""" if indexes.index_exists(): # Load existing index index = indexes.load_index() else: # Create a new index with default values embeddings = embed_dataset(dataset, tokenizer, model) index = indexes.create_index(embeddings) return index def embed_dataset( dataset: list[dict], tokenizer: models.Tokenizer, model: models.Model, map: callable = None, pooling_method: str = models.DEFAULT_POOLING_METHOD ) -> np.ndarray: """Create a FAISS index from the given dataset.""" texts = datasets.build_text_representations(dataset, map) embeddings = models.embed_texts(texts, tokenizer, model, pooling_method) return embeddings def inference( inputs: list[str], index, tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict], pooling_method: str = models.DEFAULT_POOLING_METHOD, k: int = 1 ) -> list[dict]: """Perform inference on the input text and return top matches from the index.""" embeddings = models.embed_texts(inputs, tokenizer, model, pooling_method) distances, indices = indexes.find_closest(embeddings, index, k=k) return build_inference_results(inputs, embeddings, dataset, distances, indices) def build_inference_results(inputs: list[str], embeddings: np.ndarray, dataset: list[dict], distances: np.ndarray, indices: np.ndarray) -> list[dict]: """Build the inference results from the distances and indices.""" results = [] for i in range(len(distances)): indices_list = [idx for idx in indices[i]] # Create a list of matches with distances and proverbs matches = { "input": inputs[i], "embedding": embeddings[i], "matches": [ { "rank": j + 1, "proverb": dataset[indices_list[j]]["proverb"], "distance": float(distances[i][j]) } for j in range(len(indices_list)) ] } results.append(matches) return results