Text Generation
Safetensors
English
Chinese
xTimeCrystal commited on
Commit
43433f9
·
verified ·
1 Parent(s): 2e9c550

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +190 -0
model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from torch.utils.cpp_extension import load
7
+ from typing import Dict, List, Optional, Tuple, Callable, Union
8
+
9
+ eps = torch.finfo(torch.float32).eps
10
+
11
+ def norm(x: torch.Tensor):
12
+ return torch.rms_norm(x, (x.size(-1),), eps=eps)
13
+
14
+ class Rotary(nn.Module):
15
+ def __init__(self, dim: int, max_seq_len: int):
16
+ super().__init__()
17
+ # half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
18
+ angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
19
+ angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
20
+ t = torch.arange(max_seq_len, dtype=torch.float32)
21
+ theta = torch.einsum("i,j -> ij", t, angular_freq)
22
+ self.cos = nn.Buffer(theta.cos(), persistent=False)
23
+ self.sin = nn.Buffer(theta.sin(), persistent=False)
24
+
25
+ def forward(self, x_BTHD: torch.Tensor):
26
+ assert self.cos.size(0) >= x_BTHD.size(-3)
27
+ cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
28
+ x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
29
+ y1 = x1 * cos + x2 * sin
30
+ y2 = x1 * (-sin) + x2 * cos
31
+ return torch.cat((y1, y2), 3).type_as(x_BTHD)
32
+
33
+ class CausalSoftmaxAttention(nn.Module):
34
+ def __init__(
35
+ self,
36
+ layer_id: int,
37
+ layers: int,
38
+ num_heads: int,
39
+ vocab_size: int,
40
+ input_dims: int,
41
+ hidden_dims: Union[int, None] = None,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.layer_id = layer_id
46
+ self.head_dim = input_dims // num_heads
47
+ self.num_heads = num_heads
48
+ assert input_dims % self.num_heads == 0
49
+
50
+ H = self.num_heads
51
+ N = self.head_dim
52
+ C = input_dims
53
+
54
+ with torch.no_grad():
55
+ init_bounds = 0.5 / (C ** 0.5)
56
+
57
+ self.q_proj = nn.Linear(C, C, bias=False)
58
+ self.k_proj = nn.Linear(C, C, bias=False)
59
+ self.v_proj = nn.Linear(C, C, bias=False)
60
+ self.g_proj = nn.Linear(C, C, bias=False)
61
+ self.o_proj = nn.Linear(C, C, bias=False)
62
+
63
+ self.rotary = Rotary(N, 2048)
64
+
65
+ self.q_proj.weight.data.uniform_(-init_bounds, init_bounds)
66
+ self.k_proj.weight.data.uniform_(-init_bounds, init_bounds)
67
+ self.v_proj.weight.data.uniform_(-init_bounds, init_bounds)
68
+ self.g_proj.weight.data.uniform_(-init_bounds, init_bounds)
69
+ self.o_proj.weight.data.zero_()
70
+
71
+ def forward(self, x):
72
+ B, T, C = x.size()
73
+ H = self.num_heads
74
+ N = C // H
75
+
76
+ def forward1(x):
77
+ x = norm(x)
78
+
79
+ q = self.q_proj(x).view(B, T, H, N)
80
+ k = self.k_proj(x).view(B, T, H, N)
81
+ v = self.v_proj(x).view(B, T, H, N)
82
+ g = self.g_proj(x).sigmoid()
83
+
84
+ q, k = norm(q), norm(k)
85
+ q, k = self.rotary(q), self.rotary(k)
86
+
87
+ return (q, k, v, g)
88
+
89
+ (q, k, v, g) = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False)
90
+
91
+ x = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True).transpose(1, 2).contiguous().view(B, T, C)
92
+
93
+ x = self.o_proj(x * g)
94
+
95
+ return x
96
+
97
+ class MLP(nn.Module):
98
+ def __init__(
99
+ self,
100
+ layer_id: int,
101
+ layers: int,
102
+ num_heads: int,
103
+ vocab_size: int,
104
+ input_dims: int,
105
+ hidden_dims: Union[int, None] = None,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.layer_id = layer_id
110
+
111
+ C = input_dims
112
+ hidden_dims = hidden_dims or 4 * C
113
+
114
+ with torch.no_grad():
115
+ init_bounds = 0.5 / (C ** 0.5)
116
+
117
+ self.k_proj = nn.Linear(C, hidden_dims, bias=False)
118
+ self.v_proj = nn.Linear(hidden_dims, C, bias=False)
119
+
120
+ self.k_proj.weight.data.uniform_(-init_bounds, init_bounds)
121
+ self.v_proj.weight.data.zero_()
122
+
123
+ def forward(self, x):
124
+ B, T, C = x.size()
125
+
126
+ def forward1(x):
127
+ x = norm(x)
128
+
129
+ k = torch.relu(self.k_proj(x)).square()
130
+
131
+ return self.v_proj(k)
132
+
133
+ output = torch.utils.checkpoint.checkpoint(forward1, x, use_reentrant=False)
134
+
135
+ return output
136
+
137
+ class SoftmaxBlock(nn.Module):
138
+ def __init__(
139
+ self,
140
+ layer_id: int,
141
+ layers: int,
142
+ num_heads: int,
143
+ vocab_size: int,
144
+ input_dims: int,
145
+ hidden_dims: Union[int, None] = None,
146
+ ):
147
+ super().__init__()
148
+ self.layer_id = layer_id
149
+
150
+ self.att = CausalSoftmaxAttention(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims)
151
+ self.ffn = MLP(layer_id, layers, num_heads, vocab_size, input_dims, hidden_dims)
152
+
153
+ def forward(self, x):
154
+ xx = self.att(x)
155
+ x = x + xx
156
+
157
+ xx = self.ffn(x)
158
+ x = x + xx
159
+
160
+ return x
161
+
162
+ class Transformer(nn.Module):
163
+ def __init__(
164
+ self,
165
+ layers: int,
166
+ num_heads: int,
167
+ vocab_size: int,
168
+ input_dims: int,
169
+ hidden_dims: Union[int, None] = None,
170
+ dtype = None
171
+ ):
172
+ super().__init__()
173
+
174
+ self.emb = nn.Embedding(vocab_size, input_dims)
175
+ self.emb.weight.data.uniform_(-1e-4, 1e-4)
176
+
177
+ self.blocks = nn.ModuleList([SoftmaxBlock(i, layers, num_heads, vocab_size, input_dims, hidden_dims) for i in range(layers)])
178
+
179
+ def forward(self, idx):
180
+
181
+ x = norm(self.emb(idx))
182
+
183
+ for i, block in enumerate(self.blocks):
184
+ x = block(x)
185
+
186
+ x = norm(x)
187
+
188
+ logits = F.linear(x, self.emb.weight)
189
+
190
+ return logits