Upload 7 files
Browse files- README.md +116 -3
- config.json +38 -0
- configuration_eat.py +66 -0
- eat_model.py +99 -0
- model.safetensors +3 -0
- model_core.py +224 -0
- modeling_eat.py +18 -0
README.md
CHANGED
|
@@ -1,3 +1,116 @@
|
|
| 1 |
-
---
|
| 2 |
-
license:
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- Audio
|
| 5 |
+
- SSL
|
| 6 |
+
- SSLAM
|
| 7 |
+
library_name: transformers
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# SSLAM Pretrain (ViT Base, 15 epochs)
|
| 11 |
+
|
| 12 |
+
This repository provides an SSLAM checkpoint formatted for use with Hugging Face Transformers. It is intended for feature extraction in audio LLMs, sound event detection, and general purpose audio representation learning. The implementation follows the [EAT](https://arxiv.org/abs/2401.03497) code path while swapping in SSLAM pretrained weights.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## 🔧 Usage
|
| 17 |
+
|
| 18 |
+
You can load and use the model for feature extraction directly via Hugging Face Transformers:
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
import torchaudio
|
| 22 |
+
import torch
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
import numpy as np
|
| 25 |
+
from transformers import AutoModel
|
| 26 |
+
|
| 27 |
+
model_id = "ta012/SSLAM_pretrain"
|
| 28 |
+
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
|
| 29 |
+
|
| 30 |
+
source_file = "/path/to/input.wav"
|
| 31 |
+
target_length = 1024 # Recommended: 1024 for 10s audio
|
| 32 |
+
norm_mean = -4.268
|
| 33 |
+
norm_std = 4.569
|
| 34 |
+
|
| 35 |
+
# Load and resample audio
|
| 36 |
+
wav, sr = sf.read(source_file)
|
| 37 |
+
waveform = torch.tensor(wav).float().cuda()
|
| 38 |
+
if sr != 16000:
|
| 39 |
+
waveform = torchaudio.functional.resample(waveform, sr, 16000)
|
| 40 |
+
|
| 41 |
+
# Normalize and convert to mel-spectrogram
|
| 42 |
+
waveform = waveform - waveform.mean()
|
| 43 |
+
mel = torchaudio.compliance.kaldi.fbank(
|
| 44 |
+
waveform.unsqueeze(0),
|
| 45 |
+
htk_compat=True,
|
| 46 |
+
sample_frequency=16000,
|
| 47 |
+
use_energy=False,
|
| 48 |
+
window_type='hanning',
|
| 49 |
+
num_mel_bins=128,
|
| 50 |
+
dither=0.0,
|
| 51 |
+
frame_shift=10
|
| 52 |
+
).unsqueeze(0)
|
| 53 |
+
|
| 54 |
+
# Pad or truncate
|
| 55 |
+
n_frames = mel.shape[1]
|
| 56 |
+
if n_frames < target_length:
|
| 57 |
+
mel = torch.nn.ZeroPad2d((0, 0, 0, target_length - n_frames))(mel)
|
| 58 |
+
else:
|
| 59 |
+
mel = mel[:, :target_length, :]
|
| 60 |
+
|
| 61 |
+
# Normalize
|
| 62 |
+
mel = (mel - norm_mean) / (norm_std * 2)
|
| 63 |
+
mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
|
| 64 |
+
|
| 65 |
+
# Extract features
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
feat = model.extract_features(mel)
|
| 68 |
+
|
| 69 |
+
feat = feat.squeeze(0).cpu().numpy()
|
| 70 |
+
print(f"Feature shape: {feat.shape}")
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## 📌 Notes
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions.
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
## 🙌 Acknowledgments
|
| 80 |
+
|
| 81 |
+
This repository builds on the EAT implementation for Hugging Face models. We remap SSLAM weights to that interface.
|
| 82 |
+
|
| 83 |
+
- Paper: EAT: Self supervised pretraining with Efficient Audio Transformer
|
| 84 |
+
- Code: https://github.com/cwx-worst-one/EAT
|
| 85 |
+
|
| 86 |
+
We are not affiliated with the EAT authors. All credit for the original implementation belongs to them.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
## 📚 Citation
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
If you find our work useful, please cite it as:
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
```bibtex
|
| 96 |
+
@inproceedings{alex2025sslam,
|
| 97 |
+
title={{SSLAM}: Enhancing Self-Supervised Models with Audio Mixtures for Polyphonic Soundscapes},
|
| 98 |
+
author={Tony Alex and Sara Atito and Armin Mustafa and Muhammad Awais and Philip J B Jackson},
|
| 99 |
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
| 100 |
+
year={2025},
|
| 101 |
+
url={https://openreview.net/forum?id=odU59TxdiB}
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
Please also cite EAT:
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@article{chen2024eat,
|
| 111 |
+
title={EAT: Self-supervised pre-training with efficient audio transformer},
|
| 112 |
+
author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
|
| 113 |
+
journal={arXiv preprint arXiv:2401.03497},
|
| 114 |
+
year={2024}
|
| 115 |
+
}
|
| 116 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.0,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"EATModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoModel": "modeling_eat.EATModel",
|
| 8 |
+
"AutoConfig": "configuration_eat.EATConfig"
|
| 9 |
+
},
|
| 10 |
+
"attn_drop_rate": 0.0,
|
| 11 |
+
"depth": 12,
|
| 12 |
+
"drop_rate": 0.0,
|
| 13 |
+
"embed_dim": 768,
|
| 14 |
+
"end_drop_path_rate": 0.0,
|
| 15 |
+
"fixed_positions": true,
|
| 16 |
+
"img_size": [
|
| 17 |
+
1024,
|
| 18 |
+
128
|
| 19 |
+
],
|
| 20 |
+
"in_chans": 1,
|
| 21 |
+
"layer_norm_first": false,
|
| 22 |
+
"max_length": 768,
|
| 23 |
+
"mel_bins": 128,
|
| 24 |
+
"mlp_ratio": 4.0,
|
| 25 |
+
"model_type": "eat",
|
| 26 |
+
"model_variant": "pretrain",
|
| 27 |
+
"norm_affine": true,
|
| 28 |
+
"norm_eps": 1e-06,
|
| 29 |
+
"num_classes": 527,
|
| 30 |
+
"num_heads": 12,
|
| 31 |
+
"patch_size": 16,
|
| 32 |
+
"post_mlp_drop": 0.0,
|
| 33 |
+
"qkv_bias": true,
|
| 34 |
+
"start_drop_path_rate": 0.0,
|
| 35 |
+
"stride": 16,
|
| 36 |
+
"torch_dtype": "float32",
|
| 37 |
+
"transformers_version": "4.51.3"
|
| 38 |
+
}
|
configuration_eat.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configuration_eat.py
|
| 2 |
+
|
| 3 |
+
from transformers import PretrainedConfig
|
| 4 |
+
|
| 5 |
+
class EATConfig(PretrainedConfig):
|
| 6 |
+
model_type = "eat"
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
embed_dim=768,
|
| 11 |
+
depth=12,
|
| 12 |
+
num_heads=12,
|
| 13 |
+
patch_size=16,
|
| 14 |
+
stride=16,
|
| 15 |
+
in_chans=1,
|
| 16 |
+
mel_bins=128,
|
| 17 |
+
max_length=768,
|
| 18 |
+
num_classes=527,
|
| 19 |
+
model_variant="pretrain", # or "finetune"
|
| 20 |
+
|
| 21 |
+
mlp_ratio=4.0,
|
| 22 |
+
qkv_bias=True,
|
| 23 |
+
drop_rate=0.0,
|
| 24 |
+
attn_drop_rate=0.0,
|
| 25 |
+
activation_dropout=0.0,
|
| 26 |
+
post_mlp_drop=0.0,
|
| 27 |
+
start_drop_path_rate=0.0,
|
| 28 |
+
end_drop_path_rate=0.0,
|
| 29 |
+
|
| 30 |
+
layer_norm_first=False,
|
| 31 |
+
norm_eps=1e-6,
|
| 32 |
+
norm_affine=True,
|
| 33 |
+
fixed_positions=True,
|
| 34 |
+
|
| 35 |
+
img_size=(1024, 128), # (target_length, mel_bins)
|
| 36 |
+
|
| 37 |
+
**kwargs,
|
| 38 |
+
):
|
| 39 |
+
super().__init__(**kwargs)
|
| 40 |
+
|
| 41 |
+
self.embed_dim = embed_dim
|
| 42 |
+
self.depth = depth
|
| 43 |
+
self.num_heads = num_heads
|
| 44 |
+
self.patch_size = patch_size
|
| 45 |
+
self.stride = stride
|
| 46 |
+
self.in_chans = in_chans
|
| 47 |
+
self.mel_bins = mel_bins
|
| 48 |
+
self.max_length = max_length
|
| 49 |
+
self.num_classes = num_classes
|
| 50 |
+
self.model_variant = model_variant
|
| 51 |
+
|
| 52 |
+
self.mlp_ratio = mlp_ratio
|
| 53 |
+
self.qkv_bias = qkv_bias
|
| 54 |
+
self.drop_rate = drop_rate
|
| 55 |
+
self.attn_drop_rate = attn_drop_rate
|
| 56 |
+
self.activation_dropout = activation_dropout
|
| 57 |
+
self.post_mlp_drop = post_mlp_drop
|
| 58 |
+
self.start_drop_path_rate = start_drop_path_rate
|
| 59 |
+
self.end_drop_path_rate = end_drop_path_rate
|
| 60 |
+
|
| 61 |
+
self.layer_norm_first = layer_norm_first
|
| 62 |
+
self.norm_eps = norm_eps
|
| 63 |
+
self.norm_affine = norm_affine
|
| 64 |
+
self.fixed_positions = fixed_positions
|
| 65 |
+
|
| 66 |
+
self.img_size = img_size
|
eat_model.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from timm.models.layers import trunc_normal_
|
| 4 |
+
from functools import partial
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .model_core import (
|
| 7 |
+
PatchEmbed_new,
|
| 8 |
+
get_2d_sincos_pos_embed_flexible,
|
| 9 |
+
FixedPositionalEncoder,
|
| 10 |
+
AltBlock
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
class EAT(nn.Module):
|
| 14 |
+
def __init__(self, config):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.config = config
|
| 17 |
+
self.mode = config.model_variant # "pretrain" or "finetune"
|
| 18 |
+
|
| 19 |
+
# === Embedding / Encoder ===
|
| 20 |
+
self.local_encoder = PatchEmbed_new(
|
| 21 |
+
img_size=config.img_size,
|
| 22 |
+
patch_size=config.patch_size,
|
| 23 |
+
in_chans=config.in_chans,
|
| 24 |
+
embed_dim=config.embed_dim,
|
| 25 |
+
stride=config.stride
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
|
| 29 |
+
self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
|
| 30 |
+
trunc_normal_(self.extra_tokens, std=.02)
|
| 31 |
+
|
| 32 |
+
self.fixed_positional_encoder = (
|
| 33 |
+
FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
|
| 37 |
+
dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
|
| 38 |
+
self.blocks = nn.ModuleList([
|
| 39 |
+
AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
|
| 40 |
+
qkv_bias=config.qkv_bias, drop=config.drop_rate,
|
| 41 |
+
attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
|
| 42 |
+
post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
|
| 43 |
+
norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
|
| 44 |
+
ffn_targets=True)
|
| 45 |
+
for i in range(config.depth)
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
self.pre_norm = norm_layer(config.embed_dim)
|
| 49 |
+
|
| 50 |
+
# === Head (for finetune) ===
|
| 51 |
+
if self.mode == "finetune":
|
| 52 |
+
self.fc_norm = nn.LayerNorm(config.embed_dim)
|
| 53 |
+
self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
|
| 54 |
+
else:
|
| 55 |
+
self.head = nn.Identity()
|
| 56 |
+
|
| 57 |
+
self.apply(self._init_weights)
|
| 58 |
+
|
| 59 |
+
def build_sincos_pos_embed(self):
|
| 60 |
+
W = self.config.mel_bins // self.config.patch_size
|
| 61 |
+
max_length = self.config.max_length
|
| 62 |
+
embed_dim = self.config.embed_dim
|
| 63 |
+
pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
|
| 64 |
+
emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
|
| 65 |
+
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
|
| 66 |
+
return pos_embed
|
| 67 |
+
|
| 68 |
+
def _init_weights(self, m):
|
| 69 |
+
if isinstance(m, nn.Linear):
|
| 70 |
+
trunc_normal_(m.weight, std=.02)
|
| 71 |
+
if m.bias is not None:
|
| 72 |
+
nn.init.constant_(m.bias, 0)
|
| 73 |
+
elif isinstance(m, nn.LayerNorm):
|
| 74 |
+
nn.init.constant_(m.bias, 0)
|
| 75 |
+
nn.init.constant_(m.weight, 1.0)
|
| 76 |
+
|
| 77 |
+
def encode(self, x):
|
| 78 |
+
B = x.shape[0]
|
| 79 |
+
x = self.local_encoder(x)
|
| 80 |
+
if self.fixed_positional_encoder is not None:
|
| 81 |
+
x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
|
| 82 |
+
x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
|
| 83 |
+
x = self.pre_norm(x)
|
| 84 |
+
x = self.pos_drop(x)
|
| 85 |
+
for blk in self.blocks:
|
| 86 |
+
x, _ = blk(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = self.encode(x)
|
| 91 |
+
if self.mode == "finetune":
|
| 92 |
+
x = x[:, 0] # use cls token
|
| 93 |
+
x = self.fc_norm(x)
|
| 94 |
+
x = self.head(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def extract_features(self, x):
|
| 98 |
+
x = self.encode(x)
|
| 99 |
+
return x
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ec670adb241c710422ddd894ff1bade142ef0b25cf1ee68577aa45f89432298
|
| 3 |
+
size 359905840
|
model_core.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from timm.models.layers import to_2tuple
|
| 6 |
+
|
| 7 |
+
class PatchEmbed_new(nn.Module):
|
| 8 |
+
""" Flexible Image to Patch Embedding
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16):
|
| 11 |
+
super().__init__()
|
| 12 |
+
img_size = to_2tuple(img_size)
|
| 13 |
+
patch_size = to_2tuple(patch_size)
|
| 14 |
+
stride = to_2tuple(stride)
|
| 15 |
+
|
| 16 |
+
self.img_size = img_size
|
| 17 |
+
self.patch_size = patch_size
|
| 18 |
+
|
| 19 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.proj(x)
|
| 23 |
+
x = x.flatten(2).transpose(1, 2)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
|
| 28 |
+
"""
|
| 29 |
+
grid_size: int of the grid height and width
|
| 30 |
+
return:
|
| 31 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 32 |
+
"""
|
| 33 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32)
|
| 34 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32)
|
| 35 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 36 |
+
grid = np.stack(grid, axis=0)
|
| 37 |
+
|
| 38 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
|
| 39 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 40 |
+
if cls_token:
|
| 41 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 42 |
+
return pos_embed
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 46 |
+
assert embed_dim % 2 == 0
|
| 47 |
+
|
| 48 |
+
# use half of dimensions to encode grid_h
|
| 49 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 50 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 51 |
+
|
| 52 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 53 |
+
return emb
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 57 |
+
"""
|
| 58 |
+
embed_dim: output dimension for each position
|
| 59 |
+
pos: a list of positions to be encoded: size (M,)
|
| 60 |
+
out: (M, D)
|
| 61 |
+
"""
|
| 62 |
+
assert embed_dim % 2 == 0
|
| 63 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
| 64 |
+
omega /= embed_dim / 2.0
|
| 65 |
+
omega = 1.0 / 10000 ** omega # (D/2,)
|
| 66 |
+
|
| 67 |
+
pos = pos.reshape(-1) # (M,)
|
| 68 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 69 |
+
|
| 70 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 71 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 72 |
+
|
| 73 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 74 |
+
return emb
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class FixedPositionalEncoder(nn.Module):
|
| 78 |
+
def __init__(self, pos_embed):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.positions = pos_embed
|
| 81 |
+
|
| 82 |
+
def forward(self, x, padding_mask):
|
| 83 |
+
return self.positions
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AltBlock(nn.Module):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim,
|
| 90 |
+
num_heads,
|
| 91 |
+
mlp_ratio=4.0,
|
| 92 |
+
qkv_bias=False,
|
| 93 |
+
qk_scale=None,
|
| 94 |
+
drop=0.0,
|
| 95 |
+
attn_drop=0.0,
|
| 96 |
+
mlp_drop=0.0,
|
| 97 |
+
post_mlp_drop=0.0,
|
| 98 |
+
drop_path=0.0,
|
| 99 |
+
act_layer=nn.GELU,
|
| 100 |
+
norm_layer=nn.LayerNorm,
|
| 101 |
+
layer_norm_first=True,
|
| 102 |
+
ffn_targets=False,
|
| 103 |
+
cosine_attention=False,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.layer_norm_first = layer_norm_first
|
| 108 |
+
self.ffn_targets = ffn_targets
|
| 109 |
+
|
| 110 |
+
from timm.models.vision_transformer import DropPath, Mlp
|
| 111 |
+
|
| 112 |
+
self.norm1 = norm_layer(dim)
|
| 113 |
+
self.attn = AltAttention(
|
| 114 |
+
dim,
|
| 115 |
+
num_heads=num_heads,
|
| 116 |
+
qkv_bias=qkv_bias,
|
| 117 |
+
qk_scale=qk_scale,
|
| 118 |
+
attn_drop=attn_drop,
|
| 119 |
+
proj_drop=drop,
|
| 120 |
+
cosine_attention=cosine_attention,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 124 |
+
self.norm2 = norm_layer(dim)
|
| 125 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 126 |
+
self.mlp = Mlp(
|
| 127 |
+
in_features=dim,
|
| 128 |
+
hidden_features=mlp_hidden_dim,
|
| 129 |
+
act_layer=act_layer,
|
| 130 |
+
drop=mlp_drop,
|
| 131 |
+
)
|
| 132 |
+
self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, padding_mask=None, alibi_bias=None):
|
| 135 |
+
if self.layer_norm_first:
|
| 136 |
+
x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
|
| 137 |
+
r = x = self.mlp(self.norm2(x))
|
| 138 |
+
t = x
|
| 139 |
+
x = r + self.drop_path(self.post_mlp_dropout(x))
|
| 140 |
+
if not self.ffn_targets:
|
| 141 |
+
t = x
|
| 142 |
+
else:
|
| 143 |
+
x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
|
| 144 |
+
r = x = self.norm1(x)
|
| 145 |
+
x = self.mlp(x)
|
| 146 |
+
t = x
|
| 147 |
+
x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
|
| 148 |
+
if not self.ffn_targets:
|
| 149 |
+
t = x
|
| 150 |
+
|
| 151 |
+
return x, t
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class AltAttention(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
dim,
|
| 158 |
+
num_heads=8,
|
| 159 |
+
qkv_bias=False,
|
| 160 |
+
qk_scale=None,
|
| 161 |
+
attn_drop=0.0,
|
| 162 |
+
proj_drop=0.0,
|
| 163 |
+
cosine_attention=False,
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.num_heads = num_heads
|
| 167 |
+
head_dim = dim // num_heads
|
| 168 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 169 |
+
|
| 170 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 171 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 172 |
+
self.proj = nn.Linear(dim, dim)
|
| 173 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 174 |
+
|
| 175 |
+
self.cosine_attention = cosine_attention
|
| 176 |
+
|
| 177 |
+
if cosine_attention:
|
| 178 |
+
self.logit_scale = nn.Parameter(
|
| 179 |
+
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def forward(self, x, padding_mask=None, alibi_bias=None):
|
| 183 |
+
B, N, C = x.shape
|
| 184 |
+
qkv = (
|
| 185 |
+
self.qkv(x)
|
| 186 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 187 |
+
.permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
|
| 188 |
+
)
|
| 189 |
+
q, k, v = (
|
| 190 |
+
qkv[0],
|
| 191 |
+
qkv[1],
|
| 192 |
+
qkv[2],
|
| 193 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
| 194 |
+
|
| 195 |
+
dtype = q.dtype
|
| 196 |
+
|
| 197 |
+
if self.cosine_attention:
|
| 198 |
+
# cosine attention
|
| 199 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
| 200 |
+
logit_scale = torch.clamp(
|
| 201 |
+
self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
|
| 202 |
+
).exp()
|
| 203 |
+
attn = attn * logit_scale
|
| 204 |
+
else:
|
| 205 |
+
q = q * self.scale
|
| 206 |
+
attn = q @ k.transpose(-2, -1)
|
| 207 |
+
|
| 208 |
+
if alibi_bias is not None:
|
| 209 |
+
attn = attn.type_as(alibi_bias)
|
| 210 |
+
attn[:, : alibi_bias.size(1)] += alibi_bias
|
| 211 |
+
|
| 212 |
+
if padding_mask is not None and padding_mask.any():
|
| 213 |
+
attn = attn.masked_fill(
|
| 214 |
+
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 215 |
+
float("-inf"),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
|
| 219 |
+
attn = self.attn_drop(attn)
|
| 220 |
+
x = (attn @ v).transpose(1, 2) #
|
| 221 |
+
x = x.reshape(B, N, C)
|
| 222 |
+
x = self.proj(x)
|
| 223 |
+
x = self.proj_drop(x)
|
| 224 |
+
return x
|
modeling_eat.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modeling_eat.py
|
| 2 |
+
|
| 3 |
+
from transformers import PreTrainedModel
|
| 4 |
+
from .configuration_eat import EATConfig
|
| 5 |
+
from .eat_model import EAT
|
| 6 |
+
|
| 7 |
+
class EATModel(PreTrainedModel):
|
| 8 |
+
config_class = EATConfig
|
| 9 |
+
|
| 10 |
+
def __init__(self, config: EATConfig):
|
| 11 |
+
super().__init__(config)
|
| 12 |
+
self.model = EAT(config)
|
| 13 |
+
|
| 14 |
+
def forward(self, *args, **kwargs):
|
| 15 |
+
return self.model(*args, **kwargs)
|
| 16 |
+
|
| 17 |
+
def extract_features(self, x):
|
| 18 |
+
return self.model.extract_features(x)
|