kernels
danieldk HF Staff commited on
Commit
0a93654
·
1 Parent(s): 8849043

The kernel source is now on GitHub

Browse files
README.md CHANGED
@@ -1,11 +1,14 @@
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
- - kernel
5
  ---
6
 
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/rotary)
8
 
9
  ## rotary
10
 
11
- rotary embedding kernel from [Flash Attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary).
 
 
 
 
1
  ---
2
  license: bsd-3-clause
3
  tags:
4
+ - kernel
5
  ---
6
 
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/rotary)
8
 
9
  ## rotary
10
 
11
+ rotary embedding kernel from [Flash Attention](https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary).
12
+
13
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/rotary
14
+
build.toml DELETED
@@ -1,19 +0,0 @@
1
- [general]
2
- name = "rotary"
3
- universal = false
4
-
5
- [torch]
6
- src = ["torch-ext/torch_binding.cpp"]
7
-
8
- [kernel.activation]
9
- backend = "cuda"
10
- depends = ["torch"]
11
- src = ["rotary/rotary_cuda.cu"]
12
-
13
- [kernel.rotary_xpu]
14
- backend = "xpu"
15
- depends = ["torch"]
16
- src = [
17
- "rotary-xpu/rotary_xpu.cpp",
18
- "rotary-xpu/rotary_xpu.hpp",
19
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1747046372,
21
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1759493343,
77
- "narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1759501552,
102
- "narHash": "sha256-Wnrw3l22y9jdL4C9TGxznIB4qiQznWLtU9ykCbK49EE=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "ed5722d95d9395fbc7d0239a97208f2b04147dfa",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1755963616,
117
- "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
- "owner": "nixos",
119
- "repo": "nixpkgs",
120
- "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "nixos",
125
- "ref": "nixos-unstable-small",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,11 +0,0 @@
1
- {
2
- description = "Flake for Torch kernel extension";
3
- inputs = {
4
- kernel-builder.url = "github:huggingface/kernel-builder";
5
- };
6
- outputs = { self, kernel-builder, }:
7
- kernel-builder.lib.genFlakeOutputs {
8
- inherit self;
9
- path = ./.;
10
- };
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
rotary-xpu/rotary_xpu.cpp DELETED
@@ -1,40 +0,0 @@
1
- #include <torch/all.h>
2
- #include "rotary_xpu.hpp"
3
-
4
- void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
5
- torch::Tensor const &cos, torch::Tensor const &sin,
6
- torch::Tensor &out1, torch::Tensor &out2,
7
- bool const conj) {
8
- auto iter = at::TensorIteratorConfig()
9
- .add_output(out1)
10
- .add_output(out2)
11
- .add_input(x1)
12
- .add_input(x2)
13
- .add_input(cos)
14
- .add_input(sin)
15
- .check_all_same_dtype(false)
16
- .promote_inputs_to_common_dtype(false)
17
- .build();
18
-
19
- if (!conj) {
20
- AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
21
- gpu_kernel_multiple_outputs(
22
- iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
23
- scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
24
- scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
25
- scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
26
- return {out1, out2};
27
- });
28
- });
29
- } else {
30
- AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
31
- gpu_kernel_multiple_outputs(
32
- iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
33
- scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
34
- scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
35
- scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
36
- return {out1, out2};
37
- });
38
- });
39
- }
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rotary-xpu/rotary_xpu.hpp DELETED
@@ -1,375 +0,0 @@
1
- #include <ATen/core/TensorBody.h>
2
- #include <ATen/detail/FunctionTraits.h>
3
- #include <ATen/native/TensorIterator.h>
4
- #include <sycl/sycl.hpp>
5
- #include <ATen/core/Array.h>
6
- #include <c10/macros/Macros.h>
7
- #include <c10/util/Exception.h>
8
- #include <c10/util/TypeCast.h>
9
- #include <cstdint>
10
- #include <type_traits>
11
- #include <array>
12
- #include <c10/core/ScalarType.h>
13
- #include <c10/xpu/XPUStream.h>
14
- #include <ATen/xpu/XPUContext.h>
15
-
16
- constexpr int MAX_DIMS = 12;
17
-
18
- struct LoadWithoutCast {
19
- template <typename scalar_t>
20
- C10_DEVICE scalar_t load(char* base_ptr, uint32_t offset, int arg) {
21
- return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
22
- }
23
- };
24
-
25
- struct StoreWithoutCast {
26
- template <typename scalar_t>
27
- C10_DEVICE void store(scalar_t value, char* base_ptr, uint32_t offset, int arg = 0) {
28
- *(reinterpret_cast<scalar_t*>(base_ptr) + offset) = value;
29
- }
30
- };
31
-
32
- template <template <int i> typename func, int end, int current = 0>
33
- struct static_unroll {
34
- template <typename... Args>
35
- static inline C10_HOST_DEVICE void with_args(Args&&... args) {
36
- func<current>::apply(std::forward<Args>(args)...);
37
- static_unroll<func, end, current + 1>::with_args(args...);
38
- }
39
- };
40
-
41
- template <template <int i> typename func, int end>
42
- struct static_unroll<func, end, end> {
43
- template <typename... Args>
44
- static inline C10_HOST_DEVICE void with_args(Args... args) {}
45
- };
46
-
47
- template <int current>
48
- struct multi_outputs_store_helper {
49
- template <int ntensors, int num_outputs, typename... Args>
50
- static C10_HOST_DEVICE void apply(
51
- at::detail::Array<char*, ntensors> data,
52
- at::detail::Array<uint32_t, num_outputs> offsets,
53
- std::tuple<Args...> ret) {
54
- using T = typename std::tuple_element<current, std::tuple<Args...>>::type;
55
- T* to = reinterpret_cast<T*>(data[current]) + offsets[current];
56
- *to = std::get<current>(ret);
57
- }
58
- };
59
-
60
- template <int arg_index>
61
- struct unroll_load_helper {
62
- template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
63
- static C10_DEVICE void apply(
64
- policy_t& self,
65
- args_t* args,
66
- offset_t offset,
67
- loader_t loader,
68
- int j,
69
- int num_outputs) {
70
- using arg_t = std::tuple_element_t<arg_index, args_t>;
71
- std::get<arg_index>(args[j]) = loader.template load<arg_t>(
72
- self.data[arg_index + num_outputs], offset[arg_index], arg_index);
73
- }
74
- };
75
-
76
- template <int item_work_size, typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
77
- struct multi_outputs_unroll {
78
- data_t data;
79
- int remaining;
80
- inp_calc_t input_offset_calculator;
81
- out_calc_t output_offset_calculator;
82
- LoadWithoutCast loader;
83
- StoreWithoutCast storer;
84
- int item_idx;
85
- int group_idx;
86
- int num_items_per_group;
87
- int group_work_size;
88
-
89
- multi_outputs_unroll(
90
- data_t data,
91
- int remaining,
92
- inp_calc_t ic,
93
- out_calc_t oc,
94
- int item_idx,
95
- int group_idx,
96
- int num_items_per_group)
97
- : data(data),
98
- remaining(remaining),
99
- input_offset_calculator(ic),
100
- output_offset_calculator(oc),
101
- item_idx(item_idx),
102
- group_idx(group_idx),
103
- num_items_per_group(num_items_per_group),
104
- group_work_size(item_work_size * num_items_per_group) {}
105
-
106
- inline bool check_inbounds(int item_work_elem) const {
107
- return (item_idx + item_work_elem * num_items_per_group < remaining);
108
- }
109
-
110
- template <typename args_t>
111
- inline void load(args_t* args) {
112
- constexpr int arity = std::tuple_size<args_t>::value;
113
- int item_idx_ = item_idx;
114
- #pragma unroll
115
- for (int i = 0; i < item_work_size; i++) {
116
- if (item_idx_ >= remaining) {
117
- return;
118
- }
119
- int linear_idx = item_idx_ + group_work_size * group_idx;
120
- auto offset = input_offset_calculator.get(linear_idx);
121
- static_unroll<unroll_load_helper, arity>::with_args(
122
- *this, args, offset, loader, i, num_outputs);
123
- item_idx_ += num_items_per_group;
124
- }
125
- }
126
-
127
- template <typename return_t>
128
- inline void store(return_t* from) {
129
- int item_idx_ = item_idx;
130
- #pragma unroll
131
- for (int i = 0; i < item_work_size; i++) {
132
- if (item_idx_ >= this->remaining) {
133
- return;
134
- }
135
- int linear_idx = item_idx_ + group_work_size * group_idx;
136
- auto offsets = this->output_offset_calculator.get(linear_idx);
137
- static_unroll<multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
138
- item_idx_ += num_items_per_group;
139
- }
140
- }
141
- };
142
-
143
- template <int item_work_size, typename func_t, typename policy_t>
144
- inline void elementwise_kernel_helper(func_t f, policy_t policy) {
145
- using traits = function_traits<func_t>;
146
- using return_t = typename traits::result_type;
147
- using args_t = typename traits::ArgsTuple;
148
-
149
- return_t results[item_work_size];
150
- args_t args[item_work_size];
151
-
152
- policy.load(args);
153
-
154
- #pragma unroll
155
- for (int i = 0; i < item_work_size; i++) {
156
- if (policy.check_inbounds(i)) {
157
- results[i] = std::apply(f, args[i]);
158
- }
159
- }
160
-
161
- policy.store(results);
162
- }
163
-
164
- template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
165
- struct UnrolledElementwiseForMultiOutputsKernel {
166
- static constexpr int item_work_size = 4;
167
-
168
- void operator()(sycl::nd_item<1> item_id) const {
169
- int grpsz = item_id.get_local_range(0);
170
- int grpid = item_id.get_group(0);
171
- int lid = item_id.get_local_id(0);
172
- int remaining = numel_ - item_work_size * grpsz * grpid;
173
- auto policy = multi_outputs_unroll<item_work_size, array_t, in_calc_t, out_calc_t, num_outputs>(
174
- data_, remaining, ic_, oc_, lid, grpid, grpsz);
175
- elementwise_kernel_helper<item_work_size>(f_, policy);
176
- };
177
-
178
- UnrolledElementwiseForMultiOutputsKernel(int numel, func_t f, array_t data, in_calc_t ic, out_calc_t oc)
179
- : numel_(numel), f_(f), data_(data), ic_(ic), oc_(oc) {}
180
-
181
- private:
182
- int numel_;
183
- func_t f_;
184
- array_t data_;
185
- in_calc_t ic_;
186
- out_calc_t oc_;
187
- };
188
-
189
- template <typename Value>
190
- struct IntDivider {
191
- IntDivider() = default;
192
- IntDivider(Value d) : divisor(d) {}
193
-
194
- C10_HOST_DEVICE inline Value div(Value n) const {
195
- return n / divisor;
196
- }
197
- C10_HOST_DEVICE inline Value mod(Value n) const {
198
- return n % divisor;
199
- }
200
- C10_HOST_DEVICE inline auto divmod(Value n) const {
201
- return std::make_pair(n / divisor, n % divisor);
202
- }
203
-
204
- Value divisor;
205
- };
206
-
207
- template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
208
- struct OffsetCalculator {
209
- using stride_t = std::conditional_t<signed_strides, std::make_signed_t<index_t>, index_t>;
210
- using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
211
-
212
- OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes = nullptr)
213
- : dims(dims) {
214
- TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
215
- for (int i = 0; i < dims; i++) {
216
- sizes_[i] = IntDivider<index_t>(sizes[i]);
217
- for (int arg = 0; arg < NARGS; arg++) {
218
- int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
219
- strides_[i][arg] = strides[arg][i] / element_size;
220
- }
221
- }
222
- }
223
-
224
- C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
225
- offset_type offsets;
226
- #pragma unroll
227
- for (int arg = 0; arg < NARGS; arg++) {
228
- offsets[arg] = 0;
229
- }
230
-
231
- #pragma unroll
232
- for (int dim = 0; dim < MAX_DIMS; ++dim) {
233
- if (dim == dims) {
234
- break;
235
- }
236
- auto divmod = sizes_[dim].divmod(linear_idx);
237
- linear_idx = divmod.first;
238
-
239
- #pragma unroll
240
- for (int arg = 0; arg < NARGS; arg++) {
241
- offsets[arg] += divmod.second * strides_[dim][arg];
242
- }
243
- }
244
- return offsets;
245
- }
246
-
247
- int dims;
248
- IntDivider<index_t> sizes_[MAX_DIMS];
249
- stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
250
- };
251
-
252
- template <int N>
253
- static OffsetCalculator<N> make_input_offset_calculator(const at::TensorIteratorBase& iter) {
254
- constexpr int array_size = std::max<int>(N, 1);
255
- TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
256
- std::array<const int64_t*, array_size> strides;
257
- int64_t element_sizes[array_size];
258
- for (int i = 0; i < N; i++) {
259
- strides[i] = iter.strides(i + iter.noutputs()).data();
260
- element_sizes[i] = iter.element_size(i + iter.noutputs());
261
- }
262
- return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
263
- }
264
-
265
- template <int num_outputs = 1>
266
- static OffsetCalculator<num_outputs> make_output_offset_calculator(const at::TensorIteratorBase& iter) {
267
- TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
268
- std::array<const int64_t*, num_outputs> strides;
269
- int64_t element_sizes[num_outputs];
270
- for (int i = 0; i < num_outputs; i++) {
271
- strides[i] = iter.strides(i).data();
272
- element_sizes[i] = iter.element_size(i);
273
- }
274
- return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
275
- }
276
-
277
- static inline int64_t syclMaxWorkItemsPerSubSlice(at::DeviceIndex dev_id = c10::xpu::getCurrentXPUStream().device_index()) {
278
- auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
279
- int64_t simd_width = dev_prop->sub_group_sizes[0];
280
- int64_t eu_count = dev_prop->gpu_eu_count_per_subslice;
281
- return simd_width * eu_count;
282
- }
283
-
284
- template<class T>
285
- T ceil_div(T dividend, T divisor) {
286
- return (dividend + divisor - 1) / divisor;
287
- }
288
-
289
- template <typename ker_t>
290
- static inline void sycl_kernel_submit(int64_t global_range, int64_t local_range, ::sycl::queue q, ker_t ker) {
291
- q.parallel_for(
292
- sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)),
293
- ker
294
- );
295
- }
296
-
297
- template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
298
- static inline void launch_unrolled_kernel_for_multi_outputs(
299
- int64_t N,
300
- const func_t& f,
301
- array_t data,
302
- in_calc_t ic,
303
- out_calc_t oc) {
304
- TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
305
-
306
- auto ker = UnrolledElementwiseForMultiOutputsKernel<num_outputs, func_t, array_t, in_calc_t, out_calc_t>(N, f, data, ic, oc);
307
- using ker_t = decltype(ker);
308
-
309
- int wg_sz = syclMaxWorkItemsPerSubSlice();
310
- int num_wg = ceil_div<int>(N, ker_t::item_work_size * wg_sz);
311
- sycl_kernel_submit(wg_sz * num_wg, wg_sz, c10::xpu::getCurrentXPUStream().queue(), ker);
312
- }
313
-
314
- template <int N>
315
- struct TrivialOffsetCalculator {
316
- using offset_type = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
317
-
318
- C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const {
319
- offset_type offsets;
320
- #pragma unroll
321
- for (int arg = 0; arg < N; arg++) {
322
- offsets[arg] = linear_idx;
323
- }
324
- return offsets;
325
- }
326
- };
327
-
328
- template <typename func_t>
329
- void gpu_kernel_multiple_outputs_impl(at::TensorIteratorBase& iter, const func_t& f) {
330
- using traits = function_traits<func_t>;
331
- using output_t = typename traits::result_type;
332
- constexpr int num_outputs = std::tuple_size<output_t>::value;
333
- constexpr int num_inputs = traits::arity;
334
- constexpr int ntensors = num_outputs + num_inputs;
335
-
336
- TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
337
- TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
338
-
339
- at::detail::Array<char*, ntensors> data;
340
- for (int i = 0; i < ntensors; i++) {
341
- data[i] = (char*)iter.data_ptr(i);
342
- }
343
-
344
- int64_t numel = iter.numel();
345
-
346
- if (iter.is_contiguous()) {
347
- auto input_calc = TrivialOffsetCalculator<num_inputs>();
348
- auto output_calc = TrivialOffsetCalculator<num_outputs>();
349
- launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
350
- } else {
351
- auto input_calc = make_input_offset_calculator<num_inputs>(iter);
352
- auto output_calc = make_output_offset_calculator<num_outputs>(iter);
353
- launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
354
- }
355
- }
356
-
357
- template <typename func_t>
358
- void gpu_kernel_multiple_outputs(at::TensorIteratorBase& iter, const func_t& f) {
359
- for (int arg = 0; arg < iter.ntensors(); arg++) {
360
- TORCH_INTERNAL_ASSERT(iter.device(arg).is_xpu());
361
- }
362
-
363
- if (iter.numel() == 0) {
364
- return;
365
- }
366
-
367
- if (!iter.can_use_32bit_indexing()) {
368
- for (auto& sub_iter : iter.with_32bit_indexing()) {
369
- gpu_kernel_multiple_outputs(sub_iter, f);
370
- }
371
- return;
372
- }
373
-
374
- gpu_kernel_multiple_outputs_impl(iter, f);
375
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rotary/rotary_cuda.cu DELETED
@@ -1,45 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <torch/all.h>
6
- #include <ATen/native/TensorIterator.h>
7
- #include <ATen/native/cuda/Loops.cuh>
8
-
9
- void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
10
- torch::Tensor const &cos, torch::Tensor const &sin,
11
- torch::Tensor &out1, torch::Tensor &out2,
12
- bool const conj) {
13
- auto iter = at::TensorIteratorConfig()
14
- .add_output(out1)
15
- .add_output(out2)
16
- .add_input(x1)
17
- .add_input(x2)
18
- .add_input(cos)
19
- .add_input(sin)
20
- .check_all_same_dtype(false)
21
- .promote_inputs_to_common_dtype(false)
22
- .build();
23
-
24
- if (!conj) {
25
- AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
26
- at::native::gpu_kernel_multiple_outputs(
27
- iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
28
- scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
29
- scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
30
- scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
31
- return {out1, out2};
32
- });
33
- });
34
- } else {
35
- AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
36
- at::native::gpu_kernel_multiple_outputs(
37
- iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
38
- scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
39
- scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
40
- scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
41
- return {out1, out2};
42
- });
43
- });
44
- }
45
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/__init__.py DELETED
File without changes
tests/test_rotary.py DELETED
@@ -1,130 +0,0 @@
1
- import pytest
2
- import torch
3
-
4
- from tests.utils import infer_device, supports_bfloat16
5
- from pathlib import Path
6
-
7
- # import rotary
8
- # from transformers.trainer_utils import set_seed
9
- # set_seed(42)
10
-
11
- # Set the local repo path, relative path
12
- try:
13
- import rotary
14
- except ImportError:
15
- from kernels import get_local_kernel
16
- repo_path = Path(__file__).parent.parent
17
- rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
18
-
19
- def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
20
- assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
21
-
22
- if not conj:
23
- out1 = x1 * cos - x2 * sin
24
- out2 = x1 * sin + x2 * cos
25
- else:
26
- out1 = x1 * cos + x2 * sin
27
- out2 = -x1 * sin + x2 * cos
28
- return out1, out2
29
-
30
-
31
- def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
32
- """the wrapper for apply_rotary_torch"""
33
- rotary_dim = cos.shape[-1]
34
-
35
- # apply rotation encoding to Q
36
- q1 = q[..., :rotary_dim]
37
- q2 = q[..., rotary_dim : 2 * rotary_dim]
38
- q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
39
- q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
40
-
41
- # apply rotation encoding to K
42
- k1 = k[..., :rotary_dim]
43
- k2 = k[..., rotary_dim : 2 * rotary_dim]
44
- k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
45
- k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
46
-
47
- return q_out, k_out
48
-
49
-
50
- def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
51
- """the wrapper for apply_rotary_kernel"""
52
- rotary_dim = cos.shape[-1]
53
-
54
- # apply rotation encoding to Q
55
- q1 = q[..., :rotary_dim]
56
- q2 = q[..., rotary_dim : 2 * rotary_dim]
57
- rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
58
-
59
- # apply rotation encoding to K
60
- k1 = k[..., :rotary_dim]
61
- k2 = k[..., rotary_dim : 2 * rotary_dim]
62
- rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
63
-
64
-
65
- @pytest.mark.parametrize("batch_size", [1, 2])
66
- @pytest.mark.parametrize("nheads", [8, 16])
67
- @pytest.mark.parametrize("seqlen", [128, 256])
68
- @pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
69
- @pytest.mark.parametrize("qk_dim", [3, 4])
70
- @pytest.mark.parametrize(
71
- "dtype, atol, rtol",
72
- [
73
- (torch.float32, 1e-5, 1e-5),
74
- pytest.param(
75
- torch.bfloat16,
76
- 1e-1,
77
- 1e-5,
78
- marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
79
- ),
80
- ],
81
- )
82
- @pytest.mark.parametrize("conj", [False, True])
83
- @pytest.mark.flaky(max_runs=2, min_passes=1)
84
- def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
85
- device = infer_device()
86
- if device is None:
87
- pytest.skip("No suitable device found for testing")
88
-
89
- if qk_dim == 4:
90
- q_shape = (batch_size, seqlen, nheads, headdim)
91
- cos_sin_shape = (seqlen, 1, rotary_dim)
92
- elif qk_dim == 3:
93
- q_shape = (batch_size * seqlen, nheads, headdim)
94
- cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
95
-
96
- q_orig = torch.randn(q_shape, device=device, dtype=dtype)
97
- k_orig = torch.randn(q_shape, device=device, dtype=dtype)
98
- cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
99
- sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
100
-
101
- q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
102
- q_torch, k_torch = q_orig.clone(), k_orig.clone()
103
-
104
- q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
105
- apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
106
-
107
- # verify the rotation results of Q and K are consistent
108
- try:
109
- assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
110
- except AssertionError:
111
- diff_q = torch.abs(q_torch_out - q_kernel)
112
- max_diff_q = torch.max(diff_q)
113
- print(f"Max difference for Q: {max_diff_q}")
114
- raise
115
- try:
116
- assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
117
- except AssertionError:
118
- diff_k = torch.abs(k_torch_out - k_kernel)
119
- max_diff_k = torch.max(diff_k)
120
- print(f"Max difference for K: {max_diff_k}")
121
- raise
122
-
123
- # verify the non-rotated part of Q and K remains unchanged
124
- if (2 * rotary_dim) < headdim:
125
- assert torch.equal(
126
- q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
127
- ), "Non-rotated part of Q should be unchanged"
128
- assert torch.equal(
129
- k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
130
- ), "Non-rotated part of K should be unchanged"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/utils.py DELETED
@@ -1,23 +0,0 @@
1
- import torch
2
-
3
-
4
- def infer_device():
5
- """
6
- Get current device name based on available devices
7
- """
8
- if torch.cuda.is_available(): # Works for both Nvidia and AMD
9
- return "cuda"
10
- elif torch.xpu.is_available():
11
- return "xpu"
12
- else:
13
- return None
14
-
15
-
16
- def supports_bfloat16():
17
- device = infer_device()
18
- if device == "cuda":
19
- return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
20
- elif device == "xpu":
21
- return True
22
- else:
23
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/rotary/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- from typing import Tuple
2
- import torch
3
-
4
- from ._ops import ops
5
-
6
-
7
- def apply_rotary(
8
- x1: torch.Tensor,
9
- x2: torch.Tensor,
10
- cos: torch.Tensor,
11
- sin: torch.Tensor,
12
- out1: torch.Tensor,
13
- out2: torch.Tensor,
14
- conj: bool,
15
- ):
16
- ops.apply_rotary(x1, x2, cos, sin, out1, out2, conj)
17
-
18
-
19
- __all__ = ["apply_rotary"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp DELETED
@@ -1,54 +0,0 @@
1
- #include <torch/all.h>
2
-
3
- #if defined(CUDA_KERNEL)
4
- #include <c10/cuda/CUDAGuard.h>
5
- #elif defined(XPU_KERNEL)
6
- #include <c10/core/DeviceGuard.h>
7
- #endif
8
-
9
- #include "registration.h"
10
-
11
- #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA || x.device().type() == torch::kXPU, #x " must be on CUDA or XPU")
12
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
-
14
- void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
15
- torch::Tensor const &cos, torch::Tensor const &sin,
16
- torch::Tensor &out1, torch::Tensor &out2,
17
- bool const conj);
18
-
19
- void apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
20
- torch::Tensor const &cos, torch::Tensor const &sin,
21
- torch::Tensor &out1, torch::Tensor &out2,
22
- bool const conj) {
23
- CHECK_DEVICE(x1); CHECK_DEVICE(x2);
24
- CHECK_DEVICE(cos); CHECK_DEVICE(sin);
25
- CHECK_DEVICE(out1); CHECK_DEVICE(out1);
26
- TORCH_CHECK(x1.dtype() == x2.dtype());
27
- TORCH_CHECK(cos.dtype() == sin.dtype());
28
- TORCH_CHECK(out1.dtype() == out2.dtype());
29
- TORCH_CHECK(x1.dtype() == cos.dtype());
30
- TORCH_CHECK(x1.dtype() == out1.dtype());
31
- TORCH_CHECK(x1.sizes() == x2.sizes());
32
- TORCH_CHECK(cos.sizes() == sin.sizes());
33
- TORCH_CHECK(out1.sizes() == out2.sizes());
34
-
35
- #if defined(CUDA_KERNEL)
36
- // Otherwise the kernel will be launched from cuda:0 device
37
- at::cuda::CUDAGuard device_guard{x1.device()};
38
- #elif defined(XPU_KERNEL)
39
- c10::DeviceGuard device_guard{x1.device()};
40
- #endif
41
- _apply_rotary(x1, x2, cos, sin, out1, out2, conj);
42
- }
43
-
44
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
45
- ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
46
- "Tensor! out1, Tensor! out2, bool conj) -> ()");
47
- #if defined(CUDA_KERNEL)
48
- ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
49
- #elif defined(XPU_KERNEL)
50
- ops.impl("apply_rotary", torch::kXPU, &apply_rotary);
51
- #endif
52
- }
53
-
54
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)