| from typing import Optional, Union | |
| from transformers import Qwen2Config | |
| from transformers.configuration_utils import PretrainedConfig | |
| class StepAudio2EncoderConfig(PretrainedConfig): | |
| model_type = "step_audio_2_encoder" | |
| def __init__( | |
| self, | |
| n_mels=128, | |
| n_audio_ctx=1500, | |
| n_audio_state=512, | |
| n_audio_head=8, | |
| n_audio_layer=6, | |
| llm_dim=4096, | |
| kernel_size=3, | |
| adapter_stride=2, | |
| **kwargs, | |
| ): | |
| self.n_mels = n_mels | |
| self.n_audio_ctx = n_audio_ctx | |
| self.n_audio_state = n_audio_state | |
| self.n_audio_head = n_audio_head | |
| self.n_audio_layer = n_audio_layer | |
| self.llm_dim = llm_dim | |
| self.kernel_size = kernel_size | |
| self.adapter_stride = adapter_stride | |
| super().__init__(**kwargs) | |
| class StepAudio2TextConfig(PretrainedConfig): | |
| model_type = "step_audio_2_text" | |
| def __init__( | |
| self, | |
| vocab_size=64012, | |
| hidden_size=4096, | |
| intermediate_size=11008, | |
| num_hidden_layers=48, | |
| num_attention_heads=32, | |
| num_attention_groups=4, | |
| num_key_value_heads=4, | |
| hidden_act="silu", | |
| max_position_embeddings=8192, | |
| initializer_range=0.02, | |
| rms_norm_eps=1e-6, | |
| rope_theta=1000000.0, | |
| rope_scaling=None, | |
| eos_token_id=None, | |
| **kwargs | |
| ): | |
| if eos_token_id is not None: | |
| if isinstance(eos_token_id, list): | |
| eos_token_id = list(set([151643, 151645, 151665] + eos_token_id)) | |
| else: | |
| eos_token_id = [151643, 151645, 151665, eos_token_id] | |
| else: | |
| eos_token_id = [151643, 151645, 151665] | |
| super().__init__( | |
| eos_token_id=eos_token_id, | |
| **kwargs) | |
| self.vocab_size = vocab_size | |
| self.hidden_size = hidden_size | |
| self.intermediate_size = intermediate_size | |
| self.num_hidden_layers = num_hidden_layers | |
| self.num_attention_heads = num_attention_heads | |
| self.num_attention_groups = num_attention_groups | |
| self.num_key_value_heads = num_key_value_heads | |
| assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads" | |
| self.hidden_act = hidden_act | |
| self.max_position_embeddings = max_position_embeddings | |
| self.initializer_range = initializer_range | |
| self.rms_norm_eps = rms_norm_eps | |
| self.rope_theta = rope_theta | |
| self.rope_scaling = rope_scaling | |
| self.text_config = Qwen2Config( | |
| vocab_size=vocab_size, | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| num_hidden_layers=num_hidden_layers, | |
| num_attention_heads=num_attention_heads, | |
| num_key_value_heads=num_key_value_heads, | |
| hidden_act=hidden_act, | |
| max_position_embeddings=max_position_embeddings, | |
| initializer_range=initializer_range, | |
| rms_norm_eps=rms_norm_eps, | |
| rope_theta=rope_theta, | |
| rope_scaling=rope_scaling, | |
| architectures=["Qwen2ForCausalLM"], | |
| torch_dtype=getattr(self, "torch_dtype", "bfloat16"), | |
| ) | |
| class StepAudio2Config(PretrainedConfig): | |
| model_type = "step_audio_2" | |
| architectures = ["StepAudio2ForCausalLM"] | |
| def __init__( | |
| self, | |
| audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None, | |
| text_config: Optional[Union[dict, StepAudio2TextConfig]] = None, | |
| use_sliding_window: bool = False, | |
| sliding_window: Optional[int] = 2048, | |
| max_window_layers: Optional[int] = None, | |
| **kwargs | |
| ): | |
| kwargs.setdefault("use_sliding_window", use_sliding_window) | |
| kwargs.setdefault("sliding_window", sliding_window) | |
| if max_window_layers is None: | |
| max_window_layers = kwargs.get("num_hidden_layers", None) | |
| kwargs.setdefault("max_window_layers", max_window_layers) | |
| super().__init__(**kwargs) | |
| if text_config is None: | |
| text_config = StepAudio2TextConfig().text_config | |
| elif isinstance(text_config, dict): | |
| text_config = StepAudio2TextConfig(**text_config).text_config | |
| self.text_config = text_config | |
| if audio_encoder_config is None: | |
| self.audio_encoder_config = StepAudio2EncoderConfig() | |
| elif isinstance(audio_encoder_config, dict): | |
| self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config) | |