File size: 6,120 Bytes
70cbc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from SSL_WavLM.WavLM import WavLMConfig, WavLM
from collections import OrderedDict


class MHFA(nn.Module):
    """
    Multi-Head Factorized Attentive (MHFA) Pooling.
    This layer takes representations from all layers of a model (like WavLM)
    and aggregates them into a fixed-size embedding using a multi-head
    attention-like mechanism.
    """
    def __init__(self, head_nb=8, inputs_dim=768, compression_dim=128, outputs_dim=256, nb_layer=13):
        super(MHFA, self).__init__()
        # Learnable weights to compute a weighted average over the layers
        self.weights_k = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
        self.weights_v = nn.Parameter(torch.ones(nb_layer), requires_grad=True)
        
        self.head_nb = head_nb
        self.cmp_dim = compression_dim

        # Linear layers for processing
        self.cmp_linear_k = nn.Linear(inputs_dim, self.cmp_dim)
        self.cmp_linear_v = nn.Linear(inputs_dim, self.cmp_dim)
        self.att_head = nn.Linear(self.cmp_dim, self.head_nb)
        self.pooling_fc = nn.Linear(self.head_nb * self.cmp_dim, outputs_dim)

    def forward(self, x):
        # Input x shape: [Batch, Dim, Frame_len, Nb_Layer]
        
        # 1. Compute weighted average for Key and Value across layers
        # The softmax ensures the weights sum to 1.
        k = torch.sum(x * F.softmax(self.weights_k, dim=-1), dim=-1).transpose(1, 2)
        v = torch.sum(x * F.softmax(self.weights_v, dim=-1), dim=-1).transpose(1, 2)
        # Shape of k, v is now [Batch, Frame_len, Dim]

        # 2. Compress Key and Value representations
        k = self.cmp_linear_k(k) # -> [B, T, cmp_dim]
        v = self.cmp_linear_v(v) # -> [B, T, cmp_dim]

        # 3. Compute attention scores from the compressed key
        att_scores = self.att_head(k) # -> [B, T, head_nb]
        att_weights = F.softmax(att_scores, dim=1) # Softmax over time dimension

        # 4. Perform attention-pooling
        # Reshape for broadcasting:
        # v: [B, T, 1, cmp_dim]
        # att_weights: [B, T, head_nb, 1]
        # The multiplication broadcasts to [B, T, head_nb, cmp_dim]
        pooled_features = torch.sum(v.unsqueeze(-2) * att_weights.unsqueeze(-1), dim=1)
        # Sum over time dimension results in [B, head_nb, cmp_dim]

        # 5. Flatten and project to final output dimension
        b, h, f = pooled_features.shape
        pooled_features = pooled_features.reshape(b, -1) # -> [B, head_nb * cmp_dim]
        output_embedding = self.pooling_fc(pooled_features) # -> [B, outputs_dim]
        
        return output_embedding

class WavLM_MHFA(nn.Module):
    """
    The main model that combines a pre-trained WavLM with the MHFA backend.
    """
    def __init__(self, model_path):
        super(WavLM_MHFA, self).__init__()
        
        print(f"Loading base model checkpoint from: {model_path}")
        # Use map_location to ensure it works on CPU if no GPU is available
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        
        # Correctly access the config dictionary
        cfg_dict = checkpoint['cfg']
        cfg = WavLMConfig(cfg_dict)
        self.model = WavLM(cfg)
        
        inputs_dim = checkpoint['cfg']['encoder_embed_dim']
        nb_layer = checkpoint['cfg']['encoder_layers'] + 1

        self.back_end = MHFA(inputs_dim=inputs_dim, head_nb=32, outputs_dim=256, nb_layer=nb_layer)
        
        # Load the pre-trained weights for the WavLM part of the model
        self.load_checkpoint(checkpoint['model'])

    def load_checkpoint(self, checkpoint_state):
        loaded_state = checkpoint_state
        
        # Create a new state_dict to hold the cleaned keys
        cleaned_state_dict = OrderedDict()
        
        # Handle checkpoints that might be nested (e.g., inside a 'speaker_extractor')
        prefix_to_strip = 'speaker_extractor.'
        for k, v in loaded_state.items():
            if 'projection' in k:
                continue
            if k.startswith(prefix_to_strip):
                cleaned_key = k[len(prefix_to_strip):]
                cleaned_state_dict[cleaned_key] = v
            else:
                cleaned_state_dict[k] = v
            
        
        # Now load the cleaned state_dict into the current model
        super().load_state_dict(cleaned_state_dict, strict=True)
        print("Successfully loaded weights for both WavLM and MHFA backend.")

    def forward(self, raw_wav):
        # Feature extraction should not require gradients and should be in eval mode

        _, layer_results = self.model.extract_features(raw_wav, output_layer=100)
        
        # Prepare layer representations for the MHFA backend
        # Input layer_results: List of (Time, Batch, Dim) tensors
        # Stack them to create [Batch, Time, Dim, Nb_Layer]
        stacked_reps = torch.stack([x.transpose(0, 1) for x, _ in layer_results], dim=-1)
        # Permute to match MHFA input: [Batch, Dim, Time, Nb_Layer]
        layer_reps = stacked_reps.permute(0, 2, 1, 3)
        
        # The backend part is trainable
        spk_embedding = self.back_end(layer_reps)
        
        return spk_embedding

if __name__ == "__main__":

    # Step 1: Instantiate the main model
    # The model path should point to the pre-trained base model (e.g., WavLM-Base+.pt)
    print("Loading checkpoint file ...")

    base_model_path = './SSL_WavLM/model_convert.pt'
    model = WavLM_MHFA(model_path=base_model_path)
    model.eval() # Set the model to evaluation mode

    print("\nModel WavLM_MHFA initialized successfully.")

    # Step 2: Perform a forward pass with dummy data
    batch_size = 4
    audio_samples = 32000 # ~2 seconds of audio at 16kHz
    dummy_wav = torch.randn(batch_size, audio_samples)
    
    print(f"\nPerforming forward pass with dummy input of shape: {dummy_wav.shape}")
    
    speaker_embedding = model(dummy_wav)
        
    print("Forward pass successful!")
    print(f"Output speaker embedding shape: {speaker_embedding.shape}")