WeSpeaker_SV_WavLM_MHFA / Transformer_WavLM.py
JYP2024's picture
Initial upload (hf_transfer enabled)
70cbc33 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from SSL_WavLM.WavLM import WavLMConfig, WavLM
from collections import OrderedDict
class MHFA(nn.Module):
"""
Multi-Head Factorized Attentive (MHFA) Pooling.
This layer takes representations from all layers of a model (like WavLM)
and aggregates them into a fixed-size embedding using a multi-head
attention-like mechanism.
"""
def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13):
super(MHFA, self).__init__()
# Learnable weights to compute a weighted average over the layers
self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
self.head_nb = head_nb
self.cmp_dim = compression_dim
# Linear layers for processing
self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim)
self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim)
self.att_head = nn.Linear(self.cmp_dim, self.head_nb)
self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim)
def forward(self, x):
# Input x shape: [Batch, Dim, Frame_len, Nb_Layer]
# 1. Compute weighted average for Key and Value across layers
# The softmax ensures the weights sum to 1.
k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2)
v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2)
# Shape of k, v is now [Batch, Frame_len, Dim]
# 2. Compress Key and Value representations
k = self.cmp_linear_k(k) # -> [B, T, cmp_dim]
v = self.cmp_linear_v(v) # -> [B, T, cmp_dim]
# 3. Compute attention scores from the compressed key
att_scores = self.att_head(k) # -> [B, T, head_nb]
att_weights = F.softmax(att_scores, dim=1) # Softmax over time dimension
# 4. Perform attention-pooling
# Reshape for broadcasting:
# v: [B, T, 1, cmp_dim]
# att_weights: [B, T, head_nb, 1]
# The multiplication broadcasts to [B, T, head_nb, cmp_dim]
pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1)
# Sum over time dimension results in [B, head_nb, cmp_dim]
# 5. Flatten and project to final output dimension
b, h, f = pooled_features.shape
pooled_features = pooled_features.reshape(b, -1) # -> [B, head_nb * cmp_dim]
output_embedding = self.pooling_fc(pooled_features) # -> [B, outputs_dim]
return output_embedding
class WavLM_MHFA(nn.Module):
"""
The main model that combines a pre-trained WavLM with the MHFA backend.
"""
def __init__(self, model_path):
super(WavLM_MHFA, self).__init__()
print(f"Loading base model checkpoint from: {model_path}")
# Use map_location to ensure it works on CPU if no GPU is available
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
# Correctly access the config dictionary
cfg_dict = checkpoint['cfg']
cfg = WavLMConfig(cfg_dict)
self.model = WavLM(cfg)
inputs_dim = checkpoint['cfg']['encoder_embed_dim']
nb_layer = checkpoint['cfg']['encoder_layers'] + 1
self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer)
# Load the pre-trained weights for the WavLM part of the model
self.load_checkpoint(checkpoint['model'])
def load_checkpoint(self, checkpoint_state):
loaded_state = checkpoint_state
# Create a new state_dict to hold the cleaned keys
cleaned_state_dict = OrderedDict()
# Handle checkpoints that might be nested (e.g., inside a 'speaker_extractor')
prefix_to_strip = 'speaker_extractor.'
for k, v in loaded_state.items():
if 'projection' in k:
continue
if k.startswith(prefix_to_strip):
cleaned_key = k[len(prefix_to_strip):]
cleaned_state_dict[cleaned_key] = v
else:
cleaned_state_dict[k] = v
# Now load the cleaned state_dict into the current model
super().load_state_dict(cleaned_state_dict, strict=True)
print("Successfully loaded weights for both WavLM and MHFA backend.")
def forward(self, raw_wav):
# Feature extraction should not require gradients and should be in eval mode
_, layer_results = self.model.extract_features(raw_wav, output_layer=100)
# Prepare layer representations for the MHFA backend
# Input layer_results: List of (Time, Batch, Dim) tensors
# Stack them to create [Batch, Time, Dim, Nb_Layer]
stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1)
# Permute to match MHFA input: [Batch, Dim, Time, Nb_Layer]
layer_reps = stacked_reps.permute(0, 2, 1, 3)
# The backend part is trainable
spk_embedding = self.back_end(layer_reps)
return spk_embedding
if __name__ == "__main__":
# Step 1: Instantiate the main model
# The model path should point to the pre-trained base model (e.g., WavLM-Base+.pt)
print("Loading checkpoint file ...")
base_model_path = './SSL_WavLM/model_convert.pt'
model = WavLM_MHFA(model_path=base_model_path)
model.eval() # Set the model to evaluation mode
print("\nModel WavLM_MHFA initialized successfully.")
# Step 2: Perform a forward pass with dummy data
batch_size = 4
audio_samples = 32000 # ~2 seconds of audio at 16kHz
dummy_wav = torch.randn(batch_size, audio_samples)
print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}")
speaker_embedding = model(dummy_wav)
print("Forward pass successful!")
print(f"Output speaker embedding shape: {speaker_embedding.shape}")