Update tokenizer.py
Browse files- tokenizer.py +125 -33
    	
        tokenizer.py
    CHANGED
    
    | @@ -1,6 +1,4 @@ | |
| 1 | 
            -
            import os
         | 
| 2 | 
             
            import re
         | 
| 3 | 
            -
            import textwrap
         | 
| 4 | 
             
            from typing import List, Optional, Union, Dict, Any
         | 
| 5 | 
             
            from functools import cached_property
         | 
| 6 |  | 
| @@ -20,12 +18,10 @@ from tokenizers import Tokenizer | |
| 20 | 
             
            from tokenizers.pre_tokenizers import WhitespaceSplit
         | 
| 21 | 
             
            from tokenizers.processors import TemplateProcessing
         | 
| 22 |  | 
| 23 | 
            -
            from  | 
| 24 |  | 
| 25 | 
             
            import cutlet
         | 
| 26 |  | 
| 27 | 
            -
            # Funzioni di preprocessing del testo
         | 
| 28 | 
            -
             | 
| 29 | 
             
            def get_spacy_lang(lang):
         | 
| 30 | 
             
                if lang == "zh":
         | 
| 31 | 
             
                    return Chinese()
         | 
| @@ -39,36 +35,131 @@ def get_spacy_lang(lang): | |
| 39 | 
             
                    # For most languages, English does the job
         | 
| 40 | 
             
                    return English()
         | 
| 41 |  | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
                 | 
| 45 | 
            -
                 | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 48 | 
             
                    nlp.add_pipe("sentencizer")
         | 
| 49 | 
            -
                    doc = nlp(text)
         | 
| 50 | 
            -
                    for sentence in doc.sents:
         | 
| 51 | 
            -
                        if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
         | 
| 52 | 
            -
                            text_splits[-1] += " " + str(sentence)
         | 
| 53 | 
            -
                            text_splits[-1] = text_splits[-1].lstrip()
         | 
| 54 | 
            -
                        elif len(str(sentence)) > text_split_length:
         | 
| 55 | 
            -
                            for line in textwrap.wrap(
         | 
| 56 | 
            -
                                str(sentence),
         | 
| 57 | 
            -
                                width=text_split_length,
         | 
| 58 | 
            -
                                drop_whitespace=True,
         | 
| 59 | 
            -
                                break_on_hyphens=False,
         | 
| 60 | 
            -
                                tabsize=1,
         | 
| 61 | 
            -
                            ):
         | 
| 62 | 
            -
                                text_splits.append(str(line))
         | 
| 63 | 
            -
                        else:
         | 
| 64 | 
            -
                            text_splits.append(str(sentence))
         | 
| 65 |  | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
                 | 
| 69 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 70 |  | 
| 71 | 
            -
                 | 
|  | |
|  | |
| 72 |  | 
| 73 | 
             
            _whitespace_re = re.compile(r"\s+")
         | 
| 74 |  | 
| @@ -452,6 +543,7 @@ _ordinal_re = { | |
| 452 | 
             
                "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
         | 
| 453 | 
             
            }
         | 
| 454 | 
             
            _number_re = re.compile(r"[0-9]+")
         | 
|  | |
| 455 | 
             
            _currency_re = {
         | 
| 456 | 
             
                "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
         | 
| 457 | 
             
                "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
         | 
| @@ -681,7 +773,7 @@ class XTTSTokenizerFast(PreTrainedTokenizerFast): | |
| 681 | 
             
                        char_limit = self.char_limits.get(base_lang, 250)
         | 
| 682 |  | 
| 683 | 
             
                        # Clean and preprocess
         | 
| 684 | 
            -
                        text = self.preprocess_text(text, text_lang)
         | 
| 685 |  | 
| 686 | 
             
                        # Split text into sentences/chunks based on language
         | 
| 687 | 
             
                        chunk_list = split_sentence(text, base_lang, text_split_length=char_limit)
         | 
|  | |
|  | |
| 1 | 
             
            import re
         | 
|  | |
| 2 | 
             
            from typing import List, Optional, Union, Dict, Any
         | 
| 3 | 
             
            from functools import cached_property
         | 
| 4 |  | 
|  | |
| 18 | 
             
            from tokenizers.pre_tokenizers import WhitespaceSplit
         | 
| 19 | 
             
            from tokenizers.processors import TemplateProcessing
         | 
| 20 |  | 
| 21 | 
            +
            from auralis.models.xttsv2.components.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
         | 
| 22 |  | 
| 23 | 
             
            import cutlet
         | 
| 24 |  | 
|  | |
|  | |
| 25 | 
             
            def get_spacy_lang(lang):
         | 
| 26 | 
             
                if lang == "zh":
         | 
| 27 | 
             
                    return Chinese()
         | 
|  | |
| 35 | 
             
                    # For most languages, English does the job
         | 
| 36 | 
             
                    return English()
         | 
| 37 |  | 
| 38 | 
            +
             | 
| 39 | 
            +
            def find_best_split_point(text: str, target_pos: int, window_size: int = 30) -> int:
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
                Find best split point near target position considering punctuation and language markers.
         | 
| 42 | 
            +
                added for better sentence splitting in TTS.
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                # Define split markers by priority
         | 
| 45 | 
            +
                markers = [
         | 
| 46 | 
            +
                    # Strong breaks (longest pause)
         | 
| 47 | 
            +
                    (r'[.!?؟။။။]+[\s]*', 1.0),  # Periods, exclamation, question (multi-script)
         | 
| 48 | 
            +
                    (r'[\n\r]+\s*[\n\r]+', 1.0),  # Multiple newlines
         | 
| 49 | 
            +
                    (r'[:|;;:;][\s]*', 0.9),  # Colons, semicolons (multi-script)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Medium breaks
         | 
| 52 | 
            +
                    (r'[,,،、][\s]*', 0.8),  # Commas (multi-script)
         | 
| 53 | 
            +
                    (r'[)}\])】』»›》\s]+', 0.7),  # Closing brackets/parentheses
         | 
| 54 | 
            +
                    (r'[-—−]+[\s]*', 0.7),  # Dashes
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    # Weak breaks
         | 
| 57 | 
            +
                    (r'\s+[&+=/\s]+\s+', 0.6),  # Special characters with spaces
         | 
| 58 | 
            +
                    (r'[\s]+', 0.5),  # Any whitespace as last resort
         | 
| 59 | 
            +
                ]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # Calculate window boundaries
         | 
| 62 | 
            +
                start = max(0, target_pos - window_size)
         | 
| 63 | 
            +
                end = min(len(text), target_pos + window_size)
         | 
| 64 | 
            +
                window = text[start:end]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                best_pos = target_pos
         | 
| 67 | 
            +
                best_score = 0
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                for pattern, priority in markers:
         | 
| 70 | 
            +
                    matches = list(re.finditer(pattern, window))
         | 
| 71 | 
            +
                    for match in matches:
         | 
| 72 | 
            +
                        # Calculate position score based on distance from target
         | 
| 73 | 
            +
                        pos = start + match.end()
         | 
| 74 | 
            +
                        distance = abs(pos - target_pos)
         | 
| 75 | 
            +
                        distance_score = 1 - (distance / (window_size * 2))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        # Combine priority and position scores
         | 
| 78 | 
            +
                        score = priority * distance_score
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                        if score > best_score:
         | 
| 81 | 
            +
                            best_score = score
         | 
| 82 | 
            +
                            best_pos = pos
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                return best_pos
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def split_sentence(text: str, lang: str, text_split_length: int = 250) -> List[str]:
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                Enhanced sentence splitting with language awareness and optimal breakpoints.
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                Args:
         | 
| 92 | 
            +
                    text: Input text to split
         | 
| 93 | 
            +
                    lang: Language code
         | 
| 94 | 
            +
                    text_split_length: Target length for splits
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                Returns:
         | 
| 97 | 
            +
                    List of text splits optimized for TTS
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                text = text.strip()
         | 
| 100 | 
            +
                if len(text) <= text_split_length:
         | 
| 101 | 
            +
                    return [text]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                nlp = get_spacy_lang(lang)
         | 
| 104 | 
            +
                if "sentencizer" not in nlp.pipe_names:
         | 
| 105 | 
             
                    nlp.add_pipe("sentencizer")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 106 |  | 
| 107 | 
            +
                # Get base sentences using spaCy
         | 
| 108 | 
            +
                doc = nlp(text)
         | 
| 109 | 
            +
                sentences = list(doc.sents)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                splits = []
         | 
| 112 | 
            +
                current_split = []
         | 
| 113 | 
            +
                current_length = 0
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                for sent in sentences:
         | 
| 116 | 
            +
                    sentence_text = str(sent).strip()
         | 
| 117 | 
            +
                    sentence_length = len(sentence_text)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # If sentence fits in current split
         | 
| 120 | 
            +
                    if current_length + sentence_length <= text_split_length:
         | 
| 121 | 
            +
                        current_split.append(sentence_text)
         | 
| 122 | 
            +
                        current_length += sentence_length + 1
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    # Handle long sentences
         | 
| 125 | 
            +
                    elif sentence_length > text_split_length:
         | 
| 126 | 
            +
                        # Add current split if exists
         | 
| 127 | 
            +
                        if current_split:
         | 
| 128 | 
            +
                            splits.append(" ".join(current_split))
         | 
| 129 | 
            +
                            current_split = []
         | 
| 130 | 
            +
                            current_length = 0
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        # Split long sentence at optimal points
         | 
| 133 | 
            +
                        remaining = sentence_text
         | 
| 134 | 
            +
                        while len(remaining) > text_split_length:
         | 
| 135 | 
            +
                            split_pos = find_best_split_point(
         | 
| 136 | 
            +
                                remaining,
         | 
| 137 | 
            +
                                text_split_length,
         | 
| 138 | 
            +
                                window_size=30
         | 
| 139 | 
            +
                            )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                            # Add split and continue with remainder
         | 
| 142 | 
            +
                            splits.append(remaining[:split_pos].strip())
         | 
| 143 | 
            +
                            remaining = remaining[split_pos:].strip()
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        # Handle remaining text
         | 
| 146 | 
            +
                        if remaining:
         | 
| 147 | 
            +
                            current_split = [remaining]
         | 
| 148 | 
            +
                            current_length = len(remaining)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Start new split
         | 
| 151 | 
            +
                    else:
         | 
| 152 | 
            +
                        splits.append(" ".join(current_split))
         | 
| 153 | 
            +
                        current_split = [sentence_text]
         | 
| 154 | 
            +
                        current_length = sentence_length
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                # Add final split if needed
         | 
| 157 | 
            +
                if current_split:
         | 
| 158 | 
            +
                    splits.append(" ".join(current_split))
         | 
| 159 |  | 
| 160 | 
            +
                cleaned_sentences = [s[:-1]+' ' if s.endswith('.') else s for s in splits if s] # prevents annoying sounds in italian
         | 
| 161 | 
            +
                # Clean up splits
         | 
| 162 | 
            +
                return cleaned_sentences
         | 
| 163 |  | 
| 164 | 
             
            _whitespace_re = re.compile(r"\s+")
         | 
| 165 |  | 
|  | |
| 543 | 
             
                "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
         | 
| 544 | 
             
            }
         | 
| 545 | 
             
            _number_re = re.compile(r"[0-9]+")
         | 
| 546 | 
            +
            # noinspection Annotator
         | 
| 547 | 
             
            _currency_re = {
         | 
| 548 | 
             
                "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
         | 
| 549 | 
             
                "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
         | 
|  | |
| 773 | 
             
                        char_limit = self.char_limits.get(base_lang, 250)
         | 
| 774 |  | 
| 775 | 
             
                        # Clean and preprocess
         | 
| 776 | 
            +
                        #text = self.preprocess_text(text, text_lang) we do this in the hidden function
         | 
| 777 |  | 
| 778 | 
             
                        # Split text into sentences/chunks based on language
         | 
| 779 | 
             
                        chunk_list = split_sentence(text, base_lang, text_split_length=char_limit)
         | 

