File size: 3,854 Bytes
1200d57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
import os
import random
import re
from src.customlogger import log_time, logger

PROVERBS_FILE = os.path.join("datasets", "proverbs.json")
PROMPTS_FILE = os.path.join("datasets", "prompts.json")
PROMPTS_TRAIN_FILE = os.path.join("datasets", "prompts_train.json")
PROMPTS_TEST_FILE = os.path.join("datasets", "prompts_test.json")
PROVERB_FIELDS = ["proverb", "themes", "sentiment", "explanation", "usage"]


@log_time
def load_proverbs(proverbs_file: str = PROVERBS_FILE) -> list[dict]:
    logger.debug(f"Loading proverb dataset from '{proverbs_file}'...")
    proverbs = load_dataset(proverbs_file)
    logger.debug(f"Loaded {len(proverbs)} proverb entries")
    return proverbs


@log_time
def load_prompts(prompts_file: str = PROMPTS_FILE) -> list[dict]:
    logger.debug(f"Loading prompts dataset from '{prompts_file}'...")
    prompts = load_dataset(prompts_file)
    logger.debug(f"Loaded {len(prompts)} prompt entries")
    return prompts


def load_dataset(file):
    with open(file, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


def default_proverb_fields_selection(proverb):
    """Select fields from the proverb for text representation (proverb, sentiment, usage)."""
    return [proverb["proverb"]] + proverb["themes"]


@log_time
def build_text_representations(proverbs: list[dict], map: callable = None) -> list[str]:
    if not map:
        map = default_proverb_fields_selection

    """Build text representations of proverbs for embedding."""
    # Remove duplicate periods (some fields already have them)
    text_representations = [
        re.sub(r"\.+", ".", ". ".join(map(proverb))) for proverb in proverbs]

    return text_representations


def prompts_dataset_splits_exists(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> bool:
    """Check if the prompt dataset splits exist."""
    return os.path.exists(train_file) and os.path.exists(test_file)


@log_time
def load_prompt_dataset_splits(train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]:
    """Load the prompt dataset splits."""
    with open(train_file, "r", encoding="utf-8") as f:
        train_set = json.load(f)

    with open(test_file, "r", encoding="utf-8") as f:
        test_set = json.load(f)

    return train_set, test_set


@log_time
def split_dataset(dataset: list[dict], train_ratio: float = 0, seed: int = 42,
                  train_file: str = PROMPTS_TRAIN_FILE, test_file: str = PROMPTS_TEST_FILE) -> tuple[list[dict], list[dict]]:
    """Split a dataset into train and test sets and save them to JSON files.

    Args:
        dataset (list[dict]): The dataset to split.
        train_ratio (float): The ratio of the dataset to use for training.
        seed (int): The random seed for reproducibility.
        train_file (str): Path to save the training dataset.
        test_file (str): Path to save the testing dataset.
    """
    logger.debug(
        f"Splitting dataset into train and test with train_ratio={train_ratio}, seed={seed}")
    random.seed(seed)
    # Copy the dataset to avoid modifying the original
    # (it's okay since it's a small dataset)
    dataset = dataset.copy()
    random.shuffle(dataset)

    split_index = int(len(dataset) * train_ratio)
    train_set = dataset[:split_index]
    test_set = dataset[split_index:]

    logger.debug(
        f"Train set size: {len(train_set)}, Test set size: {len(test_set)}")

    with open(train_file, "w", encoding="utf-8") as f:
        json.dump(train_set, f, ensure_ascii=False, indent=2)

    with open(test_file, "w", encoding="utf-8") as f:
        json.dump(test_set, f, ensure_ascii=False, indent=2)

    logger.info(f"Train dataset saved to '{train_file}'")
    logger.info(f"Test dataset saved to '{test_file}'")

    return train_set, test_set