File size: 2,344 Bytes
86a2cc3
 
 
 
 
 
92579d0
86a2cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import Wav2Vec2Model, Wav2Vec2Config
from .conformer import FinalConformer

class DF_Arena_1B(nn.Module):
    def __init__(self):
        super().__init__()
        self.ssl_model = Wav2Vec2Model(Wav2Vec2Config.from_pretrained("facebook/wav2vec2-xls-r-1b"))
        self.ssl_model.config.output_hidden_states = True
        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.selu = nn.SELU(inplace=True)
        self.fc0 = nn.Linear(1280, 1) #1280 for 1b, 1920 for 2b
        self.sig = nn.Sigmoid()


        self.conformer = FinalConformer(emb_size=1280, heads=4, ffmult=4, exp_fac=2, kernel_size=31, n_encoders=4)

        # Learnable attention weights
        self.attn_scores = nn.Linear(1280, 1, bias=False)
    
    def get_attenF1Dpooling(self, x):
        #print(x.shape, 'x shape in attnF1Dpooling')
        logits = self.attn_scores(x)
        weights = torch.softmax(logits, dim=1)  # (B, T, 1)    
        pooled = torch.sum(weights * x, dim=1, keepdim=True)  # (B, 1, D)
        return pooled
    
    def get_attenF1D(self, layerResult):
        poollayerResult = []
        fullf = []
        for layer in layerResult:
            # layer shape: (B, D, T)
            #layery = layer.permute(0, 2, 1)  # (B, T, D)
            layery = self.get_attenF1Dpooling(layer)  # (B, 1, D)
            poollayerResult.append(layery)
            fullf.append(layer.unsqueeze(1))  # (B, 1, D, T)

        layery = torch.cat(poollayerResult, dim=1)      # (B, L, D)
        fullfeature = torch.cat(fullf, dim=1)          # (B, L, D, T)
        return layery, fullfeature

    def forward(self, x):
        out_ssl = self.ssl_model(x.unsqueeze(0)) #layerresult = [(x,z),24个] x(201,1,1024) z(1,201,201)
        y0, fullfeature = self.get_attenF1D(out_ssl.hidden_states) 
        y0 = self.fc0(y0)
        y0 = self.sig(y0)
        y0 = y0.view(y0.shape[0], y0.shape[1], y0.shape[2], -1)
        fullfeature = fullfeature * y0
        fullfeature = torch.sum(fullfeature, 1)
        fullfeature = fullfeature.unsqueeze(dim=1)
        fullfeature = self.first_bn(fullfeature)
        fullfeature = self.selu(fullfeature)


        output, _ = self.conformer(fullfeature.squeeze(1))


        return output