|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
Root Mean Square Layer Normalization. |
|
|
More stable and computationally efficient than LayerNorm. |
|
|
Used in LLaMA, PaLM, Gopher. |
|
|
""" |
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()).type_as(x) |
|
|
return output * self.weight |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
""" |
|
|
Swish-Gated Linear Unit. |
|
|
SOTA activation function for FFNs (outperforms GELU/ReLU). |
|
|
""" |
|
|
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
self.w3 = nn.Linear(hidden_dim, dim, bias=False) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x1 = self.w1(x) |
|
|
x2 = self.w2(x) |
|
|
hidden = F.silu(x1) * x2 |
|
|
return self.w3(self.dropout(hidden)) |
|
|
|
|
|
class SEBlock(nn.Module): |
|
|
""" |
|
|
Squeeze-and-Excitation Block. |
|
|
Allows the model to dynamically weight different dimensions of the embedding |
|
|
based on global context. |
|
|
""" |
|
|
def __init__(self, dim: int, reduction: int = 4): |
|
|
super().__init__() |
|
|
self.avg_pool = nn.AdaptiveAvgPool1d(1) |
|
|
self.fc = nn.Sequential( |
|
|
nn.Linear(dim, dim // reduction, bias=False), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Linear(dim // reduction, dim, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
b, d = x.shape |
|
|
y = self.fc(x) |
|
|
return x * y |
|
|
|
|
|
class DropPath(nn.Module): |
|
|
"""Stochastic depth regularizer (Improved).""" |
|
|
def __init__(self, drop_prob: float = 0.0): |
|
|
super().__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x): |
|
|
if self.drop_prob == 0.0 or not self.training: |
|
|
return x |
|
|
keep_prob = 1.0 - self.drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
random_tensor.floor_() |
|
|
return x.div(keep_prob) * random_tensor |
|
|
|
|
|
class ModernBlock(nn.Module): |
|
|
""" |
|
|
A Pre-Norm Block combining RMSNorm, SwiGLU, and Channel Attention. |
|
|
""" |
|
|
def __init__(self, dim: int, expand: int = 4, dropout: float = 0.1, |
|
|
layer_scale_init: float = 1e-6, drop_path: float = 0.0): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm = RMSNorm(dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ffn = SwiGLU(dim, int(dim * expand * 2 / 3), dropout=dropout) |
|
|
|
|
|
|
|
|
self.se = SEBlock(dim, reduction=4) |
|
|
|
|
|
|
|
|
self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale_init) if layer_scale_init > 0 else None |
|
|
self.drop_path = DropPath(drop_path) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
|
|
|
|
|
|
out = self.norm(x) |
|
|
out = self.ffn(out) |
|
|
out = self.se(out) |
|
|
|
|
|
if self.layer_scale is not None: |
|
|
out = out * self.layer_scale |
|
|
|
|
|
out = self.drop_path(out) |
|
|
|
|
|
return residual + out |
|
|
|
|
|
class ModernTrajectoryNet(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.d_model = config.d_model |
|
|
self.n_layers = config.n_layers |
|
|
|
|
|
|
|
|
dropout = getattr(config, "dropout", 0.1) |
|
|
expand = getattr(config, "expand", 4) |
|
|
drop_path_rate = getattr(config, "drop_path_rate", 0.1) |
|
|
|
|
|
|
|
|
self.input_proj = nn.Sequential( |
|
|
RMSNorm(self.d_model), |
|
|
nn.Linear(self.d_model, self.d_model) |
|
|
) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
ModernBlock( |
|
|
dim=self.d_model, |
|
|
expand=expand, |
|
|
dropout=dropout, |
|
|
drop_path=drop_path_rate * (i / (self.n_layers - 1)) |
|
|
) for i in range(self.n_layers) |
|
|
]) |
|
|
|
|
|
self.final_norm = RMSNorm(self.d_model) |
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(self.d_model, self.d_model), |
|
|
nn.GELU(), |
|
|
nn.Linear(self.d_model, self.d_model) |
|
|
) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
torch.nn.init.trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def forward(self, x, return_trajectory=False): |
|
|
|
|
|
if x.dim() == 3: |
|
|
x = x.mean(dim=1) |
|
|
|
|
|
x = self.input_proj(x) |
|
|
|
|
|
trajectory = [] |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
trajectory.append(x) |
|
|
|
|
|
x = self.final_norm(x) |
|
|
|
|
|
|
|
|
|
|
|
output = self.head(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if return_trajectory: |
|
|
return output, torch.stack(trajectory, dim=1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
HybridMambaAttentionModel = ModernTrajectoryNet |
|
|
|