utkucoban's picture
NanoMaestro Full model weights released
47dfee0 verified
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.nn import Module
from torch.nn.modules.linear import Linear
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.normalization import LayerNorm
from torch.nn.init import *
class TransformerEncoderRPR(Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoderRPR, self).__init__()
self.layers = torch.nn.ModuleList([encoder_layer for _ in range(num_layers)]) # Fix for tracing
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask=None, src_key_padding_mask=None, **kwargs):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
if self.norm:
output = self.norm(output)
return output
class TransformerEncoderLayerRPR(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
super(TransformerEncoderLayerRPR, self).__init__()
self.self_attn = MultiheadAttentionRPR(d_model, nhead, dropout=dropout, er_len=er_len)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class MultiheadAttentionRPR(Module):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
super(MultiheadAttentionRPR, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
if bias:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
self.add_zero_attn = add_zero_attn
if er_len is not None:
self.Er = Parameter(torch.rand((er_len, self.head_dim), dtype=torch.float32))
else:
self.Er = None
self._reset_parameters()
def _reset_parameters(self):
if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
return multi_head_attention_forward_rpr(
query, key, value, self.embed_dim, self.num_heads, self.head_dim,
self.in_proj_weight, self.in_proj_bias,
None, None, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=not self._qkv_same_embed_dim,
q_proj_weight=getattr(self, 'q_proj_weight', None),
k_proj_weight=getattr(self, 'k_proj_weight', None),
v_proj_weight=getattr(self, 'v_proj_weight', None),
rpr_mat=self.Er)
def multi_head_attention_forward_rpr(query, key, value, embed_dim_to_check, num_heads, head_dim,
in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn,
dropout_p, out_proj_weight, out_proj_bias, training=True,
key_padding_mask=None, need_weights=True, attn_mask=None,
use_separate_proj_weight=False, q_proj_weight=None,
k_proj_weight=None, v_proj_weight=None, static_k=None,
static_v=None, rpr_mat=None):
tgt_len, bsz, embed_dim = query.size()
scaling = float(head_dim) ** -0.5
if not use_separate_proj_weight:
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
else:
q = F.linear(query, q_proj_weight, in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight, in_proj_bias[embed_dim:(embed_dim * 2)])
v = F.linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2):])
q = q * scaling
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
if rpr_mat is not None:
# Safe Explicit Skew
len_q = q.shape[1]
start_idx = rpr_mat.shape[0] - len_q
rpr_mat_valid = rpr_mat[start_idx:, :]
qe = torch.einsum("hld,md->hlm", q, rpr_mat_valid)
# Indices logic (Flatten -> Gather -> Reshape)
B, L, _ = qe.shape
# Mask out upper triangle BEFORE skewing
mask_tri = torch.triu(torch.ones((L, L), device=qe.device, dtype=torch.bool)).flip(0)
qe = qe.masked_fill(~mask_tri, 0.0) # Fill with 0 before shift
zeros = torch.zeros((B, L, 1), device=qe.device, dtype=qe.dtype)
qe_pad = torch.cat([zeros, qe], dim=2).view(B, -1)
offsets = torch.arange(L * L, device=qe.device, dtype=torch.int64) + L
offsets = offsets.unsqueeze(0).expand(B, -1)
srel = torch.gather(qe_pad, 1, offsets).view(B, L, L)
attn_output_weights = attn_output_weights + srel
# --- MASKING FIX (Boolean Masked Fill) ---
if attn_mask is not None:
if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0)
# ONNX prefers masked_fill with boolean mask over adding -inf
is_causal_mask = (attn_mask == float('-inf')) | (attn_mask < -1e4)
attn_output_weights = attn_output_weights.masked_fill(is_causal_mask, float('-inf'))
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf')
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, tgt_len)
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
return attn_output, attn_output_weights.view(bsz, num_heads, tgt_len, tgt_len).sum(dim=1) / num_heads
return attn_output, None