pauhmolins's picture
Fixed import
5d6e1ce
import src.datasets as datasets
import src.models as models
import src.indexes as indexes
import numpy as np
def load_or_create_index(tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict]) -> indexes.Index:
"""Load an existing index or create a new one with default values if it doesn't exist."""
if indexes.index_exists():
# Load existing index
index = indexes.load_index()
else:
# Create a new index with default values
embeddings = embed_dataset(dataset, tokenizer, model)
index = indexes.create_index(embeddings)
return index
def embed_dataset(
dataset: list[dict],
tokenizer: models.Tokenizer,
model: models.Model,
map: callable = None,
pooling_method: str = models.DEFAULT_POOLING_METHOD
) -> np.ndarray:
"""Create a FAISS index from the given dataset."""
texts = datasets.build_text_representations(dataset, map)
embeddings = models.embed_texts(texts, tokenizer, model, pooling_method)
return embeddings
def inference(
inputs: list[str],
index,
tokenizer: models.Tokenizer,
model: models.Model,
dataset: list[dict],
pooling_method: str = models.DEFAULT_POOLING_METHOD,
k: int = 1
) -> list[dict]:
"""Perform inference on the input text and return top matches from the index."""
embeddings = models.embed_texts(inputs, tokenizer, model, pooling_method)
distances, indices = indexes.find_closest(embeddings, index, k=k)
return build_inference_results(inputs, embeddings, dataset, distances, indices)
def build_inference_results(inputs: list[str], embeddings: np.ndarray, dataset: list[dict], distances: np.ndarray, indices: np.ndarray) -> list[dict]:
"""Build the inference results from the distances and indices."""
results = []
for i in range(len(distances)):
indices_list = [idx for idx in indices[i]]
# Create a list of matches with distances and proverbs
matches = {
"input": inputs[i],
"embedding": embeddings[i],
"matches": [
{
"rank": j + 1,
"proverb": dataset[indices_list[j]]["proverb"],
"distance": float(distances[i][j])
}
for j in range(len(indices_list))
]
}
results.append(matches)
return results