HighkeyPrxneeth's picture
Upload folder using huggingface_hub
eee6498 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
More stable and computationally efficient than LayerNorm.
Used in LLaMA, PaLM, Gopher.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class SwiGLU(nn.Module):
"""
Swish-Gated Linear Unit.
SOTA activation function for FFNs (outperforms GELU/ReLU).
"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Gate mechanism: (x * sigmoid(x)) * linear(x)
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(self.dropout(hidden))
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation Block.
Allows the model to dynamically weight different dimensions of the embedding
based on global context.
"""
def __init__(self, dim: int, reduction: int = 4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(dim, dim // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(dim // reduction, dim, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# Input: [B, D] -> unsqueeze to [B, D, 1] for pool/conv compatibility if needed
# But here we are working with vectors, so we simulate it.
b, d = x.shape
y = self.fc(x) # [B, D]
return x * y
class DropPath(nn.Module):
"""Stochastic depth regularizer (Improved)."""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1.0 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class ModernBlock(nn.Module):
"""
A Pre-Norm Block combining RMSNorm, SwiGLU, and Channel Attention.
"""
def __init__(self, dim: int, expand: int = 4, dropout: float = 0.1,
layer_scale_init: float = 1e-6, drop_path: float = 0.0):
super().__init__()
# 1. Normalization
self.norm = RMSNorm(dim)
# 2. SOTA Feed Forward (SwiGLU)
# SwiGLU usually requires 2/3 hidden dim of standard MLP to match params,
# but we keep it high for expressivity.
self.ffn = SwiGLU(dim, int(dim * expand * 2 / 3), dropout=dropout)
# 3. Channel Attention (Context awareness)
self.se = SEBlock(dim, reduction=4)
# 4. Regularization
self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale_init) if layer_scale_init > 0 else None
self.drop_path = DropPath(drop_path)
def forward(self, x):
residual = x
# Pre-Norm Architecture
out = self.norm(x)
out = self.ffn(out)
out = self.se(out) # Apply attention
if self.layer_scale is not None:
out = out * self.layer_scale
out = self.drop_path(out)
return residual + out
class ModernTrajectoryNet(nn.Module):
def __init__(self, config):
super().__init__()
self.d_model = config.d_model
self.n_layers = config.n_layers
# Config defaults
dropout = getattr(config, "dropout", 0.1)
expand = getattr(config, "expand", 4)
drop_path_rate = getattr(config, "drop_path_rate", 0.1)
# Input Projection (Projects to latent space)
self.input_proj = nn.Sequential(
RMSNorm(self.d_model),
nn.Linear(self.d_model, self.d_model)
)
# Backbone
self.blocks = nn.ModuleList([
ModernBlock(
dim=self.d_model,
expand=expand,
dropout=dropout,
drop_path=drop_path_rate * (i / (self.n_layers - 1)) # Linear decay
) for i in range(self.n_layers)
])
self.final_norm = RMSNorm(self.d_model)
# Projector Head (SimCLR / CLIP style)
# Important: Keep high dimension for the final linear probe
self.head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, self.d_model)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, return_trajectory=False):
# Handle sequence dimension if present
if x.dim() == 3:
x = x.mean(dim=1)
x = self.input_proj(x)
trajectory = []
for block in self.blocks:
x = block(x)
trajectory.append(x)
x = self.final_norm(x)
# Residual connection to original input is implicit via the blocks,
# but for trajectory learning, we want the final head to dictate the shift.
output = self.head(x)
# OPTIONAL: Add Denoising / Residual connection to input
# output = output + input_tensor_if_saved
if return_trajectory:
return output, torch.stack(trajectory, dim=1)
return output
# Backwards compatibility
HybridMambaAttentionModel = ModernTrajectoryNet