|  | from typing import List, Optional, Union, Dict, Tuple, Any | 
					
						
						|  | import os | 
					
						
						|  | from functools import cached_property | 
					
						
						|  |  | 
					
						
						|  | from transformers import PreTrainedTokenizerFast | 
					
						
						|  | from transformers.tokenization_utils_base import TruncationStrategy, PaddingStrategy | 
					
						
						|  | from tokenizers import Tokenizer, processors | 
					
						
						|  | from tokenizers.pre_tokenizers import WhitespaceSplit | 
					
						
						|  | from tokenizers.processors import TemplateProcessing | 
					
						
						|  | import torch | 
					
						
						|  | from hangul_romanize import Transliter | 
					
						
						|  | from hangul_romanize.rule import academic | 
					
						
						|  | import cutlet | 
					
						
						|  |  | 
					
						
						|  | from TTS.tts.layers.xtts.tokenizer import (multilingual_cleaners, basic_cleaners, | 
					
						
						|  | chinese_transliterate, korean_transliterate, | 
					
						
						|  | japanese_cleaners) | 
					
						
						|  |  | 
					
						
						|  | class XTTSTokenizerFast(PreTrainedTokenizerFast): | 
					
						
						|  | """ | 
					
						
						|  | Fast Tokenizer implementation for XTTS model using HuggingFace's PreTrainedTokenizerFast | 
					
						
						|  | """ | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vocab_file: str = None, | 
					
						
						|  | tokenizer_object: Optional[Tokenizer] = None, | 
					
						
						|  | unk_token: str = "[UNK]", | 
					
						
						|  | pad_token: str = "[PAD]", | 
					
						
						|  | bos_token: str = "[START]", | 
					
						
						|  | eos_token: str = "[STOP]", | 
					
						
						|  | clean_up_tokenization_spaces: bool = True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ): | 
					
						
						|  | if tokenizer_object is None and vocab_file is not None: | 
					
						
						|  | tokenizer_object = Tokenizer.from_file(vocab_file) | 
					
						
						|  |  | 
					
						
						|  | if tokenizer_object is not None: | 
					
						
						|  |  | 
					
						
						|  | tokenizer_object.pre_tokenizer = WhitespaceSplit() | 
					
						
						|  | tokenizer_object.enable_padding( | 
					
						
						|  | direction='right', | 
					
						
						|  | pad_id=tokenizer_object.token_to_id(pad_token) or 0, | 
					
						
						|  | pad_token=pad_token | 
					
						
						|  | ) | 
					
						
						|  | tokenizer_object.post_processor = TemplateProcessing( | 
					
						
						|  | single=f"{bos_token} $A {eos_token}", | 
					
						
						|  | special_tokens=[ | 
					
						
						|  | (bos_token, tokenizer_object.token_to_id(bos_token)), | 
					
						
						|  | (eos_token, tokenizer_object.token_to_id(eos_token)), | 
					
						
						|  | ], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | super().__init__( | 
					
						
						|  | tokenizer_object=tokenizer_object, | 
					
						
						|  | unk_token=unk_token, | 
					
						
						|  | pad_token=pad_token, | 
					
						
						|  | bos_token=bos_token, | 
					
						
						|  | eos_token=eos_token, | 
					
						
						|  | clean_up_tokenization_spaces=clean_up_tokenization_spaces, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.char_limits = { | 
					
						
						|  | "en": 250, "de": 253, "fr": 273, "es": 239, | 
					
						
						|  | "it": 213, "pt": 203, "pl": 224, "zh": 82, | 
					
						
						|  | "ar": 166, "cs": 186, "ru": 182, "nl": 251, | 
					
						
						|  | "tr": 226, "ja": 71, "hu": 224, "ko": 95, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._katsu = None | 
					
						
						|  | self._korean_transliter = Transliter(academic) | 
					
						
						|  |  | 
					
						
						|  | @cached_property | 
					
						
						|  | def katsu(self): | 
					
						
						|  | if self._katsu is None: | 
					
						
						|  | self._katsu = cutlet.Cutlet() | 
					
						
						|  | return self._katsu | 
					
						
						|  |  | 
					
						
						|  | def check_input_length(self, text: str, lang: str): | 
					
						
						|  | """Check if input text length is within limits for language""" | 
					
						
						|  | lang = lang.split("-")[0] | 
					
						
						|  | limit = self.char_limits.get(lang, 250) | 
					
						
						|  | if len(text) > limit: | 
					
						
						|  | print(f"Warning: Text length exceeds {limit} char limit for '{lang}', may cause truncation.") | 
					
						
						|  |  | 
					
						
						|  | def preprocess_text(self, text: str, lang: str) -> str: | 
					
						
						|  | """Apply text preprocessing for language""" | 
					
						
						|  | if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", | 
					
						
						|  | "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: | 
					
						
						|  | text = multilingual_cleaners(text, lang) | 
					
						
						|  | if lang == "zh": | 
					
						
						|  | text = chinese_transliterate(text) | 
					
						
						|  | if lang == "ko": | 
					
						
						|  | text = korean_transliterate(text) | 
					
						
						|  | elif lang == "ja": | 
					
						
						|  | text = japanese_cleaners(text, self.katsu) | 
					
						
						|  | else: | 
					
						
						|  | text = basic_cleaners(text) | 
					
						
						|  | return text | 
					
						
						|  |  | 
					
						
						|  | def _batch_encode_plus( | 
					
						
						|  | self, | 
					
						
						|  | batch_text_or_text_pairs, | 
					
						
						|  | add_special_tokens: bool = True, | 
					
						
						|  | padding_strategy = PaddingStrategy.DO_NOT_PAD, | 
					
						
						|  | truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE, | 
					
						
						|  | max_length: Optional[int] = 402, | 
					
						
						|  | stride: int = 0, | 
					
						
						|  | is_split_into_words: bool = False, | 
					
						
						|  | pad_to_multiple_of: Optional[int] = None, | 
					
						
						|  | return_tensors: Optional[str] = None, | 
					
						
						|  | return_token_type_ids: Optional[bool] = None, | 
					
						
						|  | return_attention_mask: Optional[bool] = None, | 
					
						
						|  | return_overflowing_tokens: bool = False, | 
					
						
						|  | return_special_tokens_mask: bool = False, | 
					
						
						|  | return_offsets_mapping: bool = False, | 
					
						
						|  | return_length: bool = False, | 
					
						
						|  | verbose: bool = True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) -> Dict[str, Any]: | 
					
						
						|  | """ | 
					
						
						|  | Override batch encoding to handle language-specific preprocessing | 
					
						
						|  | """ | 
					
						
						|  | lang = kwargs.pop("lang", ["en"] * len(batch_text_or_text_pairs)) | 
					
						
						|  | if isinstance(lang, str): | 
					
						
						|  | lang = [lang] * len(batch_text_or_text_pairs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | processed_texts = [] | 
					
						
						|  | for text, text_lang in zip(batch_text_or_text_pairs, lang): | 
					
						
						|  | if isinstance(text, str): | 
					
						
						|  |  | 
					
						
						|  | self.check_input_length(text, text_lang) | 
					
						
						|  | processed_text = self.preprocess_text(text, text_lang) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | lang_code = "zh-cn" if text_lang == "zh" else text_lang | 
					
						
						|  | processed_text = f"[{lang_code}]{processed_text}" | 
					
						
						|  | processed_text = processed_text.replace(" ", "[SPACE]") | 
					
						
						|  |  | 
					
						
						|  | processed_texts.append(processed_text) | 
					
						
						|  | else: | 
					
						
						|  | processed_texts.append(text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return super()._batch_encode_plus( | 
					
						
						|  | processed_texts, | 
					
						
						|  | add_special_tokens=add_special_tokens, | 
					
						
						|  | padding_strategy=padding_strategy, | 
					
						
						|  | truncation_strategy=truncation_strategy, | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | stride=stride, | 
					
						
						|  | is_split_into_words=is_split_into_words, | 
					
						
						|  | pad_to_multiple_of=pad_to_multiple_of, | 
					
						
						|  | return_tensors=return_tensors, | 
					
						
						|  | return_token_type_ids=return_token_type_ids, | 
					
						
						|  | return_attention_mask=return_attention_mask, | 
					
						
						|  | return_overflowing_tokens=return_overflowing_tokens, | 
					
						
						|  | return_special_tokens_mask=return_special_tokens_mask, | 
					
						
						|  | return_offsets_mapping=return_offsets_mapping, | 
					
						
						|  | return_length=return_length, | 
					
						
						|  | verbose=verbose, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | text: Union[str, List[str]], | 
					
						
						|  | lang: Union[str, List[str]] = "en", | 
					
						
						|  | add_special_tokens: bool = True, | 
					
						
						|  | padding: Union[bool, str, PaddingStrategy] = True, | 
					
						
						|  | truncation: Union[bool, str, TruncationStrategy] = True, | 
					
						
						|  | max_length: Optional[int] = 402, | 
					
						
						|  | stride: int = 0, | 
					
						
						|  | return_tensors: Optional[str] = None, | 
					
						
						|  | return_token_type_ids: Optional[bool] = None, | 
					
						
						|  | return_attention_mask: Optional[bool] = True, | 
					
						
						|  | **kwargs | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Main tokenization method | 
					
						
						|  | Args: | 
					
						
						|  | text: Text or list of texts to tokenize | 
					
						
						|  | lang: Language code or list of language codes corresponding to each text | 
					
						
						|  | add_special_tokens: Whether to add special tokens | 
					
						
						|  | padding: Padding strategy (default True) | 
					
						
						|  | truncation: Truncation strategy (default True) | 
					
						
						|  | max_length: Maximum length | 
					
						
						|  | stride: Stride for truncation | 
					
						
						|  | return_tensors: Format of output tensors ("pt" for PyTorch) | 
					
						
						|  | return_token_type_ids: Whether to return token type IDs | 
					
						
						|  | return_attention_mask: Whether to return attention mask (default True) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if isinstance(text, str): | 
					
						
						|  | text = [text] | 
					
						
						|  | if isinstance(lang, str): | 
					
						
						|  | lang = [lang] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if len(text) != len(lang): | 
					
						
						|  | raise ValueError(f"Number of texts ({len(text)}) must match number of language codes ({len(lang)})") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if isinstance(padding, bool): | 
					
						
						|  | padding_strategy = PaddingStrategy.MAX_LENGTH if padding else PaddingStrategy.DO_NOT_PAD | 
					
						
						|  | else: | 
					
						
						|  | padding_strategy = PaddingStrategy(padding) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if isinstance(truncation, bool): | 
					
						
						|  | truncation_strategy = TruncationStrategy.LONGEST_FIRST if truncation else TruncationStrategy.DO_NOT_TRUNCATE | 
					
						
						|  | else: | 
					
						
						|  | truncation_strategy = TruncationStrategy(truncation) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | encoded = self._batch_encode_plus( | 
					
						
						|  | text, | 
					
						
						|  | add_special_tokens=add_special_tokens, | 
					
						
						|  | padding_strategy=padding_strategy, | 
					
						
						|  | truncation_strategy=truncation_strategy, | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | stride=stride, | 
					
						
						|  | return_tensors=return_tensors, | 
					
						
						|  | return_token_type_ids=return_token_type_ids, | 
					
						
						|  | return_attention_mask=return_attention_mask, | 
					
						
						|  | lang=lang, | 
					
						
						|  | **kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | return encoded |