# ------------------------------------------------------------------------------ # This file includes code copied and adapted from DINO: # - DINO (https://github.com/facebookresearch/dino) # # ------------------------------------------------------------------------------ import random import math import torch import torch.nn as nn from torch import Tensor from functools import partial def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor return _no_grad_trunc_normal_(tensor, mean, std, a, b) def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x, attn class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, return_attention=False): y, attn = self.attn(self.norm1(x)) if return_attention: return attn x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchEmbed(nn.Module): def __init__( self, img_size: int, embed_dim: int, patch_size: int, in_chans_s1: int, in_chans_s2: int, ): super().__init__() attn_dim = embed_dim*3 # from Panopticon design self.img_size = img_size self.patch_size = patch_size num_patches = (img_size // patch_size) * (img_size // patch_size) self.num_patches = num_patches self.conv2d_s2_l2a = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size) self.conv2d_s2_l1c = nn.Conv2d(in_chans_s2, attn_dim, kernel_size=patch_size, stride=patch_size) self.conv2d_s1 = nn.Conv2d(in_chans_s1, attn_dim, kernel_size=patch_size, stride=patch_size) self.projection = TokenProjection(embed_dim=embed_dim, attn_dim=attn_dim) self.s2_l2a_embed = nn.Parameter(torch.zeros(1, attn_dim)) self.s2_l1c_embed = nn.Parameter(torch.zeros(1, attn_dim)) self.s1_embed = nn.Parameter(torch.zeros(1, attn_dim)) self.attn_dim = attn_dim def forward(self, x12: Tensor, is_l2a: bool = False) -> Tensor: B,C,W,H = x12.shape device, dtype = x12.device, x12.dtype B = len(x12) if C == 2: x = self.conv2d_s1(x12).flatten(2).transpose(1, 2) x += self.s1_embed elif is_l2a: x = self.conv2d_s2_l2a(x12).flatten(2).transpose(1, 2) x += self.s2_l2a_embed else: x = self.conv2d_s2_l1c(x12).flatten(2).transpose(1, 2) x += self.s2_l1c_embed x = self.projection(x) return x class TokenProjection(nn.Module): def __init__(self, embed_dim: int, attn_dim: int): super().__init__() self.proj1 = nn.Linear(attn_dim, attn_dim, bias=False) self.norm_input = nn.LayerNorm(attn_dim) self.proj2 = nn.Linear(attn_dim, attn_dim) self.proj3 = nn.Linear(attn_dim, embed_dim) def forward(self, x: Tensor) -> Tensor: """ Applies a sequence of linear projections used for Case 1 & N in modality augmentation. Steps: 1. proj1 is shared between Case 1 and Case N (acts like value projection in attention). 2. Applies LayerNorm to stabilize training and normalize features. 3. In Case N, proj2 is applied after the weighted mean operation. 4. proj3 projects to the final embedding dimension. Args: tokens (Tensor): Input tensor of shape [B, N, input_dim], where B = batch size, N = number of tokens. Returns: Tensor: Projected output of shape [B, N, final_dim]. """ x = self.proj1(x) #V in corss attn x = self.norm_input(x) x = self.proj2(x) x = self.proj3(x) #final projection return x class TerraFM(nn.Module): def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): super().__init__() self.num_features = self.embed_dim = embed_dim self.patch_embed = PatchEmbed( img_size=img_size[0], patch_size=patch_size, in_chans_s1=2, in_chans_s2=12, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_embed.patch_size h0 = h // self.patch_embed.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def prepare_tokens(self, x): B, nc, w, h = x.shape x = self.patch_embed(x) # patch linear embedding # add the [CLS] token to the embed patch tokens cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # add positional encoding to each token x = x + self.interpolate_pos_encoding(x, w, h) return self.pos_drop(x) def forward_features(self, x): return self.forward(x) def forward(self, x): x = self.prepare_tokens(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x[:, 0] def get_last_selfattention(self, x): x = self.prepare_tokens(x) for i, blk in enumerate(self.blocks): if i < len(self.blocks) - 1: x = blk(x) else: # return attention of the last block return blk(x, return_attention=True) def get_intermediate_layers(self, x, n=1, return_class_token = False, norm=False, ): x = self.prepare_tokens(x) # we return the output tokens from the `n` last blocks output = [] for i, blk in enumerate(self.blocks): x = blk(x) if len(self.blocks) - i <= n: output.append(x) # output.append(self.norm(x)) if norm: output = [self.norm(out) for out in output] class_tokens = [out[:, 0] for out in output] output = [out[:, 1:] for out in output] if return_class_token: return tuple(zip(output, class_tokens)) return output def extract_feature(self, images, return_h_w=True, out_indices=[3, 5, 7, 11]): x = self.prepare_tokens(images) output = [] h, w = int(images.shape[2] / self.patch_embed.patch_size), int(images.shape[3] / self.patch_embed.patch_size) for i, blk in enumerate(self.blocks): x = blk(x) if i in out_indices: out = x[:, 1:] out = self.norm(out) B, _, C = out.shape out = ( out.reshape(B, h, w, C) .permute(0, 3, 1, 2) .contiguous() ) output.append(out) return output def terrafm_base(patch_size=16, **kwargs): model = TerraFM( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def terrafm_large(patch_size=16, **kwargs): model = TerraFM( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model