medmekk HF Staff commited on
Commit
7587d0b
·
1 Parent(s): 8f4b908

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ torch-ext/__pycache__/
4
+ torch-ext/residual_rms_rocm/__pycache__/
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ RMSNorm kernel for ROCm devices from https://github.com/huggingface/hf-rocm-kernels
build.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "residual_rms_rocm"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.residual_rms_rocm]
12
+ depends = ["torch"]
13
+ backend = "rocm"
14
+ rocm-archs = [
15
+ "gfx90a",
16
+ ]
17
+ src = [
18
+ "residual_rms_rocm/residual_rms_dispatch.cu",
19
+ "residual_rms_rocm/residual_rms_scalar.cu",
20
+ "residual_rms_rocm/residual_rms_vectorized.cu",
21
+ "residual_rms_rocm/utils.h",
22
+ ]
23
+ include = ["."]
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1757675377,
77
+ "narHash": "sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "faf3354403a7381958d08e826c15fe30f6986a4f",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1759322505,
102
+ "narHash": "sha256-RzjCEn0zDfdwQp4WAb0BBuLlHxypr+4+a4BMON23SNw=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "437d0f5c253a78d0be8b5998d9c1fcf32ac2360c",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1755963616,
117
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "ref": "nixos-unstable-small",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs = { self, kernel-builder, }:
9
+ kernel-builder.lib.genFlakeOutputs {
10
+ path = ./.;
11
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
12
+ };
13
+ }
residual_rms_rocm/residual_rms_dispatch.cu ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <hip/hip_runtime.h>
4
+
5
+ #include "residual_rms_vectorized.cu"
6
+ #include "residual_rms_scalar.cu"
7
+
8
+ void residual_rms(torch::Tensor& input, // Shape: [m, n] / Layout: row-major / Dtype: fp16
9
+ torch::Tensor& residual, // Shape: [m, n] / Layout: row-major / Dtype: fp16
10
+ torch::Tensor& weight, // Shape: [m, ] / Layout: row-major / Dtype: fp16
11
+ torch::Tensor& scale_tensor, // Shape: [1, ] / Layout: row-major / Dtype: fp32
12
+ double epsilon,
13
+ torch::Tensor& output, // Shape: [m, n] / Layout: row-major / Dtype: fp8 or fp16
14
+ torch::Tensor& next_buffer, // Shape: [m, o] / Layout: dont-care / Dtype: fp16
15
+ int64_t num_threads, bool force_scalar) {
16
+ // Retrieve shapes
17
+ const int rows = input.size(0);
18
+ const int cols = input.size(1);
19
+ // Activate device guard
20
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
21
+
22
+ // Prepare kernel launch arguments
23
+ dim3 grid(rows);
24
+ dim3 block(num_threads);
25
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
26
+
27
+ // Check tensors alignment
28
+ bool vectorized_available = IS_16B_ALIGNED(input) && IS_16B_ALIGNED(residual) && IS_16B_ALIGNED(weight);
29
+ vectorized_available = vectorized_available && (!force_scalar) && (cols <= 16384);
30
+
31
+ // Case: output is fp16
32
+ if (output.dtype() == torch::kFloat16) {
33
+ vectorized_available = vectorized_available && IS_16B_ALIGNED(output);
34
+
35
+ if (vectorized_available) {
36
+ _residual_rms_vectorized<half2, false><<<grid, block, 0, stream>>>(
37
+ (half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(), (float*)NULL,
38
+ (half2*)output.data_ptr(), (half*)NULL, epsilon, cols, 0);
39
+ } else {
40
+ _residual_rms_scalar<half, false><<<grid, block, 0, stream>>>(
41
+ (half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(), (float*)NULL,
42
+ (half*)output.data_ptr(), (half*)NULL, epsilon, cols, 0);
43
+ }
44
+ }
45
+
46
+ // Case: output is fp8e3m4fnuz
47
+ else {
48
+ vectorized_available = vectorized_available && IS_8B_ALIGNED(output) && (next_buffer.size(1) % 8 == 0);
49
+
50
+ // Launch kernel
51
+ if (vectorized_available) {
52
+ _residual_rms_vectorized<__hip_fp8x2_storage_t, true><<<grid, block, 0, stream>>>(
53
+ (half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(),
54
+ (float*)scale_tensor.data_ptr(), (__hip_fp8x2_storage_t*)output.data_ptr(),
55
+ (half*)next_buffer.data_ptr(), epsilon, cols, next_buffer.size(1));
56
+ } else {
57
+ _residual_rms_scalar<__hip_fp8_storage_t, true><<<grid, block, 0, stream>>>(
58
+ (half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(),
59
+ (float*)scale_tensor.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), (half*)next_buffer.data_ptr(),
60
+ epsilon, cols, next_buffer.size(1));
61
+ }
62
+ }
63
+ }
residual_rms_rocm/residual_rms_scalar.cu ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils.h"
10
+
11
+ template <typename T, bool clean_next_buffer>
12
+ __global__ void _residual_rms_scalar(const half* __restrict__ input, half* __restrict__ residual,
13
+ const half* __restrict__ weight, const float* __restrict__ scale_tensor,
14
+ T* __restrict__ output, half* __restrict__ next_buffer, const float epsilon,
15
+ const int cols, const int buffer_cols) {
16
+ // Advance pointers according to the position of the thread in the grid
17
+ input += blockIdx.x * cols;
18
+ residual += blockIdx.x * cols;
19
+ output += blockIdx.x * cols;
20
+
21
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
22
+ float variance = 0.0f;
23
+
24
+ for (int i = threadIdx.x; i < cols; i += blockDim.x) {
25
+ half z = input[i];
26
+ z += residual[i];
27
+ float x = (float)z;
28
+ variance += (x * x);
29
+ residual[i] = z;
30
+ }
31
+ variance /= cols;
32
+
33
+ // Block reduce to compute the total norm
34
+ __shared__ float shared_normalizer;
35
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
36
+ __shared__ typename BlockReduce::TempStorage reduceStore;
37
+
38
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
39
+ if (threadIdx.x == 0) {
40
+ shared_normalizer = rsqrtf(variance + epsilon);
41
+ }
42
+ __syncthreads();
43
+
44
+ // Get inverse scale (only for fp8)
45
+ float inv_scale = 1.0f;
46
+ if constexpr (std::is_same_v<T, __hip_fp8_storage_t>) {
47
+ inv_scale = 1 / scale_tensor[0];
48
+ }
49
+
50
+ // Normalize and store
51
+ for (int idx = threadIdx.x; idx < cols; idx += blockDim.x) {
52
+ float x = (float)residual[idx];
53
+ half y = (half)(x * shared_normalizer);
54
+ y = (y * weight[idx]);
55
+
56
+ if constexpr (std::is_same_v<T, __hip_fp8_storage_t>) {
57
+ x = (float)y;
58
+ x *= inv_scale;
59
+ FP8_CLAMP(x, float);
60
+ output[idx] = __hip_cvt_float_to_fp8(x, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
61
+ }
62
+ if constexpr (std::is_same_v<T, half>) {
63
+ output[idx] = y;
64
+ }
65
+ }
66
+
67
+ // Initialize next buffer
68
+ if constexpr (clean_next_buffer) {
69
+ next_buffer += blockIdx.x * buffer_cols;
70
+ for (int i = threadIdx.x; i < buffer_cols; i += blockDim.x) {
71
+ next_buffer[i] = 0;
72
+ }
73
+ }
74
+ }
residual_rms_rocm/residual_rms_vectorized.cu ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+
3
+ #include <hip/hip_bf16.h>
4
+ #include <hip/hip_fp16.h>
5
+ #include <hipcub/util_type.hpp>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <hip/hip_fp8.h>
8
+
9
+ #include "utils.h"
10
+
11
+ #define USE_SMEM true // TODO: figure out if this is needed in practice
12
+
13
+ template <typename T, bool clean_next_buffer>
14
+ __global__ void _residual_rms_vectorized(const half* __restrict__ input, half* __restrict__ residual,
15
+ const half* __restrict__ weight, const float* __restrict__ scale_tensor,
16
+ T* __restrict__ output, // half2 or __hip_fp8x2_storage_t
17
+ half* __restrict__ next_buffer, const float epsilon, const int cols,
18
+ const int buffer_cols) {
19
+ static constexpr int elems_per_load = 8;
20
+ static constexpr int smem_size = USE_SMEM ? 16384 : 0;
21
+ __shared__ half _smem[smem_size];
22
+
23
+ // Advance pointers according to the position of the thread in the grid
24
+ input += blockIdx.x * cols + elems_per_load * threadIdx.x;
25
+ residual += blockIdx.x * cols + elems_per_load * threadIdx.x;
26
+ weight += elems_per_load * threadIdx.x;
27
+ output += (blockIdx.x * cols + elems_per_load * threadIdx.x) / 2;
28
+
29
+ half* residual_start = residual;
30
+ half* residual_smem_buffer = &_smem[0] + elems_per_load * threadIdx.x;
31
+
32
+ // Residual connection: inplace add of input to residual, accumulate norm along the way
33
+ float variance = 0.0f;
34
+ float fp32_residual;
35
+ half input_buffer[elems_per_load];
36
+ half residual_buffer[elems_per_load];
37
+
38
+ const int loop_stride = elems_per_load * blockDim.x;
39
+ const int iterations = CDIV(cols - elems_per_load * threadIdx.x, loop_stride);
40
+ for (int i = 0; i < iterations; i++) {
41
+ // Load data using 128-bits loads
42
+ #pragma unroll
43
+ for (int j = 0; j < elems_per_load; j++) {
44
+ input_buffer[j] = input[j];
45
+ }
46
+ #pragma unroll
47
+ for (int j = 0; j < elems_per_load; j++) {
48
+ residual_buffer[j] = residual[j];
49
+ }
50
+
51
+ // Add everything in the residual buffer and accumulate variance
52
+ #pragma unroll
53
+ for (int j = 0; j < elems_per_load; j++) {
54
+ residual_buffer[j] += input_buffer[j];
55
+ float float_res = (float)residual_buffer[j];
56
+ variance += float_res * float_res;
57
+ }
58
+
59
+ // 128-bits smem store
60
+ #pragma unroll
61
+ for (int j = 0; j < elems_per_load; j++) {
62
+ if constexpr (USE_SMEM) {
63
+ residual_smem_buffer[j] = residual_buffer[j];
64
+ } else {
65
+ residual[j] = residual_buffer[j];
66
+ }
67
+ }
68
+
69
+ // Advance pointers
70
+ input += loop_stride;
71
+ residual += loop_stride;
72
+ residual_smem_buffer += loop_stride;
73
+ }
74
+ variance /= cols;
75
+
76
+ // Block reduce to compute the total norm
77
+ __shared__ float shared_normalizer;
78
+ using BlockReduce = hipcub::BlockReduce<float, 1024>;
79
+ __shared__ typename BlockReduce::TempStorage reduceStore;
80
+
81
+ variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
82
+ if (threadIdx.x == 0) {
83
+ shared_normalizer = rsqrtf(variance + epsilon);
84
+ }
85
+ __syncthreads();
86
+
87
+ // Normalize and convert
88
+ __half2 weight_buffer[elems_per_load / 2];
89
+ T output_buffer[elems_per_load / 2];
90
+
91
+ // Apply inverse scale (only for fp8)
92
+ if constexpr (std::is_same_v<T, __hip_fp8x2_storage_t>) {
93
+ shared_normalizer = shared_normalizer / scale_tensor[0];
94
+ }
95
+
96
+ residual = residual_start;
97
+ residual_smem_buffer = &_smem[0] + elems_per_load * threadIdx.x;
98
+
99
+ for (int i = 0; i < iterations; i++) {
100
+ // 128-bits loads
101
+ #pragma unroll
102
+ for (int j = 0; j < elems_per_load; j++) {
103
+ if constexpr (USE_SMEM) {
104
+ residual_buffer[j] = residual_smem_buffer[j];
105
+ } else {
106
+ residual_buffer[j] = residual[j];
107
+ }
108
+ }
109
+ #pragma unroll
110
+ for (int j = 0; j < elems_per_load / 2; j++) {
111
+ weight_buffer[j] = reinterpret_cast<const __half2*>(weight)[j];
112
+ }
113
+
114
+ // 128b store
115
+ #pragma unroll
116
+ for (int j = 0; j < elems_per_load; j++) {
117
+ residual[j] = residual_buffer[j];
118
+ }
119
+
120
+ // Compute and fill buffer
121
+ #pragma unroll
122
+ for (int j = 0; j < elems_per_load / 2; j++) {
123
+ // Output is fp8
124
+ if constexpr (std::is_same_v<T, __hip_fp8x2_storage_t>) {
125
+ __half2 tmp_res = {residual_buffer[2 * j], residual_buffer[2 * j + 1]};
126
+ // tmp_res = tmp_res * weight_buffer[j];
127
+ float2 tmp_float2 = __half22float2(tmp_res);
128
+ // INCREASES PRECISION | TODO: figure out a better test
129
+ tmp_float2 = tmp_float2 * __half22float2(weight_buffer[j]);
130
+ tmp_float2 *= shared_normalizer;
131
+
132
+ // tmp_float2.x = __builtin_amdgcn_fmed3f(tmp_float2.x, 448.0, -448.0); // TODO: are they needed?
133
+ // tmp_float2.y = __builtin_amdgcn_fmed3f(tmp_float2.y, 448.0, -448.0); // TODO: are they needed?
134
+ output_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
135
+ }
136
+
137
+ // Output is fp16
138
+ if constexpr (std::is_same_v<T, half2>) {
139
+ float2 tmp_float2;
140
+ tmp_float2.x = (float)residual_buffer[2 * j];
141
+ tmp_float2.y = (float)residual_buffer[2 * j + 1];
142
+ tmp_float2 *= shared_normalizer;
143
+ half2 tmp = {(half)tmp_float2.x, (half)tmp_float2.y};
144
+ tmp *= reinterpret_cast<const half2*>(weight_buffer)[j];
145
+ output_buffer[j] = tmp;
146
+ }
147
+ }
148
+
149
+ // 64b store
150
+ #pragma unroll
151
+ for (int j = 0; j < elems_per_load / 2; j++) {
152
+ output[j] = output_buffer[j];
153
+ }
154
+
155
+ // Advance pointers
156
+ residual += loop_stride;
157
+ residual_smem_buffer += loop_stride;
158
+ weight += loop_stride;
159
+ output += loop_stride / 2;
160
+ }
161
+
162
+ // Initialize next buffer TODO: add this as a template (eventualy w/ vector granularity)
163
+ if constexpr (clean_next_buffer) {
164
+ next_buffer += blockIdx.x * buffer_cols;
165
+ for (int i = elems_per_load * threadIdx.x; i < buffer_cols; i += elems_per_load * blockDim.x) {
166
+ #pragma unroll
167
+ for (int j = 0; j < elems_per_load; j++) {
168
+ next_buffer[i + j] = 0;
169
+ }
170
+ }
171
+ }
172
+ }
173
+
174
+ // Nb. rows Ref (μs) Pointwise (μs) Vectorized (μs)
175
+ // ---------- ---------- ---------------- -----------------
176
+ // 1 40.6864 10.4857 4.8905
177
+ // 2 42.8676 10.5499 5.04421
178
+ // 4 43.7978 10.5729 5.05962
179
+ // 8 44.0237 10.6909 5.10061
180
+ // 16 47.1026 10.7823 5.19516
181
+ // 32 56.3393 11.0192 5.45101
182
+ // 64 74.0383 14.0895 5.86153
183
+ // 128 98.3725 15.2012 6.59527
184
+ // 256 119.426 27.7393 11.5191
185
+
186
+ // Nb. rows Ref (μs) Pointwise (μs) Vectorized (μs)
187
+ // ---------- ---------- ---------------- -----------------
188
+ // 1 38.8908 10.5276 4.28524
189
+ // 2 42.8806 10.5337 4.30209
190
+ // 4 43.2694 10.6618 4.38874
191
+ // 8 43.4718 10.6979 4.41091
192
+ // 16 46.6662 10.8013 4.49634
193
+ // 32 55.7943 11.0203 4.78883
194
+ // 64 75.1326 14.1084 5.38017
195
+ // 128 100 15.129 6.31691
196
+ // 256 118.571 27.3881 11.1394
residual_rms_rocm/utils.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define WARPSIZE 64
4
+
5
+ #define FP8_CLAMP(x, type) \
6
+ x = (x > (type)448.0) ? (type)448.0 : x; \
7
+ x = (x < (type) - 448.0) ? (type) - 448.0 : x;
8
+ // TODO: reformat clamping
9
+
10
+ #define IS_8B_ALIGNED(tensor) (reinterpret_cast<std::uintptr_t>(tensor.data_ptr()) % 4 == 0)
11
+ #define IS_16B_ALIGNED(tensor) (reinterpret_cast<std::uintptr_t>(tensor.data_ptr()) % 16 == 0)
12
+
13
+ #define CDIV(a, b) ((a + b - 1) / (b))
14
+
15
+ #define FP8_MAX 224.0f // TODO: check if this or 448.0f
torch-ext/residual_rms_rocm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .wrapped_rms import residual_rms, reference_residual_rms
2
+
3
+ __all__ = ["residual_rms", "reference_residual_rms"]
torch-ext/residual_rms_rocm/wrapped_rms.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ _HIGHEST_RESIDUAL_RMS_MODE = 3
9
+
10
+
11
+ def residual_rms_checks(
12
+ input: Tensor,
13
+ residual: Tensor,
14
+ weight: Tensor,
15
+ scale_tensor: Tensor,
16
+ epsilon: float,
17
+ next_buffer: Tensor,
18
+ ) -> None:
19
+ # Check shapes
20
+ assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
21
+ assert residual.shape == input.shape, \
22
+ f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
23
+ assert weight.shape == (input.size(1), ), \
24
+ f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
25
+ # Check devices
26
+ device = input.device
27
+ assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
28
+ assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
29
+ if scale_tensor is not None:
30
+ assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
31
+ assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
32
+ # Check layouts
33
+ assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
34
+ assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
35
+ # Check scalars
36
+ assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
37
+
38
+
39
+ def residual_rms_choose_mode(
40
+ input: Tensor,
41
+ residual: Tensor,
42
+ weight: Tensor,
43
+ next_buffer: Tensor,
44
+ mode: int,
45
+ ) -> int:
46
+ cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
47
+ tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
48
+ if mode == -1:
49
+ mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
50
+ elif mode > 0:
51
+ assert tensors_are_16b_aligned, (
52
+ f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
53
+ f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
54
+ )
55
+ assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
56
+ return mode
57
+
58
+
59
+ def infer_num_threads(rows: int, num_threads: int) -> int:
60
+ # Error case
61
+ if num_threads < 0 or num_threads > 1024:
62
+ raise ValueError(f"{num_threads = } is not between 0 and 1024")
63
+ # Case: num_threads was specified
64
+ elif num_threads != 0:
65
+ return num_threads
66
+ # Otherwise, we branch upon the number of rows
67
+ if rows <= 16:
68
+ return 1024
69
+ if rows <= 32:
70
+ return 768
71
+ if rows <= 64:
72
+ return 1024
73
+ if rows <= 256:
74
+ return 960
75
+ return 1024
76
+
77
+ ## Main kernel
78
+ def residual_rms(
79
+ input: Tensor,
80
+ residual: Tensor,
81
+ weight: Tensor,
82
+ epsilon: float,
83
+ scale_tensor: Optional[Tensor] = None,
84
+ next_buffer: Optional[Tensor] = None,
85
+ num_threads: int = 0,
86
+ force_scalar: bool = False,
87
+ ) -> Tuple[Tensor, Tensor]:
88
+ """Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
89
+ modified inplace (residual <- input + residual).
90
+ Args:
91
+ - input: a fp16 tensor of shape (rows, cols) in row-major format
92
+ - residual: a fp16 tensor of shape (rows, cols) in row-major format
93
+ - weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
94
+ - epsilon: the small epsilon used inside the RMS norm to avoid division by zero
95
+ - scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
96
+ set to None, then the output dtype is fp16
97
+ - next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
98
+ - num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
99
+ Outputs:
100
+ - an fp8 tensor of shape (rows, cols) in row-major format
101
+ - the residual modified in place
102
+ """
103
+ if next_buffer is None:
104
+ next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
105
+
106
+ residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
107
+ num_threads = infer_num_threads(input.size(0), num_threads)
108
+
109
+ if scale_tensor is not None:
110
+ output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
111
+ else:
112
+ # TODO: here, we could use input as the output tensor
113
+ output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
114
+ ops.residual_rms(
115
+ input=input,
116
+ residual=residual,
117
+ weight=weight,
118
+ scale_tensor=scale_tensor,
119
+ epsilon=epsilon,
120
+ output=output,
121
+ next_buffer=next_buffer,
122
+ num_threads=num_threads,
123
+ force_scalar=force_scalar,
124
+ )
125
+ return output, residual
126
+
127
+ ## Reference implementation
128
+ def fp8_quantize(
129
+ x_full_precision: Tensor,
130
+ scale: Tensor,
131
+ ) -> Tuple[Tensor, Tensor]:
132
+ finfo = torch.finfo(torch.float8_e4m3fn)
133
+ x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
134
+ x_quantized = x_quantized.to(torch.float8_e4m3fn)
135
+ weight_as_int8 = x_quantized.view(torch.int8)
136
+ ROCM_FP8_NAN_AS_INT = -128
137
+ mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
138
+ weight_as_int8[mask] = 0
139
+ x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
140
+ return x_quantized, scale * 2.0
141
+
142
+ def reference_residual_rms(
143
+ input: Tensor,
144
+ residual: Tensor,
145
+ weight: Tensor,
146
+ epsilon: float,
147
+ scale_tensor: Optional[Tensor],
148
+ next_buffer: Optional[Tensor] = None,
149
+ ) -> Tuple[Tensor, Tensor, float]:
150
+ """Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
151
+ the scale needs to be passed a tensor and not a float."""
152
+ assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
153
+ assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
154
+ input += residual
155
+ residual = input
156
+ input = reference_rms(input, epsilon)
157
+ if weight.dtype in [torch.float16, torch.bfloat16]:
158
+ input = input.to(weight.dtype)
159
+ input = weight * input
160
+ if scale_tensor is not None:
161
+ qinput, scale_tensor = fp8_quantize(input, scale_tensor)
162
+ if next_buffer is not None:
163
+ next_buffer.fill_(0)
164
+ else:
165
+ qinput = input
166
+ return qinput, residual, scale_tensor
167
+
168
+ def reference_rms(x: Tensor, eps: float) -> Tensor:
169
+ x = x.to(torch.float32)
170
+ variance = x.pow(2).mean(-1, keepdim=True)
171
+ return x * torch.rsqrt(variance + eps)
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+
7
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
+ ops.def("residual_rms(Tensor input, Tensor residual, Tensor weight, Tensor scale_tensor, double epsilon, Tensor! output, Tensor next_buffer, int64_t num_threads, bool force_scalar) -> ()");
9
+ ops.impl("residual_rms", torch::kCUDA, &residual_rms);
10
+ }
11
+
12
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void residual_rms(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& scale_tensor, double epsilon, torch::Tensor& output, torch::Tensor& next_buffer, int64_t num_threads, bool force_scalar);