|
|
import shutil |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
import itertools |
|
|
import faiss |
|
|
import src.datasets as datasets |
|
|
import src.models as models |
|
|
import src.indexes as indexes |
|
|
import src.commons as commons |
|
|
from src.customlogger import log_time, logger |
|
|
|
|
|
|
|
|
def build_field_selection_maps(fields: list[str]) -> dict: |
|
|
"""Build all combinations of fields for proverb selection.""" |
|
|
combos = [] |
|
|
for r in range(1, len(fields) + 1): |
|
|
combos.extend(itertools.combinations(fields, r)) |
|
|
|
|
|
maps = {} |
|
|
for combo in combos: |
|
|
maps["_".join(combo)] = ( |
|
|
lambda proverb, combo=combo: |
|
|
[proverb[field] for field in combo if field != "themes"] |
|
|
|
|
|
+ (proverb["themes"] if "themes" in combo else []) |
|
|
) |
|
|
return maps |
|
|
|
|
|
|
|
|
def setup(): |
|
|
"""Set up the environment by loading the model, tokenizer, and dataset.""" |
|
|
|
|
|
tokenizer = models.load_tokenizer() |
|
|
model = models.load_model() |
|
|
|
|
|
|
|
|
proverbs = datasets.load_proverbs() |
|
|
prompts = datasets.load_prompts() |
|
|
|
|
|
|
|
|
|
|
|
if datasets.prompts_dataset_splits_exists(): |
|
|
|
|
|
_, prompts_test_set = datasets.load_prompt_dataset_splits() |
|
|
else: |
|
|
|
|
|
_, prompts_test_set = datasets.split_dataset(prompts) |
|
|
|
|
|
return tokenizer, model, proverbs, prompts_test_set |
|
|
|
|
|
|
|
|
@log_time |
|
|
def test_distances(tokenizer: models.Tokenizer, model: models.Tokenizer, model_name: str, |
|
|
proverbs: list[dict], prompts_test_set: list[dict], |
|
|
map: tuple[str, callable], index_type: type, pooling_method: str, |
|
|
remarks: str = "") -> dict: |
|
|
"""Test the distances between the actual and expected proverbs.""" |
|
|
|
|
|
embeddings = commons.embed_dataset( |
|
|
proverbs, tokenizer, model, map=map[1], pooling_method=pooling_method) |
|
|
index = indexes.create_index(embeddings, index_type) |
|
|
|
|
|
|
|
|
test_prompts = [entry["prompt"] for entry in prompts_test_set] |
|
|
results = commons.inference( |
|
|
test_prompts, index, tokenizer, model, proverbs, pooling_method) |
|
|
actual_proverbs_embeddings = [result["embedding"] for result in results] |
|
|
|
|
|
|
|
|
proverb_to_index = {proverb["proverb"] |
|
|
: i for i, proverb in enumerate(proverbs)} |
|
|
|
|
|
|
|
|
test_proverbs = [entry["proverb"] for entry in prompts_test_set] |
|
|
proverbs_indexes = [proverb_to_index[proverb] for proverb in test_proverbs] |
|
|
expected_proverbs_embeddings = [embeddings[i] for i in proverbs_indexes] |
|
|
|
|
|
|
|
|
distances = faiss.pairwise_distances( |
|
|
actual_proverbs_embeddings, expected_proverbs_embeddings, metric=index.metric_type) |
|
|
avg_distance = distances.mean() |
|
|
var_distance = distances.var() |
|
|
logger.info( |
|
|
f"Computed average distance between actual and expected proverbs: {avg_distance}") |
|
|
logger.info( |
|
|
f"Computed variance of distances between actual and expected proverbs: {var_distance}") |
|
|
|
|
|
test_results = { |
|
|
"model": model_name, |
|
|
"index_type": index_type.__name__, |
|
|
"prompts_test_set_length": len(prompts_test_set), |
|
|
"avg_distance": float(avg_distance), |
|
|
"var_distance": float(var_distance), |
|
|
"map": map[0], |
|
|
"map_fields": map[0].split("_"), |
|
|
"remarks": remarks, |
|
|
"pooling_method": pooling_method, |
|
|
} |
|
|
|
|
|
return test_results |
|
|
|
|
|
|
|
|
def generate_unique_id() -> str: |
|
|
"""Build a unique identifier including the current timestamp.""" |
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S") |
|
|
id = timestamp |
|
|
return id |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
MODELS = models.MODELS |
|
|
PROVERB_FIELD_MAPS = { |
|
|
"proverb_sentiment_usage": datasets.default_proverb_fields_selection |
|
|
} |
|
|
INDEX_TYPES = [indexes.DEFAULT_INDEX_TYPE] |
|
|
POOLING_METHODS = [models.DEFAULT_POOLING_METHOD] |
|
|
|
|
|
remarks = "ALL hyperparameters combinations, this is going to take a while..." |
|
|
|
|
|
def log_test_case(test_number: int, test_case_id: str) -> str: |
|
|
"""Local function to log the test case information using locally defined variables.""" |
|
|
|
|
|
|
|
|
max_len_models = max(len(model) for model in MODELS) |
|
|
max_len_maps = max(len(map) for map in PROVERB_FIELD_MAPS.keys()) |
|
|
max_len_index_types = max(len(index_type.__name__) |
|
|
for index_type in INDEX_TYPES) |
|
|
max_len_pooling_methods = max(len(pooling_method) |
|
|
for pooling_method in POOLING_METHODS) |
|
|
total_number_tests = len( |
|
|
MODELS) * len(PROVERB_FIELD_MAPS) * len(INDEX_TYPES) * len(POOLING_METHODS) |
|
|
max_len_test_number = len(str(total_number_tests)) |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"({str(test_number).rjust(max_len_test_number)}/{total_number_tests}) " + |
|
|
f"Test case {test_case_id}: " + |
|
|
f"model = {model_name.ljust(max_len_models)}, " + |
|
|
f"index type = {index_type.__name__.ljust(max_len_index_types)}, " + |
|
|
f"map = {map[0].ljust(max_len_maps)}, " + |
|
|
f"pooling = {pooling_method.ljust(max_len_pooling_methods)} " |
|
|
) |
|
|
|
|
|
tokenizer, model, proverbs, prompts_test_set = setup() |
|
|
|
|
|
|
|
|
tests_run_id = generate_unique_id() |
|
|
run_folder = os.path.join(f"tests_runs", tests_run_id) |
|
|
os.makedirs(run_folder) |
|
|
tests_run_file = os.path.join( |
|
|
run_folder, f"results_test_run_{tests_run_id}.json") |
|
|
|
|
|
|
|
|
shutil.copy2(datasets.PROMPTS_TEST_FILE, run_folder) |
|
|
|
|
|
tests_run_results = {} |
|
|
test_number = 1 |
|
|
|
|
|
for model_name in MODELS: |
|
|
model = models.load_model(model_name) |
|
|
tokenizer = models.load_tokenizer(model_name) |
|
|
for map in PROVERB_FIELD_MAPS.items(): |
|
|
for pooling_method in POOLING_METHODS: |
|
|
for index_type in INDEX_TYPES: |
|
|
|
|
|
test_case_id = generate_unique_id() |
|
|
|
|
|
log_test_case(test_number, test_case_id) |
|
|
test_case_results = test_distances( |
|
|
tokenizer, model, model_name, proverbs, prompts_test_set, map, index_type, pooling_method, remarks |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
test_case_file = os.path.join( |
|
|
run_folder, f"results_test_case_{test_case_id}.json") |
|
|
with open(test_case_file, "w") as f: |
|
|
json.dump(test_case_results, f, indent=2) |
|
|
|
|
|
tests_run_results[test_case_id] = test_case_results |
|
|
test_number += 1 |
|
|
|
|
|
|
|
|
with open(tests_run_file, "w") as f: |
|
|
json.dump(tests_run_results, f, indent=2) |
|
|
|