Mohamed Mekkouri
commited on
Commit
·
60ccc25
1
Parent(s):
9ffd725
Clean
Browse files- .gitignore +2 -0
- CMakeLists.txt +0 -128
- build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc +0 -0
- build/torch28-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so} +1 -1
- build/torch28-metal-aarch64-darwin/gptoss_kernels/_ops.py +3 -3
- build/torch28-metal-aarch64-darwin/gptoss_kernels/test.py +0 -15
- build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc +0 -0
- build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc +0 -0
- build/torch29-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so} +1 -1
- build/torch29-metal-aarch64-darwin/gptoss_kernels/_ops.py +3 -3
- build/torch29-metal-aarch64-darwin/gptoss_kernels/test.py +0 -15
- cmake/compile-metal.cmake +0 -90
- cmake/metallib_to_header.py +0 -73
- cmake/utils.cmake +0 -557
- pyproject.toml +0 -10
- setup.py +0 -118
- torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc +0 -0
- torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc +0 -0
- torch-ext/gptoss_kernels/test.py +0 -15
- torch-ext/registration.h +0 -30
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
**/_pycache__/
|
CMakeLists.txt
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
cmake_minimum_required(VERSION 3.26)
|
| 2 |
-
project(gptoss_kernels LANGUAGES CXX C OBJC OBJCXX)
|
| 3 |
-
|
| 4 |
-
set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
|
| 5 |
-
|
| 6 |
-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
| 7 |
-
|
| 8 |
-
include(FetchContent)
|
| 9 |
-
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
| 10 |
-
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
| 11 |
-
|
| 12 |
-
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
| 13 |
-
|
| 14 |
-
if(DEFINED Python3_EXECUTABLE)
|
| 15 |
-
# Allow passing through the interpreter (e.g. from setup.py).
|
| 16 |
-
find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter)
|
| 17 |
-
if (NOT Python3_FOUND)
|
| 18 |
-
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 19 |
-
endif()
|
| 20 |
-
else()
|
| 21 |
-
find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
|
| 22 |
-
endif()
|
| 23 |
-
|
| 24 |
-
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
| 25 |
-
|
| 26 |
-
find_package(Torch REQUIRED)
|
| 27 |
-
|
| 28 |
-
add_compile_definitions(METAL_KERNEL)
|
| 29 |
-
|
| 30 |
-
# Initialize list for Metal shader sources
|
| 31 |
-
set(ALL_METAL_SOURCES)
|
| 32 |
-
#get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
|
| 33 |
-
#list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
|
| 34 |
-
|
| 35 |
-
set(TORCH_gptoss_kernels_SRC
|
| 36 |
-
torch-ext/torch_binding.cpp torch-ext/torch_binding.h
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
list(APPEND SRC "${TORCH_gptoss_kernels_SRC}")
|
| 41 |
-
set(gptoss_kernels_SRC
|
| 42 |
-
"gptoss_kernels/include/gpt-oss.h"
|
| 43 |
-
"gptoss_kernels/include/gpt-oss/types.h"
|
| 44 |
-
"gptoss_kernels/include/gpt-oss/macros.h"
|
| 45 |
-
"gptoss_kernels/include/gpt-oss/functions.h"
|
| 46 |
-
"gptoss_kernels/source/accumulate.metal"
|
| 47 |
-
"gptoss_kernels/source/log.c"
|
| 48 |
-
"gptoss_kernels/source/expert_routing_metadata.metal"
|
| 49 |
-
"gptoss_kernels/source/metal.mm"
|
| 50 |
-
"gptoss_kernels/source/scatter.metal"
|
| 51 |
-
"gptoss_kernels/source/topk.metal"
|
| 52 |
-
"gptoss_kernels/source/embeddings.metal"
|
| 53 |
-
"gptoss_kernels/source/metal-kernels.c"
|
| 54 |
-
"gptoss_kernels/source/tensor_wrappers.cpp"
|
| 55 |
-
"gptoss_kernels/source/random.metal"
|
| 56 |
-
"gptoss_kernels/source/sdpa.metal"
|
| 57 |
-
"gptoss_kernels/source/matmul.metal"
|
| 58 |
-
"gptoss_kernels/source/rmsnorm.metal"
|
| 59 |
-
"gptoss_kernels/source/sample.metal"
|
| 60 |
-
"gptoss_kernels/source/moematmul.metal"
|
| 61 |
-
"gptoss_kernels/source/convert.metal"
|
| 62 |
-
"gptoss_kernels/source/rope.metal"
|
| 63 |
-
"gptoss_kernels/source/gather_and_accumulate.metal"
|
| 64 |
-
"gptoss_kernels/source/include/internal/uuid.h"
|
| 65 |
-
"gptoss_kernels/source/include/internal/metal.hpp"
|
| 66 |
-
"gptoss_kernels/source/include/internal/datatype.h"
|
| 67 |
-
"gptoss_kernels/source/include/internal/rng.h"
|
| 68 |
-
"gptoss_kernels/source/include/internal/rng.hpp"
|
| 69 |
-
"gptoss_kernels/source/include/internal/log.h"
|
| 70 |
-
"gptoss_kernels/source/include/internal/macros.h"
|
| 71 |
-
"gptoss_kernels/source/include/internal/storage.h"
|
| 72 |
-
"gptoss_kernels/source/include/internal/model.h"
|
| 73 |
-
"gptoss_kernels/source/include/internal/math.h"
|
| 74 |
-
"gptoss_kernels/source/include/internal/metal.h"
|
| 75 |
-
"gptoss_kernels/source/include/internal/kernel-args.h"
|
| 76 |
-
"gptoss_kernels/source/include/internal/datatype.hpp"
|
| 77 |
-
"gptoss_kernels/source/include/internal/metal-kernels.h"
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
# Separate Metal shader files from other sources
|
| 81 |
-
set(gptoss_kernels_METAL_SRC)
|
| 82 |
-
set(gptoss_kernels_CPP_SRC)
|
| 83 |
-
|
| 84 |
-
foreach(src_file IN LISTS gptoss_kernels_SRC)
|
| 85 |
-
if(src_file MATCHES "\\.(metal|h)$")
|
| 86 |
-
list(APPEND gptoss_kernels_METAL_SRC ${src_file})
|
| 87 |
-
else()
|
| 88 |
-
list(APPEND gptoss_kernels_CPP_SRC ${src_file})
|
| 89 |
-
endif()
|
| 90 |
-
endforeach()
|
| 91 |
-
|
| 92 |
-
# TODO: check if CLion support this:
|
| 93 |
-
# https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
|
| 94 |
-
set_source_files_properties(
|
| 95 |
-
${gptoss_kernels_CPP_SRC}
|
| 96 |
-
PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.")
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# Add C++ sources to main source list
|
| 100 |
-
list(APPEND SRC "${gptoss_kernels_CPP_SRC}")
|
| 101 |
-
|
| 102 |
-
# Keep track of Metal sources for later compilation
|
| 103 |
-
if(gptoss_kernels_METAL_SRC)
|
| 104 |
-
list(APPEND ALL_METAL_SOURCES "${gptoss_kernels_METAL_SRC}")
|
| 105 |
-
endif()
|
| 106 |
-
|
| 107 |
-
# Keep the includes directory for the Metal sources
|
| 108 |
-
if(gptoss_kernels_METAL_SRC)
|
| 109 |
-
list(APPEND METAL_INCLUDE_DIRS ${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.)
|
| 110 |
-
endif()
|
| 111 |
-
|
| 112 |
-
# Include Metal shader compilation utilities
|
| 113 |
-
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
|
| 114 |
-
|
| 115 |
-
define_gpu_extension_target(
|
| 116 |
-
_gptoss_kernels_3a886f8_dirty
|
| 117 |
-
DESTINATION _gptoss_kernels_3a886f8_dirty
|
| 118 |
-
LANGUAGE ${GPU_LANG}
|
| 119 |
-
SOURCES ${SRC}
|
| 120 |
-
COMPILE_FLAGS ${GPU_FLAGS}
|
| 121 |
-
ARCHITECTURES ${GPU_ARCHES}
|
| 122 |
-
USE_SABI 3
|
| 123 |
-
WITH_SOABI)
|
| 124 |
-
|
| 125 |
-
# Compile Metal shaders if any were found
|
| 126 |
-
if(ALL_METAL_SOURCES)
|
| 127 |
-
compile_metal_shaders(_gptoss_kernels_3a886f8_dirty "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
|
| 128 |
-
endif()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc
CHANGED
|
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
|
|
|
build/torch28-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 391752
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e9880a41249b5e1859134140111e0515925129b0c70d4587a70324b489bb9e3
|
| 3 |
size 391752
|
build/torch28-metal-aarch64-darwin/gptoss_kernels/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _gptoss_kernels_9ffd725_dirty
|
| 3 |
+
ops = torch.ops._gptoss_kernels_9ffd725_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_gptoss_kernels_9ffd725_dirty::{op_name}"
|
build/torch28-metal-aarch64-darwin/gptoss_kernels/test.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import _gptoss_kernels_931bc1b_dirty
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
print(dir(_gptoss_kernels_931bc1b_dirty))
|
| 5 |
-
|
| 6 |
-
from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
|
| 7 |
-
|
| 8 |
-
print(dir(f32_bf16w_matmul))
|
| 9 |
-
|
| 10 |
-
input = torch.randn(10, 10)
|
| 11 |
-
weight_bf16 = torch.randn(10, 10)
|
| 12 |
-
bias_bf16 = torch.randn(10)
|
| 13 |
-
output = torch.randn(10, 10)
|
| 14 |
-
f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
|
| 15 |
-
print(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc
CHANGED
|
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
|
|
|
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc
CHANGED
|
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
|
|
|
build/torch29-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 392840
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd8bc3d9d6d40953740d4566aeb3543b1025a647f8df05a9d296299bafbd8c31
|
| 3 |
size 392840
|
build/torch29-metal-aarch64-darwin/gptoss_kernels/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _gptoss_kernels_9ffd725_dirty
|
| 3 |
+
ops = torch.ops._gptoss_kernels_9ffd725_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_gptoss_kernels_9ffd725_dirty::{op_name}"
|
build/torch29-metal-aarch64-darwin/gptoss_kernels/test.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import _gptoss_kernels_931bc1b_dirty
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
print(dir(_gptoss_kernels_931bc1b_dirty))
|
| 5 |
-
|
| 6 |
-
from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
|
| 7 |
-
|
| 8 |
-
print(dir(f32_bf16w_matmul))
|
| 9 |
-
|
| 10 |
-
input = torch.randn(10, 10)
|
| 11 |
-
weight_bf16 = torch.randn(10, 10)
|
| 12 |
-
bias_bf16 = torch.randn(10)
|
| 13 |
-
output = torch.randn(10, 10)
|
| 14 |
-
f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
|
| 15 |
-
print(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmake/compile-metal.cmake
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
# Metal shader compilation function
|
| 2 |
-
function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS)
|
| 3 |
-
# Find the Metal compiler
|
| 4 |
-
find_program(METAL_COMPILER xcrun REQUIRED)
|
| 5 |
-
|
| 6 |
-
# Set Metal compiler flags
|
| 7 |
-
set(METAL_FLAGS "-std=metal3.2" "-O2")
|
| 8 |
-
|
| 9 |
-
# Output directory for compiled metallib
|
| 10 |
-
set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
|
| 11 |
-
file(MAKE_DIRECTORY ${METALLIB_OUTPUT_DIR})
|
| 12 |
-
|
| 13 |
-
foreach(INC ${EXTRA_INCLUDE_DIRS})
|
| 14 |
-
list(APPEND METAL_FLAGS "-I${INC}")
|
| 15 |
-
endforeach()
|
| 16 |
-
|
| 17 |
-
# Separate .metal files from .h files and compile .metal files to .air
|
| 18 |
-
set(AIR_FILES)
|
| 19 |
-
set(METAL_FILES)
|
| 20 |
-
set(HEADER_FILES)
|
| 21 |
-
|
| 22 |
-
foreach(SOURCE_FILE ${METAL_SOURCES})
|
| 23 |
-
if(SOURCE_FILE MATCHES "\\.metal$")
|
| 24 |
-
list(APPEND METAL_FILES ${SOURCE_FILE})
|
| 25 |
-
elseif(SOURCE_FILE MATCHES "\\.h$")
|
| 26 |
-
list(APPEND HEADER_FILES ${SOURCE_FILE})
|
| 27 |
-
endif()
|
| 28 |
-
endforeach()
|
| 29 |
-
|
| 30 |
-
foreach(METAL_FILE ${METAL_FILES})
|
| 31 |
-
get_filename_component(METAL_NAME ${METAL_FILE} NAME_WE)
|
| 32 |
-
set(AIR_FILE "${CMAKE_BINARY_DIR}/${METAL_NAME}.air")
|
| 33 |
-
|
| 34 |
-
# Include header files as dependencies
|
| 35 |
-
set(ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE})
|
| 36 |
-
foreach(HEADER_FILE ${HEADER_FILES})
|
| 37 |
-
list(APPEND ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${HEADER_FILE})
|
| 38 |
-
endforeach()
|
| 39 |
-
|
| 40 |
-
add_custom_command(
|
| 41 |
-
OUTPUT ${AIR_FILE}
|
| 42 |
-
COMMAND ${METAL_COMPILER} -sdk macosx metal ${METAL_FLAGS}
|
| 43 |
-
-c ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE}
|
| 44 |
-
-o ${AIR_FILE}
|
| 45 |
-
DEPENDS ${ALL_DEPENDENCIES}
|
| 46 |
-
COMMENT "Compiling Metal shader ${METAL_FILE} to ${AIR_FILE}"
|
| 47 |
-
VERBATIM
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
list(APPEND AIR_FILES ${AIR_FILE})
|
| 51 |
-
endforeach()
|
| 52 |
-
|
| 53 |
-
# Link all .air files into a single .metallib
|
| 54 |
-
set(METALLIB_FILE "${METALLIB_OUTPUT_DIR}/${TARGET_NAME}.metallib")
|
| 55 |
-
add_custom_command(
|
| 56 |
-
OUTPUT ${METALLIB_FILE}
|
| 57 |
-
COMMAND ${METAL_COMPILER} -sdk macosx metallib ${AIR_FILES}
|
| 58 |
-
-o ${METALLIB_FILE}
|
| 59 |
-
DEPENDS ${AIR_FILES}
|
| 60 |
-
COMMENT "Linking Metal library ${METALLIB_FILE}"
|
| 61 |
-
VERBATIM
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
# Generate C++ header with embedded metallib data
|
| 65 |
-
set(METALLIB_HEADER "${CMAKE_BINARY_DIR}/${TARGET_NAME}_metallib.h")
|
| 66 |
-
set(METALLIB_TO_HEADER_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/metallib_to_header.py")
|
| 67 |
-
|
| 68 |
-
add_custom_command(
|
| 69 |
-
OUTPUT ${METALLIB_HEADER}
|
| 70 |
-
COMMAND ${Python_EXECUTABLE} ${METALLIB_TO_HEADER_SCRIPT} ${METALLIB_FILE} ${METALLIB_HEADER} ${TARGET_NAME}
|
| 71 |
-
DEPENDS ${METALLIB_FILE} ${METALLIB_TO_HEADER_SCRIPT}
|
| 72 |
-
COMMENT "Generating embedded Metal library header ${METALLIB_HEADER}"
|
| 73 |
-
VERBATIM
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Create a custom target for the metallib
|
| 77 |
-
add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE} ${METALLIB_HEADER})
|
| 78 |
-
|
| 79 |
-
# Add dependency to main target
|
| 80 |
-
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_metallib)
|
| 81 |
-
|
| 82 |
-
# Add the generated header to include directories
|
| 83 |
-
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_BINARY_DIR})
|
| 84 |
-
|
| 85 |
-
# Pass the metallib header and namespace as compile definitions
|
| 86 |
-
target_compile_definitions(${TARGET_NAME} PRIVATE
|
| 87 |
-
EMBEDDED_METALLIB_HEADER="${TARGET_NAME}_metallib.h"
|
| 88 |
-
EMBEDDED_METALLIB_NAMESPACE=${TARGET_NAME}_metal
|
| 89 |
-
)
|
| 90 |
-
endfunction()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmake/metallib_to_header.py
DELETED
|
@@ -1,73 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import sys
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
def convert_metallib_to_header(metallib_path: str, header_path: str, target_name: str) -> None:
|
| 6 |
-
"""Convert a metallib binary file to a C++ header with embedded data."""
|
| 7 |
-
|
| 8 |
-
# Read the metallib binary data
|
| 9 |
-
with open(metallib_path, 'rb') as f:
|
| 10 |
-
data: bytes = f.read()
|
| 11 |
-
|
| 12 |
-
# Generate the header content
|
| 13 |
-
header_content: str = """// Auto-generated file containing embedded Metal library
|
| 14 |
-
#pragma once
|
| 15 |
-
#include <cstddef>
|
| 16 |
-
#include <Metal/Metal.h>
|
| 17 |
-
|
| 18 |
-
namespace """ + target_name + """_metal {
|
| 19 |
-
static const unsigned char metallib_data[] = {
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
# Convert binary data to C array format
|
| 23 |
-
bytes_per_line: int = 16
|
| 24 |
-
for i in range(0, len(data), bytes_per_line):
|
| 25 |
-
chunk: bytes = data[i:i + bytes_per_line]
|
| 26 |
-
hex_values: str = ', '.join('0x{:02x}'.format(b) for b in chunk)
|
| 27 |
-
header_content += " " + hex_values + ","
|
| 28 |
-
if i + bytes_per_line < len(data):
|
| 29 |
-
header_content += "\n"
|
| 30 |
-
|
| 31 |
-
header_content += """
|
| 32 |
-
};
|
| 33 |
-
static const size_t metallib_data_len = """ + str(len(data)) + """;
|
| 34 |
-
|
| 35 |
-
// Convenience function to create Metal library from embedded data
|
| 36 |
-
inline id<MTLLibrary> createLibrary(id<MTLDevice> device, NSError** error = nullptr) {
|
| 37 |
-
dispatch_data_t libraryData = dispatch_data_create(
|
| 38 |
-
metallib_data,
|
| 39 |
-
metallib_data_len,
|
| 40 |
-
dispatch_get_main_queue(),
|
| 41 |
-
^{ /* No cleanup needed for static data */ });
|
| 42 |
-
|
| 43 |
-
NSError* localError = nil;
|
| 44 |
-
id<MTLLibrary> library = [device newLibraryWithData:libraryData error:&localError];
|
| 45 |
-
|
| 46 |
-
if (error) {
|
| 47 |
-
*error = localError;
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
return library;
|
| 51 |
-
}
|
| 52 |
-
} // namespace """ + target_name + """_metal
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
# Write the header file
|
| 56 |
-
dir_path: str = os.path.dirname(header_path)
|
| 57 |
-
if dir_path:
|
| 58 |
-
os.makedirs(dir_path, exist_ok=True)
|
| 59 |
-
with open(header_path, 'w') as f:
|
| 60 |
-
f.write(header_content)
|
| 61 |
-
|
| 62 |
-
print("Generated {} ({} bytes)".format(header_path, len(data)))
|
| 63 |
-
|
| 64 |
-
if __name__ == "__main__":
|
| 65 |
-
if len(sys.argv) != 4:
|
| 66 |
-
print("Usage: metallib_to_header.py <metallib_path> <header_path> <target_name>")
|
| 67 |
-
sys.exit(1)
|
| 68 |
-
|
| 69 |
-
metallib_path: str = sys.argv[1]
|
| 70 |
-
header_path: str = sys.argv[2]
|
| 71 |
-
target_name: str = sys.argv[3]
|
| 72 |
-
|
| 73 |
-
convert_metallib_to_header(metallib_path, header_path, target_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cmake/utils.cmake
DELETED
|
@@ -1,557 +0,0 @@
|
|
| 1 |
-
# Vendored from vLLM:
|
| 2 |
-
#
|
| 3 |
-
# https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
|
| 4 |
-
#
|
| 5 |
-
# Attempt to find the python package that uses the same python executable as
|
| 6 |
-
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
| 7 |
-
#
|
| 8 |
-
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
| 9 |
-
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
| 10 |
-
set(Python3_EXECUTABLE ${EXECUTABLE})
|
| 11 |
-
find_package(Python3 COMPONENTS Interpreter Development.Module Development.SABIModule)
|
| 12 |
-
if (NOT Python3_FOUND)
|
| 13 |
-
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
| 14 |
-
endif()
|
| 15 |
-
set(_VER "${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}")
|
| 16 |
-
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
| 17 |
-
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
| 18 |
-
message(FATAL_ERROR
|
| 19 |
-
"Python version (${_VER}) is not one of the supported versions: "
|
| 20 |
-
"${_SUPPORTED_VERSIONS_LIST}.")
|
| 21 |
-
endif()
|
| 22 |
-
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
| 23 |
-
endmacro()
|
| 24 |
-
|
| 25 |
-
#
|
| 26 |
-
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
| 27 |
-
# has trailing whitespace stripped. If an error is encountered when running
|
| 28 |
-
# python, a fatal message `ERR_MSG` is issued.
|
| 29 |
-
#
|
| 30 |
-
function (run_python OUT EXPR ERR_MSG)
|
| 31 |
-
execute_process(
|
| 32 |
-
COMMAND
|
| 33 |
-
"${Python3_EXECUTABLE}" "-c" "${EXPR}"
|
| 34 |
-
OUTPUT_VARIABLE PYTHON_OUT
|
| 35 |
-
RESULT_VARIABLE PYTHON_ERROR_CODE
|
| 36 |
-
ERROR_VARIABLE PYTHON_STDERR
|
| 37 |
-
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
| 38 |
-
|
| 39 |
-
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
| 40 |
-
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
| 41 |
-
endif()
|
| 42 |
-
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
| 43 |
-
endfunction()
|
| 44 |
-
|
| 45 |
-
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
| 46 |
-
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
| 47 |
-
macro (append_cmake_prefix_path PKG EXPR)
|
| 48 |
-
run_python(_PREFIX_PATH
|
| 49 |
-
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
| 50 |
-
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
| 51 |
-
endmacro()
|
| 52 |
-
|
| 53 |
-
#
|
| 54 |
-
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
|
| 55 |
-
# of CUDA source files. The names of the corresponding "hipified" sources are
|
| 56 |
-
# stored in `OUT_SRCS`.
|
| 57 |
-
#
|
| 58 |
-
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
| 59 |
-
#
|
| 60 |
-
# Split into C++ and non-C++ (i.e. CUDA) sources.
|
| 61 |
-
#
|
| 62 |
-
set(NODUP_SRCS ${ORIG_SRCS})
|
| 63 |
-
list(REMOVE_DUPLICATES NODUP_SRCS)
|
| 64 |
-
set(SRCS ${NODUP_SRCS})
|
| 65 |
-
set(CXX_SRCS ${NODUP_SRCS})
|
| 66 |
-
list(FILTER SRCS INCLUDE REGEX "\.cu$")
|
| 67 |
-
list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
|
| 68 |
-
|
| 69 |
-
#
|
| 70 |
-
# Generate ROCm/HIP source file names from CUDA file names.
|
| 71 |
-
# Since HIP files are generated code, they will appear in the build area
|
| 72 |
-
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
|
| 73 |
-
#
|
| 74 |
-
set(HIP_SRCS)
|
| 75 |
-
foreach (SRC ${SRCS})
|
| 76 |
-
get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
|
| 77 |
-
get_source_file_property(compile_options "${SRC}" COMPILE_OPTIONS)
|
| 78 |
-
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
|
| 79 |
-
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
|
| 80 |
-
|
| 81 |
-
if(include_dirs)
|
| 82 |
-
# Copy over include directories from the original CUDA file.
|
| 83 |
-
set_source_files_properties(
|
| 84 |
-
${SRC}
|
| 85 |
-
PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
|
| 86 |
-
endif()
|
| 87 |
-
|
| 88 |
-
if(compile_options)
|
| 89 |
-
set_source_files_properties(
|
| 90 |
-
${SRC}
|
| 91 |
-
PROPERTIES COMPILE_OPTIONS "${compile_options}")
|
| 92 |
-
endif()
|
| 93 |
-
|
| 94 |
-
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
|
| 95 |
-
endforeach()
|
| 96 |
-
|
| 97 |
-
add_custom_target(
|
| 98 |
-
hipify${NAME}
|
| 99 |
-
COMMAND "${Python3_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
|
| 100 |
-
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
|
| 101 |
-
BYPRODUCTS ${HIP_SRCS}
|
| 102 |
-
COMMENT "Running hipify on ${NAME} extension source files.")
|
| 103 |
-
|
| 104 |
-
# Swap out original extension sources with hipified sources.
|
| 105 |
-
list(APPEND HIP_SRCS ${CXX_SRCS})
|
| 106 |
-
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
|
| 107 |
-
endfunction()
|
| 108 |
-
|
| 109 |
-
#
|
| 110 |
-
# Get additional GPU compiler flags from torch.
|
| 111 |
-
#
|
| 112 |
-
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
| 113 |
-
if (${GPU_LANG} STREQUAL "CUDA")
|
| 114 |
-
#
|
| 115 |
-
# Get common NVCC flags from torch.
|
| 116 |
-
#
|
| 117 |
-
run_python(GPU_FLAGS
|
| 118 |
-
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
|
| 119 |
-
"Failed to determine torch nvcc compiler flags")
|
| 120 |
-
|
| 121 |
-
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
| 122 |
-
list(APPEND GPU_FLAGS "-DENABLE_FP8")
|
| 123 |
-
list(REMOVE_ITEM GPU_FLAGS
|
| 124 |
-
"-D__CUDA_NO_HALF_OPERATORS__"
|
| 125 |
-
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
| 126 |
-
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
| 127 |
-
"-D__CUDA_NO_HALF2_OPERATORS__")
|
| 128 |
-
endif()
|
| 129 |
-
|
| 130 |
-
elseif(${GPU_LANG} STREQUAL "HIP")
|
| 131 |
-
#
|
| 132 |
-
# Get common HIP/HIPCC flags from torch.
|
| 133 |
-
#
|
| 134 |
-
run_python(GPU_FLAGS
|
| 135 |
-
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
|
| 136 |
-
"Failed to determine torch nvcc compiler flags")
|
| 137 |
-
|
| 138 |
-
list(APPEND GPU_FLAGS
|
| 139 |
-
"-DUSE_ROCM"
|
| 140 |
-
"-DENABLE_FP8"
|
| 141 |
-
"-U__HIP_NO_HALF_CONVERSIONS__"
|
| 142 |
-
"-U__HIP_NO_HALF_OPERATORS__"
|
| 143 |
-
"-fno-gpu-rdc")
|
| 144 |
-
|
| 145 |
-
endif()
|
| 146 |
-
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
|
| 147 |
-
endfunction()
|
| 148 |
-
|
| 149 |
-
# Macro for converting a `gencode` version number to a cmake version number.
|
| 150 |
-
macro(string_to_ver OUT_VER IN_STR)
|
| 151 |
-
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
| 152 |
-
endmacro()
|
| 153 |
-
|
| 154 |
-
#
|
| 155 |
-
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
| 156 |
-
# `CUDA_ARCH_FLAGS`.
|
| 157 |
-
#
|
| 158 |
-
# Example:
|
| 159 |
-
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
| 160 |
-
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
| 161 |
-
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
| 162 |
-
# CMAKE_CUDA_FLAGS="-Wall"
|
| 163 |
-
#
|
| 164 |
-
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
| 165 |
-
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
| 166 |
-
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
| 167 |
-
${CMAKE_CUDA_FLAGS})
|
| 168 |
-
|
| 169 |
-
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
| 170 |
-
# and passed back via the `CUDA_ARCHITECTURES` property.
|
| 171 |
-
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
| 172 |
-
${CMAKE_CUDA_FLAGS})
|
| 173 |
-
endmacro()
|
| 174 |
-
|
| 175 |
-
#
|
| 176 |
-
# Extract unique CUDA architectures from a list of compute capabilities codes in
|
| 177 |
-
# the form `<major><minor>[<letter>]`, convert them to the form sort
|
| 178 |
-
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
|
| 179 |
-
# stores them in `OUT_ARCHES`.
|
| 180 |
-
#
|
| 181 |
-
# Example:
|
| 182 |
-
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
|
| 183 |
-
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 184 |
-
# OUT_ARCHES="7.5;...;9.0"
|
| 185 |
-
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
|
| 186 |
-
set(_CUDA_ARCHES)
|
| 187 |
-
foreach(_ARCH ${CUDA_ARCH_FLAGS})
|
| 188 |
-
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
| 189 |
-
if (_COMPUTE)
|
| 190 |
-
set(_COMPUTE ${CMAKE_MATCH_1})
|
| 191 |
-
endif()
|
| 192 |
-
|
| 193 |
-
string_to_ver(_COMPUTE_VER ${_COMPUTE})
|
| 194 |
-
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
|
| 195 |
-
endforeach()
|
| 196 |
-
|
| 197 |
-
list(REMOVE_DUPLICATES _CUDA_ARCHES)
|
| 198 |
-
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
|
| 199 |
-
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
|
| 200 |
-
endfunction()
|
| 201 |
-
|
| 202 |
-
#
|
| 203 |
-
# For a specific file set the `-gencode` flag in compile options conditionally
|
| 204 |
-
# for the CUDA language.
|
| 205 |
-
#
|
| 206 |
-
# Example:
|
| 207 |
-
# set_gencode_flag_for_srcs(
|
| 208 |
-
# SRCS "foo.cu"
|
| 209 |
-
# ARCH "compute_75"
|
| 210 |
-
# CODE "sm_75")
|
| 211 |
-
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
|
| 212 |
-
# `foo.cu` (only for the CUDA language).
|
| 213 |
-
#
|
| 214 |
-
macro(set_gencode_flag_for_srcs)
|
| 215 |
-
set(options)
|
| 216 |
-
set(oneValueArgs ARCH CODE)
|
| 217 |
-
set(multiValueArgs SRCS)
|
| 218 |
-
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 219 |
-
"${multiValueArgs}" ${ARGN} )
|
| 220 |
-
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
|
| 221 |
-
set_property(
|
| 222 |
-
SOURCE ${arg_SRCS}
|
| 223 |
-
APPEND PROPERTY
|
| 224 |
-
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
|
| 228 |
-
endmacro(set_gencode_flag_for_srcs)
|
| 229 |
-
|
| 230 |
-
#
|
| 231 |
-
# For a list of source files set the `-gencode` flags in the files specific
|
| 232 |
-
# compile options (specifically for the CUDA language).
|
| 233 |
-
#
|
| 234 |
-
# arguments are:
|
| 235 |
-
# SRCS: list of source files
|
| 236 |
-
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
|
| 237 |
-
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
|
| 238 |
-
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
|
| 239 |
-
# that is larger than BUILD_PTX_FOR_ARCH.
|
| 240 |
-
#
|
| 241 |
-
macro(set_gencode_flags_for_srcs)
|
| 242 |
-
set(options)
|
| 243 |
-
set(oneValueArgs BUILD_PTX_FOR_ARCH)
|
| 244 |
-
set(multiValueArgs SRCS CUDA_ARCHS)
|
| 245 |
-
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
|
| 246 |
-
"${multiValueArgs}" ${ARGN} )
|
| 247 |
-
|
| 248 |
-
foreach(_ARCH ${arg_CUDA_ARCHS})
|
| 249 |
-
# handle +PTX suffix: generate both sm and ptx codes if requested
|
| 250 |
-
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
| 251 |
-
if(NOT _HAS_PTX EQUAL -1)
|
| 252 |
-
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
| 253 |
-
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
| 254 |
-
set_gencode_flag_for_srcs(
|
| 255 |
-
SRCS ${arg_SRCS}
|
| 256 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 257 |
-
CODE "sm_${_STRIPPED_ARCH}")
|
| 258 |
-
set_gencode_flag_for_srcs(
|
| 259 |
-
SRCS ${arg_SRCS}
|
| 260 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 261 |
-
CODE "compute_${_STRIPPED_ARCH}")
|
| 262 |
-
else()
|
| 263 |
-
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
| 264 |
-
set_gencode_flag_for_srcs(
|
| 265 |
-
SRCS ${arg_SRCS}
|
| 266 |
-
ARCH "compute_${_STRIPPED_ARCH}"
|
| 267 |
-
CODE "sm_${_STRIPPED_ARCH}")
|
| 268 |
-
endif()
|
| 269 |
-
endforeach()
|
| 270 |
-
|
| 271 |
-
if (${arg_BUILD_PTX_FOR_ARCH})
|
| 272 |
-
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 273 |
-
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
|
| 274 |
-
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
|
| 275 |
-
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
|
| 276 |
-
set_gencode_flag_for_srcs(
|
| 277 |
-
SRCS ${arg_SRCS}
|
| 278 |
-
ARCH "compute_${_PTX_ARCH}"
|
| 279 |
-
CODE "compute_${_PTX_ARCH}")
|
| 280 |
-
endif()
|
| 281 |
-
endif()
|
| 282 |
-
endmacro()
|
| 283 |
-
|
| 284 |
-
#
|
| 285 |
-
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
| 286 |
-
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
| 287 |
-
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
| 288 |
-
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
| 289 |
-
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
| 290 |
-
# architecture in `SRC_CUDA_ARCHS`.
|
| 291 |
-
# The loose intersection is defined as:
|
| 292 |
-
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 293 |
-
# where `<=` is the version comparison operator.
|
| 294 |
-
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
| 295 |
-
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
| 296 |
-
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
| 297 |
-
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
| 298 |
-
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
| 299 |
-
# The result is stored in `OUT_CUDA_ARCHS`.
|
| 300 |
-
#
|
| 301 |
-
# Example:
|
| 302 |
-
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
|
| 303 |
-
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
|
| 304 |
-
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 305 |
-
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
| 306 |
-
#
|
| 307 |
-
# Example With PTX:
|
| 308 |
-
# SRC_CUDA_ARCHS="8.0+PTX"
|
| 309 |
-
# TGT_CUDA_ARCHS="9.0"
|
| 310 |
-
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 311 |
-
# OUT_CUDA_ARCHS="8.0+PTX"
|
| 312 |
-
#
|
| 313 |
-
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
| 314 |
-
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
| 315 |
-
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
| 316 |
-
|
| 317 |
-
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
| 318 |
-
set(_PTX_ARCHS)
|
| 319 |
-
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 320 |
-
if(_arch MATCHES "\\+PTX$")
|
| 321 |
-
string(REPLACE "+PTX" "" _base "${_arch}")
|
| 322 |
-
list(APPEND _PTX_ARCHS "${_base}")
|
| 323 |
-
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 324 |
-
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
| 325 |
-
endif()
|
| 326 |
-
endforeach()
|
| 327 |
-
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
| 328 |
-
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
| 329 |
-
|
| 330 |
-
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
| 331 |
-
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
| 332 |
-
set(_CUDA_ARCHS)
|
| 333 |
-
foreach(_arch ${_SRC_CUDA_ARCHS})
|
| 334 |
-
if(_arch MATCHES "\\a$")
|
| 335 |
-
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
| 336 |
-
string(REPLACE "a" "" _base "${_arch}")
|
| 337 |
-
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
| 338 |
-
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
| 339 |
-
list(APPEND _CUDA_ARCHS "${_arch}")
|
| 340 |
-
endif()
|
| 341 |
-
endif()
|
| 342 |
-
endforeach()
|
| 343 |
-
|
| 344 |
-
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
| 345 |
-
|
| 346 |
-
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
| 347 |
-
# is less or equal to ARCH (but has the same major version since SASS binary
|
| 348 |
-
# compatibility is only forward compatible within the same major version).
|
| 349 |
-
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
| 350 |
-
set(_TMP_ARCH)
|
| 351 |
-
# Extract the major version of the target arch
|
| 352 |
-
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
| 353 |
-
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
| 354 |
-
# Extract the major version of the source arch
|
| 355 |
-
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
| 356 |
-
# Check version-less-or-equal, and allow PTX arches to match across majors
|
| 357 |
-
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
| 358 |
-
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
| 359 |
-
set(_TMP_ARCH "${_SRC_ARCH}")
|
| 360 |
-
endif()
|
| 361 |
-
else()
|
| 362 |
-
# If we hit a version greater than the target, we can break
|
| 363 |
-
break()
|
| 364 |
-
endif()
|
| 365 |
-
endforeach()
|
| 366 |
-
|
| 367 |
-
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
| 368 |
-
if (_TMP_ARCH)
|
| 369 |
-
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
| 370 |
-
endif()
|
| 371 |
-
endforeach()
|
| 372 |
-
|
| 373 |
-
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
| 374 |
-
|
| 375 |
-
# reapply +PTX suffix to architectures that requested PTX
|
| 376 |
-
set(_FINAL_ARCHS)
|
| 377 |
-
foreach(_arch ${_CUDA_ARCHS})
|
| 378 |
-
if(_arch IN_LIST _PTX_ARCHS)
|
| 379 |
-
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
| 380 |
-
else()
|
| 381 |
-
list(APPEND _FINAL_ARCHS "${_arch}")
|
| 382 |
-
endif()
|
| 383 |
-
endforeach()
|
| 384 |
-
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
| 385 |
-
|
| 386 |
-
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
| 387 |
-
endfunction()
|
| 388 |
-
|
| 389 |
-
#
|
| 390 |
-
# For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
|
| 391 |
-
# `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
|
| 392 |
-
# The loose intersection is defined as:
|
| 393 |
-
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
| 394 |
-
# where `<=` is the version comparison operator.
|
| 395 |
-
# In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
|
| 396 |
-
# in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
|
| 397 |
-
# The result is stored in `OUT_ROCM_ARCHS`.
|
| 398 |
-
#
|
| 399 |
-
# Example:
|
| 400 |
-
# SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
|
| 401 |
-
# TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
|
| 402 |
-
# hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 403 |
-
# OUT_ROCM_ARCHS="gfx906;gfx908"
|
| 404 |
-
#
|
| 405 |
-
function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
|
| 406 |
-
list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
|
| 407 |
-
|
| 408 |
-
# ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
|
| 409 |
-
# and x is a letter. We can sort them by string comparison which works for this format.
|
| 410 |
-
list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
|
| 411 |
-
|
| 412 |
-
set(_ROCM_ARCHS)
|
| 413 |
-
|
| 414 |
-
# Find the intersection of supported architectures
|
| 415 |
-
foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
|
| 416 |
-
if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
|
| 417 |
-
list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
|
| 418 |
-
endif()
|
| 419 |
-
endforeach()
|
| 420 |
-
|
| 421 |
-
list(REMOVE_DUPLICATES _ROCM_ARCHS)
|
| 422 |
-
set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
|
| 423 |
-
endfunction()
|
| 424 |
-
|
| 425 |
-
#
|
| 426 |
-
# Override the GPU architectures detected by cmake/torch and filter them by
|
| 427 |
-
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
| 428 |
-
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
|
| 429 |
-
# the architectures on a per file basis.
|
| 430 |
-
#
|
| 431 |
-
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
| 432 |
-
#
|
| 433 |
-
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
| 434 |
-
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
|
| 435 |
-
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
|
| 436 |
-
|
| 437 |
-
if (${GPU_LANG} STREQUAL "HIP")
|
| 438 |
-
#
|
| 439 |
-
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
| 440 |
-
#
|
| 441 |
-
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
|
| 442 |
-
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
|
| 443 |
-
# "rocm_agent_enumerator" in "enable_language(HIP)"
|
| 444 |
-
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
|
| 445 |
-
#
|
| 446 |
-
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
| 447 |
-
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
|
| 448 |
-
else()
|
| 449 |
-
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
|
| 450 |
-
endif()
|
| 451 |
-
#
|
| 452 |
-
# Find the intersection of the supported + detected architectures to
|
| 453 |
-
# set the module architecture flags.
|
| 454 |
-
#
|
| 455 |
-
set(${GPU_ARCHES})
|
| 456 |
-
foreach (_ARCH ${HIP_ARCHITECTURES})
|
| 457 |
-
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
| 458 |
-
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
| 459 |
-
endif()
|
| 460 |
-
endforeach()
|
| 461 |
-
|
| 462 |
-
if(NOT ${GPU_ARCHES})
|
| 463 |
-
message(FATAL_ERROR
|
| 464 |
-
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
|
| 465 |
-
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
| 466 |
-
endif()
|
| 467 |
-
endif()
|
| 468 |
-
endmacro()
|
| 469 |
-
|
| 470 |
-
#
|
| 471 |
-
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
| 472 |
-
# arguments are:
|
| 473 |
-
#
|
| 474 |
-
# DESTINATION <dest> - Module destination directory.
|
| 475 |
-
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
| 476 |
-
# etc.
|
| 477 |
-
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
| 478 |
-
# directory.
|
| 479 |
-
#
|
| 480 |
-
# Optional arguments:
|
| 481 |
-
#
|
| 482 |
-
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
| 483 |
-
# format.
|
| 484 |
-
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
| 485 |
-
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
| 486 |
-
# ARCHITECTURES will use cmake's defaults if
|
| 487 |
-
# not provided.
|
| 488 |
-
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
| 489 |
-
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
| 490 |
-
# LIBRARIES <libraries> - Extra link libraries.
|
| 491 |
-
# WITH_SOABI - Generate library with python SOABI suffix name.
|
| 492 |
-
# USE_SABI <version> - Use python stable api <version>
|
| 493 |
-
#
|
| 494 |
-
# Note: optimization level/debug info is set via cmake build type.
|
| 495 |
-
#
|
| 496 |
-
function (define_gpu_extension_target GPU_MOD_NAME)
|
| 497 |
-
cmake_parse_arguments(PARSE_ARGV 1
|
| 498 |
-
GPU
|
| 499 |
-
"WITH_SOABI"
|
| 500 |
-
"DESTINATION;LANGUAGE;USE_SABI"
|
| 501 |
-
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
| 502 |
-
|
| 503 |
-
# Add hipify preprocessing step when building with HIP/ROCm.
|
| 504 |
-
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 505 |
-
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
| 506 |
-
endif()
|
| 507 |
-
|
| 508 |
-
if (GPU_WITH_SOABI)
|
| 509 |
-
set(GPU_WITH_SOABI WITH_SOABI)
|
| 510 |
-
else()
|
| 511 |
-
set(GPU_WITH_SOABI)
|
| 512 |
-
endif()
|
| 513 |
-
|
| 514 |
-
if (GPU_USE_SABI)
|
| 515 |
-
Python3_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 516 |
-
else()
|
| 517 |
-
Python3_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
| 518 |
-
endif()
|
| 519 |
-
|
| 520 |
-
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 521 |
-
# Make this target dependent on the hipify preprocessor step.
|
| 522 |
-
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
| 523 |
-
endif()
|
| 524 |
-
|
| 525 |
-
if (GPU_ARCHITECTURES)
|
| 526 |
-
if (GPU_LANGUAGE STREQUAL "HIP")
|
| 527 |
-
# Clear target architectures, we are passing arch flags per source file.
|
| 528 |
-
set_property(TARGET ${GPU_MOD_NAME} PROPERTY HIP_ARCHITECTURES off)
|
| 529 |
-
else()
|
| 530 |
-
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
| 531 |
-
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
| 532 |
-
endif()
|
| 533 |
-
endif()
|
| 534 |
-
|
| 535 |
-
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
| 536 |
-
|
| 537 |
-
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
| 538 |
-
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
| 539 |
-
|
| 540 |
-
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
| 541 |
-
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
| 542 |
-
|
| 543 |
-
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
| 544 |
-
${GPU_INCLUDE_DIRECTORIES})
|
| 545 |
-
|
| 546 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
|
| 547 |
-
|
| 548 |
-
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
| 549 |
-
# dependencies that are not necessary and may not be installed.
|
| 550 |
-
if (GPU_LANGUAGE STREQUAL "CUDA")
|
| 551 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
|
| 552 |
-
else()
|
| 553 |
-
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
| 554 |
-
endif()
|
| 555 |
-
|
| 556 |
-
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
|
| 557 |
-
endfunction()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pyproject.toml
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = [
|
| 3 |
-
"cmake>=3.26",
|
| 4 |
-
"ninja",
|
| 5 |
-
"packaging",
|
| 6 |
-
"setuptools>=61",
|
| 7 |
-
"torch",
|
| 8 |
-
"wheel",
|
| 9 |
-
]
|
| 10 |
-
build-backend = "setuptools.build_meta"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
from shutil import which, move
|
| 4 |
-
import subprocess
|
| 5 |
-
import sys
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
from setuptools import Extension, find_packages, setup
|
| 9 |
-
from setuptools.command.build_ext import build_ext
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def is_sccache_available() -> bool:
|
| 15 |
-
return which("sccache") is not None
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def is_ccache_available() -> bool:
|
| 19 |
-
return which("ccache") is not None
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def is_ninja_available() -> bool:
|
| 23 |
-
return which("ninja") is not None
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class CMakeExtension(Extension):
|
| 27 |
-
def __init__(self, name: str, sourcedir: str = "") -> None:
|
| 28 |
-
super().__init__(name, sources=[], py_limited_api=True)
|
| 29 |
-
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class CMakeBuild(build_ext):
|
| 33 |
-
def build_extension(self, ext: CMakeExtension) -> None:
|
| 34 |
-
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
|
| 35 |
-
extdir = ext_fullpath.parent.resolve()
|
| 36 |
-
|
| 37 |
-
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
| 38 |
-
cfg = "Debug" if debug else "Release"
|
| 39 |
-
|
| 40 |
-
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
| 41 |
-
|
| 42 |
-
# Set Python3_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
| 43 |
-
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
| 44 |
-
# from Python.
|
| 45 |
-
cmake_args = [
|
| 46 |
-
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
| 47 |
-
f"-DPython3_EXECUTABLE={sys.executable}",
|
| 48 |
-
f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
|
| 49 |
-
]
|
| 50 |
-
build_args = []
|
| 51 |
-
if "CMAKE_ARGS" in os.environ:
|
| 52 |
-
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
| 53 |
-
|
| 54 |
-
if not cmake_generator or cmake_generator == "Ninja":
|
| 55 |
-
try:
|
| 56 |
-
import ninja
|
| 57 |
-
|
| 58 |
-
ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
|
| 59 |
-
cmake_args += [
|
| 60 |
-
"-GNinja",
|
| 61 |
-
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
|
| 62 |
-
]
|
| 63 |
-
except ImportError:
|
| 64 |
-
pass
|
| 65 |
-
|
| 66 |
-
if is_sccache_available():
|
| 67 |
-
cmake_args += [
|
| 68 |
-
"-DCMAKE_C_COMPILER_LAUNCHER=sccache",
|
| 69 |
-
"-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
|
| 70 |
-
"-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
|
| 71 |
-
"-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
|
| 72 |
-
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
|
| 73 |
-
]
|
| 74 |
-
elif is_ccache_available():
|
| 75 |
-
cmake_args += [
|
| 76 |
-
"-DCMAKE_C_COMPILER_LAUNCHER=ccache",
|
| 77 |
-
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
|
| 78 |
-
"-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
|
| 79 |
-
"-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
|
| 80 |
-
"-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
|
| 81 |
-
]
|
| 82 |
-
|
| 83 |
-
num_jobs = os.getenv("MAX_JOBS", None)
|
| 84 |
-
if num_jobs is not None:
|
| 85 |
-
num_jobs = int(num_jobs)
|
| 86 |
-
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
|
| 87 |
-
else:
|
| 88 |
-
try:
|
| 89 |
-
# os.sched_getaffinity() isn't universally available, so fall
|
| 90 |
-
# back to os.cpu_count() if we get an error here.
|
| 91 |
-
num_jobs = len(os.sched_getaffinity(0))
|
| 92 |
-
except AttributeError:
|
| 93 |
-
num_jobs = os.cpu_count()
|
| 94 |
-
|
| 95 |
-
build_temp = Path(self.build_temp) / ext.name
|
| 96 |
-
if not build_temp.exists():
|
| 97 |
-
build_temp.mkdir(parents=True)
|
| 98 |
-
|
| 99 |
-
subprocess.run(
|
| 100 |
-
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
| 101 |
-
)
|
| 102 |
-
subprocess.run(
|
| 103 |
-
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
setup(
|
| 108 |
-
name="gptoss_kernels",
|
| 109 |
-
# The version is just a stub, it's not used by the final build artefact.
|
| 110 |
-
version="0.1.0",
|
| 111 |
-
ext_modules=[CMakeExtension("gptoss_kernels._gptoss_kernels_3a886f8_dirty")],
|
| 112 |
-
cmdclass={"build_ext": CMakeBuild},
|
| 113 |
-
packages=find_packages(where="torch-ext", include=["gptoss_kernels*"]),
|
| 114 |
-
package_dir={"": "torch-ext"},
|
| 115 |
-
zip_safe=False,
|
| 116 |
-
install_requires=["torch"],
|
| 117 |
-
python_requires=">=3.9",
|
| 118 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc
DELETED
|
Binary file (868 Bytes)
|
|
|
torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc
DELETED
|
Binary file (552 Bytes)
|
|
|
torch-ext/gptoss_kernels/test.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import _gptoss_kernels_931bc1b_dirty
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
print(dir(_gptoss_kernels_931bc1b_dirty))
|
| 5 |
-
|
| 6 |
-
from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
|
| 7 |
-
|
| 8 |
-
print(dir(f32_bf16w_matmul))
|
| 9 |
-
|
| 10 |
-
input = torch.randn(10, 10)
|
| 11 |
-
weight_bf16 = torch.randn(10, 10)
|
| 12 |
-
bias_bf16 = torch.randn(10)
|
| 13 |
-
output = torch.randn(10, 10)
|
| 14 |
-
f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
|
| 15 |
-
print(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch-ext/registration.h
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
// Registration macros from vLLM:
|
| 2 |
-
// https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <Python.h>
|
| 7 |
-
|
| 8 |
-
#define _CONCAT(A, B) A##B
|
| 9 |
-
#define CONCAT(A, B) _CONCAT(A, B)
|
| 10 |
-
|
| 11 |
-
#define _STRINGIFY(A) #A
|
| 12 |
-
#define STRINGIFY(A) _STRINGIFY(A)
|
| 13 |
-
|
| 14 |
-
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
| 15 |
-
// could be a macro instead of a literal token.
|
| 16 |
-
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
| 17 |
-
|
| 18 |
-
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
| 19 |
-
// could be a macro instead of a literal token.
|
| 20 |
-
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
| 21 |
-
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
| 22 |
-
|
| 23 |
-
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
| 24 |
-
// via python's import statement.
|
| 25 |
-
#define REGISTER_EXTENSION(NAME) \
|
| 26 |
-
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
| 27 |
-
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
| 28 |
-
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
| 29 |
-
return PyModule_Create(&module); \
|
| 30 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|