SSLAM_pretrain / eat_model.py
ta012's picture
Upload 7 files
501ebee verified
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from functools import partial
import numpy as np
from .model_core import (
PatchEmbed_new,
get_2d_sincos_pos_embed_flexible,
FixedPositionalEncoder,
AltBlock
)
class EAT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.mode = config.model_variant # "pretrain" or "finetune"
# === Embedding / Encoder ===
self.local_encoder = PatchEmbed_new(
img_size=config.img_size,
patch_size=config.patch_size,
in_chans=config.in_chans,
embed_dim=config.embed_dim,
stride=config.stride
)
self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
trunc_normal_(self.extra_tokens, std=.02)
self.fixed_positional_encoder = (
FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
)
norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
self.blocks = nn.ModuleList([
AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
qkv_bias=config.qkv_bias, drop=config.drop_rate,
attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
ffn_targets=True)
for i in range(config.depth)
])
self.pre_norm = norm_layer(config.embed_dim)
# === Head (for finetune) ===
if self.mode == "finetune":
self.fc_norm = nn.LayerNorm(config.embed_dim)
self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
else:
self.head = nn.Identity()
self.apply(self._init_weights)
def build_sincos_pos_embed(self):
W = self.config.mel_bins // self.config.patch_size
max_length = self.config.max_length
embed_dim = self.config.embed_dim
pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
return pos_embed
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def encode(self, x):
B = x.shape[0]
x = self.local_encoder(x)
if self.fixed_positional_encoder is not None:
x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
x = self.pre_norm(x)
x = self.pos_drop(x)
for blk in self.blocks:
x, _ = blk(x)
return x
def forward(self, x):
x = self.encode(x)
if self.mode == "finetune":
x = x[:, 0] # use cls token
x = self.fc_norm(x)
x = self.head(x)
return x
def extract_features(self, x):
x = self.encode(x)
return x