#include #include "ATen/ATen.h" #include #include "registration.h" typedef at::BFloat16 bf16; void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s); void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s); void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); } void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_forward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); } void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_forward_with_state(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); } void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_forward_with_state_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), s.data_ptr()); } void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); } void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { const int B = k.size(0); const int T = k.size(1); const int C = k.size(2); cuda_backward_bf16(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); } TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("forward", forward); ops.impl("forward", torch::kCUDA, &forward); ops.def("forward_bf16", forward_bf16); ops.impl("forward_bf16", torch::kCUDA, &forward_bf16); ops.def("forward_with_state", forward_with_state); ops.impl("forward_with_state", torch::kCUDA, &forward_with_state); ops.def("forward_with_state_bf16", forward_with_state_bf16); ops.impl("forward_with_state_bf16", torch::kCUDA, &forward_with_state_bf16); ops.def("backward", backward); ops.impl("backward", torch::kCUDA, &backward); ops.def("backward_bf16", backward_bf16); ops.impl("backward_bf16", torch::kCUDA, &backward_bf16); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)