Upload 7 files
Browse files- handler.py +126 -0
- main.py +596 -0
- src/models.py +74 -0
- src/models_utils.py +561 -0
- src/plot_helpers.py +58 -0
- src/running_params.py +3 -0
- src/utiles_data.py +737 -0
    	
        handler.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, List, Any
         | 
| 2 | 
            +
            from transformers import AutoConfig, AutoTokenizer
         | 
| 3 | 
            +
            from src.models import DNikudModel, ModelConfig
         | 
| 4 | 
            +
            from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
         | 
| 5 | 
            +
            from src.utiles_data import Nikud, NikudDataset
         | 
| 6 | 
            +
            from src.models_utils import predict_single, predict
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class EndpointHandler:
         | 
| 13 | 
            +
                def __init__(self, path=""):
         | 
| 14 | 
            +
                    self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
         | 
| 17 | 
            +
                    dir_model_config = os.path.join("models", "config.yml")
         | 
| 18 | 
            +
                    self.config = ModelConfig.load_from_file(dir_model_config)
         | 
| 19 | 
            +
                    self.model = DNikudModel(
         | 
| 20 | 
            +
                        self.config,
         | 
| 21 | 
            +
                        len(Nikud.label_2_id["nikud"]),
         | 
| 22 | 
            +
                        len(Nikud.label_2_id["dagesh"]),
         | 
| 23 | 
            +
                        len(Nikud.label_2_id["sin"]),
         | 
| 24 | 
            +
                        device=self.DEVICE,
         | 
| 25 | 
            +
                    ).to(self.DEVICE)
         | 
| 26 | 
            +
                    state_dict_model = self.model.state_dict()
         | 
| 27 | 
            +
                    state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
         | 
| 28 | 
            +
                    self.model.load_state_dict(state_dict_model)
         | 
| 29 | 
            +
                    self.max_length = MAX_LENGTH_SEN
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def back_2_text(self, labels, text):
         | 
| 32 | 
            +
                    nikud = Nikud()
         | 
| 33 | 
            +
                    new_line = ""
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    for indx_char, c in enumerate(text):
         | 
| 36 | 
            +
                        new_line += (
         | 
| 37 | 
            +
                            c
         | 
| 38 | 
            +
                            + nikud.id_2_char(labels[indx_char][1][1], "dagesh")
         | 
| 39 | 
            +
                            + nikud.id_2_char(labels[indx_char][1][2], "sin")
         | 
| 40 | 
            +
                            + nikud.id_2_char(labels[indx_char][1][0], "nikud")
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                        print(indx_char, c)
         | 
| 43 | 
            +
                    print(labels)
         | 
| 44 | 
            +
                    return new_line
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def prepare_data(self, data, name="train"):
         | 
| 47 | 
            +
                    print("Data = ", data)
         | 
| 48 | 
            +
                    dataset = []
         | 
| 49 | 
            +
                    for index, (sentence, label) in tqdm(
         | 
| 50 | 
            +
                        enumerate(data), desc=f"Prepare data {name}"
         | 
| 51 | 
            +
                    ):
         | 
| 52 | 
            +
                        encoded_sequence = self.tokenizer.encode_plus(
         | 
| 53 | 
            +
                            sentence,
         | 
| 54 | 
            +
                            add_special_tokens=True,
         | 
| 55 | 
            +
                            max_length=self.max_length,
         | 
| 56 | 
            +
                            padding="max_length",
         | 
| 57 | 
            +
                            truncation=True,
         | 
| 58 | 
            +
                            return_attention_mask=True,
         | 
| 59 | 
            +
                            return_tensors="pt",
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
                        label_lists = [
         | 
| 62 | 
            +
                            [letter.nikud, letter.dagesh, letter.sin] for letter in label
         | 
| 63 | 
            +
                        ]
         | 
| 64 | 
            +
                        label = torch.tensor(
         | 
| 65 | 
            +
                            [
         | 
| 66 | 
            +
                                [
         | 
| 67 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 68 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 69 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 70 | 
            +
                                ]
         | 
| 71 | 
            +
                            ]
         | 
| 72 | 
            +
                            + label_lists[: (self.max_length - 1)]
         | 
| 73 | 
            +
                            + [
         | 
| 74 | 
            +
                                [
         | 
| 75 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 76 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 77 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 78 | 
            +
                                ]
         | 
| 79 | 
            +
                                for i in range(self.max_length - len(label) - 1)
         | 
| 80 | 
            +
                            ]
         | 
| 81 | 
            +
                        )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                        dataset.append(
         | 
| 84 | 
            +
                            (
         | 
| 85 | 
            +
                                encoded_sequence["input_ids"][0],
         | 
| 86 | 
            +
                                encoded_sequence["attention_mask"][0],
         | 
| 87 | 
            +
                                label,
         | 
| 88 | 
            +
                            )
         | 
| 89 | 
            +
                        )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    self.prepered_data = dataset
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def predict_single_text(
         | 
| 94 | 
            +
                    self,
         | 
| 95 | 
            +
                    text,
         | 
| 96 | 
            +
                ):
         | 
| 97 | 
            +
                    dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
         | 
| 98 | 
            +
                    data, orig_data = dataset.read_single_text(text)
         | 
| 99 | 
            +
                    print("data", data, len(data))
         | 
| 100 | 
            +
                    dataset.prepare_data(name="inference")
         | 
| 101 | 
            +
                    mtb_prediction_dl = torch.utils.data.DataLoader(
         | 
| 102 | 
            +
                        dataset.prepered_data, batch_size=BATCH_SIZE
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    # print("dataset", dataset, len(dataset))
         | 
| 105 | 
            +
                    # data = self.tokenizer(text, return_tensors="pt")
         | 
| 106 | 
            +
                    all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
         | 
| 107 | 
            +
                    text_data_with_labels = dataset.back_2_text(labels=all_labels)
         | 
| 108 | 
            +
                    # all_labels = predict_single(self.model, dataset, self.DEVICE)
         | 
| 109 | 
            +
                    return text_data_with_labels
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
         | 
| 112 | 
            +
                    """
         | 
| 113 | 
            +
                    data args:
         | 
| 114 | 
            +
                    """
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # get inputs
         | 
| 117 | 
            +
                    inputs = data.pop("text", data)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # run normal prediction
         | 
| 120 | 
            +
                    prediction = self.predict_single_text(inputs)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # result = []
         | 
| 123 | 
            +
                    # for pred in prediction:
         | 
| 124 | 
            +
                    #     result.append(self.back_2_text(pred, inputs))
         | 
| 125 | 
            +
                    # result = self.back_2_text(prediction, inputs)
         | 
| 126 | 
            +
                    return prediction
         | 
    	
        main.py
    ADDED
    
    | @@ -0,0 +1,596 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import sys
         | 
| 5 | 
            +
            from datetime import datetime
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
            from logging.handlers import RotatingFileHandler
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # ML
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            from transformers import AutoConfig, AutoTokenizer
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # DL
         | 
| 16 | 
            +
            from src.models import DNikudModel, ModelConfig
         | 
| 17 | 
            +
            from src.models_utils import training, evaluate, predict
         | 
| 18 | 
            +
            from src.plot_helpers import (
         | 
| 19 | 
            +
                generate_plot_by_nikud_dagesh_sin_dict,
         | 
| 20 | 
            +
                generate_word_and_letter_accuracy_plot,
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
         | 
| 23 | 
            +
            from src.utiles_data import (
         | 
| 24 | 
            +
                NikudDataset,
         | 
| 25 | 
            +
                Nikud,
         | 
| 26 | 
            +
                create_missing_folders,
         | 
| 27 | 
            +
                extract_text_to_compare_nakdimon,
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 31 | 
            +
            assert DEVICE == "cuda"
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_logger(
         | 
| 35 | 
            +
                log_level, name_func, date_time=datetime.now().strftime("%d_%m_%y__%H_%M")
         | 
| 36 | 
            +
            ):
         | 
| 37 | 
            +
                log_location = os.path.join(
         | 
| 38 | 
            +
                    os.path.join(Path(__file__).parent, "logging"),
         | 
| 39 | 
            +
                    f"log_model_{name_func}_{date_time}",
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                create_missing_folders(log_location)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                log_format = "%(asctime)s %(levelname)-8s Thread_%(thread)-6d ::: %(funcName)s(%(lineno)d) ::: %(message)s"
         | 
| 44 | 
            +
                logger = logging.getLogger("algo")
         | 
| 45 | 
            +
                logger.setLevel(getattr(logging, log_level))
         | 
| 46 | 
            +
                cnsl_log_formatter = logging.Formatter(log_format)
         | 
| 47 | 
            +
                cnsl_handler = logging.StreamHandler()
         | 
| 48 | 
            +
                cnsl_handler.setFormatter(cnsl_log_formatter)
         | 
| 49 | 
            +
                cnsl_handler.setLevel(log_level)
         | 
| 50 | 
            +
                logger.addHandler(cnsl_handler)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                create_missing_folders(log_location)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                file_location = os.path.join(log_location, "Diacritization_Model_DEBUG.log")
         | 
| 55 | 
            +
                file_log_formatter = logging.Formatter(log_format)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                SINGLE_LOG_SIZE = 2 * 1024 * 1024  # in Bytes
         | 
| 58 | 
            +
                MAX_LOG_FILES = 20
         | 
| 59 | 
            +
                file_handler = RotatingFileHandler(
         | 
| 60 | 
            +
                    file_location, mode="a", maxBytes=SINGLE_LOG_SIZE, backupCount=MAX_LOG_FILES
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
                file_handler.setFormatter(file_log_formatter)
         | 
| 63 | 
            +
                file_handler.setLevel(log_level)
         | 
| 64 | 
            +
                logger.addHandler(file_handler)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return logger
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def evaluate_text(
         | 
| 70 | 
            +
                path,
         | 
| 71 | 
            +
                dnikud_model,
         | 
| 72 | 
            +
                tokenizer_tavbert,
         | 
| 73 | 
            +
                logger,
         | 
| 74 | 
            +
                plots_folder=None,
         | 
| 75 | 
            +
                batch_size=BATCH_SIZE,
         | 
| 76 | 
            +
            ):
         | 
| 77 | 
            +
                path_name = os.path.basename(path)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                msg = f"evaluate text: {path_name} on D-nikud Model"
         | 
| 80 | 
            +
                logger.debug(msg)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                if os.path.isfile(path):
         | 
| 83 | 
            +
                    dataset = NikudDataset(
         | 
| 84 | 
            +
                        tokenizer_tavbert, file=path, logger=logger, max_length=MAX_LENGTH_SEN
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                elif os.path.isdir(path):
         | 
| 87 | 
            +
                    dataset = NikudDataset(
         | 
| 88 | 
            +
                        tokenizer_tavbert, folder=path, logger=logger, max_length=MAX_LENGTH_SEN
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
                else:
         | 
| 91 | 
            +
                    raise Exception("input path doesnt exist")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                dataset.prepare_data(name="evaluate")
         | 
| 94 | 
            +
                mtb_dl = torch.utils.data.DataLoader(dataset.prepered_data, batch_size=batch_size)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                word_level_correct, letter_level_correct_dev = evaluate(
         | 
| 97 | 
            +
                    dnikud_model, mtb_dl, plots_folder, device=DEVICE
         | 
| 98 | 
            +
                )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                msg = (
         | 
| 101 | 
            +
                    f"Dnikud Model\n{path_name} evaluate\nLetter level accuracy:{letter_level_correct_dev}\n"
         | 
| 102 | 
            +
                    f"Word level accuracy: {word_level_correct}"
         | 
| 103 | 
            +
                )
         | 
| 104 | 
            +
                logger.debug(msg)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def predict_text(
         | 
| 108 | 
            +
                text_file,
         | 
| 109 | 
            +
                tokenizer_tavbert,
         | 
| 110 | 
            +
                output_file,
         | 
| 111 | 
            +
                logger,
         | 
| 112 | 
            +
                dnikud_model,
         | 
| 113 | 
            +
                compare_nakdimon=False,
         | 
| 114 | 
            +
            ):
         | 
| 115 | 
            +
                dataset = NikudDataset(
         | 
| 116 | 
            +
                    tokenizer_tavbert, file=text_file, logger=logger, max_length=MAX_LENGTH_SEN
         | 
| 117 | 
            +
                )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                dataset.prepare_data(name="prediction")
         | 
| 120 | 
            +
                mtb_prediction_dl = torch.utils.data.DataLoader(
         | 
| 121 | 
            +
                    dataset.prepered_data, batch_size=BATCH_SIZE
         | 
| 122 | 
            +
                )
         | 
| 123 | 
            +
                all_labels = predict(dnikud_model, mtb_prediction_dl, DEVICE)
         | 
| 124 | 
            +
                text_data_with_labels = dataset.back_2_text(labels=all_labels)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                if output_file is None:
         | 
| 127 | 
            +
                    for line in text_data_with_labels:
         | 
| 128 | 
            +
                        print(line)
         | 
| 129 | 
            +
                else:
         | 
| 130 | 
            +
                    with open(output_file, "w", encoding="utf-8") as f:
         | 
| 131 | 
            +
                        if compare_nakdimon:
         | 
| 132 | 
            +
                            f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
         | 
| 133 | 
            +
                        else:
         | 
| 134 | 
            +
                            f.write(text_data_with_labels)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def predict_folder(
         | 
| 138 | 
            +
                folder,
         | 
| 139 | 
            +
                output_folder,
         | 
| 140 | 
            +
                logger,
         | 
| 141 | 
            +
                tokenizer_tavbert,
         | 
| 142 | 
            +
                dnikud_model,
         | 
| 143 | 
            +
                compare_nakdimon=False,
         | 
| 144 | 
            +
            ):
         | 
| 145 | 
            +
                create_missing_folders(output_folder)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                for filename in os.listdir(folder):
         | 
| 148 | 
            +
                    file_path = os.path.join(folder, filename)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if filename.lower().endswith(".txt") and os.path.isfile(file_path):
         | 
| 151 | 
            +
                        output_file = os.path.join(output_folder, filename)
         | 
| 152 | 
            +
                        predict_text(
         | 
| 153 | 
            +
                            file_path,
         | 
| 154 | 
            +
                            output_file=output_file,
         | 
| 155 | 
            +
                            logger=logger,
         | 
| 156 | 
            +
                            tokenizer_tavbert=tokenizer_tavbert,
         | 
| 157 | 
            +
                            dnikud_model=dnikud_model,
         | 
| 158 | 
            +
                            compare_nakdimon=compare_nakdimon,
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                    elif (
         | 
| 161 | 
            +
                        os.path.isdir(file_path) and filename != ".git" and filename != "README.md"
         | 
| 162 | 
            +
                    ):
         | 
| 163 | 
            +
                        sub_folder = file_path
         | 
| 164 | 
            +
                        sub_folder_output = os.path.join(output_folder, filename)
         | 
| 165 | 
            +
                        predict_folder(
         | 
| 166 | 
            +
                            sub_folder,
         | 
| 167 | 
            +
                            sub_folder_output,
         | 
| 168 | 
            +
                            logger,
         | 
| 169 | 
            +
                            tokenizer_tavbert,
         | 
| 170 | 
            +
                            dnikud_model,
         | 
| 171 | 
            +
                            compare_nakdimon=compare_nakdimon,
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
             | 
| 175 | 
            +
            def update_compare_folder(folder, output_folder):
         | 
| 176 | 
            +
                create_missing_folders(output_folder)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                for filename in os.listdir(folder):
         | 
| 179 | 
            +
                    file_path = os.path.join(folder, filename)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    if filename.lower().endswith(".txt") and os.path.isfile(file_path):
         | 
| 182 | 
            +
                        output_file = os.path.join(output_folder, filename)
         | 
| 183 | 
            +
                        with open(file_path, "r", encoding="utf-8") as f:
         | 
| 184 | 
            +
                            text_data_with_labels = f.read()
         | 
| 185 | 
            +
                        with open(output_file, "w", encoding="utf-8") as f:
         | 
| 186 | 
            +
                            f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
         | 
| 187 | 
            +
                    elif os.path.isdir(file_path) and filename != ".git":
         | 
| 188 | 
            +
                        sub_folder = file_path
         | 
| 189 | 
            +
                        sub_folder_output = os.path.join(output_folder, filename)
         | 
| 190 | 
            +
                        update_compare_folder(sub_folder, sub_folder_output)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
             | 
| 193 | 
            +
            def check_files_excepted(folder):
         | 
| 194 | 
            +
                for filename in os.listdir(folder):
         | 
| 195 | 
            +
                    file_path = os.path.join(folder, filename)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    if filename.lower().endswith(".txt") and os.path.isfile(file_path):
         | 
| 198 | 
            +
                        try:
         | 
| 199 | 
            +
                            x = NikudDataset(None, file=file_path)
         | 
| 200 | 
            +
                        except:
         | 
| 201 | 
            +
                            print(f"failed in file: {filename}")
         | 
| 202 | 
            +
                    elif os.path.isdir(file_path) and filename != ".git":
         | 
| 203 | 
            +
                        check_files_excepted(file_path)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
             | 
| 206 | 
            +
            def do_predict(
         | 
| 207 | 
            +
                input_path, output_path, tokenizer_tavbert, logger, dnikud_model, compare_nakdimon
         | 
| 208 | 
            +
            ):
         | 
| 209 | 
            +
                if os.path.isdir(input_path):
         | 
| 210 | 
            +
                    predict_folder(
         | 
| 211 | 
            +
                        input_path,
         | 
| 212 | 
            +
                        output_path,
         | 
| 213 | 
            +
                        logger,
         | 
| 214 | 
            +
                        tokenizer_tavbert,
         | 
| 215 | 
            +
                        dnikud_model,
         | 
| 216 | 
            +
                        compare_nakdimon=compare_nakdimon,
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
                elif os.path.isfile(input_path):
         | 
| 219 | 
            +
                    predict_text(
         | 
| 220 | 
            +
                        input_path,
         | 
| 221 | 
            +
                        output_file=output_path,
         | 
| 222 | 
            +
                        logger=logger,
         | 
| 223 | 
            +
                        tokenizer_tavbert=tokenizer_tavbert,
         | 
| 224 | 
            +
                        dnikud_model=dnikud_model,
         | 
| 225 | 
            +
                        compare_nakdimon=compare_nakdimon,
         | 
| 226 | 
            +
                    )
         | 
| 227 | 
            +
                else:
         | 
| 228 | 
            +
                    raise Exception("Input file not exist")
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            def evaluate_folder(folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder):
         | 
| 232 | 
            +
                msg = f"evaluate sub folder: {folder_path}"
         | 
| 233 | 
            +
                logger.info(msg)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                evaluate_text(
         | 
| 236 | 
            +
                    folder_path,
         | 
| 237 | 
            +
                    dnikud_model=dnikud_model,
         | 
| 238 | 
            +
                    tokenizer_tavbert=tokenizer_tavbert,
         | 
| 239 | 
            +
                    logger=logger,
         | 
| 240 | 
            +
                    plots_folder=plots_folder,
         | 
| 241 | 
            +
                    batch_size=BATCH_SIZE,
         | 
| 242 | 
            +
                )
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                msg = f"\n***************************************\n"
         | 
| 245 | 
            +
                logger.info(msg)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                for sub_folder_name in os.listdir(folder_path):
         | 
| 248 | 
            +
                    sub_folder_path = os.path.join(folder_path, sub_folder_name)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    if (
         | 
| 251 | 
            +
                        not os.path.isdir(sub_folder_path)
         | 
| 252 | 
            +
                        or sub_folder_path == ".git"
         | 
| 253 | 
            +
                        or "not_use" in sub_folder_path
         | 
| 254 | 
            +
                        or "NakdanResults" in sub_folder_path
         | 
| 255 | 
            +
                    ):
         | 
| 256 | 
            +
                        continue
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    evaluate_folder(
         | 
| 259 | 
            +
                        sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def do_evaluate(
         | 
| 264 | 
            +
                input_path,
         | 
| 265 | 
            +
                logger,
         | 
| 266 | 
            +
                dnikud_model,
         | 
| 267 | 
            +
                tokenizer_tavbert,
         | 
| 268 | 
            +
                plots_folder,
         | 
| 269 | 
            +
                eval_sub_folders=False,
         | 
| 270 | 
            +
            ):
         | 
| 271 | 
            +
                msg = f"evaluate all_data: {input_path}"
         | 
| 272 | 
            +
                logger.info(msg)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                evaluate_text(
         | 
| 275 | 
            +
                    input_path,
         | 
| 276 | 
            +
                    dnikud_model=dnikud_model,
         | 
| 277 | 
            +
                    tokenizer_tavbert=tokenizer_tavbert,
         | 
| 278 | 
            +
                    logger=logger,
         | 
| 279 | 
            +
                    plots_folder=plots_folder,
         | 
| 280 | 
            +
                    batch_size=BATCH_SIZE,
         | 
| 281 | 
            +
                )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                msg = f"\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n"
         | 
| 284 | 
            +
                logger.info(msg)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                if eval_sub_folders:
         | 
| 287 | 
            +
                    for sub_folder_name in os.listdir(input_path):
         | 
| 288 | 
            +
                        sub_folder_path = os.path.join(input_path, sub_folder_name)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                        if (
         | 
| 291 | 
            +
                            not os.path.isdir(sub_folder_path)
         | 
| 292 | 
            +
                            or sub_folder_path == ".git"
         | 
| 293 | 
            +
                            or "not_use" in sub_folder_path
         | 
| 294 | 
            +
                            or "NakdanResults" in sub_folder_path
         | 
| 295 | 
            +
                        ):
         | 
| 296 | 
            +
                            continue
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        evaluate_folder(
         | 
| 299 | 
            +
                            sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
         | 
| 300 | 
            +
                        )
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            def do_train(
         | 
| 304 | 
            +
                logger,
         | 
| 305 | 
            +
                plots_folder,
         | 
| 306 | 
            +
                dir_model_config,
         | 
| 307 | 
            +
                tokenizer_tavbert,
         | 
| 308 | 
            +
                dnikud_model,
         | 
| 309 | 
            +
                output_trained_model_dir,
         | 
| 310 | 
            +
                data_folder,
         | 
| 311 | 
            +
                n_epochs,
         | 
| 312 | 
            +
                checkpoints_frequency,
         | 
| 313 | 
            +
                learning_rate,
         | 
| 314 | 
            +
                batch_size,
         | 
| 315 | 
            +
            ):
         | 
| 316 | 
            +
                msg = "Loading data..."
         | 
| 317 | 
            +
                logger.debug(msg)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                dataset_train = NikudDataset(
         | 
| 320 | 
            +
                    tokenizer_tavbert,
         | 
| 321 | 
            +
                    folder=os.path.join(data_folder, "train"),
         | 
| 322 | 
            +
                    logger=logger,
         | 
| 323 | 
            +
                    max_length=MAX_LENGTH_SEN,
         | 
| 324 | 
            +
                    is_train=True,
         | 
| 325 | 
            +
                )
         | 
| 326 | 
            +
                dataset_dev = NikudDataset(
         | 
| 327 | 
            +
                    tokenizer=tokenizer_tavbert,
         | 
| 328 | 
            +
                    folder=os.path.join(data_folder, "dev"),
         | 
| 329 | 
            +
                    logger=logger,
         | 
| 330 | 
            +
                    max_length=dataset_train.max_length,
         | 
| 331 | 
            +
                    is_train=True,
         | 
| 332 | 
            +
                )
         | 
| 333 | 
            +
                dataset_test = NikudDataset(
         | 
| 334 | 
            +
                    tokenizer=tokenizer_tavbert,
         | 
| 335 | 
            +
                    folder=os.path.join(data_folder, "test"),
         | 
| 336 | 
            +
                    logger=logger,
         | 
| 337 | 
            +
                    max_length=dataset_train.max_length,
         | 
| 338 | 
            +
                    is_train=True,
         | 
| 339 | 
            +
                )
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                dataset_train.show_data_labels(plots_folder=plots_folder)
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                msg = f"Max length of data: {dataset_train.max_length}"
         | 
| 344 | 
            +
                logger.debug(msg)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                msg = (
         | 
| 347 | 
            +
                    f"Num rows in train data: {len(dataset_train.data)}, "
         | 
| 348 | 
            +
                    f"Num rows in dev data: {len(dataset_dev.data)}, "
         | 
| 349 | 
            +
                    f"Num rows in test data: {len(dataset_test.data)}"
         | 
| 350 | 
            +
                )
         | 
| 351 | 
            +
                logger.debug(msg)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                msg = "Loading tokenizer and prepare data..."
         | 
| 354 | 
            +
                logger.debug(msg)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                dataset_train.prepare_data(name="train")
         | 
| 357 | 
            +
                dataset_dev.prepare_data(name="dev")
         | 
| 358 | 
            +
                dataset_test.prepare_data(name="test")
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                mtb_train_dl = torch.utils.data.DataLoader(
         | 
| 361 | 
            +
                    dataset_train.prepered_data, batch_size=batch_size
         | 
| 362 | 
            +
                )
         | 
| 363 | 
            +
                mtb_dev_dl = torch.utils.data.DataLoader(
         | 
| 364 | 
            +
                    dataset_dev.prepered_data, batch_size=batch_size
         | 
| 365 | 
            +
                )
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                if not os.path.isfile(dir_model_config):
         | 
| 368 | 
            +
                    our_model_config = ModelConfig(dataset_train.max_length)
         | 
| 369 | 
            +
                    our_model_config.save_to_file(dir_model_config)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                optimizer = torch.optim.Adam(dnikud_model.parameters(), lr=learning_rate)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                msg = "training..."
         | 
| 374 | 
            +
                logger.debug(msg)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                criterion_nikud = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
         | 
| 377 | 
            +
                    DEVICE
         | 
| 378 | 
            +
                )
         | 
| 379 | 
            +
                criterion_dagesh = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
         | 
| 380 | 
            +
                    DEVICE
         | 
| 381 | 
            +
                )
         | 
| 382 | 
            +
                criterion_sin = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(DEVICE)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                training_params = {
         | 
| 385 | 
            +
                    "n_epochs": n_epochs,
         | 
| 386 | 
            +
                    "checkpoints_frequency": checkpoints_frequency,
         | 
| 387 | 
            +
                }
         | 
| 388 | 
            +
                (
         | 
| 389 | 
            +
                    best_model_details,
         | 
| 390 | 
            +
                    best_accuracy,
         | 
| 391 | 
            +
                    epochs_loss_train_values,
         | 
| 392 | 
            +
                    steps_loss_train_values,
         | 
| 393 | 
            +
                    loss_dev_values,
         | 
| 394 | 
            +
                    accuracy_dev_values,
         | 
| 395 | 
            +
                ) = training(
         | 
| 396 | 
            +
                    dnikud_model,
         | 
| 397 | 
            +
                    mtb_train_dl,
         | 
| 398 | 
            +
                    mtb_dev_dl,
         | 
| 399 | 
            +
                    criterion_nikud,
         | 
| 400 | 
            +
                    criterion_dagesh,
         | 
| 401 | 
            +
                    criterion_sin,
         | 
| 402 | 
            +
                    training_params,
         | 
| 403 | 
            +
                    logger,
         | 
| 404 | 
            +
                    output_trained_model_dir,
         | 
| 405 | 
            +
                    optimizer,
         | 
| 406 | 
            +
                    device=DEVICE,
         | 
| 407 | 
            +
                )
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                generate_plot_by_nikud_dagesh_sin_dict(
         | 
| 410 | 
            +
                    epochs_loss_train_values, "Train epochs loss", "Loss", plots_folder
         | 
| 411 | 
            +
                )
         | 
| 412 | 
            +
                generate_plot_by_nikud_dagesh_sin_dict(
         | 
| 413 | 
            +
                    steps_loss_train_values, "Train steps loss", "Loss", plots_folder
         | 
| 414 | 
            +
                )
         | 
| 415 | 
            +
                generate_plot_by_nikud_dagesh_sin_dict(
         | 
| 416 | 
            +
                    loss_dev_values, "Dev epochs loss", "Loss", plots_folder
         | 
| 417 | 
            +
                )
         | 
| 418 | 
            +
                generate_plot_by_nikud_dagesh_sin_dict(
         | 
| 419 | 
            +
                    accuracy_dev_values, "Dev accuracy", "Accuracy", plots_folder
         | 
| 420 | 
            +
                )
         | 
| 421 | 
            +
                generate_word_and_letter_accuracy_plot(
         | 
| 422 | 
            +
                    accuracy_dev_values, "Accuracy", plots_folder
         | 
| 423 | 
            +
                )
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                msg = "Done"
         | 
| 426 | 
            +
                logger.info(msg)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
             | 
| 429 | 
            +
            if __name__ == "__main__":
         | 
| 430 | 
            +
                tokenizer_tavbert = AutoTokenizer.from_pretrained("tau/tavbert-he")
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 433 | 
            +
                    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
         | 
| 434 | 
            +
                    description="""Predict D-nikud""",
         | 
| 435 | 
            +
                )
         | 
| 436 | 
            +
                parser.add_argument(
         | 
| 437 | 
            +
                    "-l",
         | 
| 438 | 
            +
                    "--log",
         | 
| 439 | 
            +
                    dest="log_level",
         | 
| 440 | 
            +
                    choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
         | 
| 441 | 
            +
                    default="DEBUG",
         | 
| 442 | 
            +
                    help="Set the logging level",
         | 
| 443 | 
            +
                )
         | 
| 444 | 
            +
                parser.add_argument(
         | 
| 445 | 
            +
                    "-m",
         | 
| 446 | 
            +
                    "--output_model_dir",
         | 
| 447 | 
            +
                    type=str,
         | 
| 448 | 
            +
                    default="models",
         | 
| 449 | 
            +
                    help="save directory for model",
         | 
| 450 | 
            +
                )
         | 
| 451 | 
            +
                subparsers = parser.add_subparsers(
         | 
| 452 | 
            +
                    help="sub-command help", dest="command", required=True
         | 
| 453 | 
            +
                )
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                parser_predict = subparsers.add_parser("predict", help="diacritize a text files ")
         | 
| 456 | 
            +
                parser_predict.add_argument("input_path", help="input file or folder")
         | 
| 457 | 
            +
                parser_predict.add_argument("output_path", help="output file")
         | 
| 458 | 
            +
                parser_predict.add_argument(
         | 
| 459 | 
            +
                    "-ptmp",
         | 
| 460 | 
            +
                    "--pretrain_model_path",
         | 
| 461 | 
            +
                    type=str,
         | 
| 462 | 
            +
                    default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
         | 
| 463 | 
            +
                    help="pre-train model path - use only if you want to use trained model weights",
         | 
| 464 | 
            +
                )
         | 
| 465 | 
            +
                parser_predict.add_argument(
         | 
| 466 | 
            +
                    "-c",
         | 
| 467 | 
            +
                    "--compare",
         | 
| 468 | 
            +
                    dest="compare_nakdimon",
         | 
| 469 | 
            +
                    default=False,
         | 
| 470 | 
            +
                    help="predict text for comparing with Nakdimon",
         | 
| 471 | 
            +
                )
         | 
| 472 | 
            +
                parser_predict.set_defaults(func=do_predict)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                parser_evaluate = subparsers.add_parser("evaluate", help="evaluate D-nikud")
         | 
| 475 | 
            +
                parser_evaluate.add_argument("input_path", help="input file or folder")
         | 
| 476 | 
            +
                parser_evaluate.add_argument(
         | 
| 477 | 
            +
                    "-ptmp",
         | 
| 478 | 
            +
                    "--pretrain_model_path",
         | 
| 479 | 
            +
                    type=str,
         | 
| 480 | 
            +
                    default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
         | 
| 481 | 
            +
                    help="pre-train model path - use only if you want to use trained model weights",
         | 
| 482 | 
            +
                )
         | 
| 483 | 
            +
                parser_evaluate.add_argument(
         | 
| 484 | 
            +
                    "-df",
         | 
| 485 | 
            +
                    "--plots_folder",
         | 
| 486 | 
            +
                    dest="plots_folder",
         | 
| 487 | 
            +
                    default=os.path.join(Path(__file__).parent, "plots"),
         | 
| 488 | 
            +
                    help="set the debug folder",
         | 
| 489 | 
            +
                )
         | 
| 490 | 
            +
                parser_evaluate.add_argument(
         | 
| 491 | 
            +
                    "-es",
         | 
| 492 | 
            +
                    "--eval_sub_folders",
         | 
| 493 | 
            +
                    dest="eval_sub_folders",
         | 
| 494 | 
            +
                    default=False,
         | 
| 495 | 
            +
                    help="accuracy calculation includes the evaluation of sub-folders "
         | 
| 496 | 
            +
                    "within the input_path folder, providing independent assessments "
         | 
| 497 | 
            +
                    "for each subfolder.",
         | 
| 498 | 
            +
                )
         | 
| 499 | 
            +
                parser_evaluate.set_defaults(func=do_evaluate)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                # train --n_epochs 20
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                parser_train = subparsers.add_parser("train", help="train D-nikud")
         | 
| 504 | 
            +
                parser_train.add_argument(
         | 
| 505 | 
            +
                    "-ptmp",
         | 
| 506 | 
            +
                    "--pretrain_model_path",
         | 
| 507 | 
            +
                    type=str,
         | 
| 508 | 
            +
                    default=None,
         | 
| 509 | 
            +
                    help="pre-train model path - use only if you want to use trained model weights",
         | 
| 510 | 
            +
                )
         | 
| 511 | 
            +
                parser_train.add_argument(
         | 
| 512 | 
            +
                    "--learning_rate", type=float, default=0.001, help="Learning rate"
         | 
| 513 | 
            +
                )
         | 
| 514 | 
            +
                parser_train.add_argument("--batch_size", type=int, default=32, help="batch_size")
         | 
| 515 | 
            +
                parser_train.add_argument(
         | 
| 516 | 
            +
                    "--n_epochs", type=int, default=10, help="number of epochs"
         | 
| 517 | 
            +
                )
         | 
| 518 | 
            +
                parser_train.add_argument(
         | 
| 519 | 
            +
                    "--data_folder",
         | 
| 520 | 
            +
                    dest="data_folder",
         | 
| 521 | 
            +
                    default=os.path.join(Path(__file__).parent, "data"),
         | 
| 522 | 
            +
                    help="Set the debug folder",
         | 
| 523 | 
            +
                )
         | 
| 524 | 
            +
                parser_train.add_argument(
         | 
| 525 | 
            +
                    "--checkpoints_frequency",
         | 
| 526 | 
            +
                    type=int,
         | 
| 527 | 
            +
                    default=1,
         | 
| 528 | 
            +
                    help="checkpoints frequency for save the model",
         | 
| 529 | 
            +
                )
         | 
| 530 | 
            +
                parser_train.add_argument(
         | 
| 531 | 
            +
                    "-df",
         | 
| 532 | 
            +
                    "--plots_folder",
         | 
| 533 | 
            +
                    dest="plots_folder",
         | 
| 534 | 
            +
                    default=os.path.join(Path(__file__).parent, "plots"),
         | 
| 535 | 
            +
                    help="Set the debug folder",
         | 
| 536 | 
            +
                )
         | 
| 537 | 
            +
                parser_train.set_defaults(func=do_train)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                args = parser.parse_args()
         | 
| 540 | 
            +
                kwargs = vars(args).copy()
         | 
| 541 | 
            +
                date_time = datetime.now().strftime("%d_%m_%y__%H_%M")
         | 
| 542 | 
            +
                logger = get_logger(kwargs["log_level"], args.command, date_time)
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                del kwargs["log_level"]
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                kwargs["tokenizer_tavbert"] = tokenizer_tavbert
         | 
| 547 | 
            +
                kwargs["logger"] = logger
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                msg = "Loading model..."
         | 
| 550 | 
            +
                logger.debug(msg)
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                if args.command in ["evaluate", "predict"] or (
         | 
| 553 | 
            +
                    args.command == "train" and args.pretrain_model_path is not None
         | 
| 554 | 
            +
                ):
         | 
| 555 | 
            +
                    dir_model_config = os.path.join("models", "config.yml")
         | 
| 556 | 
            +
                    config = ModelConfig.load_from_file(dir_model_config)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    dnikud_model = DNikudModel(
         | 
| 559 | 
            +
                        config,
         | 
| 560 | 
            +
                        len(Nikud.label_2_id["nikud"]),
         | 
| 561 | 
            +
                        len(Nikud.label_2_id["dagesh"]),
         | 
| 562 | 
            +
                        len(Nikud.label_2_id["sin"]),
         | 
| 563 | 
            +
                        device=DEVICE,
         | 
| 564 | 
            +
                    ).to(DEVICE)
         | 
| 565 | 
            +
                    state_dict_model = dnikud_model.state_dict()
         | 
| 566 | 
            +
                    state_dict_model.update(torch.load(args.pretrain_model_path))
         | 
| 567 | 
            +
                    dnikud_model.load_state_dict(state_dict_model)
         | 
| 568 | 
            +
                else:
         | 
| 569 | 
            +
                    base_model_name = "tau/tavbert-he"
         | 
| 570 | 
            +
                    config = AutoConfig.from_pretrained(base_model_name)
         | 
| 571 | 
            +
                    dnikud_model = DNikudModel(
         | 
| 572 | 
            +
                        config,
         | 
| 573 | 
            +
                        len(Nikud.label_2_id["nikud"]),
         | 
| 574 | 
            +
                        len(Nikud.label_2_id["dagesh"]),
         | 
| 575 | 
            +
                        len(Nikud.label_2_id["sin"]),
         | 
| 576 | 
            +
                        pretrain_model=base_model_name,
         | 
| 577 | 
            +
                        device=DEVICE,
         | 
| 578 | 
            +
                    ).to(DEVICE)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                if args.command == "train":
         | 
| 581 | 
            +
                    output_trained_model_dir = os.path.join(
         | 
| 582 | 
            +
                        kwargs["output_model_dir"], "latest", f"output_models_{date_time}"
         | 
| 583 | 
            +
                    )
         | 
| 584 | 
            +
                    create_missing_folders(output_trained_model_dir)
         | 
| 585 | 
            +
                    dir_model_config = os.path.join(kwargs["output_model_dir"], "config.yml")
         | 
| 586 | 
            +
                    kwargs["dir_model_config"] = dir_model_config
         | 
| 587 | 
            +
                    kwargs["output_trained_model_dir"] = output_trained_model_dir
         | 
| 588 | 
            +
                del kwargs["pretrain_model_path"]
         | 
| 589 | 
            +
                del kwargs["output_model_dir"]
         | 
| 590 | 
            +
                kwargs["dnikud_model"] = dnikud_model
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                del kwargs["command"]
         | 
| 593 | 
            +
                del kwargs["func"]
         | 
| 594 | 
            +
                args.func(**kwargs)
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                sys.exit(0)
         | 
    	
        src/models.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general
         | 
| 2 | 
            +
            import subprocess
         | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # ML
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class DNikudModel(nn.Module):
         | 
| 11 | 
            +
                def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'):
         | 
| 12 | 
            +
                    super(DNikudModel, self).__init__()
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    if pretrain_model is not None:
         | 
| 15 | 
            +
                        model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device)
         | 
| 16 | 
            +
                    else:
         | 
| 17 | 
            +
                        model_base = RobertaForMaskedLM(config=config).to(device)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    self.model = model_base.roberta
         | 
| 20 | 
            +
                    for name, param in self.model.named_parameters():
         | 
| 21 | 
            +
                        param.requires_grad = False
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
         | 
| 24 | 
            +
                    self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
         | 
| 25 | 
            +
                    self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
         | 
| 26 | 
            +
                    self.out_n = nn.Linear(config.hidden_size, nikud_size)
         | 
| 27 | 
            +
                    self.out_d = nn.Linear(config.hidden_size, dagesh_size)
         | 
| 28 | 
            +
                    self.out_s = nn.Linear(config.hidden_size, sin_size)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, input_ids, attention_mask):
         | 
| 31 | 
            +
                    last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state
         | 
| 32 | 
            +
                    lstm1, _ = self.lstm1(last_hidden_state)
         | 
| 33 | 
            +
                    lstm2, _ = self.lstm2(lstm1)
         | 
| 34 | 
            +
                    dense = self.dense(lstm2)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    nikud = self.out_n(dense)
         | 
| 37 | 
            +
                    dagesh = self.out_d(dense)
         | 
| 38 | 
            +
                    sin = self.out_s(dense)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    return nikud, dagesh, sin
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def get_git_commit_hash():
         | 
| 44 | 
            +
                try:
         | 
| 45 | 
            +
                    commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
         | 
| 46 | 
            +
                    return commit_hash
         | 
| 47 | 
            +
                except subprocess.CalledProcessError:
         | 
| 48 | 
            +
                    # This will be raised if you're not in a Git repository
         | 
| 49 | 
            +
                    print("Not inside a Git repository!")
         | 
| 50 | 
            +
                    return None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class ModelConfig(PretrainedConfig):
         | 
| 54 | 
            +
                def __init__(self, max_length=None, dict=None):
         | 
| 55 | 
            +
                    super(ModelConfig, self).__init__()
         | 
| 56 | 
            +
                    if dict is None:
         | 
| 57 | 
            +
                        self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__)
         | 
| 58 | 
            +
                        self.max_length = max_length
         | 
| 59 | 
            +
                        self._commit_hash = get_git_commit_hash()
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        self.__dict__.update(dict)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def print(self):
         | 
| 64 | 
            +
                    print(self.__dict__)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def save_to_file(self, file_path):
         | 
| 67 | 
            +
                    with open(file_path, "w") as yaml_file:
         | 
| 68 | 
            +
                        yaml.dump(self.__dict__, yaml_file, default_flow_style=False)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                @classmethod
         | 
| 71 | 
            +
                def load_from_file(cls, file_path):
         | 
| 72 | 
            +
                    with open(file_path, "r") as yaml_file:
         | 
| 73 | 
            +
                        config_dict = yaml.safe_load(yaml_file)
         | 
| 74 | 
            +
                    return cls(dict=config_dict)
         | 
    	
        src/models_utils.py
    ADDED
    
    | @@ -0,0 +1,561 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # ML
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import pandas as pd
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # visual
         | 
| 11 | 
            +
            import matplotlib.pyplot as plt
         | 
| 12 | 
            +
            import seaborn as sns
         | 
| 13 | 
            +
            from sklearn.metrics import confusion_matrix
         | 
| 14 | 
            +
            from tqdm import tqdm
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from src.running_params import DEBUG_MODE
         | 
| 17 | 
            +
            from src.utiles_data import Nikud, create_missing_folders
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            CLASSES_LIST = ["nikud", "dagesh", "sin"]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def calc_num_correct_words(input, letter_correct_mask):
         | 
| 23 | 
            +
                SPACE_TOKEN = 104
         | 
| 24 | 
            +
                START_SENTENCE_TOKEN = 1
         | 
| 25 | 
            +
                END_SENTENCE_TOKEN = 2
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                correct_words_count = 0
         | 
| 28 | 
            +
                words_count = 0
         | 
| 29 | 
            +
                for index in range(input.shape[0]):
         | 
| 30 | 
            +
                    input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0
         | 
| 31 | 
            +
                    input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0
         | 
| 32 | 
            +
                    input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0
         | 
| 33 | 
            +
                    words_end_index = np.concatenate(
         | 
| 34 | 
            +
                        (np.array([-1]), np.where(input[index] == 0)[0])
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
                    is_correct_words_array = [
         | 
| 37 | 
            +
                        bool(
         | 
| 38 | 
            +
                            letter_correct_mask[index][
         | 
| 39 | 
            +
                                list(range((words_end_index[s] + 1), words_end_index[s + 1]))
         | 
| 40 | 
            +
                            ].all()
         | 
| 41 | 
            +
                        )
         | 
| 42 | 
            +
                        for s in range(len(words_end_index) - 1)
         | 
| 43 | 
            +
                        if words_end_index[s + 1] - (words_end_index[s] + 1) > 1
         | 
| 44 | 
            +
                    ]
         | 
| 45 | 
            +
                    correct_words_count += np.array(is_correct_words_array).sum()
         | 
| 46 | 
            +
                    words_count += len(is_correct_words_array)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return correct_words_count, words_count
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def predict(model, data_loader, device="cpu"):
         | 
| 52 | 
            +
                model.to(device)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                all_labels = None
         | 
| 55 | 
            +
                with torch.no_grad():
         | 
| 56 | 
            +
                    for index_data, data in enumerate(data_loader):
         | 
| 57 | 
            +
                        (inputs, attention_mask, labels_demo) = data
         | 
| 58 | 
            +
                        inputs = inputs.to(device)
         | 
| 59 | 
            +
                        attention_mask = attention_mask.to(device)
         | 
| 60 | 
            +
                        labels_demo = labels_demo.to(device)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                        mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
         | 
| 63 | 
            +
                        mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
         | 
| 64 | 
            +
                        mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                        pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
         | 
| 69 | 
            +
                            inputs.shape[0], inputs.shape[1], 1
         | 
| 70 | 
            +
                        )
         | 
| 71 | 
            +
                        pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
         | 
| 72 | 
            +
                            inputs.shape[0], inputs.shape[1], 1
         | 
| 73 | 
            +
                        )
         | 
| 74 | 
            +
                        pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
         | 
| 75 | 
            +
                            inputs.shape[0], inputs.shape[1], 1
         | 
| 76 | 
            +
                        )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        pred_nikud[mask_cant_be_nikud] = -1
         | 
| 79 | 
            +
                        pred_dagesh[mask_cant_be_dagesh] = -1
         | 
| 80 | 
            +
                        pred_sin[mask_cant_be_sin] = -1
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        if all_labels is None:
         | 
| 85 | 
            +
                            all_labels = pred_labels
         | 
| 86 | 
            +
                        else:
         | 
| 87 | 
            +
                            all_labels = np.concatenate((all_labels, pred_labels), axis=0)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return all_labels
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            def predict_single(model, data, device="cpu"):
         | 
| 93 | 
            +
                # model.to(device)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                all_labels = None
         | 
| 96 | 
            +
                with torch.no_grad():
         | 
| 97 | 
            +
                    (inputs, attention_mask, labels_demo) = data
         | 
| 98 | 
            +
                    inputs = inputs.to(device)
         | 
| 99 | 
            +
                    attention_mask = attention_mask.to(device)
         | 
| 100 | 
            +
                    labels_demo = labels_demo.to(device)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
         | 
| 103 | 
            +
                    mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
         | 
| 104 | 
            +
                    mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
         | 
| 107 | 
            +
                    print("model output: ", nikud_probs, dagesh_probs, sin_probs)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
         | 
| 110 | 
            +
                        inputs.shape[0], inputs.shape[1], 1
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
         | 
| 113 | 
            +
                        inputs.shape[0], inputs.shape[1], 1
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                    pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
         | 
| 116 | 
            +
                        inputs.shape[0], inputs.shape[1], 1
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    pred_nikud[mask_cant_be_nikud] = -1
         | 
| 120 | 
            +
                    pred_dagesh[mask_cant_be_dagesh] = -1
         | 
| 121 | 
            +
                    pred_sin[mask_cant_be_sin] = -1
         | 
| 122 | 
            +
                    # print(pred_nikud, pred_dagesh, pred_sin)
         | 
| 123 | 
            +
                    pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
         | 
| 124 | 
            +
                    print(pred_labels)
         | 
| 125 | 
            +
                    if all_labels is None:
         | 
| 126 | 
            +
                        all_labels = pred_labels
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        all_labels = np.concatenate((all_labels, pred_labels), axis=0)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                return all_labels
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            def training(
         | 
| 134 | 
            +
                model,
         | 
| 135 | 
            +
                train_loader,
         | 
| 136 | 
            +
                dev_loader,
         | 
| 137 | 
            +
                criterion_nikud,
         | 
| 138 | 
            +
                criterion_dagesh,
         | 
| 139 | 
            +
                criterion_sin,
         | 
| 140 | 
            +
                training_params,
         | 
| 141 | 
            +
                logger,
         | 
| 142 | 
            +
                output_model_path,
         | 
| 143 | 
            +
                optimizer,
         | 
| 144 | 
            +
                device="cpu",
         | 
| 145 | 
            +
            ):
         | 
| 146 | 
            +
                max_length = None
         | 
| 147 | 
            +
                best_accuracy = 0.0
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                logger.info(f"start training with training_params: {training_params}")
         | 
| 150 | 
            +
                model = model.to(device)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                criteria = {
         | 
| 153 | 
            +
                    "nikud": criterion_nikud.to(device),
         | 
| 154 | 
            +
                    "dagesh": criterion_dagesh.to(device),
         | 
| 155 | 
            +
                    "sin": criterion_sin.to(device),
         | 
| 156 | 
            +
                }
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                output_checkpoints_path = os.path.join(output_model_path, "checkpoints")
         | 
| 159 | 
            +
                create_missing_folders(output_checkpoints_path)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []}
         | 
| 162 | 
            +
                train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []}
         | 
| 163 | 
            +
                dev_loss_values = {"nikud": [], "dagesh": [], "sin": []}
         | 
| 164 | 
            +
                dev_accuracy_values = {
         | 
| 165 | 
            +
                    "nikud": [],
         | 
| 166 | 
            +
                    "dagesh": [],
         | 
| 167 | 
            +
                    "sin": [],
         | 
| 168 | 
            +
                    "all_nikud_letter": [],
         | 
| 169 | 
            +
                    "all_nikud_word": [],
         | 
| 170 | 
            +
                }
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"):
         | 
| 173 | 
            +
                    model.train()
         | 
| 174 | 
            +
                    train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 175 | 
            +
                    relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    for index_data, data in enumerate(train_loader):
         | 
| 178 | 
            +
                        (inputs, attention_mask, labels) = data
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                        if max_length is None:
         | 
| 181 | 
            +
                            max_length = labels.shape[1]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                        inputs = inputs.to(device)
         | 
| 184 | 
            +
                        attention_mask = attention_mask.to(device)
         | 
| 185 | 
            +
                        labels = labels.to(device)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        optimizer.zero_grad()
         | 
| 188 | 
            +
                        nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                        for i, (probs, class_name) in enumerate(
         | 
| 191 | 
            +
                            zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
         | 
| 192 | 
            +
                        ):
         | 
| 193 | 
            +
                            reshaped_tensor = (
         | 
| 194 | 
            +
                                torch.transpose(probs, 1, 2)
         | 
| 195 | 
            +
                                .contiguous()
         | 
| 196 | 
            +
                                .view(probs.shape[0], probs.shape[2], probs.shape[1])
         | 
| 197 | 
            +
                            )
         | 
| 198 | 
            +
                            loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                            num_relevant = (labels[:, :, i] != -1).sum()
         | 
| 201 | 
            +
                            train_loss[class_name] += loss.item() * num_relevant
         | 
| 202 | 
            +
                            relevant_count[class_name] += num_relevant
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                            loss.backward(retain_graph=True)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                        for i, class_name in enumerate(CLASSES_LIST):
         | 
| 207 | 
            +
                            train_steps_loss_values[class_name].append(
         | 
| 208 | 
            +
                                float(train_loss[class_name] / relevant_count[class_name])
         | 
| 209 | 
            +
                            )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                        optimizer.step()
         | 
| 212 | 
            +
                        if (index_data + 1) % 100 == 0:
         | 
| 213 | 
            +
                            msg = f"epoch: {epoch} , index_data: {index_data + 1}\n"
         | 
| 214 | 
            +
                            for i, class_name in enumerate(CLASSES_LIST):
         | 
| 215 | 
            +
                                msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, "
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                            logger.debug(msg[:-2])
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    for i, class_name in enumerate(CLASSES_LIST):
         | 
| 220 | 
            +
                        train_epochs_loss_values[class_name].append(
         | 
| 221 | 
            +
                            float(train_loss[class_name] / relevant_count[class_name])
         | 
| 222 | 
            +
                        )
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    for class_name in train_loss.keys():
         | 
| 225 | 
            +
                        train_loss[class_name] /= relevant_count[class_name]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
         | 
| 228 | 
            +
                    for i, class_name in enumerate(CLASSES_LIST):
         | 
| 229 | 
            +
                        msg += f"mean loss train {class_name}: {train_loss[class_name]}, "
         | 
| 230 | 
            +
                    logger.debug(msg[:-2])
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    model.eval()
         | 
| 233 | 
            +
                    dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 234 | 
            +
                    dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 235 | 
            +
                    relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 236 | 
            +
                    correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 237 | 
            +
                    un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 238 | 
            +
                    predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 239 | 
            +
                    labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    all_nikud_types_correct_preds_letter = 0.0
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    letter_count = 0.0
         | 
| 244 | 
            +
                    correct_words_count = 0.0
         | 
| 245 | 
            +
                    word_count = 0.0
         | 
| 246 | 
            +
                    with torch.no_grad():
         | 
| 247 | 
            +
                        for index_data, data in enumerate(dev_loader):
         | 
| 248 | 
            +
                            (inputs, attention_mask, labels) = data
         | 
| 249 | 
            +
                            inputs = inputs.to(device)
         | 
| 250 | 
            +
                            attention_mask = attention_mask.to(device)
         | 
| 251 | 
            +
                            labels = labels.to(device)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                            nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                            for i, (probs, class_name) in enumerate(
         | 
| 256 | 
            +
                                zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
         | 
| 257 | 
            +
                            ):
         | 
| 258 | 
            +
                                reshaped_tensor = (
         | 
| 259 | 
            +
                                    torch.transpose(probs, 1, 2)
         | 
| 260 | 
            +
                                    .contiguous()
         | 
| 261 | 
            +
                                    .view(probs.shape[0], probs.shape[2], probs.shape[1])
         | 
| 262 | 
            +
                                )
         | 
| 263 | 
            +
                                loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(
         | 
| 264 | 
            +
                                    device
         | 
| 265 | 
            +
                                )
         | 
| 266 | 
            +
                                un_masked = labels[:, :, i] != -1
         | 
| 267 | 
            +
                                num_relevant = un_masked.sum()
         | 
| 268 | 
            +
                                relevant_count[class_name] += num_relevant
         | 
| 269 | 
            +
                                _, preds = torch.max(probs, 2)
         | 
| 270 | 
            +
                                dev_loss[class_name] += loss.item() * num_relevant
         | 
| 271 | 
            +
                                correct_preds[class_name] += torch.sum(
         | 
| 272 | 
            +
                                    preds[un_masked] == labels[:, :, i][un_masked]
         | 
| 273 | 
            +
                                )
         | 
| 274 | 
            +
                                un_masks[class_name] = un_masked
         | 
| 275 | 
            +
                                predictions[class_name] = preds
         | 
| 276 | 
            +
                                labels_class[class_name] = labels[:, :, i]
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                            un_mask_all_or = torch.logical_or(
         | 
| 279 | 
            +
                                torch.logical_or(un_masks["nikud"], un_masks["dagesh"]),
         | 
| 280 | 
            +
                                un_masks["sin"],
         | 
| 281 | 
            +
                            )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                            correct = {
         | 
| 284 | 
            +
                                class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device)
         | 
| 285 | 
            +
                                for class_name in CLASSES_LIST
         | 
| 286 | 
            +
                            }
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                            for i, class_name in enumerate(CLASSES_LIST):
         | 
| 289 | 
            +
                                correct[class_name][un_masks[class_name]] = (
         | 
| 290 | 
            +
                                    predictions[class_name][un_masks[class_name]]
         | 
| 291 | 
            +
                                    == labels_class[class_name][un_masks[class_name]]
         | 
| 292 | 
            +
                                )
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                            letter_correct_mask = torch.logical_and(
         | 
| 295 | 
            +
                                torch.logical_and(correct["sin"], correct["dagesh"]),
         | 
| 296 | 
            +
                                correct["nikud"],
         | 
| 297 | 
            +
                            )
         | 
| 298 | 
            +
                            all_nikud_types_correct_preds_letter += torch.sum(
         | 
| 299 | 
            +
                                letter_correct_mask[un_mask_all_or]
         | 
| 300 | 
            +
                            )
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                            letter_correct_mask[~un_mask_all_or] = True
         | 
| 303 | 
            +
                            correct_num, total_words_num = calc_num_correct_words(
         | 
| 304 | 
            +
                                inputs.cpu(), letter_correct_mask
         | 
| 305 | 
            +
                            )
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                            word_count += total_words_num
         | 
| 308 | 
            +
                            correct_words_count += correct_num
         | 
| 309 | 
            +
                            letter_count += un_mask_all_or.sum()
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    for class_name in CLASSES_LIST:
         | 
| 312 | 
            +
                        dev_loss[class_name] /= relevant_count[class_name]
         | 
| 313 | 
            +
                        dev_accuracy[class_name] = float(
         | 
| 314 | 
            +
                            correct_preds[class_name].double() / relevant_count[class_name]
         | 
| 315 | 
            +
                        )
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                        dev_loss_values[class_name].append(float(dev_loss[class_name]))
         | 
| 318 | 
            +
                        dev_accuracy_values[class_name].append(float(dev_accuracy[class_name]))
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    dev_all_nikud_types_accuracy_letter = float(
         | 
| 321 | 
            +
                        all_nikud_types_correct_preds_letter / letter_count
         | 
| 322 | 
            +
                    )
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    dev_accuracy_values["all_nikud_letter"].append(
         | 
| 325 | 
            +
                        dev_all_nikud_types_accuracy_letter
         | 
| 326 | 
            +
                    )
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    word_all_nikud_accuracy = correct_words_count / word_count
         | 
| 329 | 
            +
                    dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    msg = (
         | 
| 332 | 
            +
                        f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
         | 
| 333 | 
            +
                        f'mean loss Dev nikud: {train_loss["nikud"]}, '
         | 
| 334 | 
            +
                        f'mean loss Dev dagesh: {train_loss["dagesh"]}, '
         | 
| 335 | 
            +
                        f'mean loss Dev sin: {train_loss["sin"]}, '
         | 
| 336 | 
            +
                        f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, "
         | 
| 337 | 
            +
                        f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, '
         | 
| 338 | 
            +
                        f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, '
         | 
| 339 | 
            +
                        f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, '
         | 
| 340 | 
            +
                        f"Dev word Accuracy: {word_all_nikud_accuracy}"
         | 
| 341 | 
            +
                    )
         | 
| 342 | 
            +
                    logger.debug(msg)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    save_progress_details(
         | 
| 345 | 
            +
                        dev_accuracy_values,
         | 
| 346 | 
            +
                        train_epochs_loss_values,
         | 
| 347 | 
            +
                        dev_loss_values,
         | 
| 348 | 
            +
                        train_steps_loss_values,
         | 
| 349 | 
            +
                    )
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    if dev_all_nikud_types_accuracy_letter > best_accuracy:
         | 
| 352 | 
            +
                        best_accuracy = dev_all_nikud_types_accuracy_letter
         | 
| 353 | 
            +
                        best_model = {
         | 
| 354 | 
            +
                            "epoch": epoch,
         | 
| 355 | 
            +
                            "model_state_dict": model.state_dict(),
         | 
| 356 | 
            +
                            "optimizer_state_dict": optimizer.state_dict(),
         | 
| 357 | 
            +
                            "loss": loss,
         | 
| 358 | 
            +
                        }
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    if epoch % training_params["checkpoints_frequency"] == 0:
         | 
| 361 | 
            +
                        save_checkpoint_path = os.path.join(
         | 
| 362 | 
            +
                            output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth"
         | 
| 363 | 
            +
                        )
         | 
| 364 | 
            +
                        checkpoint = {
         | 
| 365 | 
            +
                            "epoch": epoch,
         | 
| 366 | 
            +
                            "model_state_dict": model.state_dict(),
         | 
| 367 | 
            +
                            "optimizer_state_dict": optimizer.state_dict(),
         | 
| 368 | 
            +
                            "loss": loss,
         | 
| 369 | 
            +
                        }
         | 
| 370 | 
            +
                        torch.save(checkpoint["model_state_dict"], save_checkpoint_path)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                save_model_path = os.path.join(output_model_path, "best_model.pth")
         | 
| 373 | 
            +
                torch.save(best_model["model_state_dict"], save_model_path)
         | 
| 374 | 
            +
                return (
         | 
| 375 | 
            +
                    best_model,
         | 
| 376 | 
            +
                    best_accuracy,
         | 
| 377 | 
            +
                    train_epochs_loss_values,
         | 
| 378 | 
            +
                    train_steps_loss_values,
         | 
| 379 | 
            +
                    dev_loss_values,
         | 
| 380 | 
            +
                    dev_accuracy_values,
         | 
| 381 | 
            +
                )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
            def save_progress_details(
         | 
| 385 | 
            +
                accuracy_dev_values,
         | 
| 386 | 
            +
                epochs_loss_train_values,
         | 
| 387 | 
            +
                loss_dev_values,
         | 
| 388 | 
            +
                steps_loss_train_values,
         | 
| 389 | 
            +
            ):
         | 
| 390 | 
            +
                epochs_data_path = "epochs_data"
         | 
| 391 | 
            +
                create_missing_folders(epochs_data_path)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                save_dict_as_json(
         | 
| 394 | 
            +
                    steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json"
         | 
| 395 | 
            +
                )
         | 
| 396 | 
            +
                save_dict_as_json(
         | 
| 397 | 
            +
                    epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json"
         | 
| 398 | 
            +
                )
         | 
| 399 | 
            +
                save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json")
         | 
| 400 | 
            +
                save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json")
         | 
| 401 | 
            +
             | 
| 402 | 
            +
             | 
| 403 | 
            +
            def save_dict_as_json(dict, file_path, file_name):
         | 
| 404 | 
            +
                json_data = json.dumps(dict, indent=4)
         | 
| 405 | 
            +
                with open(os.path.join(file_path, file_name), "w") as json_file:
         | 
| 406 | 
            +
                    json_file.write(json_data)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
             | 
| 409 | 
            +
            def evaluate(model, test_data, plots_folder=None, device="cpu"):
         | 
| 410 | 
            +
                model.to(device)
         | 
| 411 | 
            +
                model.eval()
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                true_labels = {"nikud": [], "dagesh": [], "sin": []}
         | 
| 414 | 
            +
                predictions = {"nikud": 0, "dagesh": 0, "sin": 0}
         | 
| 415 | 
            +
                predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []}
         | 
| 416 | 
            +
                not_masks = {"nikud": 0, "dagesh": 0, "sin": 0}
         | 
| 417 | 
            +
                correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0}
         | 
| 418 | 
            +
                relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0}
         | 
| 419 | 
            +
                labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                all_nikud_types_letter_level_correct = 0.0
         | 
| 422 | 
            +
                nikud_letter_level_correct = 0.0
         | 
| 423 | 
            +
                dagesh_letter_level_correct = 0.0
         | 
| 424 | 
            +
                sin_letter_level_correct = 0.0
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                letters_count = 0.0
         | 
| 427 | 
            +
                words_count = 0.0
         | 
| 428 | 
            +
                correct_words_count = 0.0
         | 
| 429 | 
            +
                with torch.no_grad():
         | 
| 430 | 
            +
                    for index_data, data in enumerate(test_data):
         | 
| 431 | 
            +
                        if DEBUG_MODE and index_data > 100:
         | 
| 432 | 
            +
                            break
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        (inputs, attention_mask, labels) = data
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                        inputs = inputs.to(device)
         | 
| 437 | 
            +
                        attention_mask = attention_mask.to(device)
         | 
| 438 | 
            +
                        labels = labels.to(device)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                        nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                        for i, (probs, class_name) in enumerate(
         | 
| 443 | 
            +
                            zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
         | 
| 444 | 
            +
                        ):
         | 
| 445 | 
            +
                            labels_class[class_name] = labels[:, :, i]
         | 
| 446 | 
            +
                            not_masked = labels_class[class_name] != -1
         | 
| 447 | 
            +
                            num_relevant = not_masked.sum()
         | 
| 448 | 
            +
                            relevant_count[class_name] += num_relevant
         | 
| 449 | 
            +
                            _, preds = torch.max(probs, 2)
         | 
| 450 | 
            +
                            correct_preds[class_name] += torch.sum(
         | 
| 451 | 
            +
                                preds[not_masked] == labels_class[class_name][not_masked]
         | 
| 452 | 
            +
                            )
         | 
| 453 | 
            +
                            predictions[class_name] = preds
         | 
| 454 | 
            +
                            not_masks[class_name] = not_masked
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                            if len(true_labels[class_name]) == 0:
         | 
| 457 | 
            +
                                true_labels[class_name] = (
         | 
| 458 | 
            +
                                    labels_class[class_name][not_masked].cpu().numpy()
         | 
| 459 | 
            +
                                )
         | 
| 460 | 
            +
                            else:
         | 
| 461 | 
            +
                                true_labels[class_name] = np.concatenate(
         | 
| 462 | 
            +
                                    (
         | 
| 463 | 
            +
                                        true_labels[class_name],
         | 
| 464 | 
            +
                                        labels_class[class_name][not_masked].cpu().numpy(),
         | 
| 465 | 
            +
                                    )
         | 
| 466 | 
            +
                                )
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                            if len(predicted_labels_2_report[class_name]) == 0:
         | 
| 469 | 
            +
                                predicted_labels_2_report[class_name] = (
         | 
| 470 | 
            +
                                    preds[not_masked].cpu().numpy()
         | 
| 471 | 
            +
                                )
         | 
| 472 | 
            +
                            else:
         | 
| 473 | 
            +
                                predicted_labels_2_report[class_name] = np.concatenate(
         | 
| 474 | 
            +
                                    (
         | 
| 475 | 
            +
                                        predicted_labels_2_report[class_name],
         | 
| 476 | 
            +
                                        preds[not_masked].cpu().numpy(),
         | 
| 477 | 
            +
                                    )
         | 
| 478 | 
            +
                                )
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                        not_mask_all_or = torch.logical_or(
         | 
| 481 | 
            +
                            torch.logical_or(not_masks["nikud"], not_masks["dagesh"]),
         | 
| 482 | 
            +
                            not_masks["sin"],
         | 
| 483 | 
            +
                        )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device)
         | 
| 486 | 
            +
                        correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device)
         | 
| 487 | 
            +
                        correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                        correct_nikud[not_masks["nikud"]] = (
         | 
| 490 | 
            +
                            predictions["nikud"][not_masks["nikud"]]
         | 
| 491 | 
            +
                            == labels_class["nikud"][not_masks["nikud"]]
         | 
| 492 | 
            +
                        )
         | 
| 493 | 
            +
                        correct_dagesh[not_masks["dagesh"]] = (
         | 
| 494 | 
            +
                            predictions["dagesh"][not_masks["dagesh"]]
         | 
| 495 | 
            +
                            == labels_class["dagesh"][not_masks["dagesh"]]
         | 
| 496 | 
            +
                        )
         | 
| 497 | 
            +
                        correct_sin[not_masks["sin"]] = (
         | 
| 498 | 
            +
                            predictions["sin"][not_masks["sin"]]
         | 
| 499 | 
            +
                            == labels_class["sin"][not_masks["sin"]]
         | 
| 500 | 
            +
                        )
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                        letter_correct_mask = torch.logical_and(
         | 
| 503 | 
            +
                            torch.logical_and(correct_sin, correct_dagesh), correct_nikud
         | 
| 504 | 
            +
                        )
         | 
| 505 | 
            +
                        all_nikud_types_letter_level_correct += torch.sum(
         | 
| 506 | 
            +
                            letter_correct_mask[not_mask_all_or]
         | 
| 507 | 
            +
                        )
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                        letter_correct_mask[~not_mask_all_or] = True
         | 
| 510 | 
            +
                        total_correct_count, total_words_num = calc_num_correct_words(
         | 
| 511 | 
            +
                            inputs.cpu(), letter_correct_mask
         | 
| 512 | 
            +
                        )
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                        words_count += total_words_num
         | 
| 515 | 
            +
                        correct_words_count += total_correct_count
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                        letters_count += not_mask_all_or.sum()
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                        nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or])
         | 
| 520 | 
            +
                        dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or])
         | 
| 521 | 
            +
                        sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or])
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                for i, name in enumerate(CLASSES_LIST):
         | 
| 524 | 
            +
                    index_labels = np.unique(true_labels[name])
         | 
| 525 | 
            +
                    cm = confusion_matrix(
         | 
| 526 | 
            +
                        true_labels[name], predicted_labels_2_report[name], labels=index_labels
         | 
| 527 | 
            +
                    )
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    vowel_label = [Nikud.id_2_label[name][l] for l in index_labels]
         | 
| 530 | 
            +
                    unique_vowels_names = [
         | 
| 531 | 
            +
                        Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT"
         | 
| 532 | 
            +
                    ]
         | 
| 533 | 
            +
                    if "WITHOUT" in vowel_label:
         | 
| 534 | 
            +
                        unique_vowels_names += ["WITHOUT"]
         | 
| 535 | 
            +
                    cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    # Display confusion matrix
         | 
| 538 | 
            +
                    plt.figure(figsize=(10, 8))
         | 
| 539 | 
            +
                    sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
         | 
| 540 | 
            +
                    plt.title("Confusion Matrix")
         | 
| 541 | 
            +
                    plt.xlabel("True Label")
         | 
| 542 | 
            +
                    plt.ylabel("Predicted Label")
         | 
| 543 | 
            +
                    if plots_folder is None:
         | 
| 544 | 
            +
                        plt.show()
         | 
| 545 | 
            +
                    else:
         | 
| 546 | 
            +
                        plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg"))
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                all_nikud_types_letter_level_correct = (
         | 
| 549 | 
            +
                    all_nikud_types_letter_level_correct / letters_count
         | 
| 550 | 
            +
                )
         | 
| 551 | 
            +
                all_nikud_types_word_level_correct = correct_words_count / words_count
         | 
| 552 | 
            +
                nikud_letter_level_correct = nikud_letter_level_correct / letters_count
         | 
| 553 | 
            +
                dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count
         | 
| 554 | 
            +
                sin_letter_level_correct = sin_letter_level_correct / letters_count
         | 
| 555 | 
            +
                print("\n")
         | 
| 556 | 
            +
                print(f"nikud_letter_level_correct = {nikud_letter_level_correct}")
         | 
| 557 | 
            +
                print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}")
         | 
| 558 | 
            +
                print(f"sin_letter_level_correct = {sin_letter_level_correct}")
         | 
| 559 | 
            +
                print(f"word_level_correct = {all_nikud_types_word_level_correct}")
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct
         | 
    	
        src/plot_helpers.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # visual
         | 
| 5 | 
            +
            import matplotlib.pyplot as plt
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            cols = ["precision", "recall", "f1-score", "support"]
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def generate_plot_by_nikud_dagesh_sin_dict(nikud_dagesh_sin_dict, title, y_axis, plot_folder=None):
         | 
| 11 | 
            +
                # Create a figure and axis
         | 
| 12 | 
            +
                plt.figure(figsize=(8, 6))
         | 
| 13 | 
            +
                plt.title(title)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ax = plt.gca()
         | 
| 16 | 
            +
                indexes = list(range(1, len(nikud_dagesh_sin_dict["nikud"]) + 1))
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                # Plot data series with different colors and labels
         | 
| 19 | 
            +
                ax.plot(indexes, nikud_dagesh_sin_dict["nikud"], color='blue', label='Nikud')
         | 
| 20 | 
            +
                ax.plot(indexes, nikud_dagesh_sin_dict["dagesh"], color='green', label='Dagesh')
         | 
| 21 | 
            +
                ax.plot(indexes, nikud_dagesh_sin_dict["sin"], color='red', label='Sin')
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # Add legend
         | 
| 24 | 
            +
                ax.legend()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                # Set labels and title
         | 
| 27 | 
            +
                ax.set_xlabel('Epoch')
         | 
| 28 | 
            +
                ax.set_ylabel(y_axis)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if plot_folder is None:
         | 
| 31 | 
            +
                    plt.show()
         | 
| 32 | 
            +
                else:
         | 
| 33 | 
            +
                    plt.savefig(os.path.join(plot_folder, f'{title.replace(" ", "_")}_plot.jpg'))
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def generate_word_and_letter_accuracy_plot(word_and_letter_accuracy_dict, title, plot_folder=None):
         | 
| 37 | 
            +
                # Create a figure and axis
         | 
| 38 | 
            +
                plt.figure(figsize=(8, 6))
         | 
| 39 | 
            +
                plt.title(title)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                ax = plt.gca()
         | 
| 42 | 
            +
                indexes = list(range(1, len(word_and_letter_accuracy_dict["all_nikud_letter"]) + 1))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # Plot data series with different colors and labels
         | 
| 45 | 
            +
                ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_letter"], color='blue', label='Letter')
         | 
| 46 | 
            +
                ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_word"], color='green', label='Word')
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Add legend
         | 
| 49 | 
            +
                ax.legend()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # Set labels and title
         | 
| 52 | 
            +
                ax.set_xlabel("Epoch")
         | 
| 53 | 
            +
                ax.set_ylabel("Accuracy")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if plot_folder is None:
         | 
| 56 | 
            +
                    plt.show()
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    plt.savefig(os.path.join(plot_folder, 'word_and_letter_accuracy_plot.jpg'))
         | 
    	
        src/running_params.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            DEBUG_MODE = False
         | 
| 2 | 
            +
            BATCH_SIZE = 32
         | 
| 3 | 
            +
            MAX_LENGTH_SEN = 1024
         | 
    	
        src/utiles_data.py
    ADDED
    
    | @@ -0,0 +1,737 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # general
         | 
| 2 | 
            +
            import os.path
         | 
| 3 | 
            +
            from datetime import datetime
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            from typing import List, Tuple
         | 
| 6 | 
            +
            from uuid import uuid1
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            import glob2
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # visual
         | 
| 11 | 
            +
            import matplotlib
         | 
| 12 | 
            +
            import matplotlib.pyplot as plt
         | 
| 13 | 
            +
            from tqdm import tqdm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # ML
         | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            from torch.utils.data import Dataset
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from src.running_params import DEBUG_MODE, MAX_LENGTH_SEN
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            matplotlib.use("agg")
         | 
| 23 | 
            +
            unique_key = str(uuid1())
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class Nikud:
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                1456 HEBREW POINT SHEVA
         | 
| 29 | 
            +
                1457 HEBREW POINT HATAF SEGOL
         | 
| 30 | 
            +
                1458 HEBREW POINT HATAF PATAH
         | 
| 31 | 
            +
                1459 HEBREW POINT HATAF QAMATS
         | 
| 32 | 
            +
                1460 HEBREW POINT HIRIQ
         | 
| 33 | 
            +
                1461 HEBREW POINT TSERE
         | 
| 34 | 
            +
                1462 HEBREW POINT SEGOL
         | 
| 35 | 
            +
                1463 HEBREW POINT PATAH
         | 
| 36 | 
            +
                1464 HEBREW POINT QAMATS
         | 
| 37 | 
            +
                1465 HEBREW POINT HOLAM
         | 
| 38 | 
            +
                1466 HEBREW POINT HOLAM HASER FOR VAV     ***EXTENDED***
         | 
| 39 | 
            +
                1467 HEBREW POINT QUBUTS
         | 
| 40 | 
            +
                1468 HEBREW POINT DAGESH OR MAPIQ
         | 
| 41 | 
            +
                1469 HEBREW POINT METEG                   ***EXTENDED***
         | 
| 42 | 
            +
                1470 HEBREW PUNCTUATION MAQAF             ***EXTENDED***
         | 
| 43 | 
            +
                1471 HEBREW POINT RAFE                    ***EXTENDED***
         | 
| 44 | 
            +
                1472 HEBREW PUNCTUATION PASEQ             ***EXTENDED***
         | 
| 45 | 
            +
                1473 HEBREW POINT SHIN DOT
         | 
| 46 | 
            +
                1474 HEBREW POINT SIN DOT
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                nikud_dict = {
         | 
| 50 | 
            +
                    "SHVA": 1456,
         | 
| 51 | 
            +
                    "REDUCED_SEGOL": 1457,
         | 
| 52 | 
            +
                    "REDUCED_PATAKH": 1458,
         | 
| 53 | 
            +
                    "REDUCED_KAMATZ": 1459,
         | 
| 54 | 
            +
                    "HIRIK": 1460,
         | 
| 55 | 
            +
                    "TZEIRE": 1461,
         | 
| 56 | 
            +
                    "SEGOL": 1462,
         | 
| 57 | 
            +
                    "PATAKH": 1463,
         | 
| 58 | 
            +
                    "KAMATZ": 1464,
         | 
| 59 | 
            +
                    "KAMATZ_KATAN": 1479,
         | 
| 60 | 
            +
                    "HOLAM": 1465,
         | 
| 61 | 
            +
                    "HOLAM HASER VAV": 1466,
         | 
| 62 | 
            +
                    "KUBUTZ": 1467,
         | 
| 63 | 
            +
                    "DAGESH OR SHURUK": 1468,
         | 
| 64 | 
            +
                    "METEG": 1469,
         | 
| 65 | 
            +
                    "PUNCTUATION MAQAF": 1470,
         | 
| 66 | 
            +
                    "RAFE": 1471,
         | 
| 67 | 
            +
                    "PUNCTUATION PASEQ": 1472,
         | 
| 68 | 
            +
                    "SHIN_YEMANIT": 1473,
         | 
| 69 | 
            +
                    "SHIN_SMALIT": 1474,
         | 
| 70 | 
            +
                }
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                skip_nikud = (
         | 
| 73 | 
            +
                    []
         | 
| 74 | 
            +
                )  # [nikud_dict["KAMATZ_KATAN"], nikud_dict["HOLAM HASER VAV"], nikud_dict["METEG"], nikud_dict["PUNCTUATION MAQAF"], nikud_dict["PUNCTUATION PASEQ"]]
         | 
| 75 | 
            +
                sign_2_name = {sign: name for name, sign in nikud_dict.items()}
         | 
| 76 | 
            +
                sin = [nikud_dict["RAFE"], nikud_dict["SHIN_YEMANIT"], nikud_dict["SHIN_SMALIT"]]
         | 
| 77 | 
            +
                dagesh = [
         | 
| 78 | 
            +
                    nikud_dict["RAFE"],
         | 
| 79 | 
            +
                    nikud_dict["DAGESH OR SHURUK"],
         | 
| 80 | 
            +
                ]  # note that DAGESH and SHURUK are one and the same
         | 
| 81 | 
            +
                nikud = []
         | 
| 82 | 
            +
                for v in nikud_dict.values():
         | 
| 83 | 
            +
                    if v not in sin and v not in skip_nikud:
         | 
| 84 | 
            +
                        nikud.append(v)
         | 
| 85 | 
            +
                all_nikud_ord = {v for v in nikud_dict.values()}
         | 
| 86 | 
            +
                all_nikud_chr = {chr(v) for v in nikud_dict.values()}
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                label_2_id = {
         | 
| 89 | 
            +
                    "nikud": {label: i for i, label in enumerate(nikud + ["WITHOUT"])},
         | 
| 90 | 
            +
                    "dagesh": {label: i for i, label in enumerate(dagesh + ["WITHOUT"])},
         | 
| 91 | 
            +
                    "sin": {label: i for i, label in enumerate(sin + ["WITHOUT"])},
         | 
| 92 | 
            +
                }
         | 
| 93 | 
            +
                id_2_label = {
         | 
| 94 | 
            +
                    "nikud": {i: label for i, label in enumerate(nikud + ["WITHOUT"])},
         | 
| 95 | 
            +
                    "dagesh": {i: label for i, label in enumerate(dagesh + ["WITHOUT"])},
         | 
| 96 | 
            +
                    "sin": {i: label for i, label in enumerate(sin + ["WITHOUT"])},
         | 
| 97 | 
            +
                }
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                DAGESH_LETTER = nikud_dict["DAGESH OR SHURUK"]
         | 
| 100 | 
            +
                RAFE = nikud_dict["RAFE"]
         | 
| 101 | 
            +
                PAD_OR_IRRELEVANT = -1
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                LEN_NIKUD = len(label_2_id["nikud"])
         | 
| 104 | 
            +
                LEN_DAGESH = len(label_2_id["dagesh"])
         | 
| 105 | 
            +
                LEN_SIN = len(label_2_id["sin"])
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def id_2_char(self, c, class_type):
         | 
| 108 | 
            +
                    if c == -1:
         | 
| 109 | 
            +
                        return ""
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    label = self.id_2_label[class_type][c]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if label != "WITHOUT":
         | 
| 114 | 
            +
                        print("Label =", chr(self.id_2_label[class_type][c]))
         | 
| 115 | 
            +
                        return chr(self.id_2_label[class_type][c])
         | 
| 116 | 
            +
                    return ""
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class Letters:
         | 
| 120 | 
            +
                hebrew = [chr(c) for c in range(0x05D0, 0x05EA + 1)]
         | 
| 121 | 
            +
                VALID_LETTERS = [
         | 
| 122 | 
            +
                    " ",
         | 
| 123 | 
            +
                    "!",
         | 
| 124 | 
            +
                    '"',
         | 
| 125 | 
            +
                    "'",
         | 
| 126 | 
            +
                    "(",
         | 
| 127 | 
            +
                    ")",
         | 
| 128 | 
            +
                    ",",
         | 
| 129 | 
            +
                    "-",
         | 
| 130 | 
            +
                    ".",
         | 
| 131 | 
            +
                    ":",
         | 
| 132 | 
            +
                    ";",
         | 
| 133 | 
            +
                    "?",
         | 
| 134 | 
            +
                ] + hebrew
         | 
| 135 | 
            +
                SPECIAL_TOKENS = ["H", "O", "5", "1"]
         | 
| 136 | 
            +
                ENDINGS_TO_REGULAR = dict(zip("ךםןףץ", "כמנפצ"))
         | 
| 137 | 
            +
                vocab = VALID_LETTERS + SPECIAL_TOKENS
         | 
| 138 | 
            +
                vocab_size = len(vocab)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
             | 
| 141 | 
            +
            class Letter:
         | 
| 142 | 
            +
                def __init__(self, letter):
         | 
| 143 | 
            +
                    self.letter = letter
         | 
| 144 | 
            +
                    self.normalized = None
         | 
| 145 | 
            +
                    self.dagesh = None
         | 
| 146 | 
            +
                    self.sin = None
         | 
| 147 | 
            +
                    self.nikud = None
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def normalize(self, letter):
         | 
| 150 | 
            +
                    if letter in Letters.VALID_LETTERS:
         | 
| 151 | 
            +
                        return letter
         | 
| 152 | 
            +
                    if letter in Letters.ENDINGS_TO_REGULAR:
         | 
| 153 | 
            +
                        return Letters.ENDINGS_TO_REGULAR[letter]
         | 
| 154 | 
            +
                    if letter in ["\n", "\t"]:
         | 
| 155 | 
            +
                        return " "
         | 
| 156 | 
            +
                    if letter in ["‒", "–", "—", "―", "−", "+"]:
         | 
| 157 | 
            +
                        return "-"
         | 
| 158 | 
            +
                    if letter == "[":
         | 
| 159 | 
            +
                        return "("
         | 
| 160 | 
            +
                    if letter == "]":
         | 
| 161 | 
            +
                        return ")"
         | 
| 162 | 
            +
                    if letter in ["´", "‘", "’"]:
         | 
| 163 | 
            +
                        return "'"
         | 
| 164 | 
            +
                    if letter in ["“", "”", "״"]:
         | 
| 165 | 
            +
                        return '"'
         | 
| 166 | 
            +
                    if letter.isdigit():
         | 
| 167 | 
            +
                        if int(letter) == 1:
         | 
| 168 | 
            +
                            return "1"
         | 
| 169 | 
            +
                        else:
         | 
| 170 | 
            +
                            return "5"
         | 
| 171 | 
            +
                    if letter == "…":
         | 
| 172 | 
            +
                        return ","
         | 
| 173 | 
            +
                    if letter in ["ײ", "װ", "ױ"]:
         | 
| 174 | 
            +
                        return "H"
         | 
| 175 | 
            +
                    return "O"
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def can_dagesh(self, letter):
         | 
| 178 | 
            +
                    return letter in ("בגדהוזטיכלמנספצקשת" + "ךף")
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def can_sin(self, letter):
         | 
| 181 | 
            +
                    return letter == "ש"
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def can_nikud(self, letter):
         | 
| 184 | 
            +
                    return letter in ("אבגדהוזחטיכלמנסעפצקרשת" + "ךן")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def get_label_letter(self, labels):
         | 
| 187 | 
            +
                    dagesh_sin_nikud = [
         | 
| 188 | 
            +
                        True if self.can_dagesh(self.letter) else False,
         | 
| 189 | 
            +
                        True if self.can_sin(self.letter) else False,
         | 
| 190 | 
            +
                        True if self.can_nikud(self.letter) else False,
         | 
| 191 | 
            +
                    ]
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    labels_ids = {
         | 
| 194 | 
            +
                        "nikud": Nikud.PAD_OR_IRRELEVANT,
         | 
| 195 | 
            +
                        "dagesh": Nikud.PAD_OR_IRRELEVANT,
         | 
| 196 | 
            +
                        "sin": Nikud.PAD_OR_IRRELEVANT,
         | 
| 197 | 
            +
                    }
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    normalized = self.normalize(self.letter)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    i = 0
         | 
| 202 | 
            +
                    if Nikud.nikud_dict["PUNCTUATION PASEQ"] in labels:
         | 
| 203 | 
            +
                        labels.remove(Nikud.nikud_dict["PUNCTUATION PASEQ"])
         | 
| 204 | 
            +
                    if Nikud.nikud_dict["PUNCTUATION MAQAF"] in labels:
         | 
| 205 | 
            +
                        labels.remove(Nikud.nikud_dict["PUNCTUATION MAQAF"])
         | 
| 206 | 
            +
                    if Nikud.nikud_dict["HOLAM HASER VAV"] in labels:
         | 
| 207 | 
            +
                        labels.remove(Nikud.nikud_dict["HOLAM HASER VAV"])
         | 
| 208 | 
            +
                    if Nikud.nikud_dict["METEG"] in labels:
         | 
| 209 | 
            +
                        labels.remove(Nikud.nikud_dict["METEG"])
         | 
| 210 | 
            +
                    if Nikud.nikud_dict["KAMATZ_KATAN"] in labels:
         | 
| 211 | 
            +
                        labels[labels.index(Nikud.nikud_dict["KAMATZ_KATAN"])] = Nikud.nikud_dict[
         | 
| 212 | 
            +
                            "KAMATZ"
         | 
| 213 | 
            +
                        ]
         | 
| 214 | 
            +
                    for index, (class_name, group) in enumerate(
         | 
| 215 | 
            +
                        zip(
         | 
| 216 | 
            +
                            ["dagesh", "sin", "nikud"],
         | 
| 217 | 
            +
                            [[Nikud.DAGESH_LETTER], Nikud.sin, Nikud.nikud],
         | 
| 218 | 
            +
                        )
         | 
| 219 | 
            +
                    ):
         | 
| 220 | 
            +
                        # notice - order is important: dagesh then sin and then nikud
         | 
| 221 | 
            +
                        if dagesh_sin_nikud[index]:
         | 
| 222 | 
            +
                            if i < len(labels) and labels[i] in group:
         | 
| 223 | 
            +
                                labels_ids[class_name] = Nikud.label_2_id[class_name][labels[i]]
         | 
| 224 | 
            +
                                i += 1
         | 
| 225 | 
            +
                            else:
         | 
| 226 | 
            +
                                labels_ids[class_name] = Nikud.label_2_id[class_name]["WITHOUT"]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    if (
         | 
| 229 | 
            +
                        np.array(dagesh_sin_nikud).all()
         | 
| 230 | 
            +
                        and len(labels) == 3
         | 
| 231 | 
            +
                        and labels[0] in Nikud.sin
         | 
| 232 | 
            +
                    ):
         | 
| 233 | 
            +
                        labels_ids["nikud"] = Nikud.label_2_id["nikud"][labels[2]]
         | 
| 234 | 
            +
                        labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    if (
         | 
| 237 | 
            +
                        self.can_sin(self.letter)
         | 
| 238 | 
            +
                        and len(labels) == 2
         | 
| 239 | 
            +
                        and labels[1] == Nikud.DAGESH_LETTER
         | 
| 240 | 
            +
                    ):
         | 
| 241 | 
            +
                        labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
         | 
| 242 | 
            +
                        labels_ids["nikud"] = Nikud.label_2_id[class_name]["WITHOUT"]
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if (
         | 
| 245 | 
            +
                        self.letter == "ו"
         | 
| 246 | 
            +
                        and labels_ids["dagesh"] == Nikud.DAGESH_LETTER
         | 
| 247 | 
            +
                        and labels_ids["nikud"] == Nikud.label_2_id["nikud"]["WITHOUT"]
         | 
| 248 | 
            +
                    ):
         | 
| 249 | 
            +
                        labels_ids["dagesh"] = Nikud.label_2_id["dagesh"]["WITHOUT"]
         | 
| 250 | 
            +
                        labels_ids["nikud"] = Nikud.DAGESH_LETTER
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    self.normalized = normalized
         | 
| 253 | 
            +
                    self.dagesh = labels_ids["dagesh"]
         | 
| 254 | 
            +
                    self.sin = labels_ids["sin"]
         | 
| 255 | 
            +
                    self.nikud = labels_ids["nikud"]
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def name_of(self, letter):
         | 
| 258 | 
            +
                    if "א" <= letter <= "ת":
         | 
| 259 | 
            +
                        return letter
         | 
| 260 | 
            +
                    if letter == Nikud.DAGESH_LETTER:
         | 
| 261 | 
            +
                        return "דגש\שורוק"
         | 
| 262 | 
            +
                    if letter == Nikud.KAMATZ:
         | 
| 263 | 
            +
                        return "קמץ"
         | 
| 264 | 
            +
                    if letter == Nikud.PATAKH:
         | 
| 265 | 
            +
                        return "פתח"
         | 
| 266 | 
            +
                    if letter == Nikud.TZEIRE:
         | 
| 267 | 
            +
                        return "צירה"
         | 
| 268 | 
            +
                    if letter == Nikud.SEGOL:
         | 
| 269 | 
            +
                        return "סגול"
         | 
| 270 | 
            +
                    if letter == Nikud.SHVA:
         | 
| 271 | 
            +
                        return "שוא"
         | 
| 272 | 
            +
                    if letter == Nikud.HOLAM:
         | 
| 273 | 
            +
                        return "חולם"
         | 
| 274 | 
            +
                    if letter == Nikud.KUBUTZ:
         | 
| 275 | 
            +
                        return "קובוץ"
         | 
| 276 | 
            +
                    if letter == Nikud.HIRIK:
         | 
| 277 | 
            +
                        return "חיריק"
         | 
| 278 | 
            +
                    if letter == Nikud.REDUCED_KAMATZ:
         | 
| 279 | 
            +
                        return "חטף-קמץ"
         | 
| 280 | 
            +
                    if letter == Nikud.REDUCED_PATAKH:
         | 
| 281 | 
            +
                        return "חטף-פתח"
         | 
| 282 | 
            +
                    if letter == Nikud.REDUCED_SEGOL:
         | 
| 283 | 
            +
                        return "חטף-סגול"
         | 
| 284 | 
            +
                    if letter == Nikud.SHIN_SMALIT:
         | 
| 285 | 
            +
                        return "שין-שמאלית"
         | 
| 286 | 
            +
                    if letter == Nikud.SHIN_YEMANIT:
         | 
| 287 | 
            +
                        return "שין-ימנית"
         | 
| 288 | 
            +
                    if letter.isprintable():
         | 
| 289 | 
            +
                        return letter
         | 
| 290 | 
            +
                    return "לא ידוע ({})".format(hex(ord(letter)))
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            def text_contains_nikud(text):
         | 
| 294 | 
            +
                return len(set(text) & Nikud.all_nikud_chr) > 0
         | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 297 | 
            +
            def combine_sentences(list_sentences, max_length=0, is_train=False):
         | 
| 298 | 
            +
                all_new_sentences = []
         | 
| 299 | 
            +
                new_sen = ""
         | 
| 300 | 
            +
                index = 0
         | 
| 301 | 
            +
                while index < len(list_sentences):
         | 
| 302 | 
            +
                    sen = list_sentences[index]
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    if not text_contains_nikud(sen) and (
         | 
| 305 | 
            +
                        "------------------" in sen or sen == "\n"
         | 
| 306 | 
            +
                    ):
         | 
| 307 | 
            +
                        if len(new_sen) > 0:
         | 
| 308 | 
            +
                            all_new_sentences.append(new_sen)
         | 
| 309 | 
            +
                            if not is_train:
         | 
| 310 | 
            +
                                all_new_sentences.append(sen)
         | 
| 311 | 
            +
                            new_sen = ""
         | 
| 312 | 
            +
                            index += 1
         | 
| 313 | 
            +
                            continue
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    if not text_contains_nikud(sen) and is_train:
         | 
| 316 | 
            +
                        index += 1
         | 
| 317 | 
            +
                        continue
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    if len(sen) > max_length:
         | 
| 320 | 
            +
                        update_sen = sen.replace(". ", f". {unique_key}")
         | 
| 321 | 
            +
                        update_sen = update_sen.replace("? ", f"? {unique_key}")
         | 
| 322 | 
            +
                        update_sen = update_sen.replace("! ", f"! {unique_key}")
         | 
| 323 | 
            +
                        update_sen = update_sen.replace("” ", f"” {unique_key}")
         | 
| 324 | 
            +
                        update_sen = update_sen.replace("\t", f"\t{unique_key}")
         | 
| 325 | 
            +
                        part_sentence = update_sen.split(unique_key)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                        good_parts = []
         | 
| 328 | 
            +
                        for p in part_sentence:
         | 
| 329 | 
            +
                            if len(p) < max_length:
         | 
| 330 | 
            +
                                good_parts.append(p)
         | 
| 331 | 
            +
                            else:
         | 
| 332 | 
            +
                                prev = 0
         | 
| 333 | 
            +
                                while prev <= len(p):
         | 
| 334 | 
            +
                                    part = p[prev : (prev + max_length)]
         | 
| 335 | 
            +
                                    last_space = 0
         | 
| 336 | 
            +
                                    if " " in part:
         | 
| 337 | 
            +
                                        last_space = part[::-1].index(" ") + 1
         | 
| 338 | 
            +
                                    next = prev + max_length - last_space
         | 
| 339 | 
            +
                                    part = p[prev:next]
         | 
| 340 | 
            +
                                    good_parts.append(part)
         | 
| 341 | 
            +
                                    prev = next
         | 
| 342 | 
            +
                        list_sentences = (
         | 
| 343 | 
            +
                            list_sentences[:index] + good_parts + list_sentences[index + 1 :]
         | 
| 344 | 
            +
                        )
         | 
| 345 | 
            +
                        continue
         | 
| 346 | 
            +
                    if new_sen == "":
         | 
| 347 | 
            +
                        new_sen = sen
         | 
| 348 | 
            +
                    elif len(new_sen) + len(sen) < max_length:
         | 
| 349 | 
            +
                        new_sen += sen
         | 
| 350 | 
            +
                    else:
         | 
| 351 | 
            +
                        all_new_sentences.append(new_sen)
         | 
| 352 | 
            +
                        new_sen = sen
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    index += 1
         | 
| 355 | 
            +
                if len(new_sen) > 0:
         | 
| 356 | 
            +
                    all_new_sentences.append(new_sen)
         | 
| 357 | 
            +
                return all_new_sentences
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            class NikudDataset(Dataset):
         | 
| 361 | 
            +
                def __init__(
         | 
| 362 | 
            +
                    self,
         | 
| 363 | 
            +
                    tokenizer,
         | 
| 364 | 
            +
                    folder=None,
         | 
| 365 | 
            +
                    file=None,
         | 
| 366 | 
            +
                    logger=None,
         | 
| 367 | 
            +
                    max_length=0,
         | 
| 368 | 
            +
                    is_train=False,
         | 
| 369 | 
            +
                ):
         | 
| 370 | 
            +
                    self.max_length = max_length
         | 
| 371 | 
            +
                    self.tokenizer = tokenizer
         | 
| 372 | 
            +
                    self.is_train = is_train
         | 
| 373 | 
            +
                    self.data = None
         | 
| 374 | 
            +
                    self.origin_data = None
         | 
| 375 | 
            +
                    if folder is not None:
         | 
| 376 | 
            +
                        self.data, self.origin_data = self.read_data_folder(folder, logger)
         | 
| 377 | 
            +
                    elif file is not None:
         | 
| 378 | 
            +
                        self.data, self.origin_data = self.read_data(file, logger)
         | 
| 379 | 
            +
                    self.prepered_data = None
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                def read_data_folder(self, folder_path: str, logger=None):
         | 
| 382 | 
            +
                    all_files = glob2.glob(f"{folder_path}/**/*.txt", recursive=True)
         | 
| 383 | 
            +
                    msg = f"number of files: " + str(len(all_files))
         | 
| 384 | 
            +
                    if logger:
         | 
| 385 | 
            +
                        logger.debug(msg)
         | 
| 386 | 
            +
                    else:
         | 
| 387 | 
            +
                        print(msg)
         | 
| 388 | 
            +
                    all_data = []
         | 
| 389 | 
            +
                    all_origin_data = []
         | 
| 390 | 
            +
                    if DEBUG_MODE:
         | 
| 391 | 
            +
                        all_files = all_files[0:2]
         | 
| 392 | 
            +
                    for file in all_files:
         | 
| 393 | 
            +
                        if "not_use" in file or "NakdanResults" in file:
         | 
| 394 | 
            +
                            continue
         | 
| 395 | 
            +
                        data, origin_data = self.read_data(file, logger)
         | 
| 396 | 
            +
                        all_data.extend(data)
         | 
| 397 | 
            +
                        all_origin_data.extend(origin_data)
         | 
| 398 | 
            +
                    return all_data, all_origin_data
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                def read_data(self, filepath: str, logger=None) -> List[Tuple[str, list]]:
         | 
| 401 | 
            +
                    msg = f"read file: {filepath}"
         | 
| 402 | 
            +
                    if logger:
         | 
| 403 | 
            +
                        logger.debug(msg)
         | 
| 404 | 
            +
                    else:
         | 
| 405 | 
            +
                        print(msg)
         | 
| 406 | 
            +
                    data = []
         | 
| 407 | 
            +
                    orig_data = []
         | 
| 408 | 
            +
                    with open(filepath, "r", encoding="utf-8") as file:
         | 
| 409 | 
            +
                        file_data = file.read()
         | 
| 410 | 
            +
                    data_list = self.split_text(file_data)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    for sen in tqdm(data_list, desc=f"Source: {os.path.basename(filepath)}"):
         | 
| 413 | 
            +
                        if sen == "":
         | 
| 414 | 
            +
                            continue
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        labels = []
         | 
| 417 | 
            +
                        text = ""
         | 
| 418 | 
            +
                        text_org = ""
         | 
| 419 | 
            +
                        index = 0
         | 
| 420 | 
            +
                        sentence_length = len(sen)
         | 
| 421 | 
            +
                        while index < sentence_length:
         | 
| 422 | 
            +
                            if (
         | 
| 423 | 
            +
                                ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
         | 
| 424 | 
            +
                                or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
         | 
| 425 | 
            +
                                or ord(sen[index]) == Nikud.nikud_dict["METEG"]
         | 
| 426 | 
            +
                            ):
         | 
| 427 | 
            +
                                index += 1
         | 
| 428 | 
            +
                                continue
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                            label = []
         | 
| 431 | 
            +
                            l = Letter(sen[index])
         | 
| 432 | 
            +
                            if not (l.letter not in Nikud.all_nikud_chr):
         | 
| 433 | 
            +
                                if sen[index - 1] == "\n":
         | 
| 434 | 
            +
                                    index += 1
         | 
| 435 | 
            +
                                    continue
         | 
| 436 | 
            +
                            assert l.letter not in Nikud.all_nikud_chr
         | 
| 437 | 
            +
                            if sen[index] in Letters.hebrew:
         | 
| 438 | 
            +
                                index += 1
         | 
| 439 | 
            +
                                while (
         | 
| 440 | 
            +
                                    index < sentence_length
         | 
| 441 | 
            +
                                    and ord(sen[index]) in Nikud.all_nikud_ord
         | 
| 442 | 
            +
                                ):
         | 
| 443 | 
            +
                                    label.append(ord(sen[index]))
         | 
| 444 | 
            +
                                    index += 1
         | 
| 445 | 
            +
                            else:
         | 
| 446 | 
            +
                                index += 1
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                            l.get_label_letter(label)
         | 
| 449 | 
            +
                            text += l.normalized
         | 
| 450 | 
            +
                            text_org += l.letter
         | 
| 451 | 
            +
                            labels.append(l)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                        data.append((text, labels))
         | 
| 454 | 
            +
                        orig_data.append(text_org)
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    return data, orig_data
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                def read_single_text(self, text: str, logger=None) -> List[Tuple[str, list]]:
         | 
| 459 | 
            +
                    # msg = f"read file: {filepath}"
         | 
| 460 | 
            +
                    # if logger:
         | 
| 461 | 
            +
                    #     logger.debug(msg)
         | 
| 462 | 
            +
                    # else:
         | 
| 463 | 
            +
                    #     print(msg)
         | 
| 464 | 
            +
                    data = []
         | 
| 465 | 
            +
                    orig_data = []
         | 
| 466 | 
            +
                    # with open(filepath, "r", encoding="utf-8") as file:
         | 
| 467 | 
            +
                    #     file_data = file.read()
         | 
| 468 | 
            +
                    data_list = self.split_text(text)
         | 
| 469 | 
            +
                    # print("data_list", data_list)
         | 
| 470 | 
            +
                    for sen in tqdm(data_list, desc=f"Source: {data}"):
         | 
| 471 | 
            +
                        if sen == "":
         | 
| 472 | 
            +
                            continue
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                        labels = []
         | 
| 475 | 
            +
                        text = ""
         | 
| 476 | 
            +
                        text_org = ""
         | 
| 477 | 
            +
                        index = 0
         | 
| 478 | 
            +
                        sentence_length = len(sen)
         | 
| 479 | 
            +
                        while index < sentence_length:
         | 
| 480 | 
            +
                            if (
         | 
| 481 | 
            +
                                ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
         | 
| 482 | 
            +
                                or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
         | 
| 483 | 
            +
                                or ord(sen[index]) == Nikud.nikud_dict["METEG"]
         | 
| 484 | 
            +
                            ):
         | 
| 485 | 
            +
                                index += 1
         | 
| 486 | 
            +
                                continue
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                            label = []
         | 
| 489 | 
            +
                            l = Letter(sen[index])
         | 
| 490 | 
            +
                            if not (l.letter not in Nikud.all_nikud_chr):
         | 
| 491 | 
            +
                                if sen[index - 1] == "\n":
         | 
| 492 | 
            +
                                    index += 1
         | 
| 493 | 
            +
                                    continue
         | 
| 494 | 
            +
                            assert l.letter not in Nikud.all_nikud_chr
         | 
| 495 | 
            +
                            if sen[index] in Letters.hebrew:
         | 
| 496 | 
            +
                                index += 1
         | 
| 497 | 
            +
                                while (
         | 
| 498 | 
            +
                                    index < sentence_length
         | 
| 499 | 
            +
                                    and ord(sen[index]) in Nikud.all_nikud_ord
         | 
| 500 | 
            +
                                ):
         | 
| 501 | 
            +
                                    label.append(ord(sen[index]))
         | 
| 502 | 
            +
                                    index += 1
         | 
| 503 | 
            +
                            else:
         | 
| 504 | 
            +
                                index += 1
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                            l.get_label_letter(label)
         | 
| 507 | 
            +
                            text += l.normalized
         | 
| 508 | 
            +
                            text_org += l.letter
         | 
| 509 | 
            +
                            labels.append(l)
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                        data.append((text, labels))
         | 
| 512 | 
            +
                        orig_data.append(text_org)
         | 
| 513 | 
            +
                    self.data = data
         | 
| 514 | 
            +
                    self.origin_data = orig_data
         | 
| 515 | 
            +
                    return data, orig_data
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                def split_text(self, file_data):
         | 
| 518 | 
            +
                    file_data = file_data.replace("\n", f"\n{unique_key}")
         | 
| 519 | 
            +
                    data_list = file_data.split(unique_key)
         | 
| 520 | 
            +
                    data_list = combine_sentences(
         | 
| 521 | 
            +
                        data_list, is_train=self.is_train, max_length=MAX_LENGTH_SEN
         | 
| 522 | 
            +
                    )
         | 
| 523 | 
            +
                    return data_list
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                def show_data_labels(self, plots_folder=None):
         | 
| 526 | 
            +
                    nikud = [
         | 
| 527 | 
            +
                        Nikud.id_2_label["nikud"][label.nikud]
         | 
| 528 | 
            +
                        for _, label_list in self.data
         | 
| 529 | 
            +
                        for label in label_list
         | 
| 530 | 
            +
                        if label.nikud != -1
         | 
| 531 | 
            +
                    ]
         | 
| 532 | 
            +
                    dagesh = [
         | 
| 533 | 
            +
                        Nikud.id_2_label["dagesh"][label.dagesh]
         | 
| 534 | 
            +
                        for _, label_list in self.data
         | 
| 535 | 
            +
                        for label in label_list
         | 
| 536 | 
            +
                        if label.dagesh != -1
         | 
| 537 | 
            +
                    ]
         | 
| 538 | 
            +
                    sin = [
         | 
| 539 | 
            +
                        Nikud.id_2_label["sin"][label.sin]
         | 
| 540 | 
            +
                        for _, label_list in self.data
         | 
| 541 | 
            +
                        for label in label_list
         | 
| 542 | 
            +
                        if label.sin != -1
         | 
| 543 | 
            +
                    ]
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    vowels = nikud + dagesh + sin
         | 
| 546 | 
            +
                    unique_vowels, label_counts = np.unique(vowels, return_counts=True)
         | 
| 547 | 
            +
                    unique_vowels_names = [
         | 
| 548 | 
            +
                        Nikud.sign_2_name[int(vowel)]
         | 
| 549 | 
            +
                        for vowel in unique_vowels
         | 
| 550 | 
            +
                        if vowel != "WITHOUT"
         | 
| 551 | 
            +
                    ] + ["WITHOUT"]
         | 
| 552 | 
            +
                    fig, ax = plt.subplots(figsize=(16, 6))
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                    bar_positions = np.arange(len(unique_vowels))
         | 
| 555 | 
            +
                    bar_width = 0.15
         | 
| 556 | 
            +
                    ax.bar(bar_positions, list(label_counts), bar_width)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    ax.set_title("Distribution of Vowels in dataset")
         | 
| 559 | 
            +
                    ax.set_xlabel("Vowels")
         | 
| 560 | 
            +
                    ax.set_ylabel("Count")
         | 
| 561 | 
            +
                    ax.legend(loc="right", bbox_to_anchor=(1, 0.85))
         | 
| 562 | 
            +
                    ax.set_xticks(bar_positions)
         | 
| 563 | 
            +
                    ax.set_xticklabels(unique_vowels_names, rotation=30, ha="right", fontsize=8)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    if plots_folder is None:
         | 
| 566 | 
            +
                        plt.show()
         | 
| 567 | 
            +
                    else:
         | 
| 568 | 
            +
                        plt.savefig(os.path.join(plots_folder, "show_data_labels.jpg"))
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                def calc_max_length(self, maximum=MAX_LENGTH_SEN):
         | 
| 571 | 
            +
                    if self.max_length > maximum:
         | 
| 572 | 
            +
                        self.max_length = maximum
         | 
| 573 | 
            +
                    return self.max_length
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                def prepare_data(self, name="train"):
         | 
| 576 | 
            +
                    dataset = []
         | 
| 577 | 
            +
                    for index, (sentence, label) in tqdm(
         | 
| 578 | 
            +
                        enumerate(self.data), desc=f"prepare data {name}"
         | 
| 579 | 
            +
                    ):
         | 
| 580 | 
            +
                        encoded_sequence = self.tokenizer.encode_plus(
         | 
| 581 | 
            +
                            sentence,
         | 
| 582 | 
            +
                            add_special_tokens=True,
         | 
| 583 | 
            +
                            max_length=self.max_length,
         | 
| 584 | 
            +
                            padding="max_length",
         | 
| 585 | 
            +
                            truncation=True,
         | 
| 586 | 
            +
                            return_attention_mask=True,
         | 
| 587 | 
            +
                            return_tensors="pt",
         | 
| 588 | 
            +
                        )
         | 
| 589 | 
            +
                        label_lists = [
         | 
| 590 | 
            +
                            [letter.nikud, letter.dagesh, letter.sin] for letter in label
         | 
| 591 | 
            +
                        ]
         | 
| 592 | 
            +
                        label = torch.tensor(
         | 
| 593 | 
            +
                            [
         | 
| 594 | 
            +
                                [
         | 
| 595 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 596 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 597 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 598 | 
            +
                                ]
         | 
| 599 | 
            +
                            ]
         | 
| 600 | 
            +
                            + label_lists[: (self.max_length - 1)]
         | 
| 601 | 
            +
                            + [
         | 
| 602 | 
            +
                                [
         | 
| 603 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 604 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 605 | 
            +
                                    Nikud.PAD_OR_IRRELEVANT,
         | 
| 606 | 
            +
                                ]
         | 
| 607 | 
            +
                                for i in range(self.max_length - len(label) - 1)
         | 
| 608 | 
            +
                            ]
         | 
| 609 | 
            +
                        )
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                        dataset.append(
         | 
| 612 | 
            +
                            (
         | 
| 613 | 
            +
                                encoded_sequence["input_ids"][0],
         | 
| 614 | 
            +
                                encoded_sequence["attention_mask"][0],
         | 
| 615 | 
            +
                                label,
         | 
| 616 | 
            +
                            )
         | 
| 617 | 
            +
                        )
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    self.prepered_data = dataset
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                def back_2_text(self, labels):
         | 
| 622 | 
            +
                    nikud = Nikud()
         | 
| 623 | 
            +
                    all_text = ""
         | 
| 624 | 
            +
                    for indx_sentance, (input_ids, _, label) in enumerate(self.prepered_data):
         | 
| 625 | 
            +
                        new_line = ""
         | 
| 626 | 
            +
                        for indx_char, c in enumerate(self.origin_data[indx_sentance]):
         | 
| 627 | 
            +
                            new_line += (
         | 
| 628 | 
            +
                                c
         | 
| 629 | 
            +
                                + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 1], "dagesh")
         | 
| 630 | 
            +
                                + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 2], "sin")
         | 
| 631 | 
            +
                                + nikud.id_2_char(labels[indx_sentance, indx_char + 1, 0], "nikud")
         | 
| 632 | 
            +
                            )
         | 
| 633 | 
            +
                        all_text += new_line
         | 
| 634 | 
            +
                    return all_text
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                def __len__(self):
         | 
| 637 | 
            +
                    return self.data.shape[0]
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                def __getitem__(self, idx):
         | 
| 640 | 
            +
                    row = self.data[idx]
         | 
| 641 | 
            +
             | 
| 642 | 
            +
             | 
| 643 | 
            +
            def get_sub_folders_paths(main_folder):
         | 
| 644 | 
            +
                list_paths = []
         | 
| 645 | 
            +
                for filename in os.listdir(main_folder):
         | 
| 646 | 
            +
                    path = os.path.join(main_folder, filename)
         | 
| 647 | 
            +
                    if os.path.isdir(path) and filename != ".git":
         | 
| 648 | 
            +
                        list_paths.append(path)
         | 
| 649 | 
            +
                        list_paths.extend(get_sub_folders_paths(path))
         | 
| 650 | 
            +
                return list_paths
         | 
| 651 | 
            +
             | 
| 652 | 
            +
             | 
| 653 | 
            +
            def create_missing_folders(folder_path):
         | 
| 654 | 
            +
                # Check if the folder doesn't exist and create it if needed
         | 
| 655 | 
            +
                if not os.path.exists(folder_path):
         | 
| 656 | 
            +
                    os.makedirs(folder_path)
         | 
| 657 | 
            +
             | 
| 658 | 
            +
             | 
| 659 | 
            +
            def info_folder(folder, num_files, num_hebrew_letters):
         | 
| 660 | 
            +
                """
         | 
| 661 | 
            +
                Recursively counts the number of files and the number of Hebrew letters in all subfolders of the given folder path.
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                Args:
         | 
| 664 | 
            +
                    folder (str): The path of the folder to be analyzed.
         | 
| 665 | 
            +
                    num_files (int): The running total of the number of files encountered so far.
         | 
| 666 | 
            +
                    num_hebrew_letters (int): The running total of the number of Hebrew letters encountered so far.
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                Returns:
         | 
| 669 | 
            +
                    Tuple[int, int]: A tuple containing the total number of files and the total number of Hebrew letters.
         | 
| 670 | 
            +
                """
         | 
| 671 | 
            +
                for filename in os.listdir(folder):
         | 
| 672 | 
            +
                    file_path = os.path.join(folder, filename)
         | 
| 673 | 
            +
                    if filename.lower().endswith(".txt") and os.path.isfile(file_path):
         | 
| 674 | 
            +
                        num_files += 1
         | 
| 675 | 
            +
                        dataset = NikudDataset(None, file=file_path)
         | 
| 676 | 
            +
                        for line in dataset.data:
         | 
| 677 | 
            +
                            for c in line[0]:
         | 
| 678 | 
            +
                                if c in Letters.hebrew:
         | 
| 679 | 
            +
                                    num_hebrew_letters += 1
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                    elif os.path.isdir(file_path) and filename != ".git":
         | 
| 682 | 
            +
                        sub_folder = file_path
         | 
| 683 | 
            +
                        n1, n2 = info_folder(sub_folder, num_files, num_hebrew_letters)
         | 
| 684 | 
            +
                        num_files += n1
         | 
| 685 | 
            +
                        num_hebrew_letters += n2
         | 
| 686 | 
            +
                return num_files, num_hebrew_letters
         | 
| 687 | 
            +
             | 
| 688 | 
            +
             | 
| 689 | 
            +
            def extract_text_to_compare_nakdimon(text):
         | 
| 690 | 
            +
                res = text.replace("|", "")
         | 
| 691 | 
            +
                res = res.replace(
         | 
| 692 | 
            +
                    chr(Nikud.nikud_dict["KUBUTZ"]) + "ו" + chr(Nikud.nikud_dict["METEG"]),
         | 
| 693 | 
            +
                    "ו" + chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
         | 
| 694 | 
            +
                )
         | 
| 695 | 
            +
                res = res.replace(
         | 
| 696 | 
            +
                    chr(Nikud.nikud_dict["HOLAM"]) + "ו" + chr(Nikud.nikud_dict["METEG"]), "ו"
         | 
| 697 | 
            +
                )
         | 
| 698 | 
            +
                res = res.replace(
         | 
| 699 | 
            +
                    "ו" + chr(Nikud.nikud_dict["HOLAM"]) + chr(Nikud.nikud_dict["KAMATZ"]),
         | 
| 700 | 
            +
                    "ו" + chr(Nikud.nikud_dict["KAMATZ"]),
         | 
| 701 | 
            +
                )
         | 
| 702 | 
            +
                res = res.replace(chr(Nikud.nikud_dict["METEG"]), "")
         | 
| 703 | 
            +
                res = res.replace(
         | 
| 704 | 
            +
                    chr(Nikud.nikud_dict["KAMATZ"]) + chr(Nikud.nikud_dict["HIRIK"]),
         | 
| 705 | 
            +
                    chr(Nikud.nikud_dict["KAMATZ"]) + "י" + chr(Nikud.nikud_dict["HIRIK"]),
         | 
| 706 | 
            +
                )
         | 
| 707 | 
            +
                res = res.replace(
         | 
| 708 | 
            +
                    chr(Nikud.nikud_dict["PATAKH"]) + chr(Nikud.nikud_dict["HIRIK"]),
         | 
| 709 | 
            +
                    chr(Nikud.nikud_dict["PATAKH"]) + "י" + chr(Nikud.nikud_dict["HIRIK"]),
         | 
| 710 | 
            +
                )
         | 
| 711 | 
            +
                res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION MAQAF"]), "")
         | 
| 712 | 
            +
                res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION PASEQ"]), "")
         | 
| 713 | 
            +
                res = res.replace(
         | 
| 714 | 
            +
                    chr(Nikud.nikud_dict["KAMATZ_KATAN"]), chr(Nikud.nikud_dict["KAMATZ"])
         | 
| 715 | 
            +
                )
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                res = re.sub(chr(Nikud.nikud_dict["KUBUTZ"]) + "ו" + "(?=[א-ת])", "ו", res)
         | 
| 718 | 
            +
                res = res.replace(chr(Nikud.nikud_dict["REDUCED_KAMATZ"]) + "ו", "ו")
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                res = res.replace(
         | 
| 721 | 
            +
                    chr(Nikud.nikud_dict["DAGESH OR SHURUK"]) * 2,
         | 
| 722 | 
            +
                    chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
         | 
| 723 | 
            +
                )
         | 
| 724 | 
            +
                res = res.replace("\u05be", "-")
         | 
| 725 | 
            +
                res = res.replace("יְהוָֹה", "יהוה")
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                return res
         | 
| 728 | 
            +
             | 
| 729 | 
            +
             | 
| 730 | 
            +
            def orgenize_data(main_folder, logger):
         | 
| 731 | 
            +
                x = NikudDataset(None)
         | 
| 732 | 
            +
                x.delete_files(os.path.join(Path(main_folder).parent, "train"))
         | 
| 733 | 
            +
                x.delete_files(os.path.join(Path(main_folder).parent, "dev"))
         | 
| 734 | 
            +
                x.delete_files(os.path.join(Path(main_folder).parent, "test"))
         | 
| 735 | 
            +
                x.split_data(
         | 
| 736 | 
            +
                    main_folder, main_folder_name=os.path.basename(main_folder), logger=logger
         | 
| 737 | 
            +
                )
         |