Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchaudio.models import Conformer | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from .config import ( | |
| N_MELS, | |
| CNN_CH, | |
| N_HEADS, | |
| D_MODEL, | |
| FF_DIM, | |
| N_LAYERS, | |
| DROPOUT, | |
| DEPTHWISE_CONV_KERNEL_SIZE, | |
| HIDDEN_DIM, | |
| DEVICE, | |
| ) | |
| class TaikoConformer5(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self): | |
| super().__init__() | |
| # 1) CNN frontend: frequency-only pooling | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, CNN_CH, 3, stride=(2, 1), padding=1), | |
| nn.BatchNorm2d(CNN_CH), | |
| nn.GELU(), | |
| nn.Dropout2d(DROPOUT), | |
| nn.Conv2d(CNN_CH, CNN_CH, 3, stride=(2, 1), padding=1), | |
| nn.BatchNorm2d(CNN_CH), | |
| nn.GELU(), | |
| nn.Dropout2d(DROPOUT), | |
| ) | |
| feat_dim = CNN_CH * (N_MELS // 4) | |
| # 2) Linear projection to model dimension | |
| self.proj = nn.Linear(feat_dim, D_MODEL) | |
| # 3) FiLM conditioning for notes_per_second | |
| self.film = nn.Linear(1, 2 * D_MODEL) | |
| # 4) Conformer encoder | |
| self.encoder = Conformer( | |
| input_dim=D_MODEL, | |
| num_heads=N_HEADS, | |
| ffn_dim=FF_DIM, | |
| num_layers=N_LAYERS, | |
| depthwise_conv_kernel_size=DEPTHWISE_CONV_KERNEL_SIZE, | |
| dropout=DROPOUT, | |
| use_group_norm=False, | |
| convolution_first=False, | |
| ) | |
| # 5) Presence regressor head | |
| self.presence_regressor = nn.Sequential( | |
| nn.Dropout(DROPOUT), | |
| nn.Linear(D_MODEL, HIDDEN_DIM), | |
| nn.GELU(), | |
| nn.Dropout(DROPOUT), | |
| nn.Linear(HIDDEN_DIM, 3), # Don, Ka, DrumRoll energy | |
| nn.Sigmoid(), # Output between 0 and 1 | |
| ) | |
| # 6) Initialize weights | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") | |
| elif isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward( | |
| self, mel: torch.Tensor, lengths: torch.Tensor, notes_per_second: torch.Tensor | |
| ): | |
| """ | |
| Args: | |
| mel: (B, 1, N_MELS, T_mel) | |
| lengths: (B,) lengths after CNN | |
| notes_per_second: (B,) stream of control values | |
| Returns: | |
| Dict with: | |
| 'presence': (B, T_cnn_out, 4) | |
| 'lengths': lengths | |
| """ | |
| # CNN frontend | |
| x = self.cnn(mel) # (B, C, F, T) | |
| B, C, F, T = x.size() | |
| x = x.permute(0, 3, 1, 2).reshape(B, T, C * F) | |
| # Project to model dimension | |
| x = self.proj(x) # (B, T, D_MODEL) | |
| # FiLM conditioning | |
| nps = notes_per_second.unsqueeze(-1) # (B, 1) | |
| gamma_beta = self.film(nps) # (B, 2*D_MODEL) | |
| gamma, beta = gamma_beta.chunk(2, dim=-1) | |
| x = gamma.unsqueeze(1) * x + beta.unsqueeze(1) | |
| # Conformer encoder | |
| x, _ = self.encoder(x, lengths=lengths) | |
| # Presence prediction | |
| presence = self.presence_regressor(x) | |
| return {"presence": presence, "lengths": lengths} | |
| if __name__ == "__main__": | |
| model = TaikoConformer5().to(device=DEVICE) | |
| print(model) | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| print(f"{name}: {param.numel():,}") | |
| params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Total parameters: {params / 1e6:.2f}M") | |
| batch_size = 4 | |
| mel_time_steps = 1024 | |
| input_mel = torch.randn(batch_size, 1, N_MELS, mel_time_steps).to(DEVICE) | |
| conformer_lengths = torch.tensor( | |
| [mel_time_steps] * batch_size, dtype=torch.long | |
| ).to(DEVICE) | |
| notes_per_second_input = torch.tensor([10.0] * batch_size, dtype=torch.float32).to( | |
| DEVICE | |
| ) | |
| output = model(input_mel, conformer_lengths, notes_per_second_input) | |
| print("Output shapes:") | |
| for key, value in output.items(): | |
| print(f"{key}: {value.shape}") | |