ta012 commited on
Commit
501ebee
·
verified ·
1 Parent(s): faca072

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +116 -3
  2. config.json +38 -0
  3. configuration_eat.py +66 -0
  4. eat_model.py +99 -0
  5. model.safetensors +3 -0
  6. model_core.py +224 -0
  7. modeling_eat.py +18 -0
README.md CHANGED
@@ -1,3 +1,116 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - Audio
5
+ - SSL
6
+ - SSLAM
7
+ library_name: transformers
8
+ ---
9
+
10
+ # SSLAM Pretrain (ViT Base, 15 epochs)
11
+
12
+ This repository provides an SSLAM checkpoint formatted for use with Hugging Face Transformers. It is intended for feature extraction in audio LLMs, sound event detection, and general purpose audio representation learning. The implementation follows the [EAT](https://arxiv.org/abs/2401.03497) code path while swapping in SSLAM pretrained weights.
13
+
14
+
15
+
16
+ ## 🔧 Usage
17
+
18
+ You can load and use the model for feature extraction directly via Hugging Face Transformers:
19
+
20
+ ```python
21
+ import torchaudio
22
+ import torch
23
+ import soundfile as sf
24
+ import numpy as np
25
+ from transformers import AutoModel
26
+
27
+ model_id = "ta012/SSLAM_pretrain"
28
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
29
+
30
+ source_file = "/path/to/input.wav"
31
+ target_length = 1024 # Recommended: 1024 for 10s audio
32
+ norm_mean = -4.268
33
+ norm_std = 4.569
34
+
35
+ # Load and resample audio
36
+ wav, sr = sf.read(source_file)
37
+ waveform = torch.tensor(wav).float().cuda()
38
+ if sr != 16000:
39
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
40
+
41
+ # Normalize and convert to mel-spectrogram
42
+ waveform = waveform - waveform.mean()
43
+ mel = torchaudio.compliance.kaldi.fbank(
44
+ waveform.unsqueeze(0),
45
+ htk_compat=True,
46
+ sample_frequency=16000,
47
+ use_energy=False,
48
+ window_type='hanning',
49
+ num_mel_bins=128,
50
+ dither=0.0,
51
+ frame_shift=10
52
+ ).unsqueeze(0)
53
+
54
+ # Pad or truncate
55
+ n_frames = mel.shape[1]
56
+ if n_frames < target_length:
57
+ mel = torch.nn.ZeroPad2d((0, 0, 0, target_length - n_frames))(mel)
58
+ else:
59
+ mel = mel[:, :target_length, :]
60
+
61
+ # Normalize
62
+ mel = (mel - norm_mean) / (norm_std * 2)
63
+ mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
64
+
65
+ # Extract features
66
+ with torch.no_grad():
67
+ feat = model.extract_features(mel)
68
+
69
+ feat = feat.squeeze(0).cpu().numpy()
70
+ print(f"Feature shape: {feat.shape}")
71
+ ```
72
+
73
+ ## 📌 Notes
74
+
75
+
76
+ See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions.
77
+
78
+
79
+ ## 🙌 Acknowledgments
80
+
81
+ This repository builds on the EAT implementation for Hugging Face models. We remap SSLAM weights to that interface.
82
+
83
+ - Paper: EAT: Self supervised pretraining with Efficient Audio Transformer
84
+ - Code: https://github.com/cwx-worst-one/EAT
85
+
86
+ We are not affiliated with the EAT authors. All credit for the original implementation belongs to them.
87
+
88
+
89
+ ## 📚 Citation
90
+
91
+
92
+ If you find our work useful, please cite it as:
93
+
94
+
95
+ ```bibtex
96
+ @inproceedings{alex2025sslam,
97
+ title={{SSLAM}: Enhancing Self-Supervised Models with Audio Mixtures for Polyphonic Soundscapes},
98
+ author={Tony Alex and Sara Atito and Armin Mustafa and Muhammad Awais and Philip J B Jackson},
99
+ booktitle={The Thirteenth International Conference on Learning Representations},
100
+ year={2025},
101
+ url={https://openreview.net/forum?id=odU59TxdiB}
102
+ }
103
+ ```
104
+
105
+
106
+
107
+ Please also cite EAT:
108
+
109
+ ```bibtex
110
+ @article{chen2024eat,
111
+ title={EAT: Self-supervised pre-training with efficient audio transformer},
112
+ author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
113
+ journal={arXiv preprint arXiv:2401.03497},
114
+ year={2024}
115
+ }
116
+ ```
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "architectures": [
4
+ "EATModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_eat.EATModel",
8
+ "AutoConfig": "configuration_eat.EATConfig"
9
+ },
10
+ "attn_drop_rate": 0.0,
11
+ "depth": 12,
12
+ "drop_rate": 0.0,
13
+ "embed_dim": 768,
14
+ "end_drop_path_rate": 0.0,
15
+ "fixed_positions": true,
16
+ "img_size": [
17
+ 1024,
18
+ 128
19
+ ],
20
+ "in_chans": 1,
21
+ "layer_norm_first": false,
22
+ "max_length": 768,
23
+ "mel_bins": 128,
24
+ "mlp_ratio": 4.0,
25
+ "model_type": "eat",
26
+ "model_variant": "pretrain",
27
+ "norm_affine": true,
28
+ "norm_eps": 1e-06,
29
+ "num_classes": 527,
30
+ "num_heads": 12,
31
+ "patch_size": 16,
32
+ "post_mlp_drop": 0.0,
33
+ "qkv_bias": true,
34
+ "start_drop_path_rate": 0.0,
35
+ "stride": 16,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.51.3"
38
+ }
configuration_eat.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_eat.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class EATConfig(PretrainedConfig):
6
+ model_type = "eat"
7
+
8
+ def __init__(
9
+ self,
10
+ embed_dim=768,
11
+ depth=12,
12
+ num_heads=12,
13
+ patch_size=16,
14
+ stride=16,
15
+ in_chans=1,
16
+ mel_bins=128,
17
+ max_length=768,
18
+ num_classes=527,
19
+ model_variant="pretrain", # or "finetune"
20
+
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_rate=0.0,
24
+ attn_drop_rate=0.0,
25
+ activation_dropout=0.0,
26
+ post_mlp_drop=0.0,
27
+ start_drop_path_rate=0.0,
28
+ end_drop_path_rate=0.0,
29
+
30
+ layer_norm_first=False,
31
+ norm_eps=1e-6,
32
+ norm_affine=True,
33
+ fixed_positions=True,
34
+
35
+ img_size=(1024, 128), # (target_length, mel_bins)
36
+
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.embed_dim = embed_dim
42
+ self.depth = depth
43
+ self.num_heads = num_heads
44
+ self.patch_size = patch_size
45
+ self.stride = stride
46
+ self.in_chans = in_chans
47
+ self.mel_bins = mel_bins
48
+ self.max_length = max_length
49
+ self.num_classes = num_classes
50
+ self.model_variant = model_variant
51
+
52
+ self.mlp_ratio = mlp_ratio
53
+ self.qkv_bias = qkv_bias
54
+ self.drop_rate = drop_rate
55
+ self.attn_drop_rate = attn_drop_rate
56
+ self.activation_dropout = activation_dropout
57
+ self.post_mlp_drop = post_mlp_drop
58
+ self.start_drop_path_rate = start_drop_path_rate
59
+ self.end_drop_path_rate = end_drop_path_rate
60
+
61
+ self.layer_norm_first = layer_norm_first
62
+ self.norm_eps = norm_eps
63
+ self.norm_affine = norm_affine
64
+ self.fixed_positions = fixed_positions
65
+
66
+ self.img_size = img_size
eat_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import trunc_normal_
4
+ from functools import partial
5
+ import numpy as np
6
+ from .model_core import (
7
+ PatchEmbed_new,
8
+ get_2d_sincos_pos_embed_flexible,
9
+ FixedPositionalEncoder,
10
+ AltBlock
11
+ )
12
+
13
+ class EAT(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.mode = config.model_variant # "pretrain" or "finetune"
18
+
19
+ # === Embedding / Encoder ===
20
+ self.local_encoder = PatchEmbed_new(
21
+ img_size=config.img_size,
22
+ patch_size=config.patch_size,
23
+ in_chans=config.in_chans,
24
+ embed_dim=config.embed_dim,
25
+ stride=config.stride
26
+ )
27
+
28
+ self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
29
+ self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
30
+ trunc_normal_(self.extra_tokens, std=.02)
31
+
32
+ self.fixed_positional_encoder = (
33
+ FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
34
+ )
35
+
36
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
37
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
38
+ self.blocks = nn.ModuleList([
39
+ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
40
+ qkv_bias=config.qkv_bias, drop=config.drop_rate,
41
+ attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
42
+ post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
43
+ norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
44
+ ffn_targets=True)
45
+ for i in range(config.depth)
46
+ ])
47
+
48
+ self.pre_norm = norm_layer(config.embed_dim)
49
+
50
+ # === Head (for finetune) ===
51
+ if self.mode == "finetune":
52
+ self.fc_norm = nn.LayerNorm(config.embed_dim)
53
+ self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
54
+ else:
55
+ self.head = nn.Identity()
56
+
57
+ self.apply(self._init_weights)
58
+
59
+ def build_sincos_pos_embed(self):
60
+ W = self.config.mel_bins // self.config.patch_size
61
+ max_length = self.config.max_length
62
+ embed_dim = self.config.embed_dim
63
+ pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
64
+ emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
65
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
66
+ return pos_embed
67
+
68
+ def _init_weights(self, m):
69
+ if isinstance(m, nn.Linear):
70
+ trunc_normal_(m.weight, std=.02)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ elif isinstance(m, nn.LayerNorm):
74
+ nn.init.constant_(m.bias, 0)
75
+ nn.init.constant_(m.weight, 1.0)
76
+
77
+ def encode(self, x):
78
+ B = x.shape[0]
79
+ x = self.local_encoder(x)
80
+ if self.fixed_positional_encoder is not None:
81
+ x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
82
+ x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
83
+ x = self.pre_norm(x)
84
+ x = self.pos_drop(x)
85
+ for blk in self.blocks:
86
+ x, _ = blk(x)
87
+ return x
88
+
89
+ def forward(self, x):
90
+ x = self.encode(x)
91
+ if self.mode == "finetune":
92
+ x = x[:, 0] # use cls token
93
+ x = self.fc_norm(x)
94
+ x = self.head(x)
95
+ return x
96
+
97
+ def extract_features(self, x):
98
+ x = self.encode(x)
99
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ec670adb241c710422ddd894ff1bade142ef0b25cf1ee68577aa45f89432298
3
+ size 359905840
model_core.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from timm.models.layers import to_2tuple
6
+
7
+ class PatchEmbed_new(nn.Module):
8
+ """ Flexible Image to Patch Embedding
9
+ """
10
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16):
11
+ super().__init__()
12
+ img_size = to_2tuple(img_size)
13
+ patch_size = to_2tuple(patch_size)
14
+ stride = to_2tuple(stride)
15
+
16
+ self.img_size = img_size
17
+ self.patch_size = patch_size
18
+
19
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
20
+
21
+ def forward(self, x):
22
+ x = self.proj(x)
23
+ x = x.flatten(2).transpose(1, 2)
24
+ return x
25
+
26
+
27
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
28
+ """
29
+ grid_size: int of the grid height and width
30
+ return:
31
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32
+ """
33
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
34
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
35
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
36
+ grid = np.stack(grid, axis=0)
37
+
38
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
39
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40
+ if cls_token:
41
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
42
+ return pos_embed
43
+
44
+
45
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
46
+ assert embed_dim % 2 == 0
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position
59
+ pos: a list of positions to be encoded: size (M,)
60
+ out: (M, D)
61
+ """
62
+ assert embed_dim % 2 == 0
63
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000 ** omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ class FixedPositionalEncoder(nn.Module):
78
+ def __init__(self, pos_embed):
79
+ super().__init__()
80
+ self.positions = pos_embed
81
+
82
+ def forward(self, x, padding_mask):
83
+ return self.positions
84
+
85
+
86
+ class AltBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ num_heads,
91
+ mlp_ratio=4.0,
92
+ qkv_bias=False,
93
+ qk_scale=None,
94
+ drop=0.0,
95
+ attn_drop=0.0,
96
+ mlp_drop=0.0,
97
+ post_mlp_drop=0.0,
98
+ drop_path=0.0,
99
+ act_layer=nn.GELU,
100
+ norm_layer=nn.LayerNorm,
101
+ layer_norm_first=True,
102
+ ffn_targets=False,
103
+ cosine_attention=False,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.layer_norm_first = layer_norm_first
108
+ self.ffn_targets = ffn_targets
109
+
110
+ from timm.models.vision_transformer import DropPath, Mlp
111
+
112
+ self.norm1 = norm_layer(dim)
113
+ self.attn = AltAttention(
114
+ dim,
115
+ num_heads=num_heads,
116
+ qkv_bias=qkv_bias,
117
+ qk_scale=qk_scale,
118
+ attn_drop=attn_drop,
119
+ proj_drop=drop,
120
+ cosine_attention=cosine_attention,
121
+ )
122
+
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(
127
+ in_features=dim,
128
+ hidden_features=mlp_hidden_dim,
129
+ act_layer=act_layer,
130
+ drop=mlp_drop,
131
+ )
132
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
133
+
134
+ def forward(self, x, padding_mask=None, alibi_bias=None):
135
+ if self.layer_norm_first:
136
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
137
+ r = x = self.mlp(self.norm2(x))
138
+ t = x
139
+ x = r + self.drop_path(self.post_mlp_dropout(x))
140
+ if not self.ffn_targets:
141
+ t = x
142
+ else:
143
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
144
+ r = x = self.norm1(x)
145
+ x = self.mlp(x)
146
+ t = x
147
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
148
+ if not self.ffn_targets:
149
+ t = x
150
+
151
+ return x, t
152
+
153
+
154
+ class AltAttention(nn.Module):
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads=8,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ cosine_attention=False,
164
+ ):
165
+ super().__init__()
166
+ self.num_heads = num_heads
167
+ head_dim = dim // num_heads
168
+ self.scale = qk_scale or head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.attn_drop = nn.Dropout(attn_drop)
172
+ self.proj = nn.Linear(dim, dim)
173
+ self.proj_drop = nn.Dropout(proj_drop)
174
+
175
+ self.cosine_attention = cosine_attention
176
+
177
+ if cosine_attention:
178
+ self.logit_scale = nn.Parameter(
179
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
180
+ )
181
+
182
+ def forward(self, x, padding_mask=None, alibi_bias=None):
183
+ B, N, C = x.shape
184
+ qkv = (
185
+ self.qkv(x)
186
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
187
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
188
+ )
189
+ q, k, v = (
190
+ qkv[0],
191
+ qkv[1],
192
+ qkv[2],
193
+ ) # make torchscript happy (cannot use tensor as tuple)
194
+
195
+ dtype = q.dtype
196
+
197
+ if self.cosine_attention:
198
+ # cosine attention
199
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
200
+ logit_scale = torch.clamp(
201
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
202
+ ).exp()
203
+ attn = attn * logit_scale
204
+ else:
205
+ q = q * self.scale
206
+ attn = q @ k.transpose(-2, -1)
207
+
208
+ if alibi_bias is not None:
209
+ attn = attn.type_as(alibi_bias)
210
+ attn[:, : alibi_bias.size(1)] += alibi_bias
211
+
212
+ if padding_mask is not None and padding_mask.any():
213
+ attn = attn.masked_fill(
214
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
215
+ float("-inf"),
216
+ )
217
+
218
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
219
+ attn = self.attn_drop(attn)
220
+ x = (attn @ v).transpose(1, 2) #
221
+ x = x.reshape(B, N, C)
222
+ x = self.proj(x)
223
+ x = self.proj_drop(x)
224
+ return x
modeling_eat.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_eat.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from .configuration_eat import EATConfig
5
+ from .eat_model import EAT
6
+
7
+ class EATModel(PreTrainedModel):
8
+ config_class = EATConfig
9
+
10
+ def __init__(self, config: EATConfig):
11
+ super().__init__(config)
12
+ self.model = EAT(config)
13
+
14
+ def forward(self, *args, **kwargs):
15
+ return self.model(*args, **kwargs)
16
+
17
+ def extract_features(self, x):
18
+ return self.model.extract_features(x)