|
|
import json |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from src.customlogger import log_time, logger |
|
|
|
|
|
PROVERBS_FILE = os.path.join("datasets", "proverbs.json") |
|
|
PROMPTS_FILE = os.path.join("datasets", "prompts.json") |
|
|
PROMPTS_TRAIN_FILE = os.path.join("datasets", "prompts_train.json") |
|
|
PROMPTS_TEST_FILE = os.path.join("datasets", "prompts_test.json") |
|
|
PROVERB_FIELDS = ["proverb", "themes", "sentiment", "explanation", "usage"] |
|
|
|
|
|
|
|
|
@log_time |
|
|
def load_proverbs(proverbs_file: str = PROVERBS_FILE) -> list[dict]: |
|
|
logger.debug(f"Loading proverb dataset from '{proverbs_file}'...") |
|
|
proverbs = load_dataset(proverbs_file) |
|
|
logger.debug(f"Loaded {len(proverbs)} proverb entries") |
|
|
return proverbs |
|
|
|
|
|
|
|
|
@log_time |
|
|
def load_prompts(prompts_file: str = PROMPTS_FILE) -> list[dict]: |
|
|
logger.debug(f"Loading prompts dataset from '{prompts_file}'...") |
|
|
prompts = load_dataset(prompts_file) |
|
|
logger.debug(f"Loaded {len(prompts)} prompt entries") |
|
|
return prompts |
|
|
|
|
|
|
|
|
def load_dataset(file): |
|
|
with open(file, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
return data |
|
|
|
|
|
|
|
|
def default_proverb_fields_selection(proverb): |
|
|
"""Select fields from the proverb for text representation (proverb, sentiment, usage).""" |
|
|
return [proverb["proverb"]] + proverb["themes"] |
|
|
|
|
|
|
|
|
@log_time |
|
|
def build_text_representations(proverbs: list[dict], map: callable = None) -> list[str]: |
|
|
if not map: |
|
|
map = default_proverb_fields_selection |
|
|
|
|
|
"""Build text representations of proverbs for embedding.""" |
|
|
|
|
|
text_representations = [ |
|
|
re.sub(r"\.+", ".", ". ".join(map(proverb))) for proverb in proverbs] |
|
|
|
|
|
return text_representations |
|
|
|
|
|
|
|
|
def prompts_dataset_splits_exists(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> bool: |
|
|
"""Check if the prompt dataset splits exist.""" |
|
|
return os.path.exists(train_file) and os.path.exists(test_file) |
|
|
|
|
|
|
|
|
@log_time |
|
|
def load_prompt_dataset_splits(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]: |
|
|
"""Load the prompt dataset splits.""" |
|
|
with open(train_file, "r", encoding="utf-8") as f: |
|
|
train_set = json.load(f) |
|
|
|
|
|
with open(test_file, "r", encoding="utf-8") as f: |
|
|
test_set = json.load(f) |
|
|
|
|
|
return train_set, test_set |
|
|
|
|
|
|
|
|
@log_time |
|
|
def split_dataset(dataset: list[dict], train_ratio: float = 0, seed: int = 42, |
|
|
train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]: |
|
|
"""Split a dataset into train and test sets and save them to JSON files. |
|
|
|
|
|
Args: |
|
|
dataset (list[dict]): The dataset to split. |
|
|
train_ratio (float): The ratio of the dataset to use for training. |
|
|
seed (int): The random seed for reproducibility. |
|
|
train_file (str): Path to save the training dataset. |
|
|
test_file (str): Path to save the testing dataset. |
|
|
""" |
|
|
logger.debug( |
|
|
f"Splitting dataset into train and test with train_ratio={train_ratio}, seed={seed}") |
|
|
random.seed(seed) |
|
|
|
|
|
|
|
|
dataset = dataset.copy() |
|
|
random.shuffle(dataset) |
|
|
|
|
|
split_index = int(len(dataset) * train_ratio) |
|
|
train_set = dataset[:split_index] |
|
|
test_set = dataset[split_index:] |
|
|
|
|
|
logger.debug( |
|
|
f"Train set size: {len(train_set)}, Test set size: {len(test_set)}") |
|
|
|
|
|
with open(train_file, "w", encoding="utf-8") as f: |
|
|
json.dump(train_set, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
with open(test_file, "w", encoding="utf-8") as f: |
|
|
json.dump(test_set, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logger.info(f"Train dataset saved to '{train_file}'") |
|
|
logger.info(f"Test dataset saved to '{test_file}'") |
|
|
|
|
|
return train_set, test_set |
|
|
|