Mohamed Mekkouri commited on
Commit
bbb69f0
·
1 Parent(s): 0ee88ce

add README

Browse files
Files changed (1) hide show
  1. README.md +238 -3
README.md CHANGED
@@ -1,9 +1,244 @@
1
  ---
2
  tags:
3
- - kernels
4
- - gptoss
5
  ---
6
 
7
  # gptoss_kernels
8
 
9
- This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  tags:
3
+ - kernels
4
+ - gptoss
5
  ---
6
 
7
  # gptoss_kernels
8
 
9
+ Metal kernels that back the OpenAI GPT-OSS reference implementation, repackaged for local experiments on Apple Silicon GPUs. The GPT-OSS project distributes optimized inference primitives for the `gpt-oss-20b` and `gpt-oss-120b` open-weight models, including MXFP4-packed linear layers and fused attention paths that target Metal Performance Shaders on macOS [gpt-oss](https://github.com/openai/gpt-oss).
10
+
11
+ ## Installation
12
+
13
+ ```bash
14
+ pip install kernels # we just need to install the kernels package
15
+ ```
16
+
17
+ The package exposes Python bindings through `gptoss_kernels.ops`; these symbols are re-exported in `gptoss_kernels.__init__` for convenience. All kernels expect Metal (`mps`) tensors and operate in place on user-provided outputs to minimize additional allocations.
18
+
19
+ ## Available Ops
20
+
21
+ - `f32_bf16w_matmul`, `f32_bf16w_matmul_add`
22
+ - `f32_bf16w_dense_matmul_qkv`, `f32_bf16w_dense_matmul_attn_output`, `f32_bf16w_dense_matmul_mlp_gate`
23
+ - `f32_bf16w_rmsnorm`
24
+ - `bf16_f32_embeddings`
25
+ - `f32_rope`
26
+ - `f32_bf16w_matmul_qkv`
27
+ - `f32_sdpa`
28
+ - `f32_topk`, `expert_routing_metadata`, `f32_scatter`
29
+
30
+ For implementation details, inspect the `.metal` shader files.
31
+
32
+ ## Usage & Consistency Checks
33
+
34
+ Each example below compares a Metal kernel against the canonical PyTorch equivalent using shared random inputs. The snippets assume an Apple Silicon machine with an `mps` device and that `kernels` installed in the active environment.
35
+
36
+ ### 1. BF16-weight matmul vs PyTorch `matmul`
37
+
38
+ ```python
39
+ import torch
40
+ from kernels import get_kernel
41
+
42
+ gptoss_kernels = get_kernel("kernels-community/gptoss_kernels")
43
+
44
+ torch.manual_seed(0)
45
+ device = "mps"
46
+ batch, rows, cols = 2, 128, 1024
47
+
48
+ activations = torch.randn(batch, rows, device=device, dtype=torch.float32)
49
+ weights = torch.randn(rows, cols, device=device, dtype=torch.bfloat16)
50
+ bias = torch.zeros(cols, device=device, dtype=torch.bfloat16)
51
+ out_ref = activations @ weights.float() + bias.float()
52
+
53
+ out_kernel = torch.empty(batch, cols, device=device, dtype=torch.float32)
54
+ gptoss_kernels.f32_bf16w_matmul(
55
+ activations,
56
+ weights,
57
+ bias,
58
+ out_kernel,
59
+ num_tokens=batch,
60
+ num_cols=rows,
61
+ num_rows=cols,
62
+ threadgroup_size=32,
63
+ )
64
+
65
+ print(out_kernel)
66
+ print(out_ref)
67
+
68
+ torch.testing.assert_close(out_kernel, out_ref, atol=1e-3, rtol=1e-3)
69
+ ```
70
+
71
+ ### 2. RMSNorm vs PyTorch layer norm equivalent
72
+
73
+ ```python
74
+ from kernels import get_kernel
75
+ import torch
76
+
77
+ gptoss_kernels = get_kernel("kernels-community/gptoss_kernels")
78
+ device = "mps"
79
+
80
+ hidden = 4096
81
+ eps = 1e-5
82
+ x = torch.randn(4, hidden, device=device, dtype=torch.float32)
83
+ weight = torch.randn(hidden, device=device, dtype=torch.bfloat16)
84
+
85
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
86
+ out_ref = (x * torch.rsqrt(variance + eps)) * weight.float()
87
+
88
+ out_kernel = torch.empty_like(x)
89
+ gptoss_kernels.f32_bf16w_rmsnorm(x, weight, out_kernel, epsilon=eps)
90
+
91
+ print(out_kernel)
92
+ print(out_ref)
93
+
94
+ torch.testing.assert_close(out_kernel, out_ref, atol=1e-3, rtol=1e-3)
95
+ ```
96
+
97
+ ### 3. Embedding lookup with BF16 tables
98
+
99
+ ```python
100
+ from kernels import get_kernel
101
+ import torch
102
+
103
+ device = "mps"
104
+ gptoss_kernels = get_kernel("kernels-community/gptoss_kernels")
105
+
106
+ vocab, dim = 1024, 256
107
+ token_ids = torch.randint(0, vocab, (16,), device=device, dtype=torch.int32)
108
+ emb_table = torch.randn(vocab, dim, device=device, dtype=torch.bfloat16)
109
+
110
+ out_ref = emb_table.float().index_select(0, token_ids.long())
111
+ out_kernel = torch.empty_like(out_ref)
112
+ gptoss_kernels.bf16_f32_embeddings(token_ids, emb_table, out_kernel, threadgroup_size=32)
113
+
114
+ print(out_kernel)
115
+ print(out_ref)
116
+
117
+ torch.testing.assert_close(out_kernel, out_ref, atol=1e-4, rtol=1e-3)
118
+ ```
119
+
120
+ ### 4. Scaled dot-product attention (SDPA)
121
+
122
+ ```python
123
+ from kernels import get_kernel
124
+ import torch
125
+ import torch.nn as nn
126
+
127
+ device = "mps"
128
+ gptoss_kernels = get_kernel("kernels-community/gptoss_kernels")
129
+
130
+
131
+ head_dim = 64
132
+ kv_heads = 8
133
+ qmul = 8
134
+ num_q_heads = kv_heads * qmul
135
+ history_tokens = 3
136
+ num_q_tokens = 2
137
+ total_tokens = history_tokens + num_q_tokens
138
+ max_tokens = total_tokens
139
+ num_kv_tokens = history_tokens
140
+
141
+ # Generate Q/K/V tensors
142
+ Q_chunk = torch.randn(num_q_tokens, num_q_heads, head_dim, device=device, dtype=torch.float32)
143
+ K_all = torch.randn(kv_heads, total_tokens, head_dim, device=device, dtype=torch.float32)
144
+ V_all = torch.randn(kv_heads, total_tokens, head_dim, device=device, dtype=torch.float32)
145
+
146
+ qkv_dim = head_dim * (num_q_heads + 2 * kv_heads)
147
+ q_buffer = torch.zeros(num_q_tokens, qkv_dim, device=device, dtype=torch.float32)
148
+ for t in range(num_q_tokens):
149
+ q_buffer[t, : num_q_heads * head_dim] = Q_chunk[t].reshape(-1)
150
+ token_idx = history_tokens + t
151
+ q_buffer[
152
+ t,
153
+ num_q_heads * head_dim : num_q_heads * head_dim + kv_heads * head_dim,
154
+ ] = K_all[:, token_idx, :].reshape(-1)
155
+ q_buffer[
156
+ t,
157
+ num_q_heads * head_dim + kv_heads * head_dim :,
158
+ ] = V_all[:, token_idx, :].reshape(-1)
159
+
160
+ token_stride = 2 * head_dim
161
+ kv_stride = token_stride * max_tokens
162
+ kv_cache = torch.zeros(kv_heads, kv_stride, device=device, dtype=torch.float32)
163
+ for h in range(kv_heads):
164
+ for t in range(total_tokens):
165
+ base = t * token_stride
166
+ kv_cache[h, base : base + head_dim] = K_all[h, t]
167
+ kv_cache[h, base + head_dim : base + 2 * head_dim] = V_all[h, t]
168
+
169
+ sink = torch.full((num_q_heads,), -1e4, device=device, dtype=torch.bfloat16)
170
+ output = torch.empty(num_q_tokens, num_q_heads, head_dim, device=device, dtype=torch.float32)
171
+
172
+ gptoss_kernels.f32_sdpa(
173
+ q_buffer,
174
+ 0,
175
+ kv_cache,
176
+ 0,
177
+ sink,
178
+ 0,
179
+ output,
180
+ 0,
181
+ window=total_tokens,
182
+ kv_stride=kv_stride,
183
+ num_q_tokens=num_q_tokens,
184
+ num_kv_tokens=num_kv_tokens,
185
+ num_q_heads=num_q_heads,
186
+ num_kv_heads=kv_heads,
187
+ head_dim=head_dim,
188
+ )
189
+ ```
190
+ For this kernel, the outputs match 97% of the time, It should be related to how the reference implementation is implemented below:
191
+
192
+ ```python
193
+ def sdpa(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, S: torch.Tensor, sm_scale: float, sliding_window: int = 0) -> torch.Tensor:
194
+ Q = Q.reshape(Q.shape[0], Q.shape[1], -1, Q.shape[-1])
195
+ n_tokens, n_heads, q_mult, d_head = Q.shape
196
+ assert K.shape == (n_tokens, n_heads, d_head)
197
+ assert V.shape == (n_tokens, n_heads, d_head)
198
+ K = K[:, :, None, :].expand(-1, -1, q_mult, -1)
199
+ V = V[:, :, None, :].expand(-1, -1, q_mult, -1)
200
+ S = S.reshape(n_heads, q_mult, 1, 1).expand(-1, -1, n_tokens, -1)
201
+ mask = torch.triu(Q.new_full((n_tokens, n_tokens), -float("inf")), diagonal=1)
202
+ if sliding_window > 0:
203
+ mask += torch.tril(
204
+ mask.new_full((n_tokens, n_tokens), -float("inf")), diagonal=-sliding_window
205
+ )
206
+ QK = torch.einsum("qhmd,khmd->hmqk", Q, K)
207
+ QK *= sm_scale
208
+ QK += mask[None, None, :, :]
209
+ QK = torch.cat([QK, S], dim=-1)
210
+ W = torch.softmax(QK, dim=-1)
211
+ W = W[..., :-1]
212
+ attn = torch.einsum("hmqk,khmd->qhmd", W, V)
213
+ return attn.reshape(n_tokens, -1)
214
+
215
+ scale = head_dim ** -0.5
216
+ q_buffer_cpu = q_buffer.detach().cpu()
217
+ kv_cache_cpu = kv_cache.detach().cpu()
218
+ sinks_cpu = sink.detach().to(torch.float32).cpu()
219
+
220
+ Q_total_cpu = torch.zeros(total_tokens, kv_heads, qmul, head_dim, dtype=torch.float32)
221
+ for idx, abs_idx in enumerate(range(num_kv_tokens, total_tokens)):
222
+ q_flat = q_buffer_cpu[idx, : num_q_heads * head_dim]
223
+ Q_total_cpu[abs_idx] = q_flat.view(kv_heads, qmul, head_dim)
224
+
225
+ K_total_cpu = torch.empty(total_tokens, kv_heads, head_dim, dtype=torch.float32)
226
+ V_total_cpu = torch.empty(total_tokens, kv_heads, head_dim, dtype=torch.float32)
227
+ for t in range(total_tokens):
228
+ base = t * token_stride
229
+ K_total_cpu[t] = kv_cache_cpu[:, base : base + head_dim]
230
+ V_total_cpu[t] = kv_cache_cpu[:, base + head_dim : base + 2 * head_dim]
231
+
232
+ output_ref = sdpa(
233
+ Q_total_cpu,
234
+ K_total_cpu,
235
+ V_total_cpu,
236
+ sinks_cpu,
237
+ scale,
238
+ sliding_window=0,
239
+ )
240
+
241
+ ```
242
+
243
+
244
+ These kernels form the core of the GPT-OSS inference stack, enabling BF16 activations with MXFP4 weights while keeping latency low on Metal GPUs [gpt-oss](https://github.com/openai/gpt-oss). Use the snippets as templates when validating your own model integrations or when extending the kernel set.