pauhmolins's picture
Upload app
1200d57
import json
import os
import random
import re
from src.customlogger import log_time, logger
PROVERBS_FILE = os.path.join("datasets", "proverbs.json")
PROMPTS_FILE = os.path.join("datasets", "prompts.json")
PROMPTS_TRAIN_FILE = os.path.join("datasets", "prompts_train.json")
PROMPTS_TEST_FILE = os.path.join("datasets", "prompts_test.json")
PROVERB_FIELDS = ["proverb", "themes", "sentiment", "explanation", "usage"]
@log_time
def load_proverbs(proverbs_file: str = PROVERBS_FILE) -> list[dict]:
logger.debug(f"Loading proverb dataset from '{proverbs_file}'...")
proverbs = load_dataset(proverbs_file)
logger.debug(f"Loaded {len(proverbs)} proverb entries")
return proverbs
@log_time
def load_prompts(prompts_file: str = PROMPTS_FILE) -> list[dict]:
logger.debug(f"Loading prompts dataset from '{prompts_file}'...")
prompts = load_dataset(prompts_file)
logger.debug(f"Loaded {len(prompts)} prompt entries")
return prompts
def load_dataset(file):
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
return data
def default_proverb_fields_selection(proverb):
"""Select fields from the proverb for text representation (proverb, sentiment, usage)."""
return [proverb["proverb"]] + proverb["themes"]
@log_time
def build_text_representations(proverbs: list[dict], map: callable = None) -> list[str]:
if not map:
map = default_proverb_fields_selection
"""Build text representations of proverbs for embedding."""
# Remove duplicate periods (some fields already have them)
text_representations = [
re.sub(r"\.+", ".", ". ".join(map(proverb))) for proverb in proverbs]
return text_representations
def prompts_dataset_splits_exists(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> bool:
"""Check if the prompt dataset splits exist."""
return os.path.exists(train_file) and os.path.exists(test_file)
@log_time
def load_prompt_dataset_splits(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]:
"""Load the prompt dataset splits."""
with open(train_file, "r", encoding="utf-8") as f:
train_set = json.load(f)
with open(test_file, "r", encoding="utf-8") as f:
test_set = json.load(f)
return train_set, test_set
@log_time
def split_dataset(dataset: list[dict], train_ratio: float = 0, seed: int = 42,
train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]:
"""Split a dataset into train and test sets and save them to JSON files.
Args:
dataset (list[dict]): The dataset to split.
train_ratio (float): The ratio of the dataset to use for training.
seed (int): The random seed for reproducibility.
train_file (str): Path to save the training dataset.
test_file (str): Path to save the testing dataset.
"""
logger.debug(
f"Splitting dataset into train and test with train_ratio={train_ratio}, seed={seed}")
random.seed(seed)
# Copy the dataset to avoid modifying the original
# (it's okay since it's a small dataset)
dataset = dataset.copy()
random.shuffle(dataset)
split_index = int(len(dataset) * train_ratio)
train_set = dataset[:split_index]
test_set = dataset[split_index:]
logger.debug(
f"Train set size: {len(train_set)}, Test set size: {len(test_set)}")
with open(train_file, "w", encoding="utf-8") as f:
json.dump(train_set, f, ensure_ascii=False, indent=2)
with open(test_file, "w", encoding="utf-8") as f:
json.dump(test_set, f, ensure_ascii=False, indent=2)
logger.info(f"Train dataset saved to '{train_file}'")
logger.info(f"Test dataset saved to '{test_file}'")
return train_set, test_set