yangwang825 commited on
Commit
7c2eaae
·
1 Parent(s): 3b0efa0

Upload MERTForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +86 -0
  2. modeling_mert.py +577 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "m-a-p/MERT-v1-95M",
3
+ "activation_dropout": 0.0,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "MERTForSequenceClassification"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "attention_relax": -1.0,
10
+ "auto_map": {
11
+ "AutoConfig": "configuration_MERT.MERTConfig",
12
+ "AutoModel": "modeling_MERT.MERTModel",
13
+ "AutoModelForAudioClassification": "modeling_mert.MERTForSequenceClassification"
14
+ },
15
+ "bos_token_id": 1,
16
+ "classifier_proj_size": 256,
17
+ "conv_bias": false,
18
+ "conv_dim": [
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512,
25
+ 512
26
+ ],
27
+ "conv_kernel": [
28
+ 10,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 3,
33
+ 2,
34
+ 2
35
+ ],
36
+ "conv_stride": [
37
+ 5,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2,
43
+ 2
44
+ ],
45
+ "ctc_loss_reduction": "sum",
46
+ "ctc_zero_infinity": false,
47
+ "deepnorm": false,
48
+ "do_stable_layer_norm": false,
49
+ "eos_token_id": 2,
50
+ "feat_extract_activation": "gelu",
51
+ "feat_extract_dropout": 0.0,
52
+ "feat_extract_norm": "group",
53
+ "feat_proj_dropout": 0.1,
54
+ "feat_proj_layer_norm": true,
55
+ "feature_extractor_cqt": false,
56
+ "feature_extractor_cqt_bins": 336,
57
+ "final_dropout": 0.1,
58
+ "gradient_checkpointing": false,
59
+ "hidden_act": "gelu",
60
+ "hidden_dropout": 0.1,
61
+ "hidden_dropout_prob": 0.1,
62
+ "hidden_size": 768,
63
+ "initializer_range": 0.02,
64
+ "intermediate_size": 3072,
65
+ "layer_norm_eps": 1e-05,
66
+ "layerdrop": 0.05,
67
+ "mask_feature_length": 10,
68
+ "mask_feature_min_masks": 0,
69
+ "mask_feature_prob": 0.0,
70
+ "mask_time_length": 10,
71
+ "mask_time_min_masks": 2,
72
+ "mask_time_prob": 0.05,
73
+ "model_type": "mert_model",
74
+ "num_attention_heads": 12,
75
+ "num_conv_pos_embedding_groups": 16,
76
+ "num_conv_pos_embeddings": 128,
77
+ "num_feat_extract_layers": 7,
78
+ "num_hidden_layers": 12,
79
+ "pad_token_id": 0,
80
+ "sample_rate": 24000,
81
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
82
+ "torch_dtype": "float32",
83
+ "transformers_version": "4.25.1",
84
+ "use_weighted_layer_sum": false,
85
+ "vocab_size": 32
86
+ }
modeling_mert.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple, Union
5
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
6
+ from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
7
+ from transformers.models.hubert.modeling_hubert import (
8
+ HubertFeatureEncoder,
9
+ HubertModel,
10
+ HubertEncoderStableLayerNorm,
11
+ HubertEncoder,
12
+ HubertEncoderLayer,
13
+ HubertPositionalConvEmbedding,
14
+ HubertAttention,
15
+ HubertFeedForward,
16
+ PreTrainedModel
17
+ )
18
+
19
+ try:
20
+ from nnAudio import features as nnAudioFeatures
21
+ NNAUDIO_INSTALLED=True
22
+ except:
23
+ print("WARNING: feature_extractor_cqt requires the libray 'nnAudio'")
24
+ NNAUDIO_INSTALLED=False
25
+
26
+ from src.models.mert.configuration_mert import MERTConfig
27
+
28
+ _HIDDEN_STATES_START_POSITION = 1
29
+
30
+
31
+ class MERTFeatureProjection(nn.Module):
32
+
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.feat_proj_layer_norm = config.feat_proj_layer_norm
36
+ self.feature_extractor_cqt = config.feature_extractor_cqt
37
+
38
+ if self.feature_extractor_cqt:
39
+ # v3 concat features
40
+ self.feature_dimension = config.conv_dim[-1] + config.feature_extractor_cqt_bins
41
+ print(f"feature dimention: {self.feature_dimension}")
42
+ else:
43
+ self.feature_dimension = config.conv_dim[-1]
44
+ if self.feat_proj_layer_norm:
45
+ self.layer_norm = nn.LayerNorm(self.feature_dimension, eps=config.layer_norm_eps)
46
+ self.projection = nn.Linear(self.feature_dimension, config.hidden_size)
47
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
48
+
49
+ def forward(self, hidden_states):
50
+ # non-projected hidden states are needed for quantization
51
+ if self.feat_proj_layer_norm:
52
+ hidden_states = self.layer_norm(hidden_states)
53
+ hidden_states = self.projection(hidden_states)
54
+ hidden_states = self.dropout(hidden_states)
55
+ return hidden_states
56
+
57
+ class MERTModel(HubertModel):
58
+
59
+ config_class = MERTConfig
60
+ base_model_prefix = "mert_model"
61
+
62
+ def __init__(
63
+ self,
64
+ config: MERTConfig,
65
+ ) -> None:
66
+ """
67
+ initialize the with the grandparent method HubertPreTrainedModel.__init__()
68
+ and modify the HuBERTModel.__init__()
69
+ """
70
+ super(HubertModel, self).__init__(config)
71
+
72
+ self.config = config
73
+
74
+ self.feature_extractor = HubertFeatureEncoder(config)
75
+ self.feature_projection = MERTFeatureProjection(config) # replace Feature Projection for introcuing new feature
76
+
77
+ if self.config.feature_extractor_cqt:
78
+ assert NNAUDIO_INSTALLED, "ERROR: feature_extractor_cqt requires the libray 'nnAudio', try after `pip install nnAudio` "
79
+ print('initializing cqt extractor for MERT')
80
+ self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=self.config.sample_rate, hop_length=self.config.sample_rate//50, fmin=32.7,
81
+ fmax=None, n_bins=self.config.feature_extractor_cqt_bins, bins_per_octave=self.config.feature_extractor_cqt_bins//7,
82
+ filter_scale=1, norm=1, window='hann', center=True,
83
+ pad_mode='constant', trainable=False,
84
+ output_format='Magnitude', verbose=True)
85
+
86
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
87
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
88
+
89
+
90
+ if config.do_stable_layer_norm:
91
+ assert not config.deepnorm, "must use post-layer_norm with deepnorm"
92
+ self.encoder = HubertEncoderStableLayerNorm(config)
93
+ else:
94
+ if config.deepnorm:
95
+ self.encoder = HubertEncoder_extend(config)
96
+ else:
97
+ self.encoder = HubertEncoder(config)
98
+
99
+ # Initialize weights and apply final processing
100
+ self.post_init()
101
+
102
+ def forward(self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[Tuple, BaseModelOutput]:
103
+
104
+ # return super().forward(input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)
105
+
106
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
107
+ output_hidden_states = (
108
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
109
+ )
110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
+
112
+ extract_features = self.feature_extractor(input_values)
113
+ extract_features = extract_features.transpose(1, 2)
114
+
115
+ # add additional cqt features for transformer input
116
+ if self.config.feature_extractor_cqt:
117
+ features_cqt = self.feature_extractor_cqt(input_values).transpose(1, 2)
118
+ features_cqt = features_cqt[:,:extract_features.shape[1],:] # align shape
119
+ # # v2
120
+ # features_cqt = self.post_cqt_feature_proj(features_cqt)
121
+ # extract_features = self.feature_projection.layer_norm(extract_features) + self.feature_projection.layer_norm(features_cqt) #v2
122
+ # v3
123
+ extract_features = torch.cat([extract_features,features_cqt], 2)
124
+
125
+ if attention_mask is not None:
126
+ # compute reduced attention_mask corresponding to feature vectors
127
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
128
+
129
+ hidden_states = self.feature_projection(extract_features)
130
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
131
+
132
+ encoder_outputs = self.encoder(
133
+ hidden_states,
134
+ attention_mask=attention_mask,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ hidden_states = encoder_outputs[0] # take last_hidden from encoder output
141
+
142
+ if not return_dict:
143
+ return (hidden_states,) + encoder_outputs[1:]
144
+
145
+ return BaseModelOutput(
146
+ last_hidden_state=hidden_states,
147
+ hidden_states=encoder_outputs.hidden_states,
148
+ attentions=encoder_outputs.attentions,
149
+ )
150
+
151
+
152
+ class HubertEncoder_extend(HubertEncoder):
153
+ def __init__(self, config):
154
+ # super().__init__()
155
+ # call nn module initialization
156
+ nn.Module.__init__(self)
157
+ # super(HubertEncoder_extend, self).__init__()
158
+
159
+ self.config = config
160
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
161
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
162
+ self.dropout = nn.Dropout(config.hidden_dropout)
163
+
164
+
165
+ self.layers = nn.ModuleList([HubertEncoderLayerExtend(config) for _ in range(config.num_hidden_layers)])
166
+
167
+ self.gradient_checkpointing = False
168
+
169
+ if config.deepnorm:
170
+ import math
171
+ init_scale = math.pow(8.0 * config.num_hidden_layers, 0.25)
172
+ for name, p in self.named_parameters():
173
+ if (
174
+ "feed_forward.intermediate_dense" in name
175
+ or "feed_forward.output_dense" in name
176
+ or "out_proj" in name
177
+ or "v_proj" in name
178
+ ):
179
+ p.data.div_(init_scale)
180
+
181
+ class HubertEncoderLayerExtend(HubertEncoderLayer):
182
+
183
+ def __init__(self, config):
184
+ nn.Module.__init__(self)
185
+ # super(HubertEncoderLayerExtend, self).__init__()
186
+ if config.attention_relax > 0 :
187
+ self.attention = HubertAttention_extend(
188
+ embed_dim=config.hidden_size,
189
+ num_heads=config.num_attention_heads,
190
+ dropout=config.attention_dropout,
191
+ is_decoder=False,
192
+ attention_relax=config.attention_relax,
193
+ )
194
+ else:
195
+ self.attention = HubertAttention(
196
+ embed_dim=config.hidden_size,
197
+ num_heads=config.num_attention_heads,
198
+ dropout=config.attention_dropout,
199
+ is_decoder=False,
200
+ )
201
+ self.dropout = nn.Dropout(config.hidden_dropout)
202
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
203
+ self.feed_forward = HubertFeedForward(config)
204
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
205
+
206
+ if config.deepnorm:
207
+ import math
208
+ self.residual_alpha = math.pow(2.0 * config.num_hidden_layers, 0.25)
209
+ else:
210
+ self.residual_alpha = 1.0
211
+
212
+ def residual_connection(self, x, residual):
213
+ '''
214
+ residual: input before f()
215
+ x: output of f(residual)
216
+ '''
217
+ return residual * self.residual_alpha + x
218
+
219
+ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
220
+ attn_residual = hidden_states
221
+ hidden_states, attn_weights, _ = self.attention(
222
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
223
+ )
224
+ hidden_states = self.dropout(hidden_states)
225
+
226
+ # hidden_states = attn_residual + hidden_states
227
+ hidden_states = self.residual_connection(hidden_states, attn_residual)
228
+
229
+ hidden_states = self.layer_norm(hidden_states)
230
+
231
+ # hidden_states = hidden_states + self.feed_forward(hidden_states)
232
+ ffn_residual = hidden_states
233
+ hidden_states = self.feed_forward(hidden_states)
234
+ hidden_states = self.residual_connection(hidden_states, ffn_residual)
235
+
236
+ hidden_states = self.final_layer_norm(hidden_states)
237
+
238
+ outputs = (hidden_states,)
239
+
240
+ if output_attentions:
241
+ outputs += (attn_weights,)
242
+
243
+ return outputs
244
+
245
+
246
+ class HubertAttention_extend(nn.Module):
247
+
248
+ def __init__(
249
+ self,
250
+ embed_dim: int,
251
+ num_heads: int,
252
+ dropout: float = 0.0,
253
+ is_decoder: bool = False,
254
+ bias: bool = True,
255
+ attention_relax: float = -1.0,
256
+ ):
257
+ super().__init__()
258
+ # nn.Module.__init__(self)
259
+ self.embed_dim = embed_dim
260
+ self.num_heads = num_heads
261
+ self.dropout = dropout
262
+ self.head_dim = embed_dim // num_heads
263
+
264
+ if (self.head_dim * num_heads) != self.embed_dim:
265
+ raise ValueError(
266
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
267
+ f" and `num_heads`: {num_heads})."
268
+ )
269
+ self.scaling = self.head_dim**-0.5
270
+ self.is_decoder = is_decoder
271
+
272
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
273
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
274
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
275
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
276
+
277
+ if attention_relax > 0:
278
+ self.attention_relax = attention_relax
279
+
280
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
281
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states: torch.Tensor,
286
+ key_value_states: Optional[torch.Tensor] = None,
287
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ layer_head_mask: Optional[torch.Tensor] = None,
290
+ output_attentions: bool = False,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ """Input shape: Batch x Time x Channel"""
293
+
294
+ # if key_value_states are provided this layer is used as a cross-attention layer
295
+ # for the decoder
296
+ is_cross_attention = key_value_states is not None
297
+
298
+ bsz, tgt_len, _ = hidden_states.size()
299
+
300
+ # get query proj
301
+ query_states = self.q_proj(hidden_states) * self.scaling
302
+ # get key, value proj
303
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
304
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
305
+ # the provided `key_value_states` to support prefix tuning
306
+ if (
307
+ is_cross_attention
308
+ and past_key_value is not None
309
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
310
+ ):
311
+ # reuse k,v, cross_attentions
312
+ key_states = past_key_value[0]
313
+ value_states = past_key_value[1]
314
+ elif is_cross_attention:
315
+ # cross_attentions
316
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
317
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
318
+ elif past_key_value is not None:
319
+ # reuse k, v, self_attention
320
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
321
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
322
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
323
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
324
+ else:
325
+ # self_attention
326
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
327
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
328
+
329
+ if self.is_decoder:
330
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
331
+ # Further calls to cross_attention layer can then reuse all cross-attention
332
+ # key/value_states (first "if" case)
333
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
334
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
335
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
336
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
337
+ past_key_value = (key_states, value_states)
338
+
339
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
340
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
341
+ key_states = key_states.view(*proj_shape)
342
+ value_states = value_states.view(*proj_shape)
343
+
344
+ src_len = key_states.size(1)
345
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
346
+
347
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
348
+ raise ValueError(
349
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
350
+ f" {attn_weights.size()}"
351
+ )
352
+
353
+ if attention_mask is not None:
354
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
355
+ raise ValueError(
356
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
357
+ )
358
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
359
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
360
+
361
+ if self.attention_relax > 0:
362
+ # => (bsz, self.num_heads, tgt_len, src_len)
363
+ # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax
364
+ # => (bsz*self.num_heads, tgt_len, src_len)
365
+ attn_weights_relax = attn_weights / self.attention_relax
366
+
367
+ # => (bsz* self.num_heads, tgt_len, 1)
368
+ attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2)
369
+ attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax
370
+
371
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
372
+
373
+ if layer_head_mask is not None:
374
+ if layer_head_mask.size() != (self.num_heads,):
375
+ raise ValueError(
376
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
377
+ f" {layer_head_mask.size()}"
378
+ )
379
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
380
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
381
+
382
+ if output_attentions:
383
+ # this operation is a bit awkward, but it's required to
384
+ # make sure that attn_weights keeps its gradient.
385
+ # In order to do so, attn_weights have to be reshaped
386
+ # twice and have to be reused in the following
387
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
388
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
389
+ else:
390
+ attn_weights_reshaped = None
391
+
392
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
393
+
394
+ attn_output = torch.bmm(attn_probs, value_states)
395
+
396
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
397
+ raise ValueError(
398
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
399
+ f" {attn_output.size()}"
400
+ )
401
+
402
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
403
+ attn_output = attn_output.transpose(1, 2)
404
+
405
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
406
+ # partitioned aross GPUs when using tensor-parallelism.
407
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
408
+
409
+ attn_output = self.out_proj(attn_output)
410
+
411
+ return attn_output, attn_weights_reshaped, past_key_value
412
+
413
+
414
+ class MERTPreTrainedModel(PreTrainedModel):
415
+ """
416
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
417
+ models.
418
+ """
419
+
420
+ config_class = MERTConfig
421
+ base_model_prefix = "mert"
422
+ main_input_name = "input_values"
423
+ supports_gradient_checkpointing = True
424
+
425
+ def _init_weights(self, module):
426
+ """Initialize the weights"""
427
+ if isinstance(module, nn.Linear):
428
+ # Slightly different from the TF version which uses truncated_normal for initialization
429
+ # cf https://github.com/pytorch/pytorch/pull/5617
430
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
431
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
432
+ module.bias.data.zero_()
433
+ module.weight.data.fill_(1.0)
434
+ elif isinstance(module, nn.Conv1d):
435
+ nn.init.kaiming_normal_(module.weight.data)
436
+
437
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
438
+ module.bias.data.zero_()
439
+
440
+ def _set_gradient_checkpointing(self, module, value=False):
441
+ if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)):
442
+ module.gradient_checkpointing = value
443
+
444
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
445
+ """
446
+ Computes the output length of the convolutional layers
447
+ """
448
+
449
+ def _conv_out_length(input_length, kernel_size, stride):
450
+ # 1D convolutional layer output length formula taken
451
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
452
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
453
+
454
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
455
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
456
+
457
+ return input_lengths
458
+
459
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
460
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
461
+ batch_size = attention_mask.shape[0]
462
+
463
+ attention_mask = torch.zeros(
464
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
465
+ )
466
+ # these two operations makes sure that all values before the output lengths idxs are attended to
467
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
468
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
469
+ return attention_mask
470
+
471
+
472
+ class MERTForSequenceClassification(MERTPreTrainedModel):
473
+
474
+ def __init__(self, config):
475
+ super().__init__(config)
476
+
477
+ if hasattr(config, "add_adapter") and config.add_adapter:
478
+ raise ValueError(
479
+ "Sequence classification does not support the use of MERT adapters (config.add_adapter=True)"
480
+ )
481
+ self.mert = MERTModel(config)
482
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
483
+ if config.use_weighted_layer_sum:
484
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
485
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
486
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
487
+
488
+ # Initialize weights and apply final processing
489
+ self.post_init()
490
+
491
+ def freeze_feature_extractor(self):
492
+ """
493
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
494
+ not be updated during training.
495
+ """
496
+ warnings.warn(
497
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
498
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
499
+ FutureWarning,
500
+ )
501
+ self.freeze_feature_encoder()
502
+
503
+ def freeze_feature_encoder(self):
504
+ """
505
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
506
+ not be updated during training.
507
+ """
508
+ self.mert.feature_extractor._freeze_parameters()
509
+
510
+ def freeze_base_model(self):
511
+ """
512
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
513
+ be updated during training. Only the classification head will be updated.
514
+ """
515
+ for param in self.mert.parameters():
516
+ param.requires_grad = False
517
+
518
+ def forward(
519
+ self,
520
+ input_values: Optional[torch.Tensor],
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ output_attentions: Optional[bool] = None,
523
+ output_hidden_states: Optional[bool] = None,
524
+ return_dict: Optional[bool] = None,
525
+ labels: Optional[torch.Tensor] = None,
526
+ ) -> Union[Tuple, SequenceClassifierOutput]:
527
+ r"""
528
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
529
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
530
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
531
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
532
+ """
533
+
534
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
535
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
536
+
537
+ outputs = self.mert(
538
+ input_values,
539
+ attention_mask=attention_mask,
540
+ output_attentions=output_attentions,
541
+ output_hidden_states=output_hidden_states,
542
+ return_dict=return_dict,
543
+ )
544
+
545
+ if self.config.use_weighted_layer_sum:
546
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
547
+ hidden_states = torch.stack(hidden_states, dim=1)
548
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
549
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
550
+ else:
551
+ hidden_states = outputs[0]
552
+
553
+ hidden_states = self.projector(hidden_states)
554
+ if attention_mask is None:
555
+ pooled_output = hidden_states.mean(dim=1)
556
+ else:
557
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
558
+ hidden_states[~padding_mask] = 0.0
559
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
560
+
561
+ logits = self.classifier(pooled_output)
562
+
563
+ loss = None
564
+ if labels is not None:
565
+ loss_fct = nn.CrossEntropyLoss()
566
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
567
+
568
+ if not return_dict:
569
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
570
+ return ((loss,) + output) if loss is not None else output
571
+
572
+ return SequenceClassifierOutput(
573
+ loss=loss,
574
+ logits=logits,
575
+ hidden_states=outputs.hidden_states,
576
+ attentions=outputs.attentions,
577
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34eb474bc73867231042cb3003d56c9963c9ef010eabd3e4518039a4e5051a54
3
+ size 378346601