File size: 2,390 Bytes
5d6e1ce
1200d57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import src.datasets as datasets
import src.models as models
import src.indexes as indexes
import numpy as np


def load_or_create_index(tokenizer: models.Tokenizer, model: models.Model, dataset: list[dict]) -> indexes.Index:
    """Load an existing index or create a new one with default values if it doesn't exist."""
    if indexes.index_exists():
        # Load existing index
        index = indexes.load_index()
    else:
        # Create a new index with default values
        embeddings = embed_dataset(dataset, tokenizer, model)
        index = indexes.create_index(embeddings)
    return index


def embed_dataset(
    dataset: list[dict],
    tokenizer: models.Tokenizer,
    model: models.Model,
    map: callable = None,
    pooling_method: str = models.DEFAULT_POOLING_METHOD
) -> np.ndarray:
    """Create a FAISS index from the given dataset."""
    texts = datasets.build_text_representations(dataset, map)
    embeddings = models.embed_texts(texts, tokenizer, model, pooling_method)
    return embeddings


def inference(
    inputs: list[str],
    index,
    tokenizer: models.Tokenizer,
    model: models.Model,
    dataset: list[dict],
    pooling_method: str = models.DEFAULT_POOLING_METHOD,
    k: int = 1
) -> list[dict]:
    """Perform inference on the input text and return top matches from the index."""
    embeddings = models.embed_texts(inputs, tokenizer, model, pooling_method)
    distances, indices = indexes.find_closest(embeddings, index, k=k)
    return build_inference_results(inputs, embeddings, dataset, distances, indices)


def build_inference_results(inputs: list[str], embeddings: np.ndarray, dataset: list[dict], distances: np.ndarray, indices: np.ndarray) -> list[dict]:
    """Build the inference results from the distances and indices."""
    results = []
    for i in range(len(distances)):
        indices_list = [idx for idx in indices[i]]

        # Create a list of matches with distances and proverbs
        matches = {
            "input": inputs[i],
            "embedding": embeddings[i],
            "matches": [
                {
                    "rank": j + 1,
                    "proverb": dataset[indices_list[j]]["proverb"],
                    "distance": float(distances[i][j])
                }
                for j in range(len(indices_list))
            ]
        }
        results.append(matches)
    return results