import math import torch import torch.nn as nn import torch.nn.functional as F from SSL_WavLM.WavLM import WavLMConfig, WavLM from collections import OrderedDict class MHFA(nn.Module): """ Multi-Head Factorized Attentive (MHFA) Pooling. This layer takes representations from all layers of a model (like WavLM) and aggregates them into a fixed-size embedding using a multi-head attention-like mechanism. """ def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13): super(MHFA, self).__init__() # Learnable weights to compute a weighted average over the layers self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True) self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True) self.head_nb = head_nb self.cmp_dim = compression_dim # Linear layers for processing self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim) self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim) self.att_head = nn.Linear(self.cmp_dim, self.head_nb) self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim) def forward(self, x): # Input x shape: [Batch, Dim, Frame_len, Nb_Layer] # 1. Compute weighted average for Key and Value across layers # The softmax ensures the weights sum to 1. k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2) v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2) # Shape of k, v is now [Batch, Frame_len, Dim] # 2. Compress Key and Value representations k = self.cmp_linear_k(k) # -> [B, T, cmp_dim] v = self.cmp_linear_v(v) # -> [B, T, cmp_dim] # 3. Compute attention scores from the compressed key att_scores = self.att_head(k) # -> [B, T, head_nb] att_weights = F.softmax(att_scores, dim=1) # Softmax over time dimension # 4. Perform attention-pooling # Reshape for broadcasting: # v: [B, T, 1, cmp_dim] # att_weights: [B, T, head_nb, 1] # The multiplication broadcasts to [B, T, head_nb, cmp_dim] pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1) # Sum over time dimension results in [B, head_nb, cmp_dim] # 5. Flatten and project to final output dimension b, h, f = pooled_features.shape pooled_features = pooled_features.reshape(b, -1) # -> [B, head_nb * cmp_dim] output_embedding = self.pooling_fc(pooled_features) # -> [B, outputs_dim] return output_embedding class WavLM_MHFA(nn.Module): """ The main model that combines a pre-trained WavLM with the MHFA backend. """ def __init__(self, model_path): super(WavLM_MHFA, self).__init__() print(f"Loading base model checkpoint from: {model_path}") # Use map_location to ensure it works on CPU if no GPU is available checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Correctly access the config dictionary cfg_dict = checkpoint['cfg'] cfg = WavLMConfig(cfg_dict) self.model = WavLM(cfg) inputs_dim = checkpoint['cfg']['encoder_embed_dim'] nb_layer = checkpoint['cfg']['encoder_layers'] + 1 self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer) # Load the pre-trained weights for the WavLM part of the model self.load_checkpoint(checkpoint['model']) def load_checkpoint(self, checkpoint_state): loaded_state = checkpoint_state # Create a new state_dict to hold the cleaned keys cleaned_state_dict = OrderedDict() # Handle checkpoints that might be nested (e.g., inside a 'speaker_extractor') prefix_to_strip = 'speaker_extractor.' for k, v in loaded_state.items(): if 'projection' in k: continue if k.startswith(prefix_to_strip): cleaned_key = k[len(prefix_to_strip):] cleaned_state_dict[cleaned_key] = v else: cleaned_state_dict[k] = v # Now load the cleaned state_dict into the current model super().load_state_dict(cleaned_state_dict, strict=True) print("Successfully loaded weights for both WavLM and MHFA backend.") def forward(self, raw_wav): # Feature extraction should not require gradients and should be in eval mode _, layer_results = self.model.extract_features(raw_wav, output_layer=100) # Prepare layer representations for the MHFA backend # Input layer_results: List of (Time, Batch, Dim) tensors # Stack them to create [Batch, Time, Dim, Nb_Layer] stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1) # Permute to match MHFA input: [Batch, Dim, Time, Nb_Layer] layer_reps = stacked_reps.permute(0, 2, 1, 3) # The backend part is trainable spk_embedding = self.back_end(layer_reps) return spk_embedding if __name__ == "__main__": # Step 1: Instantiate the main model # The model path should point to the pre-trained base model (e.g., WavLM-Base+.pt) print("Loading checkpoint file ...") base_model_path = './SSL_WavLM/model_convert.pt' model = WavLM_MHFA(model_path=base_model_path) model.eval() # Set the model to evaluation mode print("\nModel WavLM_MHFA initialized successfully.") # Step 2: Perform a forward pass with dummy data batch_size = 4 audio_samples = 32000 # ~2 seconds of audio at 16kHz dummy_wav = torch.randn(batch_size, audio_samples) print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}") speaker_embedding = model(dummy_wav) print("Forward pass successful!") print(f"Output speaker embedding shape: {speaker_embedding.shape}")