import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from timm.models.layers import to_2tuple class PatchEmbed_new(nn.Module): """ Flexible Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) self.img_size = img_size self.patch_size = patch_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches def forward(self, x): x = self.proj(x) x = x.flatten(2).transpose(1, 2) return x def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size[0], dtype=np.float32) grid_w = np.arange(grid_size[1], dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class FixedPositionalEncoder(nn.Module): def __init__(self, pos_embed): super().__init__() self.positions = pos_embed def forward(self, x, padding_mask): return self.positions class AltBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, mlp_drop=0.0, post_mlp_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_norm_first=True, ffn_targets=False, cosine_attention=False, ): super().__init__() self.layer_norm_first = layer_norm_first self.ffn_targets = ffn_targets from timm.models.vision_transformer import DropPath, Mlp self.norm1 = norm_layer(dim) self.attn = AltAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, cosine_attention=cosine_attention, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop, ) self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) def forward(self, x, padding_mask=None, alibi_bias=None): if self.layer_norm_first: x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias)) r = x = self.mlp(self.norm2(x)) t = x x = r + self.drop_path(self.post_mlp_dropout(x)) if not self.ffn_targets: t = x else: x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias)) r = x = self.norm1(x) x = self.mlp(x) t = x x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) if not self.ffn_targets: t = x return x, t class AltAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, cosine_attention=False, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.cosine_attention = cosine_attention if cosine_attention: self.logit_scale = nn.Parameter( torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True ) def forward(self, x, padding_mask=None, alibi_bias=None): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D ) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) dtype = q.dtype if self.cosine_attention: # cosine attention attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) logit_scale = torch.clamp( self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) ).exp() attn = attn * logit_scale else: q = q * self.scale attn = q @ k.transpose(-2, -1) if alibi_bias is not None: attn = attn.type_as(alibi_bias) attn[:, : alibi_bias.size(1)] += alibi_bias if padding_mask is not None and padding_mask.any(): attn = attn.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2) # x = x.reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x