drbh
		
	commited on
		
		
					Commit 
							
							·
						
						3224250
	
1
								Parent(s):
							
							a585153
								
feat: vendor grouped gemm
Browse files- build.toml +4 -0
 - csrc/grouped_gemm/fill_arguments.cuh +141 -0
 - csrc/grouped_gemm/grouped_gemm.cu +567 -0
 - csrc/grouped_gemm/grouped_gemm.h +20 -0
 - csrc/grouped_gemm/ops.cu +11 -0
 - tests/ops_test.py +170 -0
 - tests/test_gg.py +57 -0
 - torch-ext/megablocks/__init__.py +9 -5
 - torch-ext/megablocks/grouped_gemm/__init__.py +2 -0
 - torch-ext/megablocks/grouped_gemm/backend.py +32 -0
 - torch-ext/megablocks/grouped_gemm/ops.py +33 -0
 - torch-ext/megablocks/grouped_gemm_util.py +8 -3
 - torch-ext/megablocks/layers/__init__.py +1 -1
 - torch-ext/torch_binding.cpp +12 -0
 
    	
        build.toml
    CHANGED
    
    | 
         @@ -35,4 +35,8 @@ src = [ 
     | 
|
| 35 | 
         
             
                "csrc/new_replicate.h",
         
     | 
| 36 | 
         
             
                "csrc/new_sort.h",
         
     | 
| 37 | 
         
             
                "csrc/new_sort.cu",
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 38 | 
         
             
            ]
         
     | 
| 
         | 
|
| 35 | 
         
             
                "csrc/new_replicate.h",
         
     | 
| 36 | 
         
             
                "csrc/new_sort.h",
         
     | 
| 37 | 
         
             
                "csrc/new_sort.cu",
         
     | 
| 38 | 
         
            +
                # vendored grouped gemm
         
     | 
| 39 | 
         
            +
                "csrc/grouped_gemm/fill_arguments.cuh",
         
     | 
| 40 | 
         
            +
                "csrc/grouped_gemm/grouped_gemm.cu",
         
     | 
| 41 | 
         
            +
                "csrc/grouped_gemm/grouped_gemm.h",
         
     | 
| 42 | 
         
             
            ]
         
     | 
    	
        csrc/grouped_gemm/fill_arguments.cuh
    ADDED
    
    | 
         @@ -0,0 +1,141 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #pragma once
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            #include <ATen/cuda/detail/KernelUtils.h>
         
     | 
| 4 | 
         
            +
            #include <cub/cub.cuh>
         
     | 
| 5 | 
         
            +
            #include <cutlass/bfloat16.h>
         
     | 
| 6 | 
         
            +
            #include <cutlass/gemm_coord.h>
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            namespace grouped_gemm {
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            constexpr int kDynamicDim = -1;
         
     | 
| 11 | 
         
            +
            constexpr int kMaxExperts = 512;
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            struct GemmProblem {
         
     | 
| 14 | 
         
            +
              ::cutlass::gemm::GemmCoord dims;
         
     | 
| 15 | 
         
            +
              int64_t lda, ldb, ldc;
         
     | 
| 16 | 
         
            +
              // All offsets are in elements.
         
     | 
| 17 | 
         
            +
              int64_t a_offset, b_offset, c_offset;
         
     | 
| 18 | 
         
            +
            };
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            // TODO: revisit `ExtractGemmProblemK` struct
         
     | 
| 21 | 
         
            +
            // struct ExtractGemmProblemK {
         
     | 
| 22 | 
         
            +
            //   __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
         
     | 
| 23 | 
         
            +
            //       return {problem.dims.k()};
         
     | 
| 24 | 
         
            +
            //   }
         
     | 
| 25 | 
         
            +
            // };
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            template <
         
     | 
| 28 | 
         
            +
                // If `k` is dynamic, we sort the problems by `k` in descending order.
         
     | 
| 29 | 
         
            +
                // Otherwise, `m` is dynamic, and no sorting happens.
         
     | 
| 30 | 
         
            +
                bool kDynamicK,
         
     | 
| 31 | 
         
            +
                typename ElementA, typename ElementB, typename ElementC,
         
     | 
| 32 | 
         
            +
                typename LayoutA, typename LayoutB, typename LayoutC,
         
     | 
| 33 | 
         
            +
                typename Args
         
     | 
| 34 | 
         
            +
            >
         
     | 
| 35 | 
         
            +
            __global__ void FillArguments(
         
     | 
| 36 | 
         
            +
                int num_experts, const int64_t* batch_sizes,
         
     | 
| 37 | 
         
            +
                ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
         
     | 
| 38 | 
         
            +
                Args args, ::cutlass::gemm::GemmCoord dims
         
     | 
| 39 | 
         
            +
            ) {
         
     | 
| 40 | 
         
            +
              const int expert_idx = threadIdx.x;
         
     | 
| 41 | 
         
            +
              const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
              if (kDynamicK) {
         
     | 
| 44 | 
         
            +
                assert(dims.k() == kDynamicDim);
         
     | 
| 45 | 
         
            +
                dims.k() = batch_size;
         
     | 
| 46 | 
         
            +
              } else {
         
     | 
| 47 | 
         
            +
                assert(dims.m() == kDynamicDim);
         
     | 
| 48 | 
         
            +
                dims.m() = batch_size;
         
     | 
| 49 | 
         
            +
              }
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
              using BlockScan = cub::BlockScan<int, kMaxExperts>;
         
     | 
| 52 | 
         
            +
              using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
              union SharedMemory {
         
     | 
| 55 | 
         
            +
                typename BlockScan::TempStorage scan_storage;
         
     | 
| 56 | 
         
            +
                typename BlockSort::TempStorage sort_storage;
         
     | 
| 57 | 
         
            +
              };
         
     | 
| 58 | 
         
            +
              __shared__ SharedMemory shared_memory;
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
              int dynamic_dim = kDynamicK ? dims.k() : dims.m();
         
     | 
| 61 | 
         
            +
              int dynamic_dim_cumsum;
         
     | 
| 62 | 
         
            +
              BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
         
     | 
| 63 | 
         
            +
              __syncthreads();
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
              // We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
         
     | 
| 66 | 
         
            +
              // `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
         
     | 
| 67 | 
         
            +
              GemmProblem problem[1] = {
         
     | 
| 68 | 
         
            +
                GemmProblem {
         
     | 
| 69 | 
         
            +
                  .dims = dims,
         
     | 
| 70 | 
         
            +
                  .lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
         
     | 
| 71 | 
         
            +
                  .ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
         
     | 
| 72 | 
         
            +
                  .ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
         
     | 
| 73 | 
         
            +
                  .a_offset = kDynamicK
         
     | 
| 74 | 
         
            +
                      ? (dims.m() * dynamic_dim_cumsum)
         
     | 
| 75 | 
         
            +
                      : (dynamic_dim_cumsum * dims.k()),
         
     | 
| 76 | 
         
            +
                  .b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
         
     | 
| 77 | 
         
            +
                  .c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
         
     | 
| 78 | 
         
            +
                },
         
     | 
| 79 | 
         
            +
              };
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
              if constexpr (kDynamicK) {
         
     | 
| 82 | 
         
            +
                // Sort by k dimension in descending order
         
     | 
| 83 | 
         
            +
                // We need to extract the key (k value) for sorting
         
     | 
| 84 | 
         
            +
                int k_keys[1] = { problem[0].dims.k() };
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
         
     | 
| 87 | 
         
            +
                
         
     | 
| 88 | 
         
            +
                // TODO: revisit original impl without `__syncthreads()`
         
     | 
| 89 | 
         
            +
                // BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
         
     | 
| 90 | 
         
            +
                // Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
         
     | 
| 91 | 
         
            +
                // > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
         
     | 
| 92 | 
         
            +
                // > is **to be reused or repurposed**.
         
     | 
| 93 | 
         
            +
                // We don't need `__syncthreads()` here, since we don't do either of these things.
         
     | 
| 94 | 
         
            +
              }
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
              if (expert_idx < num_experts) {
         
     | 
| 97 | 
         
            +
                args.problem_sizes[expert_idx] = problem[0].dims;
         
     | 
| 98 | 
         
            +
                args.lda[expert_idx] = problem[0].lda;
         
     | 
| 99 | 
         
            +
                args.ldb[expert_idx] = problem[0].ldb;
         
     | 
| 100 | 
         
            +
                args.ldc[expert_idx] = problem[0].ldc;
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
         
     | 
| 103 | 
         
            +
                args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
         
     | 
| 104 | 
         
            +
                args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
         
     | 
| 105 | 
         
            +
              }
         
     | 
| 106 | 
         
            +
            }
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            template <typename Args>
         
     | 
| 109 | 
         
            +
            __global__ void ZeroOutK0Outputs(int num_experts, Args args) {
         
     | 
| 110 | 
         
            +
              const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
         
     | 
| 111 | 
         
            +
              const int64_t delta     = (int64_t)gridDim.x * blockDim.x;
         
     | 
| 112 | 
         
            +
              for (int ei = 0; ei < num_experts; ++ei) {
         
     | 
| 113 | 
         
            +
                auto& dims = args.problem_sizes[ei];
         
     | 
| 114 | 
         
            +
                // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
         
     | 
| 115 | 
         
            +
                // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
         
     | 
| 116 | 
         
            +
                //   * (here) set the output to zero
         
     | 
| 117 | 
         
            +
                //   * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
         
     | 
| 118 | 
         
            +
                if (dims.k() == 0) {
         
     | 
| 119 | 
         
            +
                  // Assume packed layout, run a grid-strided loop over the output.
         
     | 
| 120 | 
         
            +
                  int64_t total_elems = (int64_t)dims.m() * dims.n();
         
     | 
| 121 | 
         
            +
                  auto* out           = args.ptr_C[ei];
         
     | 
| 122 | 
         
            +
                  for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
         
     | 
| 123 | 
         
            +
                    out[idx] = {};
         
     | 
| 124 | 
         
            +
                  }
         
     | 
| 125 | 
         
            +
                }
         
     | 
| 126 | 
         
            +
              }
         
     | 
| 127 | 
         
            +
            }
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            template <typename Args>
         
     | 
| 130 | 
         
            +
            __global__ void IgnoreK0Problems(int num_experts, Args args) {
         
     | 
| 131 | 
         
            +
              const int expert_idx = threadIdx.x;
         
     | 
| 132 | 
         
            +
              if (expert_idx < num_experts) {
         
     | 
| 133 | 
         
            +
                auto& dims = args.problem_sizes[expert_idx];
         
     | 
| 134 | 
         
            +
                if (dims.k() == 0) {
         
     | 
| 135 | 
         
            +
                  dims.m() = 0;
         
     | 
| 136 | 
         
            +
                  dims.n() = 0;
         
     | 
| 137 | 
         
            +
                }
         
     | 
| 138 | 
         
            +
              }
         
     | 
| 139 | 
         
            +
            }
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            }  // namespace grouped_gemm
         
     | 
    	
        csrc/grouped_gemm/grouped_gemm.cu
    ADDED
    
    | 
         @@ -0,0 +1,567 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #include "grouped_gemm.h"
         
     | 
| 2 | 
         
            +
            #include "fill_arguments.cuh"
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            #include <ATen/cuda/CUDAContext.h>
         
     | 
| 5 | 
         
            +
            #include <ATen/cuda/detail/KernelUtils.h>
         
     | 
| 6 | 
         
            +
            #include <c10/util/BFloat16.h>
         
     | 
| 7 | 
         
            +
            #include <c10/cuda/CUDAStream.h>
         
     | 
| 8 | 
         
            +
            #include <cub/cub.cuh>
         
     | 
| 9 | 
         
            +
            #include <torch/torch.h>
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            #include "cutlass/bfloat16.h"
         
     | 
| 12 | 
         
            +
            #include "cutlass/complex.h"
         
     | 
| 13 | 
         
            +
            #include "cutlass/gemm/kernel/gemm_grouped.h"
         
     | 
| 14 | 
         
            +
            #include "cutlass/gemm/kernel/default_gemm_grouped.h"
         
     | 
| 15 | 
         
            +
            #include "cutlass/gemm/device/gemm_grouped.h"
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            #include <type_traits>
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            namespace grouped_gemm {
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            #define CUDA_CALL(code)					    \
         
     | 
| 22 | 
         
            +
              do {                                                      \
         
     | 
| 23 | 
         
            +
                cudaError_t status = code;                              \
         
     | 
| 24 | 
         
            +
                std::string err = cudaGetErrorString(status);           \
         
     | 
| 25 | 
         
            +
                TORCH_CHECK(status == cudaSuccess, err);		    \
         
     | 
| 26 | 
         
            +
              } while (0)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            #define CUBLAS_CALL(code)					  \
         
     | 
| 29 | 
         
            +
              do {								  \
         
     | 
| 30 | 
         
            +
                cublasStatus_t status = code;				  \
         
     | 
| 31 | 
         
            +
                TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
         
     | 
| 32 | 
         
            +
              } while (0)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
         
     | 
| 35 | 
         
            +
            #define GROUPED_GEMM_STRINGIFY(x) \
         
     | 
| 36 | 
         
            +
              GROUPED_GEMM_STRINGIFY_HELPER(x)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            template <bool trans>
         
     | 
| 39 | 
         
            +
            using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
         
     | 
| 42 | 
         
            +
              ::cutlass::arch::OpClassTensorOp,
         
     | 
| 43 | 
         
            +
              ::cutlass::arch::Sm80,
         
     | 
| 44 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 45 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 46 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 47 | 
         
            +
              float
         
     | 
| 48 | 
         
            +
            >;
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            // TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
         
     | 
| 51 | 
         
            +
            template <bool trans_a, bool trans_b>
         
     | 
| 52 | 
         
            +
            using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
         
     | 
| 53 | 
         
            +
              // A operand.
         
     | 
| 54 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 55 | 
         
            +
              GroupedGemmInputLayout<trans_a>,
         
     | 
| 56 | 
         
            +
              ::cutlass::ComplexTransform::kNone,
         
     | 
| 57 | 
         
            +
              GroupedGemmConfig::kAlignmentA,
         
     | 
| 58 | 
         
            +
              // B operand.
         
     | 
| 59 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 60 | 
         
            +
              GroupedGemmInputLayout<trans_b>,
         
     | 
| 61 | 
         
            +
              ::cutlass::ComplexTransform::kNone,
         
     | 
| 62 | 
         
            +
              GroupedGemmConfig::kAlignmentB,
         
     | 
| 63 | 
         
            +
              // C operand.
         
     | 
| 64 | 
         
            +
              ::cutlass::bfloat16_t,
         
     | 
| 65 | 
         
            +
              ::cutlass::layout::RowMajor,
         
     | 
| 66 | 
         
            +
              float,
         
     | 
| 67 | 
         
            +
              ::cutlass::arch::OpClassTensorOp,
         
     | 
| 68 | 
         
            +
              ::cutlass::arch::Sm80,
         
     | 
| 69 | 
         
            +
              GroupedGemmConfig::ThreadblockShape,
         
     | 
| 70 | 
         
            +
              GroupedGemmConfig::WarpShape,
         
     | 
| 71 | 
         
            +
              GroupedGemmConfig::InstructionShape,
         
     | 
| 72 | 
         
            +
              GroupedGemmConfig::EpilogueOutputOp,
         
     | 
| 73 | 
         
            +
              // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
         
     | 
| 74 | 
         
            +
              // This parameter is passed in at present to match the APIs of other kernels. The parameter
         
     | 
| 75 | 
         
            +
              // is unused within the kernel.
         
     | 
| 76 | 
         
            +
              ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
         
     | 
| 77 | 
         
            +
              // TODO(tgale): Tune this for SM90.
         
     | 
| 78 | 
         
            +
              GroupedGemmConfig::kStages>::GemmKernel;
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            template <bool trans_a, bool trans_b>
         
     | 
| 81 | 
         
            +
            using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            template <typename T>
         
     | 
| 84 | 
         
            +
            torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
         
     | 
| 85 | 
         
            +
              size_t bytes = x.size() * sizeof(T);
         
     | 
| 86 | 
         
            +
              auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
         
     | 
| 87 | 
         
            +
              torch::Tensor out = torch::empty(bytes, options);
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
              CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
         
     | 
| 90 | 
         
            +
            			    x.data(), bytes,
         
     | 
| 91 | 
         
            +
            			    cudaMemcpyHostToDevice,
         
     | 
| 92 | 
         
            +
            			    c10::cuda::getCurrentCUDAStream()));
         
     | 
| 93 | 
         
            +
              return out;
         
     | 
| 94 | 
         
            +
            }
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            template <typename T>
         
     | 
| 97 | 
         
            +
            static void ReorderArray(T* data, const std::vector<size_t>& indices) {
         
     | 
| 98 | 
         
            +
                // For now, simply create a copy of the data and then copy over to the original.
         
     | 
| 99 | 
         
            +
                std::vector<T> copy(data, data + indices.size());
         
     | 
| 100 | 
         
            +
                for (size_t i = 0; i < indices.size(); ++i) {
         
     | 
| 101 | 
         
            +
                    data[i] = copy.at(indices[i]);
         
     | 
| 102 | 
         
            +
                }
         
     | 
| 103 | 
         
            +
            }
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            template <typename T>
         
     | 
| 106 | 
         
            +
            torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
         
     | 
| 107 | 
         
            +
                return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
         
     | 
| 108 | 
         
            +
            }
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            struct RawGemmArguments {
         
     | 
| 111 | 
         
            +
              torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
         
     | 
| 112 | 
         
            +
              int threadblock_count{};
         
     | 
| 113 | 
         
            +
            };
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            template <
         
     | 
| 116 | 
         
            +
              typename Gemm,
         
     | 
| 117 | 
         
            +
              typename ElementA, typename ElementB, typename ElementC
         
     | 
| 118 | 
         
            +
            >
         
     | 
| 119 | 
         
            +
            RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
         
     | 
| 120 | 
         
            +
                TORCH_CHECK(
         
     | 
| 121 | 
         
            +
                    num_experts <= kMaxExperts,
         
     | 
| 122 | 
         
            +
                    "At most ", kMaxExperts,
         
     | 
| 123 | 
         
            +
                    " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
         
     | 
| 124 | 
         
            +
                );
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                return RawGemmArguments {
         
     | 
| 127 | 
         
            +
                  .lda = TypedEmpty<int64_t>(num_experts, device),
         
     | 
| 128 | 
         
            +
                  .ldb = TypedEmpty<int64_t>(num_experts, device),
         
     | 
| 129 | 
         
            +
                  .ldc = TypedEmpty<int64_t>(num_experts, device),
         
     | 
| 130 | 
         
            +
                  .ptr_a = TypedEmpty<ElementA*>(num_experts, device),
         
     | 
| 131 | 
         
            +
                  .ptr_b = TypedEmpty<ElementB*>(num_experts, device),
         
     | 
| 132 | 
         
            +
                  .ptr_c = TypedEmpty<ElementC*>(num_experts, device),
         
     | 
| 133 | 
         
            +
                  .problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                  // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
         
     | 
| 136 | 
         
            +
                  .threadblock_count = Gemm::sufficient(),
         
     | 
| 137 | 
         
            +
                };
         
     | 
| 138 | 
         
            +
            }
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            template <
         
     | 
| 141 | 
         
            +
              bool kDynamicK,
         
     | 
| 142 | 
         
            +
              typename Gemm,
         
     | 
| 143 | 
         
            +
              typename ElementA, typename ElementB, typename ElementC,
         
     | 
| 144 | 
         
            +
              typename LayoutA, typename LayoutB, typename LayoutC
         
     | 
| 145 | 
         
            +
            >
         
     | 
| 146 | 
         
            +
            RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
         
     | 
| 147 | 
         
            +
            				     torch::Tensor b,
         
     | 
| 148 | 
         
            +
            				     torch::Tensor c,
         
     | 
| 149 | 
         
            +
            				     torch::Tensor batch_sizes,
         
     | 
| 150 | 
         
            +
            				     ::cutlass::gemm::GemmCoord coord_template,
         
     | 
| 151 | 
         
            +
            				     int64_t num_experts) {
         
     | 
| 152 | 
         
            +
              std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
              // Create the host arrays of leading dimension data and pointer data.
         
     | 
| 155 | 
         
            +
              std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
         
     | 
| 156 | 
         
            +
              int64_t elements_a = 0, elements_b = 0, elements_c = 0;
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
              std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
              for (int i = 0; i < num_experts; ++i) {
         
     | 
| 161 | 
         
            +
                auto& problem = problem_sizes_host[i];
         
     | 
| 162 | 
         
            +
                problem = coord_template;
         
     | 
| 163 | 
         
            +
                (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
         
     | 
| 166 | 
         
            +
                ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
         
     | 
| 167 | 
         
            +
                ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
         
     | 
| 170 | 
         
            +
                ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
         
     | 
| 171 | 
         
            +
                ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                elements_a += problem.m() * problem.k();
         
     | 
| 174 | 
         
            +
                elements_b += problem.k() * problem.n();
         
     | 
| 175 | 
         
            +
                elements_c += problem.m() * problem.n();
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                if (problem.k() == 0) {
         
     | 
| 178 | 
         
            +
                  // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
         
     | 
| 179 | 
         
            +
                  // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
         
     | 
| 180 | 
         
            +
                  //   * set the output to zero with `cudaMemsetAsync()`
         
     | 
| 181 | 
         
            +
                  //   * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
         
     | 
| 182 | 
         
            +
                  CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
         
     | 
| 183 | 
         
            +
                    0,
         
     | 
| 184 | 
         
            +
                    problem.m() * problem.n() * sizeof(ElementC),
         
     | 
| 185 | 
         
            +
                    c10::cuda::getCurrentCUDAStream()));
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                  problem.m() = 0;
         
     | 
| 188 | 
         
            +
                  problem.n() = 0;
         
     | 
| 189 | 
         
            +
                }
         
     | 
| 190 | 
         
            +
              }
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
              // Only sort problems when K are different
         
     | 
| 193 | 
         
            +
              if (kDynamicK) {
         
     | 
| 194 | 
         
            +
                  std::vector<size_t> indices(num_experts);
         
     | 
| 195 | 
         
            +
                  std::iota(indices.begin(), indices.end(), 0);
         
     | 
| 196 | 
         
            +
                  std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
         
     | 
| 197 | 
         
            +
                      return problem_sizes_host[i].k() > problem_sizes_host[j].k();
         
     | 
| 198 | 
         
            +
                  });
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                  ReorderArray(problem_sizes_host.data(), indices);
         
     | 
| 201 | 
         
            +
                  ReorderArray(lda_host.data(), indices);
         
     | 
| 202 | 
         
            +
                  ReorderArray(ldb_host.data(), indices);
         
     | 
| 203 | 
         
            +
                  ReorderArray(ldc_host.data(), indices);
         
     | 
| 204 | 
         
            +
                  ReorderArray(ptr_a_host.data(), indices);
         
     | 
| 205 | 
         
            +
                  ReorderArray(ptr_b_host.data(), indices);
         
     | 
| 206 | 
         
            +
                  ReorderArray(ptr_c_host.data(), indices);
         
     | 
| 207 | 
         
            +
              }
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
              // Copy the problem sizes, pointers and leading dimension data to the device.
         
     | 
| 210 | 
         
            +
              return RawGemmArguments {
         
     | 
| 211 | 
         
            +
                .lda = CopyToDevice(lda_host, a.device()),
         
     | 
| 212 | 
         
            +
                .ldb = CopyToDevice(ldb_host, a.device()),
         
     | 
| 213 | 
         
            +
                .ldc = CopyToDevice(ldc_host, a.device()),
         
     | 
| 214 | 
         
            +
                .ptr_a = CopyToDevice(ptr_a_host, a.device()),
         
     | 
| 215 | 
         
            +
                .ptr_b = CopyToDevice(ptr_b_host, a.device()),
         
     | 
| 216 | 
         
            +
                .ptr_c = CopyToDevice(ptr_c_host, a.device()),
         
     | 
| 217 | 
         
            +
                .problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
         
     | 
| 220 | 
         
            +
                .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
         
     | 
| 221 | 
         
            +
              };
         
     | 
| 222 | 
         
            +
            }
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            template <
         
     | 
| 225 | 
         
            +
              bool kDynamicK,
         
     | 
| 226 | 
         
            +
              typename Gemm,
         
     | 
| 227 | 
         
            +
              typename ElementA, typename ElementB, typename ElementC,
         
     | 
| 228 | 
         
            +
              typename LayoutA, typename LayoutB, typename LayoutC
         
     | 
| 229 | 
         
            +
            >
         
     | 
| 230 | 
         
            +
            typename Gemm::Arguments MakeArguments(torch::Tensor a,
         
     | 
| 231 | 
         
            +
            				       torch::Tensor b,
         
     | 
| 232 | 
         
            +
            				       torch::Tensor c,
         
     | 
| 233 | 
         
            +
            				       torch::Tensor batch_sizes,
         
     | 
| 234 | 
         
            +
            				       ::cutlass::gemm::GemmCoord coord_template,
         
     | 
| 235 | 
         
            +
            				       int64_t num_experts) {
         
     | 
| 236 | 
         
            +
              RawGemmArguments raw_args;
         
     | 
| 237 | 
         
            +
              if (batch_sizes.is_cuda()) {
         
     | 
| 238 | 
         
            +
                raw_args = MakeArgumentsOnDevice<
         
     | 
| 239 | 
         
            +
                  Gemm, ElementA, ElementB, ElementC
         
     | 
| 240 | 
         
            +
                >(num_experts, a.device());
         
     | 
| 241 | 
         
            +
              } else {
         
     | 
| 242 | 
         
            +
                raw_args = MakeArgumentsOnHost<
         
     | 
| 243 | 
         
            +
                  kDynamicK,
         
     | 
| 244 | 
         
            +
                  Gemm,
         
     | 
| 245 | 
         
            +
                  ElementA, ElementB, ElementC,
         
     | 
| 246 | 
         
            +
                  LayoutA, LayoutB, LayoutC
         
     | 
| 247 | 
         
            +
                >(a, b, c, batch_sizes, coord_template, num_experts);
         
     | 
| 248 | 
         
            +
              }
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
              printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
         
     | 
| 251 | 
         
            +
              // Validate the result.
         
     | 
| 252 | 
         
            +
              if (!raw_args.threadblock_count) {
         
     | 
| 253 | 
         
            +
                TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
         
     | 
| 254 | 
         
            +
              }
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
              typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
         
     | 
| 257 | 
         
            +
              // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
         
     | 
| 258 | 
         
            +
              // so we can safely pass `nullptr` for `host_problem_sizes`.
         
     | 
| 259 | 
         
            +
              // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
         
     | 
| 260 | 
         
            +
              // know the problem dimensions on the host.
         
     | 
| 261 | 
         
            +
              typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
         
     | 
| 262 | 
         
            +
            				     (int)num_experts,
         
     | 
| 263 | 
         
            +
            				     (int)raw_args.threadblock_count,
         
     | 
| 264 | 
         
            +
            				     epilogue_op,
         
     | 
| 265 | 
         
            +
            				     (ElementA**)raw_args.ptr_a.data_ptr(),
         
     | 
| 266 | 
         
            +
            				     (ElementB**)raw_args.ptr_b.data_ptr(),
         
     | 
| 267 | 
         
            +
            				     (ElementC**)raw_args.ptr_c.data_ptr(),
         
     | 
| 268 | 
         
            +
            				     (ElementC**)raw_args.ptr_c.data_ptr(),
         
     | 
| 269 | 
         
            +
            				     /*lda=*/(int64_t*)raw_args.lda.data_ptr(),
         
     | 
| 270 | 
         
            +
            				     /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
         
     | 
| 271 | 
         
            +
            				     /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
         
     | 
| 272 | 
         
            +
            				     /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
         
     | 
| 273 | 
         
            +
            				     /*host_problem_sizes=*/nullptr);
         
     | 
| 274 | 
         
            +
              return arguments;
         
     | 
| 275 | 
         
            +
            }
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
            template <
         
     | 
| 278 | 
         
            +
              bool trans_a,
         
     | 
| 279 | 
         
            +
              typename ElementA, typename ElementB, typename ElementC,
         
     | 
| 280 | 
         
            +
              typename LayoutA, typename LayoutB, typename LayoutC,
         
     | 
| 281 | 
         
            +
              typename Arguments
         
     | 
| 282 | 
         
            +
            >
         
     | 
| 283 | 
         
            +
            void FillCutlassArguments(int num_experts,
         
     | 
| 284 | 
         
            +
            			  torch::Tensor batch_sizes,
         
     | 
| 285 | 
         
            +
            			  torch::Tensor a,
         
     | 
| 286 | 
         
            +
            			  torch::Tensor b,
         
     | 
| 287 | 
         
            +
            			  torch::Tensor c,
         
     | 
| 288 | 
         
            +
            			  const Arguments& arguments,
         
     | 
| 289 | 
         
            +
            			  ::cutlass::gemm::GemmCoord coord_template) {
         
     | 
| 290 | 
         
            +
              // Convert the batch sizes to the format CUTLASS understands on the device.
         
     | 
| 291 | 
         
            +
              // Use a single block here because:
         
     | 
| 292 | 
         
            +
              //   * the number of elements to process is microscopically small
         
     | 
| 293 | 
         
            +
              //   * we don't need any additional global memory
         
     | 
| 294 | 
         
            +
              FillArguments<
         
     | 
| 295 | 
         
            +
                  /*kDynamicK*/trans_a,
         
     | 
| 296 | 
         
            +
                  ElementA, ElementB, ElementC,
         
     | 
| 297 | 
         
            +
                  LayoutA, LayoutB, LayoutC
         
     | 
| 298 | 
         
            +
              ><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
         
     | 
| 299 | 
         
            +
                  num_experts, batch_sizes.data_ptr<int64_t>(),
         
     | 
| 300 | 
         
            +
                  (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
         
     | 
| 301 | 
         
            +
                  arguments, coord_template
         
     | 
| 302 | 
         
            +
              );
         
     | 
| 303 | 
         
            +
              C10_CUDA_KERNEL_LAUNCH_CHECK();
         
     | 
| 304 | 
         
            +
            }
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
            template <typename Args>
         
     | 
| 307 | 
         
            +
            void RemoveK0Problems(int num_experts, const Args& arguments) {
         
     | 
| 308 | 
         
            +
              // For zeroing out the outputs (which might be arbitrarily large), we want to use
         
     | 
| 309 | 
         
            +
              // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
         
     | 
| 310 | 
         
            +
              // `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
         
     | 
| 311 | 
         
            +
              // should be a good approximation for this.
         
     | 
| 312 | 
         
            +
              // When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
         
     | 
| 313 | 
         
            +
              ZeroOutK0Outputs<><<<
         
     | 
| 314 | 
         
            +
                arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
         
     | 
| 315 | 
         
            +
              >>>(
         
     | 
| 316 | 
         
            +
                num_experts, arguments
         
     | 
| 317 | 
         
            +
              );
         
     | 
| 318 | 
         
            +
              IgnoreK0Problems<><<<
         
     | 
| 319 | 
         
            +
                1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
         
     | 
| 320 | 
         
            +
              >>>(
         
     | 
| 321 | 
         
            +
                num_experts, arguments
         
     | 
| 322 | 
         
            +
              );
         
     | 
| 323 | 
         
            +
            }
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
            template <bool trans_a, bool trans_b>
         
     | 
| 326 | 
         
            +
            torch::Tensor CutlassGroupedGemm(torch::Tensor a,
         
     | 
| 327 | 
         
            +
            				 torch::Tensor b,
         
     | 
| 328 | 
         
            +
            				 torch::Tensor c,
         
     | 
| 329 | 
         
            +
            				 torch::Tensor batch_sizes,
         
     | 
| 330 | 
         
            +
            				 ::cutlass::gemm::GemmCoord coord_template) {
         
     | 
| 331 | 
         
            +
              using Gemm = GemmGrouped<trans_a, trans_b>;
         
     | 
| 332 | 
         
            +
              using LayoutA = typename Gemm::LayoutA;
         
     | 
| 333 | 
         
            +
              using LayoutB = typename Gemm::LayoutB;
         
     | 
| 334 | 
         
            +
              using LayoutC = typename Gemm::LayoutC;
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
              using ElementA = typename Gemm::ElementA;
         
     | 
| 337 | 
         
            +
              using ElementB = typename Gemm::ElementB;
         
     | 
| 338 | 
         
            +
              using ElementC = typename Gemm::ElementC;
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
              Gemm gemm;
         
     | 
| 341 | 
         
            +
              int64_t num_experts = batch_sizes.size(0);
         
     | 
| 342 | 
         
            +
              auto arguments = MakeArguments<
         
     | 
| 343 | 
         
            +
                /*kDynamicK*/trans_a,
         
     | 
| 344 | 
         
            +
                Gemm,
         
     | 
| 345 | 
         
            +
                ElementA, ElementB, ElementC,
         
     | 
| 346 | 
         
            +
                LayoutA, LayoutB, LayoutC
         
     | 
| 347 | 
         
            +
              >(a, b, c, batch_sizes, coord_template, num_experts);
         
     | 
| 348 | 
         
            +
              int64_t workspace_size = gemm.get_workspace_size(arguments);
         
     | 
| 349 | 
         
            +
              auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
         
     | 
| 350 | 
         
            +
              torch::Tensor workspace = torch::empty(workspace_size, options);
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
              if (batch_sizes.is_cuda()) {
         
     | 
| 353 | 
         
            +
                  FillCutlassArguments<
         
     | 
| 354 | 
         
            +
                    trans_a,
         
     | 
| 355 | 
         
            +
                    ElementA, ElementB, ElementC,
         
     | 
| 356 | 
         
            +
                    LayoutA, LayoutB, LayoutC
         
     | 
| 357 | 
         
            +
                  >(num_experts, batch_sizes, a, b, c, arguments, coord_template);
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                  RemoveK0Problems<>(num_experts, arguments);
         
     | 
| 360 | 
         
            +
              }
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
              // Initialize the kernel.
         
     | 
| 363 | 
         
            +
              if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
         
     | 
| 364 | 
         
            +
                TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
         
     | 
| 365 | 
         
            +
              }
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
              // Execute the kernel in the current stream.
         
     | 
| 368 | 
         
            +
              if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
         
     | 
| 369 | 
         
            +
                TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
         
     | 
| 370 | 
         
            +
              }
         
     | 
| 371 | 
         
            +
              return c;
         
     | 
| 372 | 
         
            +
            }
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
            void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
         
     | 
| 375 | 
         
            +
            		c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
         
     | 
| 376 | 
         
            +
            		c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
         
     | 
| 377 | 
         
            +
              int m = trans_b ? b_rows : b_cols;
         
     | 
| 378 | 
         
            +
              int k = trans_b ? b_cols : b_rows;
         
     | 
| 379 | 
         
            +
              int n = trans_a ? a_cols : a_rows;
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
              int lda = trans_a ? n : k;
         
     | 
| 382 | 
         
            +
              int ldb = trans_b ? k : m;
         
     | 
| 383 | 
         
            +
              cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
         
     | 
| 384 | 
         
            +
              cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
              float alpha = 1.0, beta = 0.0;
         
     | 
| 387 | 
         
            +
              CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
         
     | 
| 388 | 
         
            +
            			   transpose_b, transpose_a,
         
     | 
| 389 | 
         
            +
            			   m, n, k, &alpha,
         
     | 
| 390 | 
         
            +
            			   b, CUDA_R_16BF, ldb,
         
     | 
| 391 | 
         
            +
            			   a, CUDA_R_16BF, lda,
         
     | 
| 392 | 
         
            +
            			   &beta,
         
     | 
| 393 | 
         
            +
            			   c, CUDA_R_16BF, c_cols, CUDA_R_32F,
         
     | 
| 394 | 
         
            +
            			   CUBLAS_GEMM_DEFAULT));
         
     | 
| 395 | 
         
            +
            }
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
            void CublasGroupedGemm(torch::Tensor a,
         
     | 
| 398 | 
         
            +
            		       torch::Tensor b,
         
     | 
| 399 | 
         
            +
            		       torch::Tensor c,
         
     | 
| 400 | 
         
            +
            		       torch::Tensor batch_sizes,
         
     | 
| 401 | 
         
            +
            		       bool trans_b) {
         
     | 
| 402 | 
         
            +
              int64_t bs = batch_sizes.size(0), k = a.size(1);
         
     | 
| 403 | 
         
            +
              int64_t n = trans_b ? b.size(1) : b.size(2);
         
     | 
| 404 | 
         
            +
              int64_t b_rows = b.size(1), b_cols = b.size(2);
         
     | 
| 405 | 
         
            +
              c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
         
     | 
| 406 | 
         
            +
              c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
         
     | 
| 407 | 
         
            +
              c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
         
     | 
| 408 | 
         
            +
              for (int i = 0; i < bs; ++i) {
         
     | 
| 409 | 
         
            +
                int64_t m = batch_sizes.data_ptr<int64_t>()[i];
         
     | 
| 410 | 
         
            +
                CublasGemm(a_ptr, m, k, /*trans_a=*/false,
         
     | 
| 411 | 
         
            +
            	       b_ptr, b_rows, b_cols, trans_b,
         
     | 
| 412 | 
         
            +
            	       c_ptr, m, n);
         
     | 
| 413 | 
         
            +
                a_ptr += m * k;
         
     | 
| 414 | 
         
            +
                b_ptr += b_rows * b_cols;
         
     | 
| 415 | 
         
            +
                c_ptr += m * n;
         
     | 
| 416 | 
         
            +
              }
         
     | 
| 417 | 
         
            +
            }
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
            void CublasGroupedGemmVariableK(torch::Tensor a,
         
     | 
| 420 | 
         
            +
            				torch::Tensor b,
         
     | 
| 421 | 
         
            +
            				torch::Tensor c,
         
     | 
| 422 | 
         
            +
            				torch::Tensor batch_sizes) {
         
     | 
| 423 | 
         
            +
              int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
         
     | 
| 424 | 
         
            +
              c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
         
     | 
| 425 | 
         
            +
              c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
         
     | 
| 426 | 
         
            +
              c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
         
     | 
| 427 | 
         
            +
              for (int i = 0; i < bs; ++i) {
         
     | 
| 428 | 
         
            +
                int64_t k = batch_sizes.data_ptr<int64_t>()[i];
         
     | 
| 429 | 
         
            +
                CublasGemm(a_ptr, k, m, /*trans_a=*/true,
         
     | 
| 430 | 
         
            +
            	       b_ptr, k, n, /*trans_b=*/false,
         
     | 
| 431 | 
         
            +
            	       c_ptr, m, n);
         
     | 
| 432 | 
         
            +
                a_ptr += k * m;
         
     | 
| 433 | 
         
            +
                b_ptr += k * n;
         
     | 
| 434 | 
         
            +
                c_ptr += m * n;
         
     | 
| 435 | 
         
            +
              }
         
     | 
| 436 | 
         
            +
            }
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
            void GroupedGemmVariableK(torch::Tensor a,
         
     | 
| 439 | 
         
            +
            			  torch::Tensor b,
         
     | 
| 440 | 
         
            +
            			  torch::Tensor c,
         
     | 
| 441 | 
         
            +
            			  torch::Tensor batch_sizes) {
         
     | 
| 442 | 
         
            +
              // We expected a CUDA tensor with two dimensions and shape
         
     | 
| 443 | 
         
            +
              // (tokens, hidden_out) for 'b'.
         
     | 
| 444 | 
         
            +
              TORCH_CHECK(b.is_cuda());
         
     | 
| 445 | 
         
            +
              TORCH_CHECK(b.ndimension() == 2);
         
     | 
| 446 | 
         
            +
              TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
              // Validate the dimensions.
         
     | 
| 449 | 
         
            +
              int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
         
     | 
| 450 | 
         
            +
              int64_t m = a.size(1), n = b.size(1);
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
              // Validate that we have the same contraction dimension.
         
     | 
| 453 | 
         
            +
              TORCH_CHECK(tokens == b.size(0));
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
              // Validate the output shape.
         
     | 
| 456 | 
         
            +
              TORCH_CHECK(c.is_cuda());
         
     | 
| 457 | 
         
            +
              TORCH_CHECK(c.ndimension() == 3);
         
     | 
| 458 | 
         
            +
              TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
         
     | 
| 459 | 
         
            +
              TORCH_CHECK(c.size(0) == num_experts);
         
     | 
| 460 | 
         
            +
              TORCH_CHECK(c.size(1) == m);
         
     | 
| 461 | 
         
            +
              TORCH_CHECK(c.size(2) == n);
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
              // Run the computation.
         
     | 
| 464 | 
         
            +
              CublasGroupedGemmVariableK(a, b, c, batch_sizes);
         
     | 
| 465 | 
         
            +
            }
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
            // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
         
     | 
| 468 | 
         
            +
            // assumed to be batched with fixed sized batches.
         
     | 
| 469 | 
         
            +
            //
         
     | 
| 470 | 
         
            +
            // TODO(tgale): Validate alignment is true for every batch element.
         
     | 
| 471 | 
         
            +
            void GroupedGemm(torch::Tensor a,
         
     | 
| 472 | 
         
            +
            		 torch::Tensor b,
         
     | 
| 473 | 
         
            +
            		 torch::Tensor c,
         
     | 
| 474 | 
         
            +
            		 torch::Tensor batch_sizes,
         
     | 
| 475 | 
         
            +
            		 bool trans_a, bool trans_b) {
         
     | 
| 476 | 
         
            +
              // NOTE: We only support 'trans_a' or 'trans_b', not both.
         
     | 
| 477 | 
         
            +
              TORCH_CHECK(!(trans_a && trans_b));
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
            #if !defined(GROUPED_GEMM_CUTLASS)
         
     | 
| 480 | 
         
            +
              // No way to run cuBLAS kernels if the problem dimensions are not known on the host.
         
     | 
| 481 | 
         
            +
              TORCH_CHECK(batch_sizes.is_cpu());
         
     | 
| 482 | 
         
            +
            #else
         
     | 
| 483 | 
         
            +
              // CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
         
     | 
| 484 | 
         
            +
              TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
         
     | 
| 485 | 
         
            +
            #endif
         
     | 
| 486 | 
         
            +
              TORCH_CHECK(batch_sizes.ndimension() == 1);
         
     | 
| 487 | 
         
            +
              TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
              // We expected a CUDA tensor with two dimensions and shape
         
     | 
| 490 | 
         
            +
              // (tokens, hidden_in) for 'a'.
         
     | 
| 491 | 
         
            +
              TORCH_CHECK(a.is_cuda());
         
     | 
| 492 | 
         
            +
              TORCH_CHECK(a.ndimension() == 2);
         
     | 
| 493 | 
         
            +
              TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
            #if !defined(GROUPED_GEMM_CUTLASS)
         
     | 
| 496 | 
         
            +
              if (trans_a) {
         
     | 
| 497 | 
         
            +
                // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
         
     | 
| 498 | 
         
            +
                // for the rest of the op.
         
     | 
| 499 | 
         
            +
                GroupedGemmVariableK(a, b, c, batch_sizes);
         
     | 
| 500 | 
         
            +
                return;
         
     | 
| 501 | 
         
            +
              }
         
     | 
| 502 | 
         
            +
            #endif
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
              TORCH_CHECK(b.is_cuda());
         
     | 
| 505 | 
         
            +
              TORCH_CHECK(c.is_cuda());
         
     | 
| 506 | 
         
            +
              TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
         
     | 
| 507 | 
         
            +
              TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
         
     | 
| 508 | 
         
            +
             
     | 
| 509 | 
         
            +
              // The expected shapes of 'b' and 'c' are:
         
     | 
| 510 | 
         
            +
              //   * when 'trans_a' is set: b=(tokens, hidden_out),                 c=(num_experts, hidden_in, hidden_out)
         
     | 
| 511 | 
         
            +
              //   * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
         
     | 
| 512 | 
         
            +
              //   * otherwise:             b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
         
     | 
| 513 | 
         
            +
              size_t hidden_in{}, hidden_out{};
         
     | 
| 514 | 
         
            +
              if (trans_a) {
         
     | 
| 515 | 
         
            +
                hidden_in = a.size(1);
         
     | 
| 516 | 
         
            +
                hidden_out = b.size(1);
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                TORCH_CHECK(b.ndimension() == 2);
         
     | 
| 519 | 
         
            +
                TORCH_CHECK(c.ndimension() == 3);
         
     | 
| 520 | 
         
            +
                TORCH_CHECK(b.size(0) == a.size(0));
         
     | 
| 521 | 
         
            +
                TORCH_CHECK(c.size(0) == batch_sizes.size(0));
         
     | 
| 522 | 
         
            +
                TORCH_CHECK(c.size(1) == hidden_in);
         
     | 
| 523 | 
         
            +
                TORCH_CHECK(c.size(2) == hidden_out);
         
     | 
| 524 | 
         
            +
              } else {
         
     | 
| 525 | 
         
            +
                TORCH_CHECK(b.ndimension() == 3);
         
     | 
| 526 | 
         
            +
                TORCH_CHECK(c.ndimension() == 2);
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                // Validate the contraction dimensions match.
         
     | 
| 529 | 
         
            +
                int64_t tokens = a.size(0), num_experts = b.size(0);
         
     | 
| 530 | 
         
            +
                hidden_in = trans_b ? b.size(2) : b.size(1);
         
     | 
| 531 | 
         
            +
                hidden_out = trans_b ? b.size(1) : b.size(2);
         
     | 
| 532 | 
         
            +
                TORCH_CHECK(hidden_in == a.size(1));
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                // Validate that we have one size per expert.
         
     | 
| 535 | 
         
            +
                TORCH_CHECK(batch_sizes.size(0) == num_experts);
         
     | 
| 536 | 
         
            +
              }
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
              // NOTE: We support transposition through the 'trans_b' flag.
         
     | 
| 539 | 
         
            +
              TORCH_CHECK(a.is_contiguous());
         
     | 
| 540 | 
         
            +
              TORCH_CHECK(b.is_contiguous());
         
     | 
| 541 | 
         
            +
              TORCH_CHECK(c.is_contiguous());
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
            #if !defined(GROUPED_GEMM_CUTLASS)
         
     | 
| 544 | 
         
            +
              CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
         
     | 
| 545 | 
         
            +
              return;
         
     | 
| 546 | 
         
            +
            #else
         
     | 
| 547 | 
         
            +
              // The `coord_template` argument contains `kDynamicDim` as one of its dimensions
         
     | 
| 548 | 
         
            +
              // as a placeholder. This placeholder is later expanded into the actual dimension
         
     | 
| 549 | 
         
            +
              // for every element of the batch,  either on the host or on the device
         
     | 
| 550 | 
         
            +
              // (if we can't do in on the host).
         
     | 
| 551 | 
         
            +
              const auto coord_template = trans_a
         
     | 
| 552 | 
         
            +
                ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
         
     | 
| 553 | 
         
            +
                : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
         
     | 
| 554 | 
         
            +
              if (trans_a) {
         
     | 
| 555 | 
         
            +
                CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
         
     | 
| 556 | 
         
            +
                return;
         
     | 
| 557 | 
         
            +
              }
         
     | 
| 558 | 
         
            +
              if (trans_b) {
         
     | 
| 559 | 
         
            +
                CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
         
     | 
| 560 | 
         
            +
                return;
         
     | 
| 561 | 
         
            +
              }
         
     | 
| 562 | 
         
            +
              CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
         
     | 
| 563 | 
         
            +
              return;
         
     | 
| 564 | 
         
            +
            #endif
         
     | 
| 565 | 
         
            +
            }
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
            }  // namespace grouped_gemm
         
     | 
    	
        csrc/grouped_gemm/grouped_gemm.h
    ADDED
    
    | 
         @@ -0,0 +1,20 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #pragma once
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            // // Set default if not already defined
         
     | 
| 4 | 
         
            +
            // #ifndef GROUPED_GEMM_CUTLASS
         
     | 
| 5 | 
         
            +
            // #define GROUPED_GEMM_CUTLASS 0
         
     | 
| 6 | 
         
            +
            // #endif
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            // #include <torch/extension.h>
         
     | 
| 9 | 
         
            +
            #include <torch/torch.h>
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            namespace grouped_gemm {
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            void GroupedGemm(torch::Tensor a,
         
     | 
| 14 | 
         
            +
            		 torch::Tensor b,
         
     | 
| 15 | 
         
            +
            		 torch::Tensor c,
         
     | 
| 16 | 
         
            +
            		 torch::Tensor batch_sizes,
         
     | 
| 17 | 
         
            +
            		 bool trans_a, bool trans_b);
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            }  // namespace grouped_gemm
         
     | 
| 20 | 
         
            +
             
     | 
    	
        csrc/grouped_gemm/ops.cu
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #include "grouped_gemm.h"
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            #include <torch/extension.h>
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            namespace grouped_gemm {
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         
     | 
| 8 | 
         
            +
              m.def("gmm", &GroupedGemm, "Grouped GEMM.");
         
     | 
| 9 | 
         
            +
            }
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            }  // namespace grouped_gemm
         
     | 
    	
        tests/ops_test.py
    ADDED
    
    | 
         @@ -0,0 +1,170 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import unittest
         
     | 
| 2 | 
         
            +
            import itertools
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from absl.testing import parameterized
         
     | 
| 5 | 
         
            +
            import megablocks
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def allclose(x, y, pct=2.0):
         
     | 
| 11 | 
         
            +
                mask = torch.isclose(x, y, rtol=1e-5)
         
     | 
| 12 | 
         
            +
                pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
         
     | 
| 13 | 
         
            +
                if pct_diff > pct:
         
     | 
| 14 | 
         
            +
                    print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
         
     | 
| 15 | 
         
            +
                    print("{:.2f}% of values not close.".format(pct_diff))
         
     | 
| 16 | 
         
            +
                    return False
         
     | 
| 17 | 
         
            +
                return True
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def add_flags(x):
         
     | 
| 21 | 
         
            +
                out = []
         
     | 
| 22 | 
         
            +
                for y in x:
         
     | 
| 23 | 
         
            +
                    for trans_b in (False, True):
         
     | 
| 24 | 
         
            +
                        out.append(y + (trans_b, False))
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                        # TODO: Revisit enabling batch_sizes_on_device
         
     | 
| 27 | 
         
            +
                        # for batch_sizes_on_device in (False, True):
         
     | 
| 28 | 
         
            +
                        #     out.append(y + (trans_b, batch_sizes_on_device))
         
     | 
| 29 | 
         
            +
                return out
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            _TEST_PROBLEMS = add_flags((
         
     | 
| 33 | 
         
            +
                (1, 128, 128, 128),
         
     | 
| 34 | 
         
            +
                (8, 128, 128, 128),
         
     | 
| 35 | 
         
            +
                (16, 128, 128, 128),
         
     | 
| 36 | 
         
            +
                (1, 128, 256, 512),
         
     | 
| 37 | 
         
            +
                (8, 128, 256, 512),
         
     | 
| 38 | 
         
            +
                (16, 128, 256, 512),
         
     | 
| 39 | 
         
            +
            ))
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def randn(bs, x, y):
         
     | 
| 43 | 
         
            +
                out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
         
     | 
| 44 | 
         
            +
                return out.cuda().to(torch.bfloat16)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def gmm(a, b, batch_sizes, trans_b=False):
         
     | 
| 48 | 
         
            +
                batch_sizes = batch_sizes.cpu().numpy()
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                out = []
         
     | 
| 51 | 
         
            +
                start = 0
         
     | 
| 52 | 
         
            +
                for i, size in enumerate(batch_sizes):
         
     | 
| 53 | 
         
            +
                    rhs = b[i, :, :].t() if trans_b else b[i, :, :]
         
     | 
| 54 | 
         
            +
                    out.append(a[start:start + size, :] @ rhs)
         
     | 
| 55 | 
         
            +
                    start += size
         
     | 
| 56 | 
         
            +
                return torch.cat(out)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            @parameterized.parameters(*_TEST_PROBLEMS)
         
     | 
| 60 | 
         
            +
            class OpsTest(parameterized.TestCase):
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
         
     | 
| 63 | 
         
            +
                    torch.manual_seed(0)
         
     | 
| 64 | 
         
            +
                    a = randn(z, m, k).view(-1, k)
         
     | 
| 65 | 
         
            +
                    b = randn(z, n, k) if trans_b else randn(z, k, n)
         
     | 
| 66 | 
         
            +
                    batch_sizes = torch.tensor([m] * z)
         
     | 
| 67 | 
         
            +
                    if batch_sizes_on_device:
         
     | 
| 68 | 
         
            +
                        batch_sizes = batch_sizes.cuda()
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    a.requires_grad_(True)
         
     | 
| 71 | 
         
            +
                    b.requires_grad_(True)
         
     | 
| 72 | 
         
            +
                    a_ref = a.detach().clone().requires_grad_(True)
         
     | 
| 73 | 
         
            +
                    b_ref = b.detach().clone().requires_grad_(True)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # out = ops.gmm(a, b, batch_sizes, trans_b)
         
     | 
| 76 | 
         
            +
                    out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
         
     | 
| 77 | 
         
            +
                    # print("out", out)
         
     | 
| 78 | 
         
            +
                    expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
         
     | 
| 79 | 
         
            +
                    self.assertTrue(allclose(out, expected_out))
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # Check gradients.
         
     | 
| 82 | 
         
            +
                    out.sum().backward()
         
     | 
| 83 | 
         
            +
                    expected_out.sum().backward()
         
     | 
| 84 | 
         
            +
                    self.assertTrue(allclose(a.grad, a_ref.grad))
         
     | 
| 85 | 
         
            +
                    self.assertTrue(allclose(b.grad, b_ref.grad))
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
         
     | 
| 88 | 
         
            +
                    torch.manual_seed(0)
         
     | 
| 89 | 
         
            +
                    a = randn(z, m, k).view(-1, k)
         
     | 
| 90 | 
         
            +
                    b = randn(z, n, k) if trans_b else randn(z, k, n)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    dist = torch.rand(z, )
         
     | 
| 93 | 
         
            +
                    dist /= dist.sum()
         
     | 
| 94 | 
         
            +
                    batch_sizes = (dist * m).to(torch.long)
         
     | 
| 95 | 
         
            +
                    error = m * z - batch_sizes.sum()
         
     | 
| 96 | 
         
            +
                    batch_sizes[-1] += error
         
     | 
| 97 | 
         
            +
                    assert batch_sizes.sum() == (m * z)
         
     | 
| 98 | 
         
            +
                    if batch_sizes_on_device:
         
     | 
| 99 | 
         
            +
                        batch_sizes = batch_sizes.cuda()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    a.requires_grad_(True)
         
     | 
| 102 | 
         
            +
                    b.requires_grad_(True)
         
     | 
| 103 | 
         
            +
                    a_ref = a.detach().clone().requires_grad_(True)
         
     | 
| 104 | 
         
            +
                    b_ref = b.detach().clone().requires_grad_(True)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
         
     | 
| 107 | 
         
            +
                    expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
         
     | 
| 108 | 
         
            +
                    self.assertTrue(allclose(out, expected_out))
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    # Check gradients.
         
     | 
| 111 | 
         
            +
                    out.sum().backward()
         
     | 
| 112 | 
         
            +
                    expected_out.sum().backward()
         
     | 
| 113 | 
         
            +
                    self.assertTrue(allclose(a.grad, a_ref.grad))
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # TODO: Review to ensure that the gradients are correct.
         
     | 
| 116 | 
         
            +
                    # self.assertTrue(allclose(b.grad, b_ref.grad))
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            # @parameterized.parameters(False, True)
         
     | 
| 120 | 
         
            +
            @parameterized.parameters(False, False)
         
     | 
| 121 | 
         
            +
            class EdgeCasesTest(unittest.TestCase):
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
         
     | 
| 124 | 
         
            +
                    torch.manual_seed(0)
         
     | 
| 125 | 
         
            +
                    m = 16384
         
     | 
| 126 | 
         
            +
                    k = 4096
         
     | 
| 127 | 
         
            +
                    n = 14336
         
     | 
| 128 | 
         
            +
                    num_experts = 8
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    a = randn(num_experts, m // num_experts, k).view(-1, k)
         
     | 
| 131 | 
         
            +
                    b = randn(num_experts, k, n)
         
     | 
| 132 | 
         
            +
                    batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
         
     | 
| 133 | 
         
            +
                    if batch_sizes_on_device:
         
     | 
| 134 | 
         
            +
                        batch_sizes = batch_sizes.cuda()
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    a.requires_grad_(True)
         
     | 
| 137 | 
         
            +
                    b.requires_grad_(True)
         
     | 
| 138 | 
         
            +
                    a_ref = a.detach().clone().requires_grad_(True)
         
     | 
| 139 | 
         
            +
                    b_ref = b.detach().clone().requires_grad_(True)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    out = megablocks.gg_ops.gmm(a, b, batch_sizes)
         
     | 
| 142 | 
         
            +
                    expected_out = gmm(a_ref, b_ref, batch_sizes)
         
     | 
| 143 | 
         
            +
                    self.assertTrue(allclose(out, expected_out))
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    # Check gradients.
         
     | 
| 146 | 
         
            +
                    out.sum().backward()
         
     | 
| 147 | 
         
            +
                    expected_out.sum().backward()
         
     | 
| 148 | 
         
            +
                    self.assertTrue(allclose(a.grad, a_ref.grad))
         
     | 
| 149 | 
         
            +
                    self.assertTrue(allclose(b.grad, b_ref.grad))
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
         
     | 
| 152 | 
         
            +
                    sz = 128
         
     | 
| 153 | 
         
            +
                    total_tokens = 192
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
         
     | 
| 156 | 
         
            +
                    b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
         
     | 
| 157 | 
         
            +
                    c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
         
     | 
| 158 | 
         
            +
                    batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
         
     | 
| 159 | 
         
            +
                    if batch_sizes_on_device:
         
     | 
| 160 | 
         
            +
                        batch_sizes = batch_sizes.cuda()
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
         
     | 
| 163 | 
         
            +
                    self.assertTrue((c[0] == 0).all())
         
     | 
| 164 | 
         
            +
                    self.assertTrue((c[1] == 128).all())
         
     | 
| 165 | 
         
            +
                    self.assertTrue((c[2] == 0).all())
         
     | 
| 166 | 
         
            +
                    self.assertTrue((c[3] == 64).all())
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 170 | 
         
            +
                unittest.main()
         
     | 
    	
        tests/test_gg.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import megablocks
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def randn(bs, x, y):
         
     | 
| 6 | 
         
            +
                out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
         
     | 
| 7 | 
         
            +
                return out.cuda().to(torch.bfloat16)
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def gmm(a, b, batch_sizes, trans_b=False):
         
     | 
| 11 | 
         
            +
                batch_sizes = batch_sizes.cpu().numpy()
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                out = []
         
     | 
| 14 | 
         
            +
                start = 0
         
     | 
| 15 | 
         
            +
                for i, size in enumerate(batch_sizes):
         
     | 
| 16 | 
         
            +
                    rhs = b[i, :, :].t() if trans_b else b[i, :, :]
         
     | 
| 17 | 
         
            +
                    out.append(a[start : start + size, :] @ rhs)
         
     | 
| 18 | 
         
            +
                    start += size
         
     | 
| 19 | 
         
            +
                return torch.cat(out)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def test_gmm():
         
     | 
| 23 | 
         
            +
                z = 1
         
     | 
| 24 | 
         
            +
                m = 128
         
     | 
| 25 | 
         
            +
                n = 128
         
     | 
| 26 | 
         
            +
                k = 128
         
     | 
| 27 | 
         
            +
                trans_b = False
         
     | 
| 28 | 
         
            +
                batch_sizes_on_device = False
         
     | 
| 29 | 
         
            +
                # TODO: fix to enable batch_sizes_on_device
         
     | 
| 30 | 
         
            +
                # batch_sizes_on_device = True
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                torch.manual_seed(0)
         
     | 
| 33 | 
         
            +
                a = randn(z, m, k).view(-1, k)
         
     | 
| 34 | 
         
            +
                b = randn(z, n, k) if trans_b else randn(z, k, n)
         
     | 
| 35 | 
         
            +
                batch_sizes = torch.tensor([m] * z)
         
     | 
| 36 | 
         
            +
                if batch_sizes_on_device:
         
     | 
| 37 | 
         
            +
                    batch_sizes = batch_sizes.cuda()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                a.requires_grad_(True)
         
     | 
| 40 | 
         
            +
                b.requires_grad_(True)
         
     | 
| 41 | 
         
            +
                a_ref = a.detach().clone().requires_grad_(True)
         
     | 
| 42 | 
         
            +
                b_ref = b.detach().clone().requires_grad_(True)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # out = ops.gmm(a, b, batch_sizes, trans_b)
         
     | 
| 45 | 
         
            +
                out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
         
     | 
| 46 | 
         
            +
                print("out", out)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                out.sum().backward()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                expected_out.sum().backward()
         
     | 
| 55 | 
         
            +
                assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
         
     | 
| 56 | 
         
            +
                assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
         
     | 
| 57 | 
         
            +
                print("Test passed successfully!")
         
     | 
    	
        torch-ext/megablocks/__init__.py
    CHANGED
    
    | 
         @@ -5,11 +5,15 @@ import torch 
     | 
|
| 5 | 
         | 
| 6 | 
         
             
            from ._ops import ops
         
     | 
| 7 | 
         | 
| 8 | 
         
            -
            from  
     | 
| 9 | 
         
            -
            from  
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
            from  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
             
            # This section contains the direct kernel exports (not inlcuded in the original code)
         
     | 
| 15 | 
         
             
            def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         
     | 
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            from ._ops import ops
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            from .grouped_gemm import backend as gg_backend
         
     | 
| 9 | 
         
            +
            from .grouped_gemm import ops as gg_ops
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from .layers.arguments import Arguments
         
     | 
| 13 | 
         
            +
            from .layers.dmoe import ParallelDroplessMLP, dMoE
         
     | 
| 14 | 
         
            +
            from .layers.glu import SparseGLU
         
     | 
| 15 | 
         
            +
            from .layers.mlp import MLP, SparseMLP
         
     | 
| 16 | 
         
            +
            from .layers.moe import MoE, ParallelMLP, get_load_balancing_loss
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            # This section contains the direct kernel exports (not inlcuded in the original code)
         
     | 
| 19 | 
         
             
            def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         
     | 
    	
        torch-ext/megablocks/grouped_gemm/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,2 @@ 
     | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from . import ops
         
     | 
| 2 | 
         
            +
            from . import backend
         
     | 
    	
        torch-ext/megablocks/grouped_gemm/backend.py
    ADDED
    
    | 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # NOTE: Torch needs to be imported before the custom
         
     | 
| 2 | 
         
            +
            # extensions. Otherwise libc10.so cannot be found.
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # # TODO(tgale): Wrap this in a try-block with better
         
     | 
| 6 | 
         
            +
            # # error message and instructions for building the
         
     | 
| 7 | 
         
            +
            # # c++ operations.
         
     | 
| 8 | 
         
            +
            # import grouped_gemm_backend as backend
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # We import the backend operations from the megablocks package as
         
     | 
| 11 | 
         
            +
            # grouped_gemm is vendored in megablocks in this repository.
         
     | 
| 12 | 
         
            +
            # from ... import _ops as backend
         
     | 
| 13 | 
         
            +
            from megablocks._ops import ops as backend  # type: ignore
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
         
     | 
| 16 | 
         
            +
                assert not (trans_a and trans_b)
         
     | 
| 17 | 
         
            +
                assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
         
     | 
| 18 | 
         
            +
                assert a.ndim == 2, "Expected 2d tensor for 'a'"
         
     | 
| 19 | 
         
            +
                assert b.ndim == (2 if trans_a else 3)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                shape = (
         
     | 
| 22 | 
         
            +
                    (batch_sizes.shape[0], a.shape[1], b.shape[1])
         
     | 
| 23 | 
         
            +
                    if trans_a else
         
     | 
| 24 | 
         
            +
                    (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
         
     | 
| 25 | 
         
            +
                )
         
     | 
| 26 | 
         
            +
                return torch.empty(*shape, device=a.device, dtype=a.dtype)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
         
     | 
| 29 | 
         
            +
                if c is None:
         
     | 
| 30 | 
         
            +
                    c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
         
     | 
| 31 | 
         
            +
                backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
         
     | 
| 32 | 
         
            +
                return c
         
     | 
    	
        torch-ext/megablocks/grouped_gemm/ops.py
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from . import backend
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class GroupedGemm(torch.autograd.Function):
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                @staticmethod
         
     | 
| 8 | 
         
            +
                def forward(ctx, a, b, batch_sizes, trans_b):
         
     | 
| 9 | 
         
            +
                    ctx.save_for_backward(a, b, batch_sizes)
         
     | 
| 10 | 
         
            +
                    ctx.trans_b = trans_b
         
     | 
| 11 | 
         
            +
                    return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                @staticmethod
         
     | 
| 14 | 
         
            +
                def backward(ctx, grad):
         
     | 
| 15 | 
         
            +
                    grad = grad.contiguous()
         
     | 
| 16 | 
         
            +
                    a, b, batch_sizes = ctx.saved_tensors
         
     | 
| 17 | 
         
            +
                    trans_b = ctx.trans_b
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                    agrad = None
         
     | 
| 20 | 
         
            +
                    if ctx.needs_input_grad[0]:
         
     | 
| 21 | 
         
            +
                        agrad = backend.gmm(
         
     | 
| 22 | 
         
            +
                            grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    bgrad = None
         
     | 
| 25 | 
         
            +
                    if ctx.needs_input_grad[1]:
         
     | 
| 26 | 
         
            +
                        lhs, rhs = (grad, a) if trans_b else (a, grad)
         
     | 
| 27 | 
         
            +
                        bgrad = backend.gmm(
         
     | 
| 28 | 
         
            +
                            lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
         
     | 
| 29 | 
         
            +
                    return agrad, bgrad, None, None
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def gmm(a, b, batch_sizes, trans_b=False):
         
     | 
| 33 | 
         
            +
                return GroupedGemm.apply(a, b, batch_sizes, trans_b)
         
     | 
    	
        torch-ext/megablocks/grouped_gemm_util.py
    CHANGED
    
    | 
         @@ -4,7 +4,8 @@ import warnings 
     | 
|
| 4 | 
         | 
| 5 | 
         
             
            _grouped_gemm_is_available: bool = False
         
     | 
| 6 | 
         
             
            try:
         
     | 
| 7 | 
         
            -
                import grouped_gemm
         
     | 
| 
         | 
|
| 8 | 
         
             
                _grouped_gemm_is_available = True
         
     | 
| 9 | 
         
             
            except ImportError as error:
         
     | 
| 10 | 
         
             
                warnings.warn('Grouped GEMM not available.')
         
     | 
| 
         @@ -22,5 +23,9 @@ def assert_grouped_gemm_is_available(): 
     | 
|
| 22 | 
         
             
                assert _grouped_gemm_is_available, msg
         
     | 
| 23 | 
         | 
| 24 | 
         | 
| 25 | 
         
            -
            backend = grouped_gemm.backend if grouped_gemm_is_available() else None
         
     | 
| 26 | 
         
            -
            ops = grouped_gemm.ops if grouped_gemm_is_available() else None
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
             
            _grouped_gemm_is_available: bool = False
         
     | 
| 6 | 
         
             
            try:
         
     | 
| 7 | 
         
            +
                # import grouped_gemm
         
     | 
| 8 | 
         
            +
                pass
         
     | 
| 9 | 
         
             
                _grouped_gemm_is_available = True
         
     | 
| 10 | 
         
             
            except ImportError as error:
         
     | 
| 11 | 
         
             
                warnings.warn('Grouped GEMM not available.')
         
     | 
| 
         | 
|
| 23 | 
         
             
                assert _grouped_gemm_is_available, msg
         
     | 
| 24 | 
         | 
| 25 | 
         | 
| 26 | 
         
            +
            # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
         
     | 
| 27 | 
         
            +
            # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            from .grouped_gemm import backend as ops
         
     | 
| 31 | 
         
            +
            from .grouped_gemm import ops as backend
         
     | 
    	
        torch-ext/megablocks/layers/__init__.py
    CHANGED
    
    | 
         @@ -2,7 +2,7 @@ 
     | 
|
| 2 | 
         
             
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            # from megablocks.layers.dmoe import dMoE
         
     | 
| 5 | 
         
            -
            from  
     | 
| 6 | 
         | 
| 7 | 
         
             
            __all__ = [
         
     | 
| 8 | 
         
             
                'MoE',
         
     | 
| 
         | 
|
| 2 | 
         
             
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            # from megablocks.layers.dmoe import dMoE
         
     | 
| 5 | 
         
            +
            from .moe import MoE
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            __all__ = [
         
     | 
| 8 | 
         
             
                'MoE',
         
     | 
    	
        torch-ext/torch_binding.cpp
    CHANGED
    
    | 
         @@ -9,6 +9,8 @@ 
     | 
|
| 9 | 
         
             
            #include "new_replicate.h"
         
     | 
| 10 | 
         
             
            #include "new_sort.h"
         
     | 
| 11 | 
         | 
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            // void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
         
     | 
| 13 | 
         
             
            torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
         
     | 
| 14 | 
         
             
              megablocks::exclusive_cumsum(x, dim, out);
         
     | 
| 
         @@ -70,6 +72,12 @@ torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out 
     | 
|
| 70 | 
         
             
              return x_out;
         
     | 
| 71 | 
         
             
            }
         
     | 
| 72 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 73 | 
         
             
            // Reference implementation:
         
     | 
| 74 | 
         
             
            //
         
     | 
| 75 | 
         
             
            // m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
         
     | 
| 
         @@ -101,6 +109,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 
     | 
|
| 101 | 
         | 
| 102 | 
         
             
              ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
         
     | 
| 103 | 
         
             
              ops.impl("sort", torch::kCUDA, &sort_wrapper);
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 104 | 
         
             
            }
         
     | 
| 105 | 
         | 
| 106 | 
         
             
            REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
         
     | 
| 
         | 
|
| 9 | 
         
             
            #include "new_replicate.h"
         
     | 
| 10 | 
         
             
            #include "new_sort.h"
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            #include "grouped_gemm/grouped_gemm.h"
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
             
            // void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
         
     | 
| 15 | 
         
             
            torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
         
     | 
| 16 | 
         
             
              megablocks::exclusive_cumsum(x, dim, out);
         
     | 
| 
         | 
|
| 72 | 
         
             
              return x_out;
         
     | 
| 73 | 
         
             
            }
         
     | 
| 74 | 
         | 
| 75 | 
         
            +
            // GroupedGemm operation
         
     | 
| 76 | 
         
            +
            torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) {
         
     | 
| 77 | 
         
            +
              grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b);
         
     | 
| 78 | 
         
            +
              return c;
         
     | 
| 79 | 
         
            +
            }
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
             
            // Reference implementation:
         
     | 
| 82 | 
         
             
            //
         
     | 
| 83 | 
         
             
            // m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
         
     | 
| 
         | 
|
| 109 | 
         | 
| 110 | 
         
             
              ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
         
     | 
| 111 | 
         
             
              ops.impl("sort", torch::kCUDA, &sort_wrapper);
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
              // Register the gmm GroupedGemm operation
         
     | 
| 114 | 
         
            +
              ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)");
         
     | 
| 115 | 
         
            +
              ops.impl("gmm", torch::kCUDA, &gmm);
         
     | 
| 116 | 
         
             
            }
         
     | 
| 117 | 
         | 
| 118 | 
         
             
            REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
         
     |