File size: 2,390 Bytes
5d6e1ce 1200d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import src.datasets as datasets
import src.models as models
import src.indexes as indexes
import numpy as np
def load_or_create_index(tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict]) -> indexes.Index:
"""Load an existing index or create a new one with default values if it doesn't exist."""
if indexes.index_exists():
# Load existing index
index = indexes.load_index()
else:
# Create a new index with default values
embeddings = embed_dataset(dataset, tokenizer, model)
index = indexes.create_index(embeddings)
return index
def embed_dataset(
dataset: list[dict],
tokenizer: models.Tokenizer,
model: models.Model,
map: callable = None,
pooling_method: str = models.DEFAULT_POOLING_METHOD
) -> np.ndarray:
"""Create a FAISS index from the given dataset."""
texts = datasets.build_text_representations(dataset, map)
embeddings = models.embed_texts(texts, tokenizer, model, pooling_method)
return embeddings
def inference(
inputs: list[str],
index,
tokenizer: models.Tokenizer,
model: models.Model,
dataset: list[dict],
pooling_method: str = models.DEFAULT_POOLING_METHOD,
k: int = 1
) -> list[dict]:
"""Perform inference on the input text and return top matches from the index."""
embeddings = models.embed_texts(inputs, tokenizer, model, pooling_method)
distances, indices = indexes.find_closest(embeddings, index, k=k)
return build_inference_results(inputs, embeddings, dataset, distances, indices)
def build_inference_results(inputs: list[str], embeddings: np.ndarray, dataset: list[dict], distances: np.ndarray, indices: np.ndarray) -> list[dict]:
"""Build the inference results from the distances and indices."""
results = []
for i in range(len(distances)):
indices_list = [idx for idx in indices[i]]
# Create a list of matches with distances and proverbs
matches = {
"input": inputs[i],
"embedding": embeddings[i],
"matches": [
{
"rank": j + 1,
"proverb": dataset[indices_list[j]]["proverb"],
"distance": float(distances[i][j])
}
for j in range(len(indices_list))
]
}
results.append(matches)
return results
|