Abdennacer Badaoui commited on
Commit
29547e2
·
1 Parent(s): b83c8d6

gemm radeon kernel

Browse files
Files changed (34) hide show
  1. build.toml +39 -0
  2. build/torch27-cxx11-rocm63-x86_64-linux/gemm/__init__.py +19 -0
  3. build/torch27-cxx11-rocm63-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc +0 -0
  4. build/torch27-cxx11-rocm63-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc +0 -0
  5. build/torch27-cxx11-rocm63-x86_64-linux/gemm/_gemm_60de0b2.abi3.so +3 -0
  6. build/torch27-cxx11-rocm63-x86_64-linux/gemm/_ops.py +9 -0
  7. build/torch28-cxx11-rocm63-x86_64-linux/gemm/__init__.py +19 -0
  8. build/torch28-cxx11-rocm63-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc +0 -0
  9. build/torch28-cxx11-rocm63-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc +0 -0
  10. build/torch28-cxx11-rocm63-x86_64-linux/gemm/_gemm_60de0b2.abi3.so +3 -0
  11. build/torch28-cxx11-rocm63-x86_64-linux/gemm/_ops.py +9 -0
  12. build/torch28-cxx11-rocm64-x86_64-linux/gemm/__init__.py +19 -0
  13. build/torch28-cxx11-rocm64-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc +0 -0
  14. build/torch28-cxx11-rocm64-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc +0 -0
  15. build/torch28-cxx11-rocm64-x86_64-linux/gemm/_gemm_60de0b2.abi3.so +3 -0
  16. build/torch28-cxx11-rocm64-x86_64-linux/gemm/_ops.py +9 -0
  17. flake.lock +169 -0
  18. flake.nix +18 -0
  19. gemm/gemm_kernel.h +896 -0
  20. gemm/gemm_kernel_legacy.h +377 -0
  21. gemm/gemm_launcher.hip +267 -0
  22. gemm/transpose_kernel.h +120 -0
  23. include/clangd_workaround.h +9 -0
  24. include/gpu_libs.h +42 -0
  25. include/gpu_types.h +28 -0
  26. include/timer.h +57 -0
  27. src/utils/arithmetic.h +12 -0
  28. src/utils/timer.hip +104 -0
  29. tests/checker/checker.cpp +469 -0
  30. tests/checker/checker.h +33 -0
  31. tests/checker/metrics.h +13 -0
  32. torch-ext/gemm/__init__.py +19 -0
  33. torch-ext/torch_binding.cpp +53 -0
  34. torch-ext/torch_binding.h +6 -0
build.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "gemm"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.gemm]
12
+ backend = "rocm"
13
+ rocm-archs = [
14
+ #"gfx908",
15
+ # "gfx90a",
16
+ # "gfx940",
17
+ # "gfx941",
18
+ "gfx942",
19
+ # "gfx1100",
20
+ # "gfx1101",
21
+ # "gfx1102",
22
+ # "gfx1200",
23
+ # "gfx1201",
24
+ ]
25
+ depends = ["torch"]
26
+ src = [
27
+ "include/clangd_workaround.h",
28
+ "include/gpu_libs.h",
29
+ "include/gpu_types.h",
30
+ "include/timer.h",
31
+ "gemm/gemm_kernel.h",
32
+ "gemm/gemm_kernel_legacy.h",
33
+ "gemm/gemm_launcher.hip",
34
+ "gemm/transpose_kernel.h",
35
+ "src/utils/arithmetic.h",
36
+ "src/utils/timer.hip",
37
+ "tests/checker/metrics.h",
38
+ ]
39
+ include = ["include"]
build/torch27-cxx11-rocm63-x86_64-linux/gemm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from ._ops import ops
4
+
5
+ def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor,
6
+ out: Optional[torch.Tensor] = None) -> torch.Tensor:
7
+
8
+ if out is None:
9
+ # Create output tensor with appropriate shape and dtype
10
+ M, K = a.shape
11
+ K_b, N = b.shape
12
+ assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows"
13
+
14
+ # Output should be BF16 type on the same device as inputs
15
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
16
+
17
+ ops.gemm(out, a, b, as_, bs)
18
+ return out
19
+
build/torch27-cxx11-rocm63-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.16 kB). View file
 
build/torch27-cxx11-rocm63-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (509 Bytes). View file
 
build/torch27-cxx11-rocm63-x86_64-linux/gemm/_gemm_60de0b2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abf5894843b26edc6c1894a0e5e9829441567319f446600f1e772fe2621e4faf
3
+ size 2196896
build/torch27-cxx11-rocm63-x86_64-linux/gemm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gemm_60de0b2
3
+ ops = torch.ops._gemm_60de0b2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gemm_60de0b2::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/gemm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from ._ops import ops
4
+
5
+ def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor,
6
+ out: Optional[torch.Tensor] = None) -> torch.Tensor:
7
+
8
+ if out is None:
9
+ # Create output tensor with appropriate shape and dtype
10
+ M, K = a.shape
11
+ K_b, N = b.shape
12
+ assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows"
13
+
14
+ # Output should be BF16 type on the same device as inputs
15
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
16
+
17
+ ops.gemm(out, a, b, as_, bs)
18
+ return out
19
+
build/torch28-cxx11-rocm63-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.16 kB). View file
 
build/torch28-cxx11-rocm63-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (509 Bytes). View file
 
build/torch28-cxx11-rocm63-x86_64-linux/gemm/_gemm_60de0b2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da6a47f577368eda36d189c4840e5a1da512a4a5ed55ac159fb516698f417d49
3
+ size 2196896
build/torch28-cxx11-rocm63-x86_64-linux/gemm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gemm_60de0b2
3
+ ops = torch.ops._gemm_60de0b2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gemm_60de0b2::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/gemm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from ._ops import ops
4
+
5
+ def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor,
6
+ out: Optional[torch.Tensor] = None) -> torch.Tensor:
7
+
8
+ if out is None:
9
+ # Create output tensor with appropriate shape and dtype
10
+ M, K = a.shape
11
+ K_b, N = b.shape
12
+ assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows"
13
+
14
+ # Output should be BF16 type on the same device as inputs
15
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
16
+
17
+ ops.gemm(out, a, b, as_, bs)
18
+ return out
19
+
build/torch28-cxx11-rocm64-x86_64-linux/gemm/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (1.16 kB). View file
 
build/torch28-cxx11-rocm64-x86_64-linux/gemm/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (509 Bytes). View file
 
build/torch28-cxx11-rocm64-x86_64-linux/gemm/_gemm_60de0b2.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaedc163991e19d7ed36c1367fd5627daf741d54c1abfda128d00ea702cb18f6
3
+ size 2195560
build/torch28-cxx11-rocm64-x86_64-linux/gemm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _gemm_60de0b2
3
+ ops = torch.ops._gemm_60de0b2
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_gemm_60de0b2::{op_name}"
flake.lock ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1759385472,
77
+ "narHash": "sha256-a1YMZp3Yc1RJfLIObRKBTTbjMKL91IYbzTjG/HNZN+I=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "050dd78a64cb58fb1f9fb29ca498c73107a9a13e",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1759391810,
102
+ "narHash": "sha256-7thAD4hsNGvDb59rne7Vt0JsdyjK/DaeW3PIXWjueYc=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "6029e51ead3bc6328fc817be44a546f366b84d93",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "ref": "rocm-per-source-arches",
111
+ "repo": "kernel-builder",
112
+ "type": "github"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1755963616,
118
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
119
+ "owner": "nixos",
120
+ "repo": "nixpkgs",
121
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "nixos",
126
+ "ref": "nixos-unstable-small",
127
+ "repo": "nixpkgs",
128
+ "type": "github"
129
+ }
130
+ },
131
+ "root": {
132
+ "inputs": {
133
+ "kernel-builder": "kernel-builder"
134
+ }
135
+ },
136
+ "systems": {
137
+ "locked": {
138
+ "lastModified": 1681028828,
139
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
+ "owner": "nix-systems",
141
+ "repo": "default",
142
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
+ "type": "github"
144
+ },
145
+ "original": {
146
+ "owner": "nix-systems",
147
+ "repo": "default",
148
+ "type": "github"
149
+ }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
+ }
166
+ },
167
+ "root": "root",
168
+ "version": 7
169
+ }
flake.nix ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for GEMM kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/rocm-per-source-arches";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+
14
+ kernel-builder.lib.genFlakeOutputs {
15
+ inherit self;
16
+ path = ./.;
17
+ };
18
+ }
gemm/gemm_kernel.h ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Pipeline GEMM kernel. This version is rushed written and may not applied to all shape.
2
+ // Currently, only selected parameters is tested. (See gemm_launcher )
3
+ #ifndef GEMM_KERNEL
4
+ #define GEMM_KERNEL
5
+
6
+ #include <cstdio>
7
+ #include <hip/amd_detail/amd_hip_runtime.h>
8
+ #include <hip/amd_detail/amd_warp_functions.h>
9
+ #pragma clang diagnostic push
10
+ #pragma clang diagnostic ignored "-Wunknown-attributes"
11
+ #include "../include/gpu_libs.h"
12
+ #include "../include/gpu_types.h"
13
+ #include "../src/utils/arithmetic.h"
14
+ #include "../include/clangd_workaround.h"
15
+ #include <cstdlib>
16
+ #include <cfloat>
17
+
18
+ namespace gemm_kernel {
19
+
20
+ template <typename data_type, int BATCH_SIZE> __device__ inline void read_batch(data_type *dst, const data_type *src) {
21
+ if constexpr ((sizeof(data_type) * BATCH_SIZE) == 2 * sizeof(ulong4)) {
22
+ *(reinterpret_cast<ulong4 *>(dst) + 0) = *(reinterpret_cast<const ulong4 *>(src) + 0);
23
+ *(reinterpret_cast<ulong4 *>(dst) + 1) = *(reinterpret_cast<const ulong4 *>(src) + 1);
24
+ } else if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
25
+ *reinterpret_cast<ulong4 *>(dst) = *reinterpret_cast<const ulong4 *>(src);
26
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
27
+ *reinterpret_cast<ulong2 *>(dst) = *reinterpret_cast<const ulong2 *>(src);
28
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
29
+ *reinterpret_cast<ulong1 *>(dst) = *reinterpret_cast<const ulong1 *>(src);
30
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
31
+ *reinterpret_cast<uint1 *>(dst) = *reinterpret_cast<const uint1 *>(src);
32
+ } else {
33
+ #pragma unroll
34
+ for (int b = 0; b < BATCH_SIZE; ++b) {
35
+ dst[b] = src[b];
36
+ }
37
+ }
38
+ }
39
+
40
+ template <typename data_type, int BATCH_SIZE> __device__ inline void zero_batch(data_type *dst) {
41
+ if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
42
+ *reinterpret_cast<ulong4 *>(dst) = ulong4{};
43
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
44
+ *reinterpret_cast<ulong2 *>(dst) = ulong2{};
45
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
46
+ *reinterpret_cast<ulong1 *>(dst) = ulong1{};
47
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
48
+ *reinterpret_cast<uint *>(dst) = uint{};
49
+ } else {
50
+ #pragma unroll
51
+ for (int b = 0; b < BATCH_SIZE; ++b) {
52
+ dst[b] = 0;
53
+ }
54
+ }
55
+ }
56
+
57
+ template <typename data_type, int DST_Y, int DST_X, int SRC_Y, int SRC_X, int BLOCK_DIM, int BATCH_SIZE>
58
+ __device__ inline void load_input(data_type dst[DST_Y][DST_X], const data_type src[SRC_Y][SRC_X], const int begin_x,
59
+ const int begin_y) {
60
+ static_assert(BATCH_SIZE > 0);
61
+ /**
62
+ Consider (SRC_X % DST_X == 0) && (SRC_Y % DST_Y == 0)
63
+ Step 1:
64
+ [ ][***][ ][ ]
65
+ [ ][ ][ ][ ]
66
+ [ ][ ][ ][ ]
67
+ [ ][ ][ ][ ]
68
+ Step 2:
69
+ [ ][ ][ ][ ]
70
+ [ ][***][ ][ ]
71
+ [ ][ ][ ][ ]
72
+ [ ][ ][ ][ ]
73
+ */
74
+ static_assert((SRC_X % BATCH_SIZE == 0) && (SRC_Y % BATCH_SIZE == 0));
75
+ static_assert((DST_X % BATCH_SIZE == 0) && (DST_Y % BATCH_SIZE == 0));
76
+ static_assert(BATCH_SIZE <= DST_X && DST_X % BATCH_SIZE == 0);
77
+ const int begin_idx = threadIdx.x * BATCH_SIZE;
78
+ const constexpr int total_elements = DST_X * DST_Y;
79
+ const constexpr int elements_per_step = BLOCK_DIM * BATCH_SIZE;
80
+ // FIXME: loop unrolling
81
+ #pragma unroll
82
+ for (int k = begin_idx; k < total_elements; k += elements_per_step) {
83
+ int l_kx = k % DST_X;
84
+ int l_ky = k / DST_X;
85
+ int g_kx = l_kx + begin_x;
86
+ int g_ky = l_ky + begin_y;
87
+ auto *dst_flatten = &dst[l_ky][l_kx];
88
+ // const auto *src_flatten = &src[g_ky][g_kx];
89
+ // read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
90
+ if (((SRC_X % DST_X == 0) || (g_kx < SRC_X)) && ((SRC_Y % DST_Y == 0) || (g_ky < SRC_Y))) {
91
+ const auto *src_flatten = &src[g_ky][g_kx];
92
+ read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
93
+ } else {
94
+ zero_batch<data_type, BATCH_SIZE>(dst_flatten);
95
+ }
96
+ }
97
+ }
98
+
99
+ template <int PM, int PN, int QM, int QN, int QK, int QUANT_SIZE, int BLOCK_SIZE, int BATCH_SIZE>
100
+ __device__ void load_scale(float s_s[PM][PN], const float sa[QK][QM], const float sb[QK][QN], const int m, const int n,
101
+ const int k) {
102
+ constexpr int total_elements = PM * PN;
103
+ constexpr int elements_per_step = BLOCK_SIZE * BATCH_SIZE;
104
+ // static_assert(PN % BATCH_SIZE)
105
+
106
+ const int begin_idx = threadIdx.x * BATCH_SIZE;
107
+ #pragma unroll
108
+ for (int idx = begin_idx; idx < total_elements; idx += elements_per_step) {
109
+ static_assert(BATCH_SIZE == 1);
110
+ int i = idx / PN;
111
+ int j = idx % PN;
112
+ if (((QM % PM == 0) || (m + i < QM)) && ((QN % PN == 0) || ((n + j) / QUANT_SIZE < QN))) {
113
+ s_s[i][j] = sa[k / QUANT_SIZE][(m + i)] * sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
114
+ } else {
115
+ s_s[i][j] = 1.0f;
116
+ }
117
+ }
118
+ }
119
+
120
+ // don't use __builtin_readcyclecounter(), which would insert waitcnt
121
+ __device__ auto getclock() {
122
+ uint64_t clk;
123
+ asm volatile("s_memtime %0" : "=r"(clk));
124
+ return clk;
125
+ }
126
+
127
+
128
+ template <typename Elem> __global__ void check_trans(const Elem *origin, const Elem *tranposed, int m, int n) {
129
+ auto x = threadIdx.x + blockIdx.x * blockDim.x;
130
+ auto y = threadIdx.y + blockIdx.y * blockDim.y;
131
+ if (x < m && y < n) {
132
+ if (origin[x * n + y] != tranposed[y * m + x]) {
133
+ printf("Error: %d %d\n", x, y);
134
+ }
135
+ }
136
+ }
137
+
138
+ template <typename in_data_type, typename acc_data_type, typename FragC, typename FragA, typename FragB, int PM, int PN,
139
+ int BM, int BN, int BK, int FRAG_M, int FRAG_N, int FRAG_K, int WMMA_M, int WMMA_N, int WMMA_K, int WARP_M,
140
+ int WARP_N, int BLOCK_SIZE, int BATCH_SIZE, int QUANT_SIZE>
141
+ __device__ void wmma_compute(const in_data_type s_a[BM][BK + 8], const in_data_type s_b[BN][BK + 8],
142
+ const float s_s[PN][PM], FragC frag_r[FRAG_M][FRAG_N], const int comp_c_frag_m,
143
+ const int comp_c_frag_n) {
144
+ FragC frag_c[FRAG_M][FRAG_N];
145
+
146
+ #pragma unroll
147
+ for (int i = 0; i < FRAG_M; i++) {
148
+ #pragma unroll
149
+ for (int j = 0; j < FRAG_N; j++) {
150
+ wmma::fill_fragment(frag_c[i][j], 0.0F);
151
+ }
152
+ }
153
+
154
+ #pragma unroll
155
+ for (int k = 0; k < FRAG_K; ++k) {
156
+ #pragma unroll
157
+ for (int i = 0; i < FRAG_M; i++) {
158
+ FragA frag_a;
159
+ int s_a_row = k * WMMA_K;
160
+ int s_a_col = (comp_c_frag_m * FRAG_M + i) * WMMA_M;
161
+ wmma::load_matrix_sync(frag_a, &s_a[s_a_col][s_a_row], BK + 8);
162
+ #pragma unroll
163
+ for (int j = 0; j < FRAG_N; j++) {
164
+ FragB frag_b;
165
+ int s_b_row = k * WMMA_K;
166
+ int s_b_col = (comp_c_frag_n * FRAG_N + j) * WMMA_N;
167
+ wmma::load_matrix_sync(frag_b, &s_b[s_b_col][s_b_row], BK + 8);
168
+
169
+ wmma::mma_sync(frag_c[i][j], frag_a, frag_b, frag_c[i][j]);
170
+ }
171
+ }
172
+ }
173
+ #pragma unroll
174
+ for (int i = 0; i < FRAG_M; i++) {
175
+ #pragma unroll
176
+ for (int j = 0; j < FRAG_N; j++) {
177
+ #pragma unroll
178
+ for (int k = 0; k < FragC::num_elements; ++k) {
179
+ #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
180
+ int m = ((threadIdx.x & 16) >> 1) | (k & 7) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
181
+ #else // CDNA3, WAVE_SIZE = 64
182
+ // int m = ((threadIdx.x & 48) >> 2) | (k & 3) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
183
+ #endif
184
+ // int n = ((threadIdx.x & 15) | (comp_c_frag_n * FRAG_N + j) * WMMA_N) / QUANT_SIZE;
185
+ auto lane = threadIdx.x % 64;
186
+ int m, n;
187
+ if constexpr (WMMA_M == 32) {
188
+ // C or D i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
189
+ // C or D j: (lane % 32)
190
+ m = (8 * (k / 4) % 32) + 4 * (lane / 32) + (k % 4);
191
+ n = lane % 32;
192
+ } else {
193
+ // C or D i: 4 * floor(lane / 16) + (GPR_num % 4)
194
+ // C or D j: (lane % 16)
195
+ m = 4 * (lane / 16) + (k % 4);
196
+ n = lane % 16;
197
+ }
198
+ m += (comp_c_frag_m * FRAG_M + i) * WMMA_M;
199
+ n += (comp_c_frag_n * FRAG_N + j) * WMMA_N;
200
+ n = n / QUANT_SIZE;
201
+ // if(threadIdx.x == 192 && blockIdx.x ==0 && blockIdx.y == 0 && blockIdx.z == 0)
202
+ // printf("m: %d, n: %d\n", m, n);
203
+ float scale = s_s[n][m];
204
+ frag_r[i][j].x[k] += (acc_data_type)scale * (acc_data_type)frag_c[i][j].x[k];
205
+ }
206
+ }
207
+ }
208
+ }
209
+
210
+ __device__ rocwmma::bfloat16_t fast_f32tob16(float f) {
211
+ union {
212
+ float fp32;
213
+ unsigned int u32;
214
+ } u = {f};
215
+ u.u32 += 0x7fff + ((u.u32 >> 16) & 1);
216
+ auto ret = u.u32 >> 16;
217
+ return reinterpret_cast<rocwmma::bfloat16_t &>(ret);
218
+ }
219
+
220
+ template <typename acc_data_type, typename out_data_type, typename FragC, typename FragOut, int WMMA_M, int WMMA_N,
221
+ int BM, int BN, int M, int N, int FRAG_M, int FRAG_N>
222
+ __device__ inline void store_result(out_data_type c[M][N], FragC frag_r[FRAG_M][FRAG_N], const int m, const int n,
223
+ const int comp_c_frag_m, const int comp_c_frag_n) {
224
+ #pragma unroll
225
+ for (int i = 0; i < FRAG_M; i++) {
226
+ #pragma unroll
227
+ for (int j = 0; j < FRAG_N; j++) {
228
+ int frag_m = comp_c_frag_m * FRAG_M + i;
229
+ int frag_n = comp_c_frag_n * FRAG_N + j;
230
+ int row = m + frag_m * WMMA_M;
231
+ int col = n + frag_n * WMMA_N;
232
+ if (((M % BM == 0) || (row < M)) && ((N % BN == 0) || (col < N))) {
233
+ out_data_type *c_ptr = &c[row][col];
234
+ if constexpr (sizeof(acc_data_type) == sizeof(out_data_type)) { // split_k
235
+ auto lane = threadIdx.x % 64;
236
+ #pragma unroll
237
+ for (int k = 0; k < FragC::num_elements; ++k) {
238
+ int m, n;
239
+ if constexpr (WMMA_M == 32) {
240
+ // C or D i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
241
+ // C or D j: (lane % 32)
242
+ m = (8 * (k / 4) % 32) + 4 * (lane / 32) + (k % 4);
243
+ n = lane % 32;
244
+ } else {
245
+ // C or D i: 4 * floor(lane / 16) + (GPR_num % 4)
246
+ // C or D j: (lane % 16)
247
+ m = 4 * (lane / 16) + (k % 4);
248
+ n = lane % 16;
249
+ }
250
+ c_ptr[m * N + n] = frag_r[i][j].x[k];;
251
+ }
252
+
253
+ // wmma::store_matrix_sync(reinterpret_cast<out_data_type *>(c_ptr), frag_r[i][j], N,
254
+ // wmma::mem_row_major);
255
+ } else if constexpr (sizeof(out_data_type) == sizeof(half)) {
256
+ FragOut frag_out;
257
+ static_assert(sizeof(half) == sizeof(out_data_type));
258
+ static_assert(FragOut::num_elements == FragC::num_elements);
259
+ for (int k = 0; k < FragOut::num_elements; ++k) {
260
+ auto reg = fast_f32tob16(frag_r[i][j].x[k]);
261
+ frag_out.x[k] = *reinterpret_cast<half *>(&reg);
262
+ }
263
+ wmma::store_matrix_sync(reinterpret_cast<half *>(c_ptr), frag_out, N, wmma::mem_row_major);
264
+ } else {
265
+ static_assert(0, "Unsupported data type for output");
266
+ }
267
+ }
268
+ }
269
+ }
270
+ }
271
+
272
+ // a dummy template to allow inlcuding this file
273
+ template <int Splitk> __global__ void reduce(uint32_t m, uint32_t n, const float *c_splitk, __hip_bfloat16 *c) {
274
+ auto tid = blockIdx.x * blockDim.x + threadIdx.x;
275
+ if (tid >= m * n) {
276
+ return;
277
+ }
278
+ float4 sum{};
279
+ #pragma unroll
280
+ for (auto i = 0; i < Splitk; ++i) {
281
+ sum += *(float4 *)&c_splitk[i * (m * n) + tid * 4];
282
+ }
283
+ auto res =
284
+ rocwmma::make_vector(fast_f32tob16(sum.x), fast_f32tob16(sum.y), fast_f32tob16(sum.z), fast_f32tob16(sum.w));
285
+ *(decltype(res) *)&c[tid * 4] = res;
286
+ }
287
+
288
+ template<int M, int N, int SPLITK_FACTOR, int BLOCK_SIZE>
289
+ __launch_bounds__(BLOCK_SIZE)
290
+ __global__ void reduce_kernel(const float c_splitk[SPLITK_FACTOR][M][N], __hip_bfloat16 c[M][N]) {
291
+ auto tid = blockIdx.x * blockDim.x + threadIdx.x;
292
+ if (tid >= M * N) {
293
+ return;
294
+ }
295
+ float4 sum{};
296
+ #pragma unroll
297
+ for (auto i = 0; i < SPLITK_FACTOR; ++i) {
298
+ sum += *(float4 *)&reinterpret_cast<const float*>(c_splitk)[i * (M * N) + tid * 4];
299
+ }
300
+ auto res =
301
+ rocwmma::make_vector(fast_f32tob16(sum.x), fast_f32tob16(sum.y), fast_f32tob16(sum.z), fast_f32tob16(sum.w));
302
+ *(decltype(res) *)&reinterpret_cast< __BF16_TYPE*>(c)[tid * 4] = res;
303
+ }
304
+
305
+
306
+ #ifdef PARAMETERIZE_LIBRARY
307
+ template <typename in_data_type,
308
+ typename acc_data_type, // Accumulator type (e.g., float)
309
+ typename out_data_type, // Output type (e.g., __hip_bfloat16)
310
+ int M, int N, int K, // Matrix dimensions
311
+ int BM, int BN, int BK, // Tile dimensions
312
+ int QUANT_SIZE, // Quantization block size
313
+ int BLOCK_SIZE, // Block size
314
+ int WARP_M, int WARP_N, // Warp dimensions
315
+ int LDA, int LDB,
316
+ int LOAD_BATCH_SIZE> // Load batch size for vectorized memory operations
317
+ #else
318
+ using in_data_type = __FP8_TYPE;
319
+ using out_data_type = __BF16_TYPE;
320
+ using acc_data_type = float;
321
+ // constexpr int M = 4096, N = 4096, K = 4096;
322
+ constexpr int M = 6144, N = 4608, K = 7168;
323
+ constexpr int LDA = K, LDB = K;
324
+ // constexpr int M = 512, N = 512, K = 512;
325
+ constexpr int BM = 256, BN = 128, BK = 128;
326
+ constexpr int QUANT_SIZE = 128, BLOCK_SIZE = 512;
327
+ constexpr int LOAD_BATCH_SIZE = 16;
328
+ #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
329
+ constexpr int WARP_M = 4, WARP_N = 2;
330
+ #else // CDNA3, WAVE_SIZE = 64
331
+ constexpr int WARP_M = 4, WARP_N = 2;
332
+ #endif
333
+ #endif // End of parameterization
334
+ __global__ __launch_bounds__(BLOCK_SIZE) void gemm_kernel(
335
+ const in_data_type a[M][LDA], const in_data_type b[N][LDB], out_data_type c[M][N],
336
+ const float sa[ceil_div(K, QUANT_SIZE)][M / 1], // Assuming M is divisible by 1 (always true)
337
+ const float sb[ceil_div(K, QUANT_SIZE)][ceil_div(N, QUANT_SIZE)]) {
338
+ // --- Start: Derived parameters and constants ---
339
+ constexpr int WMMA_M = 16; // Fixed WMMA dimension M
340
+ constexpr int WMMA_N = 16; // Fixed WMMA dimension N
341
+ constexpr int WMMA_K = 32; // Fixed WMMA dimension K (for FP8)
342
+
343
+ // WARP_M/N define the 2D arrangement of warps in the block grid.
344
+ // These might need adjustment based on BLOCK_DIM_X/Y strategy.
345
+ // Using fixed values based on the non-parameterized version for now.
346
+ // TODO: Derive WARP_M/N from BLOCK_DIM_X/Y if a flexible strategy is needed.
347
+ constexpr int WARP_NUM = WARP_M * WARP_N; // Total warps per block
348
+
349
+ // Assertion: Check if the assumed warp layout matches the block size
350
+ static_assert(WARP_NUM * WAVE_SIZE == BLOCK_SIZE, "WARP_M * WARP_N * WAVE_SIZE must equal BLOCK_SIZE");
351
+
352
+ // Fragments per warp
353
+ constexpr int FRAG_M_PER_WARP = BM / WMMA_M / WARP_M;
354
+ constexpr int FRAG_N_PER_WARP = BN / WMMA_N / WARP_N;
355
+ constexpr int FRAG_K = BK / WMMA_K; // Fragments along K dimension tile
356
+
357
+ static_assert(BM % (WMMA_M * WARP_M) == 0, "BM must be divisible by WMMA_M * WARP_M");
358
+ static_assert(BN % (WMMA_N * WARP_N) == 0, "BN must be divisible by WMMA_N * WARP_N");
359
+ static_assert(BK % WMMA_K == 0, "BK must be divisible by WMMA_K");
360
+ static_assert(BK >= 32, "BK must be at least 32");
361
+ // --- End: Derived parameters and constants ---
362
+
363
+ constexpr int QM = M; // Dimension M for scale A
364
+ constexpr int QN = ceil_div(N, QUANT_SIZE); // Dimension N for scale B (quantized)
365
+ constexpr int QK = ceil_div(K, QUANT_SIZE); // Dimension K for scales (quantized)
366
+ constexpr int PM = BM; // Block size M for scale A * B
367
+ constexpr int PN = ceil_div(BN, QUANT_SIZE); // Block size N for scale A * B
368
+
369
+ // Ensure derived fragment counts are positive
370
+ static_assert(FRAG_M_PER_WARP > 0, "FRAG_M_PER_WARP must be positive");
371
+ static_assert(FRAG_N_PER_WARP > 0, "FRAG_N_PER_WARP must be positive");
372
+ static_assert(FRAG_K > 0, "FRAG_K must be positive");
373
+
374
+ using FragA = wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::row_major>;
375
+ using FragB = wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::col_major>;
376
+ using FragC = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, acc_data_type>;
377
+ using FragOut = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
378
+ half>; // Output uses half for storage via bfloat16 reinterpret
379
+
380
+ __shared__ in_data_type s_a[BM][BK + 8];
381
+ __shared__ in_data_type s_b[BN][BK + 8];
382
+ __shared__ acc_data_type s_s[PN][PM]; // Accumulator type for scales
383
+ FragC frag_r[FRAG_M_PER_WARP][FRAG_N_PER_WARP]; // Accumulator fragments
384
+
385
+ // handle splitk
386
+ a = (decltype(a))((in_data_type *)a + blockIdx.z * K);
387
+ b = (decltype(b))((in_data_type *)b + blockIdx.z * K);
388
+ c += blockIdx.z * M;
389
+ sa += blockIdx.z * QK;
390
+ sb += blockIdx.z * QK;
391
+
392
+ int tid = threadIdx.x; // Linear thread ID within the block
393
+ int wid = tid / WAVE_SIZE; // Warp ID within the block
394
+
395
+ // Spilt and compute fragments
396
+ constexpr int iteration_over_k = ceil_div(K, BK); // Use ceil_div for potentially non-divisible K
397
+ static_assert(LOAD_BATCH_SIZE > 0, "LOAD_BATCH_SIZE must be positive");
398
+
399
+ constexpr auto PIPELINE = true;
400
+ // using LoadVec = rocwmma::VecT<float, LOAD_BATCH_SIZE / sizeof(float)>;
401
+ using LoadVec = __attribute__((__vector_size__(LOAD_BATCH_SIZE))) float;
402
+ static_assert(((BK * BM) % (BLOCK_SIZE * LOAD_BATCH_SIZE)) == 0,
403
+ "BK * BM must be divisible by BLOCK_SIZE * LOAD_BATCH_SIZE");
404
+ static_assert(BK % LOAD_BATCH_SIZE == 0, "BK must be divisible by LOAD_BATCH_SIZE");
405
+ LoadVec reg_a[BK * BM / BLOCK_SIZE / LOAD_BATCH_SIZE];
406
+ LoadVec reg_b[BK * BN / BLOCK_SIZE / LOAD_BATCH_SIZE];
407
+ constexpr auto PK = ceil_div(BK, QUANT_SIZE);
408
+ static_assert(PK == 1, "PK must be 1 for now");
409
+ float reg_sa[ceil_div(PM, BLOCK_SIZE)];
410
+ float reg_sb[ceil_div(PN, BLOCK_SIZE)];
411
+
412
+ // threadblock swizzle
413
+ auto log_tile = 1;
414
+ auto block_idx_x = blockIdx.x >> log_tile;
415
+ auto block_idx_y = (blockIdx.y << log_tile) + ((blockIdx.x) & ((1 << (log_tile)) - 1));
416
+ if (block_idx_x >= ceil_div(N, BN) || block_idx_y >= ceil_div(M, BM)) {
417
+ return;
418
+ }
419
+
420
+ const int m = block_idx_y * BM;
421
+ const int n = block_idx_x * BN;
422
+ int k = 0;
423
+
424
+ auto global2reg = [&]() {
425
+ #pragma unroll
426
+ for (int reg = 0; reg < sizeof(reg_sa) / sizeof(float); reg++) {
427
+ // NOTE: must iter over reg to make compiler unroll the loop
428
+ // and thus be able to allocate reg_a on register instead of on scratch memroy
429
+ int t = tid + reg * BLOCK_SIZE;
430
+ // NOTE: don't branch here
431
+ // if (t > PM) {
432
+ // break;
433
+ // }
434
+ int i = t / PM;
435
+ int j = t % PM;
436
+ reg_sa[reg] = sa[k / QUANT_SIZE][(m + j)];
437
+ }
438
+ #pragma unroll
439
+ for (int reg = 0; reg < sizeof(reg_sb) / sizeof(float); reg++) {
440
+ // NOTE: must iter over reg to make compiler unroll the loop
441
+ // and thus be able to allocate reg_a on register instead of on scratch memroy
442
+ int t = tid + reg * BLOCK_SIZE;
443
+ // NOTE: don't branch here
444
+ // if (t > PN) {
445
+ // break;
446
+ // }
447
+ int i = t / PN;
448
+ int j = t % PN;
449
+ reg_sb[reg] = sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
450
+ }
451
+ #pragma unroll
452
+ for (int reg = 0; reg < sizeof(reg_a) / sizeof(LoadVec); reg++) {
453
+ // NOTE: must iter over reg to make compiler unroll the loop
454
+ // and thus be able to allocate reg_a on register instead of on scratch memroy
455
+ int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
456
+ int i = t / BK;
457
+ int j = t % BK;
458
+ reg_a[reg] = *(LoadVec *)&a[m + i][k + j];
459
+ }
460
+ #pragma unroll
461
+ for (int reg = 0; reg < sizeof(reg_b) / sizeof(LoadVec); reg++) {
462
+ // NOTE: must iter over reg to make compiler unroll the loop
463
+ // and thus be able to allocate reg_a on register instead of on scratch memroy
464
+ int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
465
+ int i = t / BK;
466
+ int j = t % BK;
467
+ reg_b[reg] = *(LoadVec *)&b[n + i][k + j];
468
+ }
469
+ };
470
+
471
+ auto reg2lds = [&]() {
472
+ #pragma unroll
473
+ for (int rega = 0; rega < sizeof(reg_sa) / sizeof(float); rega++) {
474
+ int ta = tid + rega * BLOCK_SIZE;
475
+ int j = ta % PM;
476
+ #pragma unroll
477
+ for (int regb = 0; regb < sizeof(reg_sb) / sizeof(float); regb++) {
478
+ int tb = tid + regb * BLOCK_SIZE;
479
+ int i = tb % PN;
480
+ s_s[i][j] = reg_sa[rega] * reg_sb[regb];
481
+ }
482
+ }
483
+ #pragma unroll
484
+ for (int reg = 0; reg < sizeof(reg_a) / sizeof(LoadVec); reg++) {
485
+ int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
486
+ int i = t / BK;
487
+ int j = t % BK;
488
+ *(LoadVec *)&s_a[i][j] = reg_a[reg];
489
+ }
490
+ #pragma unroll
491
+ for (int reg = 0; reg < sizeof(reg_b) / sizeof(LoadVec); reg++) {
492
+ int t = tid * LOAD_BATCH_SIZE + reg * BLOCK_SIZE * LOAD_BATCH_SIZE;
493
+ int i = t / BK;
494
+ int j = t % BK;
495
+ *(LoadVec *)&s_b[i][j] = reg_b[reg];
496
+ }
497
+ };
498
+
499
+ if constexpr (PIPELINE) {
500
+ global2reg();
501
+ }
502
+
503
+ // Initialize the output accumulator fragments to zero
504
+ #pragma unroll
505
+ for (int i = 0; i < FRAG_M_PER_WARP; i++) {
506
+ #pragma unroll
507
+ for (int j = 0; j < FRAG_N_PER_WARP; j++) {
508
+ wmma::fill_fragment(frag_r[i][j], 0.0f); // Use float literal
509
+ }
510
+ }
511
+
512
+ if constexpr (!PIPELINE) {
513
+ global2reg();
514
+ }
515
+
516
+ reg2lds();
517
+
518
+ for (int bk = 1; bk < iteration_over_k; bk++) {
519
+ k = bk * BK;
520
+
521
+ // Calculate remaining K for boundary checks if needed (not currently used by load_input)
522
+ // const int k_rem = K - k;
523
+
524
+ // Load data into shared memory
525
+ // load_input<in_data_type, BK, BM, K, M, BLOCK_SIZE, 32>(
526
+ // s_a, a, m, k);
527
+ // load_input<in_data_type, BK, BN, K, N, BLOCK_SIZE, 32>(
528
+ // s_b, b, n, k);
529
+ // Load scales into shared memory (using acc_data_type for s_s)
530
+ // load_scale<PM, PN, QM, QN, QK, QUANT_SIZE, BLOCK_SIZE, 1>(
531
+ // s_s, sa, sb, m, n, k);
532
+
533
+ if constexpr (PIPELINE) {
534
+ global2reg();
535
+ }
536
+
537
+ __syncthreads();
538
+
539
+ // Perform matrix multiplication using WMMA
540
+ wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB, PM, PN, BM, BN, BK, FRAG_M_PER_WARP,
541
+ FRAG_N_PER_WARP, FRAG_K, WMMA_M, WMMA_N, WMMA_K, WARP_M, WARP_N, BLOCK_SIZE, LOAD_BATCH_SIZE,
542
+ QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
543
+ s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
544
+ __syncthreads();
545
+
546
+ if constexpr (!PIPELINE) {
547
+ global2reg();
548
+ }
549
+
550
+ // __builtin_amdgcn_sched_barrier(0);
551
+
552
+ reg2lds();
553
+ }
554
+ __syncthreads();
555
+ wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB, PM, PN, BM, BN, BK, FRAG_M_PER_WARP, FRAG_N_PER_WARP,
556
+ FRAG_K, WMMA_M, WMMA_N, WMMA_K, WARP_M, WARP_N, BLOCK_SIZE, LOAD_BATCH_SIZE,
557
+ QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
558
+ s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
559
+ // Store results from accumulator fragments to global memory
560
+ store_result<acc_data_type, out_data_type, FragC, FragOut, WMMA_M, WMMA_N, BM, BN, M, N, FRAG_M_PER_WARP,
561
+ FRAG_N_PER_WARP>(c, frag_r, block_idx_y * BM, block_idx_x * BN, wid / WARP_N, wid % WARP_N);
562
+ };
563
+
564
+ }; // namespace gemm_kernel
565
+
566
+ HOST_CODE_BELOW
567
+
568
+ #ifndef PARAMETERIZE_LIBRARY
569
+ // Define type aliases to match those in the namespace
570
+ using fp8_type = gemm_kernel::in_data_type; // __hip_fp8_e4m3
571
+ using fp16_type = gemm_kernel::out_data_type; // __hip_bfloat16
572
+ using acc_data_type = gemm_kernel::acc_data_type; // float
573
+
574
+ // Define constants to match those in the namespace
575
+ constexpr int M = gemm_kernel::M; // 4096
576
+ constexpr int N = gemm_kernel::N; // 4096
577
+ constexpr int K = gemm_kernel::K; // 4096
578
+ constexpr int BM = gemm_kernel::BM; // 256
579
+ constexpr int BN = gemm_kernel::BN; // 128
580
+ constexpr int BK = gemm_kernel::BK; // 32
581
+ constexpr int BLOCK_SIZE = gemm_kernel::BLOCK_SIZE;
582
+ constexpr int QUANT_SIZE = gemm_kernel::QUANT_SIZE; // 128
583
+
584
+ // Define derived constants for the test
585
+ constexpr int QK = K / QUANT_SIZE;
586
+ constexpr int QM = M;
587
+ constexpr int QN = N / QUANT_SIZE;
588
+
589
+ // Helper function to check HIP errors
590
+ #define CHECK_HIP_ERROR(val) check((val), #val, __FILE__, __LINE__)
591
+ template <typename T> void check(T err, const char *const func, const char *const file, const int line) {
592
+ if (err != hipSuccess) {
593
+ fprintf(stderr, "HIP Runtime Error at: %s:%d\n", file, line);
594
+ fprintf(stderr, "%s %s\n", hipGetErrorString(err), func);
595
+ exit(1);
596
+ }
597
+ }
598
+
599
+ // Define a macro to check HIP errors
600
+ #define HIP_CALL(call) \
601
+ do { \
602
+ hipError_t err = call; \
603
+ if (err != hipSuccess) { \
604
+ fprintf(stderr, "HIP Error: %s at %s:%d\n", hipGetErrorString(err), __FILE__, __LINE__); \
605
+ exit(EXIT_FAILURE); \
606
+ } \
607
+ } while (0)
608
+
609
+ // CPU matrix multiplication implementation for result verification
610
+ void cpu_gemm(const fp8_type a[K][M], const fp8_type b[K][N], fp16_type c[M][N], const float sa[QK][QM],
611
+ const float sb[QK][QN]) {
612
+ float(*rc)[N] = new float[M][N];
613
+ for (int m = 0; m < M; ++m) {
614
+ for (int n = 0; n < N; ++n) {
615
+ rc[m][n] = 0.0f;
616
+ }
617
+ }
618
+ for (int k = 0; k < K; ++k) {
619
+ for (int m = 0; m < M; ++m) {
620
+ for (int n = 0; n < N; ++n) {
621
+ float scale = sa[k / QUANT_SIZE][m] * sb[k / QUANT_SIZE][n / QUANT_SIZE];
622
+ rc[m][n] += (scale * (float)a[k][m] * (float)b[k][n]);
623
+ }
624
+ }
625
+ }
626
+ for (int m = 0; m < M; ++m) {
627
+ for (int n = 0; n < N; ++n) {
628
+ c[m][n] = (fp16_type)rc[m][n];
629
+ }
630
+ }
631
+ delete[] rc;
632
+ }
633
+
634
+ int main() {
635
+ // Allocate host memory
636
+ fp8_type(*h_a)[M] = new fp8_type[K][M];
637
+ fp8_type(*h_b)[N] = new fp8_type[K][N];
638
+ fp16_type(*h_c)[N] = new fp16_type[M][N];
639
+ fp16_type(*h_c_ref)[N] = new fp16_type[M][N];
640
+
641
+ // Allocate host memory for quantization scale factors
642
+ float(*h_sa)[QM] = new float[QK][QM];
643
+ float(*h_sb)[QN] = new float[QK][QN];
644
+
645
+ // Initialize input data
646
+ for (int i = 0; i < K; ++i) {
647
+ for (int j = 0; j < M; ++j) {
648
+ h_a[i][j] = (fp8_type)((rand() % 10000) / 10000.0f);
649
+ }
650
+ }
651
+ for (int i = 0; i < K; ++i) {
652
+ for (int j = 0; j < N; ++j) {
653
+ h_b[i][j] = (fp8_type)((rand() % 10000) / 10000.0f);
654
+ }
655
+ }
656
+
657
+ // Initialize quantization scale factors
658
+ for (int i = 0; i < QK; ++i) {
659
+ for (int j = 0; j < QM; ++j) {
660
+ h_sa[i][j] = 1.0f;
661
+ }
662
+ }
663
+ for (int i = 0; i < QK; ++i) {
664
+ for (int j = 0; j < QN; ++j) {
665
+ h_sb[i][j] = 1.0f;
666
+ }
667
+ }
668
+
669
+ // Allocate device memory
670
+ fp8_type(*d_a)[K];
671
+ fp8_type(*d_b)[K];
672
+ fp16_type(*d_c)[N];
673
+ float(*d_sa)[QM];
674
+ float(*d_sb)[QN];
675
+
676
+ CHECK_HIP_ERROR(hipMalloc(&d_a, K * M * sizeof(fp8_type)));
677
+ CHECK_HIP_ERROR(hipMalloc(&d_b, K * N * sizeof(fp8_type)));
678
+ CHECK_HIP_ERROR(hipMalloc(&d_c, M * N * sizeof(fp16_type)));
679
+ CHECK_HIP_ERROR(hipMalloc(&d_sa, QK * QM * sizeof(float)));
680
+ CHECK_HIP_ERROR(hipMalloc(&d_sb, QK * QN * sizeof(float)));
681
+
682
+ // Copy data from host memory to device memory
683
+ CHECK_HIP_ERROR(hipMemcpy(d_a, h_a, K * M * sizeof(fp8_type), hipMemcpyHostToDevice));
684
+ CHECK_HIP_ERROR(hipMemcpy(d_b, h_b, K * N * sizeof(fp8_type), hipMemcpyHostToDevice));
685
+ CHECK_HIP_ERROR(hipMemcpy(d_sa, h_sa, QK * QM * sizeof(float), hipMemcpyHostToDevice));
686
+ CHECK_HIP_ERROR(hipMemcpy(d_sb, h_sb, QK * QN * sizeof(float), hipMemcpyHostToDevice));
687
+
688
+ // Calculate grid and block sizes - ensure coverage of the entire matrix
689
+ dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
690
+ dim3 block(BLOCK_SIZE);
691
+
692
+ // Ensure block size is a multiple of 32, since warp size is 32
693
+ if (BLOCK_SIZE % 32 != 0) {
694
+ printf("Error: Block size must be a multiple of warp size (32)\n");
695
+ return 1;
696
+ }
697
+
698
+ // Check if device supports required compute capability
699
+ int deviceId;
700
+ HIP_CALL(hipGetDevice(&deviceId));
701
+ hipDeviceProp_t deviceProp;
702
+ HIP_CALL(hipGetDeviceProperties(&deviceProp, deviceId));
703
+
704
+ if (deviceProp.major < 7) {
705
+ printf("Error: This kernel requires a GPU with compute capability 7.0 or higher\n");
706
+ return 1;
707
+ }
708
+
709
+ printf("Running GEMM kernel with grid(%d,%d), block(%d)...\n", grid.x, grid.y, block.x);
710
+
711
+ // Query and print kernel and device information
712
+ printf("Querying kernel and device information...\n");
713
+
714
+ // Get device properties
715
+ HIP_CALL(hipGetDeviceProperties(&deviceProp, deviceId));
716
+ printf("Device Name: %s\n", deviceProp.name);
717
+ printf("Total Global Memory: %lu bytes\n", deviceProp.totalGlobalMem);
718
+ printf("Shared Memory per Block: %lu bytes\n", deviceProp.sharedMemPerBlock);
719
+ printf("Registers per Block: %d\n", deviceProp.regsPerBlock);
720
+ printf("Warp Size: %d\n", deviceProp.warpSize);
721
+ printf("Max Threads per Block: %d\n", deviceProp.maxThreadsPerBlock);
722
+ printf("Max Threads per Multiprocessor: %d\n", deviceProp.maxThreadsPerMultiProcessor);
723
+ printf("Number of Multiprocessors: %d\n", deviceProp.multiProcessorCount);
724
+
725
+ // Query kernel attributes
726
+ hipFuncAttributes funcAttr;
727
+ HIP_CALL(hipFuncGetAttributes(&funcAttr, (const void *)gemm_kernel::gemm_kernel));
728
+ printf("Kernel Attributes:\n");
729
+ printf(" Shared Memory Size: %lu bytes\n", funcAttr.sharedSizeBytes);
730
+ printf(" Number of Registers: %d\n", funcAttr.numRegs);
731
+ printf(" Max Threads per Block: %d\n", funcAttr.maxThreadsPerBlock);
732
+ printf(" Local Memory Size: %lu bytes\n", funcAttr.localSizeBytes);
733
+
734
+ // Zero the C matrix before launching kernel
735
+ CHECK_HIP_ERROR(hipMemset(d_c, 0, M * N * sizeof(fp16_type)));
736
+
737
+ // Perform warmup runs
738
+ printf("Performing warmup runs...\n");
739
+ gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
740
+ CHECK_HIP_ERROR(hipDeviceSynchronize());
741
+ gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
742
+ CHECK_HIP_ERROR(hipDeviceSynchronize());
743
+
744
+ // Declare and create timing events
745
+ hipEvent_t start, stop;
746
+ HIP_CALL(hipEventCreate(&start));
747
+ HIP_CALL(hipEventCreate(&stop));
748
+
749
+ // Ensure device synchronization before formal timing
750
+ CHECK_HIP_ERROR(hipDeviceSynchronize());
751
+ HIP_CALL(hipEventRecord(start));
752
+
753
+ // Launch kernel
754
+ printf("Launching kernel...\n");
755
+ gemm_kernel::gemm_kernel<<<grid, block>>>(d_a, d_b, d_c, d_sa, d_sb);
756
+
757
+ // Record end time and calculate execution time
758
+ HIP_CALL(hipEventRecord(stop));
759
+
760
+ // Record end time and calculate execution time
761
+ HIP_CALL(hipEventSynchronize(stop));
762
+ float milliseconds = 0;
763
+ HIP_CALL(hipEventElapsedTime(&milliseconds, start, stop));
764
+ printf("Kernel execution time: %f ms\n", milliseconds);
765
+
766
+ // Check HIP errors
767
+ CHECK_HIP_ERROR(hipGetLastError());
768
+
769
+ // Calculate GPU performance metrics
770
+ double operations = 2.0 * M * N * K; // Each multiply-add operation counts as 2 floating-point operations
771
+ double seconds = milliseconds / 1000.0;
772
+ double tflops = (operations / seconds) / 1e12;
773
+ printf("GPU Performance: %.2f TFLOPS\n", tflops);
774
+
775
+ return 0;
776
+
777
+ // Copy results from device memory back to host memory
778
+ CHECK_HIP_ERROR(hipMemcpy(h_c, d_c, M * N * sizeof(fp16_type), hipMemcpyDeviceToHost));
779
+
780
+ // Calculate reference results
781
+ printf("Computing reference result on CPU...\n");
782
+ cpu_gemm(h_a, h_b, h_c_ref, h_sa, h_sb);
783
+
784
+ // Print the first 10 values for comparison
785
+ printf("First 10 values (GPU vs CPU):\n");
786
+ int print_count = 0;
787
+ for (int i = 0; i < M && print_count < 10; ++i) {
788
+ for (int j = 0; j < N && print_count < 10; ++j) {
789
+ printf(" [%d, %d]: GPU=%f, CPU=%f\n", i, j, (float)h_c[i][j], (float)h_c_ref[i][j]);
790
+ print_count++;
791
+ }
792
+ }
793
+
794
+ // Verify results
795
+ printf("Verifying results...\n");
796
+ int errors = 0;
797
+ float max_abs_diff = 0.0f;
798
+ float max_rel_diff = 0.0f;
799
+ struct ErrorInfo {
800
+ int row, col;
801
+ float gpu_val, cpu_val, abs_diff, rel_diff;
802
+ };
803
+ ErrorInfo first_10_errors[10];
804
+ ErrorInfo max_10_errors[10] = {};
805
+
806
+ // Add a configurable variable for the number of errors to output
807
+ int max_errors_to_output = 10; // You can modify this value as needed
808
+
809
+ for (int i = 0; i < M; ++i) {
810
+ for (int j = 0; j < N; ++j) {
811
+ float gpu_val = (float)h_c[i][j];
812
+ float cpu_val = (float)h_c_ref[i][j];
813
+ float abs_diff;
814
+ float rel_diff;
815
+
816
+ if (std::isnan(gpu_val) || std::isnan(cpu_val)) {
817
+ abs_diff = INFINITY;
818
+ rel_diff = INFINITY;
819
+ } else {
820
+ abs_diff = abs(gpu_val - cpu_val);
821
+ rel_diff = abs_diff / (abs(cpu_val) + FLT_EPSILON);
822
+ }
823
+
824
+ // Track max absolute and relative differences
825
+ max_abs_diff = fmaxf(max_abs_diff, abs_diff);
826
+ max_rel_diff = fmaxf(max_rel_diff, rel_diff);
827
+
828
+ // Record first 10 errors
829
+ if (errors < max_errors_to_output && (rel_diff > 1e-2 || abs_diff > 1e-3)) {
830
+ first_10_errors[errors] = {i, j, gpu_val, cpu_val, abs_diff, rel_diff};
831
+ }
832
+
833
+ // Track top 10 largest errors
834
+ if (rel_diff > 1e-2 || abs_diff > 1e-3) {
835
+ errors++;
836
+ for (int k = 0; k < max_errors_to_output; ++k) {
837
+ if (abs_diff > max_10_errors[k].abs_diff) {
838
+ for (int l = max_errors_to_output - 1; l > k; --l) {
839
+ max_10_errors[l] = max_10_errors[l - 1];
840
+ }
841
+ max_10_errors[k] = {i, j, gpu_val, cpu_val, abs_diff, rel_diff};
842
+ break;
843
+ }
844
+ }
845
+ }
846
+ }
847
+ }
848
+
849
+ // Print first 10 errors
850
+ printf("First %d errors:\n", max_errors_to_output);
851
+ for (int i = 0; i < fmin(errors, max_errors_to_output); ++i) {
852
+ printf("Error at [%d, %d]: GPU=%f, CPU=%f, AbsDiff=%f, RelDiff=%f\n", first_10_errors[i].row,
853
+ first_10_errors[i].col, first_10_errors[i].gpu_val, first_10_errors[i].cpu_val,
854
+ first_10_errors[i].abs_diff, first_10_errors[i].rel_diff);
855
+ }
856
+
857
+ // Print top 10 largest errors
858
+ printf("Top %d largest errors:\n", max_errors_to_output);
859
+ for (int i = 0; i < max_errors_to_output && max_10_errors[i].abs_diff > 0; ++i) {
860
+ printf("Error at [%d, %d]: GPU=%f, CPU=%f, AbsDiff=%f, RelDiff=%f\n", max_10_errors[i].row,
861
+ max_10_errors[i].col, max_10_errors[i].gpu_val, max_10_errors[i].cpu_val, max_10_errors[i].abs_diff,
862
+ max_10_errors[i].rel_diff);
863
+ }
864
+
865
+ printf("Max abs_diff: %f, Max rel_diff: %f\n", max_abs_diff, max_rel_diff);
866
+ if (errors == 0) {
867
+ printf("Test PASSED!\n");
868
+ } else {
869
+ printf("Test FAILED with %d errors\n", errors);
870
+ }
871
+
872
+ // Calculate performance
873
+ double flops = 2.0 * M * N * K;
874
+ double gflops = (flops * 1e-9) / (milliseconds * 1e-3);
875
+ printf("Performance: %.2f GFLOPS\n", gflops);
876
+
877
+ // Free memory
878
+ delete[] h_a;
879
+ delete[] h_b;
880
+ delete[] h_c;
881
+ delete[] h_c_ref;
882
+ delete[] h_sa;
883
+ delete[] h_sb;
884
+ HIP_CALL(hipFree(d_a));
885
+ HIP_CALL(hipFree(d_b));
886
+ HIP_CALL(hipFree(d_c));
887
+ HIP_CALL(hipFree(d_sa));
888
+ HIP_CALL(hipFree(d_sb));
889
+ HIP_CALL(hipEventDestroy(start));
890
+ HIP_CALL(hipEventDestroy(stop));
891
+
892
+ return 0;
893
+ }
894
+ #endif
895
+ #pragma clang diagnostic pop
896
+ #endif
gemm/gemm_kernel_legacy.h ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Legacy version of gemm kernel, support all shape and various value of parameters (BM, BN, BK, etc.)
2
+ // It has been replace with faster pipeline version.
3
+ #pragma once
4
+ #include <cstdio>
5
+ #include "../include/gpu_libs.h"
6
+ #include "../include/gpu_types.h"
7
+ #include "../src/utils/arithmetic.h"
8
+ #include "../include/clangd_workaround.h"
9
+ #include <cstdlib>
10
+ #include <cfloat>
11
+
12
+ DEVICE_CODE_BELOW
13
+ namespace gemm_kernel_legacy {
14
+
15
+
16
+
17
+ template <typename data_type, int BATCH_SIZE>
18
+ __device__ inline void read_batch(data_type *dst, const data_type *src) {
19
+ if constexpr ((sizeof(data_type) * BATCH_SIZE) == 2 * sizeof(ulong4)) {
20
+ *(reinterpret_cast<ulong4 *>(dst) + 0) = *(reinterpret_cast<const ulong4 *>(src) + 0);
21
+ *(reinterpret_cast<ulong4 *>(dst) + 1) = *(reinterpret_cast<const ulong4 *>(src) + 1);
22
+ } else if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
23
+ *reinterpret_cast<ulong4 *>(dst) = *reinterpret_cast<const ulong4 *>(src);
24
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
25
+ *reinterpret_cast<ulong2 *>(dst) = *reinterpret_cast<const ulong2 *>(src);
26
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
27
+ *reinterpret_cast<ulong1 *>(dst) = *reinterpret_cast<const ulong1 *>(src);
28
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
29
+ *reinterpret_cast<uint1 *>(dst) = *reinterpret_cast<const uint1 *>(src);
30
+ } else {
31
+ #pragma unroll
32
+ for (int b = 0; b < BATCH_SIZE; ++b) {
33
+ dst[b] = src[b];
34
+ }
35
+ }
36
+ }
37
+
38
+ template <typename data_type, int BATCH_SIZE>
39
+ __device__ inline void zero_batch(data_type *dst) {
40
+ if constexpr ((sizeof(data_type) * BATCH_SIZE) == sizeof(ulong4)) {
41
+ *reinterpret_cast<ulong4 *>(dst) = ulong4{};
42
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong2)) {
43
+ *reinterpret_cast<ulong2 *>(dst) = ulong2{};
44
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(ulong1)) {
45
+ *reinterpret_cast<ulong1 *>(dst) = ulong1{};
46
+ } else if constexpr (sizeof(data_type) * BATCH_SIZE == sizeof(uint1)) {
47
+ *reinterpret_cast<uint *>(dst) = uint{};
48
+ } else {
49
+ #pragma unroll
50
+ for (int b = 0; b < BATCH_SIZE; ++b) {
51
+ dst[b] = 0;
52
+ }
53
+ }
54
+ }
55
+
56
+ template <typename data_type, int DST_Y, int DST_X, int SRC_Y, int SRC_X, int BLOCK_DIM, int BATCH_SIZE>
57
+ __device__ inline void load_input(data_type dst[DST_Y][DST_X], const data_type src[SRC_Y][SRC_X],
58
+ const int begin_x, const int begin_y) {
59
+ static_assert(BATCH_SIZE > 0);
60
+ /**
61
+ Consider (SRC_X % DST_X == 0) && (SRC_Y % DST_Y == 0)
62
+ Step 1:
63
+ [ ][***][ ][ ]
64
+ [ ][ ][ ][ ]
65
+ [ ][ ][ ][ ]
66
+ [ ][ ][ ][ ]
67
+ Step 2:
68
+ [ ][ ][ ][ ]
69
+ [ ][***][ ][ ]
70
+ [ ][ ][ ][ ]
71
+ [ ][ ][ ][ ]
72
+ */
73
+ static_assert((SRC_X % BATCH_SIZE == 0) && (SRC_Y % BATCH_SIZE == 0));
74
+ static_assert((DST_X % BATCH_SIZE == 0) && (DST_Y % BATCH_SIZE == 0));
75
+ static_assert(BATCH_SIZE <= DST_X && DST_X % BATCH_SIZE == 0);
76
+ const int begin_idx = threadIdx.x * BATCH_SIZE;
77
+ const constexpr int total_elements = DST_X * DST_Y;
78
+ const constexpr int elements_per_step = BLOCK_DIM * BATCH_SIZE;
79
+ // FIXME: loop unrolling
80
+ #pragma unroll
81
+ for (int k = begin_idx; k < total_elements; k += elements_per_step) {
82
+ int l_kx = k % DST_X;
83
+ int l_ky = k / DST_X;
84
+ int g_kx = l_kx + begin_x;
85
+ int g_ky = l_ky + begin_y;
86
+ auto *dst_flatten = &dst[l_ky][l_kx];
87
+ // const auto *src_flatten = &src[g_ky][g_kx];
88
+ // read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
89
+ if (((SRC_X % DST_X == 0) || (g_kx < SRC_X)) && ((SRC_Y % DST_Y == 0) || (g_ky < SRC_Y))) {
90
+ const auto *src_flatten = &src[g_ky][g_kx];
91
+ read_batch<data_type, BATCH_SIZE>(dst_flatten, src_flatten);
92
+ } else {
93
+ zero_batch<data_type, BATCH_SIZE>(dst_flatten);
94
+ }
95
+ }
96
+ }
97
+
98
+ template <int PM, int PN, int QM, int QN, int QK, int QUANT_SIZE, int BLOCK_SIZE, int BATCH_SIZE>
99
+ __device__ void load_scale(float s_s[PM][PN], const float sa[QK][QM], const float sb[QK][QN],
100
+ const int m, const int n, const int k) {
101
+ constexpr int total_elements = PM * PN;
102
+ constexpr int elements_per_step = BLOCK_SIZE * BATCH_SIZE;
103
+ // static_assert(PN % BATCH_SIZE)
104
+
105
+ const int begin_idx = threadIdx.x * BATCH_SIZE;
106
+ #pragma unroll
107
+ for (int idx = begin_idx; idx < total_elements; idx += elements_per_step) {
108
+ static_assert(BATCH_SIZE == 1);
109
+ int i = idx / PN;
110
+ int j = idx % PN;
111
+ if (((QM % PM == 0) || (m + i < QM)) && ((QN % PN == 0) || ((n + j) / QUANT_SIZE < QN))) {
112
+ s_s[i][j] = sa[k / QUANT_SIZE][(m + i)] * sb[k / QUANT_SIZE][(n) / QUANT_SIZE + j];
113
+ } else {
114
+ s_s[i][j] = 1.0f;
115
+ }
116
+ }
117
+
118
+ }
119
+
120
+ template <typename in_data_type, typename acc_data_type,
121
+ typename FragC, typename FragA, typename FragB,
122
+ int PM, int PN,
123
+ int BM, int BN, int BK,
124
+ int FRAG_M, int FRAG_N, int FRAG_K,
125
+ int WMMA_M, int WMMA_N, int WMMA_K,
126
+ int WARP_M, int WARP_N,
127
+ int BLOCK_SIZE, int BATCH_SIZE, int QUANT_SIZE>
128
+ __device__ void wmma_compute(
129
+ const in_data_type s_a[BK][BM],
130
+ const in_data_type s_b[BK][BN],
131
+ const float s_s[PM][PN],
132
+ FragC frag_r[FRAG_M][FRAG_N],
133
+ const int comp_c_frag_m,
134
+ const int comp_c_frag_n
135
+ ) {
136
+ FragA frag_a[FRAG_K][FRAG_M];
137
+ FragB frag_b[FRAG_K][FRAG_N];
138
+
139
+ // Spilt k over BK
140
+ for (int k = 0; k < FRAG_K; ++k) {
141
+ #pragma unroll
142
+ for (int i = 0; i < FRAG_M; ++i) {
143
+ int s_a_row = k * WMMA_K;
144
+ int s_a_col = (comp_c_frag_m * FRAG_M + i) * WMMA_M;
145
+ wmma::load_matrix_sync(frag_a[k][i], &s_a[s_a_row][s_a_col], BM);
146
+ }
147
+ #pragma unroll
148
+ for (int j = 0; j < FRAG_N; ++j) {
149
+ int s_b_row = k * WMMA_K;
150
+ int s_b_col = (comp_c_frag_n * FRAG_N + j) * WMMA_N;
151
+ wmma::load_matrix_sync(frag_b[k][j], &s_b[s_b_row][s_b_col], BN);
152
+ }
153
+ }
154
+
155
+ #pragma unroll
156
+ for (int i = 0; i < FRAG_M; i++) {
157
+ #pragma unroll
158
+ for (int j = 0; j < FRAG_N; j++) {
159
+ FragC frag_c;
160
+ wmma::fill_fragment(frag_c, 0.0F);
161
+ #pragma unroll
162
+ for (int k = 0; k < FRAG_K; ++k) {
163
+ wmma::mma_sync(frag_c, frag_a[k][i], frag_b[k][j], frag_c);
164
+ }
165
+ #pragma unroll
166
+ for (int k = 0; k < FragC::num_elements; ++k) {
167
+ #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
168
+ int m = ((threadIdx.x & 16) >> 1) | (k & 7) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
169
+ #else // CDNA3, WAVE_SIZE = 64
170
+ int m = ((threadIdx.x & 48) >> 2) | (k & 3) | (comp_c_frag_m * FRAG_M + i) * WMMA_M;
171
+ #endif
172
+ int n = ((threadIdx.x & 15) | (comp_c_frag_n * FRAG_N + j) * WMMA_N) / QUANT_SIZE;
173
+ float scale = s_s[m][n];
174
+ frag_r[i][j].x[k] += (acc_data_type)scale * (acc_data_type)frag_c.x[k];
175
+ }
176
+ }
177
+ }
178
+ }
179
+
180
+
181
+ template <typename acc_data_type, typename out_data_type,
182
+ typename FragC, typename FragOut, int WMMA_M, int WMMA_N,
183
+ int BM, int BN, int M, int N, int FRAG_M, int FRAG_N>
184
+ __device__ inline void store_result(
185
+ out_data_type c[M][N],
186
+ FragC frag_r[FRAG_M][FRAG_N],
187
+ const int m,
188
+ const int n,
189
+ const int comp_c_frag_m,
190
+ const int comp_c_frag_n
191
+ ) {
192
+ #pragma unroll
193
+ for (int i = 0; i < FRAG_M; i++) {
194
+ #pragma unroll
195
+ for (int j = 0; j < FRAG_N; j++) {
196
+ int frag_m = comp_c_frag_m * FRAG_M + i;
197
+ int frag_n = comp_c_frag_n * FRAG_N + j;
198
+ int row = m + frag_m * WMMA_M;
199
+ int col = n + frag_n * WMMA_N;
200
+ if (((M % BM == 0) || (row < M)) && ((N % BN == 0) || (col < N))) {
201
+ out_data_type *c_ptr = &c[row][col];
202
+ if constexpr (sizeof(acc_data_type) == sizeof(out_data_type)) {
203
+ wmma::store_matrix_sync(reinterpret_cast<out_data_type*>(c_ptr), frag_r[i][j], N, wmma::mem_row_major);
204
+ } else if constexpr (sizeof(out_data_type) == sizeof(half)) {
205
+ FragOut frag_out;
206
+ static_assert(sizeof(half) == sizeof(out_data_type));
207
+ static_assert(FragOut::num_elements == FragC::num_elements);
208
+ for (int k=0;k<FragOut::num_elements;++k) {
209
+ __hip_bfloat16 reg = frag_r[i][j].x[k];
210
+ frag_out.x[k] = *reinterpret_cast<half*>(&reg);
211
+ }
212
+ wmma::store_matrix_sync(reinterpret_cast<half*>(c_ptr), frag_out, N, wmma::mem_row_major);
213
+ } else {
214
+ static_assert(0, "Unsupported data type for output");
215
+ }
216
+
217
+ }
218
+ }
219
+ }
220
+ }
221
+
222
+ // a dummy template to allow inlcuding this file
223
+ template<int Dummy=0>
224
+ __global__ void reduce(uint32_t m, uint32_t n, uint32_t splitk, const float *c_splitk, __hip_bfloat16 *c) {
225
+ auto tid = blockIdx.x * blockDim.x + threadIdx.x;
226
+ if (tid >= m * n) {
227
+ return;
228
+ }
229
+ float sum = 0;
230
+ for (auto i = 0; i < splitk; ++i) {
231
+ sum += c_splitk[i * (m * n) + tid];
232
+ }
233
+ c[tid] = sum;
234
+ }
235
+
236
+
237
+ #ifdef PARAMETERIZE_LIBRARY
238
+ template <
239
+ typename in_data_type,
240
+ typename acc_data_type, // Accumulator type (e.g., float)
241
+ typename out_data_type, // Output type (e.g., __hip_bfloat16)
242
+ int M, int N, int K, // Matrix dimensions
243
+ int BM, int BN, int BK, // Tile dimensions
244
+ int QUANT_SIZE, // Quantization block size
245
+ int BLOCK_SIZE, // Block size
246
+ int WARP_M, int WARP_N // Warp dimensions
247
+ >
248
+ #else
249
+ using in_data_type = __FP8_TYPE;
250
+ using out_data_type = __BF16_TYPE;
251
+ using acc_data_type = float;
252
+ // constexpr int M = 4096, N = 4096, K = 4096;
253
+ constexpr int M = 96, N = 1024, K = 1024;
254
+ // constexpr int M = 512, N = 512, K = 512;
255
+ constexpr int BM = 64, BN = 256, BK = 32;
256
+ constexpr int QUANT_SIZE = 128, BLOCK_SIZE = 256;
257
+ #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
258
+ constexpr int WARP_M = 4, WARP_N = 2;
259
+ #else // CDNA3, WAVE_SIZE = 64
260
+ constexpr int WARP_M = 2, WARP_N = 2;
261
+ #endif
262
+ #endif // End of parameterization
263
+ __global__ void gemm_kernel(
264
+ const in_data_type a[K][M],
265
+ const in_data_type b[K][N],
266
+ out_data_type c[M][N],
267
+ const float sa[ceil_div(K, QUANT_SIZE)][M / 1 ], // Assuming M is divisible by 1 (always true)
268
+ const float sb[ceil_div(K, QUANT_SIZE)][ceil_div(N, QUANT_SIZE)]
269
+ ) {
270
+ // --- Start: Derived parameters and constants ---
271
+ constexpr int WMMA_M = 16; // Fixed WMMA dimension M
272
+ constexpr int WMMA_N = 16; // Fixed WMMA dimension N
273
+ constexpr int WMMA_K = 32; // Fixed WMMA dimension K (for FP8)
274
+
275
+ // WARP_M/N define the 2D arrangement of warps in the block grid.
276
+ // These might need adjustment based on BLOCK_DIM_X/Y strategy.
277
+ // Using fixed values based on the non-parameterized version for now.
278
+ // TODO: Derive WARP_M/N from BLOCK_DIM_X/Y if a flexible strategy is needed.
279
+ constexpr int WARP_NUM = WARP_M * WARP_N; // Total warps per block
280
+
281
+ // Assertion: Check if the assumed warp layout matches the block size
282
+ static_assert(WARP_NUM * WAVE_SIZE == BLOCK_SIZE, "WARP_M * WARP_N * WAVE_SIZE must equal BLOCK_SIZE");
283
+
284
+ // Fragments per warp
285
+ constexpr int FRAG_M_PER_WARP = BM / WMMA_M / WARP_M;
286
+ constexpr int FRAG_N_PER_WARP = BN / WMMA_N / WARP_N;
287
+ constexpr int FRAG_K = BK / WMMA_K; // Fragments along K dimension tile
288
+
289
+ static_assert(BM % (WMMA_M * WARP_M) == 0, "BM must be divisible by WMMA_M * WARP_M");
290
+ static_assert(BN % (WMMA_N * WARP_N) == 0, "BN must be divisible by WMMA_N * WARP_N");
291
+ static_assert(BK % WMMA_K == 0, "BK must be divisible by WMMA_K");
292
+ static_assert(BK >= 32, "BK must be at least 32");
293
+ // --- End: Derived parameters and constants ---
294
+
295
+ constexpr int QM = M; // Dimension M for scale A
296
+ constexpr int QN = ceil_div(N, QUANT_SIZE); // Dimension N for scale B (quantized)
297
+ constexpr int QK = ceil_div(K, QUANT_SIZE); // Dimension K for scales (quantized)
298
+ constexpr int PM = BM; // Block size M for scale A * B
299
+ constexpr int PN = ceil_div(BN, QUANT_SIZE); // Block size N for scale A * B
300
+
301
+ // Ensure derived fragment counts are positive
302
+ static_assert(FRAG_M_PER_WARP > 0, "FRAG_M_PER_WARP must be positive");
303
+ static_assert(FRAG_N_PER_WARP > 0, "FRAG_N_PER_WARP must be positive");
304
+ static_assert(FRAG_K > 0, "FRAG_K must be positive");
305
+
306
+ using FragA = wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::col_major>;
307
+ using FragB = wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, in_data_type, wmma::row_major>;
308
+ using FragC = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, acc_data_type>;
309
+ using FragOut = wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half>; // Output uses half for storage via bfloat16 reinterpret
310
+
311
+ __shared__ in_data_type s_a[BK][BM];
312
+ __shared__ in_data_type s_b[BK][BN];
313
+ __shared__ acc_data_type s_s[PM][PN]; // Accumulator type for scales
314
+ FragC frag_r[FRAG_M_PER_WARP][FRAG_N_PER_WARP]; // Accumulator fragments
315
+
316
+ // handle splitk
317
+ a += blockIdx.z * K;
318
+ b += blockIdx.z * K;
319
+ c += blockIdx.z * M;
320
+ sa += blockIdx.z * QK;
321
+ sb += blockIdx.z * QK;
322
+
323
+ int tid = threadIdx.x; // Linear thread ID within the block
324
+ int wid = tid / WAVE_SIZE; // Warp ID within the block
325
+
326
+ // Initialize the output accumulator fragments to zero
327
+ #pragma unroll
328
+ for (int i = 0; i < FRAG_M_PER_WARP; i++) {
329
+ #pragma unroll
330
+ for (int j = 0; j < FRAG_N_PER_WARP; j++) {
331
+ wmma::fill_fragment(frag_r[i][j], 0.0f); // Use float literal
332
+ }
333
+ }
334
+
335
+ // Spilt and compute fragments
336
+ constexpr int iteration_over_k = ceil_div(K, BK); // Use ceil_div for potentially non-divisible K
337
+ constexpr int LOAD_BATCH_SIZE = (2 * sizeof(float4) / sizeof(in_data_type)) > 0 ? (2 * sizeof(float4) / sizeof(in_data_type)) : 1; // Ensure batch size > 0
338
+ static_assert(LOAD_BATCH_SIZE > 0, "LOAD_BATCH_SIZE must be positive");
339
+
340
+ for (int bk = 0; bk < iteration_over_k; bk++) {
341
+ const int m = blockIdx.y * BM;
342
+ const int n = blockIdx.x * BN;
343
+ const int k = bk * BK;
344
+
345
+ // Calculate remaining K for boundary checks if needed (not currently used by load_input)
346
+ // const int k_rem = K - k;
347
+
348
+ // Load data into shared memory
349
+ load_input<in_data_type, BK, BM, K, M, BLOCK_SIZE, LOAD_BATCH_SIZE>(
350
+ s_a, a, m, k);
351
+ load_input<in_data_type, BK, BN, K, N, BLOCK_SIZE, LOAD_BATCH_SIZE>(
352
+ s_b, b, n, k);
353
+ // Load scales into shared memory (using acc_data_type for s_s)
354
+ load_scale<PM, PN, QM, QN, QK, QUANT_SIZE, BLOCK_SIZE, 1>(
355
+ s_s, sa, sb, m, n, k);
356
+ __syncthreads();
357
+
358
+ // Perform matrix multiplication using WMMA
359
+ wmma_compute<in_data_type, acc_data_type, FragC, FragA, FragB,
360
+ PM, PN, BM, BN, BK, FRAG_M_PER_WARP, FRAG_N_PER_WARP, FRAG_K,
361
+ WMMA_M, WMMA_N, WMMA_K,
362
+ WARP_M, WARP_N,
363
+ BLOCK_SIZE, LOAD_BATCH_SIZE, QUANT_SIZE>( // Pass calculated BLOCK_SIZE and LOAD_BATCH_SIZE
364
+ s_a, s_b, s_s, frag_r, wid / WARP_N, wid % WARP_N);
365
+ __syncthreads();
366
+ }
367
+ // Store results from accumulator fragments to global memory
368
+ store_result<acc_data_type, out_data_type, FragC, FragOut,
369
+ WMMA_M, WMMA_N, BM, BN, M, N, FRAG_M_PER_WARP, FRAG_N_PER_WARP>(
370
+ c, frag_r, blockIdx.y * BM, blockIdx.x * BN,
371
+ wid / WARP_N, wid % WARP_N);
372
+
373
+
374
+ };
375
+
376
+
377
+ }; // namespace gemm_kernel_legacy
gemm/gemm_launcher.hip ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Wrapped of gemm kernel launcher.
2
+ #include <unistd.h>
3
+ #include <chrono>
4
+ #define PARAMETERIZE_LIBRARY
5
+ #include "gemm_kernel.h"
6
+ #include "gemm_kernel_legacy.h"
7
+ #include "transpose_kernel.h"
8
+ #undef PARAMETERIZE_LIBRARY
9
+ #include "../include/gpu_types.h"
10
+ #include "../include/timer.h"
11
+ #include "../tests/checker/metrics.h"
12
+ #include <iostream>
13
+
14
+ #include <stdio.h>
15
+
16
+ HOST_CODE_BELOW
17
+
18
+ std::vector<std::shared_ptr<KernelTimer>> timers;
19
+
20
+ using namespace std;
21
+
22
+ float *c_splitk = nullptr;
23
+ __FP8_TYPE *a_trans = nullptr;
24
+ __FP8_TYPE *b_trans = nullptr;
25
+ constexpr int MAX_MATRIX_M = 6144;
26
+ constexpr int MAX_MATRIX_N = 7168;
27
+ constexpr int MAX_MATRIX_K = 7168;
28
+ constexpr int MAX_SPLITK_FACTOR = 8;
29
+
30
+ void init_workspace() {
31
+ LIB_CALL(HOST_TYPE(Malloc)(&c_splitk, MAX_MATRIX_M * MAX_MATRIX_N * sizeof(float) * MAX_SPLITK_FACTOR));
32
+ LIB_CALL(HOST_TYPE(Malloc)(&a_trans, MAX_MATRIX_M * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
33
+ LIB_CALL(HOST_TYPE(Malloc)(&b_trans, MAX_MATRIX_N * MAX_MATRIX_K * sizeof(__FP8_TYPE)));
34
+ // LIB_CALL(HOST_TYPE(StreamCreateWithFlags)(&job_stream0, HOST_TYPE(StreamNonBlocking)));
35
+ // job_stream0 = 0;
36
+ }
37
+
38
+
39
+ // Launch pipeline gemm kernels (most performant).
40
+ // 1. Transpose input A & B.
41
+ // 2. GEMM compute.
42
+ // 3. Reduce (if spilt-k is enable)
43
+ template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE,
44
+ int SPLITK_FACTOR, int LOAD_BATCH_SIZE = 16>
45
+ void launch_gemm(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
46
+ static_assert(M <= MAX_MATRIX_M, "M exceeds maximum supported size");
47
+ static_assert(N <= MAX_MATRIX_N, "N exceeds maximum supported size");
48
+ static_assert(K <= MAX_MATRIX_K, "K exceeds maximum supported size");
49
+ static_assert(SPLITK_FACTOR <= MAX_SPLITK_FACTOR, "SPLITK_FACTOR exceeds maximum supported size");
50
+ if (__builtin_expect(c_splitk == nullptr, 0)) {
51
+ init_workspace();
52
+ LIB_CALL(hipDeviceSynchronize());
53
+ }
54
+
55
+ transpose_kernel::transpose_fp8<K, N>(b_trans, b, job_stream0);
56
+ transpose_kernel::transpose_fp8<K, M>(a_trans, a, job_stream0);
57
+ // transpose_kernel::launch_transpose<__FP8_TYPE, K, N, 64, 512, 4>(b_trans, b, job_stream0);
58
+ // transpose_kernel::launch_transpose<__FP8_TYPE, K, M, 64, 512, 4>(a_trans, a, job_stream0);
59
+ // Busy wait for 150 microseconds
60
+ // auto start = std::chrono::high_resolution_clock::now();
61
+ // while (std::chrono::duration_cast<std::chrono::microseconds>(
62
+ // std::chrono::high_resolution_clock::now() - start).count() < 150) {
63
+ // // Busy wait
64
+ // }
65
+ // be careful that blocksize < 1024, or there's a silent fault
66
+ // gemm_kernel::check_trans<<<dim3(K / 32, M / 32), dim3(32, 32)>>>(a, a_trans, K, M);
67
+
68
+ static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
69
+ dim3 grid(ceil_div(N, BN) << 1, ceil_div(M, BM) >> 1, SPLITK_FACTOR);
70
+ static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
71
+ dim3 block(BLOCK_SIZE);
72
+ if constexpr (SPLITK_FACTOR == 1) {
73
+ hipLaunchKernelGGL(
74
+ HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
75
+ grid, block, 0, job_stream0,
76
+ reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
77
+ reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
78
+ reinterpret_cast<__BF16_TYPE(*)[N]>(c), reinterpret_cast<const float(*)[M]>(as),
79
+ reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
80
+ );
81
+ } else {
82
+ hipLaunchKernelGGL(
83
+ HIP_KERNEL_NAME(gemm_kernel::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N, K, K, LOAD_BATCH_SIZE>),
84
+ grid, block, 0, job_stream0,
85
+ reinterpret_cast<const __FP8_TYPE(*)[K]>(a_trans),
86
+ reinterpret_cast<const __FP8_TYPE(*)[K]>(b_trans),
87
+ reinterpret_cast<float(*)[N]>(c_splitk), reinterpret_cast<const float(*)[M]>(as),
88
+ reinterpret_cast<const float(*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs));
89
+ constexpr uint32_t REDUCE_BLOCK = 256;
90
+ hipLaunchKernelGGL(
91
+ HIP_KERNEL_NAME(gemm_kernel::reduce_kernel<M, N, SPLITK_FACTOR, REDUCE_BLOCK>),
92
+ ceil_div(M * N / 4, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
93
+ reinterpret_cast<const float(*)[M][N]>(c_splitk),
94
+ reinterpret_cast<__BF16_TYPE(*)[N]>(c)
95
+ ); }
96
+ auto err = HOST_TYPE(GetLastError)();
97
+ if (err != HOST_TYPE(Success)) {
98
+ std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
99
+ abort();
100
+ }
101
+ }
102
+
103
+
104
+ // Launch legacy gemm kernel. (most compellable)
105
+ template <int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int SPLITK_FACTOR>
106
+ void launch_gemm_legacy(const __FP8_TYPE *a, const __FP8_TYPE *b, __BF16_TYPE *c, const float *as, const float *bs, HOST_TYPE(Stream_t) job_stream0) {
107
+ static_assert(K % SPLITK_FACTOR == 0, "K not divisible by SPLITK_FACTOR");
108
+ dim3 grid(ceil_div(N, BN), ceil_div(M, BM), SPLITK_FACTOR);
109
+ static_assert(BLOCK_SIZE >= 32, "BLOCK_SIZE must be at least 32");
110
+ dim3 block(BLOCK_SIZE);
111
+ if (__builtin_expect(c_splitk == nullptr, 0)) {
112
+ init_workspace();
113
+ LIB_CALL(hipDeviceSynchronize());
114
+ }
115
+
116
+ if constexpr (SPLITK_FACTOR == 1) {
117
+ hipLaunchKernelGGL(
118
+ HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, __BF16_TYPE, M, N, K, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
119
+ grid, block, 0, job_stream0,
120
+ reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
121
+ reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
122
+ reinterpret_cast<__BF16_TYPE (*)[N]>(c),
123
+ reinterpret_cast<const float (*)[M]>(as),
124
+ reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
125
+ );
126
+ } else {
127
+ hipLaunchKernelGGL(
128
+ HIP_KERNEL_NAME(gemm_kernel_legacy::gemm_kernel<__FP8_TYPE, float, float, M, N, K / SPLITK_FACTOR, BM, BN, BK, QUANT_BLOCK_SIZE, BLOCK_SIZE, WARP_M, WARP_N>),
129
+ grid, block, 0, job_stream0,
130
+ reinterpret_cast<const __FP8_TYPE (*)[M]>(a),
131
+ reinterpret_cast<const __FP8_TYPE (*)[N]>(b),
132
+ reinterpret_cast<float (*)[N]>(c_splitk),
133
+ reinterpret_cast<const float (*)[M]>(as),
134
+ reinterpret_cast<const float (*)[ceil_div(N, QUANT_BLOCK_SIZE)]>(bs)
135
+ );
136
+ constexpr uint32_t REDUCE_BLOCK = 256;
137
+ hipLaunchKernelGGL(
138
+ HIP_KERNEL_NAME(gemm_kernel_legacy::reduce<0>),
139
+ ceil_div(M * N, REDUCE_BLOCK), REDUCE_BLOCK, 0, job_stream0,
140
+ M, N, SPLITK_FACTOR, c_splitk, (__BF16_TYPE *)c
141
+ );
142
+ }
143
+ auto err = HOST_TYPE(GetLastError)();
144
+ if (err != HOST_TYPE(Success)) {
145
+ std::cerr << "Kernel execution failed.\n" << HOST_TYPE(GetErrorString)(err) << std::endl;
146
+ abort();
147
+ }
148
+ }
149
+
150
+ constexpr inline uint32_t pack_shape(uint32_t m, uint32_t n, uint32_t k) {
151
+ // Pack m, n, k into a 32-bit integer
152
+ // Use 8 bits for each dimension (supports 32-aligned values from 32 to 8192)
153
+ // Divide each value by 32 to fit into 8 bits
154
+ return ((m / 32) << 16) | ((n / 32) << 8) | (k / 32);
155
+ }
156
+ // int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int BLOCK_SIZE, int QUANT_BLOCK_SIZE, int
157
+ // SPLITK_FACTOR, int LOAD_BATCH_SIZE
158
+ #define DISPATCH_GEMM(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE) \
159
+ case pack_shape_checked<M, N, K>(): { \
160
+ launch_gemm<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR, LOAD_BATCH_SIZE>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
161
+ break; \
162
+ }
163
+
164
+ #define DISPATCH_GEMM_LEGACY(M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR) \
165
+ case pack_shape_checked<M, N, K>(): { \
166
+ launch_gemm_legacy<M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, 128, SPLITK_FACTOR>(a_ptr, b_ptr, c_ptr, as_ptr, bs_ptr, job_stream0); \
167
+ break; \
168
+ }
169
+
170
+ template <int M, int N, int K> constexpr inline uint32_t pack_shape_checked() {
171
+ static_assert(M % 32 == 0, "M must be a multiple of 32");
172
+ static_assert(N % 32 == 0, "N must be a multiple of 32");
173
+ static_assert(K % 32 == 0, "K must be a multiple of 32");
174
+ static_assert(M >= 32 && M <= 8192, "M must be between 32 and 8192");
175
+ static_assert(N >= 32 && N <= 8192, "N must be between 32 and 8192");
176
+ static_assert(K >= 32 && K <= 8192, "K must be between 32 and 8192");
177
+ return pack_shape(M, N, K);
178
+ }
179
+
180
+
181
+
182
+ extern "C" {
183
+ // Basically, it dispatch GEMM to fatest implementations according to inputs' shape.
184
+ void run(void *a, void *b, void *as, void *bs, void *c, int m, int n, int k, PerfMetrics *metrics, hipStream_t job_stream0) {
185
+ // Cast pointers once
186
+ const __FP8_TYPE *a_ptr = static_cast<const __FP8_TYPE *>(a);
187
+ const __FP8_TYPE *b_ptr = static_cast<const __FP8_TYPE *>(b);
188
+ __BF16_TYPE *c_ptr = static_cast<__BF16_TYPE *>(c);
189
+ const float *as_ptr = static_cast<const float *>(as);
190
+ const float *bs_ptr = static_cast<const float *>(bs);
191
+ KernelTimerScoped timer(timers, 2LL * m * n * k,
192
+ metrics ? &metrics->entries[0].time : nullptr,
193
+ metrics ? &metrics->entries[0].gflops : nullptr, job_stream0);
194
+
195
+ switch (pack_shape(m, n, k)) {
196
+ #ifdef TEST_ON_RDNA4 // RDNA4, WAVE_SIZE = 32
197
+ // Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
198
+ DISPATCH_GEMM(64, 64, 128, 64, 64, 32, 1, 4, 128, 1, 16);
199
+ DISPATCH_GEMM(64, 1536, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
200
+ DISPATCH_GEMM(64, 3072, 1536, 64, 128, 64, 4, 2, 256, 1, 16);
201
+ DISPATCH_GEMM(64, 576, 7168, 64, 128, 64, 4, 2, 256, 1, 16);
202
+ DISPATCH_GEMM(96, 7168, 256, 96, 256, 64, 2, 4, 256, 1, 16);
203
+ DISPATCH_GEMM(96, 7168, 2048, 96, 256, 64, 2, 4, 256, 1, 16);
204
+ DISPATCH_GEMM(96, 4608, 7168, 96, 256, 64, 2, 4, 256, 1, 16);
205
+ DISPATCH_GEMM(128, 7168, 2304, 128, 128, 64, 4, 2, 256, 1, 16);
206
+ DISPATCH_GEMM(128, 512, 7168, 128, 128, 64, 4, 2, 256, 1, 16);
207
+ DISPATCH_GEMM(512, 4096, 512, 256, 128, 64, 4, 2, 256, 1, 16);
208
+ DISPATCH_GEMM(512, 1536, 7168, 256, 128, 64, 4, 2, 256, 1, 16);
209
+ // Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SIZE, SPLITK_FACTOR, LOAD_BATCH_SIZE
210
+ DISPATCH_GEMM(1024, 1536, 7168, 128, 128, 64, 1, 4, 128, 4, 16); // Optimized: 0.49 ms (45.65 TFlops)
211
+ DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.19 ms (51.32 TFlops)
212
+ DISPATCH_GEMM(1024, 576, 7168, 128, 64, 32, 4, 1, 128, 4, 16); // Optimized: 0.30 ms (28.16 TFlops)
213
+ DISPATCH_GEMM(1024, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.08 ms (46.49 TFlops)
214
+ DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (61.92 TFlops)
215
+ DISPATCH_GEMM(1024, 4608, 7168, 128, 128, 32, 2, 2, 128, 1, 16); // Optimized: 0.99 ms (68.16 TFlops)
216
+ DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (66.04 TFlops)
217
+ DISPATCH_GEMM(1024, 512, 7168, 64, 128, 32, 2, 2, 128, 4, 16); // Optimized: 0.26 ms (28.97 TFlops)
218
+ DISPATCH_GEMM(1024, 4096, 512, 128, 256, 32, 2, 4, 256, 1, 16); // Optimized: 0.08 ms (54.27 TFlops)
219
+ DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 1.76 ms (76.76 TFlops)
220
+ DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.88 ms (66.00 TFlops)
221
+ DISPATCH_GEMM(6144, 576, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.84 ms (60.68 TFlops)
222
+ DISPATCH_GEMM(6144, 7168, 256, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.49 ms (45.76 TFlops)
223
+ DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.17 ms (83.11 TFlops)
224
+ DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 4.56 ms (88.99 TFlops)
225
+ DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 2.41 ms (84.32 TFlops)
226
+ DISPATCH_GEMM(6144, 512, 7168, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.67 ms (67.45 TFlops)
227
+ DISPATCH_GEMM(6144, 4096, 512, 256, 128, 32, 4, 2, 256, 1, 16); // Optimized: 0.51 ms (50.79 TFlops)
228
+ #else // CDNA3, WAVE_SIZE = 64
229
+ // Benchmark: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F, LOAD_BS
230
+ DISPATCH_GEMM(1024, 1536, 7168, 256, 128, 128, 4, 2, 512, 4, 16); // #0
231
+ DISPATCH_GEMM(1024, 3072, 1536, 256, 128, 128, 4, 2, 512, 2, 16); // #1
232
+ DISPATCH_GEMM(1024, 576, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #2
233
+ DISPATCH_GEMM(1024, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #3
234
+ DISPATCH_GEMM(1024, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #4
235
+ DISPATCH_GEMM(1024, 4608, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #5
236
+ DISPATCH_GEMM(1024, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #6
237
+ DISPATCH_GEMM(1024, 512, 7168, 256, 128, 128, 4, 2, 512, 8, 16); // #7
238
+ DISPATCH_GEMM(1024, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #8
239
+ DISPATCH_GEMM(6144, 1536, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #9
240
+ DISPATCH_GEMM(6144, 3072, 1536, 256, 128, 128, 4, 2, 512, 1, 16); // #10
241
+ DISPATCH_GEMM(6144, 576, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #11
242
+ DISPATCH_GEMM(6144, 7168, 256, 256, 128, 128, 4, 2, 512, 1, 16); // #12
243
+ DISPATCH_GEMM(6144, 7168, 2048, 256, 128, 128, 4, 2, 512, 1, 16); // #13
244
+ DISPATCH_GEMM(6144, 4608, 7168, 256, 128, 128, 4, 2, 512, 1, 16); // #14
245
+ DISPATCH_GEMM(6144, 7168, 2304, 256, 128, 128, 4, 2, 512, 1, 16); // #15
246
+ DISPATCH_GEMM(6144, 512, 7168, 256, 128, 128, 4, 2, 512, 2, 16); // #16
247
+ DISPATCH_GEMM(6144, 4096, 512, 256, 128, 128, 4, 2, 512, 1, 16); // #17
248
+ // Test: M, N, K, BM, BN, BK, WARP_M, WARP_N, BLOCK_SZ, SPLITK_F,
249
+ DISPATCH_GEMM_LEGACY(64, 64, 128, 64, 64, 32, 4, 2, 512, 1);
250
+ DISPATCH_GEMM_LEGACY(64, 1536, 7168, 64, 128, 64, 4, 2, 512, 1);
251
+ DISPATCH_GEMM_LEGACY(64, 3072, 1536, 64, 128, 64, 4, 2, 512, 1);
252
+ DISPATCH_GEMM_LEGACY(64, 576, 7168, 64, 128, 64, 4, 2, 512, 1);
253
+ DISPATCH_GEMM_LEGACY(96, 7168, 256, 96, 256, 64, 2, 4, 512, 1);
254
+ DISPATCH_GEMM_LEGACY(96, 7168, 2048, 96, 256, 64, 2, 4, 512, 1);
255
+ DISPATCH_GEMM_LEGACY(96, 4608, 7168, 96, 256, 64, 2, 4, 512, 1);
256
+ DISPATCH_GEMM_LEGACY(128, 7168, 2304, 128, 128, 64, 4, 2, 512, 1);
257
+ DISPATCH_GEMM_LEGACY(128, 512, 7168, 128, 128, 64, 4, 2, 512, 1);
258
+ DISPATCH_GEMM_LEGACY(512, 4096, 512, 256, 128, 64, 4, 2, 512, 1);
259
+ DISPATCH_GEMM_LEGACY(512, 1536, 7168, 256, 128, 64, 4, 2, 512, 1);
260
+ #endif
261
+ default: {
262
+ printf("Error: Unsupported shape M=%d, K=%d, N=%d\n", m, k, n);
263
+ abort();
264
+ }
265
+ }
266
+ }
267
+ } // extern "C"
gemm/transpose_kernel.h ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Implementation of transpose kernel.
2
+ #pragma once
3
+
4
+ #include <hip/amd_detail/amd_hip_runtime.h>
5
+ #include <hip/amd_detail/amd_warp_functions.h>
6
+ #include "../include/gpu_libs.h"
7
+ #include "../include/gpu_types.h"
8
+ #include "../src/utils/arithmetic.h"
9
+ #include "../include/clangd_workaround.h"
10
+
11
+ DEVICE_CODE_BELOW
12
+
13
+ namespace transpose_kernel {
14
+
15
+
16
+
17
+ template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
18
+ __launch_bounds__(BLOCK_SIZE)
19
+ __global__ void transpose_kernel(Elem *odata, const Elem *idata) {
20
+ constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
21
+ constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
22
+
23
+ // avoid read bank conflict
24
+ // VEC_SIZE * (TILE_DIM + d) * sizeof(Elem) = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE) * sizeof(Elem) + 128k
25
+ // each warp read row = TILE_DIM (in VEC_SIZE reads), col = TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
26
+ // warp 0 warp 1
27
+ // t0 t16 t32 t48 ...
28
+ // ...
29
+ // t1
30
+ // ...
31
+ // t15
32
+ // don't know why padding to d as described above is not working, maybe gpu could merge contigious ds_read_u8 and
33
+ // cause padding to be TBLOCK_Y / (BLOCK_SIZE / WARP_SIZE)
34
+ constexpr auto PADDING = TBLOCK_Y / (BLOCK_SIZE / warpSize);
35
+ __shared__ Elem tile[TILE_DIM][TILE_DIM + PADDING];
36
+
37
+ int x = blockIdx.x * TILE_DIM + threadIdx.x * VEC_SIZE;
38
+ int y = blockIdx.y * TILE_DIM + threadIdx.y;
39
+
40
+ // Load tile
41
+ #pragma unroll
42
+ for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
43
+ #pragma unroll
44
+ for (int v = 0; v < VEC_SIZE; v++) {
45
+ tile[threadIdx.y + i][threadIdx.x * VEC_SIZE + v] = idata[(y + i) * N + x + v];
46
+ }
47
+ }
48
+
49
+ __syncthreads();
50
+
51
+ // Transpose indices
52
+ x = blockIdx.y * TILE_DIM + threadIdx.x * VEC_SIZE;
53
+ y = blockIdx.x * TILE_DIM + threadIdx.y;
54
+
55
+ // Write tile
56
+ #pragma unroll
57
+ for (int i = 0; i < TILE_DIM; i += TBLOCK_Y) {
58
+ #pragma unroll
59
+ for (int v = 0; v < VEC_SIZE; v++) {
60
+ odata[(y + i) * M + x + v] = tile[threadIdx.x * VEC_SIZE + v][threadIdx.y + i];
61
+ }
62
+ }
63
+ }
64
+
65
+ template <typename Elem, int M, int N, int TILE_DIM, int BLOCK_SIZE, int VEC_SIZE>
66
+ void launch_transpose(Elem *out, const Elem *in, hipStream_t stream = 0) {
67
+ static_assert(TILE_DIM % VEC_SIZE == 0);
68
+ constexpr auto TBLOCK_X = TILE_DIM / VEC_SIZE;
69
+ static_assert(BLOCK_SIZE % TBLOCK_X == 0);
70
+ constexpr auto TBLOCK_Y = BLOCK_SIZE / TBLOCK_X;
71
+ static_assert(M % TILE_DIM == 0 && N % TILE_DIM == 0);
72
+ hipLaunchKernelGGL(
73
+ HIP_KERNEL_NAME(transpose_kernel<Elem, M, N, TILE_DIM, BLOCK_SIZE, VEC_SIZE>),
74
+ dim3(N / TILE_DIM, M / TILE_DIM), dim3(TBLOCK_X, TBLOCK_Y), 0, stream,
75
+ out, in);
76
+ }
77
+
78
+ #define DISPATCH_TRANSPOSE(DIM_0, DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE) else if constexpr(IN_DIM_0 == DIM_0 && IN_DIM_1 == DIM_1) launch_transpose<__FP8_TYPE, IN_DIM_0, IN_DIM_1, TILE_DIM, BLOCK_SIZE, VEC_SIZE>(out, in, stream)
79
+
80
+ template <int DIM0, int DIM1>
81
+ struct unsupported_config {
82
+ static_assert(DIM0 == -1, "Unsupported transpose configuration - check template parameters");
83
+ };
84
+
85
+ // Selecte best parameters for tranpose kernel.
86
+ template <int IN_DIM_0, int IN_DIM_1>
87
+ void transpose_fp8(__FP8_TYPE *out, const __FP8_TYPE *in, hipStream_t stream = 0) {
88
+ if constexpr (false /* dummy*/ ) static_assert(true);
89
+ DISPATCH_TRANSPOSE( 256, 1024, 64, 256, 4); // Optimized: 2.71 µs (193.46 GB/s)
90
+ DISPATCH_TRANSPOSE( 256, 6144, 64, 256, 4); // Optimized: 2.72 µs (1157.37 GB/s)
91
+ DISPATCH_TRANSPOSE( 256, 7168, 64, 256, 8); // Optimized: 2.99 µs (1225.38 GB/s)
92
+ DISPATCH_TRANSPOSE( 512, 1024, 64, 512, 4); // Optimized: 2.55 µs (411.21 GB/s)
93
+ DISPATCH_TRANSPOSE( 512, 4096, 64, 256, 4); // Optimized: 3.01 µs (1394.85 GB/s)
94
+ DISPATCH_TRANSPOSE( 512, 6144, 64, 512, 4); // Optimized: 3.58 µs (1755.43 GB/s)
95
+ DISPATCH_TRANSPOSE( 1536, 1024, 64, 1024, 4); // Optimized: 2.78 µs (1130.74 GB/s)
96
+ DISPATCH_TRANSPOSE( 1536, 3072, 64, 512, 4); // Optimized: 3.57 µs (2641.99 GB/s)
97
+ DISPATCH_TRANSPOSE( 1536, 6144, 128, 1024, 8); // Optimized: 7.09 µs (2661.36 GB/s)
98
+ DISPATCH_TRANSPOSE( 2048, 1024, 64, 1024, 4); // Optimized: 2.84 µs (1477.91 GB/s)
99
+ DISPATCH_TRANSPOSE( 2048, 6144, 128, 512, 8); // Optimized: 8.94 µs (2816.23 GB/s)
100
+ DISPATCH_TRANSPOSE( 2048, 7168, 128, 512, 8); // Optimized: 9.56 µs (3070.50 GB/s)
101
+ DISPATCH_TRANSPOSE( 2304, 1024, 64, 1024, 4); // Optimized: 3.08 µs (1532.51 GB/s)
102
+ DISPATCH_TRANSPOSE( 2304, 6144, 128, 512, 8); // Optimized: 9.30 µs (3043.93 GB/s)
103
+ DISPATCH_TRANSPOSE( 2304, 7168, 128, 512, 8); // Optimized: 10.39 µs (3179.95 GB/s)
104
+ DISPATCH_TRANSPOSE( 7168, 512, 64, 512, 4); // Optimized: 3.25 µs (2257.78 GB/s)
105
+ DISPATCH_TRANSPOSE( 7168, 576, 64, 512, 4); // Optimized: 3.44 µs (2403.24 GB/s)
106
+ DISPATCH_TRANSPOSE( 7168, 1024, 64, 256, 4); // Optimized: 5.07 µs (2892.62 GB/s)
107
+ DISPATCH_TRANSPOSE( 7168, 1536, 128, 1024, 8); // Optimized: 7.72 µs (2851.97 GB/s)
108
+ DISPATCH_TRANSPOSE( 7168, 4608, 128, 512, 8); // Optimized: 16.87 µs (3915.84 GB/s)
109
+ DISPATCH_TRANSPOSE( 7168, 6144, 128, 256, 8); // Optimized: 21.59 µs (4079.12 GB/s)
110
+ else static_assert(false);
111
+ }
112
+
113
+ } // namespace transpose_kernel
114
+
115
+
116
+
117
+
118
+ #ifndef PARAMETERIZE_LIBRARY
119
+ int main() {}
120
+ #endif
include/clangd_workaround.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #define HOST_CODE_BELOW \
2
+ extern "C" int printf(const char *fmt, ...); \
3
+ extern void operator delete[](void *ptr) _GLIBCXX_USE_NOEXCEPT; \
4
+ extern void *operator new[](__SIZE_TYPE__ size);
5
+
6
+ #define DEVICE_CODE_BELOW \
7
+ extern "C" __device__ int printf(const char *fmt, ...); \
8
+ extern __device__ void operator delete[](void *ptr) _GLIBCXX_USE_NOEXCEPT; \
9
+ extern __device__ void *operator new[](__SIZE_TYPE__ size);
include/gpu_libs.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef TEST_ON_CUDA
2
+ #include <mma.h>
3
+
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_fp8.h>
6
+
7
+ namespace wmma = nvcuda::wmma;
8
+
9
+ #define LIB_CALL(call) \
10
+ do { \
11
+ cudaError_t err = call; \
12
+ if (err != cudaSuccess) { \
13
+ abort(); \
14
+ } \
15
+ } while (0)
16
+
17
+ #define HOST_TYPE(x) cuda##x
18
+
19
+ #else
20
+
21
+ #ifndef HIP_HEADERS__
22
+ #include <hip/hip_runtime.h>
23
+ #include <hip/hip_fp8.h>
24
+ #include <hip/hip_fp16.h>
25
+ #include <rocwmma/rocwmma.hpp>
26
+ #define HIP_HEADERS__
27
+ #endif
28
+
29
+ namespace wmma = rocwmma;
30
+
31
+ #define LIB_CALL(call) \
32
+ do { \
33
+ hipError_t err = call; \
34
+ if (err != hipSuccess) { \
35
+ abort(); \
36
+ } \
37
+ } while (0)
38
+
39
+ #define HOST_TYPE(x) hip##x
40
+
41
+ #endif
42
+
include/gpu_types.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #ifdef TEST_ON_CUDA
3
+ #define __FP8_TYPE __nv_fp8_e4m3
4
+ #define __FP8x4_TYPE __nv_fp8x4_e4m3
5
+ #define __BF16_TYPE __nv_bfloat16
6
+ #define __BF16x2_TYPE __nv_bfloat162
7
+
8
+ #else
9
+ #ifdef TEST_ON_RDNA4
10
+ #define __FP8_TYPE __hip_fp8_e4m3
11
+ #define __FP8x4_TYPE __hip_fp8x4_e4m3
12
+ constexpr const inline int WAVE_SIZE = 32;
13
+ constexpr const inline int XCD_SWIZZLE = 1;
14
+ #else
15
+ #define __FP8_TYPE __hip_fp8_e4m3_fnuz
16
+ #define __FP8x4_TYPE __hip_fp8x4_e4m3_fnuz
17
+ constexpr const inline int WAVE_SIZE = 64;
18
+ constexpr const inline int XCD_SWIZZLE = 8;
19
+ #endif
20
+
21
+ #define __BF16_TYPE __hip_bfloat16
22
+ #define __BF16x2_TYPE __hip_bfloat162
23
+ #define __FP16_TYPE __half
24
+ #define __INT16_TYPE int16_t
25
+
26
+ #endif
27
+
28
+ #define __FP32_TYPE float
include/timer.h ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "gpu_types.h"
3
+ #include "gpu_libs.h"
4
+ #include <memory>
5
+ #include <vector>
6
+
7
+ typedef void (*TimerCompletionCallback)(float elapsed_time, size_t calc_ops, float *time_ptr, float *gflops_ptr,
8
+ void *user_data);
9
+
10
+ class KernelTimer {
11
+ private:
12
+ size_t calc_ops;
13
+ HOST_TYPE(Event_t) start, stop;
14
+ float *time_ptr;
15
+ float *gflops_ptr;
16
+ void *user_data;
17
+ TimerCompletionCallback callback;
18
+ bool callback_executed;
19
+
20
+ public:
21
+ KernelTimer(size_t calc_ops, float *time, float *gflops);
22
+
23
+ void start_timer(hipStream_t stream = 0);
24
+ void stop_timer(hipStream_t stream = 0);
25
+ void set_callback(TimerCompletionCallback cb, void *data = nullptr);
26
+
27
+ // Wait for the timer to complete and execute the callback if set
28
+ void synchronize();
29
+
30
+ // Getter methods for the callback
31
+ HOST_TYPE(Event_t) get_start_event() const { return start; }
32
+ HOST_TYPE(Event_t) get_stop_event() const { return stop; }
33
+ size_t get_calc_ops() const { return calc_ops; }
34
+ float *get_time_ptr() const { return time_ptr; }
35
+ float *get_gflops_ptr() const { return gflops_ptr; }
36
+ void execute_callback(float elapsed_time);
37
+ void set_callback_executed(bool executed) { callback_executed = executed; }
38
+ bool is_callback_executed() const { return callback_executed; }
39
+
40
+ ~KernelTimer();
41
+ };
42
+
43
+ class KernelTimerScoped {
44
+ private:
45
+ std::shared_ptr<KernelTimer> timer;
46
+ hipStream_t stream;
47
+
48
+ public:
49
+ KernelTimerScoped(std::vector<std::shared_ptr<KernelTimer>> &timers, size_t calc_ops, float *time, float *gflops,
50
+ hipStream_t stream = 0)
51
+ : timer(std::make_shared<KernelTimer>(calc_ops, time, gflops)), stream(stream) {
52
+ timers.push_back(timer);
53
+ timer->start_timer(stream);
54
+ }
55
+
56
+ ~KernelTimerScoped() { timer->stop_timer(stream); }
57
+ };
src/utils/arithmetic.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef ARITHMETIC_H
2
+ #define ARITHMETIC_H
3
+
4
+ template <int x, int y> constexpr __device__ __host__ inline int exact_div() {
5
+ static_assert(x % y == 0);
6
+ static_assert(x >= y);
7
+ return x / y;
8
+ }
9
+
10
+ constexpr __device__ __host__ inline int ceil_div(int x, int y) { return (x + y - 1) / y; }
11
+
12
+ #endif
src/utils/timer.hip ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "../../include/timer.h"
2
+
3
+ // Define HIPRT_CB if not already defined
4
+ #ifndef HIPRT_CB
5
+ #define HIPRT_CB
6
+ #endif
7
+
8
+ // Forward declaration of KernelTimer
9
+ class KernelTimer;
10
+
11
+ // Static callback function for hipStreamAddCallback
12
+ static void HIPRT_CB eventCallback(hipStream_t stream, hipError_t status, void* userData) {
13
+ if (status != hipSuccess) return;
14
+
15
+ KernelTimer* timer = static_cast<KernelTimer*>(userData);
16
+ float elapsed_time;
17
+
18
+ // Use the getter methods to access the private members
19
+ HOST_TYPE(Event_t) start_event = timer->get_start_event();
20
+ HOST_TYPE(Event_t) stop_event = timer->get_stop_event();
21
+
22
+ LIB_CALL(HOST_TYPE(EventElapsedTime)(&elapsed_time, start_event, stop_event));
23
+
24
+ size_t calc_ops = timer->get_calc_ops();
25
+ double flops = static_cast<double>(calc_ops);
26
+ double gflops_val = (flops / (elapsed_time * 1e-3)) / 1e9;
27
+
28
+ // Store results in the provided pointers
29
+ float* time_ptr = timer->get_time_ptr();
30
+ float* gflops_ptr = timer->get_gflops_ptr();
31
+
32
+ if (time_ptr != nullptr) {
33
+ *time_ptr = elapsed_time;
34
+ }
35
+ if (gflops_ptr != nullptr) {
36
+ *gflops_ptr = static_cast<float>(gflops_val);
37
+ }
38
+
39
+ // Call user callback if provided
40
+ timer->execute_callback(elapsed_time);
41
+ timer->set_callback_executed(true);
42
+ }
43
+
44
+ KernelTimer::KernelTimer(size_t calc_ops, float *time, float *gflops)
45
+ : calc_ops(calc_ops), time_ptr(time), gflops_ptr(gflops), user_data(nullptr),
46
+ callback(nullptr), callback_executed(false) {
47
+ LIB_CALL(HOST_TYPE(EventCreate)(&start));
48
+ LIB_CALL(HOST_TYPE(EventCreate)(&stop));
49
+ }
50
+
51
+ void KernelTimer::start_timer(hipStream_t stream) {
52
+ LIB_CALL(HOST_TYPE(EventRecord)(start, stream));
53
+ callback_executed = false;
54
+ }
55
+
56
+ void KernelTimer::stop_timer(hipStream_t stream) {
57
+ LIB_CALL(HOST_TYPE(EventRecord)(stop, stream));
58
+ // Instead of synchronizing, add a callback to the stream that will be called when the event completes
59
+ LIB_CALL(hipStreamAddCallback(stream, eventCallback, this, 0));
60
+ }
61
+
62
+ void KernelTimer::set_callback(TimerCompletionCallback cb, void* data) {
63
+ callback = cb;
64
+ user_data = data;
65
+ }
66
+
67
+ void KernelTimer::execute_callback(float elapsed_time) {
68
+ if (callback && !callback_executed) {
69
+ callback(elapsed_time, calc_ops, time_ptr, gflops_ptr, user_data);
70
+ }
71
+ }
72
+
73
+ void KernelTimer::synchronize() {
74
+ // If callback hasn't been executed yet, synchronize and wait for event completion, then manually execute callback
75
+ if (!callback_executed) {
76
+ LIB_CALL(HOST_TYPE(EventSynchronize)(stop));
77
+ float elapsed_time;
78
+ LIB_CALL(HOST_TYPE(EventElapsedTime)(&elapsed_time, start, stop));
79
+
80
+ double flops = static_cast<double>(calc_ops);
81
+ double gflops_val = (flops / (elapsed_time * 1e-3)) / 1e9;
82
+
83
+ // Store results in the provided pointers
84
+ if (time_ptr != nullptr) {
85
+ *time_ptr = elapsed_time;
86
+ }
87
+ if (gflops_ptr != nullptr) {
88
+ *gflops_ptr = static_cast<float>(gflops_val);
89
+ }
90
+
91
+ // Execute callback
92
+ if (callback) {
93
+ callback(elapsed_time, calc_ops, time_ptr, gflops_ptr, user_data);
94
+ }
95
+ callback_executed = true;
96
+ }
97
+ }
98
+
99
+ KernelTimer::~KernelTimer() {
100
+ // Synchronize during destruction to ensure callback is executed
101
+ synchronize();
102
+ LIB_CALL(HOST_TYPE(EventDestroy)(start));
103
+ LIB_CALL(HOST_TYPE(EventDestroy)(stop));
104
+ }
tests/checker/checker.cpp ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "checker.h"
2
+ #include <dlfcn.h>
3
+ #include <sstream>
4
+ #include <fstream>
5
+ #include <iomanip>
6
+ #include <limits>
7
+ #include <getopt.h>
8
+ #include <unistd.h>
9
+
10
+ std::pair<bool, std::string> verbose_allclose(const torch::Tensor &received, const torch::Tensor &expected,
11
+ float rtol = 1e-05, float atol = 1e-08, int max_print = 5) {
12
+ // Check if the shapes of the tensors match
13
+ if (received.sizes() != expected.sizes()) {
14
+ std::string expected_shape_str = "[";
15
+ std::string received_shape_str = "[";
16
+ auto expected_sizes = expected.sizes();
17
+ auto received_sizes = received.sizes();
18
+
19
+ for (int i = 0; i < expected_sizes.size(); i++) {
20
+ expected_shape_str += std::to_string(expected_sizes[i]);
21
+ if (i < expected_sizes.size() - 1)
22
+ expected_shape_str += ", ";
23
+ }
24
+ expected_shape_str += "]";
25
+
26
+ for (int i = 0; i < received_sizes.size(); i++) {
27
+ received_shape_str += std::to_string(received_sizes[i]);
28
+ if (i < received_sizes.size() - 1)
29
+ received_shape_str += ", ";
30
+ }
31
+ received_shape_str += "]";
32
+
33
+ return {false, "SIZE MISMATCH: expected " + expected_shape_str + " but got " + received_shape_str};
34
+ }
35
+
36
+ auto diff = torch::abs(received.to(torch::kFloat32) - expected.to(torch::kFloat32));
37
+
38
+ auto tolerance = atol + rtol * torch::abs(expected);
39
+
40
+ auto tol_mismatched = diff > tolerance;
41
+ auto nan_mismatched = torch::logical_xor(torch::isnan(received), torch::isnan(expected));
42
+ auto posinf_mismatched = torch::logical_xor(torch::isposinf(received), torch::isposinf(expected));
43
+ auto neginf_mismatched = torch::logical_xor(torch::isneginf(received), torch::isneginf(expected));
44
+
45
+ auto mismatched = torch::logical_or(torch::logical_or(tol_mismatched, nan_mismatched),
46
+ torch::logical_or(posinf_mismatched, neginf_mismatched));
47
+
48
+ auto mismatched_indices = torch::nonzero(mismatched);
49
+
50
+ // Count the number of mismatched elements
51
+ int64_t num_mismatched = mismatched.sum().item<int64_t>();
52
+
53
+ // Generate detailed information if there are mismatches
54
+ if (num_mismatched >= 1) {
55
+ std::stringstream mismatch_details;
56
+ auto sizes = received.sizes();
57
+ mismatch_details << "Mismatch found in tensors with shape [";
58
+ for (int i = 0; i < sizes.size(); i++) {
59
+ mismatch_details << sizes[i];
60
+ if (i < sizes.size() - 1)
61
+ mismatch_details << ", ";
62
+ }
63
+ mismatch_details << "]:\n";
64
+ mismatch_details << "Number of mismatched elements: " << num_mismatched << "\n";
65
+
66
+ for (int i = 0; i < std::min(max_print, (int)mismatched_indices.size(0)); i++) {
67
+ auto index = mismatched_indices[i];
68
+ std::vector<int64_t> idx_vec;
69
+ for (int j = 0; j < index.size(0); j++) {
70
+ idx_vec.push_back(index[j].item<int64_t>());
71
+ }
72
+
73
+ // Format the index as a string
74
+ std::string idx_str = "(";
75
+ for (size_t j = 0; j < idx_vec.size(); j++) {
76
+ idx_str += std::to_string(idx_vec[j]);
77
+ if (j < idx_vec.size() - 1)
78
+ idx_str += ", ";
79
+ }
80
+ idx_str += ")";
81
+
82
+ float received_val, expected_val;
83
+ torch::Tensor received_elem = received;
84
+ torch::Tensor expected_elem = expected;
85
+
86
+ for (size_t j = 0; j < idx_vec.size(); j++) {
87
+ received_elem = received_elem[idx_vec[j]];
88
+ expected_elem = expected_elem[idx_vec[j]];
89
+ }
90
+
91
+ received_val = received_elem.item<float>();
92
+ expected_val = expected_elem.item<float>();
93
+
94
+ mismatch_details << "ERROR at " << idx_str << ": " << received_val << " " << expected_val << "\n";
95
+ }
96
+
97
+ if (num_mismatched > max_print) {
98
+ mismatch_details << "... and " << (num_mismatched - max_print) << " more mismatched elements.";
99
+ }
100
+
101
+ return {false, mismatch_details.str()};
102
+ }
103
+
104
+ return {true, "Maximum error: " + std::to_string(diff.max().item<float>())};
105
+ }
106
+
107
+ // Check if implementation matches reference within tolerance
108
+ std::pair<bool, std::string> check_implementation(std::ofstream &fout, const torch::Tensor &output,
109
+ const torch::Tensor &expected, float rtol = 2e-02, float atol = 1e-03,
110
+ CheckerMode mode = CheckerMode::kElementWise) {
111
+ if (mode == CheckerMode::kRowIndex) {
112
+ // For row index mode, we need to sort each row before comparison
113
+ // since the order of indices with the same values might differ
114
+ auto sorted_output = output.clone();
115
+ auto sorted_expected = expected.clone();
116
+
117
+ sorted_output = std::get<0>(torch::sort(output, 1));
118
+ sorted_expected = std::get<0>(torch::sort(expected, 1));
119
+
120
+ return verbose_allclose(sorted_output, sorted_expected, rtol, atol);
121
+ } else if (mode == CheckerMode::kJustDump) {
122
+ // Dump output and expected tensors to file
123
+ {
124
+ fout << "=====OUTPUT=====" << std::endl;
125
+ fout << output.sizes() << std::endl;
126
+
127
+ // Manually print the full tensor to avoid truncation
128
+ auto sizes = output.sizes();
129
+ if (sizes.size() == 2) {
130
+ // For 2D tensors (matrices)
131
+ for (int64_t i = 0; i < sizes[0]; i++) {
132
+ for (int64_t j = 0; j < sizes[1]; j++) {
133
+ fout << std::setw(12) << std::setprecision(6) << output[i][j].item<float>() << " ";
134
+ }
135
+ fout << std::endl;
136
+ }
137
+ } else {
138
+ // Fallback for other tensor dimensions
139
+ fout << output << std::endl;
140
+ }
141
+ }
142
+
143
+ {
144
+ fout << "=====EXPECTED=====" << std::endl;
145
+ fout << expected.sizes() << std::endl;
146
+
147
+ // Manually print the full tensor to avoid truncation
148
+ auto sizes = output.sizes();
149
+ if (sizes.size() == 2) {
150
+ // For 2D tensors (matrices)
151
+ for (int64_t i = 0; i < sizes[0]; i++) {
152
+ for (int64_t j = 0; j < sizes[1]; j++) {
153
+ fout << std::setw(12) << std::setprecision(6) << expected[i][j].item<float>() << " ";
154
+ }
155
+ fout << std::endl;
156
+ }
157
+ } else {
158
+ // Fallback for other tensor dimensions
159
+ fout << output << std::endl;
160
+ }
161
+ }
162
+
163
+ return {true, ""};
164
+ }
165
+ return verbose_allclose(output, expected, rtol, atol);
166
+ }
167
+
168
+ constexpr int BENCHMARK_ITERS = 5;
169
+
170
+ void preload() {
171
+ void *handle_rocblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/librocblas.so", RTLD_NOW | RTLD_GLOBAL);
172
+ void *handle_hipblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblas.so", RTLD_NOW | RTLD_GLOBAL);
173
+ void *handle_hipblaslt = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblaslt.so", RTLD_NOW | RTLD_GLOBAL);
174
+
175
+ if (!handle_rocblas || !handle_hipblas || !handle_hipblaslt) {
176
+ fprintf(stderr, "Failed to load required libraries: %s\n", dlerror());
177
+ exit(1);
178
+ }
179
+ }
180
+
181
+ int main(int argc, char **argv) {
182
+ // preload();
183
+ // bool benchmark = false;
184
+ bool benchmark = true;
185
+ bool profile_mode = false;
186
+ int target_test_case = -1;
187
+ int target_sub_case = -1;
188
+ int opt;
189
+
190
+ while ((opt = getopt(argc, argv, "bpt:c:")) != -1) {
191
+ switch (opt) {
192
+ case 'b':
193
+ benchmark = false;
194
+ break;
195
+ case 'p':
196
+ profile_mode = true;
197
+ break;
198
+ case 't':
199
+ target_sub_case = std::stoi(optarg);
200
+ break;
201
+ case 'c':
202
+ target_test_case = std::stoi(optarg);
203
+ break;
204
+ default:
205
+ fprintf(stderr, "Usage: %s [-b] [-p] [-t subcase_index] [-c test_case_index]\n", argv[0]);
206
+ fprintf(stderr, " -b: Disable benchmark mode\n");
207
+ fprintf(stderr, " -p: Enable profile mode (skips reference kernel and comparison)\n");
208
+ fprintf(stderr, " -t: Run only the specified subcase index\n");
209
+ fprintf(stderr, " -c: Run only the specified test case index\n");
210
+ exit(EXIT_FAILURE);
211
+ }
212
+ }
213
+
214
+ case_initialize();
215
+ int num_params, passed_cases = 0;
216
+ num_params = get_params_count();
217
+
218
+ // Validate test case index if specified
219
+ if (target_test_case >= 0) {
220
+ if (target_test_case >= num_params) {
221
+ std::cerr << "Error: Test case index " << target_test_case << " is out of range (0-" << (num_params - 1)
222
+ << ")" << std::endl;
223
+ exit(EXIT_FAILURE);
224
+ }
225
+ }
226
+
227
+ std::vector<std::vector<PerfMetrics>> run_times(num_params);
228
+ std::vector<std::tuple<bool, std::string, std::vector<std::pair<float, float>>>> results;
229
+
230
+ // If targeting specific test case and subcase, run multiple times and output only the best time
231
+ if (target_test_case >= 0 && target_sub_case >= 0) {
232
+ void *input = case_get_input(target_test_case);
233
+ std::vector<Checkee> output;
234
+ float best_time = std::numeric_limits<float>::max();
235
+
236
+ for (int j = 0; j < BENCHMARK_ITERS; j++) {
237
+ PerfMetrics metrics;
238
+ output = case_run_kernel(input, &metrics);
239
+
240
+ if (metrics.count <= target_sub_case) {
241
+ std::cerr << "Error: Subcase index " << target_sub_case << " is out of range (0-" << (metrics.count - 1)
242
+ << ")" << std::endl;
243
+ exit(EXIT_FAILURE);
244
+ }
245
+
246
+ best_time = std::min(best_time, metrics.entries[target_sub_case].time);
247
+ }
248
+
249
+ std::cout << std::fixed << std::setprecision(6) << best_time * 1e3 << std::endl;
250
+ case_destroy(input);
251
+ return 0;
252
+ }
253
+
254
+ // Normal execution path
255
+ if (!profile_mode && target_test_case < 0) {
256
+ std::cout << "Found " << num_params << " test cases for " << case_get_name() << '\n';
257
+ }
258
+ if (benchmark) {
259
+ std::cout << "Benchmark mode enabled\n";
260
+ }
261
+ if (profile_mode) {
262
+ std::cout << "Profile mode enabled (skipping reference kernels and comparison)\n";
263
+ }
264
+
265
+ // Determine which test cases to run
266
+ std::vector<int> test_cases_to_run;
267
+ if (target_test_case >= 0) {
268
+ test_cases_to_run.push_back(target_test_case);
269
+ } else {
270
+ for (int i = 0; i < num_params; i++) {
271
+ test_cases_to_run.push_back(i);
272
+ }
273
+ }
274
+
275
+ for (int i : test_cases_to_run) {
276
+ std::ofstream *fout = nullptr;
277
+ void *input = case_get_input(i);
278
+ if (!profile_mode && target_test_case < 0) {
279
+ std::cerr << "Running test case " << i << std::flush;
280
+ }
281
+ std::vector<Checkee> reference;
282
+ if (!profile_mode) {
283
+ reference = case_run_ref_kernel(input);
284
+ }
285
+ std::vector<Checkee> output;
286
+ for (int j = 0; j < (benchmark ? BENCHMARK_ITERS : 1); j++) {
287
+ PerfMetrics metrics;
288
+ output = case_run_kernel(input, &metrics);
289
+ run_times[i].push_back(metrics);
290
+ }
291
+
292
+ bool match = true;
293
+ std::string case_message;
294
+
295
+ if (!profile_mode) {
296
+ if (reference.size() != output.size()) {
297
+ std::cerr << "Wrong test definition: reference and output have different sizes" << '\n';
298
+ abort();
299
+ }
300
+
301
+ for (int j = 0; j < reference.size(); j++) {
302
+ float rtol, atol;
303
+ get_error_tolerance(&rtol, &atol);
304
+ if (output[j].mode == CheckerMode::kJustDump) {
305
+ if (!fout) {
306
+ fout = new std::ofstream(std::string("case_") + std::to_string(i) + ".txt");
307
+ }
308
+ *fout << "===== SUBCASE " << output[j].name << "=====" << std::endl;
309
+ }
310
+ auto [match_sub, message_sub] =
311
+ check_implementation(*fout, *output[j].tensor, *reference[j].tensor, rtol, atol, output[j].mode);
312
+ if (!match_sub) {
313
+ case_message += "Err on sub case " + std::to_string(j) + ": " + message_sub + "\n";
314
+ match = false;
315
+ }
316
+ }
317
+ if (match) {
318
+ passed_cases++;
319
+ }
320
+ } else {
321
+ match = true;
322
+ passed_cases++;
323
+ }
324
+
325
+ std::vector<std::pair<float, float>> case_metrics;
326
+
327
+ // Process metrics for each run
328
+ for (const auto &run : run_times[i]) {
329
+ if (run.count == 1) {
330
+ // Backward compatibility: single metric case
331
+ case_metrics.push_back({run.entries[0].time, run.entries[0].gflops});
332
+ } else {
333
+ // Multiple metrics case - first entry is the total result
334
+ case_metrics.push_back({run.entries[0].time, run.entries[0].gflops});
335
+ }
336
+ }
337
+
338
+ results.push_back(std::make_tuple(match, case_message, case_metrics));
339
+ case_destroy(input);
340
+ if (!profile_mode && target_test_case < 0) {
341
+ std::cout << "\033[2K\r" << std::flush;
342
+ }
343
+ }
344
+
345
+ // Only show detailed output if not in single test case mode
346
+ if (target_test_case < 0) {
347
+ std::cout << "=======================" << '\n';
348
+ if (!profile_mode) {
349
+ if (passed_cases == num_params) {
350
+ std::cout << "✅ All " << num_params << " test cases passed!" << '\n';
351
+ } else {
352
+ std::cout << "❌ [" << num_params - passed_cases << "/" << num_params << "] test cases failed!" << '\n';
353
+ }
354
+ } else {
355
+ std::cout << "Profile mode: results comparison skipped" << '\n';
356
+ }
357
+ std::cout << "-----------------------" << '\n';
358
+
359
+ for (int i = 0; i < num_params; i++) {
360
+ auto [match, message, metrics] = results[i];
361
+
362
+ // Calculate best and worst metrics
363
+ float best_time = std::numeric_limits<float>::max();
364
+ float best_gflops = 0.0f;
365
+ float worst_time = 0.0f;
366
+ float worst_gflops = std::numeric_limits<float>::max();
367
+
368
+ for (const auto &[time, gflops] : metrics) {
369
+ best_time = std::min(best_time, time);
370
+ best_gflops = std::max(best_gflops, gflops);
371
+ worst_time = std::max(worst_time, time);
372
+ worst_gflops = std::min(worst_gflops, gflops);
373
+ }
374
+
375
+ std::string timing_info;
376
+ if (benchmark) {
377
+ std::stringstream ss;
378
+ ss << std::fixed << std::setprecision(2);
379
+ ss << "Best: [\033[1m" << best_time * 1e3 << "\033[0m us, \033[1m" << best_gflops / 1e3
380
+ << "\033[0m TFLOPS], "
381
+ << "\033[2mSlowest: [" << worst_time * 1e3 << " us, " << worst_gflops / 1e3 << " TFLOPS]\033[0m";
382
+ timing_info = ss.str();
383
+ } else {
384
+ std::stringstream ss;
385
+ ss << std::fixed << std::setprecision(2);
386
+ ss << "Time: " << best_time * 1e3 << " us, TFLOPS: " << best_gflops / 1e3;
387
+ timing_info = ss.str();
388
+ }
389
+
390
+ if (!profile_mode && !match) {
391
+ std::cout << "❌ Test case " << i << ": " << timing_info << "\n" << message << '\n';
392
+ } else {
393
+ std::cout << "✅ Test case " << i << ": " << timing_info << "\n";
394
+ }
395
+
396
+ // Print sub-results if there are multiple metrics
397
+ if (run_times[i][0].count > 1) {
398
+ for (int j = 1; j < run_times[i][0].count; j++) {
399
+ std::stringstream ss;
400
+ ss << std::fixed << std::setprecision(2);
401
+ ss << " - Sub-case " << run_times[i][0].entries[j].name << ": ";
402
+
403
+ if (benchmark) {
404
+ float sub_best_time = std::numeric_limits<float>::max();
405
+ float sub_best_gflops = 0.0f;
406
+ float sub_worst_time = 0.0f;
407
+ float sub_worst_gflops = std::numeric_limits<float>::max();
408
+
409
+ for (const auto &run : run_times[i]) {
410
+ sub_best_time = std::min(sub_best_time, run.entries[j].time);
411
+ sub_best_gflops = std::max(sub_best_gflops, run.entries[j].gflops);
412
+ sub_worst_time = std::max(sub_worst_time, run.entries[j].time);
413
+ sub_worst_gflops = std::min(sub_worst_gflops, run.entries[j].gflops);
414
+ }
415
+
416
+ ss << "Best: [\033[1m" << sub_best_time * 1e3 << "\033[0m us, \033[1m" << sub_best_gflops / 1e3
417
+ << "\033[0m TFLOPS], "
418
+ << "\033[2mSlowest: [" << sub_worst_time * 1e3 << " us, " << sub_worst_gflops / 1e3
419
+ << " TFLOPS]\033[0m";
420
+ } else {
421
+ ss << "Time: " << run_times[i][0].entries[j].time * 1e3
422
+ << " us, TFLOPS: " << run_times[i][0].entries[j].gflops / 1e3;
423
+ }
424
+
425
+ std::cout << ss.str() << std::endl;
426
+ }
427
+ }
428
+ }
429
+ std::cout << "-----------------------" << '\n';
430
+
431
+ // Calculate geometric mean of time and GFLOPS
432
+ double geo_mean_time = 1.0;
433
+ double geo_mean_gflops = 1.0;
434
+
435
+ for (int i = 0; i < num_params; i++) {
436
+ auto [match, message, metrics] = results[i];
437
+ // Always use the best performance metrics for geometric mean
438
+ float best_time = std::numeric_limits<float>::max();
439
+ float best_gflops = 0.0f;
440
+
441
+ for (const auto &[time, gflops] : metrics) {
442
+ best_time = std::min(best_time, time);
443
+ best_gflops = std::max(best_gflops, gflops);
444
+ }
445
+
446
+ geo_mean_time *= best_time;
447
+ geo_mean_gflops *= best_gflops;
448
+ }
449
+
450
+ geo_mean_time = std::pow(geo_mean_time, 1.0 / num_params);
451
+ geo_mean_gflops = std::pow(geo_mean_gflops, 1.0 / num_params);
452
+
453
+ if (benchmark) {
454
+ std::stringstream ss;
455
+ ss << std::fixed << std::setprecision(2);
456
+ ss << "GeoMean - Best Time: \033[1m" << geo_mean_time * 1e3 << "\033[0m us, Best TFLOPS: \033[1m"
457
+ << geo_mean_gflops / 1e3 << "\033[0m";
458
+ std::cout << ss.str() << std::endl;
459
+ } else {
460
+ std::stringstream ss;
461
+ ss << std::fixed << std::setprecision(2);
462
+ ss << "GeoMean - Time: " << geo_mean_time * 1e3 << " us, TFLOPS: " << geo_mean_gflops / 1e3;
463
+ std::cout << ss.str() << std::endl;
464
+ }
465
+ std::cout << "=======================" << '\n';
466
+ }
467
+
468
+ return 0;
469
+ }
tests/checker/checker.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #pragma once
3
+ #include <iostream>
4
+ #include <vector>
5
+ #include "torch/torch.h"
6
+ #include "metrics.h"
7
+
8
+ enum class CheckerMode {
9
+ kElementWise,
10
+ kRowIndex,
11
+ kJustDump,
12
+ };
13
+
14
+ struct Checkee {
15
+ torch::Tensor *tensor;
16
+ CheckerMode mode;
17
+ std::string name;
18
+ };
19
+
20
+ void case_initialize();
21
+ int get_params_count();
22
+ void *case_get_input(int index);
23
+ std::vector<Checkee> case_run_kernel(void *input, PerfMetrics* metrics);
24
+ std::vector<Checkee> case_run_ref_kernel(void *input);
25
+ const char *case_get_name();
26
+ void get_error_tolerance(float *rtol, float *atol);
27
+ void case_destroy(void *input);
28
+ CheckerMode get_checker_mode();
29
+
30
+
31
+ // using OutputData = torch::Tensor;
32
+ // void ref_kernel(const BlockwiseMatmulInputs &data);
33
+ // BlockwiseMatmulInputs generate_input(int m, int n, int k, int seed);
tests/checker/metrics.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+
4
+ struct PerfMetricEntry {
5
+ char name[20];
6
+ float time;
7
+ float gflops;
8
+ };
9
+
10
+ struct PerfMetrics {
11
+ int count = 0;
12
+ PerfMetricEntry entries[20];
13
+ };
torch-ext/gemm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from ._ops import ops
4
+
5
+ def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor,
6
+ out: Optional[torch.Tensor] = None) -> torch.Tensor:
7
+
8
+ if out is None:
9
+ # Create output tensor with appropriate shape and dtype
10
+ M, K = a.shape
11
+ K_b, N = b.shape
12
+ assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows"
13
+
14
+ # Output should be BF16 type on the same device as inputs
15
+ out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
16
+
17
+ ops.gemm(out, a, b, as_, bs)
18
+ return out
19
+
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include <torch/library.h>
3
+ #include <hip/hip_runtime.h>
4
+
5
+ #include "registration.h"
6
+ #include "torch_binding.h"
7
+
8
+ // Forward declaration of the C function from gemm_launcher.hip
9
+ extern "C" {
10
+ struct PerfMetrics;
11
+ void run(void *a, void *b, void *as, void *bs, void *c, int m, int n, int k, PerfMetrics *metrics, hipStream_t job_stream0);
12
+ }
13
+
14
+ void gemm(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b,
15
+ torch::Tensor const &as, torch::Tensor const &bs) {
16
+
17
+ // Validate tensor properties
18
+ TORCH_CHECK(a.device().is_cuda(), "Input tensor a must be on GPU device");
19
+ TORCH_CHECK(b.device().is_cuda(), "Input tensor b must be on GPU device");
20
+ TORCH_CHECK(as.device().is_cuda(), "Scale tensor as must be on GPU device");
21
+ TORCH_CHECK(bs.device().is_cuda(), "Scale tensor bs must be on GPU device");
22
+ TORCH_CHECK(out.device().is_cuda(), "Output tensor out must be on GPU device");
23
+
24
+ TORCH_CHECK(a.is_contiguous(), "Input tensor a must be contiguous");
25
+ TORCH_CHECK(b.is_contiguous(), "Input tensor b must be contiguous");
26
+ TORCH_CHECK(as.is_contiguous(), "Scale tensor as must be contiguous");
27
+ TORCH_CHECK(bs.is_contiguous(), "Scale tensor bs must be contiguous");
28
+ TORCH_CHECK(out.is_contiguous(), "Output tensor out must be contiguous");
29
+
30
+ // Get matrix dimensions from tensor shapes
31
+ // Assuming a is [M, K], b is [K, N], out is [M, N]
32
+ int M = a.size(0);
33
+ int K = a.size(1);
34
+ int N = b.size(1);
35
+
36
+ TORCH_CHECK(b.size(0) == K, "Matrix dimensions mismatch: a.size(1) != b.size(0)");
37
+ TORCH_CHECK(out.size(0) == M, "Output tensor dimension mismatch: out.size(0) != M");
38
+ TORCH_CHECK(out.size(1) == N, "Output tensor dimension mismatch: out.size(1) != N");
39
+
40
+ // Use default HIP stream (stream 0)
41
+ const hipStream_t stream = 0;
42
+
43
+ // Call the C function
44
+ run(a.data_ptr(), b.data_ptr(), as.data_ptr(), bs.data_ptr(), out.data_ptr(),
45
+ M, N, K, nullptr, stream);
46
+ }
47
+
48
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
49
+ ops.def("gemm(Tensor! out, Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> ()");
50
+ ops.impl("gemm", torch::kCUDA, &gemm);
51
+ }
52
+
53
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void gemm(torch::Tensor &out, torch::Tensor const &a, torch::Tensor const &b,
6
+ torch::Tensor const &as, torch::Tensor const &bs);