danieldk HF Staff commited on
Commit
59ff995
·
verified ·
1 Parent(s): 500eef9

Build uploaded using `kernels`.

Browse files
build/torch-universal/scattermoe/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
+ from . import parallel_experts
3
+ from . import kernels
4
+ from . import layers
5
+
6
+ __all__ = [
7
+ "flatten_sort_count",
8
+ "parallel_linear",
9
+ "ParallelExperts",
10
+ "parallel_experts",
11
+ "kernels",
12
+ "layers"
13
+ ]
build/torch-universal/scattermoe/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (435 Bytes). View file
 
build/torch-universal/scattermoe/__pycache__/layers.cpython-313.pyc ADDED
Binary file (2.56 kB). View file
 
build/torch-universal/scattermoe/__pycache__/parallel_experts.cpython-313.pyc ADDED
Binary file (8.4 kB). View file
 
build/torch-universal/scattermoe/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._scattermoe_bccb5f3
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_scattermoe_bccb5f3::{op_name}"
build/torch-universal/scattermoe/kernels/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import ops
2
+
3
+ __all__ = ["ops"]
build/torch-universal/scattermoe/kernels/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (243 Bytes). View file
 
build/torch-universal/scattermoe/kernels/__pycache__/ops.cpython-313.pyc ADDED
Binary file (20.6 kB). View file
 
build/torch-universal/scattermoe/kernels/ops.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from typing import Optional
5
+
6
+ BLOCK_M = 128
7
+ ALLOW_TF32 = True
8
+
9
+
10
+
11
+ @triton.jit
12
+ def _compute_expert_block(
13
+ E_idx, E_mask,
14
+ M_in_idx,
15
+ N_block, N_mask,
16
+ X_ptr, stride_xm, stride_xk,
17
+ W_ptr, stride_we, stride_wk, stride_wn,
18
+ K,
19
+ acc,
20
+ no_k_mask,
21
+ BLOCK_K,
22
+ allow_tf32=True,
23
+ ):
24
+
25
+ K_block = tl.arange(0, BLOCK_K)
26
+ X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
27
+ W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
28
+ iters = tl.cdiv(K, BLOCK_K)
29
+
30
+ for K_block_id in range(iters):
31
+ if no_k_mask:
32
+ x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
33
+ w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
34
+ else:
35
+ K_mask = (K_block_id * BLOCK_K + K_block) < K
36
+ x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
37
+ w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
38
+
39
+ X_blk_ptrs += BLOCK_K * stride_xk
40
+ W_blk_ptrs += BLOCK_K * stride_wk
41
+ acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
42
+ return acc
43
+
44
+
45
+ def _scatter2scatter_configs():
46
+ return [
47
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
48
+ ]
49
+
50
+ @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
51
+ @triton.heuristics({
52
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
53
+ "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
54
+ })
55
+ @triton.jit
56
+ def _scatter2scatter(
57
+ X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr,
58
+ W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr,
59
+ Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr,
60
+ B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr,
61
+ grouped_idx_ptr, expert_idxs_ptr,
62
+ # block_start_idx_ptr,
63
+ FAN_OUT: tl.constexpr,
64
+ M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
65
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
+ ACC_TYPE: tl.constexpr,
67
+ # OUT_M,
68
+ allow_tf32: tl.constexpr,
69
+ x_grouped: tl.constexpr, y_grouped: tl.constexpr,
70
+ NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
71
+ ):
72
+ pid = tl.program_id(axis=0)
73
+
74
+ N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
75
+ M_block_id = pid // N_BLOCK_COUNT
76
+ N_block_id = pid % N_BLOCK_COUNT
77
+
78
+ M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
79
+ N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
80
+ N_mask = N_block < N
81
+ M_boundary_mask = M_block < (FAN_OUT * M)
82
+ E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
83
+
84
+ no_k_mask = K % BLOCK_K == 0
85
+
86
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
87
+ E_first_idx = tl.min(E_idxs)
88
+ E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
89
+ M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
90
+ for E_idx in range(E_first_idx, E_last_idx + 1):
91
+ E_mask = E_idxs == E_idx
92
+ E_M_idx = M_idx
93
+ if x_grouped:
94
+ M_in_idx = M_block
95
+ else:
96
+ M_in_idx = E_M_idx // FAN_OUT
97
+ acc = _compute_expert_block(
98
+ E_idx, E_mask,
99
+ M_in_idx, N_block, N_mask,
100
+ X_ptr, stride_xm, stride_xk,
101
+ W_ptr, stride_we, stride_wk, stride_wn,
102
+ K,
103
+ acc,
104
+ no_k_mask,
105
+ BLOCK_K,
106
+ allow_tf32=allow_tf32,
107
+ )
108
+
109
+ if B_ptr is not None:
110
+ B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
111
+ acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
112
+
113
+ if y_grouped:
114
+ M_out_idx = M_block
115
+ else:
116
+ M_out_idx = M_idx
117
+ Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
118
+ tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
119
+
120
+ def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
121
+ b=None,
122
+ x_grouped=False, y_grouped=False,
123
+ out=None):
124
+ assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
125
+ assert sorted_scattered_idxs.size(0) == X.size(0) * k
126
+ # Pre-kernel setup
127
+ y_dim = W.size(-1)
128
+ L_scattered = sorted_expert_idxs.size(0)
129
+ if out is None:
130
+ output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
131
+ else:
132
+ assert out.size(0) == L_scattered and out.size(1) == y_dim
133
+ output = out
134
+
135
+ scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs,
136
+ b, x_grouped, y_grouped)
137
+ return output
138
+
139
+
140
+ @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
141
+ def scatter2scatter_compileable(
142
+ output: torch.Tensor,
143
+ W: torch.Tensor,
144
+ X: torch.Tensor,
145
+ k: int,
146
+ sorted_expert_idxs: torch.Tensor,
147
+ sorted_scattered_idxs: torch.Tensor,
148
+ b: Optional[torch.Tensor],
149
+ x_grouped: bool, y_grouped: bool) -> None:
150
+ def grid(META):
151
+ grid_num = (
152
+ triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) *
153
+ triton.cdiv(META['N'], META['BLOCK_N']),
154
+ )
155
+ return grid_num
156
+
157
+ if b is None:
158
+ b = None
159
+ stride_be = stride_bk = 0
160
+ else:
161
+ stride_be, stride_bk = b.stride()
162
+
163
+ _scatter2scatter[grid](
164
+ # X_ptr, stride_xm, stride_xk,
165
+ X, X.stride(0), X.stride(1),
166
+ # W_ptr, stride_we, stride_wk, stride_wn,
167
+ W, W.stride(0), W.stride(1), W.stride(2),
168
+ # Y_ptr, stride_ym, stride_yn,
169
+ output, output.stride(0), output.stride(1),
170
+ # B_ptr, stride_be, stride_bk
171
+ b, stride_be, stride_bk,
172
+ grouped_idx_ptr=sorted_scattered_idxs,
173
+ expert_idxs_ptr=sorted_expert_idxs,
174
+ # block_start_idx_ptr=padded_block_idxs,
175
+ FAN_OUT=k,
176
+ M=X.size(0),
177
+ K=X.size(1),
178
+ N=output.size(1), E=W.size(0),
179
+ BLOCK_M=BLOCK_M,
180
+ ACC_TYPE=tl.float32,
181
+ allow_tf32=ALLOW_TF32,
182
+ x_grouped=x_grouped, y_grouped=y_grouped,
183
+ )
184
+
185
+
186
+ def _config_XtY():
187
+ return [
188
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
189
+ ]
190
+
191
+ def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
192
+ DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
193
+ DW = DWt.permute(0, 2, 1)
194
+ if has_bias:
195
+ Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
196
+ else:
197
+ Db = None
198
+ groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
199
+ return DW, Db
200
+
201
+
202
+ @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"})
203
+ def groupXtY_compileable(
204
+ E: int,
205
+ DW: torch.Tensor,
206
+ Db: Optional[torch.Tensor],
207
+ DY: torch.Tensor,
208
+ X: torch.Tensor,
209
+ expert_offsets: torch.Tensor) -> None:
210
+ def grid(META):
211
+ grid = (
212
+ E * triton.cdiv(META['K'], META['BLOCK_K']),
213
+ triton.cdiv(META['N'], META['BLOCK_N']),
214
+ )
215
+ return grid
216
+
217
+ if Db is None:
218
+ stride_dbe = 0
219
+ stride_dbn = 0
220
+ else:
221
+ stride_dbe, stride_dbn = Db.stride()
222
+
223
+ _groupXtY[grid](
224
+ # DY_ptr, stride_dym, stride_dyk,
225
+ DY, DY.stride(0), DY.stride(1),
226
+ # X_ptr, stride_xm, stride_xn,
227
+ X, X.stride(0), X.stride(1),
228
+ # DW_ptr, stride_dwe, stride_dwk, stride_dwn,
229
+ DW, DW.stride(0), DW.stride(1), DW.stride(2),
230
+ # Db_ptr, stride_dwe, stride_dbn,
231
+ Db, stride_dbe, stride_dbn,
232
+ # expert_offsets_ptr,
233
+ expert_offsets,
234
+ # K: tl.constexpr, N: tl.constexpr,
235
+ M=DY.size(0), N=DY.size(-1), K=X.size(-1),
236
+ # ACC_TYPE: tl.constexpr,
237
+ ACC_TYPE=tl.float32,
238
+ allow_tf32=ALLOW_TF32
239
+ )
240
+
241
+
242
+ @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
243
+ @triton.heuristics({
244
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
245
+ "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
246
+ })
247
+ @triton.jit
248
+ def _groupXtY(
249
+ DY_ptr, stride_dym, stride_dyk,
250
+ X_ptr, stride_xm, stride_xn,
251
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
252
+ Db_ptr, stride_dbe, stride_dbn,
253
+ expert_offsets_ptr,
254
+ M, K: tl.constexpr, N: tl.constexpr,
255
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
256
+ ACC_TYPE: tl.constexpr,
257
+ allow_tf32: tl.constexpr,
258
+ NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
259
+ ):
260
+ pid0 = tl.program_id(axis=0)
261
+ pid1 = tl.program_id(axis=1)
262
+ num0 = tl.num_programs(0)
263
+ num1 = tl.num_programs(1)
264
+ # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
265
+ pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
266
+
267
+ K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
268
+ E_idx = pid0 // K_BLOCK_COUNT
269
+ K_block_id = pid0 % K_BLOCK_COUNT
270
+ N_block_id = pid1
271
+
272
+ if E_idx == 0:
273
+ start_idx = 0
274
+ else:
275
+ start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
276
+ end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
277
+
278
+
279
+ if end_idx > start_idx:
280
+ M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
281
+
282
+ K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
283
+ K_mask = K_block < K
284
+ K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
285
+
286
+ N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
287
+ N_mask = N_block < N
288
+ N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
289
+
290
+ M_idxs = M_block
291
+ xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
292
+ dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
293
+ if (Db_ptr is not None) and (K_block_id == 0):
294
+ _xty_and_bias(
295
+ E_idx, start_idx, end_idx,
296
+ M_block,
297
+ K_block, K_mask, N_block, N_mask,
298
+ dy_blk_ptrs, stride_dym,
299
+ xt_blk_ptrs, stride_xm,
300
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
301
+ Db_ptr, stride_dbe, stride_dbn,
302
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
303
+ allow_tf32, NO_K_MASK, NO_N_MASK,
304
+ compute_bias=True
305
+ )
306
+ else:
307
+ _xty_and_bias(
308
+ E_idx, start_idx, end_idx,
309
+ M_block,
310
+ K_block, K_mask, N_block, N_mask,
311
+ dy_blk_ptrs, stride_dym,
312
+ xt_blk_ptrs, stride_xm,
313
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
314
+ Db_ptr, stride_dbe, stride_dbn,
315
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
316
+ allow_tf32, NO_K_MASK, NO_N_MASK,
317
+ compute_bias=False
318
+ )
319
+
320
+
321
+ @triton.jit
322
+ def _xty_and_bias(
323
+ E_idx, start_idx, end_idx,
324
+ M_block,
325
+ K_block, K_mask, N_block, N_mask,
326
+ dy_blk_ptrs, stride_dym,
327
+ xt_blk_ptrs, stride_xm,
328
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
329
+ Db_ptr, stride_dbe, stride_dbn,
330
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
331
+ allow_tf32, NO_K_MASK, NO_N_MASK,
332
+ compute_bias: tl.constexpr
333
+ ):
334
+
335
+ if compute_bias:
336
+ db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
337
+ else:
338
+ db_acc = None
339
+
340
+ acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
341
+ iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
342
+ for i in range(0, iters):
343
+ M_mask = (i * BLOCK_M + M_block) < end_idx
344
+ if NO_K_MASK:
345
+ xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
346
+ else:
347
+ xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
348
+ if NO_N_MASK:
349
+ dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
350
+ else:
351
+ dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
352
+
353
+ acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
354
+
355
+ xt_blk_ptrs += BLOCK_M * stride_xm
356
+ dy_blk_ptrs += BLOCK_M * stride_dym
357
+
358
+ if compute_bias:
359
+ db_acc += tl.sum(dy, axis=0)
360
+
361
+ DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
362
+ acc = acc.to(DW_blk_ptrs.dtype.element_ty)
363
+ tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
364
+ if compute_bias:
365
+ Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
366
+ tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
367
+
368
+
369
+
370
+ def _config_grouping():
371
+ return [
372
+ triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
373
+ # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
374
+ # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
375
+ ]
376
+
377
+ def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
378
+ N = sorted_expert_idxs.size(0)
379
+ K = A.size(1)
380
+ assert A.size(0) * fan_out == N
381
+ if out is not None:
382
+ Y = out
383
+ else:
384
+ Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
385
+ group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
386
+ return Y
387
+
388
+
389
+ @torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
390
+ def group_compileable(
391
+ A: torch.Tensor,
392
+ K: int,
393
+ N: int,
394
+ Y: torch.Tensor,
395
+ coeff: torch.Tensor, has_coeff: bool,
396
+ fan_out: int,
397
+ sorted_expert_idxs: torch.Tensor) -> None:
398
+ def grid(META):
399
+ grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
400
+ return grid_num
401
+ _group[grid](
402
+ # A_ptr, stride_an, stride_ai,
403
+ A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out,
404
+ # Y_ptr, stride_yn, stride_yk,
405
+ Y, Y.stride(0), Y.stride(1),
406
+ # grouped_idx_ptr,
407
+ sorted_expert_idxs,
408
+ # N: tl.constexpr, K: tl.constexpr,
409
+ N, K
410
+ )
411
+
412
+
413
+ @triton.autotune(configs=_config_grouping(), key=['K'])
414
+ @triton.heuristics({
415
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
416
+ })
417
+ @triton.jit
418
+ def _group(
419
+ src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
420
+ tgt_ptr, stride_tn, stride_ti,
421
+ grouped_idx_ptr,
422
+ N, K: tl.constexpr,
423
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
424
+ NO_K_MASK: tl.constexpr
425
+ ):
426
+ pid = tl.program_id(axis=0)
427
+
428
+ N_block_id = pid
429
+ N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
430
+ N_mask = N_blk < N
431
+ N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
432
+ N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
433
+
434
+ K_blk = tl.arange(0, BLOCK_K)
435
+ src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
436
+ tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
437
+
438
+ if has_coeff:
439
+ c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
440
+
441
+ iters = tl.cdiv(K, BLOCK_K)
442
+ for i in range(0, iters):
443
+ if NO_K_MASK or i < iters - 1:
444
+ block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
445
+ if has_coeff:
446
+ block *= c
447
+ tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
448
+
449
+ else:
450
+ K_mask = (i * BLOCK_K + K_blk) < K
451
+ mask = N_mask[:, None] & K_mask[None, :]
452
+ block = tl.load(src_blk_ptrs, mask=mask)
453
+ if has_coeff:
454
+ block *= c
455
+ tl.store(tgt_blk_ptrs, block, mask=mask)
456
+ src_blk_ptrs += BLOCK_K * stride_sk
457
+ tgt_blk_ptrs += BLOCK_K * stride_ti
build/torch-universal/scattermoe/kernels/single.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ @triton.jit
6
+ def _single2scatter(
7
+ X_ptr, stride_xm, stride_xk,
8
+ W_ptr, stride_we, stride_wk, stride_wn,
9
+ Y_ptr, stride_ym, stride_yn,
10
+ expert_idxs_ptr,
11
+ FAN_OUT: tl.constexpr,
12
+ K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
13
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
14
+ ACC_TYPE: tl.constexpr,
15
+ ):
16
+ pid0 = tl.program_id(axis=0)
17
+ pid1 = tl.program_id(axis=1)
18
+
19
+ N_block_id = pid0
20
+ if FAN_OUT == 1:
21
+ in_idx = pid1
22
+ else:
23
+ in_idx = 0
24
+ out_idx = pid1
25
+
26
+ K_block = tl.arange(0, BLOCK_K)
27
+ N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
28
+ E_idx = tl.load(expert_idxs_ptr + pid1)
29
+ X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
30
+ W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
31
+ acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
32
+ for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
33
+ x = tl.load(X_blk_ptrs)
34
+ w = tl.load(W_blk_ptrs)
35
+ acc += tl.sum(x * w, axis=0)[None, :]
36
+ X_blk_ptrs += BLOCK_K * stride_xk
37
+ W_blk_ptrs += BLOCK_K * stride_wk
38
+ Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
39
+ tl.store(Y_blk_ptrs, acc)
40
+
41
+ def single2scatter(X, W, expert_idxs):
42
+ E, xdim, ydim = W.size()
43
+ k = expert_idxs.size(1)
44
+ assert X.size(0) == k or X.size(0) == 1
45
+ Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
46
+ BLOCK_N = 128
47
+ BLOCK_K = 128
48
+ grid = ydim // BLOCK_N, k
49
+ _single2scatter[grid](
50
+ X, X.stride(0), X.stride(1),
51
+ W, W.stride(0), W.stride(1), W.stride(2),
52
+ Y, Y.stride(0), Y.stride(1),
53
+ expert_idxs,
54
+ FAN_OUT=Y.size(0) // X.size(0),
55
+ K=xdim, N=ydim, E=E,
56
+ BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
57
+ ACC_TYPE=tl.float32
58
+ )
59
+ return Y
build/torch-universal/scattermoe/layers.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from torch import nn
4
+
5
+ from . import parallel_linear, flatten_sort_count
6
+
7
+ class ScatterMoEGatedMLP(nn.Module):
8
+ def forward(self, layer_input):
9
+ """
10
+ Forward pass of the mixture of experts layer.
11
+
12
+ Args:
13
+ layer_input (Tensor):
14
+ Input tensor.
15
+
16
+ Returns:
17
+ Tensor:
18
+ Output tensor.
19
+ Tensor:
20
+ Router logits.
21
+ """
22
+ bsz, length, emb_size = layer_input.size()
23
+ layer_input = layer_input.reshape(-1, emb_size)
24
+ # compute the top_k routing decision
25
+ router_logits = self.router.layer(layer_input)
26
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
27
+ routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1)
28
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
29
+ routing_weights = routing_weights.to(layer_input.dtype)
30
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
31
+ flatten_sort_count(selected_experts, num_experts=self.router.num_experts)
32
+
33
+ # compute experts
34
+ gates, h = parallel_linear(
35
+ layer_input, self.input_linear.weight.transpose(2, 1),
36
+ self.router.top_k,
37
+ sorted_expert_idxs, sorted_scattered_idxs,
38
+ expert_offsets,
39
+ grouped_in=False, grouped_out=True,
40
+ ).chunk(2, dim=-1)
41
+ h = self.activation(gates) * h
42
+ layer_output = parallel_linear(
43
+ h, self.output_linear.weight.transpose(2, 1),
44
+ 1,
45
+ sorted_expert_idxs, sorted_scattered_idxs,
46
+ expert_offsets,
47
+ grouped_in=True, grouped_out=False,
48
+ gates=routing_weights
49
+ )
50
+ layer_output = layer_output.view(bsz, length, emb_size)
51
+ return layer_output, router_logits
52
+
build/torch-universal/scattermoe/parallel_experts.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import kernels
4
+ from typing import Optional
5
+
6
+ @torch.library.custom_op("scattermoe::bincount", mutates_args={})
7
+ def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
8
+ return x.bincount(minlength=minlength)
9
+
10
+ @compileable_bincount.register_fake
11
+ def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
12
+ return torch.empty(minlength, dtype=torch.long, device=x.device)
13
+
14
+ @torch.compile
15
+ def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
16
+ with torch.no_grad():
17
+ flattened_expert_idxs = expert_idxs.flatten()
18
+ sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
19
+ expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
20
+ expert_offsets = expert_counts.cumsum(-1)
21
+ return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
22
+
23
+
24
+
25
+ class ParallelLinear(torch.autograd.Function):
26
+ @staticmethod
27
+ def forward(
28
+ ctx,
29
+ x: torch.Tensor, expert_weights: torch.Tensor, k: int,
30
+ sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
31
+ expert_offsets: torch.Tensor,
32
+ expert_biases: Optional[torch.Tensor]=None,
33
+ gates: Optional[torch.Tensor]=None,
34
+ grouped_in: bool =False, grouped_out: bool=False,
35
+ ):
36
+ with torch.device(x.device):
37
+ output = kernels.ops.scatter2scatter(
38
+ X=x, W=expert_weights,
39
+ b=expert_biases, k=k,
40
+ sorted_expert_idxs=sorted_expert_idxs,
41
+ sorted_scattered_idxs=sorted_scattered_idxs,
42
+ x_grouped=grouped_in, y_grouped=grouped_out
43
+ )
44
+ if gates is not None:
45
+ output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
46
+ output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
47
+ else:
48
+ output_expanded = None
49
+
50
+ ctx.save_for_backward(
51
+ x, expert_weights,
52
+ expert_biases,
53
+ sorted_expert_idxs,
54
+ sorted_scattered_idxs,
55
+ expert_offsets,
56
+ gates,
57
+ output_expanded
58
+ )
59
+ ctx.grouped_in = grouped_in
60
+ ctx.grouped_out = grouped_out
61
+ ctx.k = k
62
+ return output
63
+ @staticmethod
64
+ def backward(ctx, grad_out: torch.Tensor):
65
+ with torch.device(grad_out.device):
66
+ (x, expert_weights, expert_biases,
67
+ sorted_expert_idxs,
68
+ sorted_scattered_idxs,
69
+ expert_offsets,
70
+ gates, output_expanded) = ctx.saved_tensors
71
+ k = ctx.k
72
+ grouped_in = ctx.grouped_in
73
+ grouped_out = ctx.grouped_out
74
+ # print("backward")
75
+
76
+ if gates is not None:
77
+ # calculate gates gradient
78
+ # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
79
+ d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
80
+ gates_flat = gates.flatten()
81
+ gate_fan = gates.size(1)
82
+ grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
83
+ else:
84
+ d_gates = None
85
+ gates_flat = None
86
+ gate_fan = 1
87
+ grouped_grad_out = None
88
+
89
+ if grouped_out:
90
+ grouped_grad_out = grad_out
91
+ else:
92
+ grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
93
+ fan_out=gate_fan, coeff=gates_flat,
94
+ out=grouped_grad_out)
95
+ if grouped_in:
96
+ grouped_x = x
97
+ d_expanded_input = None
98
+ else:
99
+ grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
100
+ d_expanded_input = grouped_x
101
+
102
+ d_weights, d_biases = kernels.ops.group_bwd_W(
103
+ DY=grouped_grad_out, X=grouped_x,
104
+ expert_offsets=expert_offsets,
105
+ E=expert_weights.size(0),
106
+ has_bias=expert_biases is not None
107
+ )
108
+
109
+
110
+ d_expanded_input = kernels.ops.scatter2scatter(
111
+ X=grouped_grad_out, x_grouped=True,
112
+ W=expert_weights.permute(0, 2, 1),
113
+ sorted_expert_idxs=sorted_expert_idxs,
114
+ sorted_scattered_idxs=sorted_scattered_idxs,
115
+ k=1,
116
+ y_grouped=grouped_in,
117
+ out=d_expanded_input # Reuse grouped_x buffer
118
+ )
119
+
120
+ if k == 1:
121
+ d_input = d_expanded_input
122
+ else:
123
+ d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
124
+ # print("backward end.")
125
+ return (
126
+ # x, expert_weights,
127
+ d_input, d_weights,
128
+ # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
129
+ None, None, None, None,
130
+ # bias, gates
131
+ d_biases, d_gates,
132
+ # grouped_in, grouped_out,
133
+ None, None
134
+ )
135
+
136
+ def parallel_linear(inputs, expert_weights, k,
137
+ sorted_expert_idxs, sorted_scattered_idxs,
138
+ expert_offsets,
139
+ expert_biases=None,
140
+ gates=None, grouped_in=False, grouped_out=False):
141
+ results = ParallelLinear.apply(inputs, expert_weights, k,
142
+ sorted_expert_idxs, sorted_scattered_idxs,
143
+ expert_offsets,
144
+ expert_biases,
145
+ gates, grouped_in, grouped_out)
146
+ return results
147
+
148
+ class ParallelExperts(nn.Module):
149
+ def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
150
+ super().__init__()
151
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
152
+
153
+ if bias:
154
+ self.bias = nn.Parameter(torch.empty(num_experts, output_size))
155
+ else:
156
+ self.bias = None
157
+
158
+ self.num_experts = num_experts
159
+ self.input_size = input_size
160
+ self.output_size = output_size
161
+ self.reset_parameters()
162
+
163
+ def extra_repr(self):
164
+ return 'num_experts={}, input_size={}, output_size={}'.format(
165
+ self.num_experts, self.input_size, self.output_size)
166
+
167
+ def reset_parameters(self) -> None:
168
+ nn.init.normal_(self.weight, std=0.02)
169
+ if self.bias is not None:
170
+ nn.init.zeros_(self.bias)
171
+
172
+ def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
173
+ expert_offsets,
174
+ gates=None, grouped_in=False, grouped_out=False):
175
+
176
+ results = parallel_linear(
177
+ inputs, self.weight.permute(0, 2, 1), k,
178
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
179
+ expert_biases=self.bias,
180
+ gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
181
+ )
182
+ return results