|
|
import src.datasets as datasets |
|
|
import src.models as models |
|
|
import src.indexes as indexes |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def load_or_create_index(tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict]) -> indexes.Index: |
|
|
"""Load an existing index or create a new one with default values if it doesn't exist.""" |
|
|
if indexes.index_exists(): |
|
|
|
|
|
index = indexes.load_index() |
|
|
else: |
|
|
|
|
|
embeddings = embed_dataset(dataset, tokenizer, model) |
|
|
index = indexes.create_index(embeddings) |
|
|
return index |
|
|
|
|
|
|
|
|
def embed_dataset( |
|
|
dataset: list[dict], |
|
|
tokenizer: models.Tokenizer, |
|
|
model: models.Model, |
|
|
map: callable = None, |
|
|
pooling_method: str = models.DEFAULT_POOLING_METHOD |
|
|
) -> np.ndarray: |
|
|
"""Create a FAISS index from the given dataset.""" |
|
|
texts = datasets.build_text_representations(dataset, map) |
|
|
embeddings = models.embed_texts(texts, tokenizer, model, pooling_method) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
def inference( |
|
|
inputs: list[str], |
|
|
index, |
|
|
tokenizer: models.Tokenizer, |
|
|
model: models.Model, |
|
|
dataset: list[dict], |
|
|
pooling_method: str = models.DEFAULT_POOLING_METHOD, |
|
|
k: int = 1 |
|
|
) -> list[dict]: |
|
|
"""Perform inference on the input text and return top matches from the index.""" |
|
|
embeddings = models.embed_texts(inputs, tokenizer, model, pooling_method) |
|
|
distances, indices = indexes.find_closest(embeddings, index, k=k) |
|
|
return build_inference_results(inputs, embeddings, dataset, distances, indices) |
|
|
|
|
|
|
|
|
def build_inference_results(inputs: list[str], embeddings: np.ndarray, dataset: list[dict], distances: np.ndarray, indices: np.ndarray) -> list[dict]: |
|
|
"""Build the inference results from the distances and indices.""" |
|
|
results = [] |
|
|
for i in range(len(distances)): |
|
|
indices_list = [idx for idx in indices[i]] |
|
|
|
|
|
|
|
|
matches = { |
|
|
"input": inputs[i], |
|
|
"embedding": embeddings[i], |
|
|
"matches": [ |
|
|
{ |
|
|
"rank": j + 1, |
|
|
"proverb": dataset[indices_list[j]]["proverb"], |
|
|
"distance": float(distances[i][j]) |
|
|
} |
|
|
for j in range(len(indices_list)) |
|
|
] |
|
|
} |
|
|
results.append(matches) |
|
|
return results |
|
|
|