Mohamed Mekkouri commited on
Commit
60ccc25
·
1 Parent(s): 9ffd725
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv/
2
+ **/_pycache__/
CMakeLists.txt DELETED
@@ -1,128 +0,0 @@
1
- cmake_minimum_required(VERSION 3.26)
2
- project(gptoss_kernels LANGUAGES CXX C OBJC OBJCXX)
3
-
4
- set(CMAKE_OSX_DEPLOYMENT_TARGET "15.0" CACHE STRING "Minimum macOS deployment version")
5
-
6
- install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7
-
8
- include(FetchContent)
9
- file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10
- message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11
-
12
- include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
13
-
14
- if(DEFINED Python3_EXECUTABLE)
15
- # Allow passing through the interpreter (e.g. from setup.py).
16
- find_package(Python3 COMPONENTS Development Development.SABIModule Interpreter)
17
- if (NOT Python3_FOUND)
18
- message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
19
- endif()
20
- else()
21
- find_package(Python3 REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
22
- endif()
23
-
24
- append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
25
-
26
- find_package(Torch REQUIRED)
27
-
28
- add_compile_definitions(METAL_KERNEL)
29
-
30
- # Initialize list for Metal shader sources
31
- set(ALL_METAL_SOURCES)
32
- #get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
33
- #list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
34
-
35
- set(TORCH_gptoss_kernels_SRC
36
- torch-ext/torch_binding.cpp torch-ext/torch_binding.h
37
- )
38
-
39
-
40
- list(APPEND SRC "${TORCH_gptoss_kernels_SRC}")
41
- set(gptoss_kernels_SRC
42
- "gptoss_kernels/include/gpt-oss.h"
43
- "gptoss_kernels/include/gpt-oss/types.h"
44
- "gptoss_kernels/include/gpt-oss/macros.h"
45
- "gptoss_kernels/include/gpt-oss/functions.h"
46
- "gptoss_kernels/source/accumulate.metal"
47
- "gptoss_kernels/source/log.c"
48
- "gptoss_kernels/source/expert_routing_metadata.metal"
49
- "gptoss_kernels/source/metal.mm"
50
- "gptoss_kernels/source/scatter.metal"
51
- "gptoss_kernels/source/topk.metal"
52
- "gptoss_kernels/source/embeddings.metal"
53
- "gptoss_kernels/source/metal-kernels.c"
54
- "gptoss_kernels/source/tensor_wrappers.cpp"
55
- "gptoss_kernels/source/random.metal"
56
- "gptoss_kernels/source/sdpa.metal"
57
- "gptoss_kernels/source/matmul.metal"
58
- "gptoss_kernels/source/rmsnorm.metal"
59
- "gptoss_kernels/source/sample.metal"
60
- "gptoss_kernels/source/moematmul.metal"
61
- "gptoss_kernels/source/convert.metal"
62
- "gptoss_kernels/source/rope.metal"
63
- "gptoss_kernels/source/gather_and_accumulate.metal"
64
- "gptoss_kernels/source/include/internal/uuid.h"
65
- "gptoss_kernels/source/include/internal/metal.hpp"
66
- "gptoss_kernels/source/include/internal/datatype.h"
67
- "gptoss_kernels/source/include/internal/rng.h"
68
- "gptoss_kernels/source/include/internal/rng.hpp"
69
- "gptoss_kernels/source/include/internal/log.h"
70
- "gptoss_kernels/source/include/internal/macros.h"
71
- "gptoss_kernels/source/include/internal/storage.h"
72
- "gptoss_kernels/source/include/internal/model.h"
73
- "gptoss_kernels/source/include/internal/math.h"
74
- "gptoss_kernels/source/include/internal/metal.h"
75
- "gptoss_kernels/source/include/internal/kernel-args.h"
76
- "gptoss_kernels/source/include/internal/datatype.hpp"
77
- "gptoss_kernels/source/include/internal/metal-kernels.h"
78
- )
79
-
80
- # Separate Metal shader files from other sources
81
- set(gptoss_kernels_METAL_SRC)
82
- set(gptoss_kernels_CPP_SRC)
83
-
84
- foreach(src_file IN LISTS gptoss_kernels_SRC)
85
- if(src_file MATCHES "\\.(metal|h)$")
86
- list(APPEND gptoss_kernels_METAL_SRC ${src_file})
87
- else()
88
- list(APPEND gptoss_kernels_CPP_SRC ${src_file})
89
- endif()
90
- endforeach()
91
-
92
- # TODO: check if CLion support this:
93
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
94
- set_source_files_properties(
95
- ${gptoss_kernels_CPP_SRC}
96
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.")
97
-
98
-
99
- # Add C++ sources to main source list
100
- list(APPEND SRC "${gptoss_kernels_CPP_SRC}")
101
-
102
- # Keep track of Metal sources for later compilation
103
- if(gptoss_kernels_METAL_SRC)
104
- list(APPEND ALL_METAL_SOURCES "${gptoss_kernels_METAL_SRC}")
105
- endif()
106
-
107
- # Keep the includes directory for the Metal sources
108
- if(gptoss_kernels_METAL_SRC)
109
- list(APPEND METAL_INCLUDE_DIRS ${CMAKE_SOURCE_DIR}/gptoss_kernels/source/include;${CMAKE_SOURCE_DIR}/gptoss_kernels/include;${CMAKE_SOURCE_DIR}/.)
110
- endif()
111
-
112
- # Include Metal shader compilation utilities
113
- include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
114
-
115
- define_gpu_extension_target(
116
- _gptoss_kernels_3a886f8_dirty
117
- DESTINATION _gptoss_kernels_3a886f8_dirty
118
- LANGUAGE ${GPU_LANG}
119
- SOURCES ${SRC}
120
- COMPILE_FLAGS ${GPU_FLAGS}
121
- ARCHITECTURES ${GPU_ARCHES}
122
- USE_SABI 3
123
- WITH_SOABI)
124
-
125
- # Compile Metal shaders if any were found
126
- if(ALL_METAL_SOURCES)
127
- compile_metal_shaders(_gptoss_kernels_3a886f8_dirty "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
128
- endif()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b52d3924ac74e614664fd9ec72e9673807ed170e57277b81c1922c0b54a88a6a
3
  size 391752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e9880a41249b5e1859134140111e0515925129b0c70d4587a70324b489bb9e3
3
  size 391752
build/torch28-metal-aarch64-darwin/gptoss_kernels/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _gptoss_kernels_9964bae_dirty
3
- ops = torch.ops._gptoss_kernels_9964bae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_gptoss_kernels_9964bae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _gptoss_kernels_9ffd725_dirty
3
+ ops = torch.ops._gptoss_kernels_9ffd725_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_gptoss_kernels_9ffd725_dirty::{op_name}"
build/torch28-metal-aarch64-darwin/gptoss_kernels/test.py DELETED
@@ -1,15 +0,0 @@
1
- import _gptoss_kernels_931bc1b_dirty
2
- import torch
3
-
4
- print(dir(_gptoss_kernels_931bc1b_dirty))
5
-
6
- from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
7
-
8
- print(dir(f32_bf16w_matmul))
9
-
10
- input = torch.randn(10, 10)
11
- weight_bf16 = torch.randn(10, 10)
12
- bias_bf16 = torch.randn(10)
13
- output = torch.randn(10, 10)
14
- f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
15
- print(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_9964bae_dirty.abi3.so → _gptoss_kernels_9ffd725_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dc170dbf45587f9a1091e9b6c92ab02ebe4dc3cdd13be8e56a9a8d3a353d8c86
3
  size 392840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd8bc3d9d6d40953740d4566aeb3543b1025a647f8df05a9d296299bafbd8c31
3
  size 392840
build/torch29-metal-aarch64-darwin/gptoss_kernels/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _gptoss_kernels_9964bae_dirty
3
- ops = torch.ops._gptoss_kernels_9964bae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_gptoss_kernels_9964bae_dirty::{op_name}"
 
1
  import torch
2
+ from . import _gptoss_kernels_9ffd725_dirty
3
+ ops = torch.ops._gptoss_kernels_9ffd725_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_gptoss_kernels_9ffd725_dirty::{op_name}"
build/torch29-metal-aarch64-darwin/gptoss_kernels/test.py DELETED
@@ -1,15 +0,0 @@
1
- import _gptoss_kernels_931bc1b_dirty
2
- import torch
3
-
4
- print(dir(_gptoss_kernels_931bc1b_dirty))
5
-
6
- from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
7
-
8
- print(dir(f32_bf16w_matmul))
9
-
10
- input = torch.randn(10, 10)
11
- weight_bf16 = torch.randn(10, 10)
12
- bias_bf16 = torch.randn(10)
13
- output = torch.randn(10, 10)
14
- f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
15
- print(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cmake/compile-metal.cmake DELETED
@@ -1,90 +0,0 @@
1
- # Metal shader compilation function
2
- function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS)
3
- # Find the Metal compiler
4
- find_program(METAL_COMPILER xcrun REQUIRED)
5
-
6
- # Set Metal compiler flags
7
- set(METAL_FLAGS "-std=metal3.2" "-O2")
8
-
9
- # Output directory for compiled metallib
10
- set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib")
11
- file(MAKE_DIRECTORY ${METALLIB_OUTPUT_DIR})
12
-
13
- foreach(INC ${EXTRA_INCLUDE_DIRS})
14
- list(APPEND METAL_FLAGS "-I${INC}")
15
- endforeach()
16
-
17
- # Separate .metal files from .h files and compile .metal files to .air
18
- set(AIR_FILES)
19
- set(METAL_FILES)
20
- set(HEADER_FILES)
21
-
22
- foreach(SOURCE_FILE ${METAL_SOURCES})
23
- if(SOURCE_FILE MATCHES "\\.metal$")
24
- list(APPEND METAL_FILES ${SOURCE_FILE})
25
- elseif(SOURCE_FILE MATCHES "\\.h$")
26
- list(APPEND HEADER_FILES ${SOURCE_FILE})
27
- endif()
28
- endforeach()
29
-
30
- foreach(METAL_FILE ${METAL_FILES})
31
- get_filename_component(METAL_NAME ${METAL_FILE} NAME_WE)
32
- set(AIR_FILE "${CMAKE_BINARY_DIR}/${METAL_NAME}.air")
33
-
34
- # Include header files as dependencies
35
- set(ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE})
36
- foreach(HEADER_FILE ${HEADER_FILES})
37
- list(APPEND ALL_DEPENDENCIES ${CMAKE_CURRENT_SOURCE_DIR}/${HEADER_FILE})
38
- endforeach()
39
-
40
- add_custom_command(
41
- OUTPUT ${AIR_FILE}
42
- COMMAND ${METAL_COMPILER} -sdk macosx metal ${METAL_FLAGS}
43
- -c ${CMAKE_CURRENT_SOURCE_DIR}/${METAL_FILE}
44
- -o ${AIR_FILE}
45
- DEPENDS ${ALL_DEPENDENCIES}
46
- COMMENT "Compiling Metal shader ${METAL_FILE} to ${AIR_FILE}"
47
- VERBATIM
48
- )
49
-
50
- list(APPEND AIR_FILES ${AIR_FILE})
51
- endforeach()
52
-
53
- # Link all .air files into a single .metallib
54
- set(METALLIB_FILE "${METALLIB_OUTPUT_DIR}/${TARGET_NAME}.metallib")
55
- add_custom_command(
56
- OUTPUT ${METALLIB_FILE}
57
- COMMAND ${METAL_COMPILER} -sdk macosx metallib ${AIR_FILES}
58
- -o ${METALLIB_FILE}
59
- DEPENDS ${AIR_FILES}
60
- COMMENT "Linking Metal library ${METALLIB_FILE}"
61
- VERBATIM
62
- )
63
-
64
- # Generate C++ header with embedded metallib data
65
- set(METALLIB_HEADER "${CMAKE_BINARY_DIR}/${TARGET_NAME}_metallib.h")
66
- set(METALLIB_TO_HEADER_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/metallib_to_header.py")
67
-
68
- add_custom_command(
69
- OUTPUT ${METALLIB_HEADER}
70
- COMMAND ${Python_EXECUTABLE} ${METALLIB_TO_HEADER_SCRIPT} ${METALLIB_FILE} ${METALLIB_HEADER} ${TARGET_NAME}
71
- DEPENDS ${METALLIB_FILE} ${METALLIB_TO_HEADER_SCRIPT}
72
- COMMENT "Generating embedded Metal library header ${METALLIB_HEADER}"
73
- VERBATIM
74
- )
75
-
76
- # Create a custom target for the metallib
77
- add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE} ${METALLIB_HEADER})
78
-
79
- # Add dependency to main target
80
- add_dependencies(${TARGET_NAME} ${TARGET_NAME}_metallib)
81
-
82
- # Add the generated header to include directories
83
- target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_BINARY_DIR})
84
-
85
- # Pass the metallib header and namespace as compile definitions
86
- target_compile_definitions(${TARGET_NAME} PRIVATE
87
- EMBEDDED_METALLIB_HEADER="${TARGET_NAME}_metallib.h"
88
- EMBEDDED_METALLIB_NAMESPACE=${TARGET_NAME}_metal
89
- )
90
- endfunction()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cmake/metallib_to_header.py DELETED
@@ -1,73 +0,0 @@
1
- #!/usr/bin/env python3
2
- import sys
3
- import os
4
-
5
- def convert_metallib_to_header(metallib_path: str, header_path: str, target_name: str) -> None:
6
- """Convert a metallib binary file to a C++ header with embedded data."""
7
-
8
- # Read the metallib binary data
9
- with open(metallib_path, 'rb') as f:
10
- data: bytes = f.read()
11
-
12
- # Generate the header content
13
- header_content: str = """// Auto-generated file containing embedded Metal library
14
- #pragma once
15
- #include <cstddef>
16
- #include <Metal/Metal.h>
17
-
18
- namespace """ + target_name + """_metal {
19
- static const unsigned char metallib_data[] = {
20
- """
21
-
22
- # Convert binary data to C array format
23
- bytes_per_line: int = 16
24
- for i in range(0, len(data), bytes_per_line):
25
- chunk: bytes = data[i:i + bytes_per_line]
26
- hex_values: str = ', '.join('0x{:02x}'.format(b) for b in chunk)
27
- header_content += " " + hex_values + ","
28
- if i + bytes_per_line < len(data):
29
- header_content += "\n"
30
-
31
- header_content += """
32
- };
33
- static const size_t metallib_data_len = """ + str(len(data)) + """;
34
-
35
- // Convenience function to create Metal library from embedded data
36
- inline id<MTLLibrary> createLibrary(id<MTLDevice> device, NSError** error = nullptr) {
37
- dispatch_data_t libraryData = dispatch_data_create(
38
- metallib_data,
39
- metallib_data_len,
40
- dispatch_get_main_queue(),
41
- ^{ /* No cleanup needed for static data */ });
42
-
43
- NSError* localError = nil;
44
- id<MTLLibrary> library = [device newLibraryWithData:libraryData error:&localError];
45
-
46
- if (error) {
47
- *error = localError;
48
- }
49
-
50
- return library;
51
- }
52
- } // namespace """ + target_name + """_metal
53
- """
54
-
55
- # Write the header file
56
- dir_path: str = os.path.dirname(header_path)
57
- if dir_path:
58
- os.makedirs(dir_path, exist_ok=True)
59
- with open(header_path, 'w') as f:
60
- f.write(header_content)
61
-
62
- print("Generated {} ({} bytes)".format(header_path, len(data)))
63
-
64
- if __name__ == "__main__":
65
- if len(sys.argv) != 4:
66
- print("Usage: metallib_to_header.py <metallib_path> <header_path> <target_name>")
67
- sys.exit(1)
68
-
69
- metallib_path: str = sys.argv[1]
70
- header_path: str = sys.argv[2]
71
- target_name: str = sys.argv[3]
72
-
73
- convert_metallib_to_header(metallib_path, header_path, target_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cmake/utils.cmake DELETED
@@ -1,557 +0,0 @@
1
- # Vendored from vLLM:
2
- #
3
- # https://github.com/vllm-project/vllm/blob/main/cmake/utils.cmake
4
- #
5
- # Attempt to find the python package that uses the same python executable as
6
- # `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
7
- #
8
- macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
9
- file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
10
- set(Python3_EXECUTABLE ${EXECUTABLE})
11
- find_package(Python3 COMPONENTS Interpreter Development.Module Development.SABIModule)
12
- if (NOT Python3_FOUND)
13
- message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
14
- endif()
15
- set(_VER "${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}")
16
- set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
17
- if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
18
- message(FATAL_ERROR
19
- "Python version (${_VER}) is not one of the supported versions: "
20
- "${_SUPPORTED_VERSIONS_LIST}.")
21
- endif()
22
- message(STATUS "Found python matching: ${EXECUTABLE}.")
23
- endmacro()
24
-
25
- #
26
- # Run `EXPR` in python. The standard output of python is stored in `OUT` and
27
- # has trailing whitespace stripped. If an error is encountered when running
28
- # python, a fatal message `ERR_MSG` is issued.
29
- #
30
- function (run_python OUT EXPR ERR_MSG)
31
- execute_process(
32
- COMMAND
33
- "${Python3_EXECUTABLE}" "-c" "${EXPR}"
34
- OUTPUT_VARIABLE PYTHON_OUT
35
- RESULT_VARIABLE PYTHON_ERROR_CODE
36
- ERROR_VARIABLE PYTHON_STDERR
37
- OUTPUT_STRIP_TRAILING_WHITESPACE)
38
-
39
- if(NOT PYTHON_ERROR_CODE EQUAL 0)
40
- message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
41
- endif()
42
- set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
43
- endfunction()
44
-
45
- # Run `EXPR` in python after importing `PKG`. Use the result of this to extend
46
- # `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
47
- macro (append_cmake_prefix_path PKG EXPR)
48
- run_python(_PREFIX_PATH
49
- "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
50
- list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
51
- endmacro()
52
-
53
- #
54
- # Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
55
- # of CUDA source files. The names of the corresponding "hipified" sources are
56
- # stored in `OUT_SRCS`.
57
- #
58
- function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
59
- #
60
- # Split into C++ and non-C++ (i.e. CUDA) sources.
61
- #
62
- set(NODUP_SRCS ${ORIG_SRCS})
63
- list(REMOVE_DUPLICATES NODUP_SRCS)
64
- set(SRCS ${NODUP_SRCS})
65
- set(CXX_SRCS ${NODUP_SRCS})
66
- list(FILTER SRCS INCLUDE REGEX "\.cu$")
67
- list(FILTER CXX_SRCS EXCLUDE REGEX "\.cu$")
68
-
69
- #
70
- # Generate ROCm/HIP source file names from CUDA file names.
71
- # Since HIP files are generated code, they will appear in the build area
72
- # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
73
- #
74
- set(HIP_SRCS)
75
- foreach (SRC ${SRCS})
76
- get_source_file_property(include_dirs "${SRC}" INCLUDE_DIRECTORIES)
77
- get_source_file_property(compile_options "${SRC}" COMPILE_OPTIONS)
78
- string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
79
- string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
80
-
81
- if(include_dirs)
82
- # Copy over include directories from the original CUDA file.
83
- set_source_files_properties(
84
- ${SRC}
85
- PROPERTIES INCLUDE_DIRECTORIES "${include_dirs}")
86
- endif()
87
-
88
- if(compile_options)
89
- set_source_files_properties(
90
- ${SRC}
91
- PROPERTIES COMPILE_OPTIONS "${compile_options}")
92
- endif()
93
-
94
- list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
95
- endforeach()
96
-
97
- add_custom_target(
98
- hipify${NAME}
99
- COMMAND "${Python3_EXECUTABLE}" ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR} -o ${CMAKE_CURRENT_BINARY_DIR} ${SRCS}
100
- DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
101
- BYPRODUCTS ${HIP_SRCS}
102
- COMMENT "Running hipify on ${NAME} extension source files.")
103
-
104
- # Swap out original extension sources with hipified sources.
105
- list(APPEND HIP_SRCS ${CXX_SRCS})
106
- set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
107
- endfunction()
108
-
109
- #
110
- # Get additional GPU compiler flags from torch.
111
- #
112
- function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
113
- if (${GPU_LANG} STREQUAL "CUDA")
114
- #
115
- # Get common NVCC flags from torch.
116
- #
117
- run_python(GPU_FLAGS
118
- "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
119
- "Failed to determine torch nvcc compiler flags")
120
-
121
- if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
122
- list(APPEND GPU_FLAGS "-DENABLE_FP8")
123
- list(REMOVE_ITEM GPU_FLAGS
124
- "-D__CUDA_NO_HALF_OPERATORS__"
125
- "-D__CUDA_NO_HALF_CONVERSIONS__"
126
- "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
127
- "-D__CUDA_NO_HALF2_OPERATORS__")
128
- endif()
129
-
130
- elseif(${GPU_LANG} STREQUAL "HIP")
131
- #
132
- # Get common HIP/HIPCC flags from torch.
133
- #
134
- run_python(GPU_FLAGS
135
- "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
136
- "Failed to determine torch nvcc compiler flags")
137
-
138
- list(APPEND GPU_FLAGS
139
- "-DUSE_ROCM"
140
- "-DENABLE_FP8"
141
- "-U__HIP_NO_HALF_CONVERSIONS__"
142
- "-U__HIP_NO_HALF_OPERATORS__"
143
- "-fno-gpu-rdc")
144
-
145
- endif()
146
- set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
147
- endfunction()
148
-
149
- # Macro for converting a `gencode` version number to a cmake version number.
150
- macro(string_to_ver OUT_VER IN_STR)
151
- string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
152
- endmacro()
153
-
154
- #
155
- # Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
156
- # `CUDA_ARCH_FLAGS`.
157
- #
158
- # Example:
159
- # CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
160
- # clear_cuda_arches(CUDA_ARCH_FLAGS)
161
- # CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
162
- # CMAKE_CUDA_FLAGS="-Wall"
163
- #
164
- macro(clear_cuda_arches CUDA_ARCH_FLAGS)
165
- # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
166
- string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
167
- ${CMAKE_CUDA_FLAGS})
168
-
169
- # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
170
- # and passed back via the `CUDA_ARCHITECTURES` property.
171
- string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
172
- ${CMAKE_CUDA_FLAGS})
173
- endmacro()
174
-
175
- #
176
- # Extract unique CUDA architectures from a list of compute capabilities codes in
177
- # the form `<major><minor>[<letter>]`, convert them to the form sort
178
- # `<major>.<minor>`, dedupes them and then sorts them in ascending order and
179
- # stores them in `OUT_ARCHES`.
180
- #
181
- # Example:
182
- # CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
183
- # extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
184
- # OUT_ARCHES="7.5;...;9.0"
185
- function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
186
- set(_CUDA_ARCHES)
187
- foreach(_ARCH ${CUDA_ARCH_FLAGS})
188
- string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
189
- if (_COMPUTE)
190
- set(_COMPUTE ${CMAKE_MATCH_1})
191
- endif()
192
-
193
- string_to_ver(_COMPUTE_VER ${_COMPUTE})
194
- list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
195
- endforeach()
196
-
197
- list(REMOVE_DUPLICATES _CUDA_ARCHES)
198
- list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
199
- set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
200
- endfunction()
201
-
202
- #
203
- # For a specific file set the `-gencode` flag in compile options conditionally
204
- # for the CUDA language.
205
- #
206
- # Example:
207
- # set_gencode_flag_for_srcs(
208
- # SRCS "foo.cu"
209
- # ARCH "compute_75"
210
- # CODE "sm_75")
211
- # adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
212
- # `foo.cu` (only for the CUDA language).
213
- #
214
- macro(set_gencode_flag_for_srcs)
215
- set(options)
216
- set(oneValueArgs ARCH CODE)
217
- set(multiValueArgs SRCS)
218
- cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
219
- "${multiValueArgs}" ${ARGN} )
220
- set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
221
- set_property(
222
- SOURCE ${arg_SRCS}
223
- APPEND PROPERTY
224
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
225
- )
226
-
227
- message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
228
- endmacro(set_gencode_flag_for_srcs)
229
-
230
- #
231
- # For a list of source files set the `-gencode` flags in the files specific
232
- # compile options (specifically for the CUDA language).
233
- #
234
- # arguments are:
235
- # SRCS: list of source files
236
- # CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
237
- # BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
238
- # for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
239
- # that is larger than BUILD_PTX_FOR_ARCH.
240
- #
241
- macro(set_gencode_flags_for_srcs)
242
- set(options)
243
- set(oneValueArgs BUILD_PTX_FOR_ARCH)
244
- set(multiValueArgs SRCS CUDA_ARCHS)
245
- cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
246
- "${multiValueArgs}" ${ARGN} )
247
-
248
- foreach(_ARCH ${arg_CUDA_ARCHS})
249
- # handle +PTX suffix: generate both sm and ptx codes if requested
250
- string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
251
- if(NOT _HAS_PTX EQUAL -1)
252
- string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
253
- string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
254
- set_gencode_flag_for_srcs(
255
- SRCS ${arg_SRCS}
256
- ARCH "compute_${_STRIPPED_ARCH}"
257
- CODE "sm_${_STRIPPED_ARCH}")
258
- set_gencode_flag_for_srcs(
259
- SRCS ${arg_SRCS}
260
- ARCH "compute_${_STRIPPED_ARCH}"
261
- CODE "compute_${_STRIPPED_ARCH}")
262
- else()
263
- string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
264
- set_gencode_flag_for_srcs(
265
- SRCS ${arg_SRCS}
266
- ARCH "compute_${_STRIPPED_ARCH}"
267
- CODE "sm_${_STRIPPED_ARCH}")
268
- endif()
269
- endforeach()
270
-
271
- if (${arg_BUILD_PTX_FOR_ARCH})
272
- list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
273
- list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
274
- if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
275
- string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
276
- set_gencode_flag_for_srcs(
277
- SRCS ${arg_SRCS}
278
- ARCH "compute_${_PTX_ARCH}"
279
- CODE "compute_${_PTX_ARCH}")
280
- endif()
281
- endif()
282
- endmacro()
283
-
284
- #
285
- # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
286
- # `<major>.<minor>[letter]` compute the "loose intersection" with the
287
- # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
288
- # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
289
- # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
290
- # architecture in `SRC_CUDA_ARCHS`.
291
- # The loose intersection is defined as:
292
- # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
293
- # where `<=` is the version comparison operator.
294
- # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
295
- # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
296
- # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
297
- # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
298
- # x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
299
- # The result is stored in `OUT_CUDA_ARCHS`.
300
- #
301
- # Example:
302
- # SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
303
- # TGT_CUDA_ARCHS="8.0;8.9;9.0"
304
- # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
305
- # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
306
- #
307
- # Example With PTX:
308
- # SRC_CUDA_ARCHS="8.0+PTX"
309
- # TGT_CUDA_ARCHS="9.0"
310
- # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
311
- # OUT_CUDA_ARCHS="8.0+PTX"
312
- #
313
- function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
314
- set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
315
- set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
316
-
317
- # handle +PTX suffix: separate base arch for matching, record PTX requests
318
- set(_PTX_ARCHS)
319
- foreach(_arch ${_SRC_CUDA_ARCHS})
320
- if(_arch MATCHES "\\+PTX$")
321
- string(REPLACE "+PTX" "" _base "${_arch}")
322
- list(APPEND _PTX_ARCHS "${_base}")
323
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
324
- list(APPEND _SRC_CUDA_ARCHS "${_base}")
325
- endif()
326
- endforeach()
327
- list(REMOVE_DUPLICATES _PTX_ARCHS)
328
- list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
329
-
330
- # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
331
- # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
332
- set(_CUDA_ARCHS)
333
- foreach(_arch ${_SRC_CUDA_ARCHS})
334
- if(_arch MATCHES "\\a$")
335
- list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
336
- string(REPLACE "a" "" _base "${_arch}")
337
- if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
338
- list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
339
- list(APPEND _CUDA_ARCHS "${_arch}")
340
- endif()
341
- endif()
342
- endforeach()
343
-
344
- list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
345
-
346
- # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
347
- # is less or equal to ARCH (but has the same major version since SASS binary
348
- # compatibility is only forward compatible within the same major version).
349
- foreach(_ARCH ${_TGT_CUDA_ARCHS})
350
- set(_TMP_ARCH)
351
- # Extract the major version of the target arch
352
- string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
353
- foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
354
- # Extract the major version of the source arch
355
- string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
356
- # Check version-less-or-equal, and allow PTX arches to match across majors
357
- if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
358
- if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
359
- set(_TMP_ARCH "${_SRC_ARCH}")
360
- endif()
361
- else()
362
- # If we hit a version greater than the target, we can break
363
- break()
364
- endif()
365
- endforeach()
366
-
367
- # If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
368
- if (_TMP_ARCH)
369
- list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
370
- endif()
371
- endforeach()
372
-
373
- list(REMOVE_DUPLICATES _CUDA_ARCHS)
374
-
375
- # reapply +PTX suffix to architectures that requested PTX
376
- set(_FINAL_ARCHS)
377
- foreach(_arch ${_CUDA_ARCHS})
378
- if(_arch IN_LIST _PTX_ARCHS)
379
- list(APPEND _FINAL_ARCHS "${_arch}+PTX")
380
- else()
381
- list(APPEND _FINAL_ARCHS "${_arch}")
382
- endif()
383
- endforeach()
384
- set(_CUDA_ARCHS ${_FINAL_ARCHS})
385
-
386
- set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
387
- endfunction()
388
-
389
- #
390
- # For the given `SRC_ROCM_ARCHS` list of architecture versions in the form
391
- # `<name>` compute the "loose intersection" with the `TGT_ROCM_ARCHS` list.
392
- # The loose intersection is defined as:
393
- # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
394
- # where `<=` is the version comparison operator.
395
- # In other words, for each version in `TGT_ROCM_ARCHS` find the highest version
396
- # in `SRC_ROCM_ARCHS` that is less or equal to the version in `TGT_ROCM_ARCHS`.
397
- # The result is stored in `OUT_ROCM_ARCHS`.
398
- #
399
- # Example:
400
- # SRC_ROCM_ARCHS="gfx900;gfx906;gfx908;gfx90a"
401
- # TGT_ROCM_ARCHS="gfx906;gfx908;gfx1030"
402
- # hip_archs_loose_intersection(OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
403
- # OUT_ROCM_ARCHS="gfx906;gfx908"
404
- #
405
- function(hip_archs_loose_intersection OUT_ROCM_ARCHS SRC_ROCM_ARCHS TGT_ROCM_ARCHS)
406
- list(REMOVE_DUPLICATES SRC_ROCM_ARCHS)
407
-
408
- # ROCm architectures are typically in format gfxNNN or gfxNNNx where N is a digit
409
- # and x is a letter. We can sort them by string comparison which works for this format.
410
- list(SORT SRC_ROCM_ARCHS COMPARE STRING ORDER ASCENDING)
411
-
412
- set(_ROCM_ARCHS)
413
-
414
- # Find the intersection of supported architectures
415
- foreach(_SRC_ARCH ${SRC_ROCM_ARCHS})
416
- if(_SRC_ARCH IN_LIST TGT_ROCM_ARCHS)
417
- list(APPEND _ROCM_ARCHS ${_SRC_ARCH})
418
- endif()
419
- endforeach()
420
-
421
- list(REMOVE_DUPLICATES _ROCM_ARCHS)
422
- set(${OUT_ROCM_ARCHS} ${_ROCM_ARCHS} PARENT_SCOPE)
423
- endfunction()
424
-
425
- #
426
- # Override the GPU architectures detected by cmake/torch and filter them by
427
- # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
428
- # `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
429
- # the architectures on a per file basis.
430
- #
431
- # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
432
- #
433
- macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
434
- set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
435
- message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
436
-
437
- if (${GPU_LANG} STREQUAL "HIP")
438
- #
439
- # `GPU_ARCHES` controls the `--offload-arch` flags.
440
- #
441
- # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
442
- # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
443
- # "rocm_agent_enumerator" in "enable_language(HIP)"
444
- # (in file Modules/CMakeDetermineHIPCompiler.cmake)
445
- #
446
- if(DEFINED ENV{PYTORCH_ROCM_ARCH})
447
- set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
448
- else()
449
- set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
450
- endif()
451
- #
452
- # Find the intersection of the supported + detected architectures to
453
- # set the module architecture flags.
454
- #
455
- set(${GPU_ARCHES})
456
- foreach (_ARCH ${HIP_ARCHITECTURES})
457
- if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
458
- list(APPEND ${GPU_ARCHES} ${_ARCH})
459
- endif()
460
- endforeach()
461
-
462
- if(NOT ${GPU_ARCHES})
463
- message(FATAL_ERROR
464
- "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
465
- " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
466
- endif()
467
- endif()
468
- endmacro()
469
-
470
- #
471
- # Define a target named `GPU_MOD_NAME` for a single extension. The
472
- # arguments are:
473
- #
474
- # DESTINATION <dest> - Module destination directory.
475
- # LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
476
- # etc.
477
- # SOURCES <sources> - List of source files relative to CMakeLists.txt
478
- # directory.
479
- #
480
- # Optional arguments:
481
- #
482
- # ARCHITECTURES <arches> - A list of target GPU architectures in cmake
483
- # format.
484
- # Refer `CMAKE_CUDA_ARCHITECTURES` documentation
485
- # and `CMAKE_HIP_ARCHITECTURES` for more info.
486
- # ARCHITECTURES will use cmake's defaults if
487
- # not provided.
488
- # COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
489
- # INCLUDE_DIRECTORIES <dirs> - Extra include directories.
490
- # LIBRARIES <libraries> - Extra link libraries.
491
- # WITH_SOABI - Generate library with python SOABI suffix name.
492
- # USE_SABI <version> - Use python stable api <version>
493
- #
494
- # Note: optimization level/debug info is set via cmake build type.
495
- #
496
- function (define_gpu_extension_target GPU_MOD_NAME)
497
- cmake_parse_arguments(PARSE_ARGV 1
498
- GPU
499
- "WITH_SOABI"
500
- "DESTINATION;LANGUAGE;USE_SABI"
501
- "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
502
-
503
- # Add hipify preprocessing step when building with HIP/ROCm.
504
- if (GPU_LANGUAGE STREQUAL "HIP")
505
- hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
506
- endif()
507
-
508
- if (GPU_WITH_SOABI)
509
- set(GPU_WITH_SOABI WITH_SOABI)
510
- else()
511
- set(GPU_WITH_SOABI)
512
- endif()
513
-
514
- if (GPU_USE_SABI)
515
- Python3_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
516
- else()
517
- Python3_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
518
- endif()
519
-
520
- if (GPU_LANGUAGE STREQUAL "HIP")
521
- # Make this target dependent on the hipify preprocessor step.
522
- add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
523
- endif()
524
-
525
- if (GPU_ARCHITECTURES)
526
- if (GPU_LANGUAGE STREQUAL "HIP")
527
- # Clear target architectures, we are passing arch flags per source file.
528
- set_property(TARGET ${GPU_MOD_NAME} PROPERTY HIP_ARCHITECTURES off)
529
- else()
530
- set_target_properties(${GPU_MOD_NAME} PROPERTIES
531
- ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
532
- endif()
533
- endif()
534
-
535
- set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
536
-
537
- target_compile_options(${GPU_MOD_NAME} PRIVATE
538
- $<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
539
-
540
- target_compile_definitions(${GPU_MOD_NAME} PRIVATE
541
- "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
542
-
543
- target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
544
- ${GPU_INCLUDE_DIRECTORIES})
545
-
546
- target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
547
-
548
- # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
549
- # dependencies that are not necessary and may not be installed.
550
- if (GPU_LANGUAGE STREQUAL "CUDA")
551
- target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart)
552
- else()
553
- target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
554
- endif()
555
-
556
- install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
557
- endfunction()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pyproject.toml DELETED
@@ -1,10 +0,0 @@
1
- [build-system]
2
- requires = [
3
- "cmake>=3.26",
4
- "ninja",
5
- "packaging",
6
- "setuptools>=61",
7
- "torch",
8
- "wheel",
9
- ]
10
- build-backend = "setuptools.build_meta"
 
 
 
 
 
 
 
 
 
 
 
setup.py DELETED
@@ -1,118 +0,0 @@
1
- import logging
2
- import os
3
- from shutil import which, move
4
- import subprocess
5
- import sys
6
- from pathlib import Path
7
-
8
- from setuptools import Extension, find_packages, setup
9
- from setuptools.command.build_ext import build_ext
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- def is_sccache_available() -> bool:
15
- return which("sccache") is not None
16
-
17
-
18
- def is_ccache_available() -> bool:
19
- return which("ccache") is not None
20
-
21
-
22
- def is_ninja_available() -> bool:
23
- return which("ninja") is not None
24
-
25
-
26
- class CMakeExtension(Extension):
27
- def __init__(self, name: str, sourcedir: str = "") -> None:
28
- super().__init__(name, sources=[], py_limited_api=True)
29
- self.sourcedir = os.fspath(Path(sourcedir).resolve())
30
-
31
-
32
- class CMakeBuild(build_ext):
33
- def build_extension(self, ext: CMakeExtension) -> None:
34
- ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)
35
- extdir = ext_fullpath.parent.resolve()
36
-
37
- debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
38
- cfg = "Debug" if debug else "Release"
39
-
40
- cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
41
-
42
- # Set Python3_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
43
- # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
44
- # from Python.
45
- cmake_args = [
46
- f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}",
47
- f"-DPython3_EXECUTABLE={sys.executable}",
48
- f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
49
- ]
50
- build_args = []
51
- if "CMAKE_ARGS" in os.environ:
52
- cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
53
-
54
- if not cmake_generator or cmake_generator == "Ninja":
55
- try:
56
- import ninja
57
-
58
- ninja_executable_path = Path(ninja.BIN_DIR) / "ninja"
59
- cmake_args += [
60
- "-GNinja",
61
- f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
62
- ]
63
- except ImportError:
64
- pass
65
-
66
- if is_sccache_available():
67
- cmake_args += [
68
- "-DCMAKE_C_COMPILER_LAUNCHER=sccache",
69
- "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache",
70
- "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache",
71
- "-DCMAKE_OBJC_COMPILER_LAUNCHER=sccache",
72
- "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=sccache",
73
- ]
74
- elif is_ccache_available():
75
- cmake_args += [
76
- "-DCMAKE_C_COMPILER_LAUNCHER=ccache",
77
- "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
78
- "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache",
79
- "-DCMAKE_OBJC_COMPILER_LAUNCHER=ccache",
80
- "-DCMAKE_OBJCXX_COMPILER_LAUNCHER=ccache",
81
- ]
82
-
83
- num_jobs = os.getenv("MAX_JOBS", None)
84
- if num_jobs is not None:
85
- num_jobs = int(num_jobs)
86
- logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
87
- else:
88
- try:
89
- # os.sched_getaffinity() isn't universally available, so fall
90
- # back to os.cpu_count() if we get an error here.
91
- num_jobs = len(os.sched_getaffinity(0))
92
- except AttributeError:
93
- num_jobs = os.cpu_count()
94
-
95
- build_temp = Path(self.build_temp) / ext.name
96
- if not build_temp.exists():
97
- build_temp.mkdir(parents=True)
98
-
99
- subprocess.run(
100
- ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
101
- )
102
- subprocess.run(
103
- ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
104
- )
105
-
106
-
107
- setup(
108
- name="gptoss_kernels",
109
- # The version is just a stub, it's not used by the final build artefact.
110
- version="0.1.0",
111
- ext_modules=[CMakeExtension("gptoss_kernels._gptoss_kernels_3a886f8_dirty")],
112
- cmdclass={"build_ext": CMakeBuild},
113
- packages=find_packages(where="torch-ext", include=["gptoss_kernels*"]),
114
- package_dir={"": "torch-ext"},
115
- zip_safe=False,
116
- install_requires=["torch"],
117
- python_requires=">=3.9",
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/gptoss_kernels/__pycache__/__init__.cpython-313.pyc DELETED
Binary file (868 Bytes)
 
torch-ext/gptoss_kernels/__pycache__/_ops.cpython-313.pyc DELETED
Binary file (552 Bytes)
 
torch-ext/gptoss_kernels/test.py DELETED
@@ -1,15 +0,0 @@
1
- import _gptoss_kernels_931bc1b_dirty
2
- import torch
3
-
4
- print(dir(_gptoss_kernels_931bc1b_dirty))
5
-
6
- from gptoss_kernels import _gptoss_kernels_931bc1b_dirty
7
-
8
- print(dir(f32_bf16w_matmul))
9
-
10
- input = torch.randn(10, 10)
11
- weight_bf16 = torch.randn(10, 10)
12
- bias_bf16 = torch.randn(10)
13
- output = torch.randn(10, 10)
14
- f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, 10, 10, 10, 10)
15
- print(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/registration.h DELETED
@@ -1,30 +0,0 @@
1
- // Registration macros from vLLM:
2
- // https://github.com/vllm-project/vllm/blob/main/csrc/core/registration.h
3
-
4
- #pragma once
5
-
6
- #include <Python.h>
7
-
8
- #define _CONCAT(A, B) A##B
9
- #define CONCAT(A, B) _CONCAT(A, B)
10
-
11
- #define _STRINGIFY(A) #A
12
- #define STRINGIFY(A) _STRINGIFY(A)
13
-
14
- // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
15
- // could be a macro instead of a literal token.
16
- #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
17
-
18
- // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
19
- // could be a macro instead of a literal token.
20
- #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
21
- TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
22
-
23
- // REGISTER_EXTENSION allows the shared library to be loaded and initialized
24
- // via python's import statement.
25
- #define REGISTER_EXTENSION(NAME) \
26
- PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
27
- static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
28
- STRINGIFY(NAME), nullptr, 0, nullptr}; \
29
- return PyModule_Create(&module); \
30
- }