|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
from torch.utils.cpp_extension import load |
|
|
from typing import Dict, List, Optional, Tuple, Callable, Union |
|
|
|
|
|
eps = torch.finfo(torch.float32).eps |
|
|
|
|
|
def norm(x: torch.Tensor): |
|
|
return torch.rms_norm(x, (x.size(-1),), eps=eps) |
|
|
|
|
|
class Rotary(nn.Module): |
|
|
def __init__(self, dim: int, max_seq_len: int): |
|
|
super().__init__() |
|
|
|
|
|
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) |
|
|
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) |
|
|
t = torch.arange(max_seq_len, dtype=torch.float32) |
|
|
theta = torch.einsum("i,j -> ij", t, angular_freq) |
|
|
self.cos = nn.Buffer(theta.cos(), persistent=False) |
|
|
self.sin = nn.Buffer(theta.sin(), persistent=False) |
|
|
|
|
|
def forward(self, x_BTHD: torch.Tensor): |
|
|
assert self.cos.size(0) >= x_BTHD.size(-3) |
|
|
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] |
|
|
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) |
|
|
y1 = x1 * cos + x2 * sin |
|
|
y2 = x1 * (-sin) + x2 * cos |
|
|
return torch.cat((y1, y2), 3).type_as(x_BTHD) |
|
|
|
|
|
class CausalSoftmaxAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layer_id: int, |
|
|
layers: int, |
|
|
num_heads: int, |
|
|
vocab_size: int, |
|
|
input_dims: int, |
|
|
hidden_dims: Union[int, None] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.layer_id = layer_id |
|
|
self.head_dim = input_dims // num_heads |
|
|
self.num_heads = num_heads |
|
|
assert input_dims % self.num_heads == 0 |
|
|
|
|
|
H = self.num_heads |
|
|
N = self.head_dim |
|
|
C = input_dims |
|
|
|
|
|
with torch.no_grad(): |
|
|
init_bounds = 0.5 / (C ** 0.5) |
|
|
|
|
|
self.q_proj = nn.Linear(C, C, bias=False) |
|
|
self.k_proj = nn.Linear(C, C, bias=False) |
|
|
self.v_proj = nn.Linear(C, C, bias=False) |
|
|
self.g_proj = nn.Linear(C, C, bias=False) |
|
|
self.o_proj = nn.Linear(C, C, bias=False) |
|
|
|
|
|
self.rotary = Rotary(N, 2048) |
|
|
|
|
|
self.q_proj.weight.data.uniform_(-init_bounds, init_bounds) |
|
|
self.k_proj.weight.data.uniform_(-init_bounds, init_bounds) |
|
|
self.v_proj.weight.data.uniform_(-init_bounds, init_bounds) |
|
|
self.g_proj.weight.data.uniform_(-init_bounds, init_bounds) |
|
|
self.o_proj.weight.data.zero_() |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
H = self.num_heads |
|
|
N = C // H |
|
|
|
|
|
def forward1(x): |
|
|
x = norm(x) |
|
|
|
|
|
q = self.q_proj(x).view(B, T, H, N) |
|
|
k = self.k_proj(x).view(B, T, H, N) |
|
|
v = self.v_proj(x).view(B, T, H, N) |
|
|
g = self.g_proj(x).sigmoid() |
|
|
|
|
|
q, k = norm(q), norm(k) |
|
|
q, k = self.rotary(q), self.rotary(k) |
|
|
|
|
|
return (q, k, v, g) |
|
|
|
|
|
(q, k, v, g) = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False) |
|
|
|
|
|
x = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True).transpose(1, 2).contiguous().view(B, T, C) |
|
|
|
|
|
x = self.o_proj(x * g) |
|
|
|
|
|
return x |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layer_id: int, |
|
|
layers: int, |
|
|
num_heads: int, |
|
|
vocab_size: int, |
|
|
input_dims: int, |
|
|
hidden_dims: Union[int, None] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.layer_id = layer_id |
|
|
|
|
|
C = input_dims |
|
|
hidden_dims = hidden_dims or 4 * C |
|
|
|
|
|
with torch.no_grad(): |
|
|
init_bounds = 0.5 / (C ** 0.5) |
|
|
|
|
|
self.k_proj = nn.Linear(C, hidden_dims, bias=False) |
|
|
self.v_proj = nn.Linear(hidden_dims, C, bias=False) |
|
|
|
|
|
self.k_proj.weight.data.uniform_(-init_bounds, init_bounds) |
|
|
self.v_proj.weight.data.zero_() |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
|
|
|
def forward1(x): |
|
|
x = norm(x) |
|
|
|
|
|
k = torch.relu(self.k_proj(x)).square() |
|
|
|
|
|
return self.v_proj(k) |
|
|
|
|
|
output = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False) |
|
|
|
|
|
return output |
|
|
|
|
|
class SoftmaxBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layer_id: int, |
|
|
layers: int, |
|
|
num_heads: int, |
|
|
vocab_size: int, |
|
|
input_dims: int, |
|
|
hidden_dims: Union[int, None] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.layer_id = layer_id |
|
|
|
|
|
self.att = CausalSoftmaxAttention(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims) |
|
|
self.ffn = MLP(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims) |
|
|
|
|
|
def forward(self, x): |
|
|
xx = self.att(x) |
|
|
x = x + xx |
|
|
|
|
|
xx = self.ffn(x) |
|
|
x = x + xx |
|
|
|
|
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
layers: int, |
|
|
num_heads: int, |
|
|
vocab_size: int, |
|
|
input_dims: int, |
|
|
hidden_dims: Union[int, None] = None, |
|
|
dtype = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.emb = nn.Embedding(vocab_size, input_dims) |
|
|
self.emb.weight.data.uniform_(-1e-4, 1e-4) |
|
|
|
|
|
self.blocks = nn.ModuleList([SoftmaxBlock(i, layers, num_heads, vocab_size, input_dims, hidden_dims) for i in range(layers)]) |
|
|
|
|
|
def forward(self, idx): |
|
|
|
|
|
x = norm(self.emb(idx)) |
|
|
|
|
|
for i, block in enumerate(self.blocks): |
|
|
x = block(x) |
|
|
|
|
|
x = norm(x) |
|
|
|
|
|
logits = F.linear(x, self.emb.weight) |
|
|
|
|
|
return logits |