#ifdef TEST_ON_CUDA #include #include #include namespace wmma = nvcuda::wmma; #define LIB_CALL(call) \ do { \ cudaError_t err = call; \ if (err != cudaSuccess) { \ abort(); \ } \ } while (0) #define HOST_TYPE(x) cuda##x #else #ifndef HIP_HEADERS__ #include #include #include #include #define HIP_HEADERS__ #endif namespace wmma = rocwmma; #define LIB_CALL(call) \ do { \ hipError_t err = call; \ if (err != hipSuccess) { \ abort(); \ } \ } while (0) #define HOST_TYPE(x) hip##x #endif