|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from SSL_WavLM.WavLM import WavLMConfig, WavLM |
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
|
class MHFA(nn.Module): |
|
|
""" |
|
|
Multi-Head Factorized Attentive (MHFA) Pooling. |
|
|
This layer takes representations from all layers of a model (like WavLM) |
|
|
and aggregates them into a fixed-size embedding using a multi-head |
|
|
attention-like mechanism. |
|
|
""" |
|
|
def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13): |
|
|
super(MHFA, self).__init__() |
|
|
|
|
|
self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True) |
|
|
self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True) |
|
|
|
|
|
self.head_nb = head_nb |
|
|
self.cmp_dim = compression_dim |
|
|
|
|
|
|
|
|
self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim) |
|
|
self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim) |
|
|
self.att_head = nn.Linear(self.cmp_dim, self.head_nb) |
|
|
self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2) |
|
|
v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
k = self.cmp_linear_k(k) |
|
|
v = self.cmp_linear_v(v) |
|
|
|
|
|
|
|
|
att_scores = self.att_head(k) |
|
|
att_weights = F.softmax(att_scores, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
b, h, f = pooled_features.shape |
|
|
pooled_features = pooled_features.reshape(b, -1) |
|
|
output_embedding = self.pooling_fc(pooled_features) |
|
|
|
|
|
return output_embedding |
|
|
|
|
|
class WavLM_MHFA(nn.Module): |
|
|
""" |
|
|
The main model that combines a pre-trained WavLM with the MHFA backend. |
|
|
""" |
|
|
def __init__(self, model_path): |
|
|
super(WavLM_MHFA, self).__init__() |
|
|
|
|
|
print(f"Loading base model checkpoint from: {model_path}") |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
|
cfg_dict = checkpoint['cfg'] |
|
|
cfg = WavLMConfig(cfg_dict) |
|
|
self.model = WavLM(cfg) |
|
|
|
|
|
inputs_dim = checkpoint['cfg']['encoder_embed_dim'] |
|
|
nb_layer = checkpoint['cfg']['encoder_layers'] + 1 |
|
|
|
|
|
self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer) |
|
|
|
|
|
|
|
|
self.load_checkpoint(checkpoint['model']) |
|
|
|
|
|
def load_checkpoint(self, checkpoint_state): |
|
|
loaded_state = checkpoint_state |
|
|
|
|
|
|
|
|
cleaned_state_dict = OrderedDict() |
|
|
|
|
|
|
|
|
prefix_to_strip = 'speaker_extractor.' |
|
|
for k, v in loaded_state.items(): |
|
|
if 'projection' in k: |
|
|
continue |
|
|
if k.startswith(prefix_to_strip): |
|
|
cleaned_key = k[len(prefix_to_strip):] |
|
|
cleaned_state_dict[cleaned_key] = v |
|
|
else: |
|
|
cleaned_state_dict[k] = v |
|
|
|
|
|
|
|
|
|
|
|
super().load_state_dict(cleaned_state_dict, strict=True) |
|
|
print("Successfully loaded weights for both WavLM and MHFA backend.") |
|
|
|
|
|
def forward(self, raw_wav): |
|
|
|
|
|
|
|
|
_, layer_results = self.model.extract_features(raw_wav, output_layer=100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1) |
|
|
|
|
|
layer_reps = stacked_reps.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
spk_embedding = self.back_end(layer_reps) |
|
|
|
|
|
return spk_embedding |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
print("Loading checkpoint file ...") |
|
|
|
|
|
base_model_path = './SSL_WavLM/model_convert.pt' |
|
|
model = WavLM_MHFA(model_path=base_model_path) |
|
|
model.eval() |
|
|
|
|
|
print("\nModel WavLM_MHFA initialized successfully.") |
|
|
|
|
|
|
|
|
batch_size = 4 |
|
|
audio_samples = 32000 |
|
|
dummy_wav = torch.randn(batch_size, audio_samples) |
|
|
|
|
|
print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}") |
|
|
|
|
|
speaker_embedding = model(dummy_wav) |
|
|
|
|
|
print("Forward pass successful!") |
|
|
print(f"Output speaker embedding shape: {speaker_embedding.shape}") |