first commit
Browse files- .gitignore +4 -0
- README.md +1 -0
- build.toml +23 -0
- flake.lock +168 -0
- flake.nix +13 -0
- residual_rms_rocm/residual_rms_dispatch.cu +63 -0
- residual_rms_rocm/residual_rms_scalar.cu +74 -0
- residual_rms_rocm/residual_rms_vectorized.cu +196 -0
- residual_rms_rocm/utils.h +15 -0
- torch-ext/residual_rms_rocm/__init__.py +3 -0
- torch-ext/residual_rms_rocm/wrapped_rms.py +171 -0
- torch-ext/torch_binding.cpp +12 -0
- torch-ext/torch_binding.h +5 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
torch-ext/__pycache__/
|
| 4 |
+
torch-ext/residual_rms_rocm/__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
RMSNorm kernel for ROCm devices from https://github.com/huggingface/hf-rocm-kernels
|
build.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "residual_rms_rocm"
|
| 3 |
+
universal = false
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/torch_binding.cpp",
|
| 8 |
+
"torch-ext/torch_binding.h",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
[kernel.residual_rms_rocm]
|
| 12 |
+
depends = ["torch"]
|
| 13 |
+
backend = "rocm"
|
| 14 |
+
rocm-archs = [
|
| 15 |
+
"gfx90a",
|
| 16 |
+
]
|
| 17 |
+
src = [
|
| 18 |
+
"residual_rms_rocm/residual_rms_dispatch.cu",
|
| 19 |
+
"residual_rms_rocm/residual_rms_scalar.cu",
|
| 20 |
+
"residual_rms_rocm/residual_rms_vectorized.cu",
|
| 21 |
+
"residual_rms_rocm/utils.h",
|
| 22 |
+
]
|
| 23 |
+
include = ["."]
|
flake.lock
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1747046372,
|
| 6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1747046372,
|
| 21 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1757675377,
|
| 77 |
+
"narHash": "sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "faf3354403a7381958d08e826c15fe30f6986a4f",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1759322505,
|
| 102 |
+
"narHash": "sha256-RzjCEn0zDfdwQp4WAb0BBuLlHxypr+4+a4BMON23SNw=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "437d0f5c253a78d0be8b5998d9c1fcf32ac2360c",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"repo": "kernel-builder",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nixpkgs": {
|
| 115 |
+
"locked": {
|
| 116 |
+
"lastModified": 1755963616,
|
| 117 |
+
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
| 118 |
+
"owner": "nixos",
|
| 119 |
+
"repo": "nixpkgs",
|
| 120 |
+
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
| 121 |
+
"type": "github"
|
| 122 |
+
},
|
| 123 |
+
"original": {
|
| 124 |
+
"owner": "nixos",
|
| 125 |
+
"ref": "nixos-unstable-small",
|
| 126 |
+
"repo": "nixpkgs",
|
| 127 |
+
"type": "github"
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"root": {
|
| 131 |
+
"inputs": {
|
| 132 |
+
"kernel-builder": "kernel-builder"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"systems": {
|
| 136 |
+
"locked": {
|
| 137 |
+
"lastModified": 1681028828,
|
| 138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
+
"owner": "nix-systems",
|
| 140 |
+
"repo": "default",
|
| 141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
+
"type": "github"
|
| 143 |
+
},
|
| 144 |
+
"original": {
|
| 145 |
+
"owner": "nix-systems",
|
| 146 |
+
"repo": "default",
|
| 147 |
+
"type": "github"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"systems_2": {
|
| 151 |
+
"locked": {
|
| 152 |
+
"lastModified": 1681028828,
|
| 153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
+
"owner": "nix-systems",
|
| 155 |
+
"repo": "default",
|
| 156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
+
"type": "github"
|
| 158 |
+
},
|
| 159 |
+
"original": {
|
| 160 |
+
"owner": "nix-systems",
|
| 161 |
+
"repo": "default",
|
| 162 |
+
"type": "github"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"root": "root",
|
| 167 |
+
"version": 7
|
| 168 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for Torch kernel extension";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs = { self, kernel-builder, }:
|
| 9 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 10 |
+
path = ./.;
|
| 11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 12 |
+
};
|
| 13 |
+
}
|
residual_rms_rocm/residual_rms_dispatch.cu
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 3 |
+
#include <hip/hip_runtime.h>
|
| 4 |
+
|
| 5 |
+
#include "residual_rms_vectorized.cu"
|
| 6 |
+
#include "residual_rms_scalar.cu"
|
| 7 |
+
|
| 8 |
+
void residual_rms(torch::Tensor& input, // Shape: [m, n] / Layout: row-major / Dtype: fp16
|
| 9 |
+
torch::Tensor& residual, // Shape: [m, n] / Layout: row-major / Dtype: fp16
|
| 10 |
+
torch::Tensor& weight, // Shape: [m, ] / Layout: row-major / Dtype: fp16
|
| 11 |
+
torch::Tensor& scale_tensor, // Shape: [1, ] / Layout: row-major / Dtype: fp32
|
| 12 |
+
double epsilon,
|
| 13 |
+
torch::Tensor& output, // Shape: [m, n] / Layout: row-major / Dtype: fp8 or fp16
|
| 14 |
+
torch::Tensor& next_buffer, // Shape: [m, o] / Layout: dont-care / Dtype: fp16
|
| 15 |
+
int64_t num_threads, bool force_scalar) {
|
| 16 |
+
// Retrieve shapes
|
| 17 |
+
const int rows = input.size(0);
|
| 18 |
+
const int cols = input.size(1);
|
| 19 |
+
// Activate device guard
|
| 20 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 21 |
+
|
| 22 |
+
// Prepare kernel launch arguments
|
| 23 |
+
dim3 grid(rows);
|
| 24 |
+
dim3 block(num_threads);
|
| 25 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 26 |
+
|
| 27 |
+
// Check tensors alignment
|
| 28 |
+
bool vectorized_available = IS_16B_ALIGNED(input) && IS_16B_ALIGNED(residual) && IS_16B_ALIGNED(weight);
|
| 29 |
+
vectorized_available = vectorized_available && (!force_scalar) && (cols <= 16384);
|
| 30 |
+
|
| 31 |
+
// Case: output is fp16
|
| 32 |
+
if (output.dtype() == torch::kFloat16) {
|
| 33 |
+
vectorized_available = vectorized_available && IS_16B_ALIGNED(output);
|
| 34 |
+
|
| 35 |
+
if (vectorized_available) {
|
| 36 |
+
_residual_rms_vectorized<half2, false><<<grid, block, 0, stream>>>(
|
| 37 |
+
(half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(), (float*)NULL,
|
| 38 |
+
(half2*)output.data_ptr(), (half*)NULL, epsilon, cols, 0);
|
| 39 |
+
} else {
|
| 40 |
+
_residual_rms_scalar<half, false><<<grid, block, 0, stream>>>(
|
| 41 |
+
(half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(), (float*)NULL,
|
| 42 |
+
(half*)output.data_ptr(), (half*)NULL, epsilon, cols, 0);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Case: output is fp8e3m4fnuz
|
| 47 |
+
else {
|
| 48 |
+
vectorized_available = vectorized_available && IS_8B_ALIGNED(output) && (next_buffer.size(1) % 8 == 0);
|
| 49 |
+
|
| 50 |
+
// Launch kernel
|
| 51 |
+
if (vectorized_available) {
|
| 52 |
+
_residual_rms_vectorized<__hip_fp8x2_storage_t, true><<<grid, block, 0, stream>>>(
|
| 53 |
+
(half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(),
|
| 54 |
+
(float*)scale_tensor.data_ptr(), (__hip_fp8x2_storage_t*)output.data_ptr(),
|
| 55 |
+
(half*)next_buffer.data_ptr(), epsilon, cols, next_buffer.size(1));
|
| 56 |
+
} else {
|
| 57 |
+
_residual_rms_scalar<__hip_fp8_storage_t, true><<<grid, block, 0, stream>>>(
|
| 58 |
+
(half*)input.data_ptr(), (half*)residual.data_ptr(), (half*)weight.data_ptr(),
|
| 59 |
+
(float*)scale_tensor.data_ptr(), (__hip_fp8_storage_t*)output.data_ptr(), (half*)next_buffer.data_ptr(),
|
| 60 |
+
epsilon, cols, next_buffer.size(1));
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
residual_rms_rocm/residual_rms_scalar.cu
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
|
| 3 |
+
#include <hip/hip_bf16.h>
|
| 4 |
+
#include <hip/hip_fp16.h>
|
| 5 |
+
#include <hipcub/util_type.hpp>
|
| 6 |
+
#include <hipcub/hipcub.hpp>
|
| 7 |
+
#include <hip/hip_fp8.h>
|
| 8 |
+
|
| 9 |
+
#include "utils.h"
|
| 10 |
+
|
| 11 |
+
template <typename T, bool clean_next_buffer>
|
| 12 |
+
__global__ void _residual_rms_scalar(const half* __restrict__ input, half* __restrict__ residual,
|
| 13 |
+
const half* __restrict__ weight, const float* __restrict__ scale_tensor,
|
| 14 |
+
T* __restrict__ output, half* __restrict__ next_buffer, const float epsilon,
|
| 15 |
+
const int cols, const int buffer_cols) {
|
| 16 |
+
// Advance pointers according to the position of the thread in the grid
|
| 17 |
+
input += blockIdx.x * cols;
|
| 18 |
+
residual += blockIdx.x * cols;
|
| 19 |
+
output += blockIdx.x * cols;
|
| 20 |
+
|
| 21 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
| 22 |
+
float variance = 0.0f;
|
| 23 |
+
|
| 24 |
+
for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
| 25 |
+
half z = input[i];
|
| 26 |
+
z += residual[i];
|
| 27 |
+
float x = (float)z;
|
| 28 |
+
variance += (x * x);
|
| 29 |
+
residual[i] = z;
|
| 30 |
+
}
|
| 31 |
+
variance /= cols;
|
| 32 |
+
|
| 33 |
+
// Block reduce to compute the total norm
|
| 34 |
+
__shared__ float shared_normalizer;
|
| 35 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
| 36 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 37 |
+
|
| 38 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
| 39 |
+
if (threadIdx.x == 0) {
|
| 40 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
| 41 |
+
}
|
| 42 |
+
__syncthreads();
|
| 43 |
+
|
| 44 |
+
// Get inverse scale (only for fp8)
|
| 45 |
+
float inv_scale = 1.0f;
|
| 46 |
+
if constexpr (std::is_same_v<T, __hip_fp8_storage_t>) {
|
| 47 |
+
inv_scale = 1 / scale_tensor[0];
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Normalize and store
|
| 51 |
+
for (int idx = threadIdx.x; idx < cols; idx += blockDim.x) {
|
| 52 |
+
float x = (float)residual[idx];
|
| 53 |
+
half y = (half)(x * shared_normalizer);
|
| 54 |
+
y = (y * weight[idx]);
|
| 55 |
+
|
| 56 |
+
if constexpr (std::is_same_v<T, __hip_fp8_storage_t>) {
|
| 57 |
+
x = (float)y;
|
| 58 |
+
x *= inv_scale;
|
| 59 |
+
FP8_CLAMP(x, float);
|
| 60 |
+
output[idx] = __hip_cvt_float_to_fp8(x, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
| 61 |
+
}
|
| 62 |
+
if constexpr (std::is_same_v<T, half>) {
|
| 63 |
+
output[idx] = y;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Initialize next buffer
|
| 68 |
+
if constexpr (clean_next_buffer) {
|
| 69 |
+
next_buffer += blockIdx.x * buffer_cols;
|
| 70 |
+
for (int i = threadIdx.x; i < buffer_cols; i += blockDim.x) {
|
| 71 |
+
next_buffer[i] = 0;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
residual_rms_rocm/residual_rms_vectorized.cu
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/all.h>
|
| 2 |
+
|
| 3 |
+
#include <hip/hip_bf16.h>
|
| 4 |
+
#include <hip/hip_fp16.h>
|
| 5 |
+
#include <hipcub/util_type.hpp>
|
| 6 |
+
#include <hipcub/hipcub.hpp>
|
| 7 |
+
#include <hip/hip_fp8.h>
|
| 8 |
+
|
| 9 |
+
#include "utils.h"
|
| 10 |
+
|
| 11 |
+
#define USE_SMEM true // TODO: figure out if this is needed in practice
|
| 12 |
+
|
| 13 |
+
template <typename T, bool clean_next_buffer>
|
| 14 |
+
__global__ void _residual_rms_vectorized(const half* __restrict__ input, half* __restrict__ residual,
|
| 15 |
+
const half* __restrict__ weight, const float* __restrict__ scale_tensor,
|
| 16 |
+
T* __restrict__ output, // half2 or __hip_fp8x2_storage_t
|
| 17 |
+
half* __restrict__ next_buffer, const float epsilon, const int cols,
|
| 18 |
+
const int buffer_cols) {
|
| 19 |
+
static constexpr int elems_per_load = 8;
|
| 20 |
+
static constexpr int smem_size = USE_SMEM ? 16384 : 0;
|
| 21 |
+
__shared__ half _smem[smem_size];
|
| 22 |
+
|
| 23 |
+
// Advance pointers according to the position of the thread in the grid
|
| 24 |
+
input += blockIdx.x * cols + elems_per_load * threadIdx.x;
|
| 25 |
+
residual += blockIdx.x * cols + elems_per_load * threadIdx.x;
|
| 26 |
+
weight += elems_per_load * threadIdx.x;
|
| 27 |
+
output += (blockIdx.x * cols + elems_per_load * threadIdx.x) / 2;
|
| 28 |
+
|
| 29 |
+
half* residual_start = residual;
|
| 30 |
+
half* residual_smem_buffer = &_smem[0] + elems_per_load * threadIdx.x;
|
| 31 |
+
|
| 32 |
+
// Residual connection: inplace add of input to residual, accumulate norm along the way
|
| 33 |
+
float variance = 0.0f;
|
| 34 |
+
float fp32_residual;
|
| 35 |
+
half input_buffer[elems_per_load];
|
| 36 |
+
half residual_buffer[elems_per_load];
|
| 37 |
+
|
| 38 |
+
const int loop_stride = elems_per_load * blockDim.x;
|
| 39 |
+
const int iterations = CDIV(cols - elems_per_load * threadIdx.x, loop_stride);
|
| 40 |
+
for (int i = 0; i < iterations; i++) {
|
| 41 |
+
// Load data using 128-bits loads
|
| 42 |
+
#pragma unroll
|
| 43 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 44 |
+
input_buffer[j] = input[j];
|
| 45 |
+
}
|
| 46 |
+
#pragma unroll
|
| 47 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 48 |
+
residual_buffer[j] = residual[j];
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// Add everything in the residual buffer and accumulate variance
|
| 52 |
+
#pragma unroll
|
| 53 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 54 |
+
residual_buffer[j] += input_buffer[j];
|
| 55 |
+
float float_res = (float)residual_buffer[j];
|
| 56 |
+
variance += float_res * float_res;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// 128-bits smem store
|
| 60 |
+
#pragma unroll
|
| 61 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 62 |
+
if constexpr (USE_SMEM) {
|
| 63 |
+
residual_smem_buffer[j] = residual_buffer[j];
|
| 64 |
+
} else {
|
| 65 |
+
residual[j] = residual_buffer[j];
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Advance pointers
|
| 70 |
+
input += loop_stride;
|
| 71 |
+
residual += loop_stride;
|
| 72 |
+
residual_smem_buffer += loop_stride;
|
| 73 |
+
}
|
| 74 |
+
variance /= cols;
|
| 75 |
+
|
| 76 |
+
// Block reduce to compute the total norm
|
| 77 |
+
__shared__ float shared_normalizer;
|
| 78 |
+
using BlockReduce = hipcub::BlockReduce<float, 1024>;
|
| 79 |
+
__shared__ typename BlockReduce::TempStorage reduceStore;
|
| 80 |
+
|
| 81 |
+
variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x);
|
| 82 |
+
if (threadIdx.x == 0) {
|
| 83 |
+
shared_normalizer = rsqrtf(variance + epsilon);
|
| 84 |
+
}
|
| 85 |
+
__syncthreads();
|
| 86 |
+
|
| 87 |
+
// Normalize and convert
|
| 88 |
+
__half2 weight_buffer[elems_per_load / 2];
|
| 89 |
+
T output_buffer[elems_per_load / 2];
|
| 90 |
+
|
| 91 |
+
// Apply inverse scale (only for fp8)
|
| 92 |
+
if constexpr (std::is_same_v<T, __hip_fp8x2_storage_t>) {
|
| 93 |
+
shared_normalizer = shared_normalizer / scale_tensor[0];
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
residual = residual_start;
|
| 97 |
+
residual_smem_buffer = &_smem[0] + elems_per_load * threadIdx.x;
|
| 98 |
+
|
| 99 |
+
for (int i = 0; i < iterations; i++) {
|
| 100 |
+
// 128-bits loads
|
| 101 |
+
#pragma unroll
|
| 102 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 103 |
+
if constexpr (USE_SMEM) {
|
| 104 |
+
residual_buffer[j] = residual_smem_buffer[j];
|
| 105 |
+
} else {
|
| 106 |
+
residual_buffer[j] = residual[j];
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
#pragma unroll
|
| 110 |
+
for (int j = 0; j < elems_per_load / 2; j++) {
|
| 111 |
+
weight_buffer[j] = reinterpret_cast<const __half2*>(weight)[j];
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// 128b store
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 117 |
+
residual[j] = residual_buffer[j];
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Compute and fill buffer
|
| 121 |
+
#pragma unroll
|
| 122 |
+
for (int j = 0; j < elems_per_load / 2; j++) {
|
| 123 |
+
// Output is fp8
|
| 124 |
+
if constexpr (std::is_same_v<T, __hip_fp8x2_storage_t>) {
|
| 125 |
+
__half2 tmp_res = {residual_buffer[2 * j], residual_buffer[2 * j + 1]};
|
| 126 |
+
// tmp_res = tmp_res * weight_buffer[j];
|
| 127 |
+
float2 tmp_float2 = __half22float2(tmp_res);
|
| 128 |
+
// INCREASES PRECISION | TODO: figure out a better test
|
| 129 |
+
tmp_float2 = tmp_float2 * __half22float2(weight_buffer[j]);
|
| 130 |
+
tmp_float2 *= shared_normalizer;
|
| 131 |
+
|
| 132 |
+
// tmp_float2.x = __builtin_amdgcn_fmed3f(tmp_float2.x, 448.0, -448.0); // TODO: are they needed?
|
| 133 |
+
// tmp_float2.y = __builtin_amdgcn_fmed3f(tmp_float2.y, 448.0, -448.0); // TODO: are they needed?
|
| 134 |
+
output_buffer[j] = __hip_cvt_float2_to_fp8x2(tmp_float2, __HIP_SATFINITE, __HIP_E4M3_FNUZ);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
// Output is fp16
|
| 138 |
+
if constexpr (std::is_same_v<T, half2>) {
|
| 139 |
+
float2 tmp_float2;
|
| 140 |
+
tmp_float2.x = (float)residual_buffer[2 * j];
|
| 141 |
+
tmp_float2.y = (float)residual_buffer[2 * j + 1];
|
| 142 |
+
tmp_float2 *= shared_normalizer;
|
| 143 |
+
half2 tmp = {(half)tmp_float2.x, (half)tmp_float2.y};
|
| 144 |
+
tmp *= reinterpret_cast<const half2*>(weight_buffer)[j];
|
| 145 |
+
output_buffer[j] = tmp;
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// 64b store
|
| 150 |
+
#pragma unroll
|
| 151 |
+
for (int j = 0; j < elems_per_load / 2; j++) {
|
| 152 |
+
output[j] = output_buffer[j];
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// Advance pointers
|
| 156 |
+
residual += loop_stride;
|
| 157 |
+
residual_smem_buffer += loop_stride;
|
| 158 |
+
weight += loop_stride;
|
| 159 |
+
output += loop_stride / 2;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Initialize next buffer TODO: add this as a template (eventualy w/ vector granularity)
|
| 163 |
+
if constexpr (clean_next_buffer) {
|
| 164 |
+
next_buffer += blockIdx.x * buffer_cols;
|
| 165 |
+
for (int i = elems_per_load * threadIdx.x; i < buffer_cols; i += elems_per_load * blockDim.x) {
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (int j = 0; j < elems_per_load; j++) {
|
| 168 |
+
next_buffer[i + j] = 0;
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// Nb. rows Ref (μs) Pointwise (μs) Vectorized (μs)
|
| 175 |
+
// ---------- ---------- ---------------- -----------------
|
| 176 |
+
// 1 40.6864 10.4857 4.8905
|
| 177 |
+
// 2 42.8676 10.5499 5.04421
|
| 178 |
+
// 4 43.7978 10.5729 5.05962
|
| 179 |
+
// 8 44.0237 10.6909 5.10061
|
| 180 |
+
// 16 47.1026 10.7823 5.19516
|
| 181 |
+
// 32 56.3393 11.0192 5.45101
|
| 182 |
+
// 64 74.0383 14.0895 5.86153
|
| 183 |
+
// 128 98.3725 15.2012 6.59527
|
| 184 |
+
// 256 119.426 27.7393 11.5191
|
| 185 |
+
|
| 186 |
+
// Nb. rows Ref (μs) Pointwise (μs) Vectorized (μs)
|
| 187 |
+
// ---------- ---------- ---------------- -----------------
|
| 188 |
+
// 1 38.8908 10.5276 4.28524
|
| 189 |
+
// 2 42.8806 10.5337 4.30209
|
| 190 |
+
// 4 43.2694 10.6618 4.38874
|
| 191 |
+
// 8 43.4718 10.6979 4.41091
|
| 192 |
+
// 16 46.6662 10.8013 4.49634
|
| 193 |
+
// 32 55.7943 11.0203 4.78883
|
| 194 |
+
// 64 75.1326 14.1084 5.38017
|
| 195 |
+
// 128 100 15.129 6.31691
|
| 196 |
+
// 256 118.571 27.3881 11.1394
|
residual_rms_rocm/utils.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#define WARPSIZE 64
|
| 4 |
+
|
| 5 |
+
#define FP8_CLAMP(x, type) \
|
| 6 |
+
x = (x > (type)448.0) ? (type)448.0 : x; \
|
| 7 |
+
x = (x < (type) - 448.0) ? (type) - 448.0 : x;
|
| 8 |
+
// TODO: reformat clamping
|
| 9 |
+
|
| 10 |
+
#define IS_8B_ALIGNED(tensor) (reinterpret_cast<std::uintptr_t>(tensor.data_ptr()) % 4 == 0)
|
| 11 |
+
#define IS_16B_ALIGNED(tensor) (reinterpret_cast<std::uintptr_t>(tensor.data_ptr()) % 16 == 0)
|
| 12 |
+
|
| 13 |
+
#define CDIV(a, b) ((a + b - 1) / (b))
|
| 14 |
+
|
| 15 |
+
#define FP8_MAX 224.0f // TODO: check if this or 448.0f
|
torch-ext/residual_rms_rocm/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .wrapped_rms import residual_rms, reference_residual_rms
|
| 2 |
+
|
| 3 |
+
__all__ = ["residual_rms", "reference_residual_rms"]
|
torch-ext/residual_rms_rocm/wrapped_rms.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional
|
| 2 |
+
import torch
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_HIGHEST_RESIDUAL_RMS_MODE = 3
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def residual_rms_checks(
|
| 12 |
+
input: Tensor,
|
| 13 |
+
residual: Tensor,
|
| 14 |
+
weight: Tensor,
|
| 15 |
+
scale_tensor: Tensor,
|
| 16 |
+
epsilon: float,
|
| 17 |
+
next_buffer: Tensor,
|
| 18 |
+
) -> None:
|
| 19 |
+
# Check shapes
|
| 20 |
+
assert input.dim() == 2, f"Expected input to have 2 dimensions but got {input.dim() = } instead."
|
| 21 |
+
assert residual.shape == input.shape, \
|
| 22 |
+
f"Expected input and residual to have same shape but got {input.shape = } and {residual.shape = }"
|
| 23 |
+
assert weight.shape == (input.size(1), ), \
|
| 24 |
+
f"Expected weight to have shape {(input.size(1), ) = } but got {weight.shape = }"
|
| 25 |
+
# Check devices
|
| 26 |
+
device = input.device
|
| 27 |
+
assert device.type == "cuda", f"Expected input.device to be of type cuda, but got {device.type = } instead."
|
| 28 |
+
assert residual.device == device, f"Expected {residual.device = } to be the same as {input.device = }"
|
| 29 |
+
if scale_tensor is not None:
|
| 30 |
+
assert scale_tensor.device == device, f"Expected {scale_tensor.device = } to be the same as {input.device = }"
|
| 31 |
+
assert next_buffer.device == device, f"Expected {next_buffer.device = } to be the same as {input.device = }"
|
| 32 |
+
# Check layouts
|
| 33 |
+
assert input.is_contiguous(), f"Expected input to be contiguous but got {input.stride() = }"
|
| 34 |
+
assert residual.is_contiguous(), f"Expected residual to be contiguous but got {residual.stride() = }"
|
| 35 |
+
# Check scalars
|
| 36 |
+
assert epsilon > 0, f"Expected RMS epsilon to be > 0 to avoid division by zero, but got {epsilon = }"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def residual_rms_choose_mode(
|
| 40 |
+
input: Tensor,
|
| 41 |
+
residual: Tensor,
|
| 42 |
+
weight: Tensor,
|
| 43 |
+
next_buffer: Tensor,
|
| 44 |
+
mode: int,
|
| 45 |
+
) -> int:
|
| 46 |
+
cols_is_multiple_of_8 = (input.size(1) % 8 == 0) and (next_buffer.size(1) % 8 == 0)
|
| 47 |
+
tensors_are_16b_aligned = all([x.data_ptr() % 16 == 0 for x in [input, residual, weight]])
|
| 48 |
+
if mode == -1:
|
| 49 |
+
mode = _HIGHEST_RESIDUAL_RMS_MODE if (tensors_are_16b_aligned and cols_is_multiple_of_8) else 0
|
| 50 |
+
elif mode > 0:
|
| 51 |
+
assert tensors_are_16b_aligned, (
|
| 52 |
+
f"Requested a {mode = } > 0 requires tensors to be 16 bits aligned but got {input.data_ptr() % 16 = }, "
|
| 53 |
+
f"{residual.data_ptr() % 16 = }, {weight.data_ptr() % 16 = }"
|
| 54 |
+
)
|
| 55 |
+
assert cols_is_multiple_of_8, f"Requested {mode = } requires {input.size(1) = } to be a multiple of 8."
|
| 56 |
+
return mode
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def infer_num_threads(rows: int, num_threads: int) -> int:
|
| 60 |
+
# Error case
|
| 61 |
+
if num_threads < 0 or num_threads > 1024:
|
| 62 |
+
raise ValueError(f"{num_threads = } is not between 0 and 1024")
|
| 63 |
+
# Case: num_threads was specified
|
| 64 |
+
elif num_threads != 0:
|
| 65 |
+
return num_threads
|
| 66 |
+
# Otherwise, we branch upon the number of rows
|
| 67 |
+
if rows <= 16:
|
| 68 |
+
return 1024
|
| 69 |
+
if rows <= 32:
|
| 70 |
+
return 768
|
| 71 |
+
if rows <= 64:
|
| 72 |
+
return 1024
|
| 73 |
+
if rows <= 256:
|
| 74 |
+
return 960
|
| 75 |
+
return 1024
|
| 76 |
+
|
| 77 |
+
## Main kernel
|
| 78 |
+
def residual_rms(
|
| 79 |
+
input: Tensor,
|
| 80 |
+
residual: Tensor,
|
| 81 |
+
weight: Tensor,
|
| 82 |
+
epsilon: float,
|
| 83 |
+
scale_tensor: Optional[Tensor] = None,
|
| 84 |
+
next_buffer: Optional[Tensor] = None,
|
| 85 |
+
num_threads: int = 0,
|
| 86 |
+
force_scalar: bool = False,
|
| 87 |
+
) -> Tuple[Tensor, Tensor]:
|
| 88 |
+
"""Kernel that fuses a residual connection, an RMS normalization and a conversion to fp8. The resdiual argument is
|
| 89 |
+
modified inplace (residual <- input + residual).
|
| 90 |
+
Args:
|
| 91 |
+
- input: a fp16 tensor of shape (rows, cols) in row-major format
|
| 92 |
+
- residual: a fp16 tensor of shape (rows, cols) in row-major format
|
| 93 |
+
- weight: a fp16 tensor of shape (cols, ) in row-major format which contains the weight of the RMS norm
|
| 94 |
+
- epsilon: the small epsilon used inside the RMS norm to avoid division by zero
|
| 95 |
+
- scale_tensor: a fp32 one-item tensor to divide the output of the RMS norm before their conversion to fp8. If
|
| 96 |
+
set to None, then the output dtype is fp16
|
| 97 |
+
- next_buffer: an optional tensor of shape (rows, .) to initialize to zero if the output dtype in fp8
|
| 98 |
+
- num_threads: the number of threads per block in the kernel. Default value is 0, which then defaults to 1024
|
| 99 |
+
Outputs:
|
| 100 |
+
- an fp8 tensor of shape (rows, cols) in row-major format
|
| 101 |
+
- the residual modified in place
|
| 102 |
+
"""
|
| 103 |
+
if next_buffer is None:
|
| 104 |
+
next_buffer = torch.empty(size=(input.size(0), 0), device=input.device, dtype=torch.float16)
|
| 105 |
+
|
| 106 |
+
residual_rms_checks(input, residual, weight, scale_tensor, epsilon, next_buffer)
|
| 107 |
+
num_threads = infer_num_threads(input.size(0), num_threads)
|
| 108 |
+
|
| 109 |
+
if scale_tensor is not None:
|
| 110 |
+
output = torch.empty(size=input.shape, dtype=torch.float8_e4m3fnuz, device=input.device)
|
| 111 |
+
else:
|
| 112 |
+
# TODO: here, we could use input as the output tensor
|
| 113 |
+
output = torch.empty(size=input.shape, dtype=torch.float16, device=input.device)
|
| 114 |
+
ops.residual_rms(
|
| 115 |
+
input=input,
|
| 116 |
+
residual=residual,
|
| 117 |
+
weight=weight,
|
| 118 |
+
scale_tensor=scale_tensor,
|
| 119 |
+
epsilon=epsilon,
|
| 120 |
+
output=output,
|
| 121 |
+
next_buffer=next_buffer,
|
| 122 |
+
num_threads=num_threads,
|
| 123 |
+
force_scalar=force_scalar,
|
| 124 |
+
)
|
| 125 |
+
return output, residual
|
| 126 |
+
|
| 127 |
+
## Reference implementation
|
| 128 |
+
def fp8_quantize(
|
| 129 |
+
x_full_precision: Tensor,
|
| 130 |
+
scale: Tensor,
|
| 131 |
+
) -> Tuple[Tensor, Tensor]:
|
| 132 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
| 133 |
+
x_quantized = (x_full_precision * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
| 134 |
+
x_quantized = x_quantized.to(torch.float8_e4m3fn)
|
| 135 |
+
weight_as_int8 = x_quantized.view(torch.int8)
|
| 136 |
+
ROCM_FP8_NAN_AS_INT = -128
|
| 137 |
+
mask = weight_as_int8 == ROCM_FP8_NAN_AS_INT
|
| 138 |
+
weight_as_int8[mask] = 0
|
| 139 |
+
x_quantized = weight_as_int8.view(torch.float8_e4m3fnuz)
|
| 140 |
+
return x_quantized, scale * 2.0
|
| 141 |
+
|
| 142 |
+
def reference_residual_rms(
|
| 143 |
+
input: Tensor,
|
| 144 |
+
residual: Tensor,
|
| 145 |
+
weight: Tensor,
|
| 146 |
+
epsilon: float,
|
| 147 |
+
scale_tensor: Optional[Tensor],
|
| 148 |
+
next_buffer: Optional[Tensor] = None,
|
| 149 |
+
) -> Tuple[Tensor, Tensor, float]:
|
| 150 |
+
"""Reference for the residual_rms operation. Check its docstring for more details, the only difference here is that
|
| 151 |
+
the scale needs to be passed a tensor and not a float."""
|
| 152 |
+
assert input.dtype == torch.float16, f"Expected torch.float16 but got {input.dtype = }"
|
| 153 |
+
assert residual.dtype == torch.float16, f"Expected torch.float16 but got {residual.dtype = }"
|
| 154 |
+
input += residual
|
| 155 |
+
residual = input
|
| 156 |
+
input = reference_rms(input, epsilon)
|
| 157 |
+
if weight.dtype in [torch.float16, torch.bfloat16]:
|
| 158 |
+
input = input.to(weight.dtype)
|
| 159 |
+
input = weight * input
|
| 160 |
+
if scale_tensor is not None:
|
| 161 |
+
qinput, scale_tensor = fp8_quantize(input, scale_tensor)
|
| 162 |
+
if next_buffer is not None:
|
| 163 |
+
next_buffer.fill_(0)
|
| 164 |
+
else:
|
| 165 |
+
qinput = input
|
| 166 |
+
return qinput, residual, scale_tensor
|
| 167 |
+
|
| 168 |
+
def reference_rms(x: Tensor, eps: float) -> Tensor:
|
| 169 |
+
x = x.to(torch.float32)
|
| 170 |
+
variance = x.pow(2).mean(-1, keepdim=True)
|
| 171 |
+
return x * torch.rsqrt(variance + eps)
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 8 |
+
ops.def("residual_rms(Tensor input, Tensor residual, Tensor weight, Tensor scale_tensor, double epsilon, Tensor! output, Tensor next_buffer, int64_t num_threads, bool force_scalar) -> ()");
|
| 9 |
+
ops.impl("residual_rms", torch::kCUDA, &residual_rms);
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void residual_rms(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& scale_tensor, double epsilon, torch::Tensor& output, torch::Tensor& next_buffer, int64_t num_threads, bool force_scalar);
|