Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +10 -0
- build.toml +23 -0
- flake.nix +13 -0
- gptoss_kernels/CMakeLists.txt +191 -0
- gptoss_kernels/__init__.py +6 -0
- gptoss_kernels/examples/chat.py +104 -0
- gptoss_kernels/examples/generate.py +34 -0
- gptoss_kernels/include/gpt-oss.h +5 -0
- gptoss_kernels/include/gpt-oss/functions.h +401 -0
- gptoss_kernels/include/gpt-oss/macros.h +5 -0
- gptoss_kernels/include/gpt-oss/types.h +62 -0
- gptoss_kernels/source/accumulate.metal +59 -0
- gptoss_kernels/source/context.c +1115 -0
- gptoss_kernels/source/convert.metal +64 -0
- gptoss_kernels/source/embeddings.metal +29 -0
- gptoss_kernels/source/expert_routing_metadata.metal +41 -0
- gptoss_kernels/source/gather_and_accumulate.metal +74 -0
- gptoss_kernels/source/generate.c +317 -0
- gptoss_kernels/source/include/internal/datatype.h +41 -0
- gptoss_kernels/source/include/internal/datatype.hpp +87 -0
- gptoss_kernels/source/include/internal/kernel-args.h +201 -0
- gptoss_kernels/source/include/internal/log.h +20 -0
- gptoss_kernels/source/include/internal/macros.h +107 -0
- gptoss_kernels/source/include/internal/math.h +40 -0
- gptoss_kernels/source/include/internal/metal-kernels.h +486 -0
- gptoss_kernels/source/include/internal/metal.h +138 -0
- gptoss_kernels/source/include/internal/metal.hpp +342 -0
- gptoss_kernels/source/include/internal/model.h +178 -0
- gptoss_kernels/source/include/internal/rng.h +24 -0
- gptoss_kernels/source/include/internal/rng.hpp +32 -0
- gptoss_kernels/source/include/internal/storage.h +36 -0
- gptoss_kernels/source/include/internal/uuid.h +114 -0
- gptoss_kernels/source/log.c +50 -0
- gptoss_kernels/source/matmul.metal +422 -0
- gptoss_kernels/source/metal-kernels.c +1518 -0
- gptoss_kernels/source/metal.m +482 -0
- gptoss_kernels/source/model.c +581 -0
- gptoss_kernels/source/moematmul.metal +702 -0
- gptoss_kernels/source/random.metal +97 -0
- gptoss_kernels/source/rmsnorm.metal +58 -0
- gptoss_kernels/source/rope.metal +43 -0
- gptoss_kernels/source/sample.metal +209 -0
- gptoss_kernels/source/scatter.metal +65 -0
- gptoss_kernels/source/sdpa.metal +293 -0
- gptoss_kernels/source/tokenizer.c +106 -0
- gptoss_kernels/source/topk.metal +205 -0
- gptoss_kernels/test/bf16-f32-embeddings.cc +33 -0
- gptoss_kernels/test/embeddings-kernel-tester.hpp +123 -0
- gptoss_kernels/test/f32-bf16w-matmul.cc +87 -0
- gptoss_kernels/test/f32-bf16w-rmsnorm.cc +36 -0
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- kernels
|
| 4 |
+
- gptoss
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# gptoss_kernels
|
| 8 |
+
|
| 9 |
+
This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
|
| 10 |
+
|
build.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "gptoss_kernels"
|
| 3 |
+
universal = false
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/torch_binding.cpp",
|
| 8 |
+
"torch-ext/torch_binding.h",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
[kernel.gptoss_kernels]
|
| 12 |
+
depends = ["torch"]
|
| 13 |
+
backend = "cuda"
|
| 14 |
+
|
| 15 |
+
src = [
|
| 16 |
+
"gptoss_kernels/attention_cuda_fwd.cu",
|
| 17 |
+
"gptoss_kernels/attention_cuda_bwd.cu",
|
| 18 |
+
"gptoss_kernels/attention_cuda_utils.cu",
|
| 19 |
+
"gptoss_kernels/attention_cuda_utils.cuh",
|
| 20 |
+
"gptoss_kernels/attention_cuda.cuh",
|
| 21 |
+
"gptoss_kernels/attention.h",
|
| 22 |
+
"gptoss_kernels/cudamacro.h",
|
| 23 |
+
]
|
flake.nix
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for Torch kernel extension";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs = { self, kernel-builder, }:
|
| 9 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 10 |
+
inherit self;
|
| 11 |
+
path = ./.;
|
| 12 |
+
};
|
| 13 |
+
}
|
gptoss_kernels/CMakeLists.txt
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.24)
|
| 2 |
+
project(GPTOSS
|
| 3 |
+
VERSION 1.0
|
| 4 |
+
DESCRIPTION "Local GPT-OSS inference"
|
| 5 |
+
LANGUAGES C CXX OBJC)
|
| 6 |
+
|
| 7 |
+
set(CMAKE_C_STANDARD 11)
|
| 8 |
+
set(CMAKE_CXX_STANDARD 20)
|
| 9 |
+
set(CMAKE_OBJC_STANDARD 11)
|
| 10 |
+
set(CMAKE_OBJC_STANDARD_REQUIRED ON)
|
| 11 |
+
|
| 12 |
+
find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
|
| 13 |
+
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
| 14 |
+
find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
|
| 15 |
+
|
| 16 |
+
set(METAL_SOURCES
|
| 17 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
|
| 18 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
|
| 19 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
|
| 20 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal
|
| 21 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
|
| 22 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
|
| 23 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
|
| 24 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
|
| 25 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
|
| 26 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
|
| 27 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
|
| 28 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
|
| 29 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
|
| 30 |
+
${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
|
| 31 |
+
)
|
| 32 |
+
set(METAL_LIB default.metallib)
|
| 33 |
+
|
| 34 |
+
include_directories(BEFORE include source/include)
|
| 35 |
+
|
| 36 |
+
add_custom_command(
|
| 37 |
+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 38 |
+
COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
|
| 39 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
|
| 40 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
|
| 41 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
|
| 42 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air"
|
| 43 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
|
| 44 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
|
| 45 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
|
| 46 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
|
| 47 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
|
| 48 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
|
| 49 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
|
| 50 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
|
| 51 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
|
| 52 |
+
COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
|
| 53 |
+
COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
|
| 54 |
+
DEPENDS ${METAL_SOURCES}
|
| 55 |
+
COMMENT "Compiling Metal compute library"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
add_custom_target(build_metallib ALL
|
| 59 |
+
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
|
| 60 |
+
|
| 61 |
+
add_library(log OBJECT source/log.c)
|
| 62 |
+
|
| 63 |
+
add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
|
| 64 |
+
target_link_libraries(metal-kernels PRIVATE log)
|
| 65 |
+
|
| 66 |
+
add_dependencies(metal-kernels build_metallib)
|
| 67 |
+
add_custom_command(TARGET metal-kernels POST_BUILD
|
| 68 |
+
COMMAND ${CMAKE_COMMAND} -E copy
|
| 69 |
+
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 70 |
+
$<TARGET_FILE_DIR:metal-kernels>)
|
| 71 |
+
|
| 72 |
+
target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
|
| 73 |
+
|
| 74 |
+
add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
|
| 75 |
+
target_link_libraries(gptoss PRIVATE log metal-kernels)
|
| 76 |
+
|
| 77 |
+
add_executable(generate source/generate.c)
|
| 78 |
+
target_link_libraries(generate gptoss)
|
| 79 |
+
|
| 80 |
+
# --- [ Tests
|
| 81 |
+
include(FetchContent)
|
| 82 |
+
FetchContent_Declare(
|
| 83 |
+
googletest
|
| 84 |
+
URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
|
| 85 |
+
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
| 86 |
+
)
|
| 87 |
+
# For Windows: Prevent overriding the parent project's compiler/linker settings
|
| 88 |
+
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
| 89 |
+
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
|
| 90 |
+
FetchContent_MakeAvailable(googletest)
|
| 91 |
+
|
| 92 |
+
enable_testing()
|
| 93 |
+
|
| 94 |
+
add_executable(u32-random-test test/u32-random.cc)
|
| 95 |
+
target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
| 96 |
+
target_include_directories(u32-random-test PRIVATE source/include)
|
| 97 |
+
add_test(NAME u32-random-test COMMAND u32-random-test)
|
| 98 |
+
|
| 99 |
+
add_executable(f32-random-test test/f32-random.cc)
|
| 100 |
+
target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
|
| 101 |
+
target_include_directories(f32-random-test PRIVATE source/include)
|
| 102 |
+
add_test(NAME f32-random-test COMMAND f32-random-test)
|
| 103 |
+
|
| 104 |
+
add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
|
| 105 |
+
target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
|
| 106 |
+
target_include_directories(mf4-f32-convert-test PRIVATE source/include)
|
| 107 |
+
add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
|
| 108 |
+
|
| 109 |
+
add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
|
| 110 |
+
target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
|
| 111 |
+
target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
|
| 112 |
+
add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
|
| 113 |
+
|
| 114 |
+
add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
|
| 115 |
+
target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
|
| 116 |
+
target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
|
| 117 |
+
add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
|
| 118 |
+
|
| 119 |
+
add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
|
| 120 |
+
target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
|
| 121 |
+
target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
|
| 122 |
+
add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
|
| 123 |
+
|
| 124 |
+
add_executable(f32-rope-test test/f32-rope.cc)
|
| 125 |
+
target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
|
| 126 |
+
target_include_directories(f32-rope-test PRIVATE source/include)
|
| 127 |
+
add_test(NAME f32-rope-test COMMAND f32-rope-test)
|
| 128 |
+
|
| 129 |
+
# --- [ Benchmarks
|
| 130 |
+
include(FetchContent)
|
| 131 |
+
set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
|
| 132 |
+
set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
|
| 133 |
+
FetchContent_Declare(
|
| 134 |
+
benchmark
|
| 135 |
+
URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
|
| 136 |
+
DOWNLOAD_EXTRACT_TIMESTAMP OFF
|
| 137 |
+
)
|
| 138 |
+
FetchContent_MakeAvailable(benchmark)
|
| 139 |
+
|
| 140 |
+
add_executable(f32-random-bench benchmark/f32-random.cc)
|
| 141 |
+
target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 142 |
+
target_include_directories(f32-random-bench PRIVATE source/include)
|
| 143 |
+
|
| 144 |
+
add_executable(u32-random-bench benchmark/u32-random.cc)
|
| 145 |
+
target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 146 |
+
target_include_directories(u32-random-bench PRIVATE source/include)
|
| 147 |
+
|
| 148 |
+
add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
|
| 149 |
+
target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 150 |
+
target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
|
| 151 |
+
|
| 152 |
+
add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
|
| 153 |
+
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
|
| 154 |
+
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
|
| 155 |
+
|
| 156 |
+
add_executable(end-to-end-bench benchmark/end-to-end.cc)
|
| 157 |
+
target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
|
| 158 |
+
target_include_directories(end-to-end-bench PRIVATE source/include)
|
| 159 |
+
|
| 160 |
+
add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
|
| 161 |
+
target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
|
| 162 |
+
target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
|
| 163 |
+
|
| 164 |
+
# --- [ Python extension ] -----------------------------------------------
|
| 165 |
+
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
|
| 166 |
+
|
| 167 |
+
pybind11_add_module(_metal
|
| 168 |
+
python/module.c
|
| 169 |
+
python/context.c
|
| 170 |
+
python/model.c
|
| 171 |
+
python/tokenizer.c
|
| 172 |
+
)
|
| 173 |
+
set_target_properties(_metal PROPERTIES PREFIX "")
|
| 174 |
+
|
| 175 |
+
target_link_libraries(_metal PRIVATE gptoss)
|
| 176 |
+
add_dependencies(_metal build_metallib)
|
| 177 |
+
target_link_options(_metal PRIVATE
|
| 178 |
+
LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 179 |
+
)
|
| 180 |
+
add_custom_command(TARGET _metal POST_BUILD
|
| 181 |
+
COMMAND ${CMAKE_COMMAND} -E copy
|
| 182 |
+
${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 183 |
+
$<TARGET_FILE_DIR:_metal>)
|
| 184 |
+
|
| 185 |
+
# 1️⃣ install the extension module into the Python package
|
| 186 |
+
install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
|
| 187 |
+
|
| 188 |
+
# 2️⃣ make sure the Metal shader archive travels with it
|
| 189 |
+
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
|
| 190 |
+
DESTINATION gpt_oss/metal)
|
| 191 |
+
# ------------------------------------------------------------------------
|
gptoss_kernels/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module as _im
|
| 2 |
+
|
| 3 |
+
# Load the compiled extension (gpt_oss.metal._metal)
|
| 4 |
+
_ext = _im(f"{__name__}._metal")
|
| 5 |
+
globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
|
| 6 |
+
del _im, _ext
|
gptoss_kernels/examples/chat.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from datetime import date
|
| 7 |
+
from gpt_oss.metal import Context, Model
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
|
| 11 |
+
Knowledge cutoff: 2024-06
|
| 12 |
+
Current date: {date.today().isoformat()}
|
| 13 |
+
|
| 14 |
+
reasoning effort high
|
| 15 |
+
|
| 16 |
+
# Valid channels: analysis, final. Channel must be included for every message."""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 20 |
+
parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
|
| 21 |
+
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--context-length", type=int, default=0, help="The maximum context length"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--temperature", type=float, default=1.0, help="Sampling temperature"
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--seed", type=int, default=0, help="Sampling seed"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
GREY = "\33[90m"
|
| 34 |
+
BOLD = "\33[1m"
|
| 35 |
+
RESET = "\33[0m"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main(args):
|
| 39 |
+
options = parser.parse_args(args)
|
| 40 |
+
model = Model(options.model)
|
| 41 |
+
tokenizer = model.tokenizer
|
| 42 |
+
start_token = tokenizer.encode_special_token("<|start|>")
|
| 43 |
+
message_token = tokenizer.encode_special_token("<|message|>")
|
| 44 |
+
end_token = tokenizer.encode_special_token("<|end|>")
|
| 45 |
+
return_token = tokenizer.encode_special_token("<|return|>")
|
| 46 |
+
channel_token = tokenizer.encode_special_token("<|channel|>")
|
| 47 |
+
|
| 48 |
+
context = Context(model, context_length=options.context_length)
|
| 49 |
+
context.append(start_token)
|
| 50 |
+
context.append("system")
|
| 51 |
+
context.append(message_token)
|
| 52 |
+
context.append(options.prompt)
|
| 53 |
+
context.append(end_token)
|
| 54 |
+
|
| 55 |
+
while True:
|
| 56 |
+
context.append(start_token)
|
| 57 |
+
context.append("user")
|
| 58 |
+
context.append(message_token)
|
| 59 |
+
message = input(f"{BOLD}User:{RESET} ").rstrip()
|
| 60 |
+
context.append(message)
|
| 61 |
+
context.append(end_token)
|
| 62 |
+
print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
|
| 63 |
+
context.append(start_token)
|
| 64 |
+
context.append("assistant")
|
| 65 |
+
context.append(channel_token)
|
| 66 |
+
|
| 67 |
+
inside_start_block = True
|
| 68 |
+
inside_channel_block = True
|
| 69 |
+
role = "assistant"
|
| 70 |
+
channel = ""
|
| 71 |
+
while True:
|
| 72 |
+
token = context.sample(
|
| 73 |
+
temperature=options.temperature,
|
| 74 |
+
seed=options.seed,
|
| 75 |
+
)
|
| 76 |
+
context.append(token)
|
| 77 |
+
if token == return_token:
|
| 78 |
+
print(flush=True)
|
| 79 |
+
break
|
| 80 |
+
elif token == start_token:
|
| 81 |
+
inside_start_block = True
|
| 82 |
+
role = ""
|
| 83 |
+
channel = ""
|
| 84 |
+
elif token == message_token:
|
| 85 |
+
inside_start_block = False
|
| 86 |
+
inside_channel_block = False
|
| 87 |
+
if channel == "analysis":
|
| 88 |
+
print(f"{GREY}", end="", flush=True)
|
| 89 |
+
elif token == end_token:
|
| 90 |
+
print(f"{RESET}", flush=True)
|
| 91 |
+
elif token == channel_token:
|
| 92 |
+
inside_channel_block = True
|
| 93 |
+
elif token < tokenizer.num_text_tokens:
|
| 94 |
+
if inside_channel_block:
|
| 95 |
+
channel += str(tokenizer.decode(token), encoding="utf-8")
|
| 96 |
+
elif inside_start_block:
|
| 97 |
+
role += str(tokenizer.decode(token), encoding="utf-8")
|
| 98 |
+
else:
|
| 99 |
+
sys.stdout.buffer.write(tokenizer.decode(token))
|
| 100 |
+
sys.stdout.buffer.flush()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main(sys.argv[1:])
|
gptoss_kernels/examples/generate.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from gpt_oss.metal import Context, Model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 10 |
+
parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
|
| 11 |
+
parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
|
| 12 |
+
parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
|
| 13 |
+
parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main(args):
|
| 17 |
+
options = parser.parse_args(args)
|
| 18 |
+
model = Model(options.model)
|
| 19 |
+
|
| 20 |
+
context = Context(model, context_length=options.context_length)
|
| 21 |
+
context.append(options.prompt)
|
| 22 |
+
print(context.tokens)
|
| 23 |
+
prompt_tokens = context.num_tokens
|
| 24 |
+
|
| 25 |
+
tokenizer = model.tokenizer
|
| 26 |
+
|
| 27 |
+
while context.num_tokens - prompt_tokens < options.limit:
|
| 28 |
+
token = context.sample()
|
| 29 |
+
context.append(token)
|
| 30 |
+
print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
main(sys.argv[1:])
|
gptoss_kernels/include/gpt-oss.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <gpt-oss/macros.h>
|
| 4 |
+
#include <gpt-oss/types.h>
|
| 5 |
+
#include <gpt-oss/functions.h>
|
gptoss_kernels/include/gpt-oss/functions.h
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stddef.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
|
| 6 |
+
#include <gpt-oss/macros.h>
|
| 7 |
+
#include <gpt-oss/types.h>
|
| 8 |
+
|
| 9 |
+
#ifdef __cplusplus
|
| 10 |
+
extern "C" {
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
/*
|
| 14 |
+
* Creates a Model object from a file in the filesystem.
|
| 15 |
+
*
|
| 16 |
+
* @param path Path to the file containing the model in GPT-OSS format.
|
| 17 |
+
* @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.
|
| 18 |
+
*
|
| 19 |
+
* On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.
|
| 20 |
+
* On failure, returns an error code and stores null pointer in the model_out argument.
|
| 21 |
+
*/
|
| 22 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
|
| 23 |
+
const char* path,
|
| 24 |
+
gptoss_model_t* model_out);
|
| 25 |
+
|
| 26 |
+
/*
|
| 27 |
+
* Query the Tokenizer object associated with the Model.
|
| 28 |
+
*
|
| 29 |
+
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
| 30 |
+
* @param tokenizer_out Pointer to the variable where the Tokenizer reference will be stored.
|
| 31 |
+
*
|
| 32 |
+
* On success, returns gptoss_status_success and stores reference to the Tokenizer object in the tokenizer_out argument.
|
| 33 |
+
* On failure, returns an error code and stores NULL in the tokenizer_out argument.
|
| 34 |
+
*/
|
| 35 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
|
| 36 |
+
gptoss_model_t model,
|
| 37 |
+
gptoss_tokenizer_t* tokenizer_out);
|
| 38 |
+
|
| 39 |
+
/*
|
| 40 |
+
* Query the maximum context length supported by the Model.
|
| 41 |
+
*
|
| 42 |
+
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
| 43 |
+
* @param max_context_length_out Pointer to the variable where the maximum context length will be stored.
|
| 44 |
+
*
|
| 45 |
+
* On success, returns gptoss_status_success and stores maximum context length in the max_context_length_out argument.
|
| 46 |
+
* On failure, returns an error code and leaves the value specified by max_context_length_out unchanged.
|
| 47 |
+
*/
|
| 48 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
|
| 49 |
+
gptoss_model_t model,
|
| 50 |
+
size_t* max_context_length_out);
|
| 51 |
+
|
| 52 |
+
/*
|
| 53 |
+
* Increments a Model object's reference count.
|
| 54 |
+
*
|
| 55 |
+
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
| 56 |
+
*
|
| 57 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 58 |
+
*/
|
| 59 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
|
| 60 |
+
gptoss_model_t model);
|
| 61 |
+
|
| 62 |
+
/*
|
| 63 |
+
* Decrements a Model object's reference count and possibly release associated resources.
|
| 64 |
+
*
|
| 65 |
+
* @param model Pointer to the Model object created by gptoss_model_create_from_file.
|
| 66 |
+
*
|
| 67 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 68 |
+
*/
|
| 69 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_release(
|
| 70 |
+
gptoss_model_t model);
|
| 71 |
+
|
| 72 |
+
/*
|
| 73 |
+
* Query the token ID for a special token in the Tokenizer vocabulary.
|
| 74 |
+
*
|
| 75 |
+
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
| 76 |
+
* @param token_type Type of the special token to query an ID for.
|
| 77 |
+
* @param token_id_out Pointer to the variable where the token ID will be stored.
|
| 78 |
+
*
|
| 79 |
+
* On success, returns gptoss_status_success and stores the token ID in the token_id_out argument.
|
| 80 |
+
* On failure, returns an error code and leaves the value specified by token_id_out unchanged.
|
| 81 |
+
*/
|
| 82 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
|
| 83 |
+
gptoss_tokenizer_t tokenizer,
|
| 84 |
+
enum gptoss_special_token token_type,
|
| 85 |
+
uint32_t* token_id_out);
|
| 86 |
+
|
| 87 |
+
/*
|
| 88 |
+
* Query the number of text tokens in the Tokenizer vocabulary.
|
| 89 |
+
*
|
| 90 |
+
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
| 91 |
+
* @param num_text_tokens_out Pointer to the variable where the number of text tokens will be stored.
|
| 92 |
+
*
|
| 93 |
+
* On success, returns gptoss_status_success and stores the number of text tokens in the num_text_tokens_out argument.
|
| 94 |
+
* On failure, returns an error code and leaves the value specified by num_text_tokens_out unchanged.
|
| 95 |
+
*/
|
| 96 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
|
| 97 |
+
gptoss_tokenizer_t tokenizer,
|
| 98 |
+
uint32_t* num_text_tokens_out);
|
| 99 |
+
|
| 100 |
+
/*
|
| 101 |
+
* Query the number of special tokens in the Tokenizer vocabulary.
|
| 102 |
+
*
|
| 103 |
+
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
| 104 |
+
* @param num_special_tokens_out Pointer to the variable where the number of special tokens will be stored.
|
| 105 |
+
*
|
| 106 |
+
* On success, returns gptoss_status_success and stores the number of text tokens in the num_special_tokens_out argument.
|
| 107 |
+
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
|
| 108 |
+
*/
|
| 109 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
|
| 110 |
+
gptoss_tokenizer_t tokenizer,
|
| 111 |
+
uint32_t* num_special_tokens_out);
|
| 112 |
+
|
| 113 |
+
/*
|
| 114 |
+
* Query the total number of tokens in the Tokenizer vocabulary.
|
| 115 |
+
*
|
| 116 |
+
* @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
|
| 117 |
+
* @param num_tokens_out Pointer to the variable where the total number of tokens will be stored.
|
| 118 |
+
*
|
| 119 |
+
* On success, returns gptoss_status_success and stores the total number of tokens in the num_special_tokens_out argument.
|
| 120 |
+
* On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
|
| 121 |
+
*/
|
| 122 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
|
| 123 |
+
gptoss_tokenizer_t tokenizer,
|
| 124 |
+
uint32_t* num_tokens_out);
|
| 125 |
+
|
| 126 |
+
/*
|
| 127 |
+
* Convert a text token ID to byte representation.
|
| 128 |
+
*
|
| 129 |
+
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer. The lifetime of the returned
|
| 130 |
+
* byte representation would match the lifetime of this Tokenizer object.
|
| 131 |
+
* @param token_ptr_out Pointer to the variable where the pointer to the byte representation of the token will be
|
| 132 |
+
* stored.
|
| 133 |
+
* @param token_size_out Pointer to the variable where the size of the byte representation of the token will be stored.
|
| 134 |
+
*
|
| 135 |
+
* On success, returns gptoss_status_success and stores pointer and size of the byte representation of the token in the
|
| 136 |
+
* token_ptr_out and token_size_out arguments.
|
| 137 |
+
* On failure, returns an error code and leaves the values specified in token_ptr_out and token_size_out unchanged.
|
| 138 |
+
*/
|
| 139 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
|
| 140 |
+
gptoss_tokenizer_t tokenizer,
|
| 141 |
+
uint32_t token_id,
|
| 142 |
+
const void** token_ptr_out,
|
| 143 |
+
size_t* token_size_out);
|
| 144 |
+
|
| 145 |
+
/*
|
| 146 |
+
* Increments a Tokenizer object's reference count.
|
| 147 |
+
*
|
| 148 |
+
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
|
| 149 |
+
*
|
| 150 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 151 |
+
*/
|
| 152 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
|
| 153 |
+
gptoss_tokenizer_t tokenizer);
|
| 154 |
+
|
| 155 |
+
/*
|
| 156 |
+
* Decrements a Tokenizer object's reference count and possibly release associated resources.
|
| 157 |
+
*
|
| 158 |
+
* @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
|
| 159 |
+
*
|
| 160 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 161 |
+
*/
|
| 162 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
|
| 163 |
+
gptoss_tokenizer_t tokenizer);
|
| 164 |
+
|
| 165 |
+
/*
|
| 166 |
+
* Creates a Context object for use with the particular Model object.
|
| 167 |
+
*
|
| 168 |
+
* @param model Model object to create a context for.
|
| 169 |
+
* @param context_length Maximum number of tokens in the context.
|
| 170 |
+
* Specify 0 to use the maximum context length supported by the model.
|
| 171 |
+
* @param max_batch_size Maximum number of tokens that can be processed in a single batch.
|
| 172 |
+
* Larger values may improve prefill performance, but require more memory.
|
| 173 |
+
* Specify 0 to use the default value.
|
| 174 |
+
* @param context_out Pointer to the Context object that will be created.
|
| 175 |
+
* Must be released with gptoss_release_context.
|
| 176 |
+
*
|
| 177 |
+
* On success, returns gptoss_status_success and saves a pointer to the created Context in the context_out argument.
|
| 178 |
+
* On failure, returns an error code and stores null pointer in the context_out argument.
|
| 179 |
+
*/
|
| 180 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_create(
|
| 181 |
+
gptoss_model_t model,
|
| 182 |
+
size_t context_length,
|
| 183 |
+
size_t max_batch_tokens,
|
| 184 |
+
gptoss_context_t* context_out);
|
| 185 |
+
|
| 186 |
+
/*
|
| 187 |
+
* Query the current number of tokens cached in the Context.
|
| 188 |
+
*
|
| 189 |
+
* @param context Pointer to the Context object created by gptoss_context_create.
|
| 190 |
+
* @param num_tokens_out Pointer to the variable where the current number of cached tokens will be stored.
|
| 191 |
+
*
|
| 192 |
+
* On success, returns gptoss_status_success and stores current number of cached tokens in the num_tokens_out argument.
|
| 193 |
+
* On failure, returns an error code and leaves the value specified by num_tokens_out unchanged.
|
| 194 |
+
*/
|
| 195 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
|
| 196 |
+
gptoss_context_t context,
|
| 197 |
+
size_t* num_tokens_out);
|
| 198 |
+
|
| 199 |
+
/*
|
| 200 |
+
* Query the maximum number of tokens cached in the Context.
|
| 201 |
+
*
|
| 202 |
+
* @param context Pointer to the Context object created by gptoss_context_create.
|
| 203 |
+
* @param max_tokens_out Pointer to the variable where the maximum number of cached tokens will be stored.
|
| 204 |
+
*
|
| 205 |
+
* On success, returns gptoss_status_success and stores maximum number of cached tokens in the max_tokens_out argument.
|
| 206 |
+
* On failure, returns an error code and leaves the value specified by max_tokens_out unchanged.
|
| 207 |
+
*/
|
| 208 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
|
| 209 |
+
gptoss_context_t context,
|
| 210 |
+
size_t* max_tokens_out);
|
| 211 |
+
|
| 212 |
+
/*
|
| 213 |
+
* Query the list of token IDs cached in the Context.
|
| 214 |
+
*
|
| 215 |
+
* @param context Pointer to the Context object created by gptoss_context_create.
|
| 216 |
+
* @param tokens_out Pointer to the array where up to max_tokens_out of cached tokens will be stored.
|
| 217 |
+
* @param max_tokens Maximum capacity of the buffer specified by tokens_out.
|
| 218 |
+
* @param num_tokens_out Pointer to the variable where the actual number of cached tokens will be stored.
|
| 219 |
+
* This value can exceed max_tokens if the buffer capacity is insufficient.
|
| 220 |
+
*
|
| 221 |
+
* On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of
|
| 222 |
+
* cached tokens in the num_tokens_out argument.
|
| 223 |
+
* On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged.
|
| 224 |
+
*/
|
| 225 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
|
| 226 |
+
gptoss_context_t context,
|
| 227 |
+
uint32_t* tokens_out,
|
| 228 |
+
size_t max_tokens,
|
| 229 |
+
size_t* num_tokens_out);
|
| 230 |
+
|
| 231 |
+
/*
|
| 232 |
+
* Tokenize and appends a character string to the Context object.
|
| 233 |
+
*
|
| 234 |
+
* @param context Context object created by gptoss_context_create.
|
| 235 |
+
* @param text Pointer to the character string to tokenizer and append.
|
| 236 |
+
* @param text_length Length of the string, in chars.
|
| 237 |
+
* @param num_tokens_out Optional pointer to the variable where the number of appended tokens will be stored. Ignored if a null pointer is provided.
|
| 238 |
+
*
|
| 239 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 240 |
+
*/
|
| 241 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
|
| 242 |
+
gptoss_context_t context,
|
| 243 |
+
const char* text,
|
| 244 |
+
size_t text_length,
|
| 245 |
+
size_t* num_tokens_out);
|
| 246 |
+
|
| 247 |
+
/*
|
| 248 |
+
* Appends a list of tokens to the context.
|
| 249 |
+
*
|
| 250 |
+
* @param context Context object created by gptoss_context_create.
|
| 251 |
+
* @param num_tokens Number of tokens to be appended.
|
| 252 |
+
* @param tokens Pointer to the array of tokens to be appended.
|
| 253 |
+
*
|
| 254 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 255 |
+
*/
|
| 256 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
|
| 257 |
+
gptoss_context_t context,
|
| 258 |
+
size_t num_tokens,
|
| 259 |
+
const uint32_t* tokens);
|
| 260 |
+
|
| 261 |
+
/*
|
| 262 |
+
* Resets the context, clearing its state.
|
| 263 |
+
*
|
| 264 |
+
* @param context Context object created by gptoss_context_create.
|
| 265 |
+
*
|
| 266 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 267 |
+
*/
|
| 268 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
|
| 269 |
+
gptoss_context_t context);
|
| 270 |
+
|
| 271 |
+
/*
|
| 272 |
+
* Pre-process the tokens in the Context and generate probability distribution over the next token.
|
| 273 |
+
*
|
| 274 |
+
* @param context Context object created by gptoss_context_create.
|
| 275 |
+
*
|
| 276 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 277 |
+
*/
|
| 278 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_process(
|
| 279 |
+
gptoss_context_t context);
|
| 280 |
+
|
| 281 |
+
/*
|
| 282 |
+
* Generate a token probability distribution over the next token conditioned on the Context.
|
| 283 |
+
*
|
| 284 |
+
* @param context Context object created by gptoss_context_create.
|
| 285 |
+
* @param temperature Sampling temperature. Must be non-negative.
|
| 286 |
+
* @param seed Random number generator seed to use for sampling.
|
| 287 |
+
* @param token_out Pointer to the variable where the token ID will be stored.
|
| 288 |
+
*
|
| 289 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 290 |
+
*/
|
| 291 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
|
| 292 |
+
gptoss_context_t context,
|
| 293 |
+
float temperature,
|
| 294 |
+
uint64_t seed,
|
| 295 |
+
size_t max_tokens,
|
| 296 |
+
uint32_t* tokens_out,
|
| 297 |
+
size_t* num_tokens_out);
|
| 298 |
+
|
| 299 |
+
/*
|
| 300 |
+
* Increments a Context object's reference count.
|
| 301 |
+
*
|
| 302 |
+
* @param context Pointer to the Context object created by gptoss_create_context.
|
| 303 |
+
*
|
| 304 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 305 |
+
*/
|
| 306 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
|
| 307 |
+
gptoss_context_t context);
|
| 308 |
+
|
| 309 |
+
/*
|
| 310 |
+
* Decrements a Context object's reference count and possibly release associated resources.
|
| 311 |
+
*
|
| 312 |
+
* @param context Pointer to the Context object created by gptoss_create_context.
|
| 313 |
+
*
|
| 314 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 315 |
+
*/
|
| 316 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_release(
|
| 317 |
+
gptoss_context_t context);
|
| 318 |
+
|
| 319 |
+
/*
|
| 320 |
+
* Creates a Sampler object.
|
| 321 |
+
*
|
| 322 |
+
* @param sampler_out Pointer to the Sampler object that will be created.
|
| 323 |
+
* Must be released with gptoss_sampler_release.
|
| 324 |
+
*
|
| 325 |
+
* On success, returns gptoss_status_success and saves a pointer to the created Sampler in the sampler_out argument.
|
| 326 |
+
* On failure, returns an error code and stores a null pointer in the sampler_out argument.
|
| 327 |
+
*/
|
| 328 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_create(
|
| 329 |
+
gptoss_sampler_t* sampler_out);
|
| 330 |
+
|
| 331 |
+
/*
|
| 332 |
+
* Sets the sampling temperature for the Sampler.
|
| 333 |
+
*
|
| 334 |
+
* @param sampler Sampler object created by gptoss_sampler_create.
|
| 335 |
+
* @param temperature Temperature value to be set. Must be in the [0.0, 1.0] range.
|
| 336 |
+
*
|
| 337 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 338 |
+
*/
|
| 339 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_temperature(
|
| 340 |
+
gptoss_sampler_t sampler,
|
| 341 |
+
float temperature);
|
| 342 |
+
|
| 343 |
+
/*
|
| 344 |
+
* Sets the Top-P nucleus sampling parameter for the Sampler.
|
| 345 |
+
*
|
| 346 |
+
* @param sampler Sampler object created by gptoss_sampler_create.
|
| 347 |
+
* @param top_p Top-P value to be set. Must be in the (0.0, 1.0] range.
|
| 348 |
+
*
|
| 349 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 350 |
+
*/
|
| 351 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_top_p(
|
| 352 |
+
gptoss_sampler_t sampler,
|
| 353 |
+
float top_p);
|
| 354 |
+
|
| 355 |
+
/*
|
| 356 |
+
* Sets the presence penalty for the Sampler.
|
| 357 |
+
*
|
| 358 |
+
* @param sampler Sampler object created by gptoss_sampler_create.
|
| 359 |
+
* @param presence_penalty Presence penalty value to be set. Must be in the [-2.0, 2.0] range.
|
| 360 |
+
*
|
| 361 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 362 |
+
*/
|
| 363 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_presence_penalty(
|
| 364 |
+
gptoss_sampler_t sampler,
|
| 365 |
+
float presence_penalty);
|
| 366 |
+
|
| 367 |
+
/*
|
| 368 |
+
* Sets the frequency penalty for the Sampler.
|
| 369 |
+
*
|
| 370 |
+
* @param sampler Sampler object created by gptoss_sampler_create.
|
| 371 |
+
* @param frequency_penalty Frequency penalty value to be set. Must be in the [-2.0, 2.0] range.
|
| 372 |
+
*
|
| 373 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 374 |
+
*/
|
| 375 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_set_frequency_penalty(
|
| 376 |
+
gptoss_sampler_t sampler,
|
| 377 |
+
float frequency_penalty);
|
| 378 |
+
|
| 379 |
+
/*
|
| 380 |
+
* Increments a Sampler object's reference count.
|
| 381 |
+
*
|
| 382 |
+
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
|
| 383 |
+
*
|
| 384 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 385 |
+
*/
|
| 386 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_retain(
|
| 387 |
+
gptoss_sampler_t sampler);
|
| 388 |
+
|
| 389 |
+
/*
|
| 390 |
+
* Decrements a Sampler object's reference count and possibly releases associated resources.
|
| 391 |
+
*
|
| 392 |
+
* @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
|
| 393 |
+
*
|
| 394 |
+
* On success, returns gptoss_status_success, otherwise returns an error code.
|
| 395 |
+
*/
|
| 396 |
+
enum gptoss_status GPTOSS_ABI gptoss_sampler_release(
|
| 397 |
+
gptoss_sampler_t sampler);
|
| 398 |
+
|
| 399 |
+
#ifdef __cplusplus
|
| 400 |
+
} // extern "C"
|
| 401 |
+
#endif
|
gptoss_kernels/include/gpt-oss/macros.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifndef GPTOSS_ABI
|
| 4 |
+
#define GPTOSS_ABI
|
| 5 |
+
#endif // GPTOSS_ABI
|
gptoss_kernels/include/gpt-oss/types.h
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/*
|
| 4 |
+
* Status codes returned by GPT-OSS API functions.
|
| 5 |
+
*/
|
| 6 |
+
enum gptoss_status {
|
| 7 |
+
gptoss_status_success = 0,
|
| 8 |
+
gptoss_status_invalid_argument = 1,
|
| 9 |
+
gptoss_status_unsupported_argument = 2,
|
| 10 |
+
gptoss_status_invalid_state = 3,
|
| 11 |
+
gptoss_status_io_error = 4,
|
| 12 |
+
gptoss_status_insufficient_memory = 5,
|
| 13 |
+
gptoss_status_insufficient_resources = 6,
|
| 14 |
+
gptoss_status_unsupported_system = 7,
|
| 15 |
+
gptoss_status_context_overflow = 8,
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
enum gptoss_special_token {
|
| 19 |
+
gptoss_special_token_invalid = 0,
|
| 20 |
+
gptoss_special_token_return = 1,
|
| 21 |
+
gptoss_special_token_start = 2,
|
| 22 |
+
gptoss_special_token_message = 3,
|
| 23 |
+
gptoss_special_token_end = 4,
|
| 24 |
+
gptoss_special_token_refusal = 5,
|
| 25 |
+
gptoss_special_token_constrain = 6,
|
| 26 |
+
gptoss_special_token_channel = 7,
|
| 27 |
+
gptoss_special_token_call = 8,
|
| 28 |
+
gptoss_special_token_untrusted = 9,
|
| 29 |
+
gptoss_special_token_end_untrusted = 10,
|
| 30 |
+
gptoss_special_token_max,
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
/*
|
| 34 |
+
* Model object is an opaque container comprised of:
|
| 35 |
+
* - Weights
|
| 36 |
+
* - Temporary buffers required to run the model
|
| 37 |
+
* - Any other resources requires to run the model
|
| 38 |
+
*/
|
| 39 |
+
typedef struct gptoss_model* gptoss_model_t;
|
| 40 |
+
|
| 41 |
+
typedef struct gptoss_tokenizer* gptoss_tokenizer_t;
|
| 42 |
+
|
| 43 |
+
/*
|
| 44 |
+
* Context is an opaque container comprised of:
|
| 45 |
+
* - Input tokens
|
| 46 |
+
* - Distribution over the output tokens
|
| 47 |
+
* - KV cache
|
| 48 |
+
*
|
| 49 |
+
* Multiple contexts can be created and used with the same model.
|
| 50 |
+
*/
|
| 51 |
+
typedef struct gptoss_context* gptoss_context_t;
|
| 52 |
+
|
| 53 |
+
/*
|
| 54 |
+
* Sampler is an opaque container for sampling parameters:
|
| 55 |
+
* - Temperature
|
| 56 |
+
* - Top-p (nucleus sampling)
|
| 57 |
+
* - Frequency penalty
|
| 58 |
+
* - Presence penalty
|
| 59 |
+
*
|
| 60 |
+
* Multiple samplers can be created and used with the same context.
|
| 61 |
+
*/
|
| 62 |
+
typedef struct gptoss_sampler* gptoss_sampler_t;
|
gptoss_kernels/source/accumulate.metal
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_integer>
|
| 2 |
+
#include <metal_math>
|
| 3 |
+
|
| 4 |
+
#include <internal/kernel-args.h>
|
| 5 |
+
|
| 6 |
+
#pragma METAL fp math_mode(safe)
|
| 7 |
+
#pragma METAL fp contract(off)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
kernel void gptoss_f32_accumulate_e4(
|
| 11 |
+
constant gptoss_accumulate_args& args [[ buffer(0) ]],
|
| 12 |
+
const device float4* input [[ buffer(1) ]],
|
| 13 |
+
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
| 14 |
+
device float4* output [[ buffer(3) ]],
|
| 15 |
+
const device gptoss_control* control [[ buffer(4) ]],
|
| 16 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 17 |
+
uint tid [[thread_index_in_threadgroup]],
|
| 18 |
+
uint2 threadgroup_size [[ threads_per_threadgroup ]])
|
| 19 |
+
{
|
| 20 |
+
const uint num_active_experts = 4;
|
| 21 |
+
if (control->abort != 0) {
|
| 22 |
+
return;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
| 26 |
+
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
|
| 27 |
+
const uint num_vecs = args.num_vecs;
|
| 28 |
+
const uint threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, num_vecs);
|
| 29 |
+
const uint thread_start = threadgroup_start + tid;
|
| 30 |
+
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size.x - 1)) / threadgroup_size.x);
|
| 31 |
+
|
| 32 |
+
const uint num_vecs_per_expert = args.num_vecs_per_expert;
|
| 33 |
+
const float scale0 = expert[gid.y * num_active_experts + 0].score;
|
| 34 |
+
const device float4* input0 = input + gid.y * num_vecs + thread_start;
|
| 35 |
+
const float scale1 = expert[gid.y * num_active_experts + 1].score;
|
| 36 |
+
const device float4* input1 = input0 + num_vecs_per_expert;
|
| 37 |
+
const float scale2 = expert[gid.y * num_active_experts + 2].score;
|
| 38 |
+
const device float4* input2 = input1 + num_vecs_per_expert;
|
| 39 |
+
const float scale3 = expert[gid.y * num_active_experts + 3].score;
|
| 40 |
+
const device float4* input3 = input2 + num_vecs_per_expert;
|
| 41 |
+
output += gid.y * num_vecs + thread_start;
|
| 42 |
+
for (; num_iter != 0; num_iter--) {
|
| 43 |
+
float4 acc = *output;
|
| 44 |
+
const float4 val0 = *input0;
|
| 45 |
+
const float4 val1 = *input1;
|
| 46 |
+
const float4 val2 = *input2;
|
| 47 |
+
const float4 val3 = *input3;
|
| 48 |
+
input0 += threadgroup_size.x;
|
| 49 |
+
acc = metal::fma(val0, scale0, acc);
|
| 50 |
+
input1 += threadgroup_size.x;
|
| 51 |
+
acc = metal::fma(val1, scale1, acc);
|
| 52 |
+
input2 += threadgroup_size.x;
|
| 53 |
+
acc = metal::fma(val2, scale2, acc);
|
| 54 |
+
input3 += threadgroup_size.x;
|
| 55 |
+
acc = metal::fma(val3, scale3, acc);
|
| 56 |
+
*output = acc;
|
| 57 |
+
output += threadgroup_size.x;
|
| 58 |
+
}
|
| 59 |
+
}
|
gptoss_kernels/source/context.c
ADDED
|
@@ -0,0 +1,1115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <assert.h>
|
| 2 |
+
#include <float.h>
|
| 3 |
+
#include <inttypes.h>
|
| 4 |
+
#include <stdbool.h>
|
| 5 |
+
#include <stdint.h>
|
| 6 |
+
#include <stdlib.h>
|
| 7 |
+
#include <string.h>
|
| 8 |
+
|
| 9 |
+
#include <gpt-oss.h>
|
| 10 |
+
|
| 11 |
+
#include "internal/datatype.h"
|
| 12 |
+
#include "internal/model.h"
|
| 13 |
+
#include "internal/metal.h"
|
| 14 |
+
#include "internal/metal-kernels.h"
|
| 15 |
+
#include "internal/log.h"
|
| 16 |
+
#include "internal/rng.h"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_create(
|
| 20 |
+
gptoss_model_t model,
|
| 21 |
+
size_t context_length,
|
| 22 |
+
size_t max_batch_tokens,
|
| 23 |
+
gptoss_context_t* context_out)
|
| 24 |
+
{
|
| 25 |
+
*context_out = NULL;
|
| 26 |
+
|
| 27 |
+
enum gptoss_status status = gptoss_status_success;
|
| 28 |
+
struct gptoss_context* context = NULL;
|
| 29 |
+
|
| 30 |
+
// Validate context_length
|
| 31 |
+
if (context_length == 0) {
|
| 32 |
+
context_length = model->context_length;
|
| 33 |
+
} else if (context_length > model->context_length) {
|
| 34 |
+
GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
|
| 35 |
+
context_length, model->context_length);
|
| 36 |
+
status = gptoss_status_invalid_argument;
|
| 37 |
+
goto cleanup;
|
| 38 |
+
}
|
| 39 |
+
assert(context_length != 0);
|
| 40 |
+
assert(context_length <= model->context_length);
|
| 41 |
+
|
| 42 |
+
// Validate max_batch_tokens
|
| 43 |
+
if (max_batch_tokens == 0) {
|
| 44 |
+
max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
|
| 45 |
+
} else if (max_batch_tokens > context_length) {
|
| 46 |
+
GPTOSS_LOG_ERROR("requested max batch tokens %zu exceeds context length %zu",
|
| 47 |
+
max_batch_tokens, context_length);
|
| 48 |
+
status = gptoss_status_invalid_argument;
|
| 49 |
+
goto cleanup;
|
| 50 |
+
}
|
| 51 |
+
assert(max_batch_tokens != 0);
|
| 52 |
+
assert(max_batch_tokens <= context_length);
|
| 53 |
+
|
| 54 |
+
context = malloc(sizeof(struct gptoss_context));
|
| 55 |
+
if (context == NULL) {
|
| 56 |
+
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
|
| 57 |
+
sizeof(struct gptoss_context));
|
| 58 |
+
status = gptoss_status_insufficient_memory;
|
| 59 |
+
goto cleanup;
|
| 60 |
+
}
|
| 61 |
+
memset(context, 0, sizeof(struct gptoss_context));
|
| 62 |
+
|
| 63 |
+
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
|
| 64 |
+
context->max_tokens = context_length;
|
| 65 |
+
context->max_batch_tokens = max_batch_tokens;
|
| 66 |
+
|
| 67 |
+
// Activation buffers
|
| 68 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
|
| 69 |
+
if (status != gptoss_status_success) {
|
| 70 |
+
goto cleanup;
|
| 71 |
+
}
|
| 72 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
|
| 73 |
+
if (status != gptoss_status_success) {
|
| 74 |
+
goto cleanup;
|
| 75 |
+
}
|
| 76 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
|
| 77 |
+
if (status != gptoss_status_success) {
|
| 78 |
+
goto cleanup;
|
| 79 |
+
}
|
| 80 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
|
| 81 |
+
if (status != gptoss_status_success) {
|
| 82 |
+
goto cleanup;
|
| 83 |
+
}
|
| 84 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
|
| 85 |
+
if (status != gptoss_status_success) {
|
| 86 |
+
goto cleanup;
|
| 87 |
+
}
|
| 88 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
|
| 89 |
+
if (status != gptoss_status_success) {
|
| 90 |
+
goto cleanup;
|
| 91 |
+
}
|
| 92 |
+
// The last entry will hold the total number of tokens.
|
| 93 |
+
status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
|
| 94 |
+
if (status != gptoss_status_success) {
|
| 95 |
+
goto cleanup;
|
| 96 |
+
}
|
| 97 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
|
| 98 |
+
if (status != gptoss_status_success) {
|
| 99 |
+
goto cleanup;
|
| 100 |
+
}
|
| 101 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
|
| 102 |
+
if (status != gptoss_status_success) {
|
| 103 |
+
goto cleanup;
|
| 104 |
+
}
|
| 105 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
|
| 106 |
+
if (status != gptoss_status_success) {
|
| 107 |
+
goto cleanup;
|
| 108 |
+
}
|
| 109 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
|
| 110 |
+
if (status != gptoss_status_success) {
|
| 111 |
+
goto cleanup;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// Input/output buffers
|
| 115 |
+
status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
|
| 116 |
+
if (status != gptoss_status_success) {
|
| 117 |
+
goto cleanup;
|
| 118 |
+
}
|
| 119 |
+
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
|
| 120 |
+
if (status != gptoss_status_success) {
|
| 121 |
+
goto cleanup;
|
| 122 |
+
}
|
| 123 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
|
| 124 |
+
if (status != gptoss_status_success) {
|
| 125 |
+
goto cleanup;
|
| 126 |
+
}
|
| 127 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
|
| 128 |
+
if (status != gptoss_status_success) {
|
| 129 |
+
goto cleanup;
|
| 130 |
+
}
|
| 131 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
|
| 132 |
+
if (status != gptoss_status_success) {
|
| 133 |
+
goto cleanup;
|
| 134 |
+
}
|
| 135 |
+
status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
|
| 136 |
+
if (status != gptoss_status_success) {
|
| 137 |
+
goto cleanup;
|
| 138 |
+
}
|
| 139 |
+
status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
|
| 140 |
+
if (status != gptoss_status_success) {
|
| 141 |
+
goto cleanup;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
context->kvcache_size = context->kvcache_buffer.size;
|
| 145 |
+
context->allocation_size =
|
| 146 |
+
context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
|
| 147 |
+
context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
|
| 148 |
+
context->gate_activation_buffer.size + context->expert_activation_buffer.size +
|
| 149 |
+
context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
|
| 150 |
+
context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
|
| 151 |
+
context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
|
| 152 |
+
|
| 153 |
+
context->model = model;
|
| 154 |
+
gptoss_model_retain(model);
|
| 155 |
+
*context_out = context;
|
| 156 |
+
context = NULL;
|
| 157 |
+
|
| 158 |
+
cleanup:
|
| 159 |
+
gptoss_context_release(context);
|
| 160 |
+
return status;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
|
| 164 |
+
gptoss_context_t context,
|
| 165 |
+
size_t* num_tokens_out)
|
| 166 |
+
{
|
| 167 |
+
*num_tokens_out = context->num_tokens;
|
| 168 |
+
return gptoss_status_success;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
|
| 172 |
+
gptoss_context_t context,
|
| 173 |
+
size_t* max_tokens_out)
|
| 174 |
+
{
|
| 175 |
+
*max_tokens_out = context->max_tokens;
|
| 176 |
+
return gptoss_status_success;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
|
| 180 |
+
gptoss_context_t context,
|
| 181 |
+
uint32_t* tokens_out,
|
| 182 |
+
size_t max_tokens,
|
| 183 |
+
size_t* num_tokens_out)
|
| 184 |
+
{
|
| 185 |
+
*num_tokens_out = context->num_tokens;
|
| 186 |
+
if (max_tokens < context->num_tokens) {
|
| 187 |
+
return gptoss_status_insufficient_memory;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
if (context->num_tokens != 0) {
|
| 191 |
+
memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
|
| 192 |
+
}
|
| 193 |
+
return gptoss_status_success;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
|
| 197 |
+
// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
|
| 198 |
+
// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
|
| 199 |
+
static enum gptoss_status process_tokens(
|
| 200 |
+
gptoss_context_t context,
|
| 201 |
+
struct gptoss_metal_command_buffer* command_buffer,
|
| 202 |
+
size_t input_tokens_offset,
|
| 203 |
+
size_t num_input_tokens,
|
| 204 |
+
size_t num_output_tokens)
|
| 205 |
+
{
|
| 206 |
+
assert(num_input_tokens != 0);
|
| 207 |
+
assert(num_input_tokens <= context->max_batch_tokens);
|
| 208 |
+
assert(num_output_tokens <= context->max_batch_tokens);
|
| 209 |
+
assert(num_input_tokens >= num_output_tokens);
|
| 210 |
+
const size_t dense_matmul_kernel_token_multiple_constraint = 64;
|
| 211 |
+
const size_t min_tokens_for_dense_moe_kernels = 64;
|
| 212 |
+
|
| 213 |
+
enum gptoss_status status = gptoss_status_success;
|
| 214 |
+
const struct gptoss_model* model = context->model;
|
| 215 |
+
|
| 216 |
+
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
| 217 |
+
|
| 218 |
+
const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
|
| 219 |
+
for (size_t input_batch_start = input_tokens_offset;
|
| 220 |
+
input_batch_start < input_tokens_end;
|
| 221 |
+
input_batch_start += context->max_batch_tokens)
|
| 222 |
+
{
|
| 223 |
+
const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);
|
| 224 |
+
const size_t input_batch_end = input_batch_start + input_batch_size;
|
| 225 |
+
const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
|
| 226 |
+
|
| 227 |
+
status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
| 228 |
+
command_buffer,
|
| 229 |
+
&model->bf16_f32_embeddings_fn,
|
| 230 |
+
model->embeddings_threadgroup_size,
|
| 231 |
+
&context->token_buffer,
|
| 232 |
+
input_batch_start * sizeof(uint32_t),
|
| 233 |
+
&model->shared_weight_buffer,
|
| 234 |
+
/*weight_offset=*/0,
|
| 235 |
+
&context->residual_activation_buffer,
|
| 236 |
+
/*output_offset=*/0,
|
| 237 |
+
&context->control_buffer,
|
| 238 |
+
/*control_offset=*/0,
|
| 239 |
+
/*num_tokens=*/input_batch_size,
|
| 240 |
+
/*num_channels=*/model->embedding_dim);
|
| 241 |
+
if (status != gptoss_status_success) {
|
| 242 |
+
GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
|
| 243 |
+
return status;
|
| 244 |
+
}
|
| 245 |
+
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 246 |
+
const bool last_block = n + 1 == model->num_blocks;
|
| 247 |
+
const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
|
| 248 |
+
|
| 249 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 250 |
+
command_buffer,
|
| 251 |
+
&model->f32_bf16w_rmsnorm_fn,
|
| 252 |
+
&context->residual_activation_buffer,
|
| 253 |
+
/*input_offset=*/0,
|
| 254 |
+
&model->shared_weight_buffer,
|
| 255 |
+
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
| 256 |
+
&context->rmsnorm_activation_buffer,
|
| 257 |
+
/*output_offset=*/0,
|
| 258 |
+
&context->control_buffer,
|
| 259 |
+
/*control_offset=*/0,
|
| 260 |
+
/*num_tokens=*/input_batch_size,
|
| 261 |
+
/*num_channels=*/model->embedding_dim,
|
| 262 |
+
model->rmsnorm_epsilon);
|
| 263 |
+
if (status != gptoss_status_success) {
|
| 264 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 265 |
+
return status;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 269 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
|
| 270 |
+
command_buffer,
|
| 271 |
+
&model->f32_bf16w_dense_matmul_qkv_fn,
|
| 272 |
+
&context->rmsnorm_activation_buffer,
|
| 273 |
+
/*input_offset=*/0,
|
| 274 |
+
&model->shared_weight_buffer,
|
| 275 |
+
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
|
| 276 |
+
&model->shared_weight_buffer,
|
| 277 |
+
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
|
| 278 |
+
&context->qkv_activation_buffer,
|
| 279 |
+
/*output_offset=*/0,
|
| 280 |
+
&context->control_buffer,
|
| 281 |
+
/*control_offset=*/0,
|
| 282 |
+
/*num_tokens=*/input_batch_size,
|
| 283 |
+
/*num_cols=*/model->embedding_dim,
|
| 284 |
+
/*num_rows=*/attn_qkv_dim);
|
| 285 |
+
if (status != gptoss_status_success) {
|
| 286 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
|
| 287 |
+
return status;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
|
| 291 |
+
command_buffer,
|
| 292 |
+
&model->f32_rope_fn,
|
| 293 |
+
/*threadgroup_size=*/32,
|
| 294 |
+
&context->qkv_activation_buffer,
|
| 295 |
+
/*input_offset=*/0,
|
| 296 |
+
&context->control_buffer,
|
| 297 |
+
/*control_offset=*/0,
|
| 298 |
+
model->rope_theta,
|
| 299 |
+
model->interpolation_scale,
|
| 300 |
+
model->yarn_offset,
|
| 301 |
+
model->yarn_scale,
|
| 302 |
+
model->yarn_multiplier,
|
| 303 |
+
input_batch_size,
|
| 304 |
+
model->num_heads,
|
| 305 |
+
model->num_kv_heads,
|
| 306 |
+
model->head_dim,
|
| 307 |
+
/*token_offset=*/input_batch_start);
|
| 308 |
+
if (status != gptoss_status_success) {
|
| 309 |
+
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
|
| 310 |
+
return status;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
for (uint32_t t = 0; t < input_batch_size; t++) {
|
| 314 |
+
for (uint32_t kv = 0; kv < 2; kv++) {
|
| 315 |
+
for (uint32_t h = 0; h < model->num_kv_heads; h++) {
|
| 316 |
+
status = gptoss_metal_command_buffer_encode_copy_buffer(
|
| 317 |
+
command_buffer,
|
| 318 |
+
&context->qkv_activation_buffer,
|
| 319 |
+
/*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
|
| 320 |
+
&context->kvcache_buffer,
|
| 321 |
+
/*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
|
| 322 |
+
/*size=*/model->head_dim * sizeof(float));
|
| 323 |
+
if (status != gptoss_status_success) {
|
| 324 |
+
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
|
| 325 |
+
return status;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
} else {
|
| 331 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
|
| 332 |
+
command_buffer,
|
| 333 |
+
&model->f32_bf16w_matmul_qkv_fn,
|
| 334 |
+
model->attn_qkv_threadgroup_size,
|
| 335 |
+
&context->rmsnorm_activation_buffer,
|
| 336 |
+
/*input_offset=*/0,
|
| 337 |
+
&model->shared_weight_buffer,
|
| 338 |
+
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
|
| 339 |
+
&model->shared_weight_buffer,
|
| 340 |
+
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
|
| 341 |
+
&context->qkv_activation_buffer,
|
| 342 |
+
/*output_offset=*/0,
|
| 343 |
+
&context->kvcache_buffer,
|
| 344 |
+
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
|
| 345 |
+
&context->control_buffer,
|
| 346 |
+
/*control_offset=*/0,
|
| 347 |
+
/*num_tokens=*/input_batch_size,
|
| 348 |
+
/*num_cols=*/model->embedding_dim,
|
| 349 |
+
/*num_q_heads=*/model->num_heads,
|
| 350 |
+
/*num_kv_heads=*/model->num_kv_heads,
|
| 351 |
+
/*attn_head_dim=*/model->head_dim,
|
| 352 |
+
/*token_offset=*/input_batch_start,
|
| 353 |
+
/*max_tokens=*/context->max_tokens,
|
| 354 |
+
/*rope_base=*/model->rope_theta,
|
| 355 |
+
/*interpolation_scale=*/model->interpolation_scale,
|
| 356 |
+
/*yarn_offset=*/model->yarn_offset,
|
| 357 |
+
/*yarn_scale=*/model->yarn_scale,
|
| 358 |
+
/*yarn_multiplier=*/model->yarn_multiplier);
|
| 359 |
+
if (status != gptoss_status_success) {
|
| 360 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
|
| 361 |
+
return status;
|
| 362 |
+
}
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
if (num_block_output_tokens != 0) {
|
| 366 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
| 367 |
+
command_buffer,
|
| 368 |
+
&model->f32_sdpa_q8_d64_fn,
|
| 369 |
+
&context->qkv_activation_buffer,
|
| 370 |
+
/*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 371 |
+
&context->kvcache_buffer,
|
| 372 |
+
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
|
| 373 |
+
&model->shared_weight_buffer,
|
| 374 |
+
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
|
| 375 |
+
&context->sdpa_activation_buffer,
|
| 376 |
+
/*output_offset=*/0,
|
| 377 |
+
&context->control_buffer,
|
| 378 |
+
/*control_offset=*/0,
|
| 379 |
+
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
|
| 380 |
+
/*kv_stride=*/2 * context->max_tokens * model->head_dim,
|
| 381 |
+
num_block_output_tokens,
|
| 382 |
+
input_batch_start + input_batch_size - num_block_output_tokens,
|
| 383 |
+
model->num_heads, model->num_kv_heads, model->head_dim);
|
| 384 |
+
if (status != gptoss_status_success) {
|
| 385 |
+
GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
|
| 386 |
+
return status;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 390 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
|
| 391 |
+
command_buffer,
|
| 392 |
+
&model->f32_bf16w_dense_matmul_attn_output_fn,
|
| 393 |
+
&context->sdpa_activation_buffer,
|
| 394 |
+
/*input_offset=*/0,
|
| 395 |
+
&model->shared_weight_buffer,
|
| 396 |
+
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
|
| 397 |
+
&model->shared_weight_buffer,
|
| 398 |
+
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
|
| 399 |
+
&context->residual_activation_buffer,
|
| 400 |
+
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 401 |
+
&context->control_buffer,
|
| 402 |
+
/*control_offset=*/0,
|
| 403 |
+
/*num_tokens=*/num_block_output_tokens,
|
| 404 |
+
/*num_cols=*/model->num_heads * model->head_dim,
|
| 405 |
+
/*num_rows=*/model->embedding_dim);
|
| 406 |
+
if (status != gptoss_status_success) {
|
| 407 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
|
| 408 |
+
return status;
|
| 409 |
+
}
|
| 410 |
+
} else {
|
| 411 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
| 412 |
+
command_buffer,
|
| 413 |
+
&model->f32_bf16w_matmul_fn,
|
| 414 |
+
model->attn_out_threadgroup_size,
|
| 415 |
+
&context->sdpa_activation_buffer,
|
| 416 |
+
/*input_offset=*/0,
|
| 417 |
+
&model->shared_weight_buffer,
|
| 418 |
+
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
|
| 419 |
+
&model->shared_weight_buffer,
|
| 420 |
+
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
|
| 421 |
+
&context->residual_activation_buffer,
|
| 422 |
+
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 423 |
+
&context->control_buffer,
|
| 424 |
+
/*control_offset=*/0,
|
| 425 |
+
/*num_tokens=*/num_block_output_tokens,
|
| 426 |
+
/*num_cols=*/model->num_heads * model->head_dim,
|
| 427 |
+
/*num_rows=*/model->embedding_dim);
|
| 428 |
+
if (status != gptoss_status_success) {
|
| 429 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
|
| 430 |
+
return status;
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 434 |
+
command_buffer,
|
| 435 |
+
&model->f32_bf16w_rmsnorm_fn,
|
| 436 |
+
&context->residual_activation_buffer,
|
| 437 |
+
/*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 438 |
+
&model->shared_weight_buffer,
|
| 439 |
+
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
|
| 440 |
+
&context->rmsnorm_activation_buffer,
|
| 441 |
+
/*output_offset=*/0,
|
| 442 |
+
&context->control_buffer,
|
| 443 |
+
/*control_offset=*/0,
|
| 444 |
+
num_block_output_tokens,
|
| 445 |
+
model->embedding_dim,
|
| 446 |
+
model->rmsnorm_epsilon);
|
| 447 |
+
if (status != gptoss_status_success) {
|
| 448 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 449 |
+
return status;
|
| 450 |
+
}
|
| 451 |
+
if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
|
| 452 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
|
| 453 |
+
command_buffer,
|
| 454 |
+
&model->f32_bf16w_dense_matmul_mlp_gate_fn,
|
| 455 |
+
&context->rmsnorm_activation_buffer,
|
| 456 |
+
/*input_offset=*/0,
|
| 457 |
+
&model->shared_weight_buffer,
|
| 458 |
+
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
|
| 459 |
+
&model->shared_weight_buffer,
|
| 460 |
+
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
|
| 461 |
+
&context->gate_activation_buffer,
|
| 462 |
+
/*output_offset=*/0,
|
| 463 |
+
&context->control_buffer,
|
| 464 |
+
/*control_offset=*/0,
|
| 465 |
+
num_block_output_tokens,
|
| 466 |
+
model->embedding_dim,
|
| 467 |
+
model->num_experts);
|
| 468 |
+
if (status != gptoss_status_success) {
|
| 469 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
|
| 470 |
+
return status;
|
| 471 |
+
}
|
| 472 |
+
} else {
|
| 473 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
| 474 |
+
command_buffer,
|
| 475 |
+
&model->f32_bf16w_matmul_fn,
|
| 476 |
+
model->mlp_gate_threadgroup_size,
|
| 477 |
+
&context->rmsnorm_activation_buffer,
|
| 478 |
+
/*input_offset=*/0,
|
| 479 |
+
&model->shared_weight_buffer,
|
| 480 |
+
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
|
| 481 |
+
&model->shared_weight_buffer,
|
| 482 |
+
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
|
| 483 |
+
&context->gate_activation_buffer,
|
| 484 |
+
/*output_offset=*/0,
|
| 485 |
+
&context->control_buffer,
|
| 486 |
+
/*control_offset=*/0,
|
| 487 |
+
/*num_tokens=*/num_block_output_tokens,
|
| 488 |
+
/*num_cols=*/model->embedding_dim,
|
| 489 |
+
/*num_rows=*/model->num_experts);
|
| 490 |
+
if (status != gptoss_status_success) {
|
| 491 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
|
| 492 |
+
return status;
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
const char* kernel_name = NULL;
|
| 497 |
+
switch (model->num_experts) {
|
| 498 |
+
case 32:
|
| 499 |
+
kernel_name = "f32_topk_softmax_e32_k4_fn";
|
| 500 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 501 |
+
command_buffer,
|
| 502 |
+
&model->f32_topk_softmax_e32_k4_fn,
|
| 503 |
+
&context->gate_activation_buffer, /*input_offset=*/0,
|
| 504 |
+
&context->expert_activation_buffer, /*output_offset=*/0,
|
| 505 |
+
&context->control_buffer, /*control_offset=*/0,
|
| 506 |
+
num_block_output_tokens,
|
| 507 |
+
model->num_experts,
|
| 508 |
+
model->num_active_experts);
|
| 509 |
+
break;
|
| 510 |
+
case 128:
|
| 511 |
+
kernel_name = "f32_topk_softmax_e128_k4_fn";
|
| 512 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 513 |
+
command_buffer,
|
| 514 |
+
&model->f32_topk_softmax_e128_k4_fn,
|
| 515 |
+
&context->gate_activation_buffer, /*input_offset=*/0,
|
| 516 |
+
&context->expert_activation_buffer, /*output_offset=*/0,
|
| 517 |
+
&context->control_buffer, /*control_offset=*/0,
|
| 518 |
+
num_block_output_tokens,
|
| 519 |
+
model->num_experts,
|
| 520 |
+
model->num_active_experts);
|
| 521 |
+
break;
|
| 522 |
+
default:
|
| 523 |
+
status = gptoss_status_unsupported_argument;
|
| 524 |
+
GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
|
| 525 |
+
return status;
|
| 526 |
+
}
|
| 527 |
+
if (status != gptoss_status_success) {
|
| 528 |
+
GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
|
| 529 |
+
return status;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
// If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
|
| 533 |
+
if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
|
| 534 |
+
status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
|
| 535 |
+
command_buffer,
|
| 536 |
+
&model->f32_expert_routing_metadata_fn,
|
| 537 |
+
&context->expert_activation_buffer,
|
| 538 |
+
/*expert_predictions_offset=*/0,
|
| 539 |
+
&context->expert_offset_buffer,
|
| 540 |
+
/*expert_offsets_offset=*/0,
|
| 541 |
+
&context->token_to_expert_routing_buffer,
|
| 542 |
+
/*intra_expert_offsets_offset=*/0,
|
| 543 |
+
num_block_output_tokens * model->num_active_experts,
|
| 544 |
+
model->num_experts);
|
| 545 |
+
if (status != gptoss_status_success) {
|
| 546 |
+
GPTOSS_LOG_ERROR("failed to encode f32_expert_routing_metadata kernel launch");
|
| 547 |
+
return status;
|
| 548 |
+
}
|
| 549 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
|
| 550 |
+
command_buffer,
|
| 551 |
+
&model->f32_scatter_e4_fn,
|
| 552 |
+
&context->rmsnorm_activation_buffer,
|
| 553 |
+
/*input_offset=*/0,
|
| 554 |
+
&context->expert_activation_buffer,
|
| 555 |
+
/*expert_predictions_offset=*/0,
|
| 556 |
+
&context->expert_offset_buffer,
|
| 557 |
+
/*expert_offsets_offset=*/0,
|
| 558 |
+
&context->token_to_expert_routing_buffer,
|
| 559 |
+
/*intra_expert_offsets_offset=*/0,
|
| 560 |
+
&context->swiglu_input_buffer,
|
| 561 |
+
/*output_offset=*/0,
|
| 562 |
+
model->embedding_dim,
|
| 563 |
+
num_block_output_tokens,
|
| 564 |
+
model->num_active_experts);
|
| 565 |
+
if (status != gptoss_status_success) {
|
| 566 |
+
GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
|
| 567 |
+
return status;
|
| 568 |
+
}
|
| 569 |
+
// Dense MoE SwiGLU matmul.
|
| 570 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
|
| 571 |
+
command_buffer,
|
| 572 |
+
&model->f32_mf4w_moe_dense_matmul_swiglu_fn,
|
| 573 |
+
&context->expert_offset_buffer,
|
| 574 |
+
/*expert_offsets_offset=*/0,
|
| 575 |
+
&context->swiglu_input_buffer,
|
| 576 |
+
/*input_offset=*/0,
|
| 577 |
+
&model->block_weight_buffers[n],
|
| 578 |
+
/*weight_block_offset=*/0,
|
| 579 |
+
&model->block_weight_buffers[n],
|
| 580 |
+
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
|
| 581 |
+
&model->block_weight_buffers[n],
|
| 582 |
+
/*bias_offset=*/model->mlp_swiglu_bias_offset,
|
| 583 |
+
&context->swiglu_activation_buffer,
|
| 584 |
+
/*output_offset=*/0,
|
| 585 |
+
model->swiglu_limit,
|
| 586 |
+
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
|
| 587 |
+
num_block_output_tokens,
|
| 588 |
+
model->num_experts,
|
| 589 |
+
model->embedding_dim,
|
| 590 |
+
2 * model->mlp_dim);
|
| 591 |
+
if (status != gptoss_status_success) {
|
| 592 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
|
| 593 |
+
return status;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
// Dense MoE proj matmul.
|
| 597 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
|
| 598 |
+
command_buffer,
|
| 599 |
+
&model->f32_mf4w_moe_dense_matmul_fn,
|
| 600 |
+
&context->expert_offset_buffer,
|
| 601 |
+
/*expert_offsets_offset=*/0,
|
| 602 |
+
&context->swiglu_activation_buffer,
|
| 603 |
+
/*input_offset=*/0,
|
| 604 |
+
&model->block_weight_buffers[n],
|
| 605 |
+
/*weight_block_offset=*/model->mlp_out_block_offset,
|
| 606 |
+
&model->block_weight_buffers[n],
|
| 607 |
+
/*weight_scale_offset=*/model->mlp_out_scale_offset,
|
| 608 |
+
&model->block_weight_buffers[n],
|
| 609 |
+
/*bias_offset=*/model->mlp_out_bias_offset,
|
| 610 |
+
&context->moe_activation_buffer,
|
| 611 |
+
/*output_offset=*/0,
|
| 612 |
+
/*expert_stride_bytes=*/model->per_expert_block_weight_size,
|
| 613 |
+
num_block_output_tokens,
|
| 614 |
+
model->num_experts,
|
| 615 |
+
model->mlp_dim,
|
| 616 |
+
model->embedding_dim);
|
| 617 |
+
if (status != gptoss_status_success) {
|
| 618 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
|
| 619 |
+
return status;
|
| 620 |
+
}
|
| 621 |
+
// Gather and accumulate.
|
| 622 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
|
| 623 |
+
command_buffer,
|
| 624 |
+
&model->f32_gather_and_accumulate_e4_fn,
|
| 625 |
+
&context->moe_activation_buffer,
|
| 626 |
+
/*input_offset=*/0,
|
| 627 |
+
&context->expert_activation_buffer,
|
| 628 |
+
/*expert_predictions_offset=*/0,
|
| 629 |
+
&context->expert_offset_buffer,
|
| 630 |
+
/*expert_offsets_offset=*/0,
|
| 631 |
+
&context->token_to_expert_routing_buffer,
|
| 632 |
+
/*intra_expert_offsets_offset=*/0,
|
| 633 |
+
&context->residual_activation_buffer,
|
| 634 |
+
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 635 |
+
model->embedding_dim,
|
| 636 |
+
num_block_output_tokens,
|
| 637 |
+
model->num_active_experts);
|
| 638 |
+
if (status != gptoss_status_success) {
|
| 639 |
+
GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
|
| 640 |
+
return status;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
} else {
|
| 644 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
| 645 |
+
command_buffer,
|
| 646 |
+
&model->f32_mf4w_moe_matmul_swiglu_fn,
|
| 647 |
+
model->mlp_swiglu_threadgroup_size,
|
| 648 |
+
&context->rmsnorm_activation_buffer,
|
| 649 |
+
/*input_offset=*/0,
|
| 650 |
+
&context->expert_activation_buffer,
|
| 651 |
+
/*expert_offset=*/0,
|
| 652 |
+
&model->block_weight_buffers[n],
|
| 653 |
+
/*weight_block_offset=*/0,
|
| 654 |
+
&model->block_weight_buffers[n],
|
| 655 |
+
/*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
|
| 656 |
+
&model->block_weight_buffers[n],
|
| 657 |
+
/*bias_offset=*/model->mlp_swiglu_bias_offset,
|
| 658 |
+
&context->swiglu_activation_buffer,
|
| 659 |
+
/*output_offset=*/0,
|
| 660 |
+
&context->control_buffer,
|
| 661 |
+
/*control_offset=*/0,
|
| 662 |
+
model->swiglu_limit,
|
| 663 |
+
model->per_expert_block_weight_size,
|
| 664 |
+
num_block_output_tokens,
|
| 665 |
+
model->num_active_experts,
|
| 666 |
+
model->embedding_dim,
|
| 667 |
+
model->mlp_dim);
|
| 668 |
+
if (status != gptoss_status_success) {
|
| 669 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
|
| 670 |
+
return status;
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
| 674 |
+
command_buffer,
|
| 675 |
+
&model->f32_mf4w_moe_matmul_fn,
|
| 676 |
+
model->mlp_out_threadgroup_size,
|
| 677 |
+
&context->swiglu_activation_buffer,
|
| 678 |
+
/*input_offset=*/0,
|
| 679 |
+
&context->expert_activation_buffer,
|
| 680 |
+
/*expert_offset=*/0,
|
| 681 |
+
&model->block_weight_buffers[n],
|
| 682 |
+
/*weight_block_offset=*/model->mlp_out_block_offset,
|
| 683 |
+
&model->block_weight_buffers[n],
|
| 684 |
+
/*weight_scale_offset=*/model->mlp_out_scale_offset,
|
| 685 |
+
&model->block_weight_buffers[n],
|
| 686 |
+
/*bias_offset=*/model->mlp_out_bias_offset,
|
| 687 |
+
&context->moe_activation_buffer,
|
| 688 |
+
/*output_offset=*/0,
|
| 689 |
+
&context->control_buffer,
|
| 690 |
+
/*control_offset=*/0,
|
| 691 |
+
model->per_expert_block_weight_size,
|
| 692 |
+
num_block_output_tokens,
|
| 693 |
+
model->num_active_experts,
|
| 694 |
+
model->mlp_dim,
|
| 695 |
+
model->embedding_dim);
|
| 696 |
+
if (status != gptoss_status_success) {
|
| 697 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
|
| 698 |
+
return status;
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
| 702 |
+
command_buffer,
|
| 703 |
+
&model->f32_accumulate_e4_fn,
|
| 704 |
+
model->mlp_acc_threadgroup_size,
|
| 705 |
+
model->max_threadgroups,
|
| 706 |
+
&context->moe_activation_buffer,
|
| 707 |
+
/*input_offset=*/0,
|
| 708 |
+
&context->expert_activation_buffer,
|
| 709 |
+
/*expert_offset=*/0,
|
| 710 |
+
&context->residual_activation_buffer,
|
| 711 |
+
/*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
|
| 712 |
+
&context->control_buffer,
|
| 713 |
+
/*control_offset=*/0,
|
| 714 |
+
model->embedding_dim,
|
| 715 |
+
num_block_output_tokens,
|
| 716 |
+
model->num_active_experts);
|
| 717 |
+
if (status != gptoss_status_success) {
|
| 718 |
+
GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
|
| 719 |
+
return status;
|
| 720 |
+
}
|
| 721 |
+
}
|
| 722 |
+
}
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
if (output_batch_size != 0) {
|
| 726 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 727 |
+
command_buffer,
|
| 728 |
+
&model->f32_bf16w_rmsnorm_fn,
|
| 729 |
+
&context->residual_activation_buffer,
|
| 730 |
+
/*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
|
| 731 |
+
&model->shared_weight_buffer,
|
| 732 |
+
/*weight_offset=*/model->rmsnorm_weight_offset,
|
| 733 |
+
&context->rmsnorm_activation_buffer,
|
| 734 |
+
/*output_offset=*/0,
|
| 735 |
+
&context->control_buffer,
|
| 736 |
+
/*control_offset=*/0,
|
| 737 |
+
/*num_tokens=*/output_batch_size,
|
| 738 |
+
/*num_channels=*/model->embedding_dim,
|
| 739 |
+
model->rmsnorm_epsilon);
|
| 740 |
+
if (status != gptoss_status_success) {
|
| 741 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
|
| 742 |
+
return status;
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
status = gptoss_metal_command_buffer_encode_fill_buffer(
|
| 746 |
+
command_buffer,
|
| 747 |
+
&context->argmax_buffer,
|
| 748 |
+
/*offset=*/0,
|
| 749 |
+
/*size=*/sizeof(uint64_t) * output_batch_size,
|
| 750 |
+
/*fill_value=*/0xFF);
|
| 751 |
+
if (status != gptoss_status_success) {
|
| 752 |
+
GPTOSS_LOG_ERROR("failed to encode fill buffer command");
|
| 753 |
+
return status;
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
| 757 |
+
command_buffer,
|
| 758 |
+
&model->f32_bf16w_unembedding_fn,
|
| 759 |
+
model->unembedding_threadgroup_size,
|
| 760 |
+
model->max_threadgroups,
|
| 761 |
+
&context->rmsnorm_activation_buffer,
|
| 762 |
+
/*input_offset=*/0,
|
| 763 |
+
&model->shared_weight_buffer,
|
| 764 |
+
/*weight_offset=*/model->unembedding_weight_offset,
|
| 765 |
+
&context->score_buffer,
|
| 766 |
+
/*output_offset=*/0,
|
| 767 |
+
&context->argmax_buffer,
|
| 768 |
+
/*argmax_offset=*/0,
|
| 769 |
+
&context->control_buffer,
|
| 770 |
+
/*control_offset=*/0,
|
| 771 |
+
/*num_tokens=*/output_batch_size,
|
| 772 |
+
/*num_cols=*/model->embedding_dim,
|
| 773 |
+
/*num_rows=*/model->vocabulary_size);
|
| 774 |
+
if (status != gptoss_status_success) {
|
| 775 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
|
| 776 |
+
return status;
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
}
|
| 780 |
+
return gptoss_status_success;
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
|
| 784 |
+
gptoss_context_t context,
|
| 785 |
+
const char* text,
|
| 786 |
+
size_t text_length,
|
| 787 |
+
size_t* num_tokens_out)
|
| 788 |
+
{
|
| 789 |
+
enum gptoss_status status = gptoss_status_success;
|
| 790 |
+
const struct gptoss_model* model = context->model;
|
| 791 |
+
const struct gptoss_tokenizer* tokenizer = model->tokenizer;
|
| 792 |
+
size_t num_appended_tokens = 0;
|
| 793 |
+
while (text_length != 0) {
|
| 794 |
+
if (context->num_tokens == context->max_tokens) {
|
| 795 |
+
status = gptoss_status_context_overflow;
|
| 796 |
+
break;
|
| 797 |
+
}
|
| 798 |
+
const char* tokens = tokenizer->tokens_ptr;
|
| 799 |
+
uint32_t best_token = UINT32_MAX;
|
| 800 |
+
uint32_t best_token_length = 0;
|
| 801 |
+
for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
|
| 802 |
+
uint16_t token_length;
|
| 803 |
+
memcpy(&token_length, tokens, sizeof(uint16_t));
|
| 804 |
+
tokens += sizeof(uint16_t);
|
| 805 |
+
if (token_length <= text_length && token_length > best_token_length) {
|
| 806 |
+
if (memcmp(text, tokens, token_length) == 0) {
|
| 807 |
+
if (token_length > best_token_length) {
|
| 808 |
+
best_token = (uint32_t) t;
|
| 809 |
+
best_token_length = token_length;
|
| 810 |
+
}
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
tokens += token_length;
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
if (best_token == UINT32_MAX) {
|
| 817 |
+
GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
|
| 818 |
+
return gptoss_status_invalid_argument;
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
| 822 |
+
if (context->num_kv_tokens > context->num_tokens) {
|
| 823 |
+
if (input_tokens[context->num_tokens] != best_token) {
|
| 824 |
+
input_tokens[context->num_tokens] = best_token;
|
| 825 |
+
|
| 826 |
+
// Invalidate the KV cache starting with the newly added token.
|
| 827 |
+
context->num_kv_tokens = context->num_tokens;
|
| 828 |
+
}
|
| 829 |
+
context->num_tokens++;
|
| 830 |
+
} else {
|
| 831 |
+
input_tokens[context->num_tokens++] = best_token;
|
| 832 |
+
}
|
| 833 |
+
num_appended_tokens++;
|
| 834 |
+
text += best_token_length;
|
| 835 |
+
text_length -= best_token_length;
|
| 836 |
+
}
|
| 837 |
+
if (num_tokens_out != NULL) {
|
| 838 |
+
*num_tokens_out = num_appended_tokens;
|
| 839 |
+
}
|
| 840 |
+
return status;
|
| 841 |
+
}
|
| 842 |
+
|
| 843 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
|
| 844 |
+
gptoss_context_t context,
|
| 845 |
+
size_t num_tokens,
|
| 846 |
+
const uint32_t* tokens)
|
| 847 |
+
{
|
| 848 |
+
const struct gptoss_model* model = context->model;
|
| 849 |
+
|
| 850 |
+
// Validate all tokens
|
| 851 |
+
for (size_t t = 0; t < num_tokens; t++) {
|
| 852 |
+
const uint32_t token = tokens[t];
|
| 853 |
+
if (token >= model->vocabulary_size) {
|
| 854 |
+
GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
|
| 855 |
+
token, t, context->model->vocabulary_size);
|
| 856 |
+
return gptoss_status_invalid_argument;
|
| 857 |
+
}
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
enum gptoss_status status = gptoss_status_success;
|
| 861 |
+
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
|
| 862 |
+
while (num_tokens != 0) {
|
| 863 |
+
if (context->num_tokens == context->max_tokens) {
|
| 864 |
+
status = gptoss_status_context_overflow;
|
| 865 |
+
break;
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
if (context->num_kv_tokens > context->num_tokens) {
|
| 869 |
+
const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
|
| 870 |
+
size_t num_verified_tokens = 0;
|
| 871 |
+
for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
|
| 872 |
+
if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
|
| 873 |
+
// Invalidate the KV cache starting with the newly added tokens.
|
| 874 |
+
context->num_kv_tokens = context->num_tokens + num_verified_tokens;
|
| 875 |
+
break;
|
| 876 |
+
}
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
context->num_tokens += num_verified_tokens;
|
| 880 |
+
tokens += num_verified_tokens;
|
| 881 |
+
num_tokens -= num_verified_tokens;
|
| 882 |
+
} else {
|
| 883 |
+
const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
|
| 884 |
+
memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
|
| 885 |
+
context->num_tokens += num_tokens_to_copy;
|
| 886 |
+
tokens += num_tokens_to_copy;
|
| 887 |
+
num_tokens -= num_tokens_to_copy;
|
| 888 |
+
}
|
| 889 |
+
}
|
| 890 |
+
|
| 891 |
+
return status;
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_process(
|
| 895 |
+
gptoss_context_t context)
|
| 896 |
+
{
|
| 897 |
+
if (context->num_tokens > context->num_kv_tokens) {
|
| 898 |
+
struct gptoss_metal_command_buffer command_buffer = {0};
|
| 899 |
+
|
| 900 |
+
enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
|
| 901 |
+
if (status != gptoss_status_success) {
|
| 902 |
+
goto cleanup;
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
|
| 906 |
+
control->abort = 0;
|
| 907 |
+
|
| 908 |
+
status = process_tokens(
|
| 909 |
+
context,
|
| 910 |
+
&command_buffer,
|
| 911 |
+
/*input_tokens_offset=*/context->num_kv_tokens,
|
| 912 |
+
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
|
| 913 |
+
/*num_output_tokens=*/0);
|
| 914 |
+
if (status != gptoss_status_success) {
|
| 915 |
+
goto cleanup;
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
status = gptoss_metal_command_buffer_commit(&command_buffer);
|
| 919 |
+
if (status != gptoss_status_success) {
|
| 920 |
+
goto cleanup;
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
| 924 |
+
if (status != gptoss_status_success) {
|
| 925 |
+
goto cleanup;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
context->num_kv_tokens = context->num_tokens;
|
| 929 |
+
|
| 930 |
+
cleanup:
|
| 931 |
+
gptoss_metal_command_buffer_release(&command_buffer);
|
| 932 |
+
return status;
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
return gptoss_status_success;
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_sample(
|
| 939 |
+
gptoss_context_t context,
|
| 940 |
+
float temperature,
|
| 941 |
+
uint64_t seed,
|
| 942 |
+
size_t max_tokens,
|
| 943 |
+
uint32_t* tokens_out,
|
| 944 |
+
size_t* num_tokens_out)
|
| 945 |
+
{
|
| 946 |
+
enum gptoss_status status = gptoss_status_success;
|
| 947 |
+
const struct gptoss_model* model = context->model;
|
| 948 |
+
struct gptoss_metal_command_buffer command_buffer = {0};
|
| 949 |
+
|
| 950 |
+
*num_tokens_out = 0;
|
| 951 |
+
|
| 952 |
+
const uint32_t num_original_tokens = context->num_tokens;
|
| 953 |
+
|
| 954 |
+
status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
|
| 955 |
+
if (status != gptoss_status_success) {
|
| 956 |
+
goto cleanup;
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
|
| 960 |
+
control->abort = 0;
|
| 961 |
+
|
| 962 |
+
for (size_t t = 0; t < max_tokens; t++) {
|
| 963 |
+
if (context->num_kv_tokens < context->num_tokens) {
|
| 964 |
+
status = process_tokens(
|
| 965 |
+
context,
|
| 966 |
+
&command_buffer,
|
| 967 |
+
/*input_tokens_offset=*/context->num_kv_tokens,
|
| 968 |
+
/*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
|
| 969 |
+
/*num_output_tokens=*/1);
|
| 970 |
+
context->num_kv_tokens = context->num_tokens;
|
| 971 |
+
} else {
|
| 972 |
+
status = process_tokens(
|
| 973 |
+
context,
|
| 974 |
+
&command_buffer,
|
| 975 |
+
/*input_tokens_offset=*/context->num_tokens - 1,
|
| 976 |
+
/*num_input_tokens=*/1,
|
| 977 |
+
/*num_output_tokens=*/1);
|
| 978 |
+
}
|
| 979 |
+
if (status != gptoss_status_success) {
|
| 980 |
+
goto cleanup;
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
if (temperature != 0.0f) {
|
| 984 |
+
assert(context->num_processed_tokens != 0);
|
| 985 |
+
uint32_t num_threadgroups = 0;
|
| 986 |
+
uint32_t num_dims_per_threadgroup = 0;
|
| 987 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
| 988 |
+
&command_buffer,
|
| 989 |
+
&model->f32_softmax_fn,
|
| 990 |
+
/*threadgroup_size=*/512,
|
| 991 |
+
model->max_threadgroups,
|
| 992 |
+
&context->score_buffer,
|
| 993 |
+
/*score_offset=*/0,
|
| 994 |
+
&context->argmax_buffer,
|
| 995 |
+
/*argmax_offset=*/0,
|
| 996 |
+
&context->prob_buffer,
|
| 997 |
+
/*prob_offset=*/0,
|
| 998 |
+
&context->sum_buffer,
|
| 999 |
+
/*sum_offset=*/0,
|
| 1000 |
+
&context->control_buffer,
|
| 1001 |
+
/*control_offset=*/0,
|
| 1002 |
+
model->vocabulary_size,
|
| 1003 |
+
/*num_tokens=*/1,
|
| 1004 |
+
temperature,
|
| 1005 |
+
&num_threadgroups,
|
| 1006 |
+
&num_dims_per_threadgroup);
|
| 1007 |
+
if (status != gptoss_status_success) {
|
| 1008 |
+
GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
|
| 1009 |
+
goto cleanup;
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
status = gptoss_metal_command_buffer_encode_launch_f32_sample(
|
| 1013 |
+
&command_buffer,
|
| 1014 |
+
&model->f32_sample_fn,
|
| 1015 |
+
/*min_threadgroup_size=*/512,
|
| 1016 |
+
&context->prob_buffer,
|
| 1017 |
+
/*prob_offset=*/0,
|
| 1018 |
+
&context->sum_buffer,
|
| 1019 |
+
/*sum_offset=*/0,
|
| 1020 |
+
&context->token_buffer,
|
| 1021 |
+
/*token_offset=*/context->num_tokens * sizeof(uint32_t),
|
| 1022 |
+
&context->control_buffer,
|
| 1023 |
+
/*control_offset=*/0,
|
| 1024 |
+
/*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
|
| 1025 |
+
/*rng_offset=*/context->num_tokens,
|
| 1026 |
+
/*num_blocks=*/num_threadgroups,
|
| 1027 |
+
/*num_channels=*/model->vocabulary_size,
|
| 1028 |
+
/*num_channels_per_block=*/num_dims_per_threadgroup);
|
| 1029 |
+
if (status != gptoss_status_success) {
|
| 1030 |
+
GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
|
| 1031 |
+
goto cleanup;
|
| 1032 |
+
}
|
| 1033 |
+
} else {
|
| 1034 |
+
status = gptoss_metal_command_buffer_encode_copy_buffer(
|
| 1035 |
+
&command_buffer,
|
| 1036 |
+
&context->argmax_buffer,
|
| 1037 |
+
/*input_offset=*/0,
|
| 1038 |
+
&context->token_buffer,
|
| 1039 |
+
/*output_offset=*/context->num_tokens * sizeof(uint32_t),
|
| 1040 |
+
/*size=*/sizeof(uint32_t));
|
| 1041 |
+
if (status != gptoss_status_success) {
|
| 1042 |
+
GPTOSS_LOG_ERROR("failed to encode copy buffer");
|
| 1043 |
+
goto cleanup;
|
| 1044 |
+
}
|
| 1045 |
+
}
|
| 1046 |
+
context->num_tokens += 1;
|
| 1047 |
+
context->num_kv_tokens = context->num_tokens;
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
gptoss_metal_command_buffer_commit(&command_buffer);
|
| 1051 |
+
gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
|
| 1052 |
+
|
| 1053 |
+
const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
|
| 1054 |
+
const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
|
| 1055 |
+
memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
|
| 1056 |
+
*num_tokens_out = num_generated_tokens;
|
| 1057 |
+
|
| 1058 |
+
cleanup:
|
| 1059 |
+
gptoss_metal_command_buffer_release(&command_buffer);
|
| 1060 |
+
return status;
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
|
| 1064 |
+
gptoss_context_t context)
|
| 1065 |
+
{
|
| 1066 |
+
context->num_tokens = 0;
|
| 1067 |
+
|
| 1068 |
+
// Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
|
| 1069 |
+
// If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
|
| 1070 |
+
|
| 1071 |
+
return gptoss_status_success;
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_retain(
|
| 1075 |
+
gptoss_context_t context)
|
| 1076 |
+
{
|
| 1077 |
+
atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
|
| 1078 |
+
return gptoss_status_success;
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
+
enum gptoss_status GPTOSS_ABI gptoss_context_release(
|
| 1082 |
+
gptoss_context_t context)
|
| 1083 |
+
{
|
| 1084 |
+
if (context != NULL) {
|
| 1085 |
+
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
|
| 1086 |
+
// Activation buffers
|
| 1087 |
+
gptoss_metal_buffer_release(&context->residual_activation_buffer);
|
| 1088 |
+
gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
|
| 1089 |
+
gptoss_metal_buffer_release(&context->qkv_activation_buffer);
|
| 1090 |
+
gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
|
| 1091 |
+
gptoss_metal_buffer_release(&context->gate_activation_buffer);
|
| 1092 |
+
gptoss_metal_buffer_release(&context->expert_activation_buffer);
|
| 1093 |
+
gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
|
| 1094 |
+
gptoss_metal_buffer_release(&context->moe_activation_buffer);
|
| 1095 |
+
gptoss_metal_buffer_release(&context->expert_offset_buffer);
|
| 1096 |
+
gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
|
| 1097 |
+
gptoss_metal_buffer_release(&context->swiglu_input_buffer);
|
| 1098 |
+
|
| 1099 |
+
// Input/output buffers
|
| 1100 |
+
gptoss_metal_buffer_release(&context->control_buffer);
|
| 1101 |
+
gptoss_metal_buffer_release(&context->token_buffer);
|
| 1102 |
+
gptoss_metal_buffer_release(&context->score_buffer);
|
| 1103 |
+
gptoss_metal_buffer_release(&context->prob_buffer);
|
| 1104 |
+
gptoss_metal_buffer_release(&context->sum_buffer);
|
| 1105 |
+
gptoss_metal_buffer_release(&context->argmax_buffer);
|
| 1106 |
+
gptoss_metal_buffer_release(&context->kvcache_buffer);
|
| 1107 |
+
|
| 1108 |
+
gptoss_model_release(context->model);
|
| 1109 |
+
|
| 1110 |
+
memset(context, 0, sizeof(struct gptoss_context));
|
| 1111 |
+
free(context);
|
| 1112 |
+
}
|
| 1113 |
+
}
|
| 1114 |
+
return gptoss_status_success;
|
| 1115 |
+
}
|
gptoss_kernels/source/convert.metal
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_integer>
|
| 2 |
+
|
| 3 |
+
#include <internal/kernel-args.h>
|
| 4 |
+
|
| 5 |
+
#pragma METAL fp math_mode(safe)
|
| 6 |
+
#pragma METAL fp contract(off)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
kernel void gptoss_mf4_f32_convert(
|
| 10 |
+
constant gptoss_convert_args& args [[ buffer(0) ]],
|
| 11 |
+
const device uint4* blocks [[ buffer(1) ]],
|
| 12 |
+
const device uchar* scales [[ buffer(2) ]],
|
| 13 |
+
device float4* output [[ buffer(3) ]],
|
| 14 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 15 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 16 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 17 |
+
{
|
| 18 |
+
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
| 19 |
+
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
| 20 |
+
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
| 21 |
+
const ulong thread_start = threadgroup_start + tid;
|
| 22 |
+
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
| 23 |
+
|
| 24 |
+
blocks += thread_start;
|
| 25 |
+
scales += thread_start;
|
| 26 |
+
output += 8 * thread_start;
|
| 27 |
+
for (; num_iter != 0; num_iter--) {
|
| 28 |
+
const uint4 block = *blocks;
|
| 29 |
+
const float scale = as_type<float>((static_cast<uint>(*scales) + 14) << 23);
|
| 30 |
+
uint4 block02468ACEGIKMOQSU = block + block;
|
| 31 |
+
uint4 block13579BDFHJLNPRTV = block >> 3;
|
| 32 |
+
block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
| 33 |
+
block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
| 34 |
+
block02468ACEGIKMOQSU += 0x70707070u;
|
| 35 |
+
block13579BDFHJLNPRTV += 0x70707070u;
|
| 36 |
+
block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
| 37 |
+
block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
| 38 |
+
const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;
|
| 39 |
+
const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
| 40 |
+
const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;
|
| 41 |
+
const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
| 42 |
+
const float4 block048C = static_cast<float4>(as_type<half4>(block048CGKOS.xy)) * scale;
|
| 43 |
+
const float4 blockGKOS = static_cast<float4>(as_type<half4>(block048CGKOS.zw)) * scale;
|
| 44 |
+
const float4 block26AE = static_cast<float4>(as_type<half4>(block26AEIMQU.xy)) * scale;
|
| 45 |
+
const float4 blockIMQU = static_cast<float4>(as_type<half4>(block26AEIMQU.zw)) * scale;
|
| 46 |
+
const float4 block159D = static_cast<float4>(as_type<half4>(block159DHLPT.xy)) * scale;
|
| 47 |
+
const float4 blockHLPT = static_cast<float4>(as_type<half4>(block159DHLPT.zw)) * scale;
|
| 48 |
+
const float4 block37BF = static_cast<float4>(as_type<half4>(block37BFJNRV.xy)) * scale;
|
| 49 |
+
const float4 blockJNRV = static_cast<float4>(as_type<half4>(block37BFJNRV.zw)) * scale;
|
| 50 |
+
|
| 51 |
+
output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };
|
| 52 |
+
output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };
|
| 53 |
+
output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };
|
| 54 |
+
output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };
|
| 55 |
+
output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };
|
| 56 |
+
output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };
|
| 57 |
+
output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };
|
| 58 |
+
output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };
|
| 59 |
+
|
| 60 |
+
blocks += threadgroup_size;
|
| 61 |
+
scales += threadgroup_size;
|
| 62 |
+
output += 8 * threadgroup_size;
|
| 63 |
+
}
|
| 64 |
+
}
|
gptoss_kernels/source/embeddings.metal
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/kernel-args.h>
|
| 2 |
+
|
| 3 |
+
#pragma METAL fp math_mode(safe)
|
| 4 |
+
#pragma METAL fp contract(off)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
kernel void gptoss_bf16_f32_embeddings(
|
| 8 |
+
constant gptoss_embeddings_args& args [[ buffer(0) ]],
|
| 9 |
+
const device uint* tokens [[ buffer(1) ]],
|
| 10 |
+
const device bfloat4* weights [[ buffer(2) ]],
|
| 11 |
+
device float4* output [[ buffer(3) ]],
|
| 12 |
+
const device gptoss_control* control [[ buffer(4) ]],
|
| 13 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 14 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 15 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 16 |
+
{
|
| 17 |
+
if (control->abort != 0) {
|
| 18 |
+
return;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
const uint t = tokens[gid];
|
| 22 |
+
|
| 23 |
+
weights += t * args.num_vecs;
|
| 24 |
+
output += gid * args.num_vecs;
|
| 25 |
+
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
| 26 |
+
const bfloat4 w = weights[i];
|
| 27 |
+
output[i] = static_cast<float4>(w);
|
| 28 |
+
}
|
| 29 |
+
}
|
gptoss_kernels/source/expert_routing_metadata.metal
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/kernel-args.h>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_stdlib>
|
| 5 |
+
|
| 6 |
+
constant uint kMaxExperts = 128;
|
| 7 |
+
|
| 8 |
+
kernel void gptoss_f32_expert_routing_metadata(
|
| 9 |
+
constant gptoss_expert_routing_metadata_args& args [[ buffer(0) ]],
|
| 10 |
+
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(1) ]],
|
| 11 |
+
device uint* __restrict__ expert_offsets [[ buffer(2) ]],
|
| 12 |
+
device uint* __restrict__ intra_expert_offsets [[ buffer(3) ]],
|
| 13 |
+
uint tg_size [[threads_per_threadgroup]],
|
| 14 |
+
uint tid [[thread_position_in_threadgroup]])
|
| 15 |
+
{
|
| 16 |
+
assert(args.num_experts <= kMaxExperts);
|
| 17 |
+
// Create threadgroup mem and initialize it to 0.
|
| 18 |
+
threadgroup metal::atomic_uint tg_counts[kMaxExperts];
|
| 19 |
+
for (uint e = tid; e < args.num_experts; e += tg_size) {
|
| 20 |
+
metal::atomic_store_explicit(&tg_counts[e], 0u, metal::memory_order_relaxed);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 24 |
+
|
| 25 |
+
for (uint i = tid; i < args.tokens; i += tg_size) {
|
| 26 |
+
const uint e = expert_predictions[i].expert_id;
|
| 27 |
+
const uint r = metal::atomic_fetch_add_explicit(&tg_counts[e], 1u, metal::memory_order_relaxed);
|
| 28 |
+
intra_expert_offsets[i] = r;
|
| 29 |
+
}
|
| 30 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 31 |
+
|
| 32 |
+
if (tid == 0) {
|
| 33 |
+
uint total = 0;
|
| 34 |
+
for (uint e = 0; e < args.num_experts; ++e) {
|
| 35 |
+
const uint bin = metal::atomic_load_explicit(&tg_counts[e], metal::memory_order_relaxed);
|
| 36 |
+
expert_offsets[e] = total;
|
| 37 |
+
total += bin;
|
| 38 |
+
}
|
| 39 |
+
expert_offsets[args.num_experts] = total;
|
| 40 |
+
}
|
| 41 |
+
}
|
gptoss_kernels/source/gather_and_accumulate.metal
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/kernel-args.h>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_stdlib>
|
| 5 |
+
|
| 6 |
+
// TODO(ibrahim): This is not optimal as each thread only gathers and accumulates a single float4. To amortize the
|
| 7 |
+
// cost of reading the expert, offset and scales for a token, we should let each thread gather and accumulate several
|
| 8 |
+
// float4s.
|
| 9 |
+
kernel void gptoss_f32_gather_and_accumulate_e4(
|
| 10 |
+
constant gptoss_gather_args& args [[ buffer(0) ]],
|
| 11 |
+
const device float* in [[ buffer(1) ]],
|
| 12 |
+
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
|
| 13 |
+
const device uint* expert_offsets [[ buffer(3) ]],
|
| 14 |
+
const device uint* intra_expert_offsets [[ buffer(4) ]],
|
| 15 |
+
device float* out [[ buffer(5) ]],
|
| 16 |
+
uint3 gid [[thread_position_in_grid]])
|
| 17 |
+
{
|
| 18 |
+
const uint T = args.tokens;
|
| 19 |
+
const uint k = args.active_experts_per_token;
|
| 20 |
+
const uint D = args.token_stride;
|
| 21 |
+
|
| 22 |
+
assert((D & 3u) == 0);
|
| 23 |
+
assert(k == 4);
|
| 24 |
+
|
| 25 |
+
const uint row = gid.y;
|
| 26 |
+
if (row >= T) {
|
| 27 |
+
return;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
const uint col_vec4 = gid.x;
|
| 31 |
+
const uint col = col_vec4 * 4u;
|
| 32 |
+
if (col >= D) {
|
| 33 |
+
return;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
device float4* dst4 = reinterpret_cast<device float4*>(out + row * D + col);
|
| 37 |
+
|
| 38 |
+
const uint base = row * k;
|
| 39 |
+
const gptoss_expert_prediction expert0 = expert_predictions[base];
|
| 40 |
+
const gptoss_expert_prediction expert1 = expert_predictions[base + 1];
|
| 41 |
+
const gptoss_expert_prediction expert2 = expert_predictions[base + 2];
|
| 42 |
+
const gptoss_expert_prediction expert3 = expert_predictions[base + 3];
|
| 43 |
+
const uint expert0_id = expert0.expert_id;
|
| 44 |
+
const uint expert1_id = expert1.expert_id;
|
| 45 |
+
const uint expert2_id = expert2.expert_id;
|
| 46 |
+
const uint expert3_id = expert3.expert_id;
|
| 47 |
+
const float scale0 = expert0.score;
|
| 48 |
+
const float scale1 = expert1.score;
|
| 49 |
+
const float scale2 = expert2.score;
|
| 50 |
+
const float scale3 = expert3.score;
|
| 51 |
+
const uint4 current_intra_expert_offsets =
|
| 52 |
+
*reinterpret_cast<const device uint4*>(&intra_expert_offsets[base]);
|
| 53 |
+
// Get the row indices for the current expert ids
|
| 54 |
+
const uint r0 = expert_offsets[expert0_id] + current_intra_expert_offsets.x;
|
| 55 |
+
const uint r1 = expert_offsets[expert1_id] + current_intra_expert_offsets.y;
|
| 56 |
+
const uint r2 = expert_offsets[expert2_id] + current_intra_expert_offsets.z;
|
| 57 |
+
const uint r3 = expert_offsets[expert3_id] + current_intra_expert_offsets.w;
|
| 58 |
+
|
| 59 |
+
const device float4* src0 =
|
| 60 |
+
reinterpret_cast<const device float4*>(in + r0 * D + col);
|
| 61 |
+
const device float4* src1 =
|
| 62 |
+
reinterpret_cast<const device float4*>(in + r1 * D + col);
|
| 63 |
+
const device float4* src2 =
|
| 64 |
+
reinterpret_cast<const device float4*>(in + r2 * D + col);
|
| 65 |
+
const device float4* src3 =
|
| 66 |
+
reinterpret_cast<const device float4*>(in + r3 * D + col);
|
| 67 |
+
|
| 68 |
+
float4 acc = *dst4;
|
| 69 |
+
acc = metal::fma(*src0, scale0, acc);
|
| 70 |
+
acc = metal::fma(*src1, scale1, acc);
|
| 71 |
+
acc = metal::fma(*src2, scale2, acc);
|
| 72 |
+
acc = metal::fma(*src3, scale3, acc);
|
| 73 |
+
*dst4 = acc;
|
| 74 |
+
}
|
gptoss_kernels/source/generate.c
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <assert.h>
|
| 2 |
+
#include <inttypes.h>
|
| 3 |
+
#include <math.h>
|
| 4 |
+
#include <signal.h>
|
| 5 |
+
#include <stdatomic.h>
|
| 6 |
+
#include <stdbool.h>
|
| 7 |
+
#include <stdio.h>
|
| 8 |
+
#include <stdint.h>
|
| 9 |
+
#include <stdlib.h>
|
| 10 |
+
#include <string.h>
|
| 11 |
+
|
| 12 |
+
#include <mach/mach_time.h>
|
| 13 |
+
|
| 14 |
+
#include <gpt-oss.h>
|
| 15 |
+
|
| 16 |
+
#include "internal/model.h"
|
| 17 |
+
|
| 18 |
+
struct {
|
| 19 |
+
atomic_uint_least64_t inference_bytes;
|
| 20 |
+
atomic_size_t num_prefill_tokens;
|
| 21 |
+
atomic_uint_least64_t prefill_microseconds;
|
| 22 |
+
atomic_size_t num_generated_tokens;
|
| 23 |
+
atomic_uint_least64_t generation_microseconds;
|
| 24 |
+
} globals = {
|
| 25 |
+
.inference_bytes = 0,
|
| 26 |
+
.num_prefill_tokens = 0,
|
| 27 |
+
.prefill_microseconds = 0,
|
| 28 |
+
.num_generated_tokens = 0,
|
| 29 |
+
.generation_microseconds = 0,
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
struct options {
|
| 33 |
+
const char* model;
|
| 34 |
+
const char* prompt;
|
| 35 |
+
size_t context_length;
|
| 36 |
+
size_t max_tokens;
|
| 37 |
+
float temperature;
|
| 38 |
+
bool verbose;
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
| 42 |
+
static mach_timebase_info_data_t timebase_info = {0};
|
| 43 |
+
if (timebase_info.denom == 0) {
|
| 44 |
+
mach_timebase_info(&timebase_info);
|
| 45 |
+
}
|
| 46 |
+
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
| 47 |
+
return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
|
| 51 |
+
static mach_timebase_info_data_t timebase_info = {0};
|
| 52 |
+
if (timebase_info.denom == 0) {
|
| 53 |
+
mach_timebase_info(&timebase_info);
|
| 54 |
+
}
|
| 55 |
+
const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
|
| 56 |
+
const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
|
| 57 |
+
return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
static void print_usage(const char* program_name) {
|
| 61 |
+
printf("Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\n", program_name);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
struct options parse_options(int argc, char** argv) {
|
| 65 |
+
struct options options = (struct options) {
|
| 66 |
+
.model = NULL,
|
| 67 |
+
.prompt = NULL,
|
| 68 |
+
.context_length = 0,
|
| 69 |
+
.max_tokens = 0,
|
| 70 |
+
.temperature = 0.0f,
|
| 71 |
+
.verbose = false,
|
| 72 |
+
};
|
| 73 |
+
if (argc < 2) {
|
| 74 |
+
fprintf(stderr, "Error: missing required command-line argument\n");
|
| 75 |
+
print_usage(argv[0]);
|
| 76 |
+
exit(EXIT_FAILURE);
|
| 77 |
+
}
|
| 78 |
+
for (int i = 1; i < argc; i++) {
|
| 79 |
+
if (strcmp(argv[i], "--help") == 0) {
|
| 80 |
+
print_usage(argv[0]);
|
| 81 |
+
exit(EXIT_SUCCESS);
|
| 82 |
+
} else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
|
| 83 |
+
if (i + 1 >= argc) {
|
| 84 |
+
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 85 |
+
print_usage(argv[0]);
|
| 86 |
+
exit(EXIT_FAILURE);
|
| 87 |
+
}
|
| 88 |
+
options.prompt = argv[++i];
|
| 89 |
+
} else if (strcmp(argv[i], "--context-length") == 0) {
|
| 90 |
+
if (i + 1 >= argc) {
|
| 91 |
+
fprintf(stderr, "Error: missing argument for --context-length\n");
|
| 92 |
+
print_usage(argv[0]);
|
| 93 |
+
exit(EXIT_FAILURE);
|
| 94 |
+
}
|
| 95 |
+
char* context_length_start = argv[++i];
|
| 96 |
+
char* context_length_end = context_length_start;
|
| 97 |
+
options.context_length = strtoul(context_length_start, &context_length_end, 10);
|
| 98 |
+
if (context_length_end == context_length_start || *context_length_end != 0) {
|
| 99 |
+
fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
|
| 100 |
+
exit(EXIT_FAILURE);
|
| 101 |
+
}
|
| 102 |
+
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
|
| 103 |
+
if (i + 1 >= argc) {
|
| 104 |
+
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 105 |
+
print_usage(argv[0]);
|
| 106 |
+
exit(EXIT_FAILURE);
|
| 107 |
+
}
|
| 108 |
+
char* max_tokens_start = argv[++i];
|
| 109 |
+
char* max_tokens_end = max_tokens_start;
|
| 110 |
+
options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
|
| 111 |
+
if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
|
| 112 |
+
fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
|
| 113 |
+
exit(EXIT_FAILURE);
|
| 114 |
+
}
|
| 115 |
+
if (options.max_tokens == 0) {
|
| 116 |
+
fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
|
| 117 |
+
exit(EXIT_FAILURE);
|
| 118 |
+
}
|
| 119 |
+
} else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
|
| 120 |
+
if (i + 1 >= argc) {
|
| 121 |
+
fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
|
| 122 |
+
print_usage(argv[0]);
|
| 123 |
+
exit(EXIT_FAILURE);
|
| 124 |
+
}
|
| 125 |
+
char* temperature_start = argv[++i];
|
| 126 |
+
char* temperature_end = temperature_start;
|
| 127 |
+
options.temperature = strtof(temperature_start, &temperature_end);
|
| 128 |
+
if (temperature_end == temperature_start || *temperature_end != 0) {
|
| 129 |
+
fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
|
| 130 |
+
exit(EXIT_FAILURE);
|
| 131 |
+
}
|
| 132 |
+
if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
|
| 133 |
+
fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
|
| 134 |
+
exit(EXIT_FAILURE);
|
| 135 |
+
}
|
| 136 |
+
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
|
| 137 |
+
options.verbose = true;
|
| 138 |
+
} else {
|
| 139 |
+
if (options.model == NULL) {
|
| 140 |
+
options.model = argv[i];
|
| 141 |
+
} else {
|
| 142 |
+
fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
|
| 143 |
+
print_usage(argv[0]);
|
| 144 |
+
exit(EXIT_FAILURE);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
if (options.model == NULL) {
|
| 149 |
+
fprintf(stderr, "Error: missing required model argument\n");
|
| 150 |
+
print_usage(argv[0]);
|
| 151 |
+
exit(EXIT_FAILURE);
|
| 152 |
+
}
|
| 153 |
+
if (options.prompt == NULL) {
|
| 154 |
+
fprintf(stderr, "Error: missing required prompt argument\n");
|
| 155 |
+
print_usage(argv[0]);
|
| 156 |
+
exit(EXIT_FAILURE);
|
| 157 |
+
}
|
| 158 |
+
return options;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
static void print_profile() {
|
| 163 |
+
const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
|
| 164 |
+
const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
|
| 165 |
+
const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
|
| 166 |
+
const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
|
| 167 |
+
const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
|
| 168 |
+
if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
|
| 169 |
+
printf("\n");
|
| 170 |
+
}
|
| 171 |
+
if (num_prefill_tokens != 0) {
|
| 172 |
+
printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
|
| 173 |
+
num_prefill_tokens,
|
| 174 |
+
(double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
|
| 175 |
+
}
|
| 176 |
+
if (num_generated_tokens != 0) {
|
| 177 |
+
printf("Generation speed (%zu tokens): %.1f tokens/second\n",
|
| 178 |
+
num_generated_tokens,
|
| 179 |
+
(double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
static void ctrl_c_handler(int signum) {
|
| 184 |
+
print_profile();
|
| 185 |
+
exit(EXIT_SUCCESS);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
int main(int argc, char *argv[]) {
|
| 189 |
+
enum gptoss_status status;
|
| 190 |
+
gptoss_model_t model = NULL;
|
| 191 |
+
gptoss_tokenizer_t tokenizer = NULL;
|
| 192 |
+
gptoss_context_t context = NULL;
|
| 193 |
+
|
| 194 |
+
struct sigaction act;
|
| 195 |
+
act.sa_handler = ctrl_c_handler;
|
| 196 |
+
sigaction(SIGINT, &act, NULL);
|
| 197 |
+
|
| 198 |
+
setvbuf(stdout, NULL, _IONBF, 0);
|
| 199 |
+
|
| 200 |
+
struct options options = parse_options(argc, argv);
|
| 201 |
+
|
| 202 |
+
const uint64_t load_start_time = mach_continuous_time();
|
| 203 |
+
status = gptoss_model_create_from_file(options.model, &model);
|
| 204 |
+
if (status != gptoss_status_success) {
|
| 205 |
+
fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
|
| 206 |
+
goto error;
|
| 207 |
+
}
|
| 208 |
+
size_t max_model_context_length = 0;
|
| 209 |
+
status = gptoss_model_get_max_context_length(model, &max_model_context_length);
|
| 210 |
+
if (status != gptoss_status_success) {
|
| 211 |
+
fprintf(stderr, "Error: failed to query maximum context length\n");
|
| 212 |
+
goto error;
|
| 213 |
+
}
|
| 214 |
+
assert(max_model_context_length != 0);
|
| 215 |
+
if (options.context_length == 0) {
|
| 216 |
+
options.context_length = max_model_context_length;
|
| 217 |
+
} else if (options.context_length > max_model_context_length) {
|
| 218 |
+
fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
|
| 219 |
+
goto error;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
status = gptoss_model_get_tokenizer(model, &tokenizer);
|
| 223 |
+
if (status != gptoss_status_success) {
|
| 224 |
+
fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
|
| 225 |
+
goto error;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
uint32_t return_token_id = UINT32_MAX;
|
| 229 |
+
status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
|
| 230 |
+
if (status != gptoss_status_success) {
|
| 231 |
+
fprintf(stderr, "Error: failed to query end-of-text token ID\n");
|
| 232 |
+
goto error;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);
|
| 236 |
+
if (status != gptoss_status_success) {
|
| 237 |
+
fprintf(stderr, "Error: failed to create Context object\n");
|
| 238 |
+
goto error;
|
| 239 |
+
}
|
| 240 |
+
if (options.verbose) {
|
| 241 |
+
printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
|
| 242 |
+
printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
|
| 243 |
+
printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
|
| 244 |
+
printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
const uint64_t load_end_time = mach_continuous_time();
|
| 248 |
+
const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
|
| 249 |
+
if (options.verbose) {
|
| 250 |
+
printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
const uint64_t prefill_start_time = mach_continuous_time();
|
| 254 |
+
size_t num_prefill_tokens = 0;
|
| 255 |
+
status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
|
| 256 |
+
if (status != gptoss_status_success) {
|
| 257 |
+
fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
|
| 258 |
+
goto error;
|
| 259 |
+
}
|
| 260 |
+
atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
|
| 261 |
+
status = gptoss_context_process(context);
|
| 262 |
+
if (status != gptoss_status_success) {
|
| 263 |
+
fprintf(stderr, "Error: failed to process Context object\n");
|
| 264 |
+
goto error;
|
| 265 |
+
}
|
| 266 |
+
const uint64_t prefill_end_time = mach_continuous_time();
|
| 267 |
+
|
| 268 |
+
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
|
| 269 |
+
|
| 270 |
+
uint32_t predicted_token = UINT32_MAX;
|
| 271 |
+
size_t num_predicted_tokens = 0;
|
| 272 |
+
const uint64_t inference_start_timestamp = mach_continuous_time();
|
| 273 |
+
status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
|
| 274 |
+
if (status != gptoss_status_success) {
|
| 275 |
+
fprintf(stderr, "Error: failed to sample from the Context object\n");
|
| 276 |
+
goto error;
|
| 277 |
+
}
|
| 278 |
+
const uint64_t inference_end_timestamp = mach_continuous_time();
|
| 279 |
+
|
| 280 |
+
if (predicted_token == return_token_id) {
|
| 281 |
+
// Yield token -> stop generation
|
| 282 |
+
break;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// Unembedding: detokenize
|
| 286 |
+
size_t token_size = 0;
|
| 287 |
+
const void* token_ptr = NULL;
|
| 288 |
+
status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
|
| 289 |
+
if (status != gptoss_status_success) {
|
| 290 |
+
fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
|
| 291 |
+
goto error;
|
| 292 |
+
}
|
| 293 |
+
const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
|
| 294 |
+
if (previous_num_generated_tokens == 0) {
|
| 295 |
+
atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
|
| 296 |
+
} else {
|
| 297 |
+
atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
|
| 298 |
+
}
|
| 299 |
+
printf("%.*s", (int) token_size, (const char*) token_ptr);
|
| 300 |
+
|
| 301 |
+
status = gptoss_context_append_tokens(context, 1, &predicted_token);
|
| 302 |
+
if (status != gptoss_status_success) {
|
| 303 |
+
fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
|
| 304 |
+
goto error;
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
print_profile();
|
| 309 |
+
|
| 310 |
+
return EXIT_SUCCESS;
|
| 311 |
+
|
| 312 |
+
error:
|
| 313 |
+
gptoss_context_release(context);
|
| 314 |
+
gptoss_tokenizer_release(tokenizer);
|
| 315 |
+
gptoss_model_release(model);
|
| 316 |
+
return EXIT_FAILURE;
|
| 317 |
+
}
|
gptoss_kernels/source/include/internal/datatype.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stdint.h>
|
| 4 |
+
|
| 5 |
+
#include <internal/macros.h>
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 9 |
+
GPTOSS_ALIGN(2) uint16_t bits;
|
| 10 |
+
} gptoss_bfloat16;
|
| 11 |
+
static_assert(sizeof(gptoss_bfloat16) == 2, "bfloat16 size is not 2 bytes");
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 15 |
+
GPTOSS_ALIGN(2) uint16_t bits;
|
| 16 |
+
} gptoss_float16;
|
| 17 |
+
static_assert(sizeof(gptoss_float16) == 2, "float16 size is not 2 bytes");
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 21 |
+
GPTOSS_ALIGN(1) uint8_t bits;
|
| 22 |
+
} gptoss_float8ue8m0;
|
| 23 |
+
static_assert(sizeof(gptoss_float8ue8m0) == 1, "gptoss_float8ue8m0 size is not 1 bytes");
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 27 |
+
GPTOSS_ALIGN(1) uint8_t bits;
|
| 28 |
+
} gptoss_float8e5m2;
|
| 29 |
+
static_assert(sizeof(gptoss_float8e5m2) == 1, "float8e5m2 size is not 1 bytes");
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 33 |
+
GPTOSS_ALIGN(1) uint8_t bits;
|
| 34 |
+
} gptoss_float8e4m3;
|
| 35 |
+
static_assert(sizeof(gptoss_float8e4m3) == 1, "gptoss_float8e4m3 size is not 1 bytes");
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
|
| 39 |
+
GPTOSS_ALIGN(1) uint8_t bits;
|
| 40 |
+
} gptoss_float4e2m1x2;
|
| 41 |
+
static_assert(sizeof(gptoss_float4e2m1x2) == 1, "gptoss_float4e2m1x2 size is not 1 bytes");
|
gptoss_kernels/source/include/internal/datatype.hpp
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <bit>
|
| 4 |
+
|
| 5 |
+
#include <internal/datatype.h>
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
namespace gptoss {
|
| 9 |
+
|
| 10 |
+
template <typename WideT, typename NarrowT>
|
| 11 |
+
WideT upcast(NarrowT);
|
| 12 |
+
|
| 13 |
+
template <>
|
| 14 |
+
inline float upcast<float>(gptoss_bfloat16 bf16_value) {
|
| 15 |
+
const uint32_t bits = static_cast<uint32_t>(bf16_value.bits) << 16;
|
| 16 |
+
return std::bit_cast<float>(bits);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
template <>
|
| 20 |
+
inline float upcast<float>(gptoss_float16 fp16_value) {
|
| 21 |
+
return static_cast<float>(std::bit_cast<_Float16>(fp16_value.bits));
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
template <>
|
| 25 |
+
inline float upcast<float>(gptoss_float8e4m3 fp8_value) {
|
| 26 |
+
static constexpr uint16_t fp8e4m3_to_fp32[256] = {
|
| 27 |
+
0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60,
|
| 28 |
+
0x3C80, 0x3C90, 0x3CA0, 0x3CB0, 0x3CC0, 0x3CD0, 0x3CE0, 0x3CF0,
|
| 29 |
+
0x3D00, 0x3D10, 0x3D20, 0x3D30, 0x3D40, 0x3D50, 0x3D60, 0x3D70,
|
| 30 |
+
0x3D80, 0x3D90, 0x3DA0, 0x3DB0, 0x3DC0, 0x3DD0, 0x3DE0, 0x3DF0,
|
| 31 |
+
0x3E00, 0x3E10, 0x3E20, 0x3E30, 0x3E40, 0x3E50, 0x3E60, 0x3E70,
|
| 32 |
+
0x3E80, 0x3E90, 0x3EA0, 0x3EB0, 0x3EC0, 0x3ED0, 0x3EE0, 0x3EF0,
|
| 33 |
+
0x3F00, 0x3F10, 0x3F20, 0x3F30, 0x3F40, 0x3F50, 0x3F60, 0x3F70,
|
| 34 |
+
0x3F80, 0x3F90, 0x3FA0, 0x3FB0, 0x3FC0, 0x3FD0, 0x3FE0, 0x3FF0,
|
| 35 |
+
0x4000, 0x4010, 0x4020, 0x4030, 0x4040, 0x4050, 0x4060, 0x4070,
|
| 36 |
+
0x4080, 0x4090, 0x40A0, 0x40B0, 0x40C0, 0x40D0, 0x40E0, 0x40F0,
|
| 37 |
+
0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170,
|
| 38 |
+
0x4180, 0x4190, 0x41A0, 0x41B0, 0x41C0, 0x41D0, 0x41E0, 0x41F0,
|
| 39 |
+
0x4200, 0x4210, 0x4220, 0x4230, 0x4240, 0x4250, 0x4260, 0x4270,
|
| 40 |
+
0x4280, 0x4290, 0x42A0, 0x42B0, 0x42C0, 0x42D0, 0x42E0, 0x42F0,
|
| 41 |
+
0x4300, 0x4310, 0x4320, 0x4330, 0x4340, 0x4350, 0x4360, 0x4370,
|
| 42 |
+
0x4380, 0x4390, 0x43A0, 0x43B0, 0x43C0, 0x43D0, 0x43E0, 0x7FF0,
|
| 43 |
+
0x8000, 0xBB00, 0xBB80, 0xBBC0, 0xBC00, 0xBC20, 0xBC40, 0xBC60,
|
| 44 |
+
0xBC80, 0xBC90, 0xBCA0, 0xBCB0, 0xBCC0, 0xBCD0, 0xBCE0, 0xBCF0,
|
| 45 |
+
0xBD00, 0xBD10, 0xBD20, 0xBD30, 0xBD40, 0xBD50, 0xBD60, 0xBD70,
|
| 46 |
+
0xBD80, 0xBD90, 0xBDA0, 0xBDB0, 0xBDC0, 0xBDD0, 0xBDE0, 0xBDF0,
|
| 47 |
+
0xBE00, 0xBE10, 0xBE20, 0xBE30, 0xBE40, 0xBE50, 0xBE60, 0xBE70,
|
| 48 |
+
0xBE80, 0xBE90, 0xBEA0, 0xBEB0, 0xBEC0, 0xBED0, 0xBEE0, 0xBEF0,
|
| 49 |
+
0xBF00, 0xBF10, 0xBF20, 0xBF30, 0xBF40, 0xBF50, 0xBF60, 0xBF70,
|
| 50 |
+
0xBF80, 0xBF90, 0xBFA0, 0xBFB0, 0xBFC0, 0xBFD0, 0xBFE0, 0xBFF0,
|
| 51 |
+
0xC000, 0xC010, 0xC020, 0xC030, 0xC040, 0xC050, 0xC060, 0xC070,
|
| 52 |
+
0xC080, 0xC090, 0xC0A0, 0xC0B0, 0xC0C0, 0xC0D0, 0xC0E0, 0xC0F0,
|
| 53 |
+
0xC100, 0xC110, 0xC120, 0xC130, 0xC140, 0xC150, 0xC160, 0xC170,
|
| 54 |
+
0xC180, 0xC190, 0xC1A0, 0xC1B0, 0xC1C0, 0xC1D0, 0xC1E0, 0xC1F0,
|
| 55 |
+
0xC200, 0xC210, 0xC220, 0xC230, 0xC240, 0xC250, 0xC260, 0xC270,
|
| 56 |
+
0xC280, 0xC290, 0xC2A0, 0xC2B0, 0xC2C0, 0xC2D0, 0xC2E0, 0xC2F0,
|
| 57 |
+
0xC300, 0xC310, 0xC320, 0xC330, 0xC340, 0xC350, 0xC360, 0xC370,
|
| 58 |
+
0xC380, 0xC390, 0xC3A0, 0xC3B0, 0xC3C0, 0xC3D0, 0xC3E0, 0xFFF0,
|
| 59 |
+
};
|
| 60 |
+
const gptoss_bfloat16 bf16_value{.bits = fp8e4m3_to_fp32[fp8_value.bits]};
|
| 61 |
+
return upcast<float>(bf16_value);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
template <>
|
| 65 |
+
inline double upcast<double>(float fp32_value) {
|
| 66 |
+
return static_cast<double>(fp32_value);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <>
|
| 70 |
+
inline double upcast<double>(gptoss_bfloat16 bf16_value) {
|
| 71 |
+
const float fp32_value = upcast<float>(bf16_value);
|
| 72 |
+
return upcast<double>(fp32_value);
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
template <>
|
| 76 |
+
inline double upcast<double>(gptoss_float16 fp16_value) {
|
| 77 |
+
const float fp32_value = upcast<float>(fp16_value);
|
| 78 |
+
return upcast<double>(fp32_value);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
template <>
|
| 82 |
+
inline double upcast<double>(gptoss_float8e4m3 fp8_value) {
|
| 83 |
+
const float fp32_value = upcast<float>(fp8_value);
|
| 84 |
+
return upcast<double>(fp32_value);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
} // namespace gptoss
|
gptoss_kernels/source/include/internal/kernel-args.h
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#if !defined(__METAL_VERSION__)
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
// TODO(ibahmed): specalize using metal function constants.
|
| 8 |
+
#define QKV_Bm 64
|
| 9 |
+
#define QKV_Bn 64
|
| 10 |
+
#define QKV_Bk 32
|
| 11 |
+
#define QKV_Sg_Bm 32
|
| 12 |
+
#define QKV_Sg_Bn 32
|
| 13 |
+
|
| 14 |
+
#define ATTN_OUTPUT_Bm 32
|
| 15 |
+
#define ATTN_OUTPUT_Bn 64
|
| 16 |
+
#define ATTN_OUTPUT_Bk 64
|
| 17 |
+
#define ATTN_OUTPUT_Sg_Bm 32
|
| 18 |
+
#define ATTN_OUTPUT_Sg_Bn 16
|
| 19 |
+
|
| 20 |
+
#define MLP_GATE_Bm 64
|
| 21 |
+
#define MLP_GATE_Bn 16
|
| 22 |
+
#define MLP_GATE_Bk 64
|
| 23 |
+
#define MLP_GATE_Sg_Bm 16
|
| 24 |
+
#define MLP_GATE_Sg_Bn 16
|
| 25 |
+
|
| 26 |
+
#define MOE_DENSE_MATMUL_SWIGLU_Bm 32
|
| 27 |
+
#define MOE_DENSE_MATMUL_SWIGLU_Bn 64
|
| 28 |
+
#define MOE_DENSE_MATMUL_SWIGLU_Bk 16
|
| 29 |
+
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bm 32
|
| 30 |
+
#define MOE_DENSE_MATMUL_SWIGLU_Sg_Bn 16
|
| 31 |
+
|
| 32 |
+
#define MOE_DENSE_MATMUL_Bm 32
|
| 33 |
+
#define MOE_DENSE_MATMUL_Bn 64
|
| 34 |
+
#define MOE_DENSE_MATMUL_Bk 16
|
| 35 |
+
#define MOE_DENSE_MATMUL_Sg_Bm 32
|
| 36 |
+
#define MOE_DENSE_MATMUL_Sg_Bn 16
|
| 37 |
+
|
| 38 |
+
struct gptoss_expert_prediction {
|
| 39 |
+
uint32_t expert_id;
|
| 40 |
+
float score;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
struct gptoss_control {
|
| 44 |
+
uint32_t abort;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
struct gptoss_topk_args {
|
| 48 |
+
uint32_t num_vecs_per_token;
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
struct gptoss_sdpa_args {
|
| 52 |
+
uint32_t qkv_dim;
|
| 53 |
+
uint32_t num_kv_tokens;
|
| 54 |
+
uint32_t kv_stride;
|
| 55 |
+
uint32_t window;
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
struct gptoss_u32_fill_random_args {
|
| 59 |
+
uint64_t num_vecs_per_threadgroup;
|
| 60 |
+
uint64_t num_vecs;
|
| 61 |
+
uint64_t offset;
|
| 62 |
+
uint64_t seed;
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
struct gptoss_f32_fill_random_args {
|
| 66 |
+
uint64_t num_vecs_per_threadgroup;
|
| 67 |
+
uint64_t num_vecs;
|
| 68 |
+
uint64_t offset;
|
| 69 |
+
uint64_t seed;
|
| 70 |
+
float scale;
|
| 71 |
+
float bias;
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
struct gptoss_accumulate_args {
|
| 75 |
+
uint32_t num_vecs_per_expert;
|
| 76 |
+
uint32_t num_vecs_per_threadgroup;
|
| 77 |
+
uint32_t num_vecs;
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
struct gptoss_convert_args {
|
| 81 |
+
uint64_t num_vecs_per_threadgroup;
|
| 82 |
+
uint64_t num_vecs;
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
struct gptoss_embeddings_args {
|
| 86 |
+
uint32_t num_vecs;
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
struct gptoss_rmsnorm_args {
|
| 90 |
+
uint32_t num_vecs;
|
| 91 |
+
float num_channels;
|
| 92 |
+
float epsilon;
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
struct gptoss_matmul_args {
|
| 96 |
+
uint32_t num_column_vecs;
|
| 97 |
+
uint32_t num_rows;
|
| 98 |
+
uint32_t add;
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
struct gptoss_dense_matmul_args {
|
| 102 |
+
uint32_t m;
|
| 103 |
+
uint32_t n;
|
| 104 |
+
uint32_t k;
|
| 105 |
+
};
|
| 106 |
+
|
| 107 |
+
struct gptoss_scatter_args {
|
| 108 |
+
uint32_t tokens;
|
| 109 |
+
uint32_t active_experts_per_token;
|
| 110 |
+
uint32_t token_stride;
|
| 111 |
+
};
|
| 112 |
+
|
| 113 |
+
struct gptoss_moe_dense_matmul_swiglu_args {
|
| 114 |
+
uint32_t k;
|
| 115 |
+
uint32_t n;
|
| 116 |
+
uint32_t weight_blocks_expert_stride_bytes;
|
| 117 |
+
uint32_t weight_scales_expert_stride_bytes;
|
| 118 |
+
uint32_t bias_expert_stride_bytes;
|
| 119 |
+
float swiglu_min;
|
| 120 |
+
float swiglu_max;
|
| 121 |
+
};
|
| 122 |
+
struct gptoss_moe_dense_matmul_args {
|
| 123 |
+
uint32_t k;
|
| 124 |
+
uint32_t n;
|
| 125 |
+
uint32_t weight_blocks_expert_stride_bytes;
|
| 126 |
+
uint32_t weight_scales_expert_stride_bytes;
|
| 127 |
+
uint32_t bias_expert_stride_bytes;
|
| 128 |
+
};
|
| 129 |
+
|
| 130 |
+
struct gptoss_expert_routing_metadata_args {
|
| 131 |
+
uint32_t tokens;
|
| 132 |
+
uint32_t num_experts;
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
struct gptoss_gather_args {
|
| 136 |
+
uint32_t tokens;
|
| 137 |
+
uint32_t active_experts_per_token;
|
| 138 |
+
uint32_t token_stride;
|
| 139 |
+
};
|
| 140 |
+
|
| 141 |
+
struct gptoss_unembedding_args {
|
| 142 |
+
uint32_t num_column_vecs;
|
| 143 |
+
uint32_t num_rows_per_threadgroup;
|
| 144 |
+
uint32_t num_rows;
|
| 145 |
+
};
|
| 146 |
+
|
| 147 |
+
struct gptoss_moe_matmul_swiglu_args {
|
| 148 |
+
uint32_t num_column_vecs;
|
| 149 |
+
uint32_t num_rows;
|
| 150 |
+
uint32_t num_active_experts;
|
| 151 |
+
uint32_t weight_expert_stride; // in bytes
|
| 152 |
+
uint32_t output_expert_stride; // in elements
|
| 153 |
+
float swiglu_min;
|
| 154 |
+
float swiglu_max;
|
| 155 |
+
};
|
| 156 |
+
|
| 157 |
+
struct gptoss_moe_matmul_args {
|
| 158 |
+
uint32_t num_column_vecs;
|
| 159 |
+
uint32_t num_rows;
|
| 160 |
+
uint32_t num_active_experts;
|
| 161 |
+
uint32_t input_expert_stride; // in blocks of 32 elements
|
| 162 |
+
uint32_t weight_expert_stride; // in bytes
|
| 163 |
+
uint32_t output_expert_stride; // in elements
|
| 164 |
+
};
|
| 165 |
+
|
| 166 |
+
struct gptoss_rope_args {
|
| 167 |
+
uint32_t token_stride;
|
| 168 |
+
uint32_t token_offset;
|
| 169 |
+
float freq_scale;
|
| 170 |
+
float interpolation_scale;
|
| 171 |
+
float yarn_offset;
|
| 172 |
+
float yarn_scale;
|
| 173 |
+
float yarn_multiplier;
|
| 174 |
+
};
|
| 175 |
+
|
| 176 |
+
struct gptoss_qkv_args {
|
| 177 |
+
uint32_t num_column_vecs;
|
| 178 |
+
uint32_t num_rows;
|
| 179 |
+
uint32_t token_offset;
|
| 180 |
+
float freq_scale;
|
| 181 |
+
float interpolation_scale;
|
| 182 |
+
float yarn_offset;
|
| 183 |
+
float yarn_scale;
|
| 184 |
+
float yarn_multiplier;
|
| 185 |
+
uint32_t max_tokens;
|
| 186 |
+
};
|
| 187 |
+
|
| 188 |
+
struct gptoss_softmax_args {
|
| 189 |
+
uint32_t num_vecs;
|
| 190 |
+
uint32_t num_vecs_per_threadgroup;
|
| 191 |
+
uint32_t max_threadgroups;
|
| 192 |
+
float temperature;
|
| 193 |
+
};
|
| 194 |
+
|
| 195 |
+
struct gptoss_sample_args {
|
| 196 |
+
uint64_t rng_seed;
|
| 197 |
+
uint32_t rng_offset;
|
| 198 |
+
uint32_t num_blocks;
|
| 199 |
+
uint32_t num_dims;
|
| 200 |
+
uint32_t num_dims_per_block;
|
| 201 |
+
};
|
gptoss_kernels/source/include/internal/log.h
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stdarg.h>
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
void gptoss_format_log(const char* format, va_list args);
|
| 7 |
+
|
| 8 |
+
__attribute__((__format__(__printf__, 1, 2)))
|
| 9 |
+
inline static void gptoss_log(const char* format, ...) {
|
| 10 |
+
va_list args;
|
| 11 |
+
va_start(args, format);
|
| 12 |
+
gptoss_format_log(format, args);
|
| 13 |
+
va_end(args);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
#define GPTOSS_LOG_ERROR(message, ...) \
|
| 17 |
+
gptoss_log("Error: " message "\n", ##__VA_ARGS__)
|
| 18 |
+
|
| 19 |
+
#define GPTOSS_LOG_WARNING(message, ...) \
|
| 20 |
+
gptoss_log("Warning: " message "\n", ##__VA_ARGS__)
|
gptoss_kernels/source/include/internal/macros.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/***** Architecture detection macros *****/
|
| 4 |
+
|
| 5 |
+
#ifdef GPTOSS_ARCH_X86_64
|
| 6 |
+
#if GPTOSS_ARCH_X86_64 != 0 && GPTOSS_ARCH_X86_64 != 1
|
| 7 |
+
#error "Invalid GPTOSS_ARCH_X86_64 value: must be either 0 or 1"
|
| 8 |
+
#endif
|
| 9 |
+
#else
|
| 10 |
+
#if defined(__x86_64__) || defined(_M_X64) && !defined(_M_ARM64EC)
|
| 11 |
+
#define GPTOSS_ARCH_X86_64 1
|
| 12 |
+
#else
|
| 13 |
+
#define GPTOSS_ARCH_X86_64 0
|
| 14 |
+
#endif
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#ifdef GPTOSS_ARCH_ARM64
|
| 18 |
+
#if GPTOSS_ARCH_ARM64 != 0 && GPTOSS_ARCH_ARM64 != 1
|
| 19 |
+
#error "Invalid GPTOSS_ARCH_ARM64 value: must be either 0 or 1"
|
| 20 |
+
#endif
|
| 21 |
+
#else
|
| 22 |
+
#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
|
| 23 |
+
#define GPTOSS_ARCH_ARM64 1
|
| 24 |
+
#else
|
| 25 |
+
#define GPTOSS_ARCH_ARM64 0
|
| 26 |
+
#endif
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
+
#if GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 == 0
|
| 30 |
+
#error "Unsupported architecture: neither x86-64 nor ARM64 detected"
|
| 31 |
+
#elif GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 != 1
|
| 32 |
+
#error "Inconsistent architecture detection: both x86-64 and ARM64 detection macros are specified"
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
/***** Compiler portability macros *****/
|
| 36 |
+
|
| 37 |
+
#ifndef GPTOSS_LIKELY
|
| 38 |
+
#if defined(__GNUC__)
|
| 39 |
+
#define GPTOSS_LIKELY(condition) (__builtin_expect(!!(condition), 1))
|
| 40 |
+
#else
|
| 41 |
+
#define GPTOSS_LIKELY(condition) (!!(condition))
|
| 42 |
+
#endif
|
| 43 |
+
#endif
|
| 44 |
+
|
| 45 |
+
#ifndef GPTOSS_UNLIKELY
|
| 46 |
+
#if defined(__GNUC__)
|
| 47 |
+
#define GPTOSS_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
|
| 48 |
+
#else
|
| 49 |
+
#define GPTOSS_UNLIKELY(condition) (!!(condition))
|
| 50 |
+
#endif
|
| 51 |
+
#endif
|
| 52 |
+
|
| 53 |
+
#ifndef GPTOSS_UNPREDICTABLE
|
| 54 |
+
#if defined(__has_builtin)
|
| 55 |
+
#if __has_builtin(__builtin_unpredictable)
|
| 56 |
+
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))
|
| 57 |
+
#endif
|
| 58 |
+
#endif
|
| 59 |
+
#endif
|
| 60 |
+
#ifndef GPTOSS_UNPREDICTABLE
|
| 61 |
+
#if defined(__GNUC__) && (__GNUC__ >= 9) && !defined(__INTEL_COMPILER)
|
| 62 |
+
#define GPTOSS_UNPREDICTABLE(condition) (__builtin_expect_with_probability(!!(condition), 0, 0.5))
|
| 63 |
+
#else
|
| 64 |
+
#define GPTOSS_UNPREDICTABLE(condition) (!!(condition))
|
| 65 |
+
#endif
|
| 66 |
+
#endif
|
| 67 |
+
|
| 68 |
+
// Disable padding for structure members.
|
| 69 |
+
#ifndef GPTOSS_DENSELY_PACKED_STRUCTURE
|
| 70 |
+
#if defined(__GNUC__)
|
| 71 |
+
#define GPTOSS_DENSELY_PACKED_STRUCTURE __attribute__((__packed__))
|
| 72 |
+
#else
|
| 73 |
+
#error "Compiler-specific implementation of GPTOSS_DENSELY_PACKED_STRUCTURE required"
|
| 74 |
+
#endif
|
| 75 |
+
#endif
|
| 76 |
+
|
| 77 |
+
#ifndef GPTOSS_ALIGN
|
| 78 |
+
#if defined(__GNUC__)
|
| 79 |
+
#define GPTOSS_ALIGN(alignment) __attribute__((__aligned__(alignment)))
|
| 80 |
+
#elif defined(_MSC_VER)
|
| 81 |
+
#define GPTOSS_ALIGN(alignment) __declspec(align(alignment))
|
| 82 |
+
#else
|
| 83 |
+
#error "Compiler-specific implementation of GPTOSS_ALIGN required"
|
| 84 |
+
#endif
|
| 85 |
+
#endif
|
| 86 |
+
|
| 87 |
+
#ifndef GPTOSS_FORCE_INLINE
|
| 88 |
+
#if defined(__GNUC__)
|
| 89 |
+
#define GPTOSS_FORCE_INLINE inline __attribute__((__always_inline__))
|
| 90 |
+
#elif defined(_MSC_VER)
|
| 91 |
+
#define GPTOSS_FORCE_INLINE __forceinline
|
| 92 |
+
#else
|
| 93 |
+
#define GPTOSS_FORCE_INLINE inline
|
| 94 |
+
#endif
|
| 95 |
+
#endif
|
| 96 |
+
|
| 97 |
+
/***** Symbol visibility macros *****/
|
| 98 |
+
|
| 99 |
+
#ifndef GPTOSS_INTERNAL_SYMBOL
|
| 100 |
+
#if defined(__ELF__)
|
| 101 |
+
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("internal")))
|
| 102 |
+
#elif defined(__MACH__)
|
| 103 |
+
#define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("hidden")))
|
| 104 |
+
#else
|
| 105 |
+
#define GPTOSS_INTERNAL_SYMBOL
|
| 106 |
+
#endif
|
| 107 |
+
#endif
|
gptoss_kernels/source/include/internal/math.h
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <assert.h>
|
| 4 |
+
#include <stddef.h>
|
| 5 |
+
#include <stdint.h>
|
| 6 |
+
|
| 7 |
+
inline static size_t math_ceil_div(size_t numer, size_t denom) {
|
| 8 |
+
return (numer + denom - 1) / denom;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
inline static size_t math_max(size_t a, size_t b) {
|
| 12 |
+
return a >= b ? a : b;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
inline static size_t math_min(size_t a, size_t b) {
|
| 16 |
+
return a < b ? a : b;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
inline static size_t math_sub_sat(size_t a, size_t b) {
|
| 20 |
+
return a > b ? a - b : 0;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
static size_t math_round_down_po2(size_t number, size_t multiple) {
|
| 24 |
+
assert(multiple != 0);
|
| 25 |
+
assert((multiple & (multiple - 1)) == 0);
|
| 26 |
+
|
| 27 |
+
return number & -multiple;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
static size_t math_round_up_po2(size_t number, size_t multiple) {
|
| 31 |
+
assert(multiple != 0);
|
| 32 |
+
assert((multiple & (multiple - 1)) == 0);
|
| 33 |
+
|
| 34 |
+
const size_t multiple_mask = multiple - 1;
|
| 35 |
+
if ((number & multiple_mask) != 0) {
|
| 36 |
+
number |= multiple_mask;
|
| 37 |
+
number += 1;
|
| 38 |
+
}
|
| 39 |
+
return number;
|
| 40 |
+
}
|
gptoss_kernels/source/include/internal/metal-kernels.h
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stddef.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
|
| 6 |
+
#include <internal/metal.h>
|
| 7 |
+
|
| 8 |
+
#ifdef __cplusplus
|
| 9 |
+
extern "C" {
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
#include <stddef.h>
|
| 13 |
+
#include <stdint.h>
|
| 14 |
+
|
| 15 |
+
#include <internal/kernel-args.h>
|
| 16 |
+
#include <internal/math.h>
|
| 17 |
+
#include <internal/metal.h>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
| 21 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 22 |
+
const struct gptoss_metal_function* u32_fill_random_fn,
|
| 23 |
+
size_t threadgroup_size,
|
| 24 |
+
size_t max_threadgroups,
|
| 25 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 26 |
+
size_t output_offset,
|
| 27 |
+
uint64_t num_elements,
|
| 28 |
+
uint64_t rng_seed,
|
| 29 |
+
uint64_t rng_offset);
|
| 30 |
+
|
| 31 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
| 32 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 33 |
+
const struct gptoss_metal_function* f32_fill_random_fn,
|
| 34 |
+
size_t threadgroup_size,
|
| 35 |
+
size_t max_threadgroups,
|
| 36 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 37 |
+
size_t output_offset,
|
| 38 |
+
uint64_t num_elements,
|
| 39 |
+
uint64_t rng_seed,
|
| 40 |
+
uint64_t rng_offset,
|
| 41 |
+
float rng_min,
|
| 42 |
+
float rng_max);
|
| 43 |
+
|
| 44 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
| 45 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 46 |
+
const struct gptoss_metal_function* bf16_fill_random_fn,
|
| 47 |
+
size_t threadgroup_size,
|
| 48 |
+
size_t max_threadgroups,
|
| 49 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 50 |
+
size_t output_offset,
|
| 51 |
+
uint64_t num_elements,
|
| 52 |
+
uint64_t rng_seed,
|
| 53 |
+
uint64_t rng_offset,
|
| 54 |
+
float rng_min,
|
| 55 |
+
float rng_max);
|
| 56 |
+
|
| 57 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
| 58 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 59 |
+
const struct gptoss_metal_function* mf4_f32_convert_fn,
|
| 60 |
+
size_t threadgroup_size,
|
| 61 |
+
size_t max_threadgroups,
|
| 62 |
+
const struct gptoss_metal_buffer* block_buffer,
|
| 63 |
+
const struct gptoss_metal_buffer* scale_buffer,
|
| 64 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 65 |
+
uint64_t num_elements);
|
| 66 |
+
|
| 67 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
| 68 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 69 |
+
const struct gptoss_metal_function* bf16_f32_embeddings_fn,
|
| 70 |
+
size_t threadgroup_size,
|
| 71 |
+
const struct gptoss_metal_buffer* token_buffer,
|
| 72 |
+
size_t token_offset,
|
| 73 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 74 |
+
size_t weight_offset,
|
| 75 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 76 |
+
size_t output_offset,
|
| 77 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 78 |
+
size_t control_offset,
|
| 79 |
+
uint32_t num_tokens,
|
| 80 |
+
uint32_t num_channels);
|
| 81 |
+
|
| 82 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 83 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 84 |
+
const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
|
| 85 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 86 |
+
size_t input_offset,
|
| 87 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 88 |
+
size_t weight_offset,
|
| 89 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 90 |
+
size_t output_offset,
|
| 91 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 92 |
+
size_t control_offset,
|
| 93 |
+
uint32_t num_tokens,
|
| 94 |
+
uint32_t num_channels,
|
| 95 |
+
float epsilon);
|
| 96 |
+
|
| 97 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
| 98 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 99 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
| 100 |
+
size_t threadgroup_size,
|
| 101 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 102 |
+
size_t input_offset,
|
| 103 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 104 |
+
size_t weight_offset,
|
| 105 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 106 |
+
size_t bias_offset,
|
| 107 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 108 |
+
size_t output_offset,
|
| 109 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 110 |
+
size_t control_offset,
|
| 111 |
+
uint32_t num_tokens,
|
| 112 |
+
uint32_t num_cols,
|
| 113 |
+
uint32_t num_rows);
|
| 114 |
+
|
| 115 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
|
| 116 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 117 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
|
| 118 |
+
size_t threadgroup_size,
|
| 119 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 120 |
+
size_t input_offset,
|
| 121 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 122 |
+
size_t weight_offset,
|
| 123 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 124 |
+
size_t bias_offset,
|
| 125 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 126 |
+
size_t output_offset,
|
| 127 |
+
const struct gptoss_metal_buffer* kv_buffer,
|
| 128 |
+
size_t kv_offset,
|
| 129 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 130 |
+
size_t control_offset,
|
| 131 |
+
uint32_t num_tokens,
|
| 132 |
+
uint32_t num_cols,
|
| 133 |
+
uint32_t num_q_heads,
|
| 134 |
+
uint32_t num_kv_heads,
|
| 135 |
+
uint32_t attn_head_dim,
|
| 136 |
+
uint32_t token_offset,
|
| 137 |
+
uint32_t max_tokens,
|
| 138 |
+
float rope_base,
|
| 139 |
+
float interpolation_scale,
|
| 140 |
+
float yarn_offset,
|
| 141 |
+
float yarn_scale,
|
| 142 |
+
float yarn_multiplier);
|
| 143 |
+
|
| 144 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
| 145 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 146 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
| 147 |
+
size_t threadgroup_size,
|
| 148 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 149 |
+
size_t input_offset,
|
| 150 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 151 |
+
size_t weight_offset,
|
| 152 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 153 |
+
size_t bias_offset,
|
| 154 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 155 |
+
size_t output_offset,
|
| 156 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 157 |
+
size_t control_offset,
|
| 158 |
+
uint32_t num_tokens,
|
| 159 |
+
uint32_t num_cols,
|
| 160 |
+
uint32_t num_rows);
|
| 161 |
+
|
| 162 |
+
enum gptoss_status
|
| 163 |
+
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
|
| 164 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 165 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 166 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 167 |
+
size_t input_offset,
|
| 168 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 169 |
+
size_t weight_offset,
|
| 170 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 171 |
+
size_t bias_offset,
|
| 172 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 173 |
+
size_t output_offset,
|
| 174 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 175 |
+
size_t control_offset,
|
| 176 |
+
uint32_t num_tokens,
|
| 177 |
+
uint32_t num_cols,
|
| 178 |
+
uint32_t num_rows);
|
| 179 |
+
|
| 180 |
+
enum gptoss_status
|
| 181 |
+
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
|
| 182 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 183 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 184 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 185 |
+
size_t input_offset,
|
| 186 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 187 |
+
size_t weight_offset,
|
| 188 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 189 |
+
size_t bias_offset,
|
| 190 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 191 |
+
size_t output_offset,
|
| 192 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 193 |
+
size_t control_offset,
|
| 194 |
+
uint32_t num_tokens,
|
| 195 |
+
uint32_t num_cols,
|
| 196 |
+
uint32_t num_rows);
|
| 197 |
+
|
| 198 |
+
enum gptoss_status
|
| 199 |
+
gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
|
| 200 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 201 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 202 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 203 |
+
size_t input_offset,
|
| 204 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 205 |
+
size_t weight_offset,
|
| 206 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 207 |
+
size_t bias_offset,
|
| 208 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 209 |
+
size_t output_offset,
|
| 210 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 211 |
+
size_t control_offset,
|
| 212 |
+
uint32_t num_tokens,
|
| 213 |
+
uint32_t num_cols,
|
| 214 |
+
uint32_t num_rows);
|
| 215 |
+
|
| 216 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
| 217 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 218 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
| 219 |
+
size_t threadgroup_size,
|
| 220 |
+
size_t max_threadgroups,
|
| 221 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 222 |
+
size_t input_offset,
|
| 223 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 224 |
+
size_t weight_offset,
|
| 225 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 226 |
+
size_t output_offset,
|
| 227 |
+
const struct gptoss_metal_buffer* argmax_buffer,
|
| 228 |
+
size_t argmax_offset,
|
| 229 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 230 |
+
size_t control_offset,
|
| 231 |
+
uint32_t num_tokens,
|
| 232 |
+
uint32_t num_cols,
|
| 233 |
+
uint32_t num_rows);
|
| 234 |
+
|
| 235 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
| 236 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 237 |
+
const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
|
| 238 |
+
size_t threadgroup_size,
|
| 239 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 240 |
+
size_t input_offset,
|
| 241 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 242 |
+
size_t expert_offset,
|
| 243 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 244 |
+
size_t weight_block_offset,
|
| 245 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 246 |
+
size_t weight_scale_offset,
|
| 247 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 248 |
+
size_t bias_offset,
|
| 249 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 250 |
+
size_t output_offset,
|
| 251 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 252 |
+
size_t control_offset,
|
| 253 |
+
float swiglu_limit,
|
| 254 |
+
uint32_t expert_stride,
|
| 255 |
+
uint32_t num_tokens,
|
| 256 |
+
uint32_t num_active_experts,
|
| 257 |
+
uint32_t num_cols,
|
| 258 |
+
uint32_t num_rows);
|
| 259 |
+
|
| 260 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
| 261 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 262 |
+
const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
|
| 263 |
+
size_t threadgroup_size,
|
| 264 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 265 |
+
size_t input_offset,
|
| 266 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 267 |
+
size_t expert_offset,
|
| 268 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 269 |
+
size_t weight_block_offset,
|
| 270 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 271 |
+
size_t weight_scale_offset,
|
| 272 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 273 |
+
size_t bias_offset,
|
| 274 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 275 |
+
size_t output_offset,
|
| 276 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 277 |
+
size_t control_offset,
|
| 278 |
+
uint32_t expert_stride,
|
| 279 |
+
uint32_t num_tokens,
|
| 280 |
+
uint32_t num_active_experts,
|
| 281 |
+
uint32_t num_cols,
|
| 282 |
+
uint32_t num_rows);
|
| 283 |
+
|
| 284 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
|
| 285 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 286 |
+
const struct gptoss_metal_function* f32_rope_fn,
|
| 287 |
+
size_t threadgroup_size,
|
| 288 |
+
const struct gptoss_metal_buffer* activations_buffer,
|
| 289 |
+
size_t activations_offset,
|
| 290 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 291 |
+
size_t control_offset,
|
| 292 |
+
float rope_base,
|
| 293 |
+
float interpolation_scale,
|
| 294 |
+
float yarn_offset,
|
| 295 |
+
float yarn_scale,
|
| 296 |
+
float yarn_multiplier,
|
| 297 |
+
uint32_t num_tokens,
|
| 298 |
+
uint32_t num_q_heads,
|
| 299 |
+
uint32_t num_kv_heads,
|
| 300 |
+
uint32_t attn_head_dim,
|
| 301 |
+
uint32_t token_offset);
|
| 302 |
+
|
| 303 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
| 304 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 305 |
+
const struct gptoss_metal_function* f32_accumulate_fn,
|
| 306 |
+
size_t threadgroup_size,
|
| 307 |
+
size_t max_threadgroups,
|
| 308 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 309 |
+
size_t input_offset,
|
| 310 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 311 |
+
size_t expert_offset,
|
| 312 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 313 |
+
size_t output_offset,
|
| 314 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 315 |
+
size_t control_offset,
|
| 316 |
+
uint32_t num_channels,
|
| 317 |
+
uint32_t num_tokens,
|
| 318 |
+
uint32_t num_experts);
|
| 319 |
+
|
| 320 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
|
| 321 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 322 |
+
const struct gptoss_metal_function* expert_routing_metadata_fn,
|
| 323 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 324 |
+
size_t expert_predictions_offset,
|
| 325 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 326 |
+
size_t expert_offsets_offset,
|
| 327 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 328 |
+
size_t intra_expert_offsets_offset,
|
| 329 |
+
uint32_t num_tokens,
|
| 330 |
+
uint32_t num_experts);
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(
|
| 334 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 335 |
+
const struct gptoss_metal_function* f32_scatter_fn,
|
| 336 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 337 |
+
size_t input_offset,
|
| 338 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 339 |
+
size_t expert_predictions_offset,
|
| 340 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 341 |
+
size_t expert_offsets_offset,
|
| 342 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 343 |
+
size_t intra_expert_offsets_offset,
|
| 344 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 345 |
+
size_t output_offset,
|
| 346 |
+
uint32_t num_channels,
|
| 347 |
+
uint32_t num_tokens,
|
| 348 |
+
uint32_t num_active_experts);
|
| 349 |
+
|
| 350 |
+
enum gptoss_status
|
| 351 |
+
gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
|
| 352 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 353 |
+
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,
|
| 354 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 355 |
+
size_t expert_offsets_offset,
|
| 356 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 357 |
+
size_t input_offset,
|
| 358 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 359 |
+
size_t weight_block_offset,
|
| 360 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 361 |
+
size_t weight_scale_offset,
|
| 362 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 363 |
+
size_t bias_offset,
|
| 364 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 365 |
+
size_t output_offset,
|
| 366 |
+
float swiglu_limit,
|
| 367 |
+
uint32_t expert_stride_bytes,
|
| 368 |
+
uint32_t num_tokens,
|
| 369 |
+
uint32_t num_experts,
|
| 370 |
+
uint32_t num_cols,
|
| 371 |
+
uint32_t num_rows);
|
| 372 |
+
|
| 373 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
|
| 374 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 375 |
+
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,
|
| 376 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 377 |
+
size_t expert_offsets_offset,
|
| 378 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 379 |
+
size_t input_offset,
|
| 380 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 381 |
+
size_t weight_block_offset,
|
| 382 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 383 |
+
size_t weight_scale_offset,
|
| 384 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 385 |
+
size_t bias_offset,
|
| 386 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 387 |
+
size_t output_offset,
|
| 388 |
+
uint32_t expert_stride_bytes,
|
| 389 |
+
uint32_t num_tokens,
|
| 390 |
+
uint32_t num_experts,
|
| 391 |
+
uint32_t num_cols,
|
| 392 |
+
uint32_t num_rows);
|
| 393 |
+
|
| 394 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
|
| 395 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 396 |
+
const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,
|
| 397 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 398 |
+
size_t input_offset,
|
| 399 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 400 |
+
size_t expert_predictions_offset,
|
| 401 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 402 |
+
size_t expert_offsets_offset,
|
| 403 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 404 |
+
size_t intra_expert_offsets_offset,
|
| 405 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 406 |
+
size_t output_offset,
|
| 407 |
+
uint32_t num_channels,
|
| 408 |
+
uint32_t num_tokens,
|
| 409 |
+
uint32_t num_active_experts);
|
| 410 |
+
|
| 411 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 412 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 413 |
+
const struct gptoss_metal_function* f32_topk_fn,
|
| 414 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 415 |
+
size_t input_offset,
|
| 416 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 417 |
+
size_t output_offset,
|
| 418 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 419 |
+
size_t control_offset,
|
| 420 |
+
uint32_t num_tokens,
|
| 421 |
+
uint32_t num_experts,
|
| 422 |
+
uint32_t num_active_experts);
|
| 423 |
+
|
| 424 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
| 425 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 426 |
+
const struct gptoss_metal_function* f32_sdpa_fn,
|
| 427 |
+
const struct gptoss_metal_buffer* q_buffer,
|
| 428 |
+
size_t q_offset,
|
| 429 |
+
const struct gptoss_metal_buffer* kv_buffer,
|
| 430 |
+
size_t kv_offset,
|
| 431 |
+
const struct gptoss_metal_buffer* s_buffer,
|
| 432 |
+
size_t s_offset,
|
| 433 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 434 |
+
size_t output_offset,
|
| 435 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 436 |
+
size_t control_offset,
|
| 437 |
+
uint32_t window,
|
| 438 |
+
uint32_t kv_stride,
|
| 439 |
+
uint32_t num_q_tokens,
|
| 440 |
+
uint32_t num_kv_tokens,
|
| 441 |
+
uint32_t num_q_heads,
|
| 442 |
+
uint32_t num_kv_heads,
|
| 443 |
+
uint32_t head_dim);
|
| 444 |
+
|
| 445 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
| 446 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 447 |
+
const struct gptoss_metal_function* f32_softmax_fn,
|
| 448 |
+
size_t threadgroup_size,
|
| 449 |
+
size_t max_threadgroups,
|
| 450 |
+
const struct gptoss_metal_buffer* score_buffer,
|
| 451 |
+
size_t score_offset,
|
| 452 |
+
const struct gptoss_metal_buffer* argmax_buffer,
|
| 453 |
+
size_t argmax_offset,
|
| 454 |
+
const struct gptoss_metal_buffer* prob_buffer,
|
| 455 |
+
size_t prob_offset,
|
| 456 |
+
const struct gptoss_metal_buffer* sum_buffer,
|
| 457 |
+
size_t sum_offset,
|
| 458 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 459 |
+
size_t control_offset,
|
| 460 |
+
uint32_t num_channels,
|
| 461 |
+
uint32_t num_tokens,
|
| 462 |
+
float temperature,
|
| 463 |
+
uint32_t* num_threadgroups_out,
|
| 464 |
+
uint32_t* num_channels_per_threadgroup_out);
|
| 465 |
+
|
| 466 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
|
| 467 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 468 |
+
const struct gptoss_metal_function* f32_sample_fn,
|
| 469 |
+
size_t min_threadgroup_size,
|
| 470 |
+
const struct gptoss_metal_buffer* prob_buffer,
|
| 471 |
+
size_t prob_offset,
|
| 472 |
+
const struct gptoss_metal_buffer* sum_buffer,
|
| 473 |
+
size_t sum_offset,
|
| 474 |
+
const struct gptoss_metal_buffer* token_buffer,
|
| 475 |
+
size_t token_offset,
|
| 476 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 477 |
+
size_t control_offset,
|
| 478 |
+
uint64_t rng_seed,
|
| 479 |
+
uint32_t rng_offset,
|
| 480 |
+
uint32_t num_blocks,
|
| 481 |
+
uint32_t num_channels,
|
| 482 |
+
uint32_t num_channels_per_block);
|
| 483 |
+
|
| 484 |
+
#ifdef __cplusplus
|
| 485 |
+
} // extern "C"
|
| 486 |
+
#endif
|
gptoss_kernels/source/include/internal/metal.h
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stddef.h>
|
| 4 |
+
|
| 5 |
+
#include <gpt-oss/types.h>
|
| 6 |
+
|
| 7 |
+
#ifdef __cplusplus
|
| 8 |
+
extern "C" {
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
struct gptoss_metal_device {
|
| 12 |
+
void* object; // id<MTLDevice>
|
| 13 |
+
size_t num_cores;
|
| 14 |
+
size_t max_buffer_size;
|
| 15 |
+
size_t max_threadgroup_memory;
|
| 16 |
+
size_t max_threadgroup_threads_x;
|
| 17 |
+
size_t max_threadgroup_threads_y;
|
| 18 |
+
size_t max_threadgroup_threads_z;
|
| 19 |
+
};
|
| 20 |
+
|
| 21 |
+
enum gptoss_status gptoss_metal_device_create_system_default(
|
| 22 |
+
struct gptoss_metal_device* device_out);
|
| 23 |
+
|
| 24 |
+
enum gptoss_status gptoss_metal_device_release(
|
| 25 |
+
struct gptoss_metal_device* device);
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
struct gptoss_metal_library {
|
| 29 |
+
void* object; // id<MTLLibrary>
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
enum gptoss_status gptoss_metal_library_create_default(
|
| 33 |
+
const struct gptoss_metal_device* device,
|
| 34 |
+
struct gptoss_metal_library* library_out);
|
| 35 |
+
|
| 36 |
+
enum gptoss_status gptoss_metal_library_release(
|
| 37 |
+
struct gptoss_metal_library* library);
|
| 38 |
+
|
| 39 |
+
struct gptoss_metal_function {
|
| 40 |
+
void* function_object; // id<MTLFunction>
|
| 41 |
+
void* pipeline_state_object; // id<MTLComputePipelineState>
|
| 42 |
+
size_t max_threadgroup_threads;
|
| 43 |
+
size_t simdgroup_threads;
|
| 44 |
+
size_t static_threadgroup_memory;
|
| 45 |
+
};
|
| 46 |
+
|
| 47 |
+
enum gptoss_status gptoss_metal_function_create(
|
| 48 |
+
const struct gptoss_metal_library* library,
|
| 49 |
+
const char* name,
|
| 50 |
+
struct gptoss_metal_function* function_out);
|
| 51 |
+
|
| 52 |
+
enum gptoss_status gptoss_metal_function_release(
|
| 53 |
+
struct gptoss_metal_function* function);
|
| 54 |
+
|
| 55 |
+
struct gptoss_metal_buffer {
|
| 56 |
+
void* object; // id<MTLBuffer>
|
| 57 |
+
size_t size;
|
| 58 |
+
void* ptr;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
enum gptoss_status gptoss_metal_buffer_create(
|
| 62 |
+
const struct gptoss_metal_device* device,
|
| 63 |
+
size_t size,
|
| 64 |
+
const void* data,
|
| 65 |
+
struct gptoss_metal_buffer* buffer_out);
|
| 66 |
+
|
| 67 |
+
enum gptoss_status gptoss_metal_buffer_wrap(
|
| 68 |
+
const struct gptoss_metal_device* device,
|
| 69 |
+
size_t size,
|
| 70 |
+
const void* data,
|
| 71 |
+
struct gptoss_metal_buffer* buffer_out);
|
| 72 |
+
|
| 73 |
+
enum gptoss_status gptoss_metal_buffer_release(
|
| 74 |
+
struct gptoss_metal_buffer* buffer);
|
| 75 |
+
|
| 76 |
+
struct gptoss_metal_command_queue {
|
| 77 |
+
void* object; // id<MTLCommandQueue>
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
enum gptoss_status gptoss_metal_command_queue_create(
|
| 81 |
+
const struct gptoss_metal_device* device,
|
| 82 |
+
struct gptoss_metal_command_queue* command_queue_out);
|
| 83 |
+
|
| 84 |
+
enum gptoss_status gptoss_metal_command_queue_release(
|
| 85 |
+
struct gptoss_metal_command_queue* command_queue);
|
| 86 |
+
|
| 87 |
+
struct gptoss_metal_command_buffer {
|
| 88 |
+
void* object; // id<MTLCommandBuffer>
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
enum gptoss_status gptoss_metal_command_buffer_create(
|
| 92 |
+
const struct gptoss_metal_command_queue* command_queue,
|
| 93 |
+
struct gptoss_metal_command_buffer* command_buffer_out);
|
| 94 |
+
|
| 95 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
|
| 96 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 97 |
+
const struct gptoss_metal_buffer* buffer,
|
| 98 |
+
size_t offset,
|
| 99 |
+
size_t size,
|
| 100 |
+
uint8_t fill_value);
|
| 101 |
+
|
| 102 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
|
| 103 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 104 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 105 |
+
size_t input_offset,
|
| 106 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 107 |
+
size_t output_offset,
|
| 108 |
+
size_t size);
|
| 109 |
+
|
| 110 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
|
| 111 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 112 |
+
const struct gptoss_metal_function* function,
|
| 113 |
+
size_t threadgroup_size_x,
|
| 114 |
+
size_t threadgroup_size_y,
|
| 115 |
+
size_t threadgroup_size_z,
|
| 116 |
+
size_t num_threadgroups_x,
|
| 117 |
+
size_t num_threadgroups_y,
|
| 118 |
+
size_t num_threadgroups_z,
|
| 119 |
+
size_t params_size,
|
| 120 |
+
const void* params,
|
| 121 |
+
size_t num_device_buffers,
|
| 122 |
+
const struct gptoss_metal_buffer** device_buffers,
|
| 123 |
+
const size_t* device_buffer_offsets,
|
| 124 |
+
size_t threadgroup_buffer_size);
|
| 125 |
+
|
| 126 |
+
enum gptoss_status gptoss_metal_command_buffer_commit(
|
| 127 |
+
const struct gptoss_metal_command_buffer* command_buffer);
|
| 128 |
+
|
| 129 |
+
enum gptoss_status gptoss_metal_command_buffer_wait_completion(
|
| 130 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 131 |
+
double* elapsed_seconds);
|
| 132 |
+
|
| 133 |
+
enum gptoss_status gptoss_metal_command_buffer_release(
|
| 134 |
+
struct gptoss_metal_command_buffer* command_buffer);
|
| 135 |
+
|
| 136 |
+
#ifdef __cplusplus
|
| 137 |
+
} // extern "C"
|
| 138 |
+
#endif
|
gptoss_kernels/source/include/internal/metal.hpp
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <array>
|
| 4 |
+
#include <initializer_list>
|
| 5 |
+
#include <cstring>
|
| 6 |
+
#include <stdexcept>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
#include <gpt-oss/types.h>
|
| 10 |
+
#include <internal/metal.h>
|
| 11 |
+
#include <internal/metal-kernels.h>
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
namespace gptoss {
|
| 15 |
+
|
| 16 |
+
inline void Check(gptoss_status s, const char* what) {
|
| 17 |
+
if (s != gptoss_status_success) {
|
| 18 |
+
throw std::runtime_error(what);
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
inline std::size_t round_up(std::size_t p, std::size_t q) {
|
| 23 |
+
const std::size_t r = p % q;
|
| 24 |
+
if (r == 0) {
|
| 25 |
+
return p;
|
| 26 |
+
} else {
|
| 27 |
+
return p - r + q;
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
namespace metal {
|
| 32 |
+
|
| 33 |
+
class Device {
|
| 34 |
+
public:
|
| 35 |
+
inline Device() {
|
| 36 |
+
Check(gptoss_metal_device_create_system_default(&device_), "create Device");
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
inline ~Device() {
|
| 40 |
+
gptoss_metal_device_release(&device_);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
Device(const Device&) = delete;
|
| 44 |
+
Device& operator=(const Device&) = delete;
|
| 45 |
+
|
| 46 |
+
inline Device(Device&& other) noexcept {
|
| 47 |
+
device_ = other.device_;
|
| 48 |
+
std::memset(&other.device_, 0, sizeof(other.device_));
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
inline Device& operator=(Device&& other) noexcept {
|
| 52 |
+
if (this != &other) {
|
| 53 |
+
gptoss_metal_device_release(&device_);
|
| 54 |
+
device_ = other.device_;
|
| 55 |
+
std::memset(&other.device_, 0, sizeof(other.device_));
|
| 56 |
+
}
|
| 57 |
+
return *this;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
inline const gptoss_metal_device* handle() const noexcept { return &device_; }
|
| 61 |
+
|
| 62 |
+
inline size_t max_buffer_size() const noexcept { return device_.max_buffer_size; }
|
| 63 |
+
inline size_t max_threadgroup_memory() const noexcept { return device_.max_threadgroup_memory; }
|
| 64 |
+
inline size_t max_threadgroup_threads_x() const noexcept { return device_.max_threadgroup_threads_x; }
|
| 65 |
+
inline size_t max_threadgroup_threads_y() const noexcept { return device_.max_threadgroup_threads_y; }
|
| 66 |
+
inline size_t max_threadgroup_threads_z() const noexcept { return device_.max_threadgroup_threads_z; }
|
| 67 |
+
|
| 68 |
+
private:
|
| 69 |
+
gptoss_metal_device device_{};
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
class Library {
|
| 73 |
+
public:
|
| 74 |
+
inline explicit Library(const Device& dev) {
|
| 75 |
+
Check(gptoss_metal_library_create_default(dev.handle(), &library_),
|
| 76 |
+
"gptoss_metal_library_create_default");
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
inline ~Library() {
|
| 80 |
+
gptoss_metal_library_release(&library_);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
Library(const Library&) = delete;
|
| 84 |
+
Library& operator=(const Library&) = delete;
|
| 85 |
+
|
| 86 |
+
inline Library(Library&& other) noexcept {
|
| 87 |
+
library_ = other.library_;
|
| 88 |
+
std::memset(&other.library_, 0, sizeof(other.library_));
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
inline Library& operator=(Library&& other) noexcept {
|
| 92 |
+
if (this != &other) {
|
| 93 |
+
gptoss_metal_library_release(&library_);
|
| 94 |
+
library_ = other.library_;
|
| 95 |
+
std::memset(&other.library_, 0, sizeof(other.library_));
|
| 96 |
+
}
|
| 97 |
+
return *this;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
inline const gptoss_metal_library* handle() const noexcept {
|
| 101 |
+
return &library_;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
private:
|
| 105 |
+
gptoss_metal_library library_{};
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
class Function {
|
| 109 |
+
public:
|
| 110 |
+
inline Function(const Library& library, const char* name) {
|
| 111 |
+
Check(gptoss_metal_function_create(library.handle(), name, &function_),
|
| 112 |
+
"gptoss_metal_function_create");
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline ~Function() {
|
| 116 |
+
gptoss_metal_function_release(&function_);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
Function(const Function&) = delete;
|
| 120 |
+
Function& operator=(const Function&) = delete;
|
| 121 |
+
|
| 122 |
+
inline Function(Function&& other) noexcept {
|
| 123 |
+
function_ = other.function_;
|
| 124 |
+
std::memset(&other.function_, 0, sizeof(other.function_));
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
inline Function& operator=(Function&& other) noexcept {
|
| 128 |
+
if (this != &other) {
|
| 129 |
+
gptoss_metal_function_release(&function_);
|
| 130 |
+
function_ = other.function_;
|
| 131 |
+
std::memset(&other.function_, 0, sizeof(other.function_));
|
| 132 |
+
}
|
| 133 |
+
return *this;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
inline const gptoss_metal_function* handle() const noexcept { return &function_; }
|
| 137 |
+
|
| 138 |
+
inline size_t max_threadgroup_threads() const noexcept { return function_.max_threadgroup_threads; }
|
| 139 |
+
inline size_t simdgroup_threads() const noexcept { return function_.simdgroup_threads; }
|
| 140 |
+
inline size_t static_threadgroup_memory() const noexcept { return function_.static_threadgroup_memory; }
|
| 141 |
+
|
| 142 |
+
private:
|
| 143 |
+
gptoss_metal_function function_{};
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
class Buffer {
|
| 147 |
+
public:
|
| 148 |
+
inline Buffer(const Device& dev, size_t size, const void* data = nullptr) {
|
| 149 |
+
Check(gptoss_metal_buffer_create(dev.handle(), size, data, &buffer_), "create buffer");
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
inline ~Buffer() {
|
| 153 |
+
gptoss_metal_buffer_release(&buffer_);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
Buffer(const Buffer&) = delete;
|
| 157 |
+
Buffer& operator=(const Buffer&) = delete;
|
| 158 |
+
|
| 159 |
+
inline Buffer(Buffer&& other) noexcept {
|
| 160 |
+
buffer_ = other.buffer_;
|
| 161 |
+
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
inline Buffer& operator=(Buffer&& other) noexcept {
|
| 165 |
+
if (this != &other) {
|
| 166 |
+
gptoss_metal_buffer_release(&buffer_);
|
| 167 |
+
buffer_ = other.buffer_;
|
| 168 |
+
std::memset(&other.buffer_, 0, sizeof(other.buffer_));
|
| 169 |
+
}
|
| 170 |
+
return *this;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
inline size_t size() const noexcept { return buffer_.size; }
|
| 174 |
+
inline void* ptr() const noexcept { return buffer_.ptr; }
|
| 175 |
+
|
| 176 |
+
inline const gptoss_metal_buffer* handle() const noexcept { return &buffer_; }
|
| 177 |
+
|
| 178 |
+
private:
|
| 179 |
+
gptoss_metal_buffer buffer_{};
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
class CommandQueue {
|
| 183 |
+
public:
|
| 184 |
+
inline explicit CommandQueue(const Device& dev) {
|
| 185 |
+
Check(gptoss_metal_command_queue_create(dev.handle(), &command_queue_),
|
| 186 |
+
"gptoss_metal_command_queue_create");
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
inline ~CommandQueue() {
|
| 190 |
+
gptoss_metal_command_queue_release(&command_queue_);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
CommandQueue(const CommandQueue&) = delete;
|
| 194 |
+
CommandQueue& operator=(const CommandQueue&) = delete;
|
| 195 |
+
|
| 196 |
+
inline CommandQueue(CommandQueue&& other) noexcept {
|
| 197 |
+
command_queue_ = other.command_queue_;
|
| 198 |
+
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
inline CommandQueue& operator=(CommandQueue&& other) noexcept {
|
| 202 |
+
if (this != &other) {
|
| 203 |
+
gptoss_metal_command_queue_release(&command_queue_);
|
| 204 |
+
command_queue_ = other.command_queue_;
|
| 205 |
+
std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
|
| 206 |
+
}
|
| 207 |
+
return *this;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
inline const gptoss_metal_command_queue* handle() const noexcept {
|
| 211 |
+
return &command_queue_;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
private:
|
| 215 |
+
gptoss_metal_command_queue command_queue_{};
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
class CommandBuffer {
|
| 219 |
+
public:
|
| 220 |
+
inline explicit CommandBuffer(const CommandQueue& command_queue) {
|
| 221 |
+
Check(gptoss_metal_command_buffer_create(command_queue.handle(), &command_buffer_),
|
| 222 |
+
"gptoss_metal_command_buffer_create");
|
| 223 |
+
}
|
| 224 |
+
inline ~CommandBuffer() {
|
| 225 |
+
gptoss_metal_command_buffer_release(&command_buffer_);
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
CommandBuffer(const CommandBuffer&) = delete;
|
| 229 |
+
CommandBuffer& operator=(const CommandBuffer&) = delete;
|
| 230 |
+
|
| 231 |
+
inline CommandBuffer(CommandBuffer&& other) noexcept {
|
| 232 |
+
command_buffer_ = other.command_buffer_;
|
| 233 |
+
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
inline CommandBuffer& operator=(CommandBuffer&& other) noexcept {
|
| 237 |
+
if (this != &other) {
|
| 238 |
+
gptoss_metal_command_buffer_release(&command_buffer_);
|
| 239 |
+
command_buffer_ = other.command_buffer_;
|
| 240 |
+
std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
|
| 241 |
+
}
|
| 242 |
+
return *this;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
inline void encode_launch_kernel(const Function& function,
|
| 246 |
+
const std::array<size_t, 3>& threadgroup_size,
|
| 247 |
+
const std::array<size_t, 3>& num_threadgroups,
|
| 248 |
+
size_t params_size, const void* params,
|
| 249 |
+
std::initializer_list<const Buffer*> device_buffers = {},
|
| 250 |
+
size_t threadgroup_buffer_size = 0)
|
| 251 |
+
{
|
| 252 |
+
std::vector<const gptoss_metal_buffer*> buffer_handles(device_buffers.size());
|
| 253 |
+
std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),
|
| 254 |
+
[](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
|
| 255 |
+
Check(gptoss_metal_command_buffer_encode_launch_kernel(
|
| 256 |
+
&command_buffer_, function.handle(),
|
| 257 |
+
threadgroup_size[0], threadgroup_size[1], threadgroup_size[2],
|
| 258 |
+
num_threadgroups[0], num_threadgroups[1], num_threadgroups[2],
|
| 259 |
+
params_size, params,
|
| 260 |
+
buffer_handles.size(),
|
| 261 |
+
buffer_handles.data(),
|
| 262 |
+
/*buffer_offsets=*/nullptr,
|
| 263 |
+
threadgroup_buffer_size),
|
| 264 |
+
"gptoss_metal_command_buffer_encode_launch_kernel");
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
inline void encode_launch_f32_fill_random(const Function& f32_fill_random_fn,
|
| 268 |
+
size_t threadgroup_size,
|
| 269 |
+
size_t num_threadgroups,
|
| 270 |
+
const Buffer& output_buffer,
|
| 271 |
+
size_t output_offset,
|
| 272 |
+
size_t num_channels,
|
| 273 |
+
uint64_t rng_seed,
|
| 274 |
+
uint64_t rng_offset,
|
| 275 |
+
float rng_min,
|
| 276 |
+
float rng_max)
|
| 277 |
+
{
|
| 278 |
+
Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
| 279 |
+
&command_buffer_, f32_fill_random_fn.handle(),
|
| 280 |
+
threadgroup_size, num_threadgroups,
|
| 281 |
+
output_buffer.handle(), output_offset,
|
| 282 |
+
num_channels,
|
| 283 |
+
rng_seed, rng_offset, rng_min, rng_max),
|
| 284 |
+
"gptoss_metal_command_buffer_encode_launch_f32_fill_random");
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
inline void encode_launch_bf16_fill_random(const Function& bf16_fill_random_fn,
|
| 288 |
+
size_t threadgroup_size,
|
| 289 |
+
size_t num_threadgroups,
|
| 290 |
+
const Buffer& output_buffer,
|
| 291 |
+
size_t output_offset,
|
| 292 |
+
size_t num_channels,
|
| 293 |
+
uint64_t rng_seed,
|
| 294 |
+
uint64_t rng_offset,
|
| 295 |
+
float rng_min,
|
| 296 |
+
float rng_max)
|
| 297 |
+
{
|
| 298 |
+
Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
| 299 |
+
&command_buffer_, bf16_fill_random_fn.handle(),
|
| 300 |
+
threadgroup_size, num_threadgroups,
|
| 301 |
+
output_buffer.handle(), output_offset,
|
| 302 |
+
num_channels,
|
| 303 |
+
rng_seed, rng_offset, rng_min, rng_max),
|
| 304 |
+
"gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
inline void encode_launch_u32_fill_random(const Function& u32_fill_random_fn,
|
| 308 |
+
size_t threadgroup_size,
|
| 309 |
+
size_t num_threadgroups,
|
| 310 |
+
const Buffer& output_buffer,
|
| 311 |
+
size_t output_offset,
|
| 312 |
+
size_t num_channels,
|
| 313 |
+
uint64_t rng_seed,
|
| 314 |
+
uint64_t rng_offset)
|
| 315 |
+
{
|
| 316 |
+
Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
| 317 |
+
&command_buffer_, u32_fill_random_fn.handle(),
|
| 318 |
+
threadgroup_size, num_threadgroups,
|
| 319 |
+
output_buffer.handle(), output_offset,
|
| 320 |
+
num_channels,
|
| 321 |
+
rng_seed, rng_offset),
|
| 322 |
+
"gptoss_metal_command_buffer_encode_launch_u32_fill_random");
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
inline void commit() {
|
| 326 |
+
Check(gptoss_metal_command_buffer_commit(&command_buffer_), "commit");
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
inline double wait_completion() {
|
| 330 |
+
double secs = 0.0;
|
| 331 |
+
Check(gptoss_metal_command_buffer_wait_completion(&command_buffer_, &secs), "wait completion");
|
| 332 |
+
return secs;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
inline const gptoss_metal_command_buffer* handle() const noexcept { return &command_buffer_; }
|
| 336 |
+
|
| 337 |
+
private:
|
| 338 |
+
gptoss_metal_command_buffer command_buffer_{};
|
| 339 |
+
};
|
| 340 |
+
|
| 341 |
+
} // namespace metal
|
| 342 |
+
} // namespace gptoss
|
gptoss_kernels/source/include/internal/model.h
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifndef __cplusplus
|
| 4 |
+
#include <stdatomic.h>
|
| 5 |
+
#endif
|
| 6 |
+
#include <stdbool.h>
|
| 7 |
+
#include <stddef.h>
|
| 8 |
+
#include <stdint.h>
|
| 9 |
+
|
| 10 |
+
#include "internal/metal.h"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
struct gptoss_tokenizer {
|
| 14 |
+
#ifndef __cplusplus
|
| 15 |
+
atomic_uint_least64_t ref_count;
|
| 16 |
+
#else
|
| 17 |
+
uint_least64_t ref_count;
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
+
void* mapping_ptr;
|
| 21 |
+
size_t mapping_size;
|
| 22 |
+
|
| 23 |
+
const char* regex_ptr;
|
| 24 |
+
const char* tokens_ptr;
|
| 25 |
+
|
| 26 |
+
uint32_t num_text_tokens;
|
| 27 |
+
uint32_t num_special_tokens;
|
| 28 |
+
|
| 29 |
+
uint32_t special_token_id[gptoss_special_token_max - 1];
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
struct gptoss_model {
|
| 33 |
+
#ifndef __cplusplus
|
| 34 |
+
atomic_uint_least64_t ref_count;
|
| 35 |
+
#else
|
| 36 |
+
uint_least64_t ref_count;
|
| 37 |
+
#endif
|
| 38 |
+
|
| 39 |
+
struct gptoss_tokenizer* tokenizer;
|
| 40 |
+
|
| 41 |
+
void* mapping_ptr;
|
| 42 |
+
size_t mapping_size;
|
| 43 |
+
|
| 44 |
+
uint32_t context_length;
|
| 45 |
+
uint32_t num_blocks;
|
| 46 |
+
uint32_t num_experts;
|
| 47 |
+
uint32_t num_active_experts;
|
| 48 |
+
uint32_t embedding_dim;
|
| 49 |
+
uint32_t mlp_dim;
|
| 50 |
+
float swiglu_limit;
|
| 51 |
+
uint32_t head_dim;
|
| 52 |
+
uint32_t num_heads;
|
| 53 |
+
uint32_t num_kv_heads;
|
| 54 |
+
uint32_t attention_window;
|
| 55 |
+
float rope_theta;
|
| 56 |
+
float interpolation_scale;
|
| 57 |
+
float yarn_offset;
|
| 58 |
+
float yarn_scale;
|
| 59 |
+
float yarn_multiplier;
|
| 60 |
+
float rmsnorm_epsilon;
|
| 61 |
+
|
| 62 |
+
uint32_t vocabulary_size;
|
| 63 |
+
|
| 64 |
+
bool lock_memory;
|
| 65 |
+
|
| 66 |
+
size_t weights_size;
|
| 67 |
+
size_t allocation_size;
|
| 68 |
+
|
| 69 |
+
// Metal objects
|
| 70 |
+
struct gptoss_metal_device device;
|
| 71 |
+
size_t max_threadgroups;
|
| 72 |
+
struct gptoss_metal_command_queue command_queue;
|
| 73 |
+
struct gptoss_metal_library library;
|
| 74 |
+
struct gptoss_metal_function bf16_f32_embeddings_fn;
|
| 75 |
+
struct gptoss_metal_function f32_bf16w_rmsnorm_fn;
|
| 76 |
+
struct gptoss_metal_function f32_bf16w_matmul_fn;
|
| 77 |
+
struct gptoss_metal_function f32_bf16w_matmul_qkv_fn;
|
| 78 |
+
struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn;
|
| 79 |
+
struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn;
|
| 80 |
+
struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn;
|
| 81 |
+
struct gptoss_metal_function f32_bf16w_unembedding_fn;
|
| 82 |
+
struct gptoss_metal_function f32_rope_fn;
|
| 83 |
+
struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn;
|
| 84 |
+
struct gptoss_metal_function f32_mf4w_moe_matmul_fn;
|
| 85 |
+
struct gptoss_metal_function f32_accumulate_e4_fn;
|
| 86 |
+
struct gptoss_metal_function f32_scatter_e4_fn;
|
| 87 |
+
struct gptoss_metal_function f32_mf4w_moe_dense_matmul_swiglu_fn;
|
| 88 |
+
struct gptoss_metal_function f32_mf4w_moe_dense_matmul_fn;
|
| 89 |
+
struct gptoss_metal_function f32_gather_and_accumulate_e4_fn;
|
| 90 |
+
struct gptoss_metal_function f32_expert_routing_metadata_fn;
|
| 91 |
+
struct gptoss_metal_function f32_topk_softmax_e32_k4_fn;
|
| 92 |
+
struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;
|
| 93 |
+
struct gptoss_metal_function f32_sdpa_q8_d64_fn;
|
| 94 |
+
struct gptoss_metal_function f32_softmax_fn;
|
| 95 |
+
struct gptoss_metal_function f32_sample_fn;
|
| 96 |
+
|
| 97 |
+
size_t per_block_shared_weights_size;
|
| 98 |
+
size_t per_expert_block_weight_size;
|
| 99 |
+
|
| 100 |
+
size_t embeddings_threadgroup_size;
|
| 101 |
+
size_t attn_qkv_threadgroup_size;
|
| 102 |
+
size_t attn_out_threadgroup_size;
|
| 103 |
+
size_t mlp_gate_threadgroup_size;
|
| 104 |
+
size_t mlp_swiglu_threadgroup_size;
|
| 105 |
+
size_t mlp_out_threadgroup_size;
|
| 106 |
+
size_t mlp_acc_threadgroup_size;
|
| 107 |
+
size_t unembedding_threadgroup_size;
|
| 108 |
+
|
| 109 |
+
size_t attn_rmsnorm_gain_offset;
|
| 110 |
+
size_t attn_qkv_weight_offset;
|
| 111 |
+
size_t attn_qkv_bias_offset;
|
| 112 |
+
size_t attn_sdpa_sink_offset;
|
| 113 |
+
size_t attn_out_weight_offset;
|
| 114 |
+
size_t attn_out_bias_offset;
|
| 115 |
+
size_t mlp_rmsnorm_gain_offset;
|
| 116 |
+
size_t mlp_gate_weight_offset;
|
| 117 |
+
size_t mlp_gate_bias_offset;
|
| 118 |
+
size_t mlp_swiglu_scale_offset;
|
| 119 |
+
size_t mlp_swiglu_bias_offset;
|
| 120 |
+
size_t mlp_out_block_offset;
|
| 121 |
+
size_t mlp_out_scale_offset;
|
| 122 |
+
size_t mlp_out_bias_offset;
|
| 123 |
+
size_t rmsnorm_weight_offset;
|
| 124 |
+
size_t unembedding_weight_offset;
|
| 125 |
+
|
| 126 |
+
// Buffer with non-MoE weights. Includes MoE gates, embeddings/unembeddings.
|
| 127 |
+
struct gptoss_metal_buffer shared_weight_buffer;
|
| 128 |
+
// num_blocks per-block buffers with MoE weights to follow.
|
| 129 |
+
struct gptoss_metal_buffer block_weight_buffers[];
|
| 130 |
+
};
|
| 131 |
+
|
| 132 |
+
#define GPTOSS_DEFAULT_BATCH_SIZE 128
|
| 133 |
+
|
| 134 |
+
struct gptoss_context {
|
| 135 |
+
#ifndef __cplusplus
|
| 136 |
+
atomic_uint_least64_t ref_count;
|
| 137 |
+
#else
|
| 138 |
+
uint_least64_t ref_count;
|
| 139 |
+
#endif
|
| 140 |
+
|
| 141 |
+
struct gptoss_model* model;
|
| 142 |
+
// Number of tokens processed in the context.
|
| 143 |
+
size_t num_tokens;
|
| 144 |
+
// Number of tokens in the KV cache.
|
| 145 |
+
size_t num_kv_tokens;
|
| 146 |
+
// Length of the context.
|
| 147 |
+
size_t max_tokens;
|
| 148 |
+
// Maximum number of tokens that can be processed in a single batch.
|
| 149 |
+
// Activation buffers are allocated with this size.
|
| 150 |
+
size_t max_batch_tokens;
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
size_t kvcache_size;
|
| 154 |
+
size_t allocation_size;
|
| 155 |
+
|
| 156 |
+
// Activation buffers.
|
| 157 |
+
// TODO: merge into a single buffer.
|
| 158 |
+
struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
|
| 159 |
+
struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
|
| 160 |
+
struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
|
| 161 |
+
struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
|
| 162 |
+
struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
|
| 163 |
+
struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
|
| 164 |
+
struct gptoss_metal_buffer expert_offset_buffer; // MoE expert histograms cumsum
|
| 165 |
+
struct gptoss_metal_buffer token_to_expert_routing_buffer; // MoE token to expert routing
|
| 166 |
+
struct gptoss_metal_buffer swiglu_input_buffer; // MLP+SwiGLU input for prefill.
|
| 167 |
+
struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
|
| 168 |
+
struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
|
| 169 |
+
|
| 170 |
+
// Input/output buffers.
|
| 171 |
+
struct gptoss_metal_buffer control_buffer;
|
| 172 |
+
struct gptoss_metal_buffer token_buffer; // uint32 token IDs
|
| 173 |
+
struct gptoss_metal_buffer score_buffer; // unembedding outputs
|
| 174 |
+
struct gptoss_metal_buffer prob_buffer;
|
| 175 |
+
struct gptoss_metal_buffer sum_buffer;
|
| 176 |
+
struct gptoss_metal_buffer argmax_buffer;
|
| 177 |
+
struct gptoss_metal_buffer kvcache_buffer;
|
| 178 |
+
};
|
gptoss_kernels/source/include/internal/rng.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stdint.h>
|
| 4 |
+
|
| 5 |
+
inline static uint32_t rng_squares32(uint64_t offset, uint64_t seed) {
|
| 6 |
+
const uint64_t y = offset * seed;
|
| 7 |
+
const uint64_t z = y + seed;
|
| 8 |
+
|
| 9 |
+
/* Round 1 */
|
| 10 |
+
uint64_t x = y * y + y;
|
| 11 |
+
x = (x >> 32) | (x << 32);
|
| 12 |
+
|
| 13 |
+
/* Round 2 */
|
| 14 |
+
x = x * x + z;
|
| 15 |
+
x = (x >> 32) | (x << 32);
|
| 16 |
+
|
| 17 |
+
/* Round 3 */
|
| 18 |
+
x = x * x + y;
|
| 19 |
+
x = (x >> 32) | (x << 32);
|
| 20 |
+
|
| 21 |
+
/* Round 4 */
|
| 22 |
+
x = x * x + z;
|
| 23 |
+
return (uint32_t) (x >> 32);
|
| 24 |
+
}
|
gptoss_kernels/source/include/internal/rng.hpp
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
namespace gptoss {
|
| 6 |
+
|
| 7 |
+
namespace rng {
|
| 8 |
+
|
| 9 |
+
inline static std::uint32_t squares32(std::uint64_t offset, std::uint64_t seed) {
|
| 10 |
+
const std::uint64_t y = offset * seed;
|
| 11 |
+
const std::uint64_t z = y + seed;
|
| 12 |
+
|
| 13 |
+
/* Round 1 */
|
| 14 |
+
std::uint64_t x = y * y + y;
|
| 15 |
+
x = (x >> 32) | (x << 32);
|
| 16 |
+
|
| 17 |
+
/* Round 2 */
|
| 18 |
+
x = x * x + z;
|
| 19 |
+
x = (x >> 32) | (x << 32);
|
| 20 |
+
|
| 21 |
+
/* Round 3 */
|
| 22 |
+
x = x * x + y;
|
| 23 |
+
x = (x >> 32) | (x << 32);
|
| 24 |
+
|
| 25 |
+
/* Round 4 */
|
| 26 |
+
x = x * x + z;
|
| 27 |
+
return static_cast<uint32_t>(x >> 32);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
} // namespace rng
|
| 31 |
+
|
| 32 |
+
} // namespace gptoss
|
gptoss_kernels/source/include/internal/storage.h
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stdbool.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
|
| 6 |
+
struct gptoss_file_header {
|
| 7 |
+
char magic[12];
|
| 8 |
+
uint32_t zero;
|
| 9 |
+
};
|
| 10 |
+
|
| 11 |
+
struct gptoss_gptoss_model_header {
|
| 12 |
+
uint32_t context_length;
|
| 13 |
+
uint32_t num_blocks;
|
| 14 |
+
uint32_t num_experts;
|
| 15 |
+
uint32_t num_active_experts;
|
| 16 |
+
uint32_t embedding_dim;
|
| 17 |
+
uint32_t mlp_dim;
|
| 18 |
+
float swiglu_limit;
|
| 19 |
+
uint32_t head_dim;
|
| 20 |
+
uint32_t num_heads;
|
| 21 |
+
uint32_t num_kv_heads;
|
| 22 |
+
uint32_t attention_window;
|
| 23 |
+
float rope_theta;
|
| 24 |
+
float interpolation_scale;
|
| 25 |
+
float yarn_offset;
|
| 26 |
+
float yarn_scale;
|
| 27 |
+
float yarn_multiplier;
|
| 28 |
+
float rmsnorm_epsilon;
|
| 29 |
+
};
|
| 30 |
+
|
| 31 |
+
struct gptoss_tiktoken_tokenizer_header {
|
| 32 |
+
uint32_t num_special_tokens;
|
| 33 |
+
uint32_t num_text_tokens;
|
| 34 |
+
uint32_t regex_size;
|
| 35 |
+
uint32_t tokens_size;
|
| 36 |
+
};
|
gptoss_kernels/source/include/internal/uuid.h
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <stdbool.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
#include <string.h>
|
| 6 |
+
|
| 7 |
+
#include "internal/macros.h"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
struct GPTOSS_DENSELY_PACKED_STRUCTURE gptoss_uuid {
|
| 11 |
+
uint8_t bytes[16];
|
| 12 |
+
};
|
| 13 |
+
static_assert(sizeof(struct gptoss_uuid) == 16, "UUID size is not 16 bytes");
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#define UUID_FORMAT "%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X"
|
| 17 |
+
#define UUID_ARGS(uuid) (uuid).bytes[0], (uuid).bytes[1], (uuid).bytes[2], (uuid).bytes[3], \
|
| 18 |
+
(uuid).bytes[4], (uuid).bytes[5], (uuid).bytes[6], (uuid).bytes[7], (uuid).bytes[8], (uuid).bytes[9], \
|
| 19 |
+
(uuid).bytes[10], (uuid).bytes[11], (uuid).bytes[12], (uuid).bytes[13], (uuid).bytes[14], (uuid).bytes[15]
|
| 20 |
+
|
| 21 |
+
static inline bool gptoss_is_gptoss_model_uuid(const struct gptoss_uuid* uuid) {
|
| 22 |
+
return memcmp(
|
| 23 |
+
&(struct gptoss_uuid) {0xDF, 0x52, 0xDC, 0x86, 0x17, 0x89, 0x4E, 0xD0, 0xA2, 0x95, 0x66, 0xF1, 0x05, 0x08, 0x14, 0x5B},
|
| 24 |
+
uuid,
|
| 25 |
+
sizeof(struct gptoss_uuid)) == 0;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static inline bool gptoss_is_applegpu_layout_uuid(const struct gptoss_uuid* uuid) {
|
| 29 |
+
return memcmp(
|
| 30 |
+
&(struct gptoss_uuid) {0x22, 0x91, 0x77, 0xA8, 0x57, 0x75, 0x42, 0x68, 0xBF, 0xD8, 0xD5, 0x88, 0xB3, 0x51, 0xC5, 0x6D},
|
| 31 |
+
uuid,
|
| 32 |
+
sizeof(struct gptoss_uuid)) == 0;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
static inline bool gptoss_is_tiktoken_tokenizer_uuid(const struct gptoss_uuid* uuid) {
|
| 36 |
+
return memcmp(
|
| 37 |
+
&(struct gptoss_uuid) {0x74, 0x01, 0xAD, 0xED, 0x2A, 0x95, 0x40, 0xCB, 0xB7, 0x82, 0x9C, 0xCE, 0xBA, 0xAF, 0xE7, 0x2B},
|
| 38 |
+
uuid,
|
| 39 |
+
sizeof(struct gptoss_uuid)) == 0;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
static inline enum gptoss_special_token gptoss_special_token_decode_uuid(const struct gptoss_uuid* uuid) {
|
| 43 |
+
if (memcmp(
|
| 44 |
+
&(struct gptoss_uuid) {0x55, 0xA7, 0x7C, 0x2F, 0x8A, 0x01, 0x4C, 0x54, 0x8A, 0xC2, 0x31, 0x3B, 0xFC, 0x7E, 0x20, 0x8D},
|
| 45 |
+
uuid,
|
| 46 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 47 |
+
{
|
| 48 |
+
return gptoss_special_token_start;
|
| 49 |
+
} else if (memcmp(
|
| 50 |
+
&(struct gptoss_uuid) {0x16, 0xE4, 0x04, 0x31, 0xF4, 0x7F, 0x4B, 0x22, 0xB5, 0x9B, 0x8B, 0x27, 0x8F, 0xC3, 0x0A, 0x54},
|
| 51 |
+
uuid,
|
| 52 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 53 |
+
{
|
| 54 |
+
return gptoss_special_token_message;
|
| 55 |
+
} else if (memcmp(
|
| 56 |
+
&(struct gptoss_uuid) {0xFC, 0xAC, 0x2F, 0x6D, 0x47, 0x05, 0x4F, 0x6B, 0xB2, 0x28, 0x64, 0x2A, 0xCC, 0xAC, 0x72, 0x38},
|
| 57 |
+
uuid,
|
| 58 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 59 |
+
{
|
| 60 |
+
return gptoss_special_token_end;
|
| 61 |
+
} else if (memcmp(
|
| 62 |
+
&(struct gptoss_uuid) {0xF7, 0x99, 0xFF, 0x69, 0x19, 0x92, 0x43, 0xC4, 0xA3, 0xD8, 0xD8, 0x31, 0xF4, 0x75, 0xDC, 0x75},
|
| 63 |
+
uuid,
|
| 64 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 65 |
+
{
|
| 66 |
+
return gptoss_special_token_return;
|
| 67 |
+
} else if (memcmp(
|
| 68 |
+
&(struct gptoss_uuid) {0xE1, 0x5B, 0xA7, 0x02, 0x28, 0xC4, 0x42, 0x92, 0xAB, 0x8F, 0xFF, 0xA4, 0x34, 0x70, 0x91, 0x28},
|
| 69 |
+
uuid,
|
| 70 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 71 |
+
{
|
| 72 |
+
return gptoss_special_token_refusal;
|
| 73 |
+
} else if (memcmp(
|
| 74 |
+
&(struct gptoss_uuid) {0xC0, 0xBB, 0x14, 0xC7, 0x60, 0x22, 0x49, 0xDA, 0xAD, 0x08, 0x79, 0x2D, 0x67, 0xE8, 0xB4, 0x70},
|
| 75 |
+
uuid,
|
| 76 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 77 |
+
{
|
| 78 |
+
return gptoss_special_token_constrain;
|
| 79 |
+
} else if (memcmp(
|
| 80 |
+
&(struct gptoss_uuid) {0xFD, 0x3D, 0xDA, 0x11, 0xC8, 0xAB, 0x40, 0x33, 0x87, 0x6E, 0xD9, 0x3D, 0xEB, 0x17, 0x2C, 0x93},
|
| 81 |
+
uuid,
|
| 82 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 83 |
+
{
|
| 84 |
+
return gptoss_special_token_channel;
|
| 85 |
+
} else if (memcmp(
|
| 86 |
+
&(struct gptoss_uuid) {0x12, 0x20, 0xF7, 0x96, 0xE3, 0x88, 0x4D, 0xE5, 0xB4, 0x87, 0xFE, 0x2E, 0xB5, 0xFE, 0x03, 0xC0},
|
| 87 |
+
uuid,
|
| 88 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 89 |
+
{
|
| 90 |
+
return gptoss_special_token_call;
|
| 91 |
+
} else if (memcmp(
|
| 92 |
+
&(struct gptoss_uuid) {0x07, 0xD7, 0xDA, 0x55, 0xB3, 0x46, 0x4C, 0xFF, 0x8B, 0x37, 0x7C, 0xEF, 0xAC, 0xF8, 0xA3, 0xE8},
|
| 93 |
+
uuid,
|
| 94 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 95 |
+
{
|
| 96 |
+
return gptoss_special_token_untrusted;
|
| 97 |
+
} else if (memcmp(
|
| 98 |
+
&(struct gptoss_uuid) {0xF2, 0x65, 0xBD, 0x9C, 0xC7, 0x17, 0x46, 0x9E, 0xA4, 0x47, 0x92, 0x06, 0x87, 0xD6, 0x5D, 0x90},
|
| 99 |
+
uuid,
|
| 100 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 101 |
+
{
|
| 102 |
+
return gptoss_special_token_end_untrusted;
|
| 103 |
+
} else if (memcmp(
|
| 104 |
+
&(struct gptoss_uuid) {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
| 105 |
+
uuid,
|
| 106 |
+
sizeof(struct gptoss_uuid)) == 0)
|
| 107 |
+
{
|
| 108 |
+
// Suppress warning
|
| 109 |
+
return gptoss_special_token_invalid;
|
| 110 |
+
} else {
|
| 111 |
+
GPTOSS_LOG_WARNING("unsupported special token " UUID_FORMAT, UUID_ARGS(*uuid));
|
| 112 |
+
return gptoss_special_token_invalid;
|
| 113 |
+
}
|
| 114 |
+
}
|
gptoss_kernels/source/log.c
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <assert.h> // assert
|
| 2 |
+
#include <stdarg.h> // va_list, va_copy, va_end
|
| 3 |
+
#include <stdio.h> // vsnprintf
|
| 4 |
+
#include <stdlib.h> // malloc, free
|
| 5 |
+
|
| 6 |
+
#include <unistd.h> // STDERR_FILENO
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#define GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE 16384
|
| 11 |
+
|
| 12 |
+
void gptoss_format_log(const char* format, va_list args) {
|
| 13 |
+
char stack_buffer[GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE];
|
| 14 |
+
char* heap_buffer = NULL;
|
| 15 |
+
|
| 16 |
+
va_list args_copy;
|
| 17 |
+
va_copy(args_copy, args);
|
| 18 |
+
|
| 19 |
+
const int vsnprintf_result = vsnprintf(stack_buffer, GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE, format, args);
|
| 20 |
+
assert(vsnprintf_result >= 0);
|
| 21 |
+
|
| 22 |
+
// At least a partially formatted buffer is ready.
|
| 23 |
+
char* message_buffer = &stack_buffer[0];
|
| 24 |
+
size_t message_size = (size_t) vsnprintf_result;
|
| 25 |
+
if (message_size > GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE) {
|
| 26 |
+
heap_buffer = malloc(message_size);
|
| 27 |
+
if (heap_buffer == NULL) {
|
| 28 |
+
// Fall back to the truncated message in the on-stack buffer.
|
| 29 |
+
message_size = GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE;
|
| 30 |
+
} else {
|
| 31 |
+
// Use the full message in the in-heap buffer.
|
| 32 |
+
vsnprintf(heap_buffer, message_size, format, args_copy);
|
| 33 |
+
message_buffer = heap_buffer;
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
ssize_t bytes_written;
|
| 38 |
+
do {
|
| 39 |
+
bytes_written = write(STDERR_FILENO, message_buffer, message_size);
|
| 40 |
+
if (bytes_written > 0) {
|
| 41 |
+
assert((size_t) bytes_written <= message_size);
|
| 42 |
+
message_buffer += bytes_written;
|
| 43 |
+
message_size -= bytes_written;
|
| 44 |
+
}
|
| 45 |
+
} while (bytes_written >= 0 && message_size != 0);
|
| 46 |
+
|
| 47 |
+
cleanup:
|
| 48 |
+
free(heap_buffer);
|
| 49 |
+
va_end(args_copy);
|
| 50 |
+
}
|
gptoss_kernels/source/matmul.metal
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_atomic>
|
| 2 |
+
#include <metal_compute>
|
| 3 |
+
#include <metal_integer>
|
| 4 |
+
#include <metal_math>
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
#include <metal_stdlib>
|
| 7 |
+
|
| 8 |
+
#include <internal/kernel-args.h>
|
| 9 |
+
|
| 10 |
+
#pragma METAL fp math_mode(safe)
|
| 11 |
+
#pragma METAL fp contract(off)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
// Each simdgroup reduces all channels of the input and computes a single channel of the output
|
| 15 |
+
// + Efficient synchronization
|
| 16 |
+
// + Sequential memory access within a warp
|
| 17 |
+
// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
|
| 18 |
+
// + Reuse input vector from threadgroup memory
|
| 19 |
+
// + Avoid synchronization across warps when doing reduction
|
| 20 |
+
|
| 21 |
+
kernel void gptoss_f32_bf16w_matmul(
|
| 22 |
+
constant gptoss_matmul_args& args [[ buffer(0) ]],
|
| 23 |
+
const device float4* input [[ buffer(1) ]],
|
| 24 |
+
const device bfloat4* weight [[ buffer(2) ]],
|
| 25 |
+
const device bfloat* bias [[ buffer(3) ]],
|
| 26 |
+
device float* output [[ buffer(4) ]],
|
| 27 |
+
const device gptoss_control* control [[ buffer(5) ]],
|
| 28 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 29 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 30 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 31 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 32 |
+
{
|
| 33 |
+
const uint simdgroup_size = 32;
|
| 34 |
+
if (control->abort != 0) {
|
| 35 |
+
return;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 39 |
+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
| 40 |
+
|
| 41 |
+
input += gid.y * num_column_vecs + simdgroup_tid;
|
| 42 |
+
weight += num_column_vecs * row + simdgroup_tid;
|
| 43 |
+
bias += row;
|
| 44 |
+
output += gid.y * args.num_rows + row;
|
| 45 |
+
|
| 46 |
+
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 47 |
+
|
| 48 |
+
float4 sum4 = 0.0f;
|
| 49 |
+
do {
|
| 50 |
+
const bfloat4 w = *weight;
|
| 51 |
+
const float4 i = *input;
|
| 52 |
+
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
| 53 |
+
|
| 54 |
+
weight += simdgroup_size;
|
| 55 |
+
input += simdgroup_size;
|
| 56 |
+
} while (--num_iter != 0);
|
| 57 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 58 |
+
float sum = sum2.x + sum2.y;
|
| 59 |
+
sum = metal::simd_sum(sum);
|
| 60 |
+
if (metal::simd_is_first()) {
|
| 61 |
+
sum += static_cast<float>(*bias);
|
| 62 |
+
if (args.add) {
|
| 63 |
+
*output += sum;
|
| 64 |
+
} else {
|
| 65 |
+
*output = sum;
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
kernel void gptoss_f32_bf16w_matmul_qkv(
|
| 71 |
+
constant gptoss_qkv_args& args [[ buffer(0) ]],
|
| 72 |
+
const device float4* input [[ buffer(1) ]],
|
| 73 |
+
const device bfloat4* weight [[ buffer(2) ]],
|
| 74 |
+
const device bfloat* bias [[ buffer(3) ]],
|
| 75 |
+
device float* q [[ buffer(4) ]],
|
| 76 |
+
device float* kv [[ buffer(5) ]],
|
| 77 |
+
const device gptoss_control* control [[ buffer(6) ]],
|
| 78 |
+
threadgroup void* scratch [[ threadgroup(0) ]],
|
| 79 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 80 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 81 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 82 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 83 |
+
{
|
| 84 |
+
const uint simdgroup_size = 32;
|
| 85 |
+
const uint head_dim = 64;
|
| 86 |
+
const uint num_q_heads = 64;
|
| 87 |
+
const uint num_kv_heads = 8;
|
| 88 |
+
if (control->abort != 0) {
|
| 89 |
+
return;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 93 |
+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
| 94 |
+
|
| 95 |
+
input += gid.y * num_column_vecs + simdgroup_tid;
|
| 96 |
+
weight += num_column_vecs * row + simdgroup_tid;
|
| 97 |
+
bias += row;
|
| 98 |
+
q += gid.y * args.num_rows;
|
| 99 |
+
|
| 100 |
+
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 101 |
+
|
| 102 |
+
float4 sum4 = 0.0f;
|
| 103 |
+
do {
|
| 104 |
+
const bfloat4 w = *weight;
|
| 105 |
+
const float4 i = *input;
|
| 106 |
+
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
| 107 |
+
|
| 108 |
+
weight += simdgroup_size;
|
| 109 |
+
input += simdgroup_size;
|
| 110 |
+
} while (--num_iter != 0);
|
| 111 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 112 |
+
float sum = sum2.x + sum2.y;
|
| 113 |
+
sum = metal::simd_sum(sum);
|
| 114 |
+
if (metal::simd_is_first()) {
|
| 115 |
+
sum += static_cast<float>(*bias);
|
| 116 |
+
static_cast<threadgroup float*>(scratch)[simdgroup_idx] = sum;
|
| 117 |
+
}
|
| 118 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 119 |
+
if (simdgroup_idx == 0) {
|
| 120 |
+
const uint num_half_simdgroups = num_simdgroups / 2;
|
| 121 |
+
if (simdgroup_tid < num_half_simdgroups) {
|
| 122 |
+
float2 vals = static_cast<const threadgroup float2*>(scratch)[simdgroup_tid];
|
| 123 |
+
const uint idx = gid.x * num_half_simdgroups + simdgroup_tid;
|
| 124 |
+
const uint head_idx = idx / (head_dim / 2);
|
| 125 |
+
const uint token_idx = args.token_offset + gid.y;
|
| 126 |
+
const uint dim_idx = idx % (head_dim / 2);
|
| 127 |
+
if (head_idx < num_q_heads + num_kv_heads) {
|
| 128 |
+
const float dim_idx_val = static_cast<float>(dim_idx);
|
| 129 |
+
const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale);
|
| 130 |
+
const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
|
| 131 |
+
const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset));
|
| 132 |
+
const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
|
| 133 |
+
|
| 134 |
+
const float phi = static_cast<float>(token_idx) * inv_freq;
|
| 135 |
+
const float yarn_multiplier = args.yarn_multiplier;
|
| 136 |
+
float cosphi;
|
| 137 |
+
const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
|
| 138 |
+
cosphi *= yarn_multiplier;
|
| 139 |
+
|
| 140 |
+
vals = (float2) {
|
| 141 |
+
vals.x * cosphi - vals.y * sinphi,
|
| 142 |
+
vals.x * sinphi + vals.y * cosphi,
|
| 143 |
+
};
|
| 144 |
+
}
|
| 145 |
+
if (head_idx < num_q_heads) {
|
| 146 |
+
reinterpret_cast<device float2*>(q)[idx] = vals;
|
| 147 |
+
} else if (head_idx < num_q_heads + num_kv_heads) {
|
| 148 |
+
const uint h = head_idx - num_q_heads;
|
| 149 |
+
reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals;
|
| 150 |
+
} else {
|
| 151 |
+
const uint h = head_idx - num_q_heads - num_kv_heads;
|
| 152 |
+
reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals;
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
kernel void gptoss_f32_bf16w_unembedding(
|
| 159 |
+
constant gptoss_unembedding_args& args [[ buffer(0) ]],
|
| 160 |
+
const device float4* input [[ buffer(1) ]],
|
| 161 |
+
const device bfloat4* weight [[ buffer(2) ]],
|
| 162 |
+
device float* output [[ buffer(3) ]],
|
| 163 |
+
device metal::atomic_ulong* argmax [[ buffer(4) ]],
|
| 164 |
+
const device gptoss_control* control [[ buffer(5) ]],
|
| 165 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 166 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 167 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 168 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 169 |
+
{
|
| 170 |
+
const uint simdgroup_size = 32;
|
| 171 |
+
threadgroup uint2 threadgroup_buffer[32];
|
| 172 |
+
if (control->abort != 0) {
|
| 173 |
+
return;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 177 |
+
const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;
|
| 178 |
+
const uint row_end = metal::min(gid.x * args.num_rows_per_threadgroup + args.num_rows_per_threadgroup, args.num_rows);
|
| 179 |
+
const uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 180 |
+
|
| 181 |
+
input += gid.y * num_column_vecs + simdgroup_tid;
|
| 182 |
+
weight += num_column_vecs * row_start + simdgroup_tid;
|
| 183 |
+
output += gid.y * args.num_rows + row_start;
|
| 184 |
+
|
| 185 |
+
uint2 row_sum{0xFFFFFFFFul, 0xFFFFFFFFul};
|
| 186 |
+
for (uint row = row_start; row < row_end; row += num_simdgroups) {
|
| 187 |
+
uint n = num_iter;
|
| 188 |
+
|
| 189 |
+
float4 sum4 = 0.0f;
|
| 190 |
+
do {
|
| 191 |
+
const bfloat4 w = *weight;
|
| 192 |
+
const float4 i = *input;
|
| 193 |
+
|
| 194 |
+
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
|
| 195 |
+
|
| 196 |
+
weight += simdgroup_size;
|
| 197 |
+
input += simdgroup_size;
|
| 198 |
+
} while (--n != 0);
|
| 199 |
+
input -= num_iter * simdgroup_size;
|
| 200 |
+
weight -= num_iter * simdgroup_size;
|
| 201 |
+
|
| 202 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 203 |
+
float sum = sum2.x + sum2.y;
|
| 204 |
+
sum = metal::simd_sum(sum);
|
| 205 |
+
uint sum_bits = as_type<uint>(sum);
|
| 206 |
+
if (static_cast<int>(sum_bits) >= 0) {
|
| 207 |
+
sum_bits ^= 0x7FFFFFFFu;
|
| 208 |
+
}
|
| 209 |
+
row_sum = as_type<uint2>(metal::min(as_type<ulong>(row_sum), as_type<ulong>(uint2{row, sum_bits})));
|
| 210 |
+
if (metal::simd_is_first()) {
|
| 211 |
+
*output = sum;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
weight += num_column_vecs * num_simdgroups;
|
| 215 |
+
output += num_simdgroups;
|
| 216 |
+
}
|
| 217 |
+
if (metal::simd_is_first()) {
|
| 218 |
+
threadgroup_buffer[simdgroup_idx] = row_sum;
|
| 219 |
+
}
|
| 220 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 221 |
+
if (simdgroup_idx == 0) {
|
| 222 |
+
// Min-Reduce threadgroup_buffer
|
| 223 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 224 |
+
row_sum = threadgroup_buffer[simdgroup_tid];
|
| 225 |
+
}
|
| 226 |
+
const uint sum_bits = row_sum.y;
|
| 227 |
+
const uint sum_bits_min = metal::simd_min(sum_bits);
|
| 228 |
+
const uint row_min = metal::simd_min(sum_bits == sum_bits_min ? row_sum.x : 0xFFFFFFFFu);
|
| 229 |
+
if (metal::simd_is_first()) {
|
| 230 |
+
const uint2 threadgroup_output{row_min, sum_bits_min};
|
| 231 |
+
atomic_min_explicit(&argmax[gid.y], as_type<ulong>(threadgroup_output), metal::memory_order_relaxed);
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// Current constraints for the dense matmul kernel:
|
| 237 |
+
// 1- All B* and Sg_* are a multiple of 8.
|
| 238 |
+
// 2- Bm is divisible by Sg_n and Bn is divisible by Sg_n.
|
| 239 |
+
// 3- M, N and K are all divisible by 8..
|
| 240 |
+
template <uint Bm, uint Bn, uint Bk, uint Sg_Bm, uint Sg_Bn, uint add = 0>
|
| 241 |
+
inline void _gptoss_f32_bf16w_dense_matmul_impl(
|
| 242 |
+
constant gptoss_dense_matmul_args& args, const device float* lhs,
|
| 243 |
+
const device bfloat* rhs, const device bfloat* __restrict__ bias,
|
| 244 |
+
device float* out, const device gptoss_control* control, threadgroup float* scratch, threadgroup float* bias_tile,
|
| 245 |
+
uint sg_id, uint sg_count_per_tg, uint3 gid, uint3 tg_id, uint3 local_tid,
|
| 246 |
+
uint3 threadgroup_size) {
|
| 247 |
+
|
| 248 |
+
if (control->abort != 0) {
|
| 249 |
+
return;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
// The kernel assumes that M, K, and N are divisible by 8.
|
| 253 |
+
const uint M = args.m;
|
| 254 |
+
const uint K = args.k;
|
| 255 |
+
const uint N = args.n;
|
| 256 |
+
static_assert((Bm % 8u) == 0u, "Bm must be a multiple of 8");
|
| 257 |
+
static_assert((Bn % 8u) == 0u, "Bn must be a multiple of 8");
|
| 258 |
+
static_assert((Bk % 8u) == 0u, "Bk must be a multiple of 8");
|
| 259 |
+
static_assert((Sg_Bm % 8u) == 0u, "Bk must be a multiple of 8");
|
| 260 |
+
static_assert((Sg_Bn % 8u) == 0u, "Bk must be a multiple of 8");
|
| 261 |
+
static_assert((Bn % Sg_Bn) == 0u, "Bn must be a multiple of Sg_Bn");
|
| 262 |
+
static_assert((Bm % Sg_Bm) == 0u, "Bm must be a multiple of Sg_Bm");
|
| 263 |
+
|
| 264 |
+
// Get row and col tg.
|
| 265 |
+
const uint row_tg = tg_id.y;
|
| 266 |
+
const uint col_tg = tg_id.x;
|
| 267 |
+
// Get row and col local tid.
|
| 268 |
+
const uint row_tg_offset = row_tg * Bm;
|
| 269 |
+
const uint col_tg_offset = col_tg * Bn;
|
| 270 |
+
|
| 271 |
+
const uint sg_col_count = Bn / Sg_Bn;
|
| 272 |
+
const uint row_sg = sg_id / sg_col_count;
|
| 273 |
+
const uint col_sg = sg_id % sg_col_count;
|
| 274 |
+
|
| 275 |
+
const uint row_sg_offset = row_sg * Sg_Bm;
|
| 276 |
+
const uint col_sg_offset = col_sg * Sg_Bn;
|
| 277 |
+
constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
|
| 278 |
+
// Create an array of simdgroup_float8x8 to hold temp results.
|
| 279 |
+
metal::simdgroup_float8x8 OutTiles[temp_result_size];
|
| 280 |
+
#pragma clang loop unroll(full)
|
| 281 |
+
for (uint i = 0; i < temp_result_size; i++) {
|
| 282 |
+
OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(
|
| 283 |
+
static_cast<float>(0.0));
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
|
| 287 |
+
#pragma clang loop unroll(full)
|
| 288 |
+
for (uint k = 0; k < Bk; k += 8) {
|
| 289 |
+
#pragma clang loop unroll(full)
|
| 290 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 291 |
+
// const uint m_subtile = row_sg_offset + m_subtile_;
|
| 292 |
+
// const uint row_index_in_out_tile = (m_subtile - row_sg_offset) / 8;
|
| 293 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 294 |
+
metal::simdgroup_float8x8 LHStile;
|
| 295 |
+
const uint k_id = k + k_offset;
|
| 296 |
+
const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_;
|
| 297 |
+
metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset));
|
| 298 |
+
metal::simdgroup_bfloat8x8 RHStile;
|
| 299 |
+
#pragma clang loop unroll(full)
|
| 300 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 301 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 302 |
+
const uint current_index_out_tile =
|
| 303 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 304 |
+
const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_;
|
| 305 |
+
simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true);
|
| 306 |
+
// If rhs was not transposed, use the following instead:
|
| 307 |
+
// simdgroup_load(RHStile, rhs, N, ulong2(col_offset, k_id));
|
| 308 |
+
simdgroup_multiply_accumulate(OutTiles[current_index_out_tile],
|
| 309 |
+
LHStile, RHStile,
|
| 310 |
+
OutTiles[current_index_out_tile]);
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
// Epilogue.
|
| 316 |
+
#pragma clang loop unroll(full)
|
| 317 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 318 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 319 |
+
const uint local_col_offset = col_sg_offset + n_subtile_;
|
| 320 |
+
#pragma clang loop unroll(full)
|
| 321 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 322 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 323 |
+
const uint local_row_offset = row_sg_offset + m_subtile_;
|
| 324 |
+
const uint current_index_out_tile =
|
| 325 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 326 |
+
simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
|
| 327 |
+
ulong2(local_col_offset, local_row_offset));
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
// TODO(ibahmed): vectorize these loads an maybe unroll the loop.
|
| 331 |
+
const uint thread_count_per_tg =
|
| 332 |
+
threadgroup_size.x * threadgroup_size.y * threadgroup_size.z;
|
| 333 |
+
for (uint c_local = local_tid.x; c_local < Bn;
|
| 334 |
+
c_local += thread_count_per_tg) {
|
| 335 |
+
const uint c_global = col_tg_offset + c_local;
|
| 336 |
+
bias_tile[c_local] =
|
| 337 |
+
(c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 341 |
+
|
| 342 |
+
// TODO(ibahmed): vectorize these stores and maybe unroll the loop.
|
| 343 |
+
for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {
|
| 344 |
+
const uint r = idx / Bn;
|
| 345 |
+
const uint c = idx % Bn;
|
| 346 |
+
|
| 347 |
+
const uint out_row = row_tg_offset + r;
|
| 348 |
+
const uint out_col = col_tg_offset + c;
|
| 349 |
+
|
| 350 |
+
if (out_row < M && out_col < N) {
|
| 351 |
+
float acc = scratch[idx] + bias_tile[c];
|
| 352 |
+
if (add) {
|
| 353 |
+
acc += out[out_row * N + out_col];
|
| 354 |
+
}
|
| 355 |
+
out[out_row * N + out_col] = acc;
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
kernel void gptoss_f32_bf16w_dense_matmul_qkv(
|
| 361 |
+
constant gptoss_dense_matmul_args& args [[buffer(0)]],
|
| 362 |
+
const device float* lhs [[buffer(1)]],
|
| 363 |
+
const device bfloat* rhs [[buffer(2)]],
|
| 364 |
+
const device bfloat* __restrict__ bias [[buffer(3)]],
|
| 365 |
+
device float* out [[buffer(4)]],
|
| 366 |
+
const device gptoss_control* control [[buffer(5)]],
|
| 367 |
+
uint sg_id [[simdgroup_index_in_threadgroup]],
|
| 368 |
+
uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
|
| 369 |
+
uint3 gid [[thread_position_in_grid]],
|
| 370 |
+
uint3 tg_id [[threadgroup_position_in_grid]],
|
| 371 |
+
uint3 local_tid [[thread_position_in_threadgroup]],
|
| 372 |
+
uint3 threadgroup_size [[threads_per_threadgroup]]) {
|
| 373 |
+
threadgroup float scratch[QKV_Bm * QKV_Bn];
|
| 374 |
+
threadgroup float bias_tile[QKV_Bn];
|
| 375 |
+
_gptoss_f32_bf16w_dense_matmul_impl<QKV_Bm, QKV_Bn, QKV_Bk, QKV_Sg_Bm,
|
| 376 |
+
QKV_Sg_Bn>(
|
| 377 |
+
args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
|
| 378 |
+
gid, tg_id, local_tid, threadgroup_size);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
kernel void gptoss_f32_bf16w_dense_matmul_attn_output(
|
| 382 |
+
constant gptoss_dense_matmul_args& args [[buffer(0)]],
|
| 383 |
+
const device float* lhs [[buffer(1)]],
|
| 384 |
+
const device bfloat* rhs [[buffer(2)]],
|
| 385 |
+
const device bfloat* __restrict__ bias [[buffer(3)]],
|
| 386 |
+
device float* out [[buffer(4)]],
|
| 387 |
+
const device gptoss_control* control [[buffer(5)]],
|
| 388 |
+
uint sg_id [[simdgroup_index_in_threadgroup]],
|
| 389 |
+
uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
|
| 390 |
+
uint3 gid [[thread_position_in_grid]],
|
| 391 |
+
uint3 tg_id [[threadgroup_position_in_grid]],
|
| 392 |
+
uint3 local_tid [[thread_position_in_threadgroup]],
|
| 393 |
+
uint3 threadgroup_size [[threads_per_threadgroup]]) {
|
| 394 |
+
threadgroup float scratch[ATTN_OUTPUT_Bm * ATTN_OUTPUT_Bn];
|
| 395 |
+
threadgroup float bias_tile[ATTN_OUTPUT_Bn];
|
| 396 |
+
_gptoss_f32_bf16w_dense_matmul_impl<ATTN_OUTPUT_Bm, ATTN_OUTPUT_Bn,
|
| 397 |
+
ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm,
|
| 398 |
+
ATTN_OUTPUT_Sg_Bn, /*add=*/1>(
|
| 399 |
+
args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
|
| 400 |
+
gid, tg_id, local_tid, threadgroup_size);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
kernel void gptoss_f32_bf16w_dense_matmul_mlp_gate(
|
| 404 |
+
constant gptoss_dense_matmul_args& args [[buffer(0)]],
|
| 405 |
+
const device float* lhs [[buffer(1)]],
|
| 406 |
+
const device bfloat* rhs [[buffer(2)]],
|
| 407 |
+
const device bfloat* __restrict__ bias [[buffer(3)]],
|
| 408 |
+
device float* out [[buffer(4)]],
|
| 409 |
+
const device gptoss_control* control [[buffer(5)]],
|
| 410 |
+
uint sg_id [[simdgroup_index_in_threadgroup]],
|
| 411 |
+
uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
|
| 412 |
+
uint3 gid [[thread_position_in_grid]],
|
| 413 |
+
uint3 tg_id [[threadgroup_position_in_grid]],
|
| 414 |
+
uint3 local_tid [[thread_position_in_threadgroup]],
|
| 415 |
+
uint3 threadgroup_size [[threads_per_threadgroup]]) {
|
| 416 |
+
threadgroup float scratch[MLP_GATE_Bm * MLP_GATE_Bn];
|
| 417 |
+
threadgroup float bias_tile[MLP_GATE_Bn];
|
| 418 |
+
_gptoss_f32_bf16w_dense_matmul_impl<MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk,
|
| 419 |
+
MLP_GATE_Sg_Bm, MLP_GATE_Sg_Bn>(
|
| 420 |
+
args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
|
| 421 |
+
gid, tg_id, local_tid, threadgroup_size);
|
| 422 |
+
}
|
gptoss_kernels/source/metal-kernels.c
ADDED
|
@@ -0,0 +1,1518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <inttypes.h>
|
| 2 |
+
#include <stddef.h>
|
| 3 |
+
#include <stdint.h>
|
| 4 |
+
#include <math.h>
|
| 5 |
+
|
| 6 |
+
#include <internal/kernel-args.h>
|
| 7 |
+
#include <internal/log.h>
|
| 8 |
+
#include <internal/math.h>
|
| 9 |
+
#include <internal/metal.h>
|
| 10 |
+
#include <internal/metal-kernels.h>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
|
| 14 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 15 |
+
const struct gptoss_metal_function* u32_fill_random_fn,
|
| 16 |
+
size_t threadgroup_size,
|
| 17 |
+
size_t max_threadgroups,
|
| 18 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 19 |
+
size_t output_offset,
|
| 20 |
+
uint64_t num_elements,
|
| 21 |
+
uint64_t rng_seed,
|
| 22 |
+
uint64_t rng_offset)
|
| 23 |
+
{
|
| 24 |
+
if (command_buffer->object == NULL || u32_fill_random_fn->pipeline_state_object == NULL) {
|
| 25 |
+
return gptoss_status_invalid_state;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
if (threadgroup_size == 0) {
|
| 29 |
+
threadgroup_size = u32_fill_random_fn->max_threadgroup_threads;
|
| 30 |
+
} else if (threadgroup_size > u32_fill_random_fn->max_threadgroup_threads) {
|
| 31 |
+
return gptoss_status_invalid_argument;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
const size_t num_vecs = num_elements;
|
| 35 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 36 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 37 |
+
const struct gptoss_u32_fill_random_args args = {
|
| 38 |
+
.num_vecs = num_vecs,
|
| 39 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 40 |
+
.seed = rng_seed,
|
| 41 |
+
.offset = rng_offset,
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 45 |
+
command_buffer, u32_fill_random_fn,
|
| 46 |
+
threadgroup_size, 1, 1,
|
| 47 |
+
num_threadgroups, 1, 1,
|
| 48 |
+
sizeof(args), &args,
|
| 49 |
+
1, &output_buffer, &output_offset,
|
| 50 |
+
/*threadgroup_buffer_size=*/0);
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
|
| 54 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 55 |
+
const struct gptoss_metal_function* f32_fill_random_fn,
|
| 56 |
+
size_t threadgroup_size,
|
| 57 |
+
size_t max_threadgroups,
|
| 58 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 59 |
+
size_t output_offset,
|
| 60 |
+
uint64_t num_elements,
|
| 61 |
+
uint64_t rng_seed,
|
| 62 |
+
uint64_t rng_offset,
|
| 63 |
+
float rng_min,
|
| 64 |
+
float rng_max)
|
| 65 |
+
{
|
| 66 |
+
if (command_buffer->object == NULL || f32_fill_random_fn->pipeline_state_object == NULL) {
|
| 67 |
+
return gptoss_status_invalid_state;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if (threadgroup_size == 0) {
|
| 71 |
+
threadgroup_size = f32_fill_random_fn->max_threadgroup_threads;
|
| 72 |
+
} else if (threadgroup_size > f32_fill_random_fn->max_threadgroup_threads) {
|
| 73 |
+
return gptoss_status_invalid_argument;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
if (rng_min >= rng_max) {
|
| 77 |
+
return gptoss_status_invalid_argument;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
const size_t num_vecs = num_elements;
|
| 81 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 82 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 83 |
+
const struct gptoss_f32_fill_random_args args = {
|
| 84 |
+
.num_vecs = num_vecs,
|
| 85 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 86 |
+
.seed = rng_seed,
|
| 87 |
+
.offset = rng_offset,
|
| 88 |
+
.scale = (rng_max - rng_min) * 0x1.0p-32f,
|
| 89 |
+
.bias = (rng_min + rng_max) * 0.5f,
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 93 |
+
command_buffer, f32_fill_random_fn,
|
| 94 |
+
threadgroup_size, 1, 1,
|
| 95 |
+
num_threadgroups, 1, 1,
|
| 96 |
+
sizeof(args), &args,
|
| 97 |
+
1, &output_buffer, &output_offset,
|
| 98 |
+
/*threadgroup_buffer_size=*/0);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
|
| 102 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 103 |
+
const struct gptoss_metal_function* bf16_fill_random_fn,
|
| 104 |
+
size_t threadgroup_size,
|
| 105 |
+
size_t max_threadgroups,
|
| 106 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 107 |
+
size_t output_offset,
|
| 108 |
+
uint64_t num_elements,
|
| 109 |
+
uint64_t rng_seed,
|
| 110 |
+
uint64_t rng_offset,
|
| 111 |
+
float rng_min,
|
| 112 |
+
float rng_max)
|
| 113 |
+
{
|
| 114 |
+
if (command_buffer->object == NULL || bf16_fill_random_fn->pipeline_state_object == NULL) {
|
| 115 |
+
return gptoss_status_invalid_state;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if (threadgroup_size == 0) {
|
| 119 |
+
threadgroup_size = bf16_fill_random_fn->max_threadgroup_threads;
|
| 120 |
+
} else if (threadgroup_size > bf16_fill_random_fn->max_threadgroup_threads) {
|
| 121 |
+
return gptoss_status_invalid_argument;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if (rng_min >= rng_max) {
|
| 125 |
+
return gptoss_status_invalid_argument;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
const size_t num_vecs = num_elements;
|
| 129 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 130 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 131 |
+
const struct gptoss_f32_fill_random_args args = {
|
| 132 |
+
.num_vecs = num_vecs,
|
| 133 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 134 |
+
.seed = rng_seed,
|
| 135 |
+
.offset = rng_offset,
|
| 136 |
+
.scale = (rng_max - rng_min) * 0x1.0p-32f,
|
| 137 |
+
.bias = (rng_min + rng_max) * 0.5f,
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 141 |
+
command_buffer, bf16_fill_random_fn,
|
| 142 |
+
threadgroup_size, 1, 1,
|
| 143 |
+
num_threadgroups, 1, 1,
|
| 144 |
+
sizeof(args), &args,
|
| 145 |
+
1, &output_buffer, &output_offset,
|
| 146 |
+
/*threadgroup_buffer_size=*/0);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
|
| 150 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 151 |
+
const struct gptoss_metal_function* mf4_f32_convert_fn,
|
| 152 |
+
size_t threadgroup_size,
|
| 153 |
+
size_t max_threadgroups,
|
| 154 |
+
const struct gptoss_metal_buffer* block_buffer,
|
| 155 |
+
const struct gptoss_metal_buffer* scale_buffer,
|
| 156 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 157 |
+
uint64_t num_elements)
|
| 158 |
+
{
|
| 159 |
+
if (command_buffer->object == NULL || mf4_f32_convert_fn->pipeline_state_object == NULL) {
|
| 160 |
+
return gptoss_status_invalid_state;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
if (num_elements % 32 != 0) {
|
| 164 |
+
return gptoss_status_invalid_argument;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
if (threadgroup_size == 0) {
|
| 168 |
+
threadgroup_size = mf4_f32_convert_fn->max_threadgroup_threads;
|
| 169 |
+
} else if (threadgroup_size > mf4_f32_convert_fn->max_threadgroup_threads) {
|
| 170 |
+
return gptoss_status_invalid_argument;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
const size_t num_vecs = num_elements / 32;
|
| 174 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 175 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 176 |
+
const struct gptoss_convert_args args = {
|
| 177 |
+
.num_vecs = num_vecs,
|
| 178 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 179 |
+
};
|
| 180 |
+
|
| 181 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 182 |
+
command_buffer, mf4_f32_convert_fn,
|
| 183 |
+
threadgroup_size, 1, 1,
|
| 184 |
+
num_threadgroups, 1, 1,
|
| 185 |
+
sizeof(args), &args,
|
| 186 |
+
3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL,
|
| 187 |
+
/*threadgroup_buffer_size=*/0);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
| 191 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 192 |
+
const struct gptoss_metal_function* bf16_f32_embeddings_fn,
|
| 193 |
+
size_t threadgroup_size,
|
| 194 |
+
const struct gptoss_metal_buffer* token_buffer,
|
| 195 |
+
size_t token_offset,
|
| 196 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 197 |
+
size_t weight_offset,
|
| 198 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 199 |
+
size_t output_offset,
|
| 200 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 201 |
+
size_t control_offset,
|
| 202 |
+
uint32_t num_tokens,
|
| 203 |
+
uint32_t num_channels)
|
| 204 |
+
{
|
| 205 |
+
if (command_buffer->object == NULL || bf16_f32_embeddings_fn->pipeline_state_object == NULL) {
|
| 206 |
+
return gptoss_status_invalid_state;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
if (num_channels % 4 != 0) {
|
| 210 |
+
return gptoss_status_invalid_argument;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if (threadgroup_size == 0) {
|
| 214 |
+
threadgroup_size = bf16_f32_embeddings_fn->max_threadgroup_threads;
|
| 215 |
+
} else if (threadgroup_size > bf16_f32_embeddings_fn->max_threadgroup_threads) {
|
| 216 |
+
return gptoss_status_invalid_argument;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
const uint32_t num_vecs = num_channels / 4;
|
| 220 |
+
const struct gptoss_embeddings_args args = {
|
| 221 |
+
.num_vecs = num_vecs,
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 225 |
+
command_buffer, bf16_f32_embeddings_fn,
|
| 226 |
+
threadgroup_size, 1, 1,
|
| 227 |
+
num_tokens, 1, 1,
|
| 228 |
+
sizeof(args), &args,
|
| 229 |
+
4,
|
| 230 |
+
(const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer},
|
| 231 |
+
(const size_t[]) {token_offset, weight_offset, output_offset, control_offset},
|
| 232 |
+
/*threadgroup_buffer_size=*/0);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
|
| 236 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 237 |
+
const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
|
| 238 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 239 |
+
size_t input_offset,
|
| 240 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 241 |
+
size_t weight_offset,
|
| 242 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 243 |
+
size_t output_offset,
|
| 244 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 245 |
+
size_t control_offset,
|
| 246 |
+
uint32_t num_tokens,
|
| 247 |
+
uint32_t num_channels,
|
| 248 |
+
float epsilon)
|
| 249 |
+
{
|
| 250 |
+
if (command_buffer->object == NULL || f32_bf16w_rmsnorm_fn->pipeline_state_object == NULL) {
|
| 251 |
+
return gptoss_status_invalid_state;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
if (num_channels % 4 != 0) {
|
| 255 |
+
return gptoss_status_invalid_argument;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
if (f32_bf16w_rmsnorm_fn->max_threadgroup_threads < 1024) {
|
| 259 |
+
return gptoss_status_unsupported_system;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
if (f32_bf16w_rmsnorm_fn->simdgroup_threads != 32) {
|
| 263 |
+
return gptoss_status_unsupported_system;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
const uint32_t num_vecs = num_channels / 4;
|
| 267 |
+
const struct gptoss_rmsnorm_args args = {
|
| 268 |
+
.num_vecs = num_vecs,
|
| 269 |
+
.num_channels = (float) num_channels,
|
| 270 |
+
.epsilon = epsilon,
|
| 271 |
+
};
|
| 272 |
+
|
| 273 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 274 |
+
command_buffer, f32_bf16w_rmsnorm_fn,
|
| 275 |
+
/*threadgroup_size=*/1024, 1, 1,
|
| 276 |
+
num_tokens, 1, 1,
|
| 277 |
+
sizeof(args), &args,
|
| 278 |
+
4,
|
| 279 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer},
|
| 280 |
+
(const size_t[]) {input_offset, weight_offset, output_offset, control_offset},
|
| 281 |
+
/*threadgroup_buffer_size=*/0);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
|
| 285 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 286 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
| 287 |
+
size_t threadgroup_size,
|
| 288 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 289 |
+
size_t input_offset,
|
| 290 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 291 |
+
size_t weight_offset,
|
| 292 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 293 |
+
size_t bias_offset,
|
| 294 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 295 |
+
size_t output_offset,
|
| 296 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 297 |
+
size_t control_offset,
|
| 298 |
+
uint32_t num_tokens,
|
| 299 |
+
uint32_t num_cols,
|
| 300 |
+
uint32_t num_rows)
|
| 301 |
+
{
|
| 302 |
+
if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
|
| 303 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: invalid command buffer or pipeline state object");
|
| 304 |
+
return gptoss_status_invalid_state;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
if (threadgroup_size == 0) {
|
| 308 |
+
threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
|
| 309 |
+
} else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
|
| 310 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 311 |
+
threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
|
| 312 |
+
return gptoss_status_invalid_argument;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
if (num_cols % 4 != 0) {
|
| 316 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
| 317 |
+
num_cols);
|
| 318 |
+
return gptoss_status_invalid_argument;
|
| 319 |
+
}
|
| 320 |
+
const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
|
| 321 |
+
if (num_rows % num_simdgroups != 0) {
|
| 322 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
| 323 |
+
num_rows, num_simdgroups);
|
| 324 |
+
return gptoss_status_invalid_argument;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
const struct gptoss_matmul_args args = {
|
| 328 |
+
.num_column_vecs = num_cols / 4,
|
| 329 |
+
.num_rows = num_rows,
|
| 330 |
+
.add = 0,
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 334 |
+
command_buffer, f32_bf16w_matmul_fn,
|
| 335 |
+
threadgroup_size, 1, 1,
|
| 336 |
+
num_rows / num_simdgroups, num_tokens, 1,
|
| 337 |
+
sizeof(args), &args,
|
| 338 |
+
5,
|
| 339 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
|
| 340 |
+
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
|
| 341 |
+
/*threadgroup_buffer_size=*/0);
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
|
| 345 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 346 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
|
| 347 |
+
size_t threadgroup_size,
|
| 348 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 349 |
+
size_t input_offset,
|
| 350 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 351 |
+
size_t weight_offset,
|
| 352 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 353 |
+
size_t bias_offset,
|
| 354 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 355 |
+
size_t output_offset,
|
| 356 |
+
const struct gptoss_metal_buffer* kv_buffer,
|
| 357 |
+
size_t kv_offset,
|
| 358 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 359 |
+
size_t control_offset,
|
| 360 |
+
uint32_t num_tokens,
|
| 361 |
+
uint32_t num_cols,
|
| 362 |
+
uint32_t num_q_heads,
|
| 363 |
+
uint32_t num_kv_heads,
|
| 364 |
+
uint32_t attn_head_dim,
|
| 365 |
+
uint32_t token_offset,
|
| 366 |
+
uint32_t max_tokens,
|
| 367 |
+
float rope_base,
|
| 368 |
+
float interpolation_scale,
|
| 369 |
+
float yarn_offset,
|
| 370 |
+
float yarn_scale,
|
| 371 |
+
float yarn_multiplier)
|
| 372 |
+
{
|
| 373 |
+
if (command_buffer->object == NULL || f32_bf16w_matmul_qkv_fn->pipeline_state_object == NULL) {
|
| 374 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: invalid command buffer or pipeline state object");
|
| 375 |
+
return gptoss_status_invalid_state;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if (threadgroup_size == 0) {
|
| 379 |
+
threadgroup_size = f32_bf16w_matmul_qkv_fn->simdgroup_threads;
|
| 380 |
+
} else if (threadgroup_size > f32_bf16w_matmul_qkv_fn->max_threadgroup_threads) {
|
| 381 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 382 |
+
threadgroup_size, f32_bf16w_matmul_qkv_fn->max_threadgroup_threads);
|
| 383 |
+
return gptoss_status_invalid_argument;
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
if (num_cols % 4 != 0) {
|
| 387 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
| 388 |
+
num_cols);
|
| 389 |
+
return gptoss_status_invalid_argument;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
if (num_q_heads != 64) {
|
| 393 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of Q heads (%" PRIu32 ") must be 64",
|
| 394 |
+
num_q_heads);
|
| 395 |
+
return gptoss_status_invalid_argument;
|
| 396 |
+
}
|
| 397 |
+
if (num_kv_heads != 8) {
|
| 398 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of KV heads (%" PRIu32 ") must be 8",
|
| 399 |
+
num_kv_heads);
|
| 400 |
+
return gptoss_status_invalid_argument;
|
| 401 |
+
}
|
| 402 |
+
if (attn_head_dim != 64) {
|
| 403 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: attention head dimension (%" PRIu32 ") must be 64",
|
| 404 |
+
attn_head_dim);
|
| 405 |
+
return gptoss_status_invalid_argument;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_qkv_fn->simdgroup_threads;
|
| 409 |
+
const uint32_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim;
|
| 410 |
+
if (num_rows % num_simdgroups != 0) {
|
| 411 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
| 412 |
+
num_rows, num_simdgroups);
|
| 413 |
+
return gptoss_status_invalid_argument;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
const struct gptoss_qkv_args args = {
|
| 417 |
+
.num_column_vecs = num_cols / 4,
|
| 418 |
+
.num_rows = num_rows,
|
| 419 |
+
.token_offset = token_offset,
|
| 420 |
+
.freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
|
| 421 |
+
.interpolation_scale = interpolation_scale,
|
| 422 |
+
.yarn_offset = yarn_offset,
|
| 423 |
+
.yarn_scale = yarn_scale,
|
| 424 |
+
.yarn_multiplier = yarn_multiplier,
|
| 425 |
+
.max_tokens = max_tokens,
|
| 426 |
+
};
|
| 427 |
+
|
| 428 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 429 |
+
command_buffer, f32_bf16w_matmul_qkv_fn,
|
| 430 |
+
threadgroup_size, 1, 1,
|
| 431 |
+
num_rows / num_simdgroups, num_tokens, 1,
|
| 432 |
+
sizeof(args), &args,
|
| 433 |
+
6,
|
| 434 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer},
|
| 435 |
+
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset},
|
| 436 |
+
/*threadgroup_buffer_size=*/num_simdgroups * sizeof(float));
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
|
| 440 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 441 |
+
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
|
| 442 |
+
size_t threadgroup_size,
|
| 443 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 444 |
+
size_t input_offset,
|
| 445 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 446 |
+
size_t weight_offset,
|
| 447 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 448 |
+
size_t bias_offset,
|
| 449 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 450 |
+
size_t output_offset,
|
| 451 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 452 |
+
size_t control_offset,
|
| 453 |
+
uint32_t num_tokens,
|
| 454 |
+
uint32_t num_cols,
|
| 455 |
+
uint32_t num_rows)
|
| 456 |
+
{
|
| 457 |
+
if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
|
| 458 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: invalid command buffer or pipeline state object");
|
| 459 |
+
return gptoss_status_invalid_state;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
if (threadgroup_size == 0) {
|
| 463 |
+
threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
|
| 464 |
+
} else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
|
| 465 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 466 |
+
threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
|
| 467 |
+
return gptoss_status_invalid_argument;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
if (num_cols % 4 != 0) {
|
| 471 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
| 472 |
+
num_cols);
|
| 473 |
+
return gptoss_status_invalid_argument;
|
| 474 |
+
}
|
| 475 |
+
const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
|
| 476 |
+
if (num_rows % num_simdgroups != 0) {
|
| 477 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
| 478 |
+
num_rows, num_simdgroups);
|
| 479 |
+
return gptoss_status_invalid_argument;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
const struct gptoss_matmul_args args = {
|
| 483 |
+
.num_column_vecs = num_cols / 4,
|
| 484 |
+
.num_rows = num_rows,
|
| 485 |
+
.add = 1,
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 489 |
+
command_buffer, f32_bf16w_matmul_fn,
|
| 490 |
+
threadgroup_size, 1, 1,
|
| 491 |
+
num_rows / num_simdgroups, num_tokens, 1,
|
| 492 |
+
sizeof(args), &args,
|
| 493 |
+
5,
|
| 494 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
|
| 495 |
+
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
|
| 496 |
+
/*threadgroup_buffer_size=*/0);
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
enum gptoss_status _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
|
| 500 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 501 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 502 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 503 |
+
size_t input_offset,
|
| 504 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 505 |
+
size_t weight_offset,
|
| 506 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 507 |
+
size_t bias_offset,
|
| 508 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 509 |
+
size_t output_offset,
|
| 510 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 511 |
+
size_t control_offset,
|
| 512 |
+
uint32_t num_tokens,
|
| 513 |
+
uint32_t num_cols,
|
| 514 |
+
uint32_t num_rows,
|
| 515 |
+
uint32_t Bm,
|
| 516 |
+
uint32_t Bn,
|
| 517 |
+
uint32_t Bk,
|
| 518 |
+
uint32_t Sg_Bm,
|
| 519 |
+
uint32_t Sg_Bn)
|
| 520 |
+
{
|
| 521 |
+
|
| 522 |
+
if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) {
|
| 523 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object");
|
| 524 |
+
return gptoss_status_invalid_state;
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
if (num_cols % 8 != 0) {
|
| 528 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 8",
|
| 529 |
+
num_cols);
|
| 530 |
+
return gptoss_status_invalid_argument;
|
| 531 |
+
}
|
| 532 |
+
if (num_rows % 8 != 0) {
|
| 533 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by 8",
|
| 534 |
+
num_rows);
|
| 535 |
+
return gptoss_status_invalid_argument;
|
| 536 |
+
}
|
| 537 |
+
if (num_tokens % 8 != 0) {
|
| 538 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of tokens (%" PRIu32 ") is not divisible by 8",
|
| 539 |
+
num_tokens);
|
| 540 |
+
return gptoss_status_invalid_argument;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
const struct gptoss_dense_matmul_args args = {
|
| 544 |
+
.m = num_tokens,
|
| 545 |
+
.n = num_rows,
|
| 546 |
+
.k = num_cols,
|
| 547 |
+
};
|
| 548 |
+
const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads;
|
| 549 |
+
const uint32_t m = args.m;
|
| 550 |
+
const uint32_t n = args.n;
|
| 551 |
+
const uint32_t k = args.k;
|
| 552 |
+
if (Bm % Sg_Bm != 0) {
|
| 553 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
|
| 554 |
+
Bm, Sg_Bm);
|
| 555 |
+
return gptoss_status_invalid_argument;
|
| 556 |
+
}
|
| 557 |
+
if (Bn % Sg_Bn != 0) {
|
| 558 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
|
| 559 |
+
Bn, Sg_Bn);
|
| 560 |
+
return gptoss_status_invalid_argument;
|
| 561 |
+
}
|
| 562 |
+
const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
|
| 563 |
+
const size_t threadgroup_size_y = 1;
|
| 564 |
+
const size_t threadgroup_size_z = 1;
|
| 565 |
+
const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
|
| 566 |
+
if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) {
|
| 567 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 568 |
+
total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads);
|
| 569 |
+
return gptoss_status_invalid_argument;
|
| 570 |
+
}
|
| 571 |
+
if (m % Bm != 0) {
|
| 572 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: m (%" PRIu32 ") is not divisible by Bm (%" PRIu32 ")",
|
| 573 |
+
m, Bm);
|
| 574 |
+
return gptoss_status_invalid_argument;
|
| 575 |
+
}
|
| 576 |
+
if (n % Bn != 0) {
|
| 577 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
|
| 578 |
+
n, Bn);
|
| 579 |
+
return gptoss_status_invalid_argument;
|
| 580 |
+
}
|
| 581 |
+
if (k % Bk != 0) {
|
| 582 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
|
| 583 |
+
k, Bk);
|
| 584 |
+
return gptoss_status_invalid_argument;
|
| 585 |
+
}
|
| 586 |
+
const size_t grid_x = n / Bn;
|
| 587 |
+
const size_t grid_y = m / Bm;
|
| 588 |
+
const size_t grid_z = 1;
|
| 589 |
+
|
| 590 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 591 |
+
command_buffer, f32_bf16w_dense_matmul_fn,
|
| 592 |
+
threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
|
| 593 |
+
grid_x, grid_y, grid_z,
|
| 594 |
+
sizeof(args), &args,
|
| 595 |
+
5,
|
| 596 |
+
(const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
|
| 597 |
+
(const size_t[]){input_offset, weight_offset, bias_offset, output_offset, control_offset},
|
| 598 |
+
/*threadgroup_buffer_size=*/0);
|
| 599 |
+
return gptoss_status_success;
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
|
| 603 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 604 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 605 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 606 |
+
size_t input_offset,
|
| 607 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 608 |
+
size_t weight_offset,
|
| 609 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 610 |
+
size_t bias_offset,
|
| 611 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 612 |
+
size_t output_offset,
|
| 613 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 614 |
+
size_t control_offset,
|
| 615 |
+
uint32_t num_tokens,
|
| 616 |
+
uint32_t num_cols,
|
| 617 |
+
uint32_t num_rows)
|
| 618 |
+
{
|
| 619 |
+
return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
|
| 620 |
+
command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
|
| 621 |
+
weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
|
| 622 |
+
output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, QKV_Bm, QKV_Bn, QKV_Bk,
|
| 623 |
+
QKV_Sg_Bm, QKV_Sg_Bn);
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
|
| 627 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 628 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 629 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 630 |
+
size_t input_offset,
|
| 631 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 632 |
+
size_t weight_offset,
|
| 633 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 634 |
+
size_t bias_offset,
|
| 635 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 636 |
+
size_t output_offset,
|
| 637 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 638 |
+
size_t control_offset,
|
| 639 |
+
uint32_t num_tokens,
|
| 640 |
+
uint32_t num_cols,
|
| 641 |
+
uint32_t num_rows)
|
| 642 |
+
{
|
| 643 |
+
return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
|
| 644 |
+
command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
|
| 645 |
+
weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
|
| 646 |
+
output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, ATTN_OUTPUT_Bm,
|
| 647 |
+
ATTN_OUTPUT_Bn, ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm, ATTN_OUTPUT_Sg_Bn);
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
|
| 651 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 652 |
+
const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
|
| 653 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 654 |
+
size_t input_offset,
|
| 655 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 656 |
+
size_t weight_offset,
|
| 657 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 658 |
+
size_t bias_offset,
|
| 659 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 660 |
+
size_t output_offset,
|
| 661 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 662 |
+
size_t control_offset,
|
| 663 |
+
uint32_t num_tokens,
|
| 664 |
+
uint32_t num_cols,
|
| 665 |
+
uint32_t num_rows)
|
| 666 |
+
{
|
| 667 |
+
return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
|
| 668 |
+
command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
|
| 669 |
+
weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
|
| 670 |
+
output_offset, control_buffer, control_offset, num_tokens, num_cols,
|
| 671 |
+
num_rows, MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk, MLP_GATE_Sg_Bm,
|
| 672 |
+
MLP_GATE_Sg_Bn);
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
|
| 676 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 677 |
+
const struct gptoss_metal_function* f32_bf16w_unembedding_fn,
|
| 678 |
+
size_t threadgroup_size,
|
| 679 |
+
size_t max_threadgroups,
|
| 680 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 681 |
+
size_t input_offset,
|
| 682 |
+
const struct gptoss_metal_buffer* weight_buffer,
|
| 683 |
+
size_t weight_offset,
|
| 684 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 685 |
+
size_t output_offset,
|
| 686 |
+
const struct gptoss_metal_buffer* argmax_buffer,
|
| 687 |
+
size_t argmax_offset,
|
| 688 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 689 |
+
size_t control_offset,
|
| 690 |
+
uint32_t num_tokens,
|
| 691 |
+
uint32_t num_cols,
|
| 692 |
+
uint32_t num_rows)
|
| 693 |
+
{
|
| 694 |
+
if (command_buffer->object == NULL || f32_bf16w_unembedding_fn->pipeline_state_object == NULL) {
|
| 695 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: invalid command buffer or pipeline state object");
|
| 696 |
+
return gptoss_status_invalid_state;
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
if (threadgroup_size == 0) {
|
| 700 |
+
threadgroup_size = f32_bf16w_unembedding_fn->simdgroup_threads;
|
| 701 |
+
} else if (threadgroup_size > f32_bf16w_unembedding_fn->max_threadgroup_threads) {
|
| 702 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 703 |
+
threadgroup_size, f32_bf16w_unembedding_fn->max_threadgroup_threads);
|
| 704 |
+
return gptoss_status_invalid_argument;
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
if (num_cols % 4 != 0) {
|
| 708 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
|
| 709 |
+
num_cols);
|
| 710 |
+
return gptoss_status_invalid_argument;
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
const size_t num_simdgroups = threadgroup_size / f32_bf16w_unembedding_fn->simdgroup_threads;
|
| 714 |
+
const size_t num_rows_per_threadgroup = math_ceil_div(num_rows, max_threadgroups * num_simdgroups) * num_simdgroups;
|
| 715 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_rows, num_rows_per_threadgroup));
|
| 716 |
+
const struct gptoss_unembedding_args args = {
|
| 717 |
+
.num_column_vecs = num_cols / 4,
|
| 718 |
+
.num_rows_per_threadgroup = num_rows_per_threadgroup,
|
| 719 |
+
.num_rows = num_rows,
|
| 720 |
+
};
|
| 721 |
+
|
| 722 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 723 |
+
command_buffer, f32_bf16w_unembedding_fn,
|
| 724 |
+
threadgroup_size, 1, 1,
|
| 725 |
+
num_threadgroups, num_tokens, 1,
|
| 726 |
+
sizeof(args), &args,
|
| 727 |
+
5,
|
| 728 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer},
|
| 729 |
+
(const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset},
|
| 730 |
+
/*threadgroup_buffer_size=*/0);
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
|
| 734 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 735 |
+
const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
|
| 736 |
+
size_t threadgroup_size,
|
| 737 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 738 |
+
size_t input_offset,
|
| 739 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 740 |
+
size_t expert_offset,
|
| 741 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 742 |
+
size_t weight_block_offset,
|
| 743 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 744 |
+
size_t weight_scale_offset,
|
| 745 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 746 |
+
size_t bias_offset,
|
| 747 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 748 |
+
size_t output_offset,
|
| 749 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 750 |
+
size_t control_offset,
|
| 751 |
+
float swiglu_limit,
|
| 752 |
+
uint32_t expert_stride,
|
| 753 |
+
uint32_t num_tokens,
|
| 754 |
+
uint32_t num_active_experts,
|
| 755 |
+
uint32_t num_cols,
|
| 756 |
+
uint32_t num_rows)
|
| 757 |
+
{
|
| 758 |
+
if (command_buffer->object == NULL || f32_mf4w_moe_matmul_swiglu_fn->pipeline_state_object == NULL) {
|
| 759 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: invalid command buffer or pipeline state object");
|
| 760 |
+
return gptoss_status_invalid_state;
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
if (threadgroup_size == 0) {
|
| 764 |
+
threadgroup_size = 2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
|
| 765 |
+
} else if (threadgroup_size > f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads) {
|
| 766 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 767 |
+
threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads);
|
| 768 |
+
return gptoss_status_invalid_argument;
|
| 769 |
+
} else if (threadgroup_size % (2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads)) {
|
| 770 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu) multiplied by 2X",
|
| 771 |
+
threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads);
|
| 772 |
+
return gptoss_status_invalid_argument;
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
if (num_cols % 32 != 0) {
|
| 776 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
| 777 |
+
num_cols);
|
| 778 |
+
return gptoss_status_invalid_argument;
|
| 779 |
+
}
|
| 780 |
+
const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
|
| 781 |
+
if ((2 * num_rows) % num_simdgroups != 0) {
|
| 782 |
+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: "
|
| 783 |
+
"the number of rows (%" PRIu32 ") multiplied by 2X is not divisible by the number of simdgroups (%zu)",
|
| 784 |
+
num_rows, num_simdgroups);
|
| 785 |
+
return gptoss_status_invalid_argument;
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
const struct gptoss_moe_matmul_swiglu_args args = {
|
| 789 |
+
.num_column_vecs = num_cols / 32,
|
| 790 |
+
.num_rows = num_rows,
|
| 791 |
+
.num_active_experts = num_active_experts,
|
| 792 |
+
.weight_expert_stride = expert_stride,
|
| 793 |
+
.output_expert_stride = num_rows * num_tokens,
|
| 794 |
+
.swiglu_min = -swiglu_limit,
|
| 795 |
+
.swiglu_max = swiglu_limit,
|
| 796 |
+
};
|
| 797 |
+
|
| 798 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 799 |
+
command_buffer, f32_mf4w_moe_matmul_swiglu_fn,
|
| 800 |
+
threadgroup_size, 1, 1,
|
| 801 |
+
(2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,
|
| 802 |
+
sizeof(args), &args,
|
| 803 |
+
7,
|
| 804 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
|
| 805 |
+
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
|
| 806 |
+
/*threadgroup_buffer_size=*/0);
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
|
| 810 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 811 |
+
const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
|
| 812 |
+
size_t threadgroup_size,
|
| 813 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 814 |
+
size_t input_offset,
|
| 815 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 816 |
+
size_t expert_offset,
|
| 817 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 818 |
+
size_t weight_block_offset,
|
| 819 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 820 |
+
size_t weight_scale_offset,
|
| 821 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 822 |
+
size_t bias_offset,
|
| 823 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 824 |
+
size_t output_offset,
|
| 825 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 826 |
+
size_t control_offset,
|
| 827 |
+
uint32_t expert_stride,
|
| 828 |
+
uint32_t num_tokens,
|
| 829 |
+
uint32_t num_active_experts,
|
| 830 |
+
uint32_t num_cols,
|
| 831 |
+
uint32_t num_rows)
|
| 832 |
+
{
|
| 833 |
+
if (command_buffer->object == NULL || f32_mf4w_moe_matmul_fn->pipeline_state_object == NULL) {
|
| 834 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: invalid command buffer or pipeline state object");
|
| 835 |
+
return gptoss_status_invalid_state;
|
| 836 |
+
}
|
| 837 |
+
|
| 838 |
+
if (threadgroup_size == 0) {
|
| 839 |
+
threadgroup_size = f32_mf4w_moe_matmul_fn->simdgroup_threads;
|
| 840 |
+
} else if (threadgroup_size > f32_mf4w_moe_matmul_fn->max_threadgroup_threads) {
|
| 841 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 842 |
+
threadgroup_size, f32_mf4w_moe_matmul_fn->max_threadgroup_threads);
|
| 843 |
+
return gptoss_status_invalid_argument;
|
| 844 |
+
} else if (threadgroup_size % f32_mf4w_moe_matmul_fn->simdgroup_threads) {
|
| 845 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu)",
|
| 846 |
+
threadgroup_size, f32_mf4w_moe_matmul_fn->simdgroup_threads);
|
| 847 |
+
return gptoss_status_invalid_argument;
|
| 848 |
+
}
|
| 849 |
+
|
| 850 |
+
if (num_cols % 32 != 0) {
|
| 851 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
| 852 |
+
num_cols);
|
| 853 |
+
return gptoss_status_invalid_argument;
|
| 854 |
+
}
|
| 855 |
+
const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_fn->simdgroup_threads;
|
| 856 |
+
if (num_rows % num_simdgroups != 0) {
|
| 857 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: "
|
| 858 |
+
"the number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
|
| 859 |
+
num_rows, num_simdgroups);
|
| 860 |
+
return gptoss_status_invalid_argument;
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
const struct gptoss_moe_matmul_args args = {
|
| 864 |
+
.num_column_vecs = num_cols / 32,
|
| 865 |
+
.num_rows = num_rows,
|
| 866 |
+
.num_active_experts = num_active_experts,
|
| 867 |
+
.input_expert_stride = num_tokens * (num_cols / 32),
|
| 868 |
+
.weight_expert_stride = expert_stride,
|
| 869 |
+
.output_expert_stride = num_rows * num_tokens,
|
| 870 |
+
};
|
| 871 |
+
|
| 872 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 873 |
+
command_buffer, f32_mf4w_moe_matmul_fn,
|
| 874 |
+
threadgroup_size, 1, 1,
|
| 875 |
+
num_rows / num_simdgroups, num_tokens, num_active_experts,
|
| 876 |
+
sizeof(args), &args,
|
| 877 |
+
7,
|
| 878 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
|
| 879 |
+
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
|
| 880 |
+
/*threadgroup_buffer_size=*/0);
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
|
| 884 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 885 |
+
const struct gptoss_metal_function* f32_rope_fn,
|
| 886 |
+
size_t threadgroup_size,
|
| 887 |
+
const struct gptoss_metal_buffer* activations_buffer,
|
| 888 |
+
size_t activations_offset,
|
| 889 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 890 |
+
size_t control_offset,
|
| 891 |
+
float rope_base,
|
| 892 |
+
float interpolation_scale,
|
| 893 |
+
float yarn_offset,
|
| 894 |
+
float yarn_scale,
|
| 895 |
+
float yarn_multiplier,
|
| 896 |
+
uint32_t num_tokens,
|
| 897 |
+
uint32_t num_q_heads,
|
| 898 |
+
uint32_t num_kv_heads,
|
| 899 |
+
uint32_t attn_head_dim,
|
| 900 |
+
uint32_t token_offset)
|
| 901 |
+
{
|
| 902 |
+
if (command_buffer->object == NULL || f32_rope_fn->pipeline_state_object == NULL) {
|
| 903 |
+
return gptoss_status_invalid_state;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
if (threadgroup_size == 0) {
|
| 907 |
+
threadgroup_size = f32_rope_fn->max_threadgroup_threads;
|
| 908 |
+
} else if (threadgroup_size > f32_rope_fn->max_threadgroup_threads) {
|
| 909 |
+
return gptoss_status_invalid_argument;
|
| 910 |
+
}
|
| 911 |
+
|
| 912 |
+
const size_t num_simdgroups = threadgroup_size / f32_rope_fn->simdgroup_threads;
|
| 913 |
+
const uint32_t num_qk_heads = num_q_heads + num_kv_heads;
|
| 914 |
+
if (num_qk_heads % num_simdgroups != 0) {
|
| 915 |
+
return gptoss_status_invalid_argument;
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
const struct gptoss_rope_args args = {
|
| 919 |
+
.token_stride = (num_q_heads + 2 * num_kv_heads) * (attn_head_dim / 2),
|
| 920 |
+
.token_offset = token_offset,
|
| 921 |
+
.freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
|
| 922 |
+
.interpolation_scale = interpolation_scale,
|
| 923 |
+
.yarn_offset = yarn_offset,
|
| 924 |
+
.yarn_scale = yarn_scale,
|
| 925 |
+
.yarn_multiplier = yarn_multiplier,
|
| 926 |
+
};
|
| 927 |
+
|
| 928 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 929 |
+
command_buffer, f32_rope_fn,
|
| 930 |
+
threadgroup_size, 1, 1,
|
| 931 |
+
num_qk_heads / num_simdgroups, num_tokens, 1,
|
| 932 |
+
sizeof(args), &args,
|
| 933 |
+
2,
|
| 934 |
+
(const struct gptoss_metal_buffer *[]) {activations_buffer, control_buffer},
|
| 935 |
+
(const size_t[]) {activations_offset, control_offset},
|
| 936 |
+
/*threadgroup_buffer_size=*/0);
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
|
| 940 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 941 |
+
const struct gptoss_metal_function* expert_routing_metadata_fn,
|
| 942 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 943 |
+
size_t expert_predictions_offset,
|
| 944 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 945 |
+
size_t expert_offsets_offset,
|
| 946 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 947 |
+
size_t intra_expert_offsets_offset,
|
| 948 |
+
uint32_t num_tokens,
|
| 949 |
+
uint32_t num_experts)
|
| 950 |
+
{
|
| 951 |
+
if (command_buffer->object == NULL || expert_routing_metadata_fn->pipeline_state_object == NULL) {
|
| 952 |
+
return gptoss_status_invalid_state;
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
const struct gptoss_expert_routing_metadata_args args = {
|
| 956 |
+
.tokens = num_tokens,
|
| 957 |
+
.num_experts = num_experts,
|
| 958 |
+
};
|
| 959 |
+
const uint32_t threadgroup_size = 256;
|
| 960 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 961 |
+
command_buffer, expert_routing_metadata_fn,
|
| 962 |
+
threadgroup_size, 1, 1,
|
| 963 |
+
/*num_threadgroups_x=*/1, /*num_threadgroups_y=*/1, /*num_threadgroups_z=*/1,
|
| 964 |
+
sizeof(args), &args,
|
| 965 |
+
3,
|
| 966 |
+
(const struct gptoss_metal_buffer *[]) {expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer},
|
| 967 |
+
(const size_t[]) {expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset},
|
| 968 |
+
/*threadgroup_buffer_size=*/0);
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(
|
| 972 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 973 |
+
const struct gptoss_metal_function* f32_scatter_fn,
|
| 974 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 975 |
+
size_t input_offset,
|
| 976 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 977 |
+
size_t expert_predictions_offset,
|
| 978 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 979 |
+
size_t expert_offsets_offset,
|
| 980 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 981 |
+
size_t intra_expert_offsets_offset,
|
| 982 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 983 |
+
size_t output_offset,
|
| 984 |
+
uint32_t num_channels,
|
| 985 |
+
uint32_t num_tokens,
|
| 986 |
+
uint32_t num_active_experts)
|
| 987 |
+
{
|
| 988 |
+
if (command_buffer->object == NULL || f32_scatter_fn->pipeline_state_object == NULL) {
|
| 989 |
+
return gptoss_status_invalid_state;
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
if (num_channels % 4 != 0) {
|
| 993 |
+
return gptoss_status_invalid_argument;
|
| 994 |
+
}
|
| 995 |
+
|
| 996 |
+
const size_t num_vecs = num_channels / 4;
|
| 997 |
+
const size_t tgx = math_min(num_vecs, 64);
|
| 998 |
+
const size_t tgy = 1;
|
| 999 |
+
const size_t tgz = 1;
|
| 1000 |
+
const size_t grid_x = math_ceil_div(num_vecs, tgx);
|
| 1001 |
+
const size_t grid_y = num_tokens;
|
| 1002 |
+
const size_t grid_z = 1;
|
| 1003 |
+
const size_t total_threadgroup_size = tgx * tgy * tgz;
|
| 1004 |
+
if (total_threadgroup_size > f32_scatter_fn->max_threadgroup_threads) {
|
| 1005 |
+
return gptoss_status_invalid_argument;
|
| 1006 |
+
}
|
| 1007 |
+
const struct gptoss_scatter_args args = {
|
| 1008 |
+
.tokens = num_tokens,
|
| 1009 |
+
.active_experts_per_token = num_active_experts,
|
| 1010 |
+
.token_stride = num_channels,
|
| 1011 |
+
};
|
| 1012 |
+
|
| 1013 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1014 |
+
command_buffer, f32_scatter_fn,
|
| 1015 |
+
tgx, tgy, tgz,
|
| 1016 |
+
grid_x, grid_y, grid_z,
|
| 1017 |
+
sizeof(args), &args,
|
| 1018 |
+
5,
|
| 1019 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},
|
| 1020 |
+
(const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},
|
| 1021 |
+
/*threadgroup_buffer_size=*/0);
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
|
| 1025 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1026 |
+
const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,
|
| 1027 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 1028 |
+
size_t input_offset,
|
| 1029 |
+
const struct gptoss_metal_buffer* expert_predictions_buffer,
|
| 1030 |
+
size_t expert_predictions_offset,
|
| 1031 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 1032 |
+
size_t expert_offsets_offset,
|
| 1033 |
+
const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
|
| 1034 |
+
size_t intra_expert_offsets_offset,
|
| 1035 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1036 |
+
size_t output_offset,
|
| 1037 |
+
uint32_t num_channels,
|
| 1038 |
+
uint32_t num_tokens,
|
| 1039 |
+
uint32_t num_active_experts)
|
| 1040 |
+
{
|
| 1041 |
+
if (command_buffer->object == NULL || f32_gather_and_accumulate_e4_fn->pipeline_state_object == NULL) {
|
| 1042 |
+
return gptoss_status_invalid_state;
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
if (num_channels % 4 != 0) {
|
| 1046 |
+
return gptoss_status_invalid_argument;
|
| 1047 |
+
}
|
| 1048 |
+
|
| 1049 |
+
const size_t num_vecs = num_channels / 4;
|
| 1050 |
+
const size_t tgx = math_min(num_vecs, 64);
|
| 1051 |
+
const size_t tgy = 1;
|
| 1052 |
+
const size_t tgz = 1;
|
| 1053 |
+
const size_t grid_x = math_ceil_div(num_vecs, tgx);
|
| 1054 |
+
const size_t grid_y = num_tokens;
|
| 1055 |
+
const size_t grid_z = 1;
|
| 1056 |
+
const size_t total_threadgroup_size = tgx * tgy * tgz;
|
| 1057 |
+
if (total_threadgroup_size > f32_gather_and_accumulate_e4_fn->max_threadgroup_threads) {
|
| 1058 |
+
return gptoss_status_invalid_argument;
|
| 1059 |
+
}
|
| 1060 |
+
const struct gptoss_gather_args args = {
|
| 1061 |
+
.tokens = num_tokens,
|
| 1062 |
+
.active_experts_per_token = num_active_experts,
|
| 1063 |
+
.token_stride = num_channels,
|
| 1064 |
+
};
|
| 1065 |
+
|
| 1066 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1067 |
+
command_buffer, f32_gather_and_accumulate_e4_fn,
|
| 1068 |
+
tgx, tgy, tgz,
|
| 1069 |
+
grid_x, grid_y, grid_z,
|
| 1070 |
+
sizeof(args), &args,
|
| 1071 |
+
5,
|
| 1072 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},
|
| 1073 |
+
(const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},
|
| 1074 |
+
/*threadgroup_buffer_size=*/0);
|
| 1075 |
+
}
|
| 1076 |
+
|
| 1077 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
|
| 1078 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1079 |
+
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,
|
| 1080 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 1081 |
+
size_t expert_offsets_offset,
|
| 1082 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 1083 |
+
size_t input_offset,
|
| 1084 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 1085 |
+
size_t weight_block_offset,
|
| 1086 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 1087 |
+
size_t weight_scale_offset,
|
| 1088 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 1089 |
+
size_t bias_offset,
|
| 1090 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1091 |
+
size_t output_offset,
|
| 1092 |
+
float swiglu_limit,
|
| 1093 |
+
uint32_t expert_stride_bytes,
|
| 1094 |
+
uint32_t num_tokens,
|
| 1095 |
+
uint32_t num_experts,
|
| 1096 |
+
uint32_t num_cols,
|
| 1097 |
+
uint32_t num_rows)
|
| 1098 |
+
{
|
| 1099 |
+
if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_swiglu_fn->pipeline_state_object == NULL) {
|
| 1100 |
+
return gptoss_status_invalid_state;
|
| 1101 |
+
}
|
| 1102 |
+
|
| 1103 |
+
if (num_cols % 32 != 0) {
|
| 1104 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
| 1105 |
+
num_cols);
|
| 1106 |
+
return gptoss_status_invalid_argument;
|
| 1107 |
+
}
|
| 1108 |
+
|
| 1109 |
+
const struct gptoss_moe_dense_matmul_swiglu_args args = {
|
| 1110 |
+
.n = num_rows,
|
| 1111 |
+
.k = num_cols,
|
| 1112 |
+
.weight_blocks_expert_stride_bytes = expert_stride_bytes,
|
| 1113 |
+
.weight_scales_expert_stride_bytes = expert_stride_bytes,
|
| 1114 |
+
.bias_expert_stride_bytes = expert_stride_bytes,
|
| 1115 |
+
.swiglu_min = -swiglu_limit,
|
| 1116 |
+
.swiglu_max = swiglu_limit,
|
| 1117 |
+
};
|
| 1118 |
+
const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_swiglu_fn->simdgroup_threads;
|
| 1119 |
+
const uint32_t m = num_tokens;
|
| 1120 |
+
const uint32_t n = args.n;
|
| 1121 |
+
const uint32_t k = args.k;
|
| 1122 |
+
const uint32_t Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;
|
| 1123 |
+
const uint32_t Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;
|
| 1124 |
+
const uint32_t Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;
|
| 1125 |
+
const uint32_t Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;
|
| 1126 |
+
const uint32_t Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;
|
| 1127 |
+
if (Bm % Sg_Bm != 0) {
|
| 1128 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
|
| 1129 |
+
Bm, Sg_Bm);
|
| 1130 |
+
return gptoss_status_invalid_argument;
|
| 1131 |
+
}
|
| 1132 |
+
if (Bn % Sg_Bn != 0) {
|
| 1133 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
|
| 1134 |
+
Bn, Sg_Bn);
|
| 1135 |
+
return gptoss_status_invalid_argument;
|
| 1136 |
+
}
|
| 1137 |
+
|
| 1138 |
+
const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
|
| 1139 |
+
const size_t threadgroup_size_y = 1;
|
| 1140 |
+
const size_t threadgroup_size_z = 1;
|
| 1141 |
+
const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
|
| 1142 |
+
if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads) {
|
| 1143 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 1144 |
+
total_threadgroup_size, f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads);
|
| 1145 |
+
return gptoss_status_invalid_argument;
|
| 1146 |
+
}
|
| 1147 |
+
if (n % Bn != 0) {
|
| 1148 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
|
| 1149 |
+
n, Bn);
|
| 1150 |
+
return gptoss_status_invalid_argument;
|
| 1151 |
+
}
|
| 1152 |
+
if (k % Bk != 0) {
|
| 1153 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
|
| 1154 |
+
k, Bk);
|
| 1155 |
+
return gptoss_status_invalid_argument;
|
| 1156 |
+
}
|
| 1157 |
+
const size_t grid_x = n / Bn;
|
| 1158 |
+
const size_t grid_y = math_ceil_div(m, Bm);
|
| 1159 |
+
const size_t grid_z = num_experts;
|
| 1160 |
+
|
| 1161 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1162 |
+
command_buffer, f32_mf4w_moe_dense_matmul_swiglu_fn,
|
| 1163 |
+
threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
|
| 1164 |
+
grid_x, grid_y, grid_z,
|
| 1165 |
+
sizeof(args), &args,
|
| 1166 |
+
6,
|
| 1167 |
+
(const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
|
| 1168 |
+
(const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
|
| 1169 |
+
/*threadgroup_buffer_size=*/0);
|
| 1170 |
+
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
|
| 1174 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1175 |
+
const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,
|
| 1176 |
+
const struct gptoss_metal_buffer* expert_offsets_buffer,
|
| 1177 |
+
size_t expert_offsets_offset,
|
| 1178 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 1179 |
+
size_t input_offset,
|
| 1180 |
+
const struct gptoss_metal_buffer* weight_block_buffer,
|
| 1181 |
+
size_t weight_block_offset,
|
| 1182 |
+
const struct gptoss_metal_buffer* weight_scale_buffer,
|
| 1183 |
+
size_t weight_scale_offset,
|
| 1184 |
+
const struct gptoss_metal_buffer* bias_buffer,
|
| 1185 |
+
size_t bias_offset,
|
| 1186 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1187 |
+
size_t output_offset,
|
| 1188 |
+
uint32_t expert_stride_bytes,
|
| 1189 |
+
uint32_t num_tokens,
|
| 1190 |
+
uint32_t num_experts,
|
| 1191 |
+
uint32_t num_cols,
|
| 1192 |
+
uint32_t num_rows)
|
| 1193 |
+
{
|
| 1194 |
+
if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_fn->pipeline_state_object == NULL) {
|
| 1195 |
+
return gptoss_status_invalid_state;
|
| 1196 |
+
}
|
| 1197 |
+
|
| 1198 |
+
if (num_cols % 32 != 0) {
|
| 1199 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
|
| 1200 |
+
num_cols);
|
| 1201 |
+
return gptoss_status_invalid_argument;
|
| 1202 |
+
}
|
| 1203 |
+
const struct gptoss_moe_dense_matmul_args args = {
|
| 1204 |
+
.k = num_cols,
|
| 1205 |
+
.n = num_rows,
|
| 1206 |
+
.weight_blocks_expert_stride_bytes = expert_stride_bytes,
|
| 1207 |
+
.weight_scales_expert_stride_bytes = expert_stride_bytes,
|
| 1208 |
+
.bias_expert_stride_bytes = expert_stride_bytes,
|
| 1209 |
+
};
|
| 1210 |
+
|
| 1211 |
+
const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_fn->simdgroup_threads;
|
| 1212 |
+
const uint32_t m = num_tokens;
|
| 1213 |
+
const uint32_t n = args.n;
|
| 1214 |
+
const uint32_t k = args.k;
|
| 1215 |
+
const uint32_t Bm = MOE_DENSE_MATMUL_Bm;
|
| 1216 |
+
const uint32_t Bn = MOE_DENSE_MATMUL_Bn;
|
| 1217 |
+
const uint32_t Bk = MOE_DENSE_MATMUL_Bk;
|
| 1218 |
+
const uint32_t Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;
|
| 1219 |
+
const uint32_t Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;
|
| 1220 |
+
if (Bm % Sg_Bm != 0) {
|
| 1221 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
|
| 1222 |
+
Bm, Sg_Bm);
|
| 1223 |
+
return gptoss_status_invalid_argument;
|
| 1224 |
+
}
|
| 1225 |
+
if (Bn % Sg_Bn != 0) {
|
| 1226 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
|
| 1227 |
+
Bn, Sg_Bn);
|
| 1228 |
+
return gptoss_status_invalid_argument;
|
| 1229 |
+
}
|
| 1230 |
+
|
| 1231 |
+
const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
|
| 1232 |
+
const size_t threadgroup_size_y = 1;
|
| 1233 |
+
const size_t threadgroup_size_z = 1;
|
| 1234 |
+
const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
|
| 1235 |
+
if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads) {
|
| 1236 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
|
| 1237 |
+
total_threadgroup_size, f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads);
|
| 1238 |
+
return gptoss_status_invalid_argument;
|
| 1239 |
+
}
|
| 1240 |
+
if (n % Bn != 0) {
|
| 1241 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
|
| 1242 |
+
n, Bn);
|
| 1243 |
+
return gptoss_status_invalid_argument;
|
| 1244 |
+
}
|
| 1245 |
+
if (k % Bk != 0) {
|
| 1246 |
+
GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
|
| 1247 |
+
k, Bk);
|
| 1248 |
+
return gptoss_status_invalid_argument;
|
| 1249 |
+
}
|
| 1250 |
+
|
| 1251 |
+
const size_t grid_y = math_ceil_div(m, Bm);
|
| 1252 |
+
const size_t grid_x = n / Bn;
|
| 1253 |
+
const size_t grid_z = num_experts;
|
| 1254 |
+
|
| 1255 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1256 |
+
command_buffer, f32_mf4w_moe_dense_matmul_fn,
|
| 1257 |
+
threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
|
| 1258 |
+
grid_x, grid_y, grid_z,
|
| 1259 |
+
sizeof(args), &args,
|
| 1260 |
+
6,
|
| 1261 |
+
(const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
|
| 1262 |
+
(const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
|
| 1263 |
+
/*threadgroup_buffer_size=*/0);
|
| 1264 |
+
}
|
| 1265 |
+
|
| 1266 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
|
| 1267 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1268 |
+
const struct gptoss_metal_function* f32_accumulate_fn,
|
| 1269 |
+
size_t threadgroup_size,
|
| 1270 |
+
size_t max_threadgroups,
|
| 1271 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 1272 |
+
size_t input_offset,
|
| 1273 |
+
const struct gptoss_metal_buffer* expert_buffer,
|
| 1274 |
+
size_t expert_offset,
|
| 1275 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1276 |
+
size_t output_offset,
|
| 1277 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 1278 |
+
size_t control_offset,
|
| 1279 |
+
uint32_t num_channels,
|
| 1280 |
+
uint32_t num_tokens,
|
| 1281 |
+
uint32_t num_experts)
|
| 1282 |
+
{
|
| 1283 |
+
if (command_buffer->object == NULL || f32_accumulate_fn->pipeline_state_object == NULL) {
|
| 1284 |
+
return gptoss_status_invalid_state;
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
if (num_channels% 4 != 0) {
|
| 1288 |
+
return gptoss_status_invalid_argument;
|
| 1289 |
+
}
|
| 1290 |
+
|
| 1291 |
+
if (threadgroup_size == 0) {
|
| 1292 |
+
threadgroup_size = f32_accumulate_fn->max_threadgroup_threads;
|
| 1293 |
+
} else if (threadgroup_size > f32_accumulate_fn->max_threadgroup_threads) {
|
| 1294 |
+
return gptoss_status_invalid_argument;
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
const size_t num_vecs = num_channels / 4;
|
| 1298 |
+
const size_t num_vecs_per_expert = num_vecs * num_tokens;
|
| 1299 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 1300 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 1301 |
+
const struct gptoss_accumulate_args args = {
|
| 1302 |
+
.num_vecs_per_expert = num_vecs_per_expert,
|
| 1303 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 1304 |
+
.num_vecs = num_vecs,
|
| 1305 |
+
};
|
| 1306 |
+
|
| 1307 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1308 |
+
command_buffer, f32_accumulate_fn,
|
| 1309 |
+
threadgroup_size, 1, 1,
|
| 1310 |
+
num_threadgroups, num_tokens, 1,
|
| 1311 |
+
sizeof(args), &args,
|
| 1312 |
+
4,
|
| 1313 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer},
|
| 1314 |
+
(const size_t[]) {input_offset, expert_offset, output_offset, control_offset},
|
| 1315 |
+
/*threadgroup_buffer_size=*/0);
|
| 1316 |
+
}
|
| 1317 |
+
|
| 1318 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
|
| 1319 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1320 |
+
const struct gptoss_metal_function* f32_topk_fn,
|
| 1321 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 1322 |
+
size_t input_offset,
|
| 1323 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1324 |
+
size_t output_offset,
|
| 1325 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 1326 |
+
size_t control_offset,
|
| 1327 |
+
uint32_t num_tokens,
|
| 1328 |
+
uint32_t num_experts,
|
| 1329 |
+
uint32_t num_active_experts)
|
| 1330 |
+
{
|
| 1331 |
+
if (command_buffer->object == NULL || f32_topk_fn->pipeline_state_object == NULL) {
|
| 1332 |
+
return gptoss_status_invalid_state;
|
| 1333 |
+
}
|
| 1334 |
+
|
| 1335 |
+
if (num_experts != 32 && num_experts != 128) {
|
| 1336 |
+
return gptoss_status_invalid_argument;
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
if (num_active_experts != 4) {
|
| 1340 |
+
return gptoss_status_invalid_argument;
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
const struct gptoss_topk_args args = { 0 };
|
| 1344 |
+
|
| 1345 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1346 |
+
command_buffer, f32_topk_fn,
|
| 1347 |
+
/*threadgroup_size=*/32, 1, 1,
|
| 1348 |
+
num_tokens, 1, 1,
|
| 1349 |
+
sizeof(args), &args,
|
| 1350 |
+
3,
|
| 1351 |
+
(const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer},
|
| 1352 |
+
(const size_t[]) {input_offset, output_offset, control_offset},
|
| 1353 |
+
/*threadgroup_buffer_size=*/0);
|
| 1354 |
+
}
|
| 1355 |
+
|
| 1356 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
|
| 1357 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1358 |
+
const struct gptoss_metal_function* f32_sdpa_fn,
|
| 1359 |
+
const struct gptoss_metal_buffer* q_buffer,
|
| 1360 |
+
size_t q_offset,
|
| 1361 |
+
const struct gptoss_metal_buffer* kv_buffer,
|
| 1362 |
+
size_t kv_offset,
|
| 1363 |
+
const struct gptoss_metal_buffer* s_buffer,
|
| 1364 |
+
size_t s_offset,
|
| 1365 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 1366 |
+
size_t output_offset,
|
| 1367 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 1368 |
+
size_t control_offset,
|
| 1369 |
+
uint32_t window,
|
| 1370 |
+
uint32_t kv_stride,
|
| 1371 |
+
uint32_t num_q_tokens,
|
| 1372 |
+
uint32_t num_kv_tokens,
|
| 1373 |
+
uint32_t num_q_heads,
|
| 1374 |
+
uint32_t num_kv_heads,
|
| 1375 |
+
uint32_t head_dim)
|
| 1376 |
+
{
|
| 1377 |
+
if (command_buffer->object == NULL || f32_sdpa_fn->pipeline_state_object == NULL) {
|
| 1378 |
+
return gptoss_status_invalid_state;
|
| 1379 |
+
}
|
| 1380 |
+
|
| 1381 |
+
if (num_q_heads != num_kv_heads * 8) {
|
| 1382 |
+
GPTOSS_LOG_ERROR("number of Q heads (%" PRIu32 ") must be 8 times the number of KV heads (%" PRIu32 ")",
|
| 1383 |
+
num_q_heads, num_kv_heads);
|
| 1384 |
+
return gptoss_status_invalid_argument;
|
| 1385 |
+
}
|
| 1386 |
+
|
| 1387 |
+
if (head_dim != 64) {
|
| 1388 |
+
GPTOSS_LOG_ERROR("attention head dimension (%" PRIu32 ") must be 64", head_dim);
|
| 1389 |
+
return gptoss_status_invalid_argument;
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window);
|
| 1393 |
+
const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads,
|
| 1394 |
+
max_context_tokens * f32_sdpa_fn->simdgroup_threads);
|
| 1395 |
+
const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads);
|
| 1396 |
+
|
| 1397 |
+
const struct gptoss_sdpa_args args = {
|
| 1398 |
+
.qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),
|
| 1399 |
+
.num_kv_tokens = num_kv_tokens,
|
| 1400 |
+
.kv_stride = kv_stride,
|
| 1401 |
+
.window = window,
|
| 1402 |
+
};
|
| 1403 |
+
|
| 1404 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1405 |
+
command_buffer, f32_sdpa_fn,
|
| 1406 |
+
threadgroup_size, 1, 1,
|
| 1407 |
+
num_q_tokens, num_kv_heads, 1,
|
| 1408 |
+
sizeof(args), &args,
|
| 1409 |
+
5,
|
| 1410 |
+
(const struct gptoss_metal_buffer *[]) {q_buffer, kv_buffer, s_buffer, output_buffer, control_buffer},
|
| 1411 |
+
(const size_t[]) {q_offset, kv_offset, s_offset, output_offset, control_offset},
|
| 1412 |
+
/*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
|
| 1416 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1417 |
+
const struct gptoss_metal_function* f32_softmax_fn,
|
| 1418 |
+
size_t threadgroup_size,
|
| 1419 |
+
size_t max_threadgroups,
|
| 1420 |
+
const struct gptoss_metal_buffer* score_buffer,
|
| 1421 |
+
size_t score_offset,
|
| 1422 |
+
const struct gptoss_metal_buffer* argmax_buffer,
|
| 1423 |
+
size_t argmax_offset,
|
| 1424 |
+
const struct gptoss_metal_buffer* prob_buffer,
|
| 1425 |
+
size_t prob_offset,
|
| 1426 |
+
const struct gptoss_metal_buffer* sum_buffer,
|
| 1427 |
+
size_t sum_offset,
|
| 1428 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 1429 |
+
size_t control_offset,
|
| 1430 |
+
uint32_t num_channels,
|
| 1431 |
+
uint32_t num_tokens,
|
| 1432 |
+
float temperature,
|
| 1433 |
+
uint32_t* num_threadgroups_out,
|
| 1434 |
+
uint32_t* num_channels_per_threadgroup_out)
|
| 1435 |
+
{
|
| 1436 |
+
*num_threadgroups_out = 0;
|
| 1437 |
+
*num_channels_per_threadgroup_out = 0;
|
| 1438 |
+
if (command_buffer->object == NULL || f32_softmax_fn->pipeline_state_object == NULL) {
|
| 1439 |
+
return gptoss_status_invalid_state;
|
| 1440 |
+
}
|
| 1441 |
+
|
| 1442 |
+
const size_t num_vecs = num_channels;
|
| 1443 |
+
const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
|
| 1444 |
+
const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
|
| 1445 |
+
const struct gptoss_softmax_args args = {
|
| 1446 |
+
.num_vecs = num_vecs,
|
| 1447 |
+
.num_vecs_per_threadgroup = num_vecs_per_threadgroup,
|
| 1448 |
+
.max_threadgroups = max_threadgroups,
|
| 1449 |
+
.temperature = temperature,
|
| 1450 |
+
};
|
| 1451 |
+
|
| 1452 |
+
*num_threadgroups_out = num_threadgroups;
|
| 1453 |
+
*num_channels_per_threadgroup_out = num_vecs_per_threadgroup;
|
| 1454 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1455 |
+
command_buffer, f32_softmax_fn,
|
| 1456 |
+
threadgroup_size, 1, 1,
|
| 1457 |
+
num_threadgroups, num_tokens, 1,
|
| 1458 |
+
sizeof(args), &args,
|
| 1459 |
+
5,
|
| 1460 |
+
(const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer},
|
| 1461 |
+
(const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset},
|
| 1462 |
+
/*threadgroup_buffer_size=*/0);
|
| 1463 |
+
}
|
| 1464 |
+
|
| 1465 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
|
| 1466 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 1467 |
+
const struct gptoss_metal_function* f32_sample_fn,
|
| 1468 |
+
size_t min_threadgroup_size,
|
| 1469 |
+
const struct gptoss_metal_buffer* prob_buffer,
|
| 1470 |
+
size_t prob_offset,
|
| 1471 |
+
const struct gptoss_metal_buffer* sum_buffer,
|
| 1472 |
+
size_t sum_offset,
|
| 1473 |
+
const struct gptoss_metal_buffer* token_buffer,
|
| 1474 |
+
size_t token_offset,
|
| 1475 |
+
const struct gptoss_metal_buffer* control_buffer,
|
| 1476 |
+
size_t control_offset,
|
| 1477 |
+
uint64_t rng_seed,
|
| 1478 |
+
uint32_t rng_offset,
|
| 1479 |
+
uint32_t num_blocks,
|
| 1480 |
+
uint32_t num_channels,
|
| 1481 |
+
uint32_t num_channels_per_block)
|
| 1482 |
+
{
|
| 1483 |
+
if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {
|
| 1484 |
+
return gptoss_status_invalid_state;
|
| 1485 |
+
}
|
| 1486 |
+
|
| 1487 |
+
if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) {
|
| 1488 |
+
return gptoss_status_invalid_argument;
|
| 1489 |
+
}
|
| 1490 |
+
|
| 1491 |
+
if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) {
|
| 1492 |
+
return gptoss_status_invalid_argument;
|
| 1493 |
+
}
|
| 1494 |
+
|
| 1495 |
+
if (num_blocks > f32_sample_fn->max_threadgroup_threads) {
|
| 1496 |
+
return gptoss_status_invalid_argument;
|
| 1497 |
+
}
|
| 1498 |
+
|
| 1499 |
+
const struct gptoss_sample_args args = {
|
| 1500 |
+
.rng_seed = rng_seed,
|
| 1501 |
+
.rng_offset = rng_offset,
|
| 1502 |
+
.num_blocks = num_blocks,
|
| 1503 |
+
.num_dims = num_channels,
|
| 1504 |
+
.num_dims_per_block = num_channels_per_block,
|
| 1505 |
+
};
|
| 1506 |
+
|
| 1507 |
+
const size_t threadgroup_size = math_max(min_threadgroup_size,
|
| 1508 |
+
math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads));
|
| 1509 |
+
return gptoss_metal_command_buffer_encode_launch_kernel(
|
| 1510 |
+
command_buffer, f32_sample_fn,
|
| 1511 |
+
threadgroup_size, 1, 1,
|
| 1512 |
+
1, 1, 1,
|
| 1513 |
+
sizeof(args), &args,
|
| 1514 |
+
4,
|
| 1515 |
+
(const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer},
|
| 1516 |
+
(const size_t[]) {prob_offset, sum_offset, token_offset, control_offset},
|
| 1517 |
+
/*threadgroup_buffer_size=*/0);
|
| 1518 |
+
}
|
gptoss_kernels/source/metal.m
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import <Foundation/Foundation.h>
|
| 2 |
+
#import <Metal/Metal.h>
|
| 3 |
+
|
| 4 |
+
#include <dispatch/dispatch.h>
|
| 5 |
+
#include <mach-o/getsect.h>
|
| 6 |
+
|
| 7 |
+
#include <gpt-oss/types.h>
|
| 8 |
+
|
| 9 |
+
#include <internal/log.h>
|
| 10 |
+
#include <internal/metal.h>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
|
| 14 |
+
if (!device) {
|
| 15 |
+
return 0;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
const uint64_t target_registry_id = [device registryID];
|
| 19 |
+
|
| 20 |
+
io_iterator_t it = IO_OBJECT_NULL;
|
| 21 |
+
const kern_return_t kr = IOServiceGetMatchingServices(
|
| 22 |
+
kIOMainPortDefault,
|
| 23 |
+
IOServiceMatching("IOAccelerator"),
|
| 24 |
+
&it
|
| 25 |
+
);
|
| 26 |
+
if (kr != KERN_SUCCESS) {
|
| 27 |
+
GPTOSS_LOG_ERROR("failed to find IOAccelerator objects: error %d", kr);
|
| 28 |
+
return 0;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
size_t result = 0;
|
| 32 |
+
for (io_object_t obj = IOIteratorNext(it); obj != IO_OBJECT_NULL; obj = IOIteratorNext(it)) {
|
| 33 |
+
uint64_t registry_id = 0;
|
| 34 |
+
if (IORegistryEntryGetRegistryEntryID(obj, ®istry_id) == KERN_SUCCESS &&
|
| 35 |
+
registry_id == target_registry_id)
|
| 36 |
+
{
|
| 37 |
+
// Read "gpu-core-count" from this accelerator node
|
| 38 |
+
const CFTypeRef value = IORegistryEntryCreateCFProperty(
|
| 39 |
+
obj, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
|
| 40 |
+
if (value != NULL) {
|
| 41 |
+
if (CFGetTypeID(value) == CFNumberGetTypeID()) {
|
| 42 |
+
int32_t n = -1;
|
| 43 |
+
if (CFNumberGetValue((CFNumberRef) value, kCFNumberSInt32Type, &n) && n > 0) {
|
| 44 |
+
result = (size_t) n;
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
CFRelease(value);
|
| 48 |
+
}
|
| 49 |
+
IOObjectRelease(obj);
|
| 50 |
+
break;
|
| 51 |
+
}
|
| 52 |
+
IOObjectRelease(obj);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
IOObjectRelease(it);
|
| 56 |
+
return result;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
enum gptoss_status gptoss_metal_device_create_system_default(
|
| 60 |
+
struct gptoss_metal_device* device_out)
|
| 61 |
+
{
|
| 62 |
+
id<MTLDevice> device_obj = MTLCreateSystemDefaultDevice();
|
| 63 |
+
if (device_obj == nil) {
|
| 64 |
+
GPTOSS_LOG_ERROR("failed to create Metal device");
|
| 65 |
+
return gptoss_status_unsupported_system;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
device_out->object = (void*) device_obj;
|
| 69 |
+
device_out->num_cores = gptoss_metal_device_get_core_count(device_obj);
|
| 70 |
+
device_out->max_buffer_size = (size_t) [device_obj maxBufferLength];
|
| 71 |
+
device_out->max_threadgroup_memory = (size_t) [device_obj maxThreadgroupMemoryLength];
|
| 72 |
+
const MTLSize max_threadgroup_threads = [device_obj maxThreadsPerThreadgroup];
|
| 73 |
+
device_out->max_threadgroup_threads_x = (size_t) max_threadgroup_threads.width;
|
| 74 |
+
device_out->max_threadgroup_threads_y = (size_t) max_threadgroup_threads.height;
|
| 75 |
+
device_out->max_threadgroup_threads_z = (size_t) max_threadgroup_threads.depth;
|
| 76 |
+
return gptoss_status_success;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
enum gptoss_status gptoss_metal_device_release(
|
| 80 |
+
struct gptoss_metal_device* device)
|
| 81 |
+
{
|
| 82 |
+
if (device->object != NULL) {
|
| 83 |
+
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
| 84 |
+
[device_obj release];
|
| 85 |
+
}
|
| 86 |
+
memset(device, 0, sizeof(struct gptoss_metal_device));
|
| 87 |
+
return gptoss_status_success;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
extern const struct mach_header_64 __dso_handle;
|
| 91 |
+
|
| 92 |
+
enum gptoss_status gptoss_metal_library_create_default(
|
| 93 |
+
const struct gptoss_metal_device* device,
|
| 94 |
+
struct gptoss_metal_library* library_out)
|
| 95 |
+
{
|
| 96 |
+
enum gptoss_status status = gptoss_status_success;
|
| 97 |
+
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
| 98 |
+
id<MTLLibrary> library_obj = nil;
|
| 99 |
+
NSAutoreleasePool* autorelease_pool = nil;
|
| 100 |
+
dispatch_data_t library_blob = NULL;
|
| 101 |
+
|
| 102 |
+
unsigned long library_size = 0;
|
| 103 |
+
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
|
| 104 |
+
if (library_data != NULL) {
|
| 105 |
+
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
|
| 106 |
+
|
| 107 |
+
autorelease_pool = [[NSAutoreleasePool alloc] init];
|
| 108 |
+
NSError* error_obj = nil;
|
| 109 |
+
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
|
| 110 |
+
if (library_obj == nil) {
|
| 111 |
+
GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
|
| 112 |
+
status = gptoss_status_unsupported_system;
|
| 113 |
+
goto cleanup;
|
| 114 |
+
}
|
| 115 |
+
} else {
|
| 116 |
+
// Fall-back to loading from the bundle
|
| 117 |
+
library_obj = [device_obj newDefaultLibrary];
|
| 118 |
+
if (library_obj == nil) {
|
| 119 |
+
GPTOSS_LOG_ERROR("failed to create Metal default library");
|
| 120 |
+
status = gptoss_status_unsupported_system;
|
| 121 |
+
goto cleanup;
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
*library_out = (struct gptoss_metal_library) {
|
| 126 |
+
.object = (void*) library_obj,
|
| 127 |
+
};
|
| 128 |
+
|
| 129 |
+
cleanup:
|
| 130 |
+
if (library_blob != NULL) {
|
| 131 |
+
dispatch_release(library_blob);
|
| 132 |
+
}
|
| 133 |
+
if (autorelease_pool != nil) {
|
| 134 |
+
[autorelease_pool drain];
|
| 135 |
+
}
|
| 136 |
+
return status;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
enum gptoss_status gptoss_metal_library_release(
|
| 140 |
+
struct gptoss_metal_library* library)
|
| 141 |
+
{
|
| 142 |
+
if (library->object != NULL) {
|
| 143 |
+
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
|
| 144 |
+
[library_obj release];
|
| 145 |
+
}
|
| 146 |
+
memset(library, 0, sizeof(struct gptoss_metal_library));
|
| 147 |
+
return gptoss_status_success;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
enum gptoss_status gptoss_metal_function_create(
|
| 151 |
+
const struct gptoss_metal_library* library,
|
| 152 |
+
const char* name,
|
| 153 |
+
struct gptoss_metal_function* function_out)
|
| 154 |
+
{
|
| 155 |
+
__block NSString* error_string_obj = nil;
|
| 156 |
+
id<MTLFunction> function_obj = nil;
|
| 157 |
+
MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
|
| 158 |
+
__block id<MTLComputePipelineState> pipeline_state_obj = nil;
|
| 159 |
+
dispatch_semaphore_t pipeline_build_semaphore = NULL;
|
| 160 |
+
enum gptoss_status status = gptoss_status_success;
|
| 161 |
+
|
| 162 |
+
NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
|
| 163 |
+
id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
|
| 164 |
+
NSString* name_obj = [NSString stringWithUTF8String:name];
|
| 165 |
+
function_obj = [library_obj newFunctionWithName:name_obj];
|
| 166 |
+
if (function_obj == nil) {
|
| 167 |
+
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
|
| 168 |
+
status = gptoss_status_unsupported_system;
|
| 169 |
+
goto cleanup;
|
| 170 |
+
}
|
| 171 |
+
id<MTLDevice> device_obj = [library_obj device];
|
| 172 |
+
pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
|
| 173 |
+
[pipeline_descriptor_obj setComputeFunction:function_obj];
|
| 174 |
+
[pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];
|
| 175 |
+
|
| 176 |
+
pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
|
| 177 |
+
[device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
|
| 178 |
+
options:MTLPipelineOptionNone
|
| 179 |
+
completionHandler:^(id<MTLComputePipelineState> _Nullable new_state,
|
| 180 |
+
MTLComputePipelineReflection* _Nullable reflection,
|
| 181 |
+
NSError* _Nullable error_obj) {
|
| 182 |
+
if (new_state != nil) {
|
| 183 |
+
pipeline_state_obj = [new_state retain];
|
| 184 |
+
}
|
| 185 |
+
if (error_obj != nil) {
|
| 186 |
+
error_string_obj = [[error_obj localizedDescription] copy];
|
| 187 |
+
}
|
| 188 |
+
dispatch_semaphore_signal(pipeline_build_semaphore);
|
| 189 |
+
}];
|
| 190 |
+
dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);
|
| 191 |
+
|
| 192 |
+
if (pipeline_state_obj == nil) {
|
| 193 |
+
const char* error_string = "unknown error";
|
| 194 |
+
if (error_string_obj != nil) {
|
| 195 |
+
error_string = [error_string_obj UTF8String];
|
| 196 |
+
}
|
| 197 |
+
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
|
| 198 |
+
name, error_string);
|
| 199 |
+
status = gptoss_status_unsupported_system;
|
| 200 |
+
goto cleanup;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
// Commit
|
| 204 |
+
function_out->function_object = function_obj;
|
| 205 |
+
function_out->pipeline_state_object = pipeline_state_obj;
|
| 206 |
+
function_out->max_threadgroup_threads = (size_t) [pipeline_state_obj maxTotalThreadsPerThreadgroup];
|
| 207 |
+
function_out->simdgroup_threads = (size_t) [pipeline_state_obj threadExecutionWidth];
|
| 208 |
+
function_out->static_threadgroup_memory = (size_t) [pipeline_state_obj staticThreadgroupMemoryLength];
|
| 209 |
+
|
| 210 |
+
function_obj = nil;
|
| 211 |
+
pipeline_state_obj = nil;
|
| 212 |
+
|
| 213 |
+
cleanup:
|
| 214 |
+
if (function_obj != nil) {
|
| 215 |
+
[function_obj release];
|
| 216 |
+
}
|
| 217 |
+
if (pipeline_descriptor_obj != nil) {
|
| 218 |
+
[pipeline_descriptor_obj release];
|
| 219 |
+
}
|
| 220 |
+
if (error_string_obj != nil) {
|
| 221 |
+
[error_string_obj release];
|
| 222 |
+
}
|
| 223 |
+
if (pipeline_build_semaphore != NULL) {
|
| 224 |
+
dispatch_release(pipeline_build_semaphore);
|
| 225 |
+
}
|
| 226 |
+
if (autorelease_pool != nil) {
|
| 227 |
+
[autorelease_pool drain];
|
| 228 |
+
}
|
| 229 |
+
return status;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
enum gptoss_status gptoss_metal_function_release(
|
| 233 |
+
struct gptoss_metal_function* function)
|
| 234 |
+
{
|
| 235 |
+
if (function->pipeline_state_object != NULL) {
|
| 236 |
+
id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
|
| 237 |
+
[pipeline_state_obj release];
|
| 238 |
+
}
|
| 239 |
+
if (function->function_object != NULL) {
|
| 240 |
+
id<MTLFunction> function_obj = (id<MTLFunction>) function->function_object;
|
| 241 |
+
[function_obj release];
|
| 242 |
+
}
|
| 243 |
+
memset(function, 0, sizeof(struct gptoss_metal_function));
|
| 244 |
+
return gptoss_status_success;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
enum gptoss_status gptoss_metal_buffer_create(
|
| 248 |
+
const struct gptoss_metal_device* device,
|
| 249 |
+
size_t size,
|
| 250 |
+
const void* data,
|
| 251 |
+
struct gptoss_metal_buffer* buffer_out)
|
| 252 |
+
{
|
| 253 |
+
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
| 254 |
+
id<MTLBuffer> buffer_obj = nil;
|
| 255 |
+
if (data != NULL) {
|
| 256 |
+
buffer_obj = [device_obj newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];
|
| 257 |
+
} else {
|
| 258 |
+
buffer_obj = [device_obj newBufferWithLength:size options:MTLResourceStorageModeShared];
|
| 259 |
+
}
|
| 260 |
+
if (buffer_obj == nil) {
|
| 261 |
+
GPTOSS_LOG_ERROR("failed to create Metal buffer of size %zu", size);
|
| 262 |
+
return gptoss_status_unsupported_system;
|
| 263 |
+
}
|
| 264 |
+
buffer_out->object = (void*) buffer_obj;
|
| 265 |
+
buffer_out->size = size;
|
| 266 |
+
buffer_out->ptr = [buffer_obj contents];
|
| 267 |
+
return gptoss_status_success;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
enum gptoss_status gptoss_metal_buffer_wrap(
|
| 271 |
+
const struct gptoss_metal_device* device,
|
| 272 |
+
size_t size,
|
| 273 |
+
const void* data,
|
| 274 |
+
struct gptoss_metal_buffer* buffer_out)
|
| 275 |
+
{
|
| 276 |
+
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
| 277 |
+
id<MTLBuffer> buffer_obj = [device_obj newBufferWithBytesNoCopy:(void*) data length:size options:MTLResourceStorageModeShared deallocator:nil];
|
| 278 |
+
if (buffer_obj == nil) {
|
| 279 |
+
GPTOSS_LOG_ERROR("failed to wrap Metal buffer of size %zu", size);
|
| 280 |
+
return gptoss_status_unsupported_system;
|
| 281 |
+
}
|
| 282 |
+
buffer_out->object = (void*) buffer_obj;
|
| 283 |
+
buffer_out->size = size;
|
| 284 |
+
buffer_out->ptr = (void*) data;
|
| 285 |
+
return gptoss_status_success;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
enum gptoss_status gptoss_metal_buffer_release(
|
| 289 |
+
struct gptoss_metal_buffer* buffer)
|
| 290 |
+
{
|
| 291 |
+
if (buffer->object != NULL) {
|
| 292 |
+
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
|
| 293 |
+
[buffer_obj release];
|
| 294 |
+
}
|
| 295 |
+
memset(buffer, 0, sizeof(struct gptoss_metal_buffer));
|
| 296 |
+
return gptoss_status_success;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
enum gptoss_status gptoss_metal_command_queue_create(
|
| 300 |
+
const struct gptoss_metal_device* device,
|
| 301 |
+
struct gptoss_metal_command_queue* command_queue_out)
|
| 302 |
+
{
|
| 303 |
+
id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
|
| 304 |
+
id<MTLCommandQueue> command_queue_obj = [device_obj newCommandQueue];
|
| 305 |
+
if (command_queue_obj == nil) {
|
| 306 |
+
GPTOSS_LOG_ERROR("failed to create Metal command queue");
|
| 307 |
+
return gptoss_status_unsupported_system;
|
| 308 |
+
}
|
| 309 |
+
command_queue_out->object = (void*) command_queue_obj;
|
| 310 |
+
return gptoss_status_success;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
enum gptoss_status gptoss_metal_command_queue_release(
|
| 314 |
+
struct gptoss_metal_command_queue* command_queue)
|
| 315 |
+
{
|
| 316 |
+
if (command_queue->object != NULL) {
|
| 317 |
+
id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
|
| 318 |
+
[command_queue_obj release];
|
| 319 |
+
}
|
| 320 |
+
memset(command_queue, 0, sizeof(struct gptoss_metal_command_queue));
|
| 321 |
+
return gptoss_status_success;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
enum gptoss_status gptoss_metal_command_buffer_create(
|
| 325 |
+
const struct gptoss_metal_command_queue* command_queue,
|
| 326 |
+
struct gptoss_metal_command_buffer* command_buffer_out)
|
| 327 |
+
{
|
| 328 |
+
id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
|
| 329 |
+
id<MTLCommandBuffer> command_buffer_obj = [command_queue_obj commandBuffer];
|
| 330 |
+
if (command_buffer_obj == nil) {
|
| 331 |
+
GPTOSS_LOG_ERROR("failed to create Metal command buffer");
|
| 332 |
+
return gptoss_status_unsupported_system;
|
| 333 |
+
}
|
| 334 |
+
[command_buffer_obj retain];
|
| 335 |
+
command_buffer_out->object = (void*) command_buffer_obj;
|
| 336 |
+
return gptoss_status_success;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
|
| 340 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 341 |
+
const struct gptoss_metal_buffer* buffer,
|
| 342 |
+
size_t offset,
|
| 343 |
+
size_t size,
|
| 344 |
+
uint8_t fill_value)
|
| 345 |
+
{
|
| 346 |
+
if (command_buffer->object == NULL) {
|
| 347 |
+
return gptoss_status_invalid_state;
|
| 348 |
+
}
|
| 349 |
+
if (buffer->object == NULL) {
|
| 350 |
+
return gptoss_status_invalid_argument;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 354 |
+
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
|
| 355 |
+
|
| 356 |
+
id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
|
| 357 |
+
|
| 358 |
+
const NSRange range = NSMakeRange((NSUInteger) offset, (NSUInteger) size);
|
| 359 |
+
[command_encoder_obj fillBuffer:buffer_obj range:range value:fill_value];
|
| 360 |
+
[command_encoder_obj endEncoding];
|
| 361 |
+
|
| 362 |
+
return gptoss_status_success;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
|
| 366 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 367 |
+
const struct gptoss_metal_buffer* input_buffer,
|
| 368 |
+
size_t input_offset,
|
| 369 |
+
const struct gptoss_metal_buffer* output_buffer,
|
| 370 |
+
size_t output_offset,
|
| 371 |
+
size_t size)
|
| 372 |
+
{
|
| 373 |
+
if (command_buffer->object == NULL) {
|
| 374 |
+
return gptoss_status_invalid_state;
|
| 375 |
+
}
|
| 376 |
+
if (input_buffer->object == NULL) {
|
| 377 |
+
return gptoss_status_invalid_argument;
|
| 378 |
+
}
|
| 379 |
+
if (output_buffer->object == NULL) {
|
| 380 |
+
return gptoss_status_invalid_argument;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 384 |
+
id<MTLBuffer> input_buffer_obj = (id<MTLBuffer>) input_buffer->object;
|
| 385 |
+
id<MTLBuffer> output_buffer_obj = (id<MTLBuffer>) output_buffer->object;
|
| 386 |
+
|
| 387 |
+
id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
|
| 388 |
+
|
| 389 |
+
[command_encoder_obj copyFromBuffer:input_buffer_obj sourceOffset:(NSUInteger) input_offset
|
| 390 |
+
toBuffer:output_buffer_obj destinationOffset:(NSUInteger) output_offset
|
| 391 |
+
size:(NSUInteger) size];
|
| 392 |
+
[command_encoder_obj endEncoding];
|
| 393 |
+
|
| 394 |
+
return gptoss_status_success;
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
|
| 398 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 399 |
+
const struct gptoss_metal_function* function,
|
| 400 |
+
size_t threadgroup_size_x,
|
| 401 |
+
size_t threadgroup_size_y,
|
| 402 |
+
size_t threadgroup_size_z,
|
| 403 |
+
size_t num_threadgroups_x,
|
| 404 |
+
size_t num_threadgroups_y,
|
| 405 |
+
size_t num_threadgroups_z,
|
| 406 |
+
size_t params_size,
|
| 407 |
+
const void* params,
|
| 408 |
+
size_t num_device_buffers,
|
| 409 |
+
const struct gptoss_metal_buffer** device_buffers,
|
| 410 |
+
const size_t* device_buffer_offsets,
|
| 411 |
+
size_t threadgroup_buffer_size)
|
| 412 |
+
{
|
| 413 |
+
if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {
|
| 414 |
+
return gptoss_status_invalid_state;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 418 |
+
id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
|
| 419 |
+
|
| 420 |
+
id<MTLComputeCommandEncoder> command_encoder_obj = [command_buffer_obj computeCommandEncoder];
|
| 421 |
+
|
| 422 |
+
// Set kernel arguments
|
| 423 |
+
[command_encoder_obj setComputePipelineState:pipeline_state_obj];
|
| 424 |
+
[command_encoder_obj setBytes:params length:params_size atIndex:0];
|
| 425 |
+
for (size_t i = 0; i < num_device_buffers; ++i) {
|
| 426 |
+
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) device_buffers[i]->object;
|
| 427 |
+
const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i];
|
| 428 |
+
[command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];
|
| 429 |
+
}
|
| 430 |
+
if (threadgroup_buffer_size != 0) {
|
| 431 |
+
[command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0];
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
// Dispatch kernel
|
| 435 |
+
const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);
|
| 436 |
+
const MTLSize num_threadgroups = MTLSizeMake(num_threadgroups_x, num_threadgroups_y, num_threadgroups_z);
|
| 437 |
+
[command_encoder_obj dispatchThreadgroups:num_threadgroups threadsPerThreadgroup:threadgroup_size];
|
| 438 |
+
[command_encoder_obj endEncoding];
|
| 439 |
+
|
| 440 |
+
return gptoss_status_success;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
enum gptoss_status gptoss_metal_command_buffer_commit(
|
| 444 |
+
const struct gptoss_metal_command_buffer* command_buffer)
|
| 445 |
+
{
|
| 446 |
+
if (command_buffer->object == NULL) {
|
| 447 |
+
return gptoss_status_invalid_state;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 451 |
+
[command_buffer_obj commit];
|
| 452 |
+
return gptoss_status_success;
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
enum gptoss_status gptoss_metal_command_buffer_wait_completion(
|
| 456 |
+
const struct gptoss_metal_command_buffer* command_buffer,
|
| 457 |
+
double* elapsed_seconds)
|
| 458 |
+
{
|
| 459 |
+
if (command_buffer->object == NULL) {
|
| 460 |
+
return gptoss_status_invalid_state;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 464 |
+
[command_buffer_obj waitUntilCompleted];
|
| 465 |
+
if (elapsed_seconds != NULL) {
|
| 466 |
+
const CFTimeInterval start_time = [command_buffer_obj GPUStartTime];
|
| 467 |
+
const CFTimeInterval end_time = [command_buffer_obj GPUEndTime];
|
| 468 |
+
*elapsed_seconds = (double) end_time - (double) start_time;
|
| 469 |
+
}
|
| 470 |
+
return gptoss_status_success;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
enum gptoss_status gptoss_metal_command_buffer_release(
|
| 474 |
+
struct gptoss_metal_command_buffer* command_buffer)
|
| 475 |
+
{
|
| 476 |
+
if (command_buffer->object != NULL) {
|
| 477 |
+
id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
|
| 478 |
+
[command_buffer_obj release];
|
| 479 |
+
}
|
| 480 |
+
memset(command_buffer, 0, sizeof(struct gptoss_metal_command_buffer));
|
| 481 |
+
return gptoss_status_success;
|
| 482 |
+
}
|
gptoss_kernels/source/model.c
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <assert.h>
|
| 2 |
+
#include <inttypes.h>
|
| 3 |
+
#include <stdatomic.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
#include <stdlib.h>
|
| 6 |
+
#include <string.h>
|
| 7 |
+
|
| 8 |
+
#include <errno.h> // errno, EISDIR, ENOENT, ENOTDIR
|
| 9 |
+
#include <fcntl.h> // open
|
| 10 |
+
#include <mach/vm_page_size.h> // vm_page_size
|
| 11 |
+
#include <sys/mman.h> // mmap, PROT_READ, MAP_PRIVATE
|
| 12 |
+
#include <sys/stat.h> // fstat, stat
|
| 13 |
+
#include <sys/types.h> // off_t, ssize_t
|
| 14 |
+
#include <unistd.h> // close
|
| 15 |
+
|
| 16 |
+
#include <gpt-oss.h>
|
| 17 |
+
|
| 18 |
+
#include "internal/datatype.h"
|
| 19 |
+
#include "internal/kernel-args.h" // gptoss_expert_prediction
|
| 20 |
+
#include "internal/log.h"
|
| 21 |
+
#include "internal/uuid.h"
|
| 22 |
+
#include "internal/storage.h"
|
| 23 |
+
#include "internal/math.h"
|
| 24 |
+
#include "internal/model.h"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
static size_t round_up_to_page_size(size_t bytes) {
|
| 28 |
+
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
| 29 |
+
if ((bytes & page_size_mask) != 0) {
|
| 30 |
+
bytes |= page_size_mask;
|
| 31 |
+
bytes += 1;
|
| 32 |
+
}
|
| 33 |
+
return bytes;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
static size_t round_down_to_page_size(size_t bytes) {
|
| 37 |
+
const size_t page_size_mask = (size_t) vm_page_size - 1;
|
| 38 |
+
return bytes & ~page_size_mask;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
static enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {
|
| 42 |
+
assert(fd != -1);
|
| 43 |
+
assert(data != NULL);
|
| 44 |
+
assert(size != 0);
|
| 45 |
+
|
| 46 |
+
size_t bytes_to_read = size;
|
| 47 |
+
char* current_byte = (char*) data;
|
| 48 |
+
do {
|
| 49 |
+
const ssize_t read_result = read(fd, current_byte, bytes_to_read);
|
| 50 |
+
if (read_result < 0) {
|
| 51 |
+
GPTOSS_LOG_ERROR("reading %zu bytes from file %s failed with error %d",
|
| 52 |
+
size, path, errno);
|
| 53 |
+
return gptoss_status_io_error;
|
| 54 |
+
}
|
| 55 |
+
current_byte += (size_t) read_result;
|
| 56 |
+
bytes_to_read -= (size_t) read_result;
|
| 57 |
+
} while (bytes_to_read != 0);
|
| 58 |
+
return gptoss_status_success;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
|
| 62 |
+
// radvisory.ra_count is int, so we can't prefetch 2GB+ at once
|
| 63 |
+
const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);
|
| 64 |
+
do {
|
| 65 |
+
const size_t prefetch_size = math_min(size, prefetch_max);
|
| 66 |
+
const struct radvisory ra = {
|
| 67 |
+
.ra_offset = offset,
|
| 68 |
+
.ra_count = (int) prefetch_size,
|
| 69 |
+
};
|
| 70 |
+
if (fcntl(fd, F_RDADVISE, &ra) == -1) {
|
| 71 |
+
GPTOSS_LOG_WARNING("fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\n",
|
| 72 |
+
path, (size_t) ra.ra_offset, ra.ra_count, errno);
|
| 73 |
+
return;
|
| 74 |
+
}
|
| 75 |
+
offset += prefetch_size;
|
| 76 |
+
size -= prefetch_size;
|
| 77 |
+
} while (size != 0);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
|
| 81 |
+
const char* path,
|
| 82 |
+
gptoss_model_t* model_out)
|
| 83 |
+
{
|
| 84 |
+
*model_out = NULL;
|
| 85 |
+
|
| 86 |
+
enum gptoss_status status = gptoss_status_success;
|
| 87 |
+
struct gptoss_model* model = NULL;
|
| 88 |
+
struct gptoss_tokenizer* tokenizer = NULL;
|
| 89 |
+
int fd = -1;
|
| 90 |
+
size_t file_offset = 0;
|
| 91 |
+
|
| 92 |
+
fd = open(path, O_RDONLY);
|
| 93 |
+
if (fd == -1) {
|
| 94 |
+
GPTOSS_LOG_ERROR("open(%s) failed with error %d", path, errno);
|
| 95 |
+
switch (errno) {
|
| 96 |
+
case EISDIR:
|
| 97 |
+
case ENOENT:
|
| 98 |
+
case ENOTDIR:
|
| 99 |
+
status = gptoss_status_invalid_argument;
|
| 100 |
+
break;
|
| 101 |
+
default:
|
| 102 |
+
status = gptoss_status_io_error;
|
| 103 |
+
break;
|
| 104 |
+
}
|
| 105 |
+
goto cleanup;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
struct gptoss_file_header file_header;
|
| 109 |
+
status = read_fd(fd, &file_header, sizeof(file_header), path);
|
| 110 |
+
if (status != gptoss_status_success) {
|
| 111 |
+
goto cleanup;
|
| 112 |
+
}
|
| 113 |
+
file_offset += sizeof(file_header);
|
| 114 |
+
|
| 115 |
+
if (file_header.magic[0] != 'G' ||
|
| 116 |
+
file_header.magic[1] != 'P' ||
|
| 117 |
+
file_header.magic[2] != 'T' ||
|
| 118 |
+
file_header.magic[3] != '-' ||
|
| 119 |
+
file_header.magic[4] != 'O' ||
|
| 120 |
+
file_header.magic[5] != 'S' ||
|
| 121 |
+
file_header.magic[6] != 'S' ||
|
| 122 |
+
file_header.magic[7] != ' ' ||
|
| 123 |
+
file_header.magic[8] != 'v' ||
|
| 124 |
+
file_header.magic[9] != '1' ||
|
| 125 |
+
file_header.magic[10] != '.' ||
|
| 126 |
+
file_header.magic[11] != '0' ||
|
| 127 |
+
file_header.zero != 0)
|
| 128 |
+
{
|
| 129 |
+
GPTOSS_LOG_ERROR("invalid magic in file %s", path);
|
| 130 |
+
status = gptoss_status_invalid_argument;
|
| 131 |
+
goto cleanup;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
struct gptoss_uuid model_uuid;
|
| 135 |
+
status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);
|
| 136 |
+
if (status != gptoss_status_success) {
|
| 137 |
+
goto cleanup;
|
| 138 |
+
}
|
| 139 |
+
file_offset += sizeof(model_uuid);
|
| 140 |
+
|
| 141 |
+
if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {
|
| 142 |
+
GPTOSS_LOG_ERROR("unsupported model UUID " UUID_FORMAT, UUID_ARGS(model_uuid));
|
| 143 |
+
status = gptoss_status_invalid_argument;
|
| 144 |
+
goto cleanup;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
struct gptoss_gptoss_model_header model_header;
|
| 148 |
+
status = read_fd(fd, &model_header, sizeof(model_header), path);
|
| 149 |
+
if (status != gptoss_status_success) {
|
| 150 |
+
goto cleanup;
|
| 151 |
+
}
|
| 152 |
+
file_offset += sizeof(model_header);
|
| 153 |
+
|
| 154 |
+
struct gptoss_uuid layout_uuid;
|
| 155 |
+
status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);
|
| 156 |
+
if (status != gptoss_status_success) {
|
| 157 |
+
goto cleanup;
|
| 158 |
+
}
|
| 159 |
+
file_offset += sizeof(layout_uuid);
|
| 160 |
+
|
| 161 |
+
if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {
|
| 162 |
+
GPTOSS_LOG_ERROR("unsupported layout UUID " UUID_FORMAT, UUID_ARGS(layout_uuid));
|
| 163 |
+
status = gptoss_status_invalid_argument;
|
| 164 |
+
goto cleanup;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);
|
| 168 |
+
model = malloc(model_size);
|
| 169 |
+
if (model == NULL) {
|
| 170 |
+
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for model descriptor", model_size);
|
| 171 |
+
status = gptoss_status_insufficient_memory;
|
| 172 |
+
goto cleanup;
|
| 173 |
+
}
|
| 174 |
+
memset(model, 0, model_size);
|
| 175 |
+
|
| 176 |
+
atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);
|
| 177 |
+
model->context_length = model_header.context_length;
|
| 178 |
+
model->num_blocks = model_header.num_blocks;
|
| 179 |
+
model->num_experts = model_header.num_experts;
|
| 180 |
+
model->num_active_experts = model_header.num_active_experts;
|
| 181 |
+
model->embedding_dim = model_header.embedding_dim;
|
| 182 |
+
model->mlp_dim = model_header.mlp_dim;
|
| 183 |
+
model->swiglu_limit = model_header.swiglu_limit;
|
| 184 |
+
model->head_dim = model_header.head_dim;
|
| 185 |
+
model->num_heads = model_header.num_heads;
|
| 186 |
+
model->num_kv_heads = model_header.num_kv_heads;
|
| 187 |
+
model->attention_window = model_header.attention_window;
|
| 188 |
+
model->rope_theta = model_header.rope_theta;
|
| 189 |
+
model->interpolation_scale = model_header.interpolation_scale;
|
| 190 |
+
model->yarn_offset = model_header.yarn_offset;
|
| 191 |
+
model->yarn_scale = model_header.yarn_scale;
|
| 192 |
+
model->yarn_multiplier = model_header.yarn_multiplier;
|
| 193 |
+
model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
|
| 194 |
+
|
| 195 |
+
struct gptoss_uuid tokenizer_uuid;
|
| 196 |
+
status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
|
| 197 |
+
if (status != gptoss_status_success) {
|
| 198 |
+
goto cleanup;
|
| 199 |
+
}
|
| 200 |
+
file_offset += sizeof(tokenizer_uuid);
|
| 201 |
+
|
| 202 |
+
if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {
|
| 203 |
+
GPTOSS_LOG_ERROR("unsupported tokenizer UUID " UUID_FORMAT, UUID_ARGS(tokenizer_uuid));
|
| 204 |
+
status = gptoss_status_invalid_argument;
|
| 205 |
+
goto cleanup;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
struct gptoss_tiktoken_tokenizer_header tokenizer_header;
|
| 209 |
+
status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);
|
| 210 |
+
if (status != gptoss_status_success) {
|
| 211 |
+
goto cleanup;
|
| 212 |
+
}
|
| 213 |
+
file_offset += sizeof(tokenizer_header);
|
| 214 |
+
|
| 215 |
+
tokenizer = malloc(sizeof(struct gptoss_tokenizer));
|
| 216 |
+
if (tokenizer == NULL) {
|
| 217 |
+
GPTOSS_LOG_ERROR("failed to allocate %zu bytes for tokenizer descriptor", sizeof(struct gptoss_tokenizer));
|
| 218 |
+
status = gptoss_status_insufficient_memory;
|
| 219 |
+
goto cleanup;
|
| 220 |
+
}
|
| 221 |
+
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
| 222 |
+
// Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)
|
| 223 |
+
memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));
|
| 224 |
+
|
| 225 |
+
atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 226 |
+
tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;
|
| 227 |
+
tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;
|
| 228 |
+
model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;
|
| 229 |
+
for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {
|
| 230 |
+
struct gptoss_uuid token_uuid;
|
| 231 |
+
status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);
|
| 232 |
+
if (status != gptoss_status_success) {
|
| 233 |
+
goto cleanup;
|
| 234 |
+
}
|
| 235 |
+
file_offset += sizeof(token_uuid);
|
| 236 |
+
|
| 237 |
+
const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);
|
| 238 |
+
if (token != gptoss_special_token_invalid) {
|
| 239 |
+
tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
const size_t tokenizer_start_offset = file_offset;
|
| 244 |
+
const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;
|
| 245 |
+
const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);
|
| 246 |
+
const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;
|
| 247 |
+
void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);
|
| 248 |
+
if (tokenizer_mapping_ptr == (void*) -1) {
|
| 249 |
+
GPTOSS_LOG_ERROR("failed to mmap(%s) tokenizer at offset %zu size %zu",
|
| 250 |
+
path, tokenizer_mapping_start, tokenizer_mapping_size);
|
| 251 |
+
status = gptoss_status_io_error;
|
| 252 |
+
goto cleanup;
|
| 253 |
+
}
|
| 254 |
+
tokenizer->mapping_ptr = tokenizer_mapping_ptr;
|
| 255 |
+
tokenizer->mapping_size = tokenizer_mapping_size;
|
| 256 |
+
tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);
|
| 257 |
+
tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;
|
| 258 |
+
|
| 259 |
+
if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {
|
| 260 |
+
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, tokenizer_mapping_size, errno);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);
|
| 264 |
+
|
| 265 |
+
struct stat model_stat = {0};
|
| 266 |
+
int stat_result = fstat(fd, &model_stat);
|
| 267 |
+
if (stat_result != 0) {
|
| 268 |
+
GPTOSS_LOG_ERROR("stat(%s) failed with error %d", path, errno);
|
| 269 |
+
status = gptoss_status_io_error;
|
| 270 |
+
goto cleanup;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);
|
| 274 |
+
const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;
|
| 275 |
+
void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);
|
| 276 |
+
if (model_mapping_ptr == (void*) -1) {
|
| 277 |
+
GPTOSS_LOG_ERROR("failed to mmap(%s) model weights at offset %zu size %zu",
|
| 278 |
+
path, model_mapping_start, model_mapping_size);
|
| 279 |
+
status = gptoss_status_io_error;
|
| 280 |
+
goto cleanup;
|
| 281 |
+
}
|
| 282 |
+
model->mapping_ptr = model_mapping_ptr;
|
| 283 |
+
model->mapping_size = model_mapping_size;
|
| 284 |
+
|
| 285 |
+
if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {
|
| 286 |
+
GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
|
| 290 |
+
|
| 291 |
+
if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
|
| 292 |
+
GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
|
| 293 |
+
} else {
|
| 294 |
+
model->lock_memory = true;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
// Initialize Metal
|
| 298 |
+
status = gptoss_metal_device_create_system_default(&model->device);
|
| 299 |
+
if (status != gptoss_status_success) {
|
| 300 |
+
goto cleanup;
|
| 301 |
+
}
|
| 302 |
+
model->max_threadgroups = model->device.num_cores * 3;
|
| 303 |
+
status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);
|
| 304 |
+
if (status != gptoss_status_success) {
|
| 305 |
+
goto cleanup;
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
// Metal kernels
|
| 309 |
+
status = gptoss_metal_library_create_default(&model->device, &model->library);
|
| 310 |
+
if (status != gptoss_status_success) {
|
| 311 |
+
goto cleanup;
|
| 312 |
+
}
|
| 313 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_bf16_f32_embeddings", &model->bf16_f32_embeddings_fn);
|
| 314 |
+
if (status != gptoss_status_success) {
|
| 315 |
+
goto cleanup;
|
| 316 |
+
}
|
| 317 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_rmsnorm", &model->f32_bf16w_rmsnorm_fn);
|
| 318 |
+
if (status != gptoss_status_success) {
|
| 319 |
+
goto cleanup;
|
| 320 |
+
}
|
| 321 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul", &model->f32_bf16w_matmul_fn);
|
| 322 |
+
if (status != gptoss_status_success) {
|
| 323 |
+
goto cleanup;
|
| 324 |
+
}
|
| 325 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn);
|
| 326 |
+
if (status != gptoss_status_success) {
|
| 327 |
+
goto cleanup;
|
| 328 |
+
}
|
| 329 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn);
|
| 330 |
+
if (status != gptoss_status_success) {
|
| 331 |
+
goto cleanup;
|
| 332 |
+
}
|
| 333 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn);
|
| 334 |
+
if (status != gptoss_status_success) {
|
| 335 |
+
goto cleanup;
|
| 336 |
+
}
|
| 337 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn);
|
| 338 |
+
if (status != gptoss_status_success) {
|
| 339 |
+
goto cleanup;
|
| 340 |
+
}
|
| 341 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
|
| 342 |
+
if (status != gptoss_status_success) {
|
| 343 |
+
goto cleanup;
|
| 344 |
+
}
|
| 345 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_rope", &model->f32_rope_fn);
|
| 346 |
+
if (status != gptoss_status_success) {
|
| 347 |
+
goto cleanup;
|
| 348 |
+
}
|
| 349 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_expert_routing_metadata", &model->f32_expert_routing_metadata_fn);
|
| 350 |
+
if (status != gptoss_status_success) {
|
| 351 |
+
goto cleanup;
|
| 352 |
+
}
|
| 353 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_scatter_e4", &model->f32_scatter_e4_fn);
|
| 354 |
+
if (status != gptoss_status_success) {
|
| 355 |
+
goto cleanup;
|
| 356 |
+
}
|
| 357 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul_swiglu", &model->f32_mf4w_moe_dense_matmul_swiglu_fn);
|
| 358 |
+
if (status != gptoss_status_success) {
|
| 359 |
+
goto cleanup;
|
| 360 |
+
}
|
| 361 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul", &model->f32_mf4w_moe_dense_matmul_fn);
|
| 362 |
+
if (status != gptoss_status_success) {
|
| 363 |
+
goto cleanup;
|
| 364 |
+
}
|
| 365 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_gather_and_accumulate_e4", &model->f32_gather_and_accumulate_e4_fn);
|
| 366 |
+
if (status != gptoss_status_success) {
|
| 367 |
+
goto cleanup;
|
| 368 |
+
}
|
| 369 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn);
|
| 370 |
+
if (status != gptoss_status_success) {
|
| 371 |
+
goto cleanup;
|
| 372 |
+
}
|
| 373 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul", &model->f32_mf4w_moe_matmul_fn);
|
| 374 |
+
if (status != gptoss_status_success) {
|
| 375 |
+
goto cleanup;
|
| 376 |
+
}
|
| 377 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_accumulate_e4", &model->f32_accumulate_e4_fn);
|
| 378 |
+
if (status != gptoss_status_success) {
|
| 379 |
+
goto cleanup;
|
| 380 |
+
}
|
| 381 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e32_k4", &model->f32_topk_softmax_e32_k4_fn);
|
| 382 |
+
if (status != gptoss_status_success) {
|
| 383 |
+
goto cleanup;
|
| 384 |
+
}
|
| 385 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e128_k4", &model->f32_topk_softmax_e128_k4_fn);
|
| 386 |
+
if (status != gptoss_status_success) {
|
| 387 |
+
goto cleanup;
|
| 388 |
+
}
|
| 389 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_softmax", &model->f32_softmax_fn);
|
| 390 |
+
if (status != gptoss_status_success) {
|
| 391 |
+
goto cleanup;
|
| 392 |
+
}
|
| 393 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
|
| 394 |
+
if (status != gptoss_status_success) {
|
| 395 |
+
goto cleanup;
|
| 396 |
+
}
|
| 397 |
+
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
|
| 398 |
+
if (status != gptoss_status_success) {
|
| 399 |
+
goto cleanup;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
// Kernel launch parameters
|
| 403 |
+
model->embeddings_threadgroup_size = 512;
|
| 404 |
+
model->attn_qkv_threadgroup_size = 1024;
|
| 405 |
+
model->attn_out_threadgroup_size = 768;
|
| 406 |
+
model->mlp_gate_threadgroup_size = 256;
|
| 407 |
+
model->mlp_swiglu_threadgroup_size = 192;
|
| 408 |
+
model->mlp_out_threadgroup_size = 192;
|
| 409 |
+
model->mlp_acc_threadgroup_size = 768;
|
| 410 |
+
model->unembedding_threadgroup_size = 416;
|
| 411 |
+
|
| 412 |
+
// Weight buffers
|
| 413 |
+
const char* current_ptr = (const char*) model->mapping_ptr;
|
| 414 |
+
|
| 415 |
+
const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 416 |
+
model->attn_rmsnorm_gain_offset = embedding_weight_size;
|
| 417 |
+
const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 418 |
+
model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;
|
| 419 |
+
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
|
| 420 |
+
const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 421 |
+
model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;
|
| 422 |
+
const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);
|
| 423 |
+
model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;
|
| 424 |
+
const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);
|
| 425 |
+
model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;
|
| 426 |
+
const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);
|
| 427 |
+
model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;
|
| 428 |
+
const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 429 |
+
model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;
|
| 430 |
+
model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;
|
| 431 |
+
const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 432 |
+
model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;
|
| 433 |
+
const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);
|
| 434 |
+
const size_t per_block_shared_weights_size =
|
| 435 |
+
rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +
|
| 436 |
+
rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;
|
| 437 |
+
model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;
|
| 438 |
+
model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;
|
| 439 |
+
const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 440 |
+
|
| 441 |
+
model->per_block_shared_weights_size = per_block_shared_weights_size;
|
| 442 |
+
const size_t shared_weights_size =
|
| 443 |
+
round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);
|
| 444 |
+
|
| 445 |
+
status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);
|
| 446 |
+
if (status != gptoss_status_success) {
|
| 447 |
+
GPTOSS_LOG_ERROR("failed to map expert-shared weight of size %zu onto a Metal buffer", shared_weights_size);
|
| 448 |
+
goto cleanup;
|
| 449 |
+
}
|
| 450 |
+
current_ptr += shared_weights_size;
|
| 451 |
+
model->weights_size += shared_weights_size;
|
| 452 |
+
|
| 453 |
+
const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);
|
| 454 |
+
model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;
|
| 455 |
+
const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);
|
| 456 |
+
model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;
|
| 457 |
+
const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);
|
| 458 |
+
model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;
|
| 459 |
+
const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);
|
| 460 |
+
model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;
|
| 461 |
+
const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);
|
| 462 |
+
model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;
|
| 463 |
+
const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
|
| 464 |
+
model->per_expert_block_weight_size =
|
| 465 |
+
mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;
|
| 466 |
+
const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);
|
| 467 |
+
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 468 |
+
status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);
|
| 469 |
+
if (status != gptoss_status_success) {
|
| 470 |
+
GPTOSS_LOG_ERROR("failed to map block #%" PRIu32 " MoE weight of size %zu onto a Metal buffer",
|
| 471 |
+
n, moe_block_weight_size);
|
| 472 |
+
goto cleanup;
|
| 473 |
+
}
|
| 474 |
+
current_ptr += moe_block_weight_size;
|
| 475 |
+
model->weights_size += moe_block_weight_size;
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
// Commit tokenizer
|
| 479 |
+
model->tokenizer = tokenizer;
|
| 480 |
+
tokenizer = NULL;
|
| 481 |
+
|
| 482 |
+
// Commit model
|
| 483 |
+
*model_out = model;
|
| 484 |
+
model = NULL;
|
| 485 |
+
|
| 486 |
+
cleanup:
|
| 487 |
+
if (fd != -1) {
|
| 488 |
+
close(fd);
|
| 489 |
+
fd = -1;
|
| 490 |
+
}
|
| 491 |
+
gptoss_model_release(model); // does nothing if model is NULL
|
| 492 |
+
gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
|
| 493 |
+
return status;
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
|
| 497 |
+
gptoss_model_t model,
|
| 498 |
+
gptoss_tokenizer_t* tokenizer_out)
|
| 499 |
+
{
|
| 500 |
+
gptoss_tokenizer_t tokenizer = model->tokenizer;
|
| 501 |
+
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 502 |
+
*tokenizer_out = tokenizer;
|
| 503 |
+
return gptoss_status_success;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
|
| 507 |
+
gptoss_model_t model,
|
| 508 |
+
size_t* max_context_length_out)
|
| 509 |
+
{
|
| 510 |
+
*max_context_length_out = model->context_length;
|
| 511 |
+
return gptoss_status_success;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_retain(
|
| 515 |
+
gptoss_model_t model)
|
| 516 |
+
{
|
| 517 |
+
atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);
|
| 518 |
+
return gptoss_status_success;
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
enum gptoss_status GPTOSS_ABI gptoss_model_release(
|
| 522 |
+
gptoss_model_t model)
|
| 523 |
+
{
|
| 524 |
+
if (model != NULL) {
|
| 525 |
+
if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
|
| 526 |
+
gptoss_tokenizer_release(model->tokenizer);
|
| 527 |
+
|
| 528 |
+
// Weight buffers
|
| 529 |
+
gptoss_metal_buffer_release(&model->shared_weight_buffer);
|
| 530 |
+
for (uint32_t n = 0; n < model->num_blocks; n++) {
|
| 531 |
+
gptoss_metal_buffer_release(&model->block_weight_buffers[n]);
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
// Metal kernels
|
| 535 |
+
gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
|
| 536 |
+
gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
|
| 537 |
+
gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
|
| 538 |
+
gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);
|
| 539 |
+
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);
|
| 540 |
+
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);
|
| 541 |
+
gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);
|
| 542 |
+
gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
|
| 543 |
+
gptoss_metal_function_release(&model->f32_rope_fn);
|
| 544 |
+
gptoss_metal_function_release(&model->f32_expert_routing_metadata_fn);
|
| 545 |
+
gptoss_metal_function_release(&model->f32_scatter_e4_fn);
|
| 546 |
+
gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn);
|
| 547 |
+
gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn);
|
| 548 |
+
gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn);
|
| 549 |
+
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
|
| 550 |
+
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);
|
| 551 |
+
gptoss_metal_function_release(&model->f32_accumulate_e4_fn);
|
| 552 |
+
gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
|
| 553 |
+
gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
|
| 554 |
+
gptoss_metal_function_release(&model->f32_softmax_fn);
|
| 555 |
+
gptoss_metal_function_release(&model->f32_sample_fn);
|
| 556 |
+
gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
|
| 557 |
+
gptoss_metal_library_release(&model->library);
|
| 558 |
+
|
| 559 |
+
gptoss_metal_command_queue_release(&model->command_queue);
|
| 560 |
+
gptoss_metal_device_release(&model->device);
|
| 561 |
+
// Weight buffers
|
| 562 |
+
|
| 563 |
+
if (model->mapping_ptr != NULL && model->mapping_size != 0) {
|
| 564 |
+
if (model->lock_memory) {
|
| 565 |
+
if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
|
| 566 |
+
GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
|
| 571 |
+
GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);
|
| 576 |
+
memset(model, 0, model_size);
|
| 577 |
+
free(model);
|
| 578 |
+
}
|
| 579 |
+
}
|
| 580 |
+
return gptoss_status_success;
|
| 581 |
+
}
|
gptoss_kernels/source/moematmul.metal
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/kernel-args.h>
|
| 2 |
+
#include <metal_common>
|
| 3 |
+
#include <metal_compute>
|
| 4 |
+
#include <metal_math>
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
#include <metal_stdlib>
|
| 7 |
+
|
| 8 |
+
#pragma METAL fp math_mode(safe)
|
| 9 |
+
#pragma METAL fp contract(off)
|
| 10 |
+
#define ceil_div(a, b) (((a) + (b) - 1) / (b))
|
| 11 |
+
|
| 12 |
+
// Each simdgroup reduces all channels of the input and computes a single channel of the output
|
| 13 |
+
// + Efficient synchronization
|
| 14 |
+
// + Sequential memory access within a warp
|
| 15 |
+
// Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
|
| 16 |
+
// + Reuse input vector from threadgroup memory
|
| 17 |
+
// + Avoid synchronization across warps when doing reduction
|
| 18 |
+
|
| 19 |
+
kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
|
| 20 |
+
constant gptoss_moe_matmul_swiglu_args& args [[ buffer(0) ]],
|
| 21 |
+
const device float4* input [[ buffer(1) ]],
|
| 22 |
+
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
| 23 |
+
const device uint4* weight_blocks [[ buffer(3) ]],
|
| 24 |
+
const device uchar* weight_scales [[ buffer(4) ]],
|
| 25 |
+
const device bfloat* bias [[ buffer(5) ]],
|
| 26 |
+
device float* output [[ buffer(6) ]],
|
| 27 |
+
const device gptoss_control* control [[ buffer(7) ]],
|
| 28 |
+
uint3 gid [[threadgroup_position_in_grid]],
|
| 29 |
+
uint tid [[thread_index_in_threadgroup]],
|
| 30 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 31 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 32 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 33 |
+
{
|
| 34 |
+
const uint simdgroup_size = 32;
|
| 35 |
+
threadgroup float threadgroup_buffer[32];
|
| 36 |
+
if (control->abort != 0) {
|
| 37 |
+
return;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 41 |
+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
| 42 |
+
const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
|
| 43 |
+
|
| 44 |
+
input += 8 * (gid.y * num_column_vecs + simdgroup_tid);
|
| 45 |
+
weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
| 46 |
+
weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
| 47 |
+
bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
|
| 48 |
+
output += gid.y * args.num_rows + gid.x * (num_simdgroups / 2) + gid.z * args.output_expert_stride;
|
| 49 |
+
|
| 50 |
+
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 51 |
+
|
| 52 |
+
float4 sum4 = 0.0f;
|
| 53 |
+
do {
|
| 54 |
+
const uint4 wblock = *weight_blocks;
|
| 55 |
+
const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
|
| 56 |
+
uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
|
| 57 |
+
uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
|
| 58 |
+
wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
| 59 |
+
wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
| 60 |
+
wblock02468ACEGIKMOQSU += 0x70707070u;
|
| 61 |
+
wblock13579BDFHJLNPRTV += 0x70707070u;
|
| 62 |
+
wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
| 63 |
+
wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
| 64 |
+
const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
|
| 65 |
+
const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
| 66 |
+
const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
|
| 67 |
+
const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
| 68 |
+
const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
|
| 69 |
+
const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
|
| 70 |
+
const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
|
| 71 |
+
const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
|
| 72 |
+
const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
|
| 73 |
+
const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
|
| 74 |
+
const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
|
| 75 |
+
const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
|
| 76 |
+
|
| 77 |
+
const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
|
| 78 |
+
const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
|
| 79 |
+
const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
|
| 80 |
+
const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
|
| 81 |
+
const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
|
| 82 |
+
const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
|
| 83 |
+
const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
|
| 84 |
+
const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
|
| 85 |
+
|
| 86 |
+
const float4 i0123 = input[0];
|
| 87 |
+
const float4 i4567 = input[1];
|
| 88 |
+
const float4 i89AB = input[2];
|
| 89 |
+
const float4 iCDEF = input[3];
|
| 90 |
+
const float4 iGHIJ = input[4];
|
| 91 |
+
const float4 iKLMN = input[5];
|
| 92 |
+
const float4 iOPQR = input[6];
|
| 93 |
+
const float4 iSTUV = input[7];
|
| 94 |
+
|
| 95 |
+
float4 psum0 = i0123 * w0123;
|
| 96 |
+
float4 psum1 = i4567 * w4567;
|
| 97 |
+
psum0 = metal::fma(i89AB, w89AB, psum0);
|
| 98 |
+
psum1 = metal::fma(iCDEF, wCDEF, psum1);
|
| 99 |
+
psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
|
| 100 |
+
psum1 = metal::fma(iKLMN, wKLMN, psum1);
|
| 101 |
+
psum0 = metal::fma(iOPQR, wOPQR, psum0);
|
| 102 |
+
psum1 = metal::fma(iSTUV, wSTUV, psum1);
|
| 103 |
+
sum4 = metal::fma(psum0, wscale, sum4);
|
| 104 |
+
sum4 = metal::fma(psum1, wscale, sum4);
|
| 105 |
+
|
| 106 |
+
weight_blocks += simdgroup_size;
|
| 107 |
+
weight_scales += simdgroup_size;
|
| 108 |
+
input += 8 * simdgroup_size;
|
| 109 |
+
} while (--num_iter != 0);
|
| 110 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 111 |
+
float sum = sum2.x + sum2.y;
|
| 112 |
+
sum = metal::simd_sum(sum);
|
| 113 |
+
if (metal::simd_is_first()) {
|
| 114 |
+
sum += static_cast<float>(*bias);
|
| 115 |
+
threadgroup_buffer[simdgroup_idx] = sum;
|
| 116 |
+
}
|
| 117 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 118 |
+
if (tid * 2 < num_simdgroups) {
|
| 119 |
+
const float2 x = reinterpret_cast<const threadgroup float2*>(threadgroup_buffer)[tid];
|
| 120 |
+
const float swish_x = metal::min(x.x, args.swiglu_max);
|
| 121 |
+
const float linear_x = metal::clamp(x.y, args.swiglu_min, args.swiglu_max);
|
| 122 |
+
const float alpha = 1.702f;
|
| 123 |
+
const float swish_y = swish_x / (1.0f + metal::precise::exp(-alpha * swish_x));
|
| 124 |
+
const float swiglu_y = metal::fma(swish_y, linear_x, swish_y);
|
| 125 |
+
output[tid] = swiglu_y;
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
kernel void gptoss_f32_mf4w_moe_matmul(
|
| 130 |
+
constant gptoss_moe_matmul_args& args [[ buffer(0) ]],
|
| 131 |
+
const device float4* input [[ buffer(1) ]],
|
| 132 |
+
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
|
| 133 |
+
const device uint4* weight_blocks [[ buffer(3) ]],
|
| 134 |
+
const device uchar* weight_scales [[ buffer(4) ]],
|
| 135 |
+
const device bfloat* bias [[ buffer(5) ]],
|
| 136 |
+
device float* output [[ buffer(6) ]],
|
| 137 |
+
const device gptoss_control* control [[ buffer(7) ]],
|
| 138 |
+
uint3 gid [[threadgroup_position_in_grid]],
|
| 139 |
+
uint tid [[thread_index_in_threadgroup]],
|
| 140 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 141 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 142 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 143 |
+
{
|
| 144 |
+
const uint simdgroup_size = 32;
|
| 145 |
+
if (control->abort != 0) {
|
| 146 |
+
return;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
const uint num_column_vecs = args.num_column_vecs;
|
| 150 |
+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
|
| 151 |
+
const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
|
| 152 |
+
|
| 153 |
+
input += 8 * (gid.y * num_column_vecs + simdgroup_tid + gid.z * args.input_expert_stride);
|
| 154 |
+
weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
| 155 |
+
weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
|
| 156 |
+
bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
|
| 157 |
+
output += gid.y * args.num_rows + row + gid.z * args.output_expert_stride;
|
| 158 |
+
|
| 159 |
+
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
|
| 160 |
+
|
| 161 |
+
float4 sum4 = 0.0f;
|
| 162 |
+
do {
|
| 163 |
+
const uint4 wblock = *weight_blocks;
|
| 164 |
+
const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
|
| 165 |
+
uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
|
| 166 |
+
uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
|
| 167 |
+
wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
|
| 168 |
+
wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
|
| 169 |
+
wblock02468ACEGIKMOQSU += 0x70707070u;
|
| 170 |
+
wblock13579BDFHJLNPRTV += 0x70707070u;
|
| 171 |
+
wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
|
| 172 |
+
wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
|
| 173 |
+
const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
|
| 174 |
+
const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
|
| 175 |
+
const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
|
| 176 |
+
const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
|
| 177 |
+
const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
|
| 178 |
+
const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
|
| 179 |
+
const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
|
| 180 |
+
const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
|
| 181 |
+
const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
|
| 182 |
+
const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
|
| 183 |
+
const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
|
| 184 |
+
const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
|
| 185 |
+
|
| 186 |
+
const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
|
| 187 |
+
const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
|
| 188 |
+
const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
|
| 189 |
+
const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
|
| 190 |
+
const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
|
| 191 |
+
const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
|
| 192 |
+
const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
|
| 193 |
+
const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
|
| 194 |
+
|
| 195 |
+
const float4 i0123 = input[0];
|
| 196 |
+
const float4 i4567 = input[1];
|
| 197 |
+
const float4 i89AB = input[2];
|
| 198 |
+
const float4 iCDEF = input[3];
|
| 199 |
+
const float4 iGHIJ = input[4];
|
| 200 |
+
const float4 iKLMN = input[5];
|
| 201 |
+
const float4 iOPQR = input[6];
|
| 202 |
+
const float4 iSTUV = input[7];
|
| 203 |
+
|
| 204 |
+
float4 psum0 = i0123 * w0123;
|
| 205 |
+
float4 psum1 = i4567 * w4567;
|
| 206 |
+
psum0 = metal::fma(i89AB, w89AB, psum0);
|
| 207 |
+
psum1 = metal::fma(iCDEF, wCDEF, psum1);
|
| 208 |
+
psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
|
| 209 |
+
psum1 = metal::fma(iKLMN, wKLMN, psum1);
|
| 210 |
+
psum0 = metal::fma(iOPQR, wOPQR, psum0);
|
| 211 |
+
psum1 = metal::fma(iSTUV, wSTUV, psum1);
|
| 212 |
+
sum4 = metal::fma(psum0, wscale, sum4);
|
| 213 |
+
sum4 = metal::fma(psum1, wscale, sum4);
|
| 214 |
+
|
| 215 |
+
weight_blocks += simdgroup_size;
|
| 216 |
+
weight_scales += simdgroup_size;
|
| 217 |
+
input += 8 * simdgroup_size;
|
| 218 |
+
} while (--num_iter != 0);
|
| 219 |
+
const float2 sum2 = sum4.xy + sum4.zw;
|
| 220 |
+
float sum = sum2.x + sum2.y;
|
| 221 |
+
sum = metal::simd_sum(sum);
|
| 222 |
+
if (metal::simd_is_first()) {
|
| 223 |
+
sum += static_cast<float>(*bias);
|
| 224 |
+
*output = sum;
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
kernel void gptoss_f32_mf4w_moe_dense_matmul_swiglu(
|
| 229 |
+
constant gptoss_moe_dense_matmul_swiglu_args& params [[ buffer(0) ]],
|
| 230 |
+
const device uint* __restrict__ expert_offsets [[ buffer(1) ]],
|
| 231 |
+
const device float* lhs [[ buffer(2) ]],
|
| 232 |
+
const device uint* weight_blocks [[ buffer(3) ]],
|
| 233 |
+
const device uchar* weight_scales [[ buffer(4) ]],
|
| 234 |
+
const device bfloat* __restrict__ bias [[ buffer(5) ]],
|
| 235 |
+
device float* out [[ buffer(6) ]],
|
| 236 |
+
uint sg_id [[simdgroup_index_in_threadgroup]],
|
| 237 |
+
uint3 threads_per_tg [[threads_per_threadgroup]],
|
| 238 |
+
uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
|
| 239 |
+
uint3 gid [[thread_position_in_grid]],
|
| 240 |
+
uint3 tg_id [[threadgroup_position_in_grid]],
|
| 241 |
+
uint3 local_tid [[thread_position_in_threadgroup]])
|
| 242 |
+
{
|
| 243 |
+
constexpr uint Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;
|
| 244 |
+
constexpr uint Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;
|
| 245 |
+
constexpr uint Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;
|
| 246 |
+
constexpr uint Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;
|
| 247 |
+
constexpr uint Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;
|
| 248 |
+
|
| 249 |
+
// Assumptions about shapes.
|
| 250 |
+
assert(Bm % 8 == 0);
|
| 251 |
+
assert(Bn % 8 == 0);
|
| 252 |
+
assert(Bk % 8 == 0);
|
| 253 |
+
assert(Sg_Bm % 8 == 0);
|
| 254 |
+
assert(Sg_Bn % 8 == 0);
|
| 255 |
+
assert(Bm % Sg_Bm == 0);
|
| 256 |
+
assert(Bn % Sg_Bn == 0);
|
| 257 |
+
|
| 258 |
+
const uint K = params.k;
|
| 259 |
+
const uint N = params.n;
|
| 260 |
+
const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];
|
| 261 |
+
assert((K % 32) == 0);
|
| 262 |
+
assert((K % 8) == 0);
|
| 263 |
+
assert(N % Bn == 0);
|
| 264 |
+
assert(K % Bk == 0);
|
| 265 |
+
// Get row and col tg.
|
| 266 |
+
const uint row_tg = tg_id.y;
|
| 267 |
+
const uint col_tg = tg_id.x;
|
| 268 |
+
// Get row and col local tid.
|
| 269 |
+
const uint row_tg_offset = row_tg * Bm;
|
| 270 |
+
const uint col_tg_offset = col_tg * Bn;
|
| 271 |
+
if (row_tg_offset >= M || col_tg_offset >= N) {
|
| 272 |
+
return;
|
| 273 |
+
}
|
| 274 |
+
// Move lhs and output according to the passed offset.
|
| 275 |
+
const uint expert_offset = expert_offsets[tg_id.z];
|
| 276 |
+
lhs += expert_offset * K;
|
| 277 |
+
const uint N_output = N / 2;
|
| 278 |
+
out += expert_offset * N_output;
|
| 279 |
+
|
| 280 |
+
const uint S = params.weight_blocks_expert_stride_bytes;
|
| 281 |
+
const uint S_scales = params.weight_scales_expert_stride_bytes;
|
| 282 |
+
const uint S_bias = params.bias_expert_stride_bytes;
|
| 283 |
+
|
| 284 |
+
const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);
|
| 285 |
+
const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);
|
| 286 |
+
const device char* bi0 = reinterpret_cast<const device char*>(bias);
|
| 287 |
+
|
| 288 |
+
weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);
|
| 289 |
+
weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);
|
| 290 |
+
bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);
|
| 291 |
+
|
| 292 |
+
const uint sg_col_count = Bn / Sg_Bn;
|
| 293 |
+
const uint row_sg = sg_id / sg_col_count;
|
| 294 |
+
const uint col_sg = sg_id % sg_col_count;
|
| 295 |
+
|
| 296 |
+
const uint row_sg_offset = row_sg * Sg_Bm;
|
| 297 |
+
const uint col_sg_offset = col_sg * Sg_Bn;
|
| 298 |
+
// Declare threadgroup blocks.
|
| 299 |
+
threadgroup float lhs_block[Bm * Bk];
|
| 300 |
+
// rhs_block will hold the scaled fp32 weights.
|
| 301 |
+
threadgroup float rhs_block[Bn * Bk];
|
| 302 |
+
|
| 303 |
+
constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
|
| 304 |
+
// Create an array of simdgroup_float8x8 to hold temp results.
|
| 305 |
+
metal::simdgroup_float8x8 OutTiles[temp_result_size];
|
| 306 |
+
for (uint i = 0; i < temp_result_size; i++) {
|
| 307 |
+
OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);
|
| 308 |
+
}
|
| 309 |
+
// Linear thread id within TG (we launch 1-D TGs)
|
| 310 |
+
const uint lin_tid = local_tid.x;
|
| 311 |
+
const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;
|
| 312 |
+
|
| 313 |
+
// Iterate over all Bk blocks.
|
| 314 |
+
for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
|
| 315 |
+
constexpr uint lhs_row_stride = Bk;
|
| 316 |
+
constexpr uint lhs_vec_cols = Bk / 4;
|
| 317 |
+
constexpr uint lhs_vec_total = Bm * lhs_vec_cols;
|
| 318 |
+
|
| 319 |
+
const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);
|
| 320 |
+
|
| 321 |
+
// #pragma clang loop unroll(full)
|
| 322 |
+
for (uint t = 0; t < LHS_ITERS; ++t) {
|
| 323 |
+
const uint i = t * thread_count_per_tg + lin_tid;
|
| 324 |
+
if (i < lhs_vec_total) {
|
| 325 |
+
const uint r = i / lhs_vec_cols;
|
| 326 |
+
const uint c4 = i % lhs_vec_cols;
|
| 327 |
+
|
| 328 |
+
const uint gr = row_tg_offset + r;
|
| 329 |
+
const uint gc4 = (k_offset / 4) + c4;
|
| 330 |
+
|
| 331 |
+
threadgroup float4* dst4 =
|
| 332 |
+
reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));
|
| 333 |
+
if (gr < M) {
|
| 334 |
+
const device float4* src4 =
|
| 335 |
+
reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));
|
| 336 |
+
|
| 337 |
+
*dst4 = *src4;
|
| 338 |
+
} else {
|
| 339 |
+
*dst4 = float4(0.0);
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
// Load weights with vector loads.
|
| 345 |
+
constexpr uint rhs_row_stride = Bk;
|
| 346 |
+
constexpr uint weights_per_elem = 8;
|
| 347 |
+
constexpr uint rhs_loads_per_col = Bk / weights_per_elem;
|
| 348 |
+
constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;
|
| 349 |
+
const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);
|
| 350 |
+
// #pragma clang loop unroll(full)
|
| 351 |
+
for (uint t = 0; t < RHS_ITERS; ++t) {
|
| 352 |
+
const uint i = t * thread_count_per_tg + lin_tid;
|
| 353 |
+
if (i < rhs_loads_total) {
|
| 354 |
+
const uint r = i / rhs_loads_per_col;
|
| 355 |
+
const uint c = i % rhs_loads_per_col;
|
| 356 |
+
|
| 357 |
+
const uint gr = col_tg_offset + r;
|
| 358 |
+
const uint gc = (k_offset / weights_per_elem) + c;
|
| 359 |
+
const uint gc_scale = (k_offset / 32) + (c >> 2);
|
| 360 |
+
|
| 361 |
+
const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];
|
| 362 |
+
const float scale =
|
| 363 |
+
as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);
|
| 364 |
+
uint wblock0246 = (wblock + wblock);
|
| 365 |
+
uint wblock1357 = (wblock >> 3);
|
| 366 |
+
wblock0246 &= 0x1E1E1E1Eu;
|
| 367 |
+
wblock1357 &= 0x1E1E1E1Eu;
|
| 368 |
+
|
| 369 |
+
wblock0246 += 0x70707070u;
|
| 370 |
+
wblock1357 += 0x70707070u;
|
| 371 |
+
wblock0246 &= 0x8E8E8E8Eu;
|
| 372 |
+
wblock1357 &= 0x8E8E8E8Eu;
|
| 373 |
+
|
| 374 |
+
uint wblock26 = (wblock0246) & 0xFF00FF00u;
|
| 375 |
+
uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;
|
| 376 |
+
uint wblock37 = (wblock1357) & 0xFF00FF00u;
|
| 377 |
+
uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;
|
| 378 |
+
|
| 379 |
+
half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));
|
| 380 |
+
half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));
|
| 381 |
+
|
| 382 |
+
// Convert to float scalars and apply scale
|
| 383 |
+
const float w0 = float(wblock0426.x) * scale;
|
| 384 |
+
const float w1 = float(wblock1537.x) * scale;
|
| 385 |
+
const float w2 = float(wblock0426.z) * scale;
|
| 386 |
+
const float w3 = float(wblock1537.z) * scale;
|
| 387 |
+
const float w4 = float(wblock0426.y) * scale;
|
| 388 |
+
const float w5 = float(wblock1537.y) * scale;
|
| 389 |
+
const float w6 = float(wblock0426.w) * scale;
|
| 390 |
+
const float w7 = float(wblock1537.w) * scale;
|
| 391 |
+
const uint rhs_offset = r * rhs_row_stride + c * 8;
|
| 392 |
+
rhs_block[rhs_offset] = w0;
|
| 393 |
+
rhs_block[rhs_offset + 1] = w1;
|
| 394 |
+
rhs_block[rhs_offset + 2] = w2;
|
| 395 |
+
rhs_block[rhs_offset + 3] = w3;
|
| 396 |
+
rhs_block[rhs_offset + 4] = w4;
|
| 397 |
+
rhs_block[rhs_offset + 5] = w5;
|
| 398 |
+
rhs_block[rhs_offset + 6] = w6;
|
| 399 |
+
rhs_block[rhs_offset + 7] = w7;
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 403 |
+
#pragma clang loop unroll(full)
|
| 404 |
+
for (uint k = 0; k < Bk; k += 8) {
|
| 405 |
+
#pragma clang loop unroll(full)
|
| 406 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 407 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 408 |
+
metal::simdgroup_float8x8 lhs_frag;
|
| 409 |
+
|
| 410 |
+
simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));
|
| 411 |
+
#pragma clang loop unroll(full)
|
| 412 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 413 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 414 |
+
const uint current_index_out_tile =
|
| 415 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 416 |
+
metal::simdgroup_float8x8 rhs_frag;
|
| 417 |
+
simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);
|
| 418 |
+
|
| 419 |
+
simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,
|
| 420 |
+
OutTiles[current_index_out_tile]);
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
// Epilogue.
|
| 428 |
+
threadgroup float scratch[Bm * Bn];
|
| 429 |
+
#pragma clang loop unroll(full)
|
| 430 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 431 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 432 |
+
const uint local_col_offset = col_sg_offset + n_subtile_;
|
| 433 |
+
#pragma clang loop unroll(full)
|
| 434 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 435 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 436 |
+
const uint local_row_offset = row_sg_offset + m_subtile_;
|
| 437 |
+
const uint current_index_out_tile =
|
| 438 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 439 |
+
simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
|
| 440 |
+
ulong2(local_col_offset, local_row_offset));
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
threadgroup float bias_tile[Bn];
|
| 444 |
+
// TODO(ibahmed): vectorize these loads an maybe unroll the loop.
|
| 445 |
+
for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {
|
| 446 |
+
const uint c_global = col_tg_offset + c_local;
|
| 447 |
+
bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 451 |
+
const float alpha = 1.702f;
|
| 452 |
+
// TODO(ibahmed): vectorize these stores and maybe unroll the loop.
|
| 453 |
+
for (uint idx = local_tid.x; idx < Bm * Bn / 2; idx += thread_count_per_tg) {
|
| 454 |
+
const uint idx_swish = idx * 2;
|
| 455 |
+
const uint r = idx_swish / Bn;
|
| 456 |
+
const uint c_swish = idx_swish % Bn;
|
| 457 |
+
|
| 458 |
+
const uint out_row = row_tg_offset + r;
|
| 459 |
+
const uint out_col = (col_tg_offset / 2) + (c_swish / 2);
|
| 460 |
+
|
| 461 |
+
if (out_row < M && out_col < N_output) {
|
| 462 |
+
float acc_swish = scratch[idx_swish] + bias_tile[c_swish];
|
| 463 |
+
float acc_linear = scratch[idx_swish + 1] + bias_tile[c_swish + 1];
|
| 464 |
+
const float swish = metal::min(acc_swish, params.swiglu_max);
|
| 465 |
+
const float linear = metal::clamp(acc_linear, params.swiglu_min, params.swiglu_max);
|
| 466 |
+
const float swish_y = swish / (1.0f + metal::precise::exp(-alpha * swish));
|
| 467 |
+
const float swiglu_y = metal::fma(swish_y, linear, swish_y);
|
| 468 |
+
out[out_row * N_output + out_col] = swiglu_y;
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
kernel void gptoss_f32_mf4w_moe_dense_matmul(
|
| 474 |
+
constant gptoss_moe_dense_matmul_args& params [[ buffer(0) ]],
|
| 475 |
+
const device uint* __restrict__ expert_offsets [[ buffer(1) ]],
|
| 476 |
+
const device float* lhs [[ buffer(2) ]],
|
| 477 |
+
const device uint* weight_blocks [[ buffer(3) ]],
|
| 478 |
+
const device uchar* weight_scales [[ buffer(4) ]],
|
| 479 |
+
const device bfloat* __restrict__ bias [[ buffer(5) ]],
|
| 480 |
+
device float* out [[ buffer(6) ]],
|
| 481 |
+
uint sg_id [[simdgroup_index_in_threadgroup]],
|
| 482 |
+
uint3 threads_per_tg [[threads_per_threadgroup]],
|
| 483 |
+
uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
|
| 484 |
+
uint3 gid [[thread_position_in_grid]],
|
| 485 |
+
uint3 tg_id [[threadgroup_position_in_grid]],
|
| 486 |
+
uint3 local_tid [[thread_position_in_threadgroup]])
|
| 487 |
+
{
|
| 488 |
+
const uint Bm = MOE_DENSE_MATMUL_Bm;
|
| 489 |
+
const uint Bn = MOE_DENSE_MATMUL_Bn;
|
| 490 |
+
const uint Bk = MOE_DENSE_MATMUL_Bk;
|
| 491 |
+
const uint Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;
|
| 492 |
+
const uint Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;
|
| 493 |
+
assert(Bm % 8 == 0);
|
| 494 |
+
assert(Bn % 8 == 0);
|
| 495 |
+
assert(Bk % 8 == 0);
|
| 496 |
+
assert(Sg_Bm % 8 == 0);
|
| 497 |
+
assert(Sg_Bn % 8 == 0);
|
| 498 |
+
assert(Bm % Sg_Bm == 0);
|
| 499 |
+
assert(Bn % Sg_Bn == 0);
|
| 500 |
+
|
| 501 |
+
const uint K = params.k;
|
| 502 |
+
const uint N = params.n;
|
| 503 |
+
const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];
|
| 504 |
+
assert((K % 32) == 0);
|
| 505 |
+
assert((K % 8) == 0);
|
| 506 |
+
assert(N % Bn == 0);
|
| 507 |
+
assert(K % Bk == 0);
|
| 508 |
+
// Get row and col tg.
|
| 509 |
+
const uint row_tg = tg_id.y;
|
| 510 |
+
const uint col_tg = tg_id.x;
|
| 511 |
+
// Get row and col local tid.
|
| 512 |
+
const uint row_tg_offset = row_tg * Bm;
|
| 513 |
+
const uint col_tg_offset = col_tg * Bn;
|
| 514 |
+
if (row_tg_offset >= M || col_tg_offset >= N) {
|
| 515 |
+
return;
|
| 516 |
+
}
|
| 517 |
+
// Move lhs and output according to the passed offset.
|
| 518 |
+
const uint expert_offset = expert_offsets[tg_id.z];
|
| 519 |
+
lhs += expert_offset * K;
|
| 520 |
+
out += expert_offset * N;
|
| 521 |
+
|
| 522 |
+
const uint S = params.weight_blocks_expert_stride_bytes;
|
| 523 |
+
const uint S_scales = params.weight_scales_expert_stride_bytes;
|
| 524 |
+
const uint S_bias = params.bias_expert_stride_bytes;
|
| 525 |
+
|
| 526 |
+
const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);
|
| 527 |
+
const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);
|
| 528 |
+
const device char* bi0 = reinterpret_cast<const device char*>(bias);
|
| 529 |
+
|
| 530 |
+
weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);
|
| 531 |
+
weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);
|
| 532 |
+
bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);
|
| 533 |
+
|
| 534 |
+
const uint sg_col_count = Bn / Sg_Bn;
|
| 535 |
+
const uint row_sg = sg_id / sg_col_count;
|
| 536 |
+
const uint col_sg = sg_id % sg_col_count;
|
| 537 |
+
|
| 538 |
+
const uint row_sg_offset = row_sg * Sg_Bm;
|
| 539 |
+
const uint col_sg_offset = col_sg * Sg_Bn;
|
| 540 |
+
// Declare threadgroup blocks.
|
| 541 |
+
threadgroup float lhs_block[Bm * Bk];
|
| 542 |
+
// rhs_block will hold the scaled fp32 weights.
|
| 543 |
+
threadgroup float rhs_block[Bn * Bk];
|
| 544 |
+
|
| 545 |
+
constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
|
| 546 |
+
// Create an array of simdgroup_float8x8 to hold temp results.
|
| 547 |
+
metal::simdgroup_float8x8 OutTiles[temp_result_size];
|
| 548 |
+
for (uint i = 0; i < temp_result_size; i++) {
|
| 549 |
+
OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);
|
| 550 |
+
}
|
| 551 |
+
// Linear thread id within TG (we launch 1-D TGs)
|
| 552 |
+
const uint lin_tid = local_tid.x;
|
| 553 |
+
|
| 554 |
+
const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;
|
| 555 |
+
// Iterate over all Bk blocks.
|
| 556 |
+
for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
|
| 557 |
+
constexpr uint lhs_row_stride = Bk;
|
| 558 |
+
constexpr uint lhs_vec_cols = Bk / 4;
|
| 559 |
+
constexpr uint lhs_vec_total = Bm * lhs_vec_cols;
|
| 560 |
+
|
| 561 |
+
const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);
|
| 562 |
+
|
| 563 |
+
for (uint t = 0; t < LHS_ITERS; ++t) {
|
| 564 |
+
const uint i = t * thread_count_per_tg + lin_tid;
|
| 565 |
+
if (i < lhs_vec_total) {
|
| 566 |
+
const uint r = i / lhs_vec_cols;
|
| 567 |
+
const uint c4 = i % lhs_vec_cols;
|
| 568 |
+
|
| 569 |
+
const uint gr = row_tg_offset + r;
|
| 570 |
+
const uint gc4 = (k_offset / 4) + c4;
|
| 571 |
+
|
| 572 |
+
threadgroup float4* dst4 =
|
| 573 |
+
reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));
|
| 574 |
+
if (gr < M) {
|
| 575 |
+
const device float4* src4 =
|
| 576 |
+
reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));
|
| 577 |
+
|
| 578 |
+
*dst4 = *src4;
|
| 579 |
+
} else {
|
| 580 |
+
*dst4 = float4(0.0);
|
| 581 |
+
}
|
| 582 |
+
}
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
// Load weights with vector loads.
|
| 586 |
+
constexpr uint rhs_row_stride = Bk;
|
| 587 |
+
constexpr uint weights_per_elem = 8;
|
| 588 |
+
constexpr uint rhs_loads_per_col = Bk / weights_per_elem;
|
| 589 |
+
constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;
|
| 590 |
+
const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);
|
| 591 |
+
// #pragma clang loop unroll(full)
|
| 592 |
+
for (uint t = 0; t < RHS_ITERS; ++t) {
|
| 593 |
+
const uint i = t * thread_count_per_tg + lin_tid;
|
| 594 |
+
if (i < rhs_loads_total) {
|
| 595 |
+
const uint r = i / rhs_loads_per_col;
|
| 596 |
+
const uint c = i % rhs_loads_per_col;
|
| 597 |
+
|
| 598 |
+
const uint gr = col_tg_offset + r;
|
| 599 |
+
const uint gc = (k_offset / weights_per_elem) + c;
|
| 600 |
+
const uint gc_scale = (k_offset / 32) + (c >> 2);
|
| 601 |
+
|
| 602 |
+
const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];
|
| 603 |
+
const float scale =
|
| 604 |
+
as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);
|
| 605 |
+
|
| 606 |
+
uint wblock0246 = (wblock + wblock);
|
| 607 |
+
uint wblock1357 = (wblock >> 3);
|
| 608 |
+
wblock0246 &= 0x1E1E1E1Eu;
|
| 609 |
+
wblock1357 &= 0x1E1E1E1Eu;
|
| 610 |
+
|
| 611 |
+
wblock0246 += 0x70707070u;
|
| 612 |
+
wblock1357 += 0x70707070u;
|
| 613 |
+
wblock0246 &= 0x8E8E8E8Eu;
|
| 614 |
+
wblock1357 &= 0x8E8E8E8Eu;
|
| 615 |
+
|
| 616 |
+
uint wblock26 = (wblock0246) & 0xFF00FF00u;
|
| 617 |
+
uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;
|
| 618 |
+
uint wblock37 = (wblock1357) & 0xFF00FF00u;
|
| 619 |
+
uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;
|
| 620 |
+
|
| 621 |
+
half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));
|
| 622 |
+
half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));
|
| 623 |
+
|
| 624 |
+
const float w0 = float(wblock0426.x) * scale;
|
| 625 |
+
const float w1 = float(wblock1537.x) * scale;
|
| 626 |
+
const float w2 = float(wblock0426.z) * scale;
|
| 627 |
+
const float w3 = float(wblock1537.z) * scale;
|
| 628 |
+
const float w4 = float(wblock0426.y) * scale;
|
| 629 |
+
const float w5 = float(wblock1537.y) * scale;
|
| 630 |
+
const float w6 = float(wblock0426.w) * scale;
|
| 631 |
+
const float w7 = float(wblock1537.w) * scale;
|
| 632 |
+
const uint rhs_offset = r * rhs_row_stride + c * 8;
|
| 633 |
+
rhs_block[rhs_offset] = w0;
|
| 634 |
+
rhs_block[rhs_offset + 1] = w1;
|
| 635 |
+
rhs_block[rhs_offset + 2] = w2;
|
| 636 |
+
rhs_block[rhs_offset + 3] = w3;
|
| 637 |
+
rhs_block[rhs_offset + 4] = w4;
|
| 638 |
+
rhs_block[rhs_offset + 5] = w5;
|
| 639 |
+
rhs_block[rhs_offset + 6] = w6;
|
| 640 |
+
rhs_block[rhs_offset + 7] = w7;
|
| 641 |
+
}
|
| 642 |
+
}
|
| 643 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 644 |
+
#pragma clang loop unroll(full)
|
| 645 |
+
for (uint k = 0; k < Bk; k += 8) {
|
| 646 |
+
#pragma clang loop unroll(full)
|
| 647 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 648 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 649 |
+
metal::simdgroup_float8x8 lhs_frag;
|
| 650 |
+
|
| 651 |
+
simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));
|
| 652 |
+
#pragma clang loop unroll(full)
|
| 653 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 654 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 655 |
+
const uint current_index_out_tile =
|
| 656 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 657 |
+
metal::simdgroup_float8x8 rhs_frag;
|
| 658 |
+
simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);
|
| 659 |
+
simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,
|
| 660 |
+
OutTiles[current_index_out_tile]);
|
| 661 |
+
}
|
| 662 |
+
}
|
| 663 |
+
}
|
| 664 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
// Epilogue.
|
| 668 |
+
threadgroup float scratch[Bm * Bn];
|
| 669 |
+
#pragma clang loop unroll(full)
|
| 670 |
+
for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
|
| 671 |
+
const uint col_index_in_out_tile = n_subtile_ / 8;
|
| 672 |
+
const uint local_col_offset = col_sg_offset + n_subtile_;
|
| 673 |
+
#pragma clang loop unroll(full)
|
| 674 |
+
for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
|
| 675 |
+
const uint row_index_in_out_tile = m_subtile_ / 8;
|
| 676 |
+
const uint local_row_offset = row_sg_offset + m_subtile_;
|
| 677 |
+
const uint current_index_out_tile =
|
| 678 |
+
row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
|
| 679 |
+
simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
|
| 680 |
+
ulong2(local_col_offset, local_row_offset));
|
| 681 |
+
}
|
| 682 |
+
}
|
| 683 |
+
threadgroup float bias_tile[Bn];
|
| 684 |
+
for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {
|
| 685 |
+
const uint c_global = col_tg_offset + c_local;
|
| 686 |
+
bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 690 |
+
for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {
|
| 691 |
+
const uint r = idx / Bn;
|
| 692 |
+
const uint c = idx % Bn;
|
| 693 |
+
|
| 694 |
+
const uint out_row = row_tg_offset + r;
|
| 695 |
+
const uint out_col = col_tg_offset + c;
|
| 696 |
+
|
| 697 |
+
if (out_row < M && out_col < N) {
|
| 698 |
+
float acc = scratch[idx] + bias_tile[c];
|
| 699 |
+
out[out_row * N + out_col] = acc;
|
| 700 |
+
}
|
| 701 |
+
}
|
| 702 |
+
}
|
gptoss_kernels/source/random.metal
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_integer>
|
| 2 |
+
#include <metal_math>
|
| 3 |
+
|
| 4 |
+
#include <internal/kernel-args.h>
|
| 5 |
+
|
| 6 |
+
#pragma METAL fp math_mode(safe)
|
| 7 |
+
#pragma METAL fp contract(off)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
inline static uint rng_squares32(ulong offset, ulong seed) {
|
| 11 |
+
const ulong y = offset * seed;
|
| 12 |
+
const ulong z = y + seed;
|
| 13 |
+
|
| 14 |
+
/* Round 1 */
|
| 15 |
+
ulong x = y * y + y;
|
| 16 |
+
x = metal::rotate(x, 32ul);
|
| 17 |
+
|
| 18 |
+
/* Round 2 */
|
| 19 |
+
x = x * x + z;
|
| 20 |
+
x = metal::rotate(x, 32ul);
|
| 21 |
+
|
| 22 |
+
/* Round 3 */
|
| 23 |
+
x = x * x + y;
|
| 24 |
+
x = metal::rotate(x, 32ul);
|
| 25 |
+
|
| 26 |
+
/* Round 4 */
|
| 27 |
+
x = x * x + z;
|
| 28 |
+
return as_type<uint2>(x).y;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
kernel void gptoss_u32_fill_random(
|
| 32 |
+
constant gptoss_u32_fill_random_args& args [[ buffer(0) ]],
|
| 33 |
+
device uint* output [[ buffer(1) ]],
|
| 34 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 35 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 36 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 37 |
+
{
|
| 38 |
+
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
| 39 |
+
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
| 40 |
+
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
| 41 |
+
const ulong thread_start = threadgroup_start + tid;
|
| 42 |
+
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
| 43 |
+
|
| 44 |
+
output += thread_start;
|
| 45 |
+
ulong offset = args.offset + thread_start;
|
| 46 |
+
for (; num_iter != 0; num_iter--) {
|
| 47 |
+
*output = rng_squares32(offset, args.seed);
|
| 48 |
+
output += threadgroup_size;
|
| 49 |
+
offset += threadgroup_size;
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
kernel void gptoss_f32_fill_random(
|
| 54 |
+
constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
|
| 55 |
+
device float* output [[ buffer(1) ]],
|
| 56 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 57 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 58 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 59 |
+
{
|
| 60 |
+
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
| 61 |
+
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
| 62 |
+
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
| 63 |
+
const ulong thread_start = threadgroup_start + tid;
|
| 64 |
+
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
| 65 |
+
|
| 66 |
+
output += thread_start;
|
| 67 |
+
ulong offset = args.offset + thread_start;
|
| 68 |
+
for (; num_iter != 0; num_iter--) {
|
| 69 |
+
const uint word = rng_squares32(offset, args.seed);
|
| 70 |
+
*output = metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias);
|
| 71 |
+
output += threadgroup_size;
|
| 72 |
+
offset += threadgroup_size;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
kernel void gptoss_bf16_fill_random(
|
| 77 |
+
constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
|
| 78 |
+
device bfloat* output [[ buffer(1) ]],
|
| 79 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 80 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 81 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 82 |
+
{
|
| 83 |
+
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
|
| 84 |
+
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
|
| 85 |
+
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
|
| 86 |
+
const ulong thread_start = threadgroup_start + tid;
|
| 87 |
+
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
|
| 88 |
+
|
| 89 |
+
output += thread_start;
|
| 90 |
+
ulong offset = args.offset + thread_start;
|
| 91 |
+
for (; num_iter != 0; num_iter--) {
|
| 92 |
+
const uint word = rng_squares32(offset, args.seed);
|
| 93 |
+
*output = static_cast<bfloat>(metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias));
|
| 94 |
+
output += threadgroup_size;
|
| 95 |
+
offset += threadgroup_size;
|
| 96 |
+
}
|
| 97 |
+
}
|
gptoss_kernels/source/rmsnorm.metal
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_compute>
|
| 2 |
+
#include <metal_math>
|
| 3 |
+
#include <metal_simdgroup>
|
| 4 |
+
|
| 5 |
+
#include <internal/kernel-args.h>
|
| 6 |
+
|
| 7 |
+
#pragma METAL fp math_mode(safe)
|
| 8 |
+
#pragma METAL fp contract(off)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
[[max_total_threads_per_threadgroup(1024)]]
|
| 12 |
+
kernel void gptoss_f32_bf16w_rmsnorm(
|
| 13 |
+
constant gptoss_rmsnorm_args& args [[ buffer(0) ]],
|
| 14 |
+
const device float4* input [[ buffer(1) ]],
|
| 15 |
+
const device bfloat4* weights [[ buffer(2) ]],
|
| 16 |
+
device float4* output [[ buffer(3) ]],
|
| 17 |
+
const device gptoss_control* control [[ buffer(4) ]],
|
| 18 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 19 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 20 |
+
uint threadgroup_size [[ threads_per_threadgroup ]])
|
| 21 |
+
{
|
| 22 |
+
const uint simdgroup_size = 32;
|
| 23 |
+
threadgroup float threadgroup_buffer[32];
|
| 24 |
+
if (control->abort != 0) {
|
| 25 |
+
return;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
input += gid * args.num_vecs;
|
| 29 |
+
output += gid * args.num_vecs;
|
| 30 |
+
|
| 31 |
+
float4 sumsq4 = 0.0f;
|
| 32 |
+
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
| 33 |
+
const float4 val = input[i];
|
| 34 |
+
sumsq4 = metal::fma(val, val, sumsq4);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Tree-reduce sumsq within thread, then all-reduce within threadgroup.
|
| 38 |
+
const float2 sumsq2 = sumsq4.xy + sumsq4.zw;
|
| 39 |
+
float sumsq = sumsq2.x + sumsq2.y;
|
| 40 |
+
// Warning: this all-reduce works only for simdgroup of 32 threads and threadgroup of 32*32=1024 threads.
|
| 41 |
+
sumsq = metal::simd_sum(sumsq);
|
| 42 |
+
if (metal::simd_is_first()) {
|
| 43 |
+
const uint simdgroup_idx = tid / simdgroup_size;
|
| 44 |
+
threadgroup_buffer[simdgroup_idx] = sumsq;
|
| 45 |
+
}
|
| 46 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 47 |
+
const uint simdgroup_tid = tid % simdgroup_size;
|
| 48 |
+
sumsq = threadgroup_buffer[simdgroup_tid];
|
| 49 |
+
sumsq = metal::simd_sum(sumsq);
|
| 50 |
+
|
| 51 |
+
const float avgsq = sumsq / args.num_channels;
|
| 52 |
+
const float scale = metal::precise::rsqrt(avgsq + args.epsilon);
|
| 53 |
+
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
| 54 |
+
const float4 val = input[i] * scale;
|
| 55 |
+
const float4 weight_val = static_cast<float4>(weights[i]);
|
| 56 |
+
output[i] = val * weight_val;
|
| 57 |
+
}
|
| 58 |
+
}
|
gptoss_kernels/source/rope.metal
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_common>
|
| 2 |
+
#include <metal_math>
|
| 3 |
+
|
| 4 |
+
#include <internal/kernel-args.h>
|
| 5 |
+
|
| 6 |
+
#pragma METAL fp math_mode(safe)
|
| 7 |
+
#pragma METAL fp contract(off)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
// Each thread handles 2 head elements.
|
| 11 |
+
// Each simdgroup handles one head (64 head elements).
|
| 12 |
+
|
| 13 |
+
kernel void gptoss_f32_rope(
|
| 14 |
+
constant gptoss_rope_args& args [[ buffer(0) ]],
|
| 15 |
+
device float2* activations [[ buffer(1) ]],
|
| 16 |
+
const device gptoss_control* control [[ buffer(2) ]],
|
| 17 |
+
uint2 gid [[thread_position_in_grid]])
|
| 18 |
+
{
|
| 19 |
+
const uint num_head_dims = 64;
|
| 20 |
+
if (control->abort != 0) {
|
| 21 |
+
return;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
const float dim_idx = static_cast<float>(gid.x % (num_head_dims / 2));
|
| 25 |
+
const uint token_idx = args.token_offset + gid.y;
|
| 26 |
+
activations += gid.y * args.token_stride + gid.x;
|
| 27 |
+
|
| 28 |
+
const float2 input_vals = *activations;
|
| 29 |
+
const float inv_extrapolation_freq = metal::precise::exp(dim_idx * args.freq_scale);
|
| 30 |
+
const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
|
| 31 |
+
const float alpha = metal::saturate(metal::fma(dim_idx, args.yarn_scale, args.yarn_offset));
|
| 32 |
+
const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
|
| 33 |
+
|
| 34 |
+
const float phi = static_cast<float>(token_idx) * inv_freq;
|
| 35 |
+
const float yarn_multiplier = args.yarn_multiplier;
|
| 36 |
+
float cosphi;
|
| 37 |
+
const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
|
| 38 |
+
cosphi *= yarn_multiplier;
|
| 39 |
+
|
| 40 |
+
const float output_re = input_vals.x * cosphi - input_vals.y * sinphi;
|
| 41 |
+
const float output_im = input_vals.x * sinphi + input_vals.y * cosphi;
|
| 42 |
+
*activations = (float2) { output_re, output_im };
|
| 43 |
+
}
|
gptoss_kernels/source/sample.metal
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_compute>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_simdgroup>
|
| 5 |
+
|
| 6 |
+
#include <internal/kernel-args.h>
|
| 7 |
+
|
| 8 |
+
#pragma METAL fp math_mode(safe)
|
| 9 |
+
#pragma METAL fp contract(off)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
inline static uint rng_squares32(ulong offset, ulong seed) {
|
| 13 |
+
const ulong y = offset * seed;
|
| 14 |
+
const ulong z = y + seed;
|
| 15 |
+
|
| 16 |
+
/* Round 1 */
|
| 17 |
+
ulong x = y * y + y;
|
| 18 |
+
x = metal::rotate(x, 32ul);
|
| 19 |
+
|
| 20 |
+
/* Round 2 */
|
| 21 |
+
x = x * x + z;
|
| 22 |
+
x = metal::rotate(x, 32ul);
|
| 23 |
+
|
| 24 |
+
/* Round 3 */
|
| 25 |
+
x = x * x + y;
|
| 26 |
+
x = metal::rotate(x, 32ul);
|
| 27 |
+
|
| 28 |
+
/* Round 4 */
|
| 29 |
+
x = x * x + z;
|
| 30 |
+
return as_type<uint2>(x).y;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
kernel void gptoss_f32_softmax(
|
| 34 |
+
constant gptoss_softmax_args& args [[ buffer(0) ]],
|
| 35 |
+
const device float* score [[ buffer(1) ]],
|
| 36 |
+
const device uint2* argmax [[ buffer(2) ]],
|
| 37 |
+
device float* prob [[ buffer(3) ]],
|
| 38 |
+
device float* sum [[ buffer(4) ]],
|
| 39 |
+
const device gptoss_control* control [[ buffer(5) ]],
|
| 40 |
+
uint tidx [[thread_index_in_threadgroup]],
|
| 41 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 42 |
+
uint2 threadgroup_size [[threads_per_threadgroup]],
|
| 43 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 44 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 45 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 46 |
+
{
|
| 47 |
+
threadgroup float threadgroup_sumexp[32];
|
| 48 |
+
if (control->abort != 0) {
|
| 49 |
+
return;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
|
| 53 |
+
prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
|
| 54 |
+
sum += gid.y * args.max_threadgroups;
|
| 55 |
+
|
| 56 |
+
uint max_bits = argmax[gid.y].y;
|
| 57 |
+
if (static_cast<int>(max_bits) >= 0) {
|
| 58 |
+
max_bits ^= 0x7FFFFFFFu;
|
| 59 |
+
}
|
| 60 |
+
const float max_val = as_type<float>(max_bits);
|
| 61 |
+
float sum_exp = 0.0f;
|
| 62 |
+
const uint num_vecs_per_threadgroup = metal::min(args.num_vecs - gid.x * args.num_vecs_per_threadgroup, args.num_vecs_per_threadgroup);
|
| 63 |
+
for (uint i = tidx; i < num_vecs_per_threadgroup; i += threadgroup_size.x) {
|
| 64 |
+
const float score_val = score[i];
|
| 65 |
+
const float prob_val = metal::precise::exp((score_val - max_val) * args.temperature);
|
| 66 |
+
prob[i] = prob_val;
|
| 67 |
+
sum_exp += prob_val;
|
| 68 |
+
}
|
| 69 |
+
sum_exp = metal::simd_sum(sum_exp);
|
| 70 |
+
if (metal::simd_is_first()) {
|
| 71 |
+
threadgroup_sumexp[simdgroup_idx] = sum_exp;
|
| 72 |
+
}
|
| 73 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 74 |
+
if (simdgroup_idx == 0) {
|
| 75 |
+
// Sum-Reduce threadgroup_sumexp
|
| 76 |
+
sum_exp = 0.0f;
|
| 77 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 78 |
+
sum_exp = threadgroup_sumexp[simdgroup_tid];
|
| 79 |
+
}
|
| 80 |
+
sum_exp = metal::simd_sum(sum_exp);
|
| 81 |
+
if (metal::simd_is_first()) {
|
| 82 |
+
sum[gid.x] = sum_exp;
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
[[max_total_threads_per_threadgroup(1024)]]
|
| 88 |
+
kernel void gptoss_f32_sample(
|
| 89 |
+
constant gptoss_sample_args& args [[ buffer(0) ]],
|
| 90 |
+
device const float* prob [[ buffer(1) ]],
|
| 91 |
+
device const float* sum [[ buffer(2) ]],
|
| 92 |
+
device uint* prediction [[ buffer(3) ]],
|
| 93 |
+
device gptoss_control* control [[ buffer(4) ]],
|
| 94 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 95 |
+
uint threadgroup_size [[threads_per_threadgroup]],
|
| 96 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 97 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 98 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 99 |
+
{
|
| 100 |
+
threadgroup float threadgroup_sum_buffer[32];
|
| 101 |
+
threadgroup uint threadgroup_idx_buffer[32];
|
| 102 |
+
threadgroup float threadgroup_cumsum_buffer[32];
|
| 103 |
+
if (control->abort != 0) {
|
| 104 |
+
return;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed);
|
| 108 |
+
float sample_cdf = static_cast<float>(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;
|
| 109 |
+
|
| 110 |
+
float cumsum = 0.0f;
|
| 111 |
+
if (tid < args.num_blocks) {
|
| 112 |
+
cumsum = sum[tid];
|
| 113 |
+
}
|
| 114 |
+
cumsum = metal::simd_prefix_inclusive_sum(cumsum);
|
| 115 |
+
if (simdgroup_tid == 31) {
|
| 116 |
+
threadgroup_sum_buffer[simdgroup_idx] = cumsum;
|
| 117 |
+
}
|
| 118 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 119 |
+
float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
|
| 120 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 121 |
+
threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
|
| 122 |
+
if (simdgroup_tid < simdgroup_idx) {
|
| 123 |
+
threadgroup_cumsum = threadgroup_sum;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
threadgroup_sum = metal::simd_sum(threadgroup_sum);
|
| 127 |
+
cumsum += metal::simd_sum(threadgroup_cumsum);
|
| 128 |
+
|
| 129 |
+
sample_cdf *= threadgroup_sum;
|
| 130 |
+
sample_cdf = metal::max(sample_cdf, 0x1.0p-149f);
|
| 131 |
+
|
| 132 |
+
// Find the block: the smallest tid where sample_cdf >= s
|
| 133 |
+
uint block_idx = args.num_blocks;
|
| 134 |
+
float block_sum = cumsum;
|
| 135 |
+
if (tid >= args.num_blocks - 1) {
|
| 136 |
+
block_idx = args.num_blocks - 1;
|
| 137 |
+
block_sum = 0.0f;
|
| 138 |
+
} else if (cumsum >= sample_cdf) {
|
| 139 |
+
block_idx = tid;
|
| 140 |
+
block_sum = 0.0f;
|
| 141 |
+
}
|
| 142 |
+
block_idx = metal::simd_min(block_idx);
|
| 143 |
+
block_sum = metal::simd_max(block_sum);
|
| 144 |
+
if (simdgroup_tid == 0) {
|
| 145 |
+
threadgroup_idx_buffer[simdgroup_idx] = block_idx;
|
| 146 |
+
threadgroup_cumsum_buffer[simdgroup_idx] = block_sum;
|
| 147 |
+
}
|
| 148 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 149 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 150 |
+
block_idx = threadgroup_idx_buffer[simdgroup_tid];
|
| 151 |
+
block_sum = threadgroup_cumsum_buffer[simdgroup_tid];
|
| 152 |
+
}
|
| 153 |
+
block_idx = metal::simd_min(block_idx);
|
| 154 |
+
block_sum = metal::simd_max(block_sum);
|
| 155 |
+
|
| 156 |
+
const uint block_start = args.num_dims_per_block * block_idx;
|
| 157 |
+
const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims);
|
| 158 |
+
uint offset = block_start + tid;
|
| 159 |
+
float accumulated_sum = block_sum;
|
| 160 |
+
uint sample_idx;
|
| 161 |
+
|
| 162 |
+
// This loop must be threadgroup-uniform.
|
| 163 |
+
do {
|
| 164 |
+
// Find the token: the smallest tid where sample_cdf >= s
|
| 165 |
+
float cumsum = 0.0f;
|
| 166 |
+
if (offset < block_end) {
|
| 167 |
+
cumsum = prob[offset];
|
| 168 |
+
}
|
| 169 |
+
cumsum = metal::simd_prefix_inclusive_sum(cumsum);
|
| 170 |
+
if (simdgroup_tid == 31) {
|
| 171 |
+
threadgroup_sum_buffer[simdgroup_idx] = cumsum;
|
| 172 |
+
}
|
| 173 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 174 |
+
float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
|
| 175 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 176 |
+
threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
|
| 177 |
+
if (simdgroup_tid < simdgroup_idx) {
|
| 178 |
+
threadgroup_cumsum = threadgroup_sum;
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
threadgroup_sum = metal::simd_sum(threadgroup_sum);
|
| 182 |
+
cumsum += metal::simd_sum(threadgroup_cumsum);
|
| 183 |
+
cumsum += accumulated_sum;
|
| 184 |
+
|
| 185 |
+
sample_idx = block_end;
|
| 186 |
+
if (offset >= block_end) {
|
| 187 |
+
// Trigger loop exit, with the last token in the block being sampled if no other candidate was found.
|
| 188 |
+
sample_idx = block_end - 1;
|
| 189 |
+
} else if (cumsum >= sample_cdf) {
|
| 190 |
+
sample_idx = offset;
|
| 191 |
+
}
|
| 192 |
+
sample_idx = metal::simd_min(sample_idx);
|
| 193 |
+
if (simdgroup_tid == 0) {
|
| 194 |
+
threadgroup_idx_buffer[simdgroup_idx] = sample_idx;
|
| 195 |
+
}
|
| 196 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 197 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 198 |
+
sample_idx = threadgroup_idx_buffer[simdgroup_tid];
|
| 199 |
+
}
|
| 200 |
+
sample_idx = metal::simd_min(sample_idx);
|
| 201 |
+
|
| 202 |
+
offset += threadgroup_size;
|
| 203 |
+
accumulated_sum += threadgroup_sum;
|
| 204 |
+
} while (sample_idx == block_end);
|
| 205 |
+
|
| 206 |
+
if (tid == 0) {
|
| 207 |
+
*prediction = sample_idx;
|
| 208 |
+
}
|
| 209 |
+
}
|
gptoss_kernels/source/scatter.metal
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <internal/kernel-args.h>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_stdlib>
|
| 5 |
+
|
| 6 |
+
// TODO(ibrahim): This is not optimal as each thread only scatters a single float4. To amortize the
|
| 7 |
+
// cost of reading the expert id and offset for a token, we should let each thread scatter several
|
| 8 |
+
// float4s.
|
| 9 |
+
kernel void gptoss_f32_scatter_e4(
|
| 10 |
+
constant gptoss_scatter_args& args [[ buffer(0) ]],
|
| 11 |
+
const device float* in [[ buffer(1) ]],
|
| 12 |
+
const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
|
| 13 |
+
const device uint* __restrict__ expert_offsets [[ buffer(3) ]],
|
| 14 |
+
const device uint* __restrict__ intra_expert_offsets [[ buffer(4) ]],
|
| 15 |
+
device float* out [[ buffer(5) ]],
|
| 16 |
+
uint3 gid [[thread_position_in_grid]])
|
| 17 |
+
{
|
| 18 |
+
const uint total_tokens = args.tokens;
|
| 19 |
+
const uint active_experts_per_token = args.active_experts_per_token;
|
| 20 |
+
const uint embedding_dim = args.token_stride;
|
| 21 |
+
assert(embedding_dim % 4 == 0);
|
| 22 |
+
// Hard coded to top4 for now.
|
| 23 |
+
assert(active_experts_per_token == 4);
|
| 24 |
+
const uint row_in = gid.y;
|
| 25 |
+
if (row_in >= total_tokens) {
|
| 26 |
+
return;
|
| 27 |
+
}
|
| 28 |
+
// Consecutive threads in a tg read consecutive columns of the input.
|
| 29 |
+
const uint col_in_vec4 = gid.x;
|
| 30 |
+
const uint col_in = col_in_vec4 * 4;
|
| 31 |
+
if (col_in >= embedding_dim) {
|
| 32 |
+
return;
|
| 33 |
+
}
|
| 34 |
+
// Pointer to the piece of the input that we will copy to the top4 experts.
|
| 35 |
+
const device float4* src4 =
|
| 36 |
+
reinterpret_cast<const device float4*>(in + row_in * embedding_dim + col_in);
|
| 37 |
+
|
| 38 |
+
// Get the 4 destinations -- 4 experts.
|
| 39 |
+
const uint base = row_in * active_experts_per_token;
|
| 40 |
+
const uint expert0_id = expert_predictions[base].expert_id;
|
| 41 |
+
const uint expert1_id = expert_predictions[base + 1].expert_id;
|
| 42 |
+
const uint expert2_id = expert_predictions[base + 2].expert_id;
|
| 43 |
+
const uint expert3_id = expert_predictions[base + 3].expert_id;
|
| 44 |
+
const uint expert0_offset = expert_offsets[expert0_id];
|
| 45 |
+
const uint expert1_offset = expert_offsets[expert1_id];
|
| 46 |
+
const uint expert2_offset = expert_offsets[expert2_id];
|
| 47 |
+
const uint expert3_offset = expert_offsets[expert3_id];
|
| 48 |
+
const uint expert0_intra_expert_offset = intra_expert_offsets[base];
|
| 49 |
+
const uint expert1_intra_expert_offset = intra_expert_offsets[base + 1];
|
| 50 |
+
const uint expert2_intra_expert_offset = intra_expert_offsets[base + 2];
|
| 51 |
+
const uint expert3_intra_expert_offset = intra_expert_offsets[base + 3];
|
| 52 |
+
device float4* dst4_0 = reinterpret_cast<device float4*>(
|
| 53 |
+
out + (expert0_offset + expert0_intra_expert_offset) * embedding_dim + col_in);
|
| 54 |
+
device float4* dst4_1 = reinterpret_cast<device float4*>(
|
| 55 |
+
out + (expert1_offset + expert1_intra_expert_offset) * embedding_dim + col_in);
|
| 56 |
+
device float4* dst4_2 = reinterpret_cast<device float4*>(
|
| 57 |
+
out + (expert2_offset + expert2_intra_expert_offset) * embedding_dim + col_in);
|
| 58 |
+
device float4* dst4_3 = reinterpret_cast<device float4*>(
|
| 59 |
+
out + (expert3_offset + expert3_intra_expert_offset) * embedding_dim + col_in);
|
| 60 |
+
const float4 data = *src4;
|
| 61 |
+
*dst4_0 = data;
|
| 62 |
+
*dst4_1 = data;
|
| 63 |
+
*dst4_2 = data;
|
| 64 |
+
*dst4_3 = data;
|
| 65 |
+
}
|
gptoss_kernels/source/sdpa.metal
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_geometric>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_compute>
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
|
| 7 |
+
#include <internal/kernel-args.h>
|
| 8 |
+
|
| 9 |
+
#pragma METAL fp math_mode(safe)
|
| 10 |
+
#pragma METAL fp contract(off)
|
| 11 |
+
|
| 12 |
+
// Each threadgroup handles 8 Q heads / 1 KV head for 1 token
|
| 13 |
+
|
| 14 |
+
kernel void gptoss_f32_sdpa_q8_d64(
|
| 15 |
+
constant gptoss_sdpa_args& args [[ buffer(0) ]],
|
| 16 |
+
const device float* q [[ buffer(1) ]],
|
| 17 |
+
const device float* kv [[ buffer(2) ]],
|
| 18 |
+
const device bfloat* s [[ buffer(3) ]],
|
| 19 |
+
device float* output [[ buffer(4) ]],
|
| 20 |
+
const device gptoss_control* control [[ buffer(6) ]],
|
| 21 |
+
threadgroup void* threadgroup_buffer [[ threadgroup(0) ]],
|
| 22 |
+
uint2 gid [[threadgroup_position_in_grid]],
|
| 23 |
+
uint2 tid [[thread_position_in_threadgroup]],
|
| 24 |
+
uint simdgroup_tid [[thread_index_in_simdgroup]],
|
| 25 |
+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
|
| 26 |
+
uint num_simdgroups [[simdgroups_per_threadgroup]])
|
| 27 |
+
{
|
| 28 |
+
const uint simdgroup_size = 32;
|
| 29 |
+
if (control->abort != 0) {
|
| 30 |
+
return;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
const uint num_q_heads = 64;
|
| 34 |
+
const uint head_dim = 64;
|
| 35 |
+
const uint qmul = 8;
|
| 36 |
+
|
| 37 |
+
const uint token_stride = 2 * head_dim;
|
| 38 |
+
|
| 39 |
+
const uint qt = gid.x; // Q token index
|
| 40 |
+
const uint h = gid.y; // KV head index
|
| 41 |
+
|
| 42 |
+
q += qt * args.qkv_dim + h * (qmul * head_dim);
|
| 43 |
+
kv += h * args.kv_stride;
|
| 44 |
+
output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim);
|
| 45 |
+
|
| 46 |
+
float m0 = static_cast<float>(s[h * qmul + 0]);
|
| 47 |
+
float m1 = static_cast<float>(s[h * qmul + 1]);
|
| 48 |
+
float m2 = static_cast<float>(s[h * qmul + 2]);
|
| 49 |
+
float m3 = static_cast<float>(s[h * qmul + 3]);
|
| 50 |
+
float m4 = static_cast<float>(s[h * qmul + 4]);
|
| 51 |
+
float m5 = static_cast<float>(s[h * qmul + 5]);
|
| 52 |
+
float m6 = static_cast<float>(s[h * qmul + 6]);
|
| 53 |
+
float m7 = static_cast<float>(s[h * qmul + 7]);
|
| 54 |
+
|
| 55 |
+
float l0 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 56 |
+
float l1 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 57 |
+
float l2 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 58 |
+
float l3 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 59 |
+
float l4 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 60 |
+
float l5 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 61 |
+
float l6 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 62 |
+
float l7 = simdgroup_idx == 0 ? 1.0f : 0.0f;
|
| 63 |
+
|
| 64 |
+
float2 out0 = 0.0f;
|
| 65 |
+
float2 out1 = 0.0f;
|
| 66 |
+
float2 out2 = 0.0f;
|
| 67 |
+
float2 out3 = 0.0f;
|
| 68 |
+
float2 out4 = 0.0f;
|
| 69 |
+
float2 out5 = 0.0f;
|
| 70 |
+
float2 out6 = 0.0f;
|
| 71 |
+
float2 out7 = 0.0f;
|
| 72 |
+
|
| 73 |
+
float2 q0 = reinterpret_cast<const device float2*>(q + 0 * head_dim)[simdgroup_tid];
|
| 74 |
+
float2 q1 = reinterpret_cast<const device float2*>(q + 1 * head_dim)[simdgroup_tid];
|
| 75 |
+
float2 q2 = reinterpret_cast<const device float2*>(q + 2 * head_dim)[simdgroup_tid];
|
| 76 |
+
float2 q3 = reinterpret_cast<const device float2*>(q + 3 * head_dim)[simdgroup_tid];
|
| 77 |
+
float2 q4 = reinterpret_cast<const device float2*>(q + 4 * head_dim)[simdgroup_tid];
|
| 78 |
+
float2 q5 = reinterpret_cast<const device float2*>(q + 5 * head_dim)[simdgroup_tid];
|
| 79 |
+
float2 q6 = reinterpret_cast<const device float2*>(q + 6 * head_dim)[simdgroup_tid];
|
| 80 |
+
float2 q7 = reinterpret_cast<const device float2*>(q + 7 * head_dim)[simdgroup_tid];
|
| 81 |
+
|
| 82 |
+
const uint kt_end = qt + args.num_kv_tokens + 1;
|
| 83 |
+
const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx;
|
| 84 |
+
kv += token_stride * kt_start;
|
| 85 |
+
for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) {
|
| 86 |
+
const float2 kval = reinterpret_cast<const device float2*>(kv)[simdgroup_tid];
|
| 87 |
+
|
| 88 |
+
float qk0 = metal::dot(q0, kval);
|
| 89 |
+
float qk1 = metal::dot(q1, kval);
|
| 90 |
+
float qk2 = metal::dot(q2, kval);
|
| 91 |
+
float qk3 = metal::dot(q3, kval);
|
| 92 |
+
float qk4 = metal::dot(q4, kval);
|
| 93 |
+
float qk5 = metal::dot(q5, kval);
|
| 94 |
+
float qk6 = metal::dot(q6, kval);
|
| 95 |
+
float qk7 = metal::dot(q7, kval);
|
| 96 |
+
|
| 97 |
+
qk0 = metal::simd_sum(qk0);
|
| 98 |
+
qk1 = metal::simd_sum(qk1);
|
| 99 |
+
qk2 = metal::simd_sum(qk2);
|
| 100 |
+
qk3 = metal::simd_sum(qk3);
|
| 101 |
+
qk4 = metal::simd_sum(qk4);
|
| 102 |
+
qk5 = metal::simd_sum(qk5);
|
| 103 |
+
qk6 = metal::simd_sum(qk6);
|
| 104 |
+
qk7 = metal::simd_sum(qk7);
|
| 105 |
+
|
| 106 |
+
const float new_m0 = metal::max(m0, qk0);
|
| 107 |
+
const float new_m1 = metal::max(m1, qk1);
|
| 108 |
+
const float new_m2 = metal::max(m2, qk2);
|
| 109 |
+
const float new_m3 = metal::max(m3, qk3);
|
| 110 |
+
const float new_m4 = metal::max(m4, qk4);
|
| 111 |
+
const float new_m5 = metal::max(m5, qk5);
|
| 112 |
+
const float new_m6 = metal::max(m6, qk6);
|
| 113 |
+
const float new_m7 = metal::max(m7, qk7);
|
| 114 |
+
|
| 115 |
+
const float alpha0 = metal::fast::exp(m0 - new_m0);
|
| 116 |
+
const float alpha1 = metal::fast::exp(m1 - new_m1);
|
| 117 |
+
const float alpha2 = metal::fast::exp(m2 - new_m2);
|
| 118 |
+
const float alpha3 = metal::fast::exp(m3 - new_m3);
|
| 119 |
+
const float alpha4 = metal::fast::exp(m4 - new_m4);
|
| 120 |
+
const float alpha5 = metal::fast::exp(m5 - new_m5);
|
| 121 |
+
const float alpha6 = metal::fast::exp(m6 - new_m6);
|
| 122 |
+
const float alpha7 = metal::fast::exp(m7 - new_m7);
|
| 123 |
+
|
| 124 |
+
qk0 = metal::fast::exp(qk0 - new_m0);
|
| 125 |
+
qk1 = metal::fast::exp(qk1 - new_m1);
|
| 126 |
+
qk2 = metal::fast::exp(qk2 - new_m2);
|
| 127 |
+
qk3 = metal::fast::exp(qk3 - new_m3);
|
| 128 |
+
qk4 = metal::fast::exp(qk4 - new_m4);
|
| 129 |
+
qk5 = metal::fast::exp(qk5 - new_m5);
|
| 130 |
+
qk6 = metal::fast::exp(qk6 - new_m6);
|
| 131 |
+
qk7 = metal::fast::exp(qk7 - new_m7);
|
| 132 |
+
|
| 133 |
+
l0 = metal::fma(l0, alpha0, qk0);
|
| 134 |
+
l1 = metal::fma(l1, alpha1, qk1);
|
| 135 |
+
l2 = metal::fma(l2, alpha2, qk2);
|
| 136 |
+
l3 = metal::fma(l3, alpha3, qk3);
|
| 137 |
+
l4 = metal::fma(l4, alpha4, qk4);
|
| 138 |
+
l5 = metal::fma(l5, alpha5, qk5);
|
| 139 |
+
l6 = metal::fma(l6, alpha6, qk6);
|
| 140 |
+
l7 = metal::fma(l7, alpha7, qk7);
|
| 141 |
+
|
| 142 |
+
m0 = new_m0;
|
| 143 |
+
m1 = new_m1;
|
| 144 |
+
m2 = new_m2;
|
| 145 |
+
m3 = new_m3;
|
| 146 |
+
m4 = new_m4;
|
| 147 |
+
m5 = new_m5;
|
| 148 |
+
m6 = new_m6;
|
| 149 |
+
m7 = new_m7;
|
| 150 |
+
|
| 151 |
+
const float2 vval = reinterpret_cast<const device float2*>(kv + head_dim)[simdgroup_tid];
|
| 152 |
+
kv += token_stride * num_simdgroups;
|
| 153 |
+
out0 = metal::fma(vval, qk0, out0 * alpha0);
|
| 154 |
+
out1 = metal::fma(vval, qk1, out1 * alpha1);
|
| 155 |
+
out2 = metal::fma(vval, qk2, out2 * alpha2);
|
| 156 |
+
out3 = metal::fma(vval, qk3, out3 * alpha3);
|
| 157 |
+
out4 = metal::fma(vval, qk4, out4 * alpha4);
|
| 158 |
+
out5 = metal::fma(vval, qk5, out5 * alpha5);
|
| 159 |
+
out6 = metal::fma(vval, qk6, out6 * alpha6);
|
| 160 |
+
out7 = metal::fma(vval, qk7, out7 * alpha7);
|
| 161 |
+
}
|
| 162 |
+
if (num_simdgroups > 1) {
|
| 163 |
+
if (metal::simd_is_first()) {
|
| 164 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_idx] = m0;
|
| 165 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_idx] = m1;
|
| 166 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_idx] = m2;
|
| 167 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_idx] = m3;
|
| 168 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_idx] = m4;
|
| 169 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_idx] = m5;
|
| 170 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_idx] = m6;
|
| 171 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_idx] = m7;
|
| 172 |
+
|
| 173 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_idx] = l0;
|
| 174 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_idx] = l1;
|
| 175 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_idx] = l2;
|
| 176 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_idx] = l3;
|
| 177 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_idx] = l4;
|
| 178 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_idx] = l5;
|
| 179 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_idx] = l6;
|
| 180 |
+
static_cast<threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_idx] = l7;
|
| 181 |
+
}
|
| 182 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 183 |
+
// Note: simdgroup refers not to the thread's current simdgroup, but to one with simdgroup_idx == thread's simdgroup_tid.
|
| 184 |
+
float simdgroup_m0 = m0;
|
| 185 |
+
float simdgroup_m1 = m1;
|
| 186 |
+
float simdgroup_m2 = m2;
|
| 187 |
+
float simdgroup_m3 = m3;
|
| 188 |
+
float simdgroup_m4 = m4;
|
| 189 |
+
float simdgroup_m5 = m5;
|
| 190 |
+
float simdgroup_m6 = m6;
|
| 191 |
+
float simdgroup_m7 = m7;
|
| 192 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 193 |
+
simdgroup_m0 = static_cast<const threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_tid];
|
| 194 |
+
simdgroup_m1 = static_cast<const threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_tid];
|
| 195 |
+
simdgroup_m2 = static_cast<const threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_tid];
|
| 196 |
+
simdgroup_m3 = static_cast<const threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_tid];
|
| 197 |
+
simdgroup_m4 = static_cast<const threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_tid];
|
| 198 |
+
simdgroup_m5 = static_cast<const threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_tid];
|
| 199 |
+
simdgroup_m6 = static_cast<const threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_tid];
|
| 200 |
+
simdgroup_m7 = static_cast<const threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_tid];
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
const float threadgroup_m0 = metal::simd_max(simdgroup_m0);
|
| 204 |
+
const float threadgroup_m1 = metal::simd_max(simdgroup_m1);
|
| 205 |
+
const float threadgroup_m2 = metal::simd_max(simdgroup_m2);
|
| 206 |
+
const float threadgroup_m3 = metal::simd_max(simdgroup_m3);
|
| 207 |
+
const float threadgroup_m4 = metal::simd_max(simdgroup_m4);
|
| 208 |
+
const float threadgroup_m5 = metal::simd_max(simdgroup_m5);
|
| 209 |
+
const float threadgroup_m6 = metal::simd_max(simdgroup_m6);
|
| 210 |
+
const float threadgroup_m7 = metal::simd_max(simdgroup_m7);
|
| 211 |
+
|
| 212 |
+
out0 *= metal::fast::exp(m0 - threadgroup_m0);
|
| 213 |
+
out1 *= metal::fast::exp(m1 - threadgroup_m1);
|
| 214 |
+
out2 *= metal::fast::exp(m2 - threadgroup_m2);
|
| 215 |
+
out3 *= metal::fast::exp(m3 - threadgroup_m3);
|
| 216 |
+
out4 *= metal::fast::exp(m4 - threadgroup_m4);
|
| 217 |
+
out5 *= metal::fast::exp(m5 - threadgroup_m5);
|
| 218 |
+
out6 *= metal::fast::exp(m6 - threadgroup_m6);
|
| 219 |
+
out7 *= metal::fast::exp(m7 - threadgroup_m7);
|
| 220 |
+
|
| 221 |
+
if (simdgroup_idx == 0) {
|
| 222 |
+
l0 = 0.0f;
|
| 223 |
+
l1 = 0.0f;
|
| 224 |
+
l2 = 0.0f;
|
| 225 |
+
l3 = 0.0f;
|
| 226 |
+
l4 = 0.0f;
|
| 227 |
+
l5 = 0.0f;
|
| 228 |
+
l6 = 0.0f;
|
| 229 |
+
l7 = 0.0f;
|
| 230 |
+
if (simdgroup_tid < num_simdgroups) {
|
| 231 |
+
l0 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_tid];
|
| 232 |
+
l1 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_tid];
|
| 233 |
+
l2 = static_cast<const threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_tid];
|
| 234 |
+
l3 = static_cast<const threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_tid];
|
| 235 |
+
l4 = static_cast<const threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_tid];
|
| 236 |
+
l5 = static_cast<const threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_tid];
|
| 237 |
+
l6 = static_cast<const threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_tid];
|
| 238 |
+
l7 = static_cast<const threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_tid];
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
l0 = metal::simd_sum(l0 * metal::fast::exp(simdgroup_m0 - threadgroup_m0));
|
| 242 |
+
l1 = metal::simd_sum(l1 * metal::fast::exp(simdgroup_m1 - threadgroup_m1));
|
| 243 |
+
l2 = metal::simd_sum(l2 * metal::fast::exp(simdgroup_m2 - threadgroup_m2));
|
| 244 |
+
l3 = metal::simd_sum(l3 * metal::fast::exp(simdgroup_m3 - threadgroup_m3));
|
| 245 |
+
l4 = metal::simd_sum(l4 * metal::fast::exp(simdgroup_m4 - threadgroup_m4));
|
| 246 |
+
l5 = metal::simd_sum(l5 * metal::fast::exp(simdgroup_m5 - threadgroup_m5));
|
| 247 |
+
l6 = metal::simd_sum(l6 * metal::fast::exp(simdgroup_m6 - threadgroup_m6));
|
| 248 |
+
l7 = metal::simd_sum(l7 * metal::fast::exp(simdgroup_m7 - threadgroup_m7));
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
uint num_threads = num_simdgroups * simdgroup_size;
|
| 252 |
+
do {
|
| 253 |
+
const uint num_smem_threads = (num_threads / 2) & -simdgroup_size;
|
| 254 |
+
const uint num_half_threads = num_threads - num_smem_threads;
|
| 255 |
+
|
| 256 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 257 |
+
const uint smem_tid = tid.x - num_half_threads;
|
| 258 |
+
if (smem_tid < num_smem_threads) {
|
| 259 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + smem_tid] = out0;
|
| 260 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + smem_tid] = out1;
|
| 261 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + smem_tid] = out2;
|
| 262 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + smem_tid] = out3;
|
| 263 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + smem_tid] = out4;
|
| 264 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + smem_tid] = out5;
|
| 265 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + smem_tid] = out6;
|
| 266 |
+
static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + smem_tid] = out7;
|
| 267 |
+
}
|
| 268 |
+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
|
| 269 |
+
if (tid.x < num_smem_threads) {
|
| 270 |
+
out0 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + tid.x];
|
| 271 |
+
out1 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + tid.x];
|
| 272 |
+
out2 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + tid.x];
|
| 273 |
+
out3 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + tid.x];
|
| 274 |
+
out4 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + tid.x];
|
| 275 |
+
out5 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + tid.x];
|
| 276 |
+
out6 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + tid.x];
|
| 277 |
+
out7 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + tid.x];
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
num_threads = num_half_threads;
|
| 281 |
+
} while (num_threads > simdgroup_size);
|
| 282 |
+
}
|
| 283 |
+
if (simdgroup_idx == 0) {
|
| 284 |
+
reinterpret_cast<device float2*>(output + 0 * head_dim)[simdgroup_tid] = out0 / l0;
|
| 285 |
+
reinterpret_cast<device float2*>(output + 1 * head_dim)[simdgroup_tid] = out1 / l1;
|
| 286 |
+
reinterpret_cast<device float2*>(output + 2 * head_dim)[simdgroup_tid] = out2 / l2;
|
| 287 |
+
reinterpret_cast<device float2*>(output + 3 * head_dim)[simdgroup_tid] = out3 / l3;
|
| 288 |
+
reinterpret_cast<device float2*>(output + 4 * head_dim)[simdgroup_tid] = out4 / l4;
|
| 289 |
+
reinterpret_cast<device float2*>(output + 5 * head_dim)[simdgroup_tid] = out5 / l5;
|
| 290 |
+
reinterpret_cast<device float2*>(output + 6 * head_dim)[simdgroup_tid] = out6 / l6;
|
| 291 |
+
reinterpret_cast<device float2*>(output + 7 * head_dim)[simdgroup_tid] = out7 / l7;
|
| 292 |
+
}
|
| 293 |
+
}
|
gptoss_kernels/source/tokenizer.c
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <assert.h>
|
| 2 |
+
#include <stdatomic.h>
|
| 3 |
+
#include <stddef.h>
|
| 4 |
+
#include <stdint.h>
|
| 5 |
+
#include <stdlib.h>
|
| 6 |
+
#include <string.h>
|
| 7 |
+
|
| 8 |
+
#include <errno.h>
|
| 9 |
+
#include <sys/mman.h>
|
| 10 |
+
|
| 11 |
+
#include <gpt-oss.h>
|
| 12 |
+
|
| 13 |
+
#include "internal/log.h"
|
| 14 |
+
#include "internal/model.h"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
|
| 18 |
+
gptoss_tokenizer_t tokenizer,
|
| 19 |
+
enum gptoss_special_token token_type,
|
| 20 |
+
uint32_t* token_id_out)
|
| 21 |
+
{
|
| 22 |
+
uint32_t token_id = UINT32_MAX;
|
| 23 |
+
if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)
|
| 24 |
+
{
|
| 25 |
+
token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];
|
| 26 |
+
}
|
| 27 |
+
if (token_id == UINT32_MAX) {
|
| 28 |
+
return gptoss_status_invalid_argument;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
*token_id_out = token_id;
|
| 32 |
+
return gptoss_status_success;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
|
| 36 |
+
gptoss_tokenizer_t tokenizer,
|
| 37 |
+
uint32_t* num_text_tokens_out)
|
| 38 |
+
{
|
| 39 |
+
*num_text_tokens_out = tokenizer->num_text_tokens;
|
| 40 |
+
return gptoss_status_success;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
|
| 44 |
+
gptoss_tokenizer_t tokenizer,
|
| 45 |
+
uint32_t* num_special_tokens_out)
|
| 46 |
+
{
|
| 47 |
+
*num_special_tokens_out = tokenizer->num_special_tokens;
|
| 48 |
+
return gptoss_status_success;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
|
| 52 |
+
gptoss_tokenizer_t tokenizer,
|
| 53 |
+
uint32_t* num_tokens_out)
|
| 54 |
+
{
|
| 55 |
+
*num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;
|
| 56 |
+
return gptoss_status_success;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
|
| 60 |
+
gptoss_tokenizer_t tokenizer,
|
| 61 |
+
uint32_t token_id,
|
| 62 |
+
const void** token_ptr_out,
|
| 63 |
+
size_t* token_size_out)
|
| 64 |
+
{
|
| 65 |
+
if (token_id >= tokenizer->num_text_tokens) {
|
| 66 |
+
return gptoss_status_invalid_argument;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const char* token_ptr = (const char*) tokenizer->tokens_ptr;
|
| 70 |
+
for (uint32_t t = 0; t < token_id; t++) {
|
| 71 |
+
// Reading unaligned uint16_t
|
| 72 |
+
uint16_t token_length;
|
| 73 |
+
memcpy(&token_length, token_ptr, sizeof(token_length));
|
| 74 |
+
|
| 75 |
+
token_ptr += (size_t) token_length + sizeof(uint16_t);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
*token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));
|
| 79 |
+
*token_size_out = (size_t) *token_ptr;
|
| 80 |
+
return gptoss_status_success;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
|
| 84 |
+
gptoss_tokenizer_t tokenizer)
|
| 85 |
+
{
|
| 86 |
+
atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
|
| 87 |
+
return gptoss_status_success;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
|
| 91 |
+
gptoss_tokenizer_t tokenizer)
|
| 92 |
+
{
|
| 93 |
+
if (tokenizer != NULL) {
|
| 94 |
+
if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {
|
| 95 |
+
if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {
|
| 96 |
+
if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {
|
| 97 |
+
GPTOSS_LOG_WARNING("munmap for tokenizer mapping failed with error %d", errno);
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
|
| 102 |
+
free(tokenizer);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
return gptoss_status_success;
|
| 106 |
+
}
|
gptoss_kernels/source/topk.metal
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_compute>
|
| 2 |
+
#include <metal_integer>
|
| 3 |
+
#include <metal_math>
|
| 4 |
+
#include <metal_simdgroup>
|
| 5 |
+
|
| 6 |
+
#include <internal/kernel-args.h>
|
| 7 |
+
|
| 8 |
+
#pragma METAL fp math_mode(safe)
|
| 9 |
+
#pragma METAL fp contract(off)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
[[max_total_threads_per_threadgroup(32)]]
|
| 13 |
+
kernel void gptoss_f32_topk_softmax_e128_k4(
|
| 14 |
+
constant gptoss_topk_args& args [[ buffer(0) ]],
|
| 15 |
+
const device float4* input [[ buffer(1) ]],
|
| 16 |
+
device gptoss_expert_prediction* output [[ buffer(2) ]],
|
| 17 |
+
const device gptoss_control* control [[ buffer(3) ]],
|
| 18 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 19 |
+
uint tid [[thread_position_in_threadgroup]])
|
| 20 |
+
{
|
| 21 |
+
const uint num_experts = 128;
|
| 22 |
+
const uint num_active_experts = 4;
|
| 23 |
+
if (control->abort != 0) {
|
| 24 |
+
return;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
input += gid * (num_experts / 4);
|
| 28 |
+
output += gid * num_active_experts;
|
| 29 |
+
|
| 30 |
+
uint4 idx = tid * 4 + (uint4) {0, 1, 2, 3};
|
| 31 |
+
float4 val = input[tid];
|
| 32 |
+
|
| 33 |
+
const float topval0 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
| 34 |
+
uint idx0 = 0xFFFFFFFFu;
|
| 35 |
+
if (val.w == topval0) {
|
| 36 |
+
idx0 = idx.w;
|
| 37 |
+
}
|
| 38 |
+
if (val.z == topval0) {
|
| 39 |
+
idx0 = idx.z;
|
| 40 |
+
}
|
| 41 |
+
if (val.y == topval0) {
|
| 42 |
+
idx0 = idx.y;
|
| 43 |
+
}
|
| 44 |
+
if (val.x == topval0) {
|
| 45 |
+
idx0 = idx.x;
|
| 46 |
+
}
|
| 47 |
+
const uint topidx0 = metal::simd_min(idx0);
|
| 48 |
+
const bool4 is_topidx0 = idx == topidx0;
|
| 49 |
+
val = metal::select(val, -INFINITY, is_topidx0);
|
| 50 |
+
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx0);
|
| 51 |
+
|
| 52 |
+
const float topval1 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
| 53 |
+
uint idx1 = 0xFFFFFFFFu;
|
| 54 |
+
if (val.w == topval1) {
|
| 55 |
+
idx1 = idx.w;
|
| 56 |
+
}
|
| 57 |
+
if (val.z == topval1) {
|
| 58 |
+
idx1 = idx.z;
|
| 59 |
+
}
|
| 60 |
+
if (val.y == topval1) {
|
| 61 |
+
idx1 = idx.y;
|
| 62 |
+
}
|
| 63 |
+
if (val.x == topval1) {
|
| 64 |
+
idx1 = idx.x;
|
| 65 |
+
}
|
| 66 |
+
const uint topidx1 = metal::simd_min(idx1);
|
| 67 |
+
const bool4 is_topidx1 = idx == topidx1;
|
| 68 |
+
val = metal::select(val, -INFINITY, is_topidx1);
|
| 69 |
+
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx1);
|
| 70 |
+
|
| 71 |
+
const float topval2 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
| 72 |
+
uint idx2 = 0xFFFFFFFFu;
|
| 73 |
+
if (val.w == topval2) {
|
| 74 |
+
idx2 = idx.w;
|
| 75 |
+
}
|
| 76 |
+
if (val.z == topval2) {
|
| 77 |
+
idx2 = idx.z;
|
| 78 |
+
}
|
| 79 |
+
if (val.y == topval2) {
|
| 80 |
+
idx2 = idx.y;
|
| 81 |
+
}
|
| 82 |
+
if (val.x == topval2) {
|
| 83 |
+
idx2 = idx.x;
|
| 84 |
+
}
|
| 85 |
+
const uint topidx2 = metal::simd_min(idx2);
|
| 86 |
+
const bool4 is_topidx2 = idx == topidx2;
|
| 87 |
+
val = metal::select(val, -INFINITY, is_topidx2);
|
| 88 |
+
idx = metal::select(idx, 0xFFFFFFFFu, is_topidx2);
|
| 89 |
+
|
| 90 |
+
const float topval3 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
|
| 91 |
+
uint idx3 = 0xFFFFFFFFu;
|
| 92 |
+
if (val.w == topval3) {
|
| 93 |
+
idx3 = idx.w;
|
| 94 |
+
}
|
| 95 |
+
if (val.z == topval3) {
|
| 96 |
+
idx3 = idx.z;
|
| 97 |
+
}
|
| 98 |
+
if (val.y == topval3) {
|
| 99 |
+
idx3 = idx.y;
|
| 100 |
+
}
|
| 101 |
+
if (val.x == topval3) {
|
| 102 |
+
idx3 = idx.x;
|
| 103 |
+
}
|
| 104 |
+
const uint topidx3 = metal::simd_min(idx3);
|
| 105 |
+
|
| 106 |
+
if (metal::simd_is_first()) {
|
| 107 |
+
const float topexp0 = 1.0f;
|
| 108 |
+
const float topexp1 = metal::precise::exp(topval1 - topval0);
|
| 109 |
+
const float topexp2 = metal::precise::exp(topval2 - topval0);
|
| 110 |
+
const float topexp3 = metal::precise::exp(topval3 - topval0);
|
| 111 |
+
|
| 112 |
+
const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
|
| 113 |
+
const float scale = 1.0 / sum;
|
| 114 |
+
|
| 115 |
+
output[0] = (gptoss_expert_prediction) {
|
| 116 |
+
.expert_id = topidx0,
|
| 117 |
+
.score = topexp0 * scale,
|
| 118 |
+
};
|
| 119 |
+
output[1] = (gptoss_expert_prediction) {
|
| 120 |
+
.expert_id = topidx1,
|
| 121 |
+
.score = topexp1 * scale,
|
| 122 |
+
};
|
| 123 |
+
output[2] = (gptoss_expert_prediction) {
|
| 124 |
+
.expert_id = topidx2,
|
| 125 |
+
.score = topexp2 * scale,
|
| 126 |
+
};
|
| 127 |
+
output[3] = (gptoss_expert_prediction) {
|
| 128 |
+
.expert_id = topidx3,
|
| 129 |
+
.score = topexp3 * scale,
|
| 130 |
+
};
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
[[max_total_threads_per_threadgroup(32)]]
|
| 135 |
+
kernel void gptoss_f32_topk_softmax_e32_k4(
|
| 136 |
+
constant gptoss_topk_args& args [[ buffer(0) ]],
|
| 137 |
+
const device float* input [[ buffer(1) ]],
|
| 138 |
+
device gptoss_expert_prediction* output [[ buffer(2) ]],
|
| 139 |
+
const device gptoss_control* control [[ buffer(3) ]],
|
| 140 |
+
uint gid [[threadgroup_position_in_grid]],
|
| 141 |
+
uint tid [[thread_position_in_threadgroup]])
|
| 142 |
+
{
|
| 143 |
+
const uint num_experts = 32;
|
| 144 |
+
const uint num_active_experts = 4;
|
| 145 |
+
if (control->abort != 0) {
|
| 146 |
+
return;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
input += gid * num_experts;
|
| 150 |
+
output += gid * num_active_experts;
|
| 151 |
+
|
| 152 |
+
float val = input[tid];
|
| 153 |
+
uint idx = tid;
|
| 154 |
+
|
| 155 |
+
const float topval0 = metal::simd_max(val);
|
| 156 |
+
const uint topidx0 = metal::simd_min(val == topval0 ? idx : 0xFFFFFFFFu);
|
| 157 |
+
if (idx == topidx0) {
|
| 158 |
+
val = -INFINITY;
|
| 159 |
+
idx = 0xFFFFFFFFu;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
const float topval1 = metal::simd_max(val);
|
| 163 |
+
const uint topidx1 = metal::simd_min(val == topval1 ? idx : 0xFFFFFFFFu);
|
| 164 |
+
if (idx == topidx1) {
|
| 165 |
+
val = -INFINITY;
|
| 166 |
+
idx = 0xFFFFFFFFu;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
const float topval2 = metal::simd_max(val);
|
| 170 |
+
const uint topidx2 = metal::simd_min(val == topval2 ? idx : 0xFFFFFFFFu);
|
| 171 |
+
if (idx == topidx2) {
|
| 172 |
+
val = -INFINITY;
|
| 173 |
+
idx = 0xFFFFFFFFu;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
const float topval3 = metal::simd_max(val);
|
| 177 |
+
const uint topidx3 = metal::simd_min(val == topval3 ? idx : 0xFFFFFFFFu);
|
| 178 |
+
|
| 179 |
+
if (metal::simd_is_first()) {
|
| 180 |
+
const float topexp0 = 1.0f;
|
| 181 |
+
const float topexp1 = metal::precise::exp(topval1 - topval0);
|
| 182 |
+
const float topexp2 = metal::precise::exp(topval2 - topval0);
|
| 183 |
+
const float topexp3 = metal::precise::exp(topval3 - topval0);
|
| 184 |
+
|
| 185 |
+
const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
|
| 186 |
+
const float scale = 1.0 / sum;
|
| 187 |
+
|
| 188 |
+
output[0] = (gptoss_expert_prediction) {
|
| 189 |
+
.expert_id = topidx0,
|
| 190 |
+
.score = topexp0 * scale,
|
| 191 |
+
};
|
| 192 |
+
output[1] = (gptoss_expert_prediction) {
|
| 193 |
+
.expert_id = topidx1,
|
| 194 |
+
.score = topexp1 * scale,
|
| 195 |
+
};
|
| 196 |
+
output[2] = (gptoss_expert_prediction) {
|
| 197 |
+
.expert_id = topidx2,
|
| 198 |
+
.score = topexp2 * scale,
|
| 199 |
+
};
|
| 200 |
+
output[3] = (gptoss_expert_prediction) {
|
| 201 |
+
.expert_id = topidx3,
|
| 202 |
+
.score = topexp3 * scale,
|
| 203 |
+
};
|
| 204 |
+
}
|
| 205 |
+
}
|
gptoss_kernels/test/bf16-f32-embeddings.cc
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <gtest/gtest.h>
|
| 2 |
+
|
| 3 |
+
#include <cstddef>
|
| 4 |
+
|
| 5 |
+
#include "embeddings-kernel-tester.hpp"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
using gptoss::EmbeddingsKernelTester;
|
| 9 |
+
|
| 10 |
+
constexpr std::size_t kThreadgroupSize = 64;
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TEST(BF16_F32_EMBEDDINGS, single_token_single_tile) {
|
| 14 |
+
EmbeddingsKernelTester()
|
| 15 |
+
.num_channels(kThreadgroupSize)
|
| 16 |
+
.threadgroup_size(kThreadgroupSize)
|
| 17 |
+
.TestBF16_F32();
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
TEST(BF16_F32_EMBEDDINGS, single_token_multi_tile) {
|
| 21 |
+
EmbeddingsKernelTester()
|
| 22 |
+
.num_channels(kThreadgroupSize * 4 + 16)
|
| 23 |
+
.threadgroup_size(kThreadgroupSize)
|
| 24 |
+
.TestBF16_F32();
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
TEST(BF16_F32_EMBEDDINGS, multiple_tokens) {
|
| 28 |
+
EmbeddingsKernelTester()
|
| 29 |
+
.num_channels(kThreadgroupSize * 4 + 16)
|
| 30 |
+
.num_tokens(3)
|
| 31 |
+
.threadgroup_size(kThreadgroupSize)
|
| 32 |
+
.TestBF16_F32();
|
| 33 |
+
}
|
gptoss_kernels/test/embeddings-kernel-tester.hpp
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <gtest/gtest.h>
|
| 4 |
+
|
| 5 |
+
#include <cstddef>
|
| 6 |
+
#include <cstdint>
|
| 7 |
+
|
| 8 |
+
#include <internal/datatype.hpp>
|
| 9 |
+
#include <internal/metal.hpp>
|
| 10 |
+
#include <internal/metal-kernels.h>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
namespace gptoss {
|
| 14 |
+
|
| 15 |
+
class EmbeddingsKernelTester {
|
| 16 |
+
public:
|
| 17 |
+
EmbeddingsKernelTester() { }
|
| 18 |
+
|
| 19 |
+
EmbeddingsKernelTester(const EmbeddingsKernelTester&) = delete;
|
| 20 |
+
EmbeddingsKernelTester(EmbeddingsKernelTester&&) = delete;
|
| 21 |
+
EmbeddingsKernelTester& operator=(const EmbeddingsKernelTester&) = delete;
|
| 22 |
+
EmbeddingsKernelTester& operator=(EmbeddingsKernelTester&&) = delete;
|
| 23 |
+
|
| 24 |
+
[[nodiscard]]
|
| 25 |
+
EmbeddingsKernelTester& num_channels(std::uint32_t num_channels) {
|
| 26 |
+
num_channels_ = num_channels;
|
| 27 |
+
return *this;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
std::uint32_t num_channels() const {
|
| 31 |
+
return num_channels_;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
[[nodiscard]]
|
| 35 |
+
EmbeddingsKernelTester& num_tokens(std::uint32_t num_tokens) {
|
| 36 |
+
num_tokens_ = num_tokens;
|
| 37 |
+
return *this;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
std::uint32_t num_tokens() const {
|
| 41 |
+
return num_tokens_;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
std::uint32_t vocabulary_size() const {
|
| 45 |
+
return num_tokens() + 1;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
[[nodiscard]]
|
| 49 |
+
EmbeddingsKernelTester& threadgroup_size(std::size_t threadgroup_size) {
|
| 50 |
+
threadgroup_size_ = threadgroup_size;
|
| 51 |
+
return *this;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
std::size_t threadgroup_size() const {
|
| 55 |
+
return threadgroup_size_;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
void Validate() const {
|
| 59 |
+
ASSERT_NE(num_channels(), 0);
|
| 60 |
+
ASSERT_NE(num_tokens(), 0);
|
| 61 |
+
ASSERT_NE(threadgroup_size(), 0);
|
| 62 |
+
ASSERT_EQ(threadgroup_size() % 32, 0);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
void TestBF16_F32() const {
|
| 66 |
+
Validate();
|
| 67 |
+
|
| 68 |
+
metal::CommandBuffer command_buffer{command_queue_};
|
| 69 |
+
metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};
|
| 70 |
+
metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};
|
| 71 |
+
metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};
|
| 72 |
+
metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
|
| 73 |
+
std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
|
| 74 |
+
|
| 75 |
+
std::uint32_t* token_ptr = static_cast<std::uint32_t*>(token_buffer.ptr());
|
| 76 |
+
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
| 77 |
+
token_ptr[t] = t + 1;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
Check(gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
|
| 81 |
+
command_buffer.handle(),
|
| 82 |
+
bf16_f32_embeddings_fn.handle(),
|
| 83 |
+
threadgroup_size(),
|
| 84 |
+
token_buffer.handle(),
|
| 85 |
+
/*token_offset=*/0,
|
| 86 |
+
weight_buffer.handle(),
|
| 87 |
+
/*weight_offset=*/0,
|
| 88 |
+
output_buffer.handle(),
|
| 89 |
+
/*output_offset=*/0,
|
| 90 |
+
control_buffer.handle(),
|
| 91 |
+
/*control_offset=*/0,
|
| 92 |
+
num_tokens(),
|
| 93 |
+
num_channels()),
|
| 94 |
+
"gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings");
|
| 95 |
+
|
| 96 |
+
command_buffer.commit();
|
| 97 |
+
command_buffer.wait_completion();
|
| 98 |
+
|
| 99 |
+
const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());
|
| 100 |
+
const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
|
| 101 |
+
for (std::uint32_t t = 0; t < num_tokens(); t++) {
|
| 102 |
+
const std::uint32_t token = token_ptr[t];
|
| 103 |
+
for (std::uint32_t i = 0; i < num_channels(); i++) {
|
| 104 |
+
const gptoss_bfloat16 input_val = weight_ptr[token * num_channels() + i];
|
| 105 |
+
const float ref_output = upcast<float>(input_val);
|
| 106 |
+
const float output = output_ptr[t * num_channels() + i];
|
| 107 |
+
ASSERT_EQ(output, ref_output)
|
| 108 |
+
<< "at token " << t << ", position " << i << " / " << num_channels() << ", input " << std::uint32_t(input_val.bits);
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
private:
|
| 114 |
+
metal::Device device_{};
|
| 115 |
+
metal::CommandQueue command_queue_{device_};
|
| 116 |
+
metal::Library library_{device_};
|
| 117 |
+
metal::Function bf16_f32_embeddings_fn{library_, "gptoss_bf16_f32_embeddings"};
|
| 118 |
+
std::uint32_t num_tokens_{1};
|
| 119 |
+
std::uint32_t num_channels_{1};
|
| 120 |
+
std::size_t threadgroup_size_{32};
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
} // namespace gptoss
|
gptoss_kernels/test/f32-bf16w-matmul.cc
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <gtest/gtest.h>
|
| 2 |
+
|
| 3 |
+
#include <cstddef>
|
| 4 |
+
#include <cstdint>
|
| 5 |
+
|
| 6 |
+
#include "matmul-kernel-tester.hpp"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
using gptoss::MatMulKernelTester;
|
| 10 |
+
|
| 11 |
+
constexpr size_t kSimdgroupSize = 32; // fixed in the kernel
|
| 12 |
+
|
| 13 |
+
TEST(F32_BF16W_MATMUL, single_simdgroup_single_iteration) {
|
| 14 |
+
MatMulKernelTester()
|
| 15 |
+
.num_rows(1)
|
| 16 |
+
.num_cols(kSimdgroupSize * 4)
|
| 17 |
+
.threadgroup_size(kSimdgroupSize)
|
| 18 |
+
.TestF32_BF16W();
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
TEST(F32_BF16W_MATMUL, single_simdgroup_multiple_iteration) {
|
| 22 |
+
MatMulKernelTester()
|
| 23 |
+
.num_rows(1)
|
| 24 |
+
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
| 25 |
+
.threadgroup_size(kSimdgroupSize)
|
| 26 |
+
.TestF32_BF16W();
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
TEST(F32_BF16W_MATMUL, single_threadgroup) {
|
| 30 |
+
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
| 31 |
+
|
| 32 |
+
MatMulKernelTester()
|
| 33 |
+
.num_rows(threadgroup_size / kSimdgroupSize)
|
| 34 |
+
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
| 35 |
+
.threadgroup_size(threadgroup_size)
|
| 36 |
+
.TestF32_BF16W();
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
TEST(F32_BF16W_MATMUL, multiple_threadgroups) {
|
| 40 |
+
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
| 41 |
+
constexpr std::uint32_t num_threadgroups = 3;
|
| 42 |
+
|
| 43 |
+
MatMulKernelTester()
|
| 44 |
+
.num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
|
| 45 |
+
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
| 46 |
+
.threadgroup_size(threadgroup_size)
|
| 47 |
+
.TestF32_BF16W();
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
TEST(F32_BF16W_MATMUL, multiple_tokens) {
|
| 51 |
+
constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
|
| 52 |
+
constexpr std::uint32_t num_threadgroups = 3;
|
| 53 |
+
|
| 54 |
+
MatMulKernelTester()
|
| 55 |
+
.num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
|
| 56 |
+
.num_cols((2 * kSimdgroupSize + 1) * 4)
|
| 57 |
+
.num_tokens(2)
|
| 58 |
+
.threadgroup_size(threadgroup_size)
|
| 59 |
+
.TestF32_BF16W();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
TEST(F32_BF16W_DENSE_MATMUL_QKV, seq_len_1024) {
|
| 63 |
+
MatMulKernelTester()
|
| 64 |
+
.num_tokens(1024)
|
| 65 |
+
.num_rows(5120)
|
| 66 |
+
.num_cols(2880)
|
| 67 |
+
.TestF32_BF16W(
|
| 68 |
+
MatMulKernelTester::MatMulKernelType::PREFILL_QKV_OPTIMIZED);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
TEST(F32_BF16W_DENSE_MATMUL_ATTN_OUTPUT, seq_len_1024) {
|
| 72 |
+
MatMulKernelTester()
|
| 73 |
+
.num_tokens(1024)
|
| 74 |
+
.num_rows(2880)
|
| 75 |
+
.num_cols(4096)
|
| 76 |
+
.TestF32_BF16W(
|
| 77 |
+
MatMulKernelTester::MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
TEST(F32_BF16W_DENSE_MATMUL_MLP_GATE, seq_len_1024) {
|
| 81 |
+
MatMulKernelTester()
|
| 82 |
+
.num_tokens(1024)
|
| 83 |
+
.num_rows(128)
|
| 84 |
+
.num_cols(2880)
|
| 85 |
+
.TestF32_BF16W(
|
| 86 |
+
MatMulKernelTester::MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED);
|
| 87 |
+
}
|
gptoss_kernels/test/f32-bf16w-rmsnorm.cc
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <gtest/gtest.h>
|
| 2 |
+
|
| 3 |
+
#include <cstdint>
|
| 4 |
+
|
| 5 |
+
#include "rmsnorm-kernel-tester.hpp"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
using gptoss::RMSNormKernelTester;
|
| 9 |
+
|
| 10 |
+
constexpr std::uint32_t kThreadgroupSize = 1024; // fixed in the kernel
|
| 11 |
+
constexpr std::uint32_t kVectorSize = 4; // fixed in the kernel
|
| 12 |
+
|
| 13 |
+
TEST(F32_BF16W_RMSNORM, single_iteration) {
|
| 14 |
+
RMSNormKernelTester()
|
| 15 |
+
.num_channels(kThreadgroupSize)
|
| 16 |
+
.TestF32_BF16W();
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
TEST(F32_BF16W_RMSNORM, multiple_iterations) {
|
| 20 |
+
RMSNormKernelTester()
|
| 21 |
+
.num_channels(kThreadgroupSize * 2)
|
| 22 |
+
.TestF32_BF16W();
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
TEST(F32_BF16W_RMSNORM, partial_iteration) {
|
| 26 |
+
RMSNormKernelTester()
|
| 27 |
+
.num_channels(kThreadgroupSize * 2 + kVectorSize)
|
| 28 |
+
.TestF32_BF16W();
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
TEST(F32_BF16W_RMSNORM, multiple_tokens) {
|
| 32 |
+
RMSNormKernelTester()
|
| 33 |
+
.num_tokens(3)
|
| 34 |
+
.num_channels(kThreadgroupSize * 2 + kVectorSize)
|
| 35 |
+
.TestF32_BF16W();
|
| 36 |
+
}
|