File size: 6,120 Bytes
70cbc33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from SSL_WavLM.WavLM import WavLMConfig, WavLM
from collections import OrderedDict
class MHFA(nn.Module):
"""
Multi-Head Factorized Attentive (MHFA) Pooling.
This layer takes representations from all layers of a model (like WavLM)
and aggregates them into a fixed-size embedding using a multi-head
attention-like mechanism.
"""
def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13):
super(MHFA, self).__init__()
# Learnable weights to compute a weighted average over the layers
self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
self.head_nb = head_nb
self.cmp_dim = compression_dim
# Linear layers for processing
self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim)
self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim)
self.att_head = nn.Linear(self.cmp_dim, self.head_nb)
self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim)
def forward(self, x):
# Input x shape: [Batch, Dim, Frame_len, Nb_Layer]
# 1. Compute weighted average for Key and Value across layers
# The softmax ensures the weights sum to 1.
k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2)
v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2)
# Shape of k, v is now [Batch, Frame_len, Dim]
# 2. Compress Key and Value representations
k = self.cmp_linear_k(k) # -> [B, T, cmp_dim]
v = self.cmp_linear_v(v) # -> [B, T, cmp_dim]
# 3. Compute attention scores from the compressed key
att_scores = self.att_head(k) # -> [B, T, head_nb]
att_weights = F.softmax(att_scores, dim=1) # Softmax over time dimension
# 4. Perform attention-pooling
# Reshape for broadcasting:
# v: [B, T, 1, cmp_dim]
# att_weights: [B, T, head_nb, 1]
# The multiplication broadcasts to [B, T, head_nb, cmp_dim]
pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1)
# Sum over time dimension results in [B, head_nb, cmp_dim]
# 5. Flatten and project to final output dimension
b, h, f = pooled_features.shape
pooled_features = pooled_features.reshape(b, -1) # -> [B, head_nb * cmp_dim]
output_embedding = self.pooling_fc(pooled_features) # -> [B, outputs_dim]
return output_embedding
class WavLM_MHFA(nn.Module):
"""
The main model that combines a pre-trained WavLM with the MHFA backend.
"""
def __init__(self, model_path):
super(WavLM_MHFA, self).__init__()
print(f"Loading base model checkpoint from: {model_path}")
# Use map_location to ensure it works on CPU if no GPU is available
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
# Correctly access the config dictionary
cfg_dict = checkpoint['cfg']
cfg = WavLMConfig(cfg_dict)
self.model = WavLM(cfg)
inputs_dim = checkpoint['cfg']['encoder_embed_dim']
nb_layer = checkpoint['cfg']['encoder_layers'] + 1
self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer)
# Load the pre-trained weights for the WavLM part of the model
self.load_checkpoint(checkpoint['model'])
def load_checkpoint(self, checkpoint_state):
loaded_state = checkpoint_state
# Create a new state_dict to hold the cleaned keys
cleaned_state_dict = OrderedDict()
# Handle checkpoints that might be nested (e.g., inside a 'speaker_extractor')
prefix_to_strip = 'speaker_extractor.'
for k, v in loaded_state.items():
if 'projection' in k:
continue
if k.startswith(prefix_to_strip):
cleaned_key = k[len(prefix_to_strip):]
cleaned_state_dict[cleaned_key] = v
else:
cleaned_state_dict[k] = v
# Now load the cleaned state_dict into the current model
super().load_state_dict(cleaned_state_dict, strict=True)
print("Successfully loaded weights for both WavLM and MHFA backend.")
def forward(self, raw_wav):
# Feature extraction should not require gradients and should be in eval mode
_, layer_results = self.model.extract_features(raw_wav, output_layer=100)
# Prepare layer representations for the MHFA backend
# Input layer_results: List of (Time, Batch, Dim) tensors
# Stack them to create [Batch, Time, Dim, Nb_Layer]
stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1)
# Permute to match MHFA input: [Batch, Dim, Time, Nb_Layer]
layer_reps = stacked_reps.permute(0, 2, 1, 3)
# The backend part is trainable
spk_embedding = self.back_end(layer_reps)
return spk_embedding
if __name__ == "__main__":
# Step 1: Instantiate the main model
# The model path should point to the pre-trained base model (e.g., WavLM-Base+.pt)
print("Loading checkpoint file ...")
base_model_path = './SSL_WavLM/model_convert.pt'
model = WavLM_MHFA(model_path=base_model_path)
model.eval() # Set the model to evaluation mode
print("\nModel WavLM_MHFA initialized successfully.")
# Step 2: Perform a forward pass with dummy data
batch_size = 4
audio_samples = 32000 # ~2 seconds of audio at 16kHz
dummy_wav = torch.randn(batch_size, audio_samples)
print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}")
speaker_embedding = model(dummy_wav)
print("Forward pass successful!")
print(f"Output speaker embedding shape: {speaker_embedding.shape}") |