File size: 5,365 Bytes
			
			09a868c 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 013e081 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 7eebd5c 7e1e475 09a868c 7e1e475 bfce01d 7e1e475 e7fb2db 09a868c 7eebd5c 7e1e475 7eebd5c e7fb2db bfce01d e7fb2db bfce01d 7e1e475 bfce01d 7e1e475 bfce01d 7e1e475 bfce01d 7eebd5c 7e1e475 bfce01d 7e1e475 09a868c 7e1e475 7eebd5c 7e1e475  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144  | 
								from dataclasses import asdict, dataclass
from typing import Dict, Optional, List
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
@dataclass
class GPTAudioConfig:
    """Configuration for GPT audio processing parameters"""
    mel_channels: int = 80
    sample_rate: int = 22050
    output_sample_rate: int = 24000
@dataclass
class XTTSAudioConfig:
    """Configuration for audio processing parameters"""
    sample_rate: int = 22050
    output_sample_rate: int = 24000
    mel_channels: int = 80
    hop_length: int = 256
    win_length: int = 1024
    n_fft: int = 1024
    fmin: int = 0
    fmax: int = 8000
    power: float = 1.0
    mel_norms_file: Optional[str] = None
class XTTSGPTConfig(PretrainedConfig):
    """Configuration class for the GPT component of XTTS."""
    model_type = "xtts_gpt"
    def __init__(
            self,
            # Model architecture
            hidden_size: int = 1024,  # gpt_n_model_channels in original
            n_inner: int = 4096,
            num_hidden_layers: int = 30,  # gpt_layers in original
            num_attention_heads: int = 16,  # gpt_n_heads in original
            # Tokenizer settings
            vocab_size: int = 6681,  # gpt_number_text_tokens in original
            number_text_tokens: int = 6681,  # Explicit text token vocabulary size
            start_text_token: Optional[int] = None,
            stop_text_token: Optional[int] = None,
            # Audio token settings
            num_audio_tokens: int = 1026,  # gpt_num_audio_tokens in original
            start_audio_token: int = 1024,  # gpt_start_audio_token in original
            stop_audio_token: int = 1025,  # gpt_stop_audio_token in original
            # Sequence length settings
            max_audio_tokens: int = 605,  # gpt_max_audio_tokens in original
            max_text_tokens: int = 402,  # gpt_max_text_tokens in original
            max_prompt_tokens: int = 70,  # gpt_max_prompt_tokens in original
            gpt_max_audio_tokens: int = 605,  # Used for generation
            # Model behavior settings
            use_masking_gt_prompt_approach: bool = True,  # gpt_use_masking_gt_prompt_approach in original
            use_perceiver_resampler: bool = True,  # gpt_use_perceiver_resampler in original
            kv_cache: bool = True,
            enable_redaction: bool = False,
            # GPT batch settings
            gpt_batch_size: int = 1,
            # Audio processing
            audio_config: Optional[Dict] = None,
            # Architecture specifics
            layer_norm_epsilon: float = 1e-5,
            initializer_range: float = 0.02,
            add_cross_attention: bool = False,
            scale_attn_by_inverse_layer_idx: bool = False,
            reorder_and_upcast_attn: bool = False,
            # Size settings for the decoder
            decoder_input_dim: int = 1024,
            architectures=["XttsGPT"],
            auto_map={
                "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
                "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
            },
            activation_function: str = "gelu",
            attn_pdrop: float = 0.1,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.architectures = architectures
        self.auto_map = auto_map
        self.audio_config = GPTAudioConfig(
            **audio_config if audio_config is not None else {}
        )
        self.activation_function = activation_function
        self.attn_pdrop = attn_pdrop
        self.hidden_size = hidden_size
        self.n_inner = n_inner
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.vocab_size = vocab_size
        self.number_text_tokens = number_text_tokens
        self.start_text_token = start_text_token
        self.stop_text_token = stop_text_token
        self.num_audio_tokens = num_audio_tokens
        self.start_audio_token = start_audio_token
        self.stop_audio_token = stop_audio_token
        self.max_audio_tokens = max_audio_tokens
        self.max_text_tokens = max_text_tokens
        self.max_prompt_tokens = max_prompt_tokens
        self.gpt_max_audio_tokens = gpt_max_audio_tokens
        self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
        self.use_perceiver_resampler = use_perceiver_resampler
        self.kv_cache = kv_cache
        self.enable_redaction = enable_redaction
        self.gpt_batch_size = gpt_batch_size
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.add_cross_attention = add_cross_attention
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn
        self.decoder_input_dim = decoder_input_dim
    def to_dict(self) -> Dict:
        """Convert the config to a dictionary."""
        output = super().to_dict()
        output["audio_config"] = asdict(self.audio_config)
        return output
    @classmethod
    def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
        """Create a config from a dictionary."""
        return cls(**config_dict)
 |