import json import os import random import re from src.customlogger import log_time, logger PROVERBS_FILE = os.path.join("datasets", "proverbs.json") PROMPTS_FILE = os.path.join("datasets", "prompts.json") PROMPTS_TRAIN_FILE = os.path.join("datasets", "prompts_train.json") PROMPTS_TEST_FILE = os.path.join("datasets", "prompts_test.json") PROVERB_FIELDS = ["proverb", "themes", "sentiment", "explanation", "usage"] @log_time def load_proverbs(proverbs_file: str = PROVERBS_FILE) -> list[dict]: logger.debug(f"Loading proverb dataset from '{proverbs_file}'...") proverbs = load_dataset(proverbs_file) logger.debug(f"Loaded {len(proverbs)} proverb entries") return proverbs @log_time def load_prompts(prompts_file: str = PROMPTS_FILE) -> list[dict]: logger.debug(f"Loading prompts dataset from '{prompts_file}'...") prompts = load_dataset(prompts_file) logger.debug(f"Loaded {len(prompts)} prompt entries") return prompts def load_dataset(file): with open(file, "r", encoding="utf-8") as f: data = json.load(f) return data def default_proverb_fields_selection(proverb): """Select fields from the proverb for text representation (proverb, sentiment, usage).""" return [proverb["proverb"]] + proverb["themes"] @log_time def build_text_representations(proverbs: list[dict], map: callable = None) -> list[str]: if not map: map = default_proverb_fields_selection """Build text representations of proverbs for embedding.""" # Remove duplicate periods (some fields already have them) text_representations = [ re.sub(r"\.+", ".", ". ".join(map(proverb))) for proverb in proverbs] return text_representations def prompts_dataset_splits_exists(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> bool: """Check if the prompt dataset splits exist.""" return os.path.exists(train_file) and os.path.exists(test_file) @log_time def load_prompt_dataset_splits(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]: """Load the prompt dataset splits.""" with open(train_file, "r", encoding="utf-8") as f: train_set = json.load(f) with open(test_file, "r", encoding="utf-8") as f: test_set = json.load(f) return train_set, test_set @log_time def split_dataset(dataset: list[dict], train_ratio: float = 0, seed: int = 42, train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]: """Split a dataset into train and test sets and save them to JSON files. Args: dataset (list[dict]): The dataset to split. train_ratio (float): The ratio of the dataset to use for training. seed (int): The random seed for reproducibility. train_file (str): Path to save the training dataset. test_file (str): Path to save the testing dataset. """ logger.debug( f"Splitting dataset into train and test with train_ratio={train_ratio}, seed={seed}") random.seed(seed) # Copy the dataset to avoid modifying the original # (it's okay since it's a small dataset) dataset = dataset.copy() random.shuffle(dataset) split_index = int(len(dataset) * train_ratio) train_set = dataset[:split_index] test_set = dataset[split_index:] logger.debug( f"Train set size: {len(train_set)}, Test set size: {len(test_set)}") with open(train_file, "w", encoding="utf-8") as f: json.dump(train_set, f, ensure_ascii=False, indent=2) with open(test_file, "w", encoding="utf-8") as f: json.dump(test_set, f, ensure_ascii=False, indent=2) logger.info(f"Train dataset saved to '{train_file}'") logger.info(f"Test dataset saved to '{test_file}'") return train_set, test_set