medmekk HF Staff commited on
Commit
51250cb
·
unverified ·
1 Parent(s): 560f73a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +10 -0
  2. build.toml +23 -0
  3. flake.nix +13 -0
  4. gptoss_kernels/CMakeLists.txt +191 -0
  5. gptoss_kernels/__init__.py +6 -0
  6. gptoss_kernels/examples/chat.py +104 -0
  7. gptoss_kernels/examples/generate.py +34 -0
  8. gptoss_kernels/include/gpt-oss.h +5 -0
  9. gptoss_kernels/include/gpt-oss/functions.h +401 -0
  10. gptoss_kernels/include/gpt-oss/macros.h +5 -0
  11. gptoss_kernels/include/gpt-oss/types.h +62 -0
  12. gptoss_kernels/source/accumulate.metal +59 -0
  13. gptoss_kernels/source/context.c +1115 -0
  14. gptoss_kernels/source/convert.metal +64 -0
  15. gptoss_kernels/source/embeddings.metal +29 -0
  16. gptoss_kernels/source/expert_routing_metadata.metal +41 -0
  17. gptoss_kernels/source/gather_and_accumulate.metal +74 -0
  18. gptoss_kernels/source/generate.c +317 -0
  19. gptoss_kernels/source/include/internal/datatype.h +41 -0
  20. gptoss_kernels/source/include/internal/datatype.hpp +87 -0
  21. gptoss_kernels/source/include/internal/kernel-args.h +201 -0
  22. gptoss_kernels/source/include/internal/log.h +20 -0
  23. gptoss_kernels/source/include/internal/macros.h +107 -0
  24. gptoss_kernels/source/include/internal/math.h +40 -0
  25. gptoss_kernels/source/include/internal/metal-kernels.h +486 -0
  26. gptoss_kernels/source/include/internal/metal.h +138 -0
  27. gptoss_kernels/source/include/internal/metal.hpp +342 -0
  28. gptoss_kernels/source/include/internal/model.h +178 -0
  29. gptoss_kernels/source/include/internal/rng.h +24 -0
  30. gptoss_kernels/source/include/internal/rng.hpp +32 -0
  31. gptoss_kernels/source/include/internal/storage.h +36 -0
  32. gptoss_kernels/source/include/internal/uuid.h +114 -0
  33. gptoss_kernels/source/log.c +50 -0
  34. gptoss_kernels/source/matmul.metal +422 -0
  35. gptoss_kernels/source/metal-kernels.c +1518 -0
  36. gptoss_kernels/source/metal.m +482 -0
  37. gptoss_kernels/source/model.c +581 -0
  38. gptoss_kernels/source/moematmul.metal +702 -0
  39. gptoss_kernels/source/random.metal +97 -0
  40. gptoss_kernels/source/rmsnorm.metal +58 -0
  41. gptoss_kernels/source/rope.metal +43 -0
  42. gptoss_kernels/source/sample.metal +209 -0
  43. gptoss_kernels/source/scatter.metal +65 -0
  44. gptoss_kernels/source/sdpa.metal +293 -0
  45. gptoss_kernels/source/tokenizer.c +106 -0
  46. gptoss_kernels/source/topk.metal +205 -0
  47. gptoss_kernels/test/bf16-f32-embeddings.cc +33 -0
  48. gptoss_kernels/test/embeddings-kernel-tester.hpp +123 -0
  49. gptoss_kernels/test/f32-bf16w-matmul.cc +87 -0
  50. gptoss_kernels/test/f32-bf16w-rmsnorm.cc +36 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernels
4
+ - gptoss
5
+ ---
6
+
7
+ # gptoss_kernels
8
+
9
+ This is a build for some kernel released by OpenAI in the GPT-OSS repo : https://github.com/openai/gpt-oss
10
+
build.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "gptoss_kernels"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.gptoss_kernels]
12
+ depends = ["torch"]
13
+ backend = "cuda"
14
+
15
+ src = [
16
+ "gptoss_kernels/attention_cuda_fwd.cu",
17
+ "gptoss_kernels/attention_cuda_bwd.cu",
18
+ "gptoss_kernels/attention_cuda_utils.cu",
19
+ "gptoss_kernels/attention_cuda_utils.cuh",
20
+ "gptoss_kernels/attention_cuda.cuh",
21
+ "gptoss_kernels/attention.h",
22
+ "gptoss_kernels/cudamacro.h",
23
+ ]
flake.nix ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs = { self, kernel-builder, }:
9
+ kernel-builder.lib.genFlakeOutputs {
10
+ inherit self;
11
+ path = ./.;
12
+ };
13
+ }
gptoss_kernels/CMakeLists.txt ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.24)
2
+ project(GPTOSS
3
+ VERSION 1.0
4
+ DESCRIPTION "Local GPT-OSS inference"
5
+ LANGUAGES C CXX OBJC)
6
+
7
+ set(CMAKE_C_STANDARD 11)
8
+ set(CMAKE_CXX_STANDARD 20)
9
+ set(CMAKE_OBJC_STANDARD 11)
10
+ set(CMAKE_OBJC_STANDARD_REQUIRED ON)
11
+
12
+ find_library(FOUNDATION_FRAMEWORK Foundation REQUIRED)
13
+ find_library(METAL_FRAMEWORK Metal REQUIRED)
14
+ find_library(IOKIT_FRAMEWORK IOKit REQUIRED)
15
+
16
+ set(METAL_SOURCES
17
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal
18
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal
19
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal
20
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal
21
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal
22
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal
23
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal
24
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal
25
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal
26
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal
27
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal
28
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal
29
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal
30
+ ${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal
31
+ )
32
+ set(METAL_LIB default.metallib)
33
+
34
+ include_directories(BEFORE include source/include)
35
+
36
+ add_custom_command(
37
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
38
+ COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/source/"
39
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air"
40
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/convert.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air"
41
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/embeddings.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air"
42
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/expert_routing_metadata.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air"
43
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/matmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air"
44
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/moematmul.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air"
45
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/gather_and_accumulate.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air"
46
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/random.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/random.air"
47
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rmsnorm.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air"
48
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/rope.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air"
49
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sample.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air"
50
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/scatter.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air"
51
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/sdpa.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air"
52
+ COMMAND xcrun -sdk macosx metal -g "-I${CMAKE_CURRENT_SOURCE_DIR}/source/include" -c "${CMAKE_CURRENT_SOURCE_DIR}/source/topk.metal" -o "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air"
53
+ COMMAND xcrun -sdk macosx metallib "${CMAKE_CURRENT_BINARY_DIR}/source/accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/convert.air" "${CMAKE_CURRENT_BINARY_DIR}/source/embeddings.air" "${CMAKE_CURRENT_BINARY_DIR}/source/expert_routing_metadata.air" "${CMAKE_CURRENT_BINARY_DIR}/source/gather_and_accumulate.air" "${CMAKE_CURRENT_BINARY_DIR}/source/matmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/moematmul.air" "${CMAKE_CURRENT_BINARY_DIR}/source/random.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rmsnorm.air" "${CMAKE_CURRENT_BINARY_DIR}/source/rope.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sample.air" "${CMAKE_CURRENT_BINARY_DIR}/source/scatter.air" "${CMAKE_CURRENT_BINARY_DIR}/source/sdpa.air" "${CMAKE_CURRENT_BINARY_DIR}/source/topk.air" -o "${METAL_LIB}"
54
+ DEPENDS ${METAL_SOURCES}
55
+ COMMENT "Compiling Metal compute library"
56
+ )
57
+
58
+ add_custom_target(build_metallib ALL
59
+ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB})
60
+
61
+ add_library(log OBJECT source/log.c)
62
+
63
+ add_library(metal-kernels STATIC source/metal.m source/metal-kernels.c)
64
+ target_link_libraries(metal-kernels PRIVATE log)
65
+
66
+ add_dependencies(metal-kernels build_metallib)
67
+ add_custom_command(TARGET metal-kernels POST_BUILD
68
+ COMMAND ${CMAKE_COMMAND} -E copy
69
+ ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
70
+ $<TARGET_FILE_DIR:metal-kernels>)
71
+
72
+ target_link_libraries(metal-kernels PRIVATE ${FOUNDATION_FRAMEWORK} ${METAL_FRAMEWORK} ${IOKIT_FRAMEWORK})
73
+
74
+ add_library(gptoss STATIC source/model.c source/tokenizer.c source/context.c)
75
+ target_link_libraries(gptoss PRIVATE log metal-kernels)
76
+
77
+ add_executable(generate source/generate.c)
78
+ target_link_libraries(generate gptoss)
79
+
80
+ # --- [ Tests
81
+ include(FetchContent)
82
+ FetchContent_Declare(
83
+ googletest
84
+ URL https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip
85
+ DOWNLOAD_EXTRACT_TIMESTAMP OFF
86
+ )
87
+ # For Windows: Prevent overriding the parent project's compiler/linker settings
88
+ set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
89
+ set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
90
+ FetchContent_MakeAvailable(googletest)
91
+
92
+ enable_testing()
93
+
94
+ add_executable(u32-random-test test/u32-random.cc)
95
+ target_link_libraries(u32-random-test PRIVATE GTest::gtest_main metal-kernels)
96
+ target_include_directories(u32-random-test PRIVATE source/include)
97
+ add_test(NAME u32-random-test COMMAND u32-random-test)
98
+
99
+ add_executable(f32-random-test test/f32-random.cc)
100
+ target_link_libraries(f32-random-test PRIVATE GTest::gtest_main metal-kernels)
101
+ target_include_directories(f32-random-test PRIVATE source/include)
102
+ add_test(NAME f32-random-test COMMAND f32-random-test)
103
+
104
+ add_executable(mf4-f32-convert-test test/mf4-f32-convert.cc)
105
+ target_link_libraries(mf4-f32-convert-test PRIVATE GTest::gtest_main metal-kernels)
106
+ target_include_directories(mf4-f32-convert-test PRIVATE source/include)
107
+ add_test(NAME mf4-f32-convert-test COMMAND mf4-f32-convert-test)
108
+
109
+ add_executable(bf16-f32-embeddings-test test/bf16-f32-embeddings.cc)
110
+ target_link_libraries(bf16-f32-embeddings-test PRIVATE GTest::gtest_main metal-kernels)
111
+ target_include_directories(bf16-f32-embeddings-test PRIVATE source/include)
112
+ add_test(NAME bf16-f32-embeddings-test COMMAND bf16-f32-embeddings-test)
113
+
114
+ add_executable(f32-bf16w-rmsnorm-test test/f32-bf16w-rmsnorm.cc)
115
+ target_link_libraries(f32-bf16w-rmsnorm-test PRIVATE GTest::gtest_main metal-kernels)
116
+ target_include_directories(f32-bf16w-rmsnorm-test PRIVATE source/include)
117
+ add_test(NAME f32-bf16w-rmsnorm-test COMMAND f32-bf16w-rmsnorm-test)
118
+
119
+ add_executable(f32-bf16w-matmul-test test/f32-bf16w-matmul.cc)
120
+ target_link_libraries(f32-bf16w-matmul-test PRIVATE GTest::gtest_main metal-kernels)
121
+ target_include_directories(f32-bf16w-matmul-test PRIVATE source/include)
122
+ add_test(NAME f32-bf16w-matmul-test COMMAND f32-bf16w-matmul-test)
123
+
124
+ add_executable(f32-rope-test test/f32-rope.cc)
125
+ target_link_libraries(f32-rope-test PRIVATE GTest::gtest_main metal-kernels)
126
+ target_include_directories(f32-rope-test PRIVATE source/include)
127
+ add_test(NAME f32-rope-test COMMAND f32-rope-test)
128
+
129
+ # --- [ Benchmarks
130
+ include(FetchContent)
131
+ set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable self-tests in Google Benchmark" FORCE)
132
+ set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable installation of Google Benchmark" FORCE)
133
+ FetchContent_Declare(
134
+ benchmark
135
+ URL https://github.com/google/benchmark/archive/refs/tags/v1.9.4.zip
136
+ DOWNLOAD_EXTRACT_TIMESTAMP OFF
137
+ )
138
+ FetchContent_MakeAvailable(benchmark)
139
+
140
+ add_executable(f32-random-bench benchmark/f32-random.cc)
141
+ target_link_libraries(f32-random-bench PRIVATE benchmark::benchmark metal-kernels)
142
+ target_include_directories(f32-random-bench PRIVATE source/include)
143
+
144
+ add_executable(u32-random-bench benchmark/u32-random.cc)
145
+ target_link_libraries(u32-random-bench PRIVATE benchmark::benchmark metal-kernels)
146
+ target_include_directories(u32-random-bench PRIVATE source/include)
147
+
148
+ add_executable(mf4-f32-convert-bench benchmark/mf4-f32-convert.cc)
149
+ target_link_libraries(mf4-f32-convert-bench PRIVATE benchmark::benchmark metal-kernels)
150
+ target_include_directories(mf4-f32-convert-bench PRIVATE source/include)
151
+
152
+ add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
153
+ target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
154
+ target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
155
+
156
+ add_executable(end-to-end-bench benchmark/end-to-end.cc)
157
+ target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
158
+ target_include_directories(end-to-end-bench PRIVATE source/include)
159
+
160
+ add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
161
+ target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
162
+ target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
163
+
164
+ # --- [ Python extension ] -----------------------------------------------
165
+ find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
166
+
167
+ pybind11_add_module(_metal
168
+ python/module.c
169
+ python/context.c
170
+ python/model.c
171
+ python/tokenizer.c
172
+ )
173
+ set_target_properties(_metal PROPERTIES PREFIX "")
174
+
175
+ target_link_libraries(_metal PRIVATE gptoss)
176
+ add_dependencies(_metal build_metallib)
177
+ target_link_options(_metal PRIVATE
178
+ LINKER:-sectcreate,__METAL,__shaders,${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
179
+ )
180
+ add_custom_command(TARGET _metal POST_BUILD
181
+ COMMAND ${CMAKE_COMMAND} -E copy
182
+ ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
183
+ $<TARGET_FILE_DIR:_metal>)
184
+
185
+ # 1️⃣ install the extension module into the Python package
186
+ install(TARGETS _metal LIBRARY DESTINATION gpt_oss/metal)
187
+
188
+ # 2️⃣ make sure the Metal shader archive travels with it
189
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${METAL_LIB}
190
+ DESTINATION gpt_oss/metal)
191
+ # ------------------------------------------------------------------------
gptoss_kernels/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from importlib import import_module as _im
2
+
3
+ # Load the compiled extension (gpt_oss.metal._metal)
4
+ _ext = _im(f"{__name__}._metal")
5
+ globals().update({k: v for k, v in _ext.__dict__.items() if not k.startswith("_")})
6
+ del _im, _ext
gptoss_kernels/examples/chat.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import sys
5
+
6
+ from datetime import date
7
+ from gpt_oss.metal import Context, Model
8
+
9
+
10
+ DEFAULT_PROMPT = f"""You are ChatGPT, a large language model trained by OpenAI.
11
+ Knowledge cutoff: 2024-06
12
+ Current date: {date.today().isoformat()}
13
+
14
+ reasoning effort high
15
+
16
+ # Valid channels: analysis, final. Channel must be included for every message."""
17
+
18
+
19
+ parser = argparse.ArgumentParser(description="Chat with gpt-oss", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ parser.add_argument("model", metavar="PATH", type=str, help="Path to gpt-oss model in Metal inference format")
21
+ parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="System prompt")
22
+ parser.add_argument(
23
+ "--context-length", type=int, default=0, help="The maximum context length"
24
+ )
25
+ parser.add_argument(
26
+ "--temperature", type=float, default=1.0, help="Sampling temperature"
27
+ )
28
+ parser.add_argument(
29
+ "--seed", type=int, default=0, help="Sampling seed"
30
+ )
31
+
32
+
33
+ GREY = "\33[90m"
34
+ BOLD = "\33[1m"
35
+ RESET = "\33[0m"
36
+
37
+
38
+ def main(args):
39
+ options = parser.parse_args(args)
40
+ model = Model(options.model)
41
+ tokenizer = model.tokenizer
42
+ start_token = tokenizer.encode_special_token("<|start|>")
43
+ message_token = tokenizer.encode_special_token("<|message|>")
44
+ end_token = tokenizer.encode_special_token("<|end|>")
45
+ return_token = tokenizer.encode_special_token("<|return|>")
46
+ channel_token = tokenizer.encode_special_token("<|channel|>")
47
+
48
+ context = Context(model, context_length=options.context_length)
49
+ context.append(start_token)
50
+ context.append("system")
51
+ context.append(message_token)
52
+ context.append(options.prompt)
53
+ context.append(end_token)
54
+
55
+ while True:
56
+ context.append(start_token)
57
+ context.append("user")
58
+ context.append(message_token)
59
+ message = input(f"{BOLD}User:{RESET} ").rstrip()
60
+ context.append(message)
61
+ context.append(end_token)
62
+ print(f"{BOLD}Assistant:{RESET} {GREY}", end="", flush=True)
63
+ context.append(start_token)
64
+ context.append("assistant")
65
+ context.append(channel_token)
66
+
67
+ inside_start_block = True
68
+ inside_channel_block = True
69
+ role = "assistant"
70
+ channel = ""
71
+ while True:
72
+ token = context.sample(
73
+ temperature=options.temperature,
74
+ seed=options.seed,
75
+ )
76
+ context.append(token)
77
+ if token == return_token:
78
+ print(flush=True)
79
+ break
80
+ elif token == start_token:
81
+ inside_start_block = True
82
+ role = ""
83
+ channel = ""
84
+ elif token == message_token:
85
+ inside_start_block = False
86
+ inside_channel_block = False
87
+ if channel == "analysis":
88
+ print(f"{GREY}", end="", flush=True)
89
+ elif token == end_token:
90
+ print(f"{RESET}", flush=True)
91
+ elif token == channel_token:
92
+ inside_channel_block = True
93
+ elif token < tokenizer.num_text_tokens:
94
+ if inside_channel_block:
95
+ channel += str(tokenizer.decode(token), encoding="utf-8")
96
+ elif inside_start_block:
97
+ role += str(tokenizer.decode(token), encoding="utf-8")
98
+ else:
99
+ sys.stdout.buffer.write(tokenizer.decode(token))
100
+ sys.stdout.buffer.flush()
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main(sys.argv[1:])
gptoss_kernels/examples/generate.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import sys
5
+
6
+ from gpt_oss.metal import Context, Model
7
+
8
+
9
+ parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10
+ parser.add_argument('model', metavar='PATH', type=str, help='Path to gpt-oss checkpoint')
11
+ parser.add_argument('-p', '--prompt', type=str, required=True, help='Prompt')
12
+ parser.add_argument('-l', '--limit', type=int, default=100, help='Number of tokens to generate')
13
+ parser.add_argument('--context-length', type=int, default=0, help='The maximum context length')
14
+
15
+
16
+ def main(args):
17
+ options = parser.parse_args(args)
18
+ model = Model(options.model)
19
+
20
+ context = Context(model, context_length=options.context_length)
21
+ context.append(options.prompt)
22
+ print(context.tokens)
23
+ prompt_tokens = context.num_tokens
24
+
25
+ tokenizer = model.tokenizer
26
+
27
+ while context.num_tokens - prompt_tokens < options.limit:
28
+ token = context.sample()
29
+ context.append(token)
30
+ print(str(tokenizer.decode(token), encoding="utf-8"), end='', flush=True)
31
+
32
+
33
+ if __name__ == '__main__':
34
+ main(sys.argv[1:])
gptoss_kernels/include/gpt-oss.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <gpt-oss/macros.h>
4
+ #include <gpt-oss/types.h>
5
+ #include <gpt-oss/functions.h>
gptoss_kernels/include/gpt-oss/functions.h ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stddef.h>
4
+ #include <stdint.h>
5
+
6
+ #include <gpt-oss/macros.h>
7
+ #include <gpt-oss/types.h>
8
+
9
+ #ifdef __cplusplus
10
+ extern "C" {
11
+ #endif
12
+
13
+ /*
14
+ * Creates a Model object from a file in the filesystem.
15
+ *
16
+ * @param path Path to the file containing the model in GPT-OSS format.
17
+ * @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.
18
+ *
19
+ * On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.
20
+ * On failure, returns an error code and stores null pointer in the model_out argument.
21
+ */
22
+ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
23
+ const char* path,
24
+ gptoss_model_t* model_out);
25
+
26
+ /*
27
+ * Query the Tokenizer object associated with the Model.
28
+ *
29
+ * @param model Pointer to the Model object created by gptoss_model_create_from_file.
30
+ * @param tokenizer_out Pointer to the variable where the Tokenizer reference will be stored.
31
+ *
32
+ * On success, returns gptoss_status_success and stores reference to the Tokenizer object in the tokenizer_out argument.
33
+ * On failure, returns an error code and stores NULL in the tokenizer_out argument.
34
+ */
35
+ enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
36
+ gptoss_model_t model,
37
+ gptoss_tokenizer_t* tokenizer_out);
38
+
39
+ /*
40
+ * Query the maximum context length supported by the Model.
41
+ *
42
+ * @param model Pointer to the Model object created by gptoss_model_create_from_file.
43
+ * @param max_context_length_out Pointer to the variable where the maximum context length will be stored.
44
+ *
45
+ * On success, returns gptoss_status_success and stores maximum context length in the max_context_length_out argument.
46
+ * On failure, returns an error code and leaves the value specified by max_context_length_out unchanged.
47
+ */
48
+ enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
49
+ gptoss_model_t model,
50
+ size_t* max_context_length_out);
51
+
52
+ /*
53
+ * Increments a Model object's reference count.
54
+ *
55
+ * @param model Pointer to the Model object created by gptoss_model_create_from_file.
56
+ *
57
+ * On success, returns gptoss_status_success, otherwise returns an error code.
58
+ */
59
+ enum gptoss_status GPTOSS_ABI gptoss_model_retain(
60
+ gptoss_model_t model);
61
+
62
+ /*
63
+ * Decrements a Model object's reference count and possibly release associated resources.
64
+ *
65
+ * @param model Pointer to the Model object created by gptoss_model_create_from_file.
66
+ *
67
+ * On success, returns gptoss_status_success, otherwise returns an error code.
68
+ */
69
+ enum gptoss_status GPTOSS_ABI gptoss_model_release(
70
+ gptoss_model_t model);
71
+
72
+ /*
73
+ * Query the token ID for a special token in the Tokenizer vocabulary.
74
+ *
75
+ * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
76
+ * @param token_type Type of the special token to query an ID for.
77
+ * @param token_id_out Pointer to the variable where the token ID will be stored.
78
+ *
79
+ * On success, returns gptoss_status_success and stores the token ID in the token_id_out argument.
80
+ * On failure, returns an error code and leaves the value specified by token_id_out unchanged.
81
+ */
82
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
83
+ gptoss_tokenizer_t tokenizer,
84
+ enum gptoss_special_token token_type,
85
+ uint32_t* token_id_out);
86
+
87
+ /*
88
+ * Query the number of text tokens in the Tokenizer vocabulary.
89
+ *
90
+ * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
91
+ * @param num_text_tokens_out Pointer to the variable where the number of text tokens will be stored.
92
+ *
93
+ * On success, returns gptoss_status_success and stores the number of text tokens in the num_text_tokens_out argument.
94
+ * On failure, returns an error code and leaves the value specified by num_text_tokens_out unchanged.
95
+ */
96
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
97
+ gptoss_tokenizer_t tokenizer,
98
+ uint32_t* num_text_tokens_out);
99
+
100
+ /*
101
+ * Query the number of special tokens in the Tokenizer vocabulary.
102
+ *
103
+ * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
104
+ * @param num_special_tokens_out Pointer to the variable where the number of special tokens will be stored.
105
+ *
106
+ * On success, returns gptoss_status_success and stores the number of text tokens in the num_special_tokens_out argument.
107
+ * On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
108
+ */
109
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
110
+ gptoss_tokenizer_t tokenizer,
111
+ uint32_t* num_special_tokens_out);
112
+
113
+ /*
114
+ * Query the total number of tokens in the Tokenizer vocabulary.
115
+ *
116
+ * @param tokenizer Pointer to the Tokenizer object created by gptoss_model_get_tokenizer.
117
+ * @param num_tokens_out Pointer to the variable where the total number of tokens will be stored.
118
+ *
119
+ * On success, returns gptoss_status_success and stores the total number of tokens in the num_special_tokens_out argument.
120
+ * On failure, returns an error code and leaves the value specified by num_special_tokens_out unchanged.
121
+ */
122
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
123
+ gptoss_tokenizer_t tokenizer,
124
+ uint32_t* num_tokens_out);
125
+
126
+ /*
127
+ * Convert a text token ID to byte representation.
128
+ *
129
+ * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer. The lifetime of the returned
130
+ * byte representation would match the lifetime of this Tokenizer object.
131
+ * @param token_ptr_out Pointer to the variable where the pointer to the byte representation of the token will be
132
+ * stored.
133
+ * @param token_size_out Pointer to the variable where the size of the byte representation of the token will be stored.
134
+ *
135
+ * On success, returns gptoss_status_success and stores pointer and size of the byte representation of the token in the
136
+ * token_ptr_out and token_size_out arguments.
137
+ * On failure, returns an error code and leaves the values specified in token_ptr_out and token_size_out unchanged.
138
+ */
139
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
140
+ gptoss_tokenizer_t tokenizer,
141
+ uint32_t token_id,
142
+ const void** token_ptr_out,
143
+ size_t* token_size_out);
144
+
145
+ /*
146
+ * Increments a Tokenizer object's reference count.
147
+ *
148
+ * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
149
+ *
150
+ * On success, returns gptoss_status_success, otherwise returns an error code.
151
+ */
152
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
153
+ gptoss_tokenizer_t tokenizer);
154
+
155
+ /*
156
+ * Decrements a Tokenizer object's reference count and possibly release associated resources.
157
+ *
158
+ * @param tokenizer Pointer to the Tokenizer object returned by gptoss_model_get_tokenizer.
159
+ *
160
+ * On success, returns gptoss_status_success, otherwise returns an error code.
161
+ */
162
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
163
+ gptoss_tokenizer_t tokenizer);
164
+
165
+ /*
166
+ * Creates a Context object for use with the particular Model object.
167
+ *
168
+ * @param model Model object to create a context for.
169
+ * @param context_length Maximum number of tokens in the context.
170
+ * Specify 0 to use the maximum context length supported by the model.
171
+ * @param max_batch_size Maximum number of tokens that can be processed in a single batch.
172
+ * Larger values may improve prefill performance, but require more memory.
173
+ * Specify 0 to use the default value.
174
+ * @param context_out Pointer to the Context object that will be created.
175
+ * Must be released with gptoss_release_context.
176
+ *
177
+ * On success, returns gptoss_status_success and saves a pointer to the created Context in the context_out argument.
178
+ * On failure, returns an error code and stores null pointer in the context_out argument.
179
+ */
180
+ enum gptoss_status GPTOSS_ABI gptoss_context_create(
181
+ gptoss_model_t model,
182
+ size_t context_length,
183
+ size_t max_batch_tokens,
184
+ gptoss_context_t* context_out);
185
+
186
+ /*
187
+ * Query the current number of tokens cached in the Context.
188
+ *
189
+ * @param context Pointer to the Context object created by gptoss_context_create.
190
+ * @param num_tokens_out Pointer to the variable where the current number of cached tokens will be stored.
191
+ *
192
+ * On success, returns gptoss_status_success and stores current number of cached tokens in the num_tokens_out argument.
193
+ * On failure, returns an error code and leaves the value specified by num_tokens_out unchanged.
194
+ */
195
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
196
+ gptoss_context_t context,
197
+ size_t* num_tokens_out);
198
+
199
+ /*
200
+ * Query the maximum number of tokens cached in the Context.
201
+ *
202
+ * @param context Pointer to the Context object created by gptoss_context_create.
203
+ * @param max_tokens_out Pointer to the variable where the maximum number of cached tokens will be stored.
204
+ *
205
+ * On success, returns gptoss_status_success and stores maximum number of cached tokens in the max_tokens_out argument.
206
+ * On failure, returns an error code and leaves the value specified by max_tokens_out unchanged.
207
+ */
208
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
209
+ gptoss_context_t context,
210
+ size_t* max_tokens_out);
211
+
212
+ /*
213
+ * Query the list of token IDs cached in the Context.
214
+ *
215
+ * @param context Pointer to the Context object created by gptoss_context_create.
216
+ * @param tokens_out Pointer to the array where up to max_tokens_out of cached tokens will be stored.
217
+ * @param max_tokens Maximum capacity of the buffer specified by tokens_out.
218
+ * @param num_tokens_out Pointer to the variable where the actual number of cached tokens will be stored.
219
+ * This value can exceed max_tokens if the buffer capacity is insufficient.
220
+ *
221
+ * On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of
222
+ * cached tokens in the num_tokens_out argument.
223
+ * On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged.
224
+ */
225
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
226
+ gptoss_context_t context,
227
+ uint32_t* tokens_out,
228
+ size_t max_tokens,
229
+ size_t* num_tokens_out);
230
+
231
+ /*
232
+ * Tokenize and appends a character string to the Context object.
233
+ *
234
+ * @param context Context object created by gptoss_context_create.
235
+ * @param text Pointer to the character string to tokenizer and append.
236
+ * @param text_length Length of the string, in chars.
237
+ * @param num_tokens_out Optional pointer to the variable where the number of appended tokens will be stored. Ignored if a null pointer is provided.
238
+ *
239
+ * On success, returns gptoss_status_success, otherwise returns an error code.
240
+ */
241
+ enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
242
+ gptoss_context_t context,
243
+ const char* text,
244
+ size_t text_length,
245
+ size_t* num_tokens_out);
246
+
247
+ /*
248
+ * Appends a list of tokens to the context.
249
+ *
250
+ * @param context Context object created by gptoss_context_create.
251
+ * @param num_tokens Number of tokens to be appended.
252
+ * @param tokens Pointer to the array of tokens to be appended.
253
+ *
254
+ * On success, returns gptoss_status_success, otherwise returns an error code.
255
+ */
256
+ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
257
+ gptoss_context_t context,
258
+ size_t num_tokens,
259
+ const uint32_t* tokens);
260
+
261
+ /*
262
+ * Resets the context, clearing its state.
263
+ *
264
+ * @param context Context object created by gptoss_context_create.
265
+ *
266
+ * On success, returns gptoss_status_success, otherwise returns an error code.
267
+ */
268
+ enum gptoss_status GPTOSS_ABI gptoss_context_reset(
269
+ gptoss_context_t context);
270
+
271
+ /*
272
+ * Pre-process the tokens in the Context and generate probability distribution over the next token.
273
+ *
274
+ * @param context Context object created by gptoss_context_create.
275
+ *
276
+ * On success, returns gptoss_status_success, otherwise returns an error code.
277
+ */
278
+ enum gptoss_status GPTOSS_ABI gptoss_context_process(
279
+ gptoss_context_t context);
280
+
281
+ /*
282
+ * Generate a token probability distribution over the next token conditioned on the Context.
283
+ *
284
+ * @param context Context object created by gptoss_context_create.
285
+ * @param temperature Sampling temperature. Must be non-negative.
286
+ * @param seed Random number generator seed to use for sampling.
287
+ * @param token_out Pointer to the variable where the token ID will be stored.
288
+ *
289
+ * On success, returns gptoss_status_success, otherwise returns an error code.
290
+ */
291
+ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
292
+ gptoss_context_t context,
293
+ float temperature,
294
+ uint64_t seed,
295
+ size_t max_tokens,
296
+ uint32_t* tokens_out,
297
+ size_t* num_tokens_out);
298
+
299
+ /*
300
+ * Increments a Context object's reference count.
301
+ *
302
+ * @param context Pointer to the Context object created by gptoss_create_context.
303
+ *
304
+ * On success, returns gptoss_status_success, otherwise returns an error code.
305
+ */
306
+ enum gptoss_status GPTOSS_ABI gptoss_context_retain(
307
+ gptoss_context_t context);
308
+
309
+ /*
310
+ * Decrements a Context object's reference count and possibly release associated resources.
311
+ *
312
+ * @param context Pointer to the Context object created by gptoss_create_context.
313
+ *
314
+ * On success, returns gptoss_status_success, otherwise returns an error code.
315
+ */
316
+ enum gptoss_status GPTOSS_ABI gptoss_context_release(
317
+ gptoss_context_t context);
318
+
319
+ /*
320
+ * Creates a Sampler object.
321
+ *
322
+ * @param sampler_out Pointer to the Sampler object that will be created.
323
+ * Must be released with gptoss_sampler_release.
324
+ *
325
+ * On success, returns gptoss_status_success and saves a pointer to the created Sampler in the sampler_out argument.
326
+ * On failure, returns an error code and stores a null pointer in the sampler_out argument.
327
+ */
328
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_create(
329
+ gptoss_sampler_t* sampler_out);
330
+
331
+ /*
332
+ * Sets the sampling temperature for the Sampler.
333
+ *
334
+ * @param sampler Sampler object created by gptoss_sampler_create.
335
+ * @param temperature Temperature value to be set. Must be in the [0.0, 1.0] range.
336
+ *
337
+ * On success, returns gptoss_status_success, otherwise returns an error code.
338
+ */
339
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_set_temperature(
340
+ gptoss_sampler_t sampler,
341
+ float temperature);
342
+
343
+ /*
344
+ * Sets the Top-P nucleus sampling parameter for the Sampler.
345
+ *
346
+ * @param sampler Sampler object created by gptoss_sampler_create.
347
+ * @param top_p Top-P value to be set. Must be in the (0.0, 1.0] range.
348
+ *
349
+ * On success, returns gptoss_status_success, otherwise returns an error code.
350
+ */
351
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_set_top_p(
352
+ gptoss_sampler_t sampler,
353
+ float top_p);
354
+
355
+ /*
356
+ * Sets the presence penalty for the Sampler.
357
+ *
358
+ * @param sampler Sampler object created by gptoss_sampler_create.
359
+ * @param presence_penalty Presence penalty value to be set. Must be in the [-2.0, 2.0] range.
360
+ *
361
+ * On success, returns gptoss_status_success, otherwise returns an error code.
362
+ */
363
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_set_presence_penalty(
364
+ gptoss_sampler_t sampler,
365
+ float presence_penalty);
366
+
367
+ /*
368
+ * Sets the frequency penalty for the Sampler.
369
+ *
370
+ * @param sampler Sampler object created by gptoss_sampler_create.
371
+ * @param frequency_penalty Frequency penalty value to be set. Must be in the [-2.0, 2.0] range.
372
+ *
373
+ * On success, returns gptoss_status_success, otherwise returns an error code.
374
+ */
375
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_set_frequency_penalty(
376
+ gptoss_sampler_t sampler,
377
+ float frequency_penalty);
378
+
379
+ /*
380
+ * Increments a Sampler object's reference count.
381
+ *
382
+ * @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
383
+ *
384
+ * On success, returns gptoss_status_success, otherwise returns an error code.
385
+ */
386
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_retain(
387
+ gptoss_sampler_t sampler);
388
+
389
+ /*
390
+ * Decrements a Sampler object's reference count and possibly releases associated resources.
391
+ *
392
+ * @param sampler Pointer to the Sampler object created by gptoss_sampler_create.
393
+ *
394
+ * On success, returns gptoss_status_success, otherwise returns an error code.
395
+ */
396
+ enum gptoss_status GPTOSS_ABI gptoss_sampler_release(
397
+ gptoss_sampler_t sampler);
398
+
399
+ #ifdef __cplusplus
400
+ } // extern "C"
401
+ #endif
gptoss_kernels/include/gpt-oss/macros.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifndef GPTOSS_ABI
4
+ #define GPTOSS_ABI
5
+ #endif // GPTOSS_ABI
gptoss_kernels/include/gpt-oss/types.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /*
4
+ * Status codes returned by GPT-OSS API functions.
5
+ */
6
+ enum gptoss_status {
7
+ gptoss_status_success = 0,
8
+ gptoss_status_invalid_argument = 1,
9
+ gptoss_status_unsupported_argument = 2,
10
+ gptoss_status_invalid_state = 3,
11
+ gptoss_status_io_error = 4,
12
+ gptoss_status_insufficient_memory = 5,
13
+ gptoss_status_insufficient_resources = 6,
14
+ gptoss_status_unsupported_system = 7,
15
+ gptoss_status_context_overflow = 8,
16
+ };
17
+
18
+ enum gptoss_special_token {
19
+ gptoss_special_token_invalid = 0,
20
+ gptoss_special_token_return = 1,
21
+ gptoss_special_token_start = 2,
22
+ gptoss_special_token_message = 3,
23
+ gptoss_special_token_end = 4,
24
+ gptoss_special_token_refusal = 5,
25
+ gptoss_special_token_constrain = 6,
26
+ gptoss_special_token_channel = 7,
27
+ gptoss_special_token_call = 8,
28
+ gptoss_special_token_untrusted = 9,
29
+ gptoss_special_token_end_untrusted = 10,
30
+ gptoss_special_token_max,
31
+ };
32
+
33
+ /*
34
+ * Model object is an opaque container comprised of:
35
+ * - Weights
36
+ * - Temporary buffers required to run the model
37
+ * - Any other resources requires to run the model
38
+ */
39
+ typedef struct gptoss_model* gptoss_model_t;
40
+
41
+ typedef struct gptoss_tokenizer* gptoss_tokenizer_t;
42
+
43
+ /*
44
+ * Context is an opaque container comprised of:
45
+ * - Input tokens
46
+ * - Distribution over the output tokens
47
+ * - KV cache
48
+ *
49
+ * Multiple contexts can be created and used with the same model.
50
+ */
51
+ typedef struct gptoss_context* gptoss_context_t;
52
+
53
+ /*
54
+ * Sampler is an opaque container for sampling parameters:
55
+ * - Temperature
56
+ * - Top-p (nucleus sampling)
57
+ * - Frequency penalty
58
+ * - Presence penalty
59
+ *
60
+ * Multiple samplers can be created and used with the same context.
61
+ */
62
+ typedef struct gptoss_sampler* gptoss_sampler_t;
gptoss_kernels/source/accumulate.metal ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_integer>
2
+ #include <metal_math>
3
+
4
+ #include <internal/kernel-args.h>
5
+
6
+ #pragma METAL fp math_mode(safe)
7
+ #pragma METAL fp contract(off)
8
+
9
+
10
+ kernel void gptoss_f32_accumulate_e4(
11
+ constant gptoss_accumulate_args& args [[ buffer(0) ]],
12
+ const device float4* input [[ buffer(1) ]],
13
+ const device gptoss_expert_prediction* expert [[ buffer(2) ]],
14
+ device float4* output [[ buffer(3) ]],
15
+ const device gptoss_control* control [[ buffer(4) ]],
16
+ uint2 gid [[threadgroup_position_in_grid]],
17
+ uint tid [[thread_index_in_threadgroup]],
18
+ uint2 threadgroup_size [[ threads_per_threadgroup ]])
19
+ {
20
+ const uint num_active_experts = 4;
21
+ if (control->abort != 0) {
22
+ return;
23
+ }
24
+
25
+ const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
26
+ const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
27
+ const uint num_vecs = args.num_vecs;
28
+ const uint threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, num_vecs);
29
+ const uint thread_start = threadgroup_start + tid;
30
+ uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size.x - 1)) / threadgroup_size.x);
31
+
32
+ const uint num_vecs_per_expert = args.num_vecs_per_expert;
33
+ const float scale0 = expert[gid.y * num_active_experts + 0].score;
34
+ const device float4* input0 = input + gid.y * num_vecs + thread_start;
35
+ const float scale1 = expert[gid.y * num_active_experts + 1].score;
36
+ const device float4* input1 = input0 + num_vecs_per_expert;
37
+ const float scale2 = expert[gid.y * num_active_experts + 2].score;
38
+ const device float4* input2 = input1 + num_vecs_per_expert;
39
+ const float scale3 = expert[gid.y * num_active_experts + 3].score;
40
+ const device float4* input3 = input2 + num_vecs_per_expert;
41
+ output += gid.y * num_vecs + thread_start;
42
+ for (; num_iter != 0; num_iter--) {
43
+ float4 acc = *output;
44
+ const float4 val0 = *input0;
45
+ const float4 val1 = *input1;
46
+ const float4 val2 = *input2;
47
+ const float4 val3 = *input3;
48
+ input0 += threadgroup_size.x;
49
+ acc = metal::fma(val0, scale0, acc);
50
+ input1 += threadgroup_size.x;
51
+ acc = metal::fma(val1, scale1, acc);
52
+ input2 += threadgroup_size.x;
53
+ acc = metal::fma(val2, scale2, acc);
54
+ input3 += threadgroup_size.x;
55
+ acc = metal::fma(val3, scale3, acc);
56
+ *output = acc;
57
+ output += threadgroup_size.x;
58
+ }
59
+ }
gptoss_kernels/source/context.c ADDED
@@ -0,0 +1,1115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h>
2
+ #include <float.h>
3
+ #include <inttypes.h>
4
+ #include <stdbool.h>
5
+ #include <stdint.h>
6
+ #include <stdlib.h>
7
+ #include <string.h>
8
+
9
+ #include <gpt-oss.h>
10
+
11
+ #include "internal/datatype.h"
12
+ #include "internal/model.h"
13
+ #include "internal/metal.h"
14
+ #include "internal/metal-kernels.h"
15
+ #include "internal/log.h"
16
+ #include "internal/rng.h"
17
+
18
+
19
+ enum gptoss_status GPTOSS_ABI gptoss_context_create(
20
+ gptoss_model_t model,
21
+ size_t context_length,
22
+ size_t max_batch_tokens,
23
+ gptoss_context_t* context_out)
24
+ {
25
+ *context_out = NULL;
26
+
27
+ enum gptoss_status status = gptoss_status_success;
28
+ struct gptoss_context* context = NULL;
29
+
30
+ // Validate context_length
31
+ if (context_length == 0) {
32
+ context_length = model->context_length;
33
+ } else if (context_length > model->context_length) {
34
+ GPTOSS_LOG_ERROR("requested context length %zu exceeds model context length %" PRIu32,
35
+ context_length, model->context_length);
36
+ status = gptoss_status_invalid_argument;
37
+ goto cleanup;
38
+ }
39
+ assert(context_length != 0);
40
+ assert(context_length <= model->context_length);
41
+
42
+ // Validate max_batch_tokens
43
+ if (max_batch_tokens == 0) {
44
+ max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
45
+ } else if (max_batch_tokens > context_length) {
46
+ GPTOSS_LOG_ERROR("requested max batch tokens %zu exceeds context length %zu",
47
+ max_batch_tokens, context_length);
48
+ status = gptoss_status_invalid_argument;
49
+ goto cleanup;
50
+ }
51
+ assert(max_batch_tokens != 0);
52
+ assert(max_batch_tokens <= context_length);
53
+
54
+ context = malloc(sizeof(struct gptoss_context));
55
+ if (context == NULL) {
56
+ GPTOSS_LOG_ERROR("failed to allocate %zu bytes for Context object",
57
+ sizeof(struct gptoss_context));
58
+ status = gptoss_status_insufficient_memory;
59
+ goto cleanup;
60
+ }
61
+ memset(context, 0, sizeof(struct gptoss_context));
62
+
63
+ atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
64
+ context->max_tokens = context_length;
65
+ context->max_batch_tokens = max_batch_tokens;
66
+
67
+ // Activation buffers
68
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
69
+ if (status != gptoss_status_success) {
70
+ goto cleanup;
71
+ }
72
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
73
+ if (status != gptoss_status_success) {
74
+ goto cleanup;
75
+ }
76
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
77
+ if (status != gptoss_status_success) {
78
+ goto cleanup;
79
+ }
80
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
81
+ if (status != gptoss_status_success) {
82
+ goto cleanup;
83
+ }
84
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
85
+ if (status != gptoss_status_success) {
86
+ goto cleanup;
87
+ }
88
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
89
+ if (status != gptoss_status_success) {
90
+ goto cleanup;
91
+ }
92
+ // The last entry will hold the total number of tokens.
93
+ status = gptoss_metal_buffer_create(&model->device, (1 + model->num_experts) * sizeof(uint32_t), NULL, &context->expert_offset_buffer);
94
+ if (status != gptoss_status_success) {
95
+ goto cleanup;
96
+ }
97
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * sizeof(uint32_t), NULL, &context->token_to_expert_routing_buffer);
98
+ if (status != gptoss_status_success) {
99
+ goto cleanup;
100
+ }
101
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->swiglu_input_buffer);
102
+ if (status != gptoss_status_success) {
103
+ goto cleanup;
104
+ }
105
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
106
+ if (status != gptoss_status_success) {
107
+ goto cleanup;
108
+ }
109
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
110
+ if (status != gptoss_status_success) {
111
+ goto cleanup;
112
+ }
113
+
114
+ // Input/output buffers
115
+ status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
116
+ if (status != gptoss_status_success) {
117
+ goto cleanup;
118
+ }
119
+ status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
120
+ if (status != gptoss_status_success) {
121
+ goto cleanup;
122
+ }
123
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->score_buffer);
124
+ if (status != gptoss_status_success) {
125
+ goto cleanup;
126
+ }
127
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->vocabulary_size * sizeof(float), NULL, &context->prob_buffer);
128
+ if (status != gptoss_status_success) {
129
+ goto cleanup;
130
+ }
131
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * model->max_threadgroups * sizeof(float), NULL, &context->sum_buffer);
132
+ if (status != gptoss_status_success) {
133
+ goto cleanup;
134
+ }
135
+ status = gptoss_metal_buffer_create(&model->device, max_batch_tokens * sizeof(uint64_t), NULL, &context->argmax_buffer);
136
+ if (status != gptoss_status_success) {
137
+ goto cleanup;
138
+ }
139
+ status = gptoss_metal_buffer_create(&model->device, model->num_blocks * context_length * 2 * model->num_kv_heads * model->head_dim * sizeof(float), NULL, &context->kvcache_buffer);
140
+ if (status != gptoss_status_success) {
141
+ goto cleanup;
142
+ }
143
+
144
+ context->kvcache_size = context->kvcache_buffer.size;
145
+ context->allocation_size =
146
+ context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
147
+ context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
148
+ context->gate_activation_buffer.size + context->expert_activation_buffer.size +
149
+ context->expert_offset_buffer.size + context->token_to_expert_routing_buffer.size + context->swiglu_input_buffer.size +
150
+ context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
151
+ context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
152
+
153
+ context->model = model;
154
+ gptoss_model_retain(model);
155
+ *context_out = context;
156
+ context = NULL;
157
+
158
+ cleanup:
159
+ gptoss_context_release(context);
160
+ return status;
161
+ }
162
+
163
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_num_tokens(
164
+ gptoss_context_t context,
165
+ size_t* num_tokens_out)
166
+ {
167
+ *num_tokens_out = context->num_tokens;
168
+ return gptoss_status_success;
169
+ }
170
+
171
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
172
+ gptoss_context_t context,
173
+ size_t* max_tokens_out)
174
+ {
175
+ *max_tokens_out = context->max_tokens;
176
+ return gptoss_status_success;
177
+ }
178
+
179
+ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
180
+ gptoss_context_t context,
181
+ uint32_t* tokens_out,
182
+ size_t max_tokens,
183
+ size_t* num_tokens_out)
184
+ {
185
+ *num_tokens_out = context->num_tokens;
186
+ if (max_tokens < context->num_tokens) {
187
+ return gptoss_status_insufficient_memory;
188
+ }
189
+
190
+ if (context->num_tokens != 0) {
191
+ memcpy(tokens_out, context->token_buffer.ptr, context->num_tokens * sizeof(uint32_t));
192
+ }
193
+ return gptoss_status_success;
194
+ }
195
+
196
+ // Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
197
+ // Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
198
+ // Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
199
+ static enum gptoss_status process_tokens(
200
+ gptoss_context_t context,
201
+ struct gptoss_metal_command_buffer* command_buffer,
202
+ size_t input_tokens_offset,
203
+ size_t num_input_tokens,
204
+ size_t num_output_tokens)
205
+ {
206
+ assert(num_input_tokens != 0);
207
+ assert(num_input_tokens <= context->max_batch_tokens);
208
+ assert(num_output_tokens <= context->max_batch_tokens);
209
+ assert(num_input_tokens >= num_output_tokens);
210
+ const size_t dense_matmul_kernel_token_multiple_constraint = 64;
211
+ const size_t min_tokens_for_dense_moe_kernels = 64;
212
+
213
+ enum gptoss_status status = gptoss_status_success;
214
+ const struct gptoss_model* model = context->model;
215
+
216
+ const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
217
+
218
+ const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
219
+ for (size_t input_batch_start = input_tokens_offset;
220
+ input_batch_start < input_tokens_end;
221
+ input_batch_start += context->max_batch_tokens)
222
+ {
223
+ const size_t input_batch_size = math_min(context->max_batch_tokens, input_tokens_end - input_batch_start);
224
+ const size_t input_batch_end = input_batch_start + input_batch_size;
225
+ const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
226
+
227
+ status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
228
+ command_buffer,
229
+ &model->bf16_f32_embeddings_fn,
230
+ model->embeddings_threadgroup_size,
231
+ &context->token_buffer,
232
+ input_batch_start * sizeof(uint32_t),
233
+ &model->shared_weight_buffer,
234
+ /*weight_offset=*/0,
235
+ &context->residual_activation_buffer,
236
+ /*output_offset=*/0,
237
+ &context->control_buffer,
238
+ /*control_offset=*/0,
239
+ /*num_tokens=*/input_batch_size,
240
+ /*num_channels=*/model->embedding_dim);
241
+ if (status != gptoss_status_success) {
242
+ GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
243
+ return status;
244
+ }
245
+ for (uint32_t n = 0; n < model->num_blocks; n++) {
246
+ const bool last_block = n + 1 == model->num_blocks;
247
+ const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
248
+
249
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
250
+ command_buffer,
251
+ &model->f32_bf16w_rmsnorm_fn,
252
+ &context->residual_activation_buffer,
253
+ /*input_offset=*/0,
254
+ &model->shared_weight_buffer,
255
+ /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
256
+ &context->rmsnorm_activation_buffer,
257
+ /*output_offset=*/0,
258
+ &context->control_buffer,
259
+ /*control_offset=*/0,
260
+ /*num_tokens=*/input_batch_size,
261
+ /*num_channels=*/model->embedding_dim,
262
+ model->rmsnorm_epsilon);
263
+ if (status != gptoss_status_success) {
264
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
265
+ return status;
266
+ }
267
+
268
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
269
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
270
+ command_buffer,
271
+ &model->f32_bf16w_dense_matmul_qkv_fn,
272
+ &context->rmsnorm_activation_buffer,
273
+ /*input_offset=*/0,
274
+ &model->shared_weight_buffer,
275
+ /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
276
+ &model->shared_weight_buffer,
277
+ /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
278
+ &context->qkv_activation_buffer,
279
+ /*output_offset=*/0,
280
+ &context->control_buffer,
281
+ /*control_offset=*/0,
282
+ /*num_tokens=*/input_batch_size,
283
+ /*num_cols=*/model->embedding_dim,
284
+ /*num_rows=*/attn_qkv_dim);
285
+ if (status != gptoss_status_success) {
286
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
287
+ return status;
288
+ }
289
+
290
+ status = gptoss_metal_command_buffer_encode_launch_f32_rope(
291
+ command_buffer,
292
+ &model->f32_rope_fn,
293
+ /*threadgroup_size=*/32,
294
+ &context->qkv_activation_buffer,
295
+ /*input_offset=*/0,
296
+ &context->control_buffer,
297
+ /*control_offset=*/0,
298
+ model->rope_theta,
299
+ model->interpolation_scale,
300
+ model->yarn_offset,
301
+ model->yarn_scale,
302
+ model->yarn_multiplier,
303
+ input_batch_size,
304
+ model->num_heads,
305
+ model->num_kv_heads,
306
+ model->head_dim,
307
+ /*token_offset=*/input_batch_start);
308
+ if (status != gptoss_status_success) {
309
+ GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
310
+ return status;
311
+ }
312
+
313
+ for (uint32_t t = 0; t < input_batch_size; t++) {
314
+ for (uint32_t kv = 0; kv < 2; kv++) {
315
+ for (uint32_t h = 0; h < model->num_kv_heads; h++) {
316
+ status = gptoss_metal_command_buffer_encode_copy_buffer(
317
+ command_buffer,
318
+ &context->qkv_activation_buffer,
319
+ /*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
320
+ &context->kvcache_buffer,
321
+ /*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
322
+ /*size=*/model->head_dim * sizeof(float));
323
+ if (status != gptoss_status_success) {
324
+ GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
325
+ return status;
326
+ }
327
+ }
328
+ }
329
+ }
330
+ } else {
331
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
332
+ command_buffer,
333
+ &model->f32_bf16w_matmul_qkv_fn,
334
+ model->attn_qkv_threadgroup_size,
335
+ &context->rmsnorm_activation_buffer,
336
+ /*input_offset=*/0,
337
+ &model->shared_weight_buffer,
338
+ /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
339
+ &model->shared_weight_buffer,
340
+ /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
341
+ &context->qkv_activation_buffer,
342
+ /*output_offset=*/0,
343
+ &context->kvcache_buffer,
344
+ /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
345
+ &context->control_buffer,
346
+ /*control_offset=*/0,
347
+ /*num_tokens=*/input_batch_size,
348
+ /*num_cols=*/model->embedding_dim,
349
+ /*num_q_heads=*/model->num_heads,
350
+ /*num_kv_heads=*/model->num_kv_heads,
351
+ /*attn_head_dim=*/model->head_dim,
352
+ /*token_offset=*/input_batch_start,
353
+ /*max_tokens=*/context->max_tokens,
354
+ /*rope_base=*/model->rope_theta,
355
+ /*interpolation_scale=*/model->interpolation_scale,
356
+ /*yarn_offset=*/model->yarn_offset,
357
+ /*yarn_scale=*/model->yarn_scale,
358
+ /*yarn_multiplier=*/model->yarn_multiplier);
359
+ if (status != gptoss_status_success) {
360
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
361
+ return status;
362
+ }
363
+ }
364
+
365
+ if (num_block_output_tokens != 0) {
366
+ status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
367
+ command_buffer,
368
+ &model->f32_sdpa_q8_d64_fn,
369
+ &context->qkv_activation_buffer,
370
+ /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
371
+ &context->kvcache_buffer,
372
+ /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
373
+ &model->shared_weight_buffer,
374
+ /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
375
+ &context->sdpa_activation_buffer,
376
+ /*output_offset=*/0,
377
+ &context->control_buffer,
378
+ /*control_offset=*/0,
379
+ /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
380
+ /*kv_stride=*/2 * context->max_tokens * model->head_dim,
381
+ num_block_output_tokens,
382
+ input_batch_start + input_batch_size - num_block_output_tokens,
383
+ model->num_heads, model->num_kv_heads, model->head_dim);
384
+ if (status != gptoss_status_success) {
385
+ GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
386
+ return status;
387
+ }
388
+
389
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
390
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
391
+ command_buffer,
392
+ &model->f32_bf16w_dense_matmul_attn_output_fn,
393
+ &context->sdpa_activation_buffer,
394
+ /*input_offset=*/0,
395
+ &model->shared_weight_buffer,
396
+ /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
397
+ &model->shared_weight_buffer,
398
+ /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
399
+ &context->residual_activation_buffer,
400
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
401
+ &context->control_buffer,
402
+ /*control_offset=*/0,
403
+ /*num_tokens=*/num_block_output_tokens,
404
+ /*num_cols=*/model->num_heads * model->head_dim,
405
+ /*num_rows=*/model->embedding_dim);
406
+ if (status != gptoss_status_success) {
407
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
408
+ return status;
409
+ }
410
+ } else {
411
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
412
+ command_buffer,
413
+ &model->f32_bf16w_matmul_fn,
414
+ model->attn_out_threadgroup_size,
415
+ &context->sdpa_activation_buffer,
416
+ /*input_offset=*/0,
417
+ &model->shared_weight_buffer,
418
+ /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
419
+ &model->shared_weight_buffer,
420
+ /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
421
+ &context->residual_activation_buffer,
422
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
423
+ &context->control_buffer,
424
+ /*control_offset=*/0,
425
+ /*num_tokens=*/num_block_output_tokens,
426
+ /*num_cols=*/model->num_heads * model->head_dim,
427
+ /*num_rows=*/model->embedding_dim);
428
+ if (status != gptoss_status_success) {
429
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
430
+ return status;
431
+ }
432
+ }
433
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
434
+ command_buffer,
435
+ &model->f32_bf16w_rmsnorm_fn,
436
+ &context->residual_activation_buffer,
437
+ /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
438
+ &model->shared_weight_buffer,
439
+ /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
440
+ &context->rmsnorm_activation_buffer,
441
+ /*output_offset=*/0,
442
+ &context->control_buffer,
443
+ /*control_offset=*/0,
444
+ num_block_output_tokens,
445
+ model->embedding_dim,
446
+ model->rmsnorm_epsilon);
447
+ if (status != gptoss_status_success) {
448
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
449
+ return status;
450
+ }
451
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
452
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
453
+ command_buffer,
454
+ &model->f32_bf16w_dense_matmul_mlp_gate_fn,
455
+ &context->rmsnorm_activation_buffer,
456
+ /*input_offset=*/0,
457
+ &model->shared_weight_buffer,
458
+ /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
459
+ &model->shared_weight_buffer,
460
+ /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
461
+ &context->gate_activation_buffer,
462
+ /*output_offset=*/0,
463
+ &context->control_buffer,
464
+ /*control_offset=*/0,
465
+ num_block_output_tokens,
466
+ model->embedding_dim,
467
+ model->num_experts);
468
+ if (status != gptoss_status_success) {
469
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
470
+ return status;
471
+ }
472
+ } else {
473
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
474
+ command_buffer,
475
+ &model->f32_bf16w_matmul_fn,
476
+ model->mlp_gate_threadgroup_size,
477
+ &context->rmsnorm_activation_buffer,
478
+ /*input_offset=*/0,
479
+ &model->shared_weight_buffer,
480
+ /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
481
+ &model->shared_weight_buffer,
482
+ /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
483
+ &context->gate_activation_buffer,
484
+ /*output_offset=*/0,
485
+ &context->control_buffer,
486
+ /*control_offset=*/0,
487
+ /*num_tokens=*/num_block_output_tokens,
488
+ /*num_cols=*/model->embedding_dim,
489
+ /*num_rows=*/model->num_experts);
490
+ if (status != gptoss_status_success) {
491
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
492
+ return status;
493
+ }
494
+ }
495
+
496
+ const char* kernel_name = NULL;
497
+ switch (model->num_experts) {
498
+ case 32:
499
+ kernel_name = "f32_topk_softmax_e32_k4_fn";
500
+ status = gptoss_metal_command_buffer_encode_launch_f32_topk(
501
+ command_buffer,
502
+ &model->f32_topk_softmax_e32_k4_fn,
503
+ &context->gate_activation_buffer, /*input_offset=*/0,
504
+ &context->expert_activation_buffer, /*output_offset=*/0,
505
+ &context->control_buffer, /*control_offset=*/0,
506
+ num_block_output_tokens,
507
+ model->num_experts,
508
+ model->num_active_experts);
509
+ break;
510
+ case 128:
511
+ kernel_name = "f32_topk_softmax_e128_k4_fn";
512
+ status = gptoss_metal_command_buffer_encode_launch_f32_topk(
513
+ command_buffer,
514
+ &model->f32_topk_softmax_e128_k4_fn,
515
+ &context->gate_activation_buffer, /*input_offset=*/0,
516
+ &context->expert_activation_buffer, /*output_offset=*/0,
517
+ &context->control_buffer, /*control_offset=*/0,
518
+ num_block_output_tokens,
519
+ model->num_experts,
520
+ model->num_active_experts);
521
+ break;
522
+ default:
523
+ status = gptoss_status_unsupported_argument;
524
+ GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
525
+ return status;
526
+ }
527
+ if (status != gptoss_status_success) {
528
+ GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
529
+ return status;
530
+ }
531
+
532
+ // If we have enough tokens in prefill, we will pick the prefill-optimized kernels.
533
+ if (num_block_output_tokens >= min_tokens_for_dense_moe_kernels) {
534
+ status = gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
535
+ command_buffer,
536
+ &model->f32_expert_routing_metadata_fn,
537
+ &context->expert_activation_buffer,
538
+ /*expert_predictions_offset=*/0,
539
+ &context->expert_offset_buffer,
540
+ /*expert_offsets_offset=*/0,
541
+ &context->token_to_expert_routing_buffer,
542
+ /*intra_expert_offsets_offset=*/0,
543
+ num_block_output_tokens * model->num_active_experts,
544
+ model->num_experts);
545
+ if (status != gptoss_status_success) {
546
+ GPTOSS_LOG_ERROR("failed to encode f32_expert_routing_metadata kernel launch");
547
+ return status;
548
+ }
549
+ status = gptoss_metal_command_buffer_encode_launch_f32_scatter(
550
+ command_buffer,
551
+ &model->f32_scatter_e4_fn,
552
+ &context->rmsnorm_activation_buffer,
553
+ /*input_offset=*/0,
554
+ &context->expert_activation_buffer,
555
+ /*expert_predictions_offset=*/0,
556
+ &context->expert_offset_buffer,
557
+ /*expert_offsets_offset=*/0,
558
+ &context->token_to_expert_routing_buffer,
559
+ /*intra_expert_offsets_offset=*/0,
560
+ &context->swiglu_input_buffer,
561
+ /*output_offset=*/0,
562
+ model->embedding_dim,
563
+ num_block_output_tokens,
564
+ model->num_active_experts);
565
+ if (status != gptoss_status_success) {
566
+ GPTOSS_LOG_ERROR("failed to encode f32_scatter kernel launch");
567
+ return status;
568
+ }
569
+ // Dense MoE SwiGLU matmul.
570
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
571
+ command_buffer,
572
+ &model->f32_mf4w_moe_dense_matmul_swiglu_fn,
573
+ &context->expert_offset_buffer,
574
+ /*expert_offsets_offset=*/0,
575
+ &context->swiglu_input_buffer,
576
+ /*input_offset=*/0,
577
+ &model->block_weight_buffers[n],
578
+ /*weight_block_offset=*/0,
579
+ &model->block_weight_buffers[n],
580
+ /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
581
+ &model->block_weight_buffers[n],
582
+ /*bias_offset=*/model->mlp_swiglu_bias_offset,
583
+ &context->swiglu_activation_buffer,
584
+ /*output_offset=*/0,
585
+ model->swiglu_limit,
586
+ /*expert_stride_bytes=*/model->per_expert_block_weight_size,
587
+ num_block_output_tokens,
588
+ model->num_experts,
589
+ model->embedding_dim,
590
+ 2 * model->mlp_dim);
591
+ if (status != gptoss_status_success) {
592
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
593
+ return status;
594
+ }
595
+
596
+ // Dense MoE proj matmul.
597
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
598
+ command_buffer,
599
+ &model->f32_mf4w_moe_dense_matmul_fn,
600
+ &context->expert_offset_buffer,
601
+ /*expert_offsets_offset=*/0,
602
+ &context->swiglu_activation_buffer,
603
+ /*input_offset=*/0,
604
+ &model->block_weight_buffers[n],
605
+ /*weight_block_offset=*/model->mlp_out_block_offset,
606
+ &model->block_weight_buffers[n],
607
+ /*weight_scale_offset=*/model->mlp_out_scale_offset,
608
+ &model->block_weight_buffers[n],
609
+ /*bias_offset=*/model->mlp_out_bias_offset,
610
+ &context->moe_activation_buffer,
611
+ /*output_offset=*/0,
612
+ /*expert_stride_bytes=*/model->per_expert_block_weight_size,
613
+ num_block_output_tokens,
614
+ model->num_experts,
615
+ model->mlp_dim,
616
+ model->embedding_dim);
617
+ if (status != gptoss_status_success) {
618
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch");
619
+ return status;
620
+ }
621
+ // Gather and accumulate.
622
+ status = gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
623
+ command_buffer,
624
+ &model->f32_gather_and_accumulate_e4_fn,
625
+ &context->moe_activation_buffer,
626
+ /*input_offset=*/0,
627
+ &context->expert_activation_buffer,
628
+ /*expert_predictions_offset=*/0,
629
+ &context->expert_offset_buffer,
630
+ /*expert_offsets_offset=*/0,
631
+ &context->token_to_expert_routing_buffer,
632
+ /*intra_expert_offsets_offset=*/0,
633
+ &context->residual_activation_buffer,
634
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
635
+ model->embedding_dim,
636
+ num_block_output_tokens,
637
+ model->num_active_experts);
638
+ if (status != gptoss_status_success) {
639
+ GPTOSS_LOG_ERROR("failed to encode f32_gather_and_accumulate_e4 kernel launch");
640
+ return status;
641
+ }
642
+
643
+ } else {
644
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
645
+ command_buffer,
646
+ &model->f32_mf4w_moe_matmul_swiglu_fn,
647
+ model->mlp_swiglu_threadgroup_size,
648
+ &context->rmsnorm_activation_buffer,
649
+ /*input_offset=*/0,
650
+ &context->expert_activation_buffer,
651
+ /*expert_offset=*/0,
652
+ &model->block_weight_buffers[n],
653
+ /*weight_block_offset=*/0,
654
+ &model->block_weight_buffers[n],
655
+ /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
656
+ &model->block_weight_buffers[n],
657
+ /*bias_offset=*/model->mlp_swiglu_bias_offset,
658
+ &context->swiglu_activation_buffer,
659
+ /*output_offset=*/0,
660
+ &context->control_buffer,
661
+ /*control_offset=*/0,
662
+ model->swiglu_limit,
663
+ model->per_expert_block_weight_size,
664
+ num_block_output_tokens,
665
+ model->num_active_experts,
666
+ model->embedding_dim,
667
+ model->mlp_dim);
668
+ if (status != gptoss_status_success) {
669
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
670
+ return status;
671
+ }
672
+
673
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
674
+ command_buffer,
675
+ &model->f32_mf4w_moe_matmul_fn,
676
+ model->mlp_out_threadgroup_size,
677
+ &context->swiglu_activation_buffer,
678
+ /*input_offset=*/0,
679
+ &context->expert_activation_buffer,
680
+ /*expert_offset=*/0,
681
+ &model->block_weight_buffers[n],
682
+ /*weight_block_offset=*/model->mlp_out_block_offset,
683
+ &model->block_weight_buffers[n],
684
+ /*weight_scale_offset=*/model->mlp_out_scale_offset,
685
+ &model->block_weight_buffers[n],
686
+ /*bias_offset=*/model->mlp_out_bias_offset,
687
+ &context->moe_activation_buffer,
688
+ /*output_offset=*/0,
689
+ &context->control_buffer,
690
+ /*control_offset=*/0,
691
+ model->per_expert_block_weight_size,
692
+ num_block_output_tokens,
693
+ model->num_active_experts,
694
+ model->mlp_dim,
695
+ model->embedding_dim);
696
+ if (status != gptoss_status_success) {
697
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
698
+ return status;
699
+ }
700
+
701
+ status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
702
+ command_buffer,
703
+ &model->f32_accumulate_e4_fn,
704
+ model->mlp_acc_threadgroup_size,
705
+ model->max_threadgroups,
706
+ &context->moe_activation_buffer,
707
+ /*input_offset=*/0,
708
+ &context->expert_activation_buffer,
709
+ /*expert_offset=*/0,
710
+ &context->residual_activation_buffer,
711
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
712
+ &context->control_buffer,
713
+ /*control_offset=*/0,
714
+ model->embedding_dim,
715
+ num_block_output_tokens,
716
+ model->num_active_experts);
717
+ if (status != gptoss_status_success) {
718
+ GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
719
+ return status;
720
+ }
721
+ }
722
+ }
723
+ }
724
+
725
+ if (output_batch_size != 0) {
726
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
727
+ command_buffer,
728
+ &model->f32_bf16w_rmsnorm_fn,
729
+ &context->residual_activation_buffer,
730
+ /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
731
+ &model->shared_weight_buffer,
732
+ /*weight_offset=*/model->rmsnorm_weight_offset,
733
+ &context->rmsnorm_activation_buffer,
734
+ /*output_offset=*/0,
735
+ &context->control_buffer,
736
+ /*control_offset=*/0,
737
+ /*num_tokens=*/output_batch_size,
738
+ /*num_channels=*/model->embedding_dim,
739
+ model->rmsnorm_epsilon);
740
+ if (status != gptoss_status_success) {
741
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
742
+ return status;
743
+ }
744
+
745
+ status = gptoss_metal_command_buffer_encode_fill_buffer(
746
+ command_buffer,
747
+ &context->argmax_buffer,
748
+ /*offset=*/0,
749
+ /*size=*/sizeof(uint64_t) * output_batch_size,
750
+ /*fill_value=*/0xFF);
751
+ if (status != gptoss_status_success) {
752
+ GPTOSS_LOG_ERROR("failed to encode fill buffer command");
753
+ return status;
754
+ }
755
+
756
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
757
+ command_buffer,
758
+ &model->f32_bf16w_unembedding_fn,
759
+ model->unembedding_threadgroup_size,
760
+ model->max_threadgroups,
761
+ &context->rmsnorm_activation_buffer,
762
+ /*input_offset=*/0,
763
+ &model->shared_weight_buffer,
764
+ /*weight_offset=*/model->unembedding_weight_offset,
765
+ &context->score_buffer,
766
+ /*output_offset=*/0,
767
+ &context->argmax_buffer,
768
+ /*argmax_offset=*/0,
769
+ &context->control_buffer,
770
+ /*control_offset=*/0,
771
+ /*num_tokens=*/output_batch_size,
772
+ /*num_cols=*/model->embedding_dim,
773
+ /*num_rows=*/model->vocabulary_size);
774
+ if (status != gptoss_status_success) {
775
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
776
+ return status;
777
+ }
778
+ }
779
+ }
780
+ return gptoss_status_success;
781
+ }
782
+
783
+ enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
784
+ gptoss_context_t context,
785
+ const char* text,
786
+ size_t text_length,
787
+ size_t* num_tokens_out)
788
+ {
789
+ enum gptoss_status status = gptoss_status_success;
790
+ const struct gptoss_model* model = context->model;
791
+ const struct gptoss_tokenizer* tokenizer = model->tokenizer;
792
+ size_t num_appended_tokens = 0;
793
+ while (text_length != 0) {
794
+ if (context->num_tokens == context->max_tokens) {
795
+ status = gptoss_status_context_overflow;
796
+ break;
797
+ }
798
+ const char* tokens = tokenizer->tokens_ptr;
799
+ uint32_t best_token = UINT32_MAX;
800
+ uint32_t best_token_length = 0;
801
+ for (size_t t = 0; t < tokenizer->num_text_tokens; t++) {
802
+ uint16_t token_length;
803
+ memcpy(&token_length, tokens, sizeof(uint16_t));
804
+ tokens += sizeof(uint16_t);
805
+ if (token_length <= text_length && token_length > best_token_length) {
806
+ if (memcmp(text, tokens, token_length) == 0) {
807
+ if (token_length > best_token_length) {
808
+ best_token = (uint32_t) t;
809
+ best_token_length = token_length;
810
+ }
811
+ }
812
+ }
813
+ tokens += token_length;
814
+ }
815
+
816
+ if (best_token == UINT32_MAX) {
817
+ GPTOSS_LOG_ERROR("failed to tokenize text \"%.*s\"", (int) text_length, text);
818
+ return gptoss_status_invalid_argument;
819
+ }
820
+
821
+ uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
822
+ if (context->num_kv_tokens > context->num_tokens) {
823
+ if (input_tokens[context->num_tokens] != best_token) {
824
+ input_tokens[context->num_tokens] = best_token;
825
+
826
+ // Invalidate the KV cache starting with the newly added token.
827
+ context->num_kv_tokens = context->num_tokens;
828
+ }
829
+ context->num_tokens++;
830
+ } else {
831
+ input_tokens[context->num_tokens++] = best_token;
832
+ }
833
+ num_appended_tokens++;
834
+ text += best_token_length;
835
+ text_length -= best_token_length;
836
+ }
837
+ if (num_tokens_out != NULL) {
838
+ *num_tokens_out = num_appended_tokens;
839
+ }
840
+ return status;
841
+ }
842
+
843
+ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
844
+ gptoss_context_t context,
845
+ size_t num_tokens,
846
+ const uint32_t* tokens)
847
+ {
848
+ const struct gptoss_model* model = context->model;
849
+
850
+ // Validate all tokens
851
+ for (size_t t = 0; t < num_tokens; t++) {
852
+ const uint32_t token = tokens[t];
853
+ if (token >= model->vocabulary_size) {
854
+ GPTOSS_LOG_ERROR("token %" PRIu32 " at index %zu is out of bounds for vocabulary size %" PRIu32,
855
+ token, t, context->model->vocabulary_size);
856
+ return gptoss_status_invalid_argument;
857
+ }
858
+ }
859
+
860
+ enum gptoss_status status = gptoss_status_success;
861
+ uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
862
+ while (num_tokens != 0) {
863
+ if (context->num_tokens == context->max_tokens) {
864
+ status = gptoss_status_context_overflow;
865
+ break;
866
+ }
867
+
868
+ if (context->num_kv_tokens > context->num_tokens) {
869
+ const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
870
+ size_t num_verified_tokens = 0;
871
+ for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
872
+ if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
873
+ // Invalidate the KV cache starting with the newly added tokens.
874
+ context->num_kv_tokens = context->num_tokens + num_verified_tokens;
875
+ break;
876
+ }
877
+ }
878
+
879
+ context->num_tokens += num_verified_tokens;
880
+ tokens += num_verified_tokens;
881
+ num_tokens -= num_verified_tokens;
882
+ } else {
883
+ const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
884
+ memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
885
+ context->num_tokens += num_tokens_to_copy;
886
+ tokens += num_tokens_to_copy;
887
+ num_tokens -= num_tokens_to_copy;
888
+ }
889
+ }
890
+
891
+ return status;
892
+ }
893
+
894
+ enum gptoss_status GPTOSS_ABI gptoss_context_process(
895
+ gptoss_context_t context)
896
+ {
897
+ if (context->num_tokens > context->num_kv_tokens) {
898
+ struct gptoss_metal_command_buffer command_buffer = {0};
899
+
900
+ enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
901
+ if (status != gptoss_status_success) {
902
+ goto cleanup;
903
+ }
904
+
905
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
906
+ control->abort = 0;
907
+
908
+ status = process_tokens(
909
+ context,
910
+ &command_buffer,
911
+ /*input_tokens_offset=*/context->num_kv_tokens,
912
+ /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
913
+ /*num_output_tokens=*/0);
914
+ if (status != gptoss_status_success) {
915
+ goto cleanup;
916
+ }
917
+
918
+ status = gptoss_metal_command_buffer_commit(&command_buffer);
919
+ if (status != gptoss_status_success) {
920
+ goto cleanup;
921
+ }
922
+
923
+ status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
924
+ if (status != gptoss_status_success) {
925
+ goto cleanup;
926
+ }
927
+
928
+ context->num_kv_tokens = context->num_tokens;
929
+
930
+ cleanup:
931
+ gptoss_metal_command_buffer_release(&command_buffer);
932
+ return status;
933
+ }
934
+
935
+ return gptoss_status_success;
936
+ }
937
+
938
+ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
939
+ gptoss_context_t context,
940
+ float temperature,
941
+ uint64_t seed,
942
+ size_t max_tokens,
943
+ uint32_t* tokens_out,
944
+ size_t* num_tokens_out)
945
+ {
946
+ enum gptoss_status status = gptoss_status_success;
947
+ const struct gptoss_model* model = context->model;
948
+ struct gptoss_metal_command_buffer command_buffer = {0};
949
+
950
+ *num_tokens_out = 0;
951
+
952
+ const uint32_t num_original_tokens = context->num_tokens;
953
+
954
+ status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
955
+ if (status != gptoss_status_success) {
956
+ goto cleanup;
957
+ }
958
+
959
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
960
+ control->abort = 0;
961
+
962
+ for (size_t t = 0; t < max_tokens; t++) {
963
+ if (context->num_kv_tokens < context->num_tokens) {
964
+ status = process_tokens(
965
+ context,
966
+ &command_buffer,
967
+ /*input_tokens_offset=*/context->num_kv_tokens,
968
+ /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
969
+ /*num_output_tokens=*/1);
970
+ context->num_kv_tokens = context->num_tokens;
971
+ } else {
972
+ status = process_tokens(
973
+ context,
974
+ &command_buffer,
975
+ /*input_tokens_offset=*/context->num_tokens - 1,
976
+ /*num_input_tokens=*/1,
977
+ /*num_output_tokens=*/1);
978
+ }
979
+ if (status != gptoss_status_success) {
980
+ goto cleanup;
981
+ }
982
+
983
+ if (temperature != 0.0f) {
984
+ assert(context->num_processed_tokens != 0);
985
+ uint32_t num_threadgroups = 0;
986
+ uint32_t num_dims_per_threadgroup = 0;
987
+ status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
988
+ &command_buffer,
989
+ &model->f32_softmax_fn,
990
+ /*threadgroup_size=*/512,
991
+ model->max_threadgroups,
992
+ &context->score_buffer,
993
+ /*score_offset=*/0,
994
+ &context->argmax_buffer,
995
+ /*argmax_offset=*/0,
996
+ &context->prob_buffer,
997
+ /*prob_offset=*/0,
998
+ &context->sum_buffer,
999
+ /*sum_offset=*/0,
1000
+ &context->control_buffer,
1001
+ /*control_offset=*/0,
1002
+ model->vocabulary_size,
1003
+ /*num_tokens=*/1,
1004
+ temperature,
1005
+ &num_threadgroups,
1006
+ &num_dims_per_threadgroup);
1007
+ if (status != gptoss_status_success) {
1008
+ GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
1009
+ goto cleanup;
1010
+ }
1011
+
1012
+ status = gptoss_metal_command_buffer_encode_launch_f32_sample(
1013
+ &command_buffer,
1014
+ &model->f32_sample_fn,
1015
+ /*min_threadgroup_size=*/512,
1016
+ &context->prob_buffer,
1017
+ /*prob_offset=*/0,
1018
+ &context->sum_buffer,
1019
+ /*sum_offset=*/0,
1020
+ &context->token_buffer,
1021
+ /*token_offset=*/context->num_tokens * sizeof(uint32_t),
1022
+ &context->control_buffer,
1023
+ /*control_offset=*/0,
1024
+ /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
1025
+ /*rng_offset=*/context->num_tokens,
1026
+ /*num_blocks=*/num_threadgroups,
1027
+ /*num_channels=*/model->vocabulary_size,
1028
+ /*num_channels_per_block=*/num_dims_per_threadgroup);
1029
+ if (status != gptoss_status_success) {
1030
+ GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
1031
+ goto cleanup;
1032
+ }
1033
+ } else {
1034
+ status = gptoss_metal_command_buffer_encode_copy_buffer(
1035
+ &command_buffer,
1036
+ &context->argmax_buffer,
1037
+ /*input_offset=*/0,
1038
+ &context->token_buffer,
1039
+ /*output_offset=*/context->num_tokens * sizeof(uint32_t),
1040
+ /*size=*/sizeof(uint32_t));
1041
+ if (status != gptoss_status_success) {
1042
+ GPTOSS_LOG_ERROR("failed to encode copy buffer");
1043
+ goto cleanup;
1044
+ }
1045
+ }
1046
+ context->num_tokens += 1;
1047
+ context->num_kv_tokens = context->num_tokens;
1048
+ }
1049
+
1050
+ gptoss_metal_command_buffer_commit(&command_buffer);
1051
+ gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
1052
+
1053
+ const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
1054
+ const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
1055
+ memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
1056
+ *num_tokens_out = num_generated_tokens;
1057
+
1058
+ cleanup:
1059
+ gptoss_metal_command_buffer_release(&command_buffer);
1060
+ return status;
1061
+ }
1062
+
1063
+ enum gptoss_status GPTOSS_ABI gptoss_context_reset(
1064
+ gptoss_context_t context)
1065
+ {
1066
+ context->num_tokens = 0;
1067
+
1068
+ // Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
1069
+ // If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
1070
+
1071
+ return gptoss_status_success;
1072
+ }
1073
+
1074
+ enum gptoss_status GPTOSS_ABI gptoss_context_retain(
1075
+ gptoss_context_t context)
1076
+ {
1077
+ atomic_fetch_add_explicit(&context->ref_count, 1, memory_order_relaxed);
1078
+ return gptoss_status_success;
1079
+ }
1080
+
1081
+ enum gptoss_status GPTOSS_ABI gptoss_context_release(
1082
+ gptoss_context_t context)
1083
+ {
1084
+ if (context != NULL) {
1085
+ if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
1086
+ // Activation buffers
1087
+ gptoss_metal_buffer_release(&context->residual_activation_buffer);
1088
+ gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
1089
+ gptoss_metal_buffer_release(&context->qkv_activation_buffer);
1090
+ gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
1091
+ gptoss_metal_buffer_release(&context->gate_activation_buffer);
1092
+ gptoss_metal_buffer_release(&context->expert_activation_buffer);
1093
+ gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
1094
+ gptoss_metal_buffer_release(&context->moe_activation_buffer);
1095
+ gptoss_metal_buffer_release(&context->expert_offset_buffer);
1096
+ gptoss_metal_buffer_release(&context->token_to_expert_routing_buffer);
1097
+ gptoss_metal_buffer_release(&context->swiglu_input_buffer);
1098
+
1099
+ // Input/output buffers
1100
+ gptoss_metal_buffer_release(&context->control_buffer);
1101
+ gptoss_metal_buffer_release(&context->token_buffer);
1102
+ gptoss_metal_buffer_release(&context->score_buffer);
1103
+ gptoss_metal_buffer_release(&context->prob_buffer);
1104
+ gptoss_metal_buffer_release(&context->sum_buffer);
1105
+ gptoss_metal_buffer_release(&context->argmax_buffer);
1106
+ gptoss_metal_buffer_release(&context->kvcache_buffer);
1107
+
1108
+ gptoss_model_release(context->model);
1109
+
1110
+ memset(context, 0, sizeof(struct gptoss_context));
1111
+ free(context);
1112
+ }
1113
+ }
1114
+ return gptoss_status_success;
1115
+ }
gptoss_kernels/source/convert.metal ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_integer>
2
+
3
+ #include <internal/kernel-args.h>
4
+
5
+ #pragma METAL fp math_mode(safe)
6
+ #pragma METAL fp contract(off)
7
+
8
+
9
+ kernel void gptoss_mf4_f32_convert(
10
+ constant gptoss_convert_args& args [[ buffer(0) ]],
11
+ const device uint4* blocks [[ buffer(1) ]],
12
+ const device uchar* scales [[ buffer(2) ]],
13
+ device float4* output [[ buffer(3) ]],
14
+ uint gid [[threadgroup_position_in_grid]],
15
+ uint tid [[thread_position_in_threadgroup]],
16
+ uint threadgroup_size [[ threads_per_threadgroup ]])
17
+ {
18
+ const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
19
+ const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
20
+ const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
21
+ const ulong thread_start = threadgroup_start + tid;
22
+ uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
23
+
24
+ blocks += thread_start;
25
+ scales += thread_start;
26
+ output += 8 * thread_start;
27
+ for (; num_iter != 0; num_iter--) {
28
+ const uint4 block = *blocks;
29
+ const float scale = as_type<float>((static_cast<uint>(*scales) + 14) << 23);
30
+ uint4 block02468ACEGIKMOQSU = block + block;
31
+ uint4 block13579BDFHJLNPRTV = block >> 3;
32
+ block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
33
+ block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
34
+ block02468ACEGIKMOQSU += 0x70707070u;
35
+ block13579BDFHJLNPRTV += 0x70707070u;
36
+ block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
37
+ block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
38
+ const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;
39
+ const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
40
+ const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;
41
+ const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
42
+ const float4 block048C = static_cast<float4>(as_type<half4>(block048CGKOS.xy)) * scale;
43
+ const float4 blockGKOS = static_cast<float4>(as_type<half4>(block048CGKOS.zw)) * scale;
44
+ const float4 block26AE = static_cast<float4>(as_type<half4>(block26AEIMQU.xy)) * scale;
45
+ const float4 blockIMQU = static_cast<float4>(as_type<half4>(block26AEIMQU.zw)) * scale;
46
+ const float4 block159D = static_cast<float4>(as_type<half4>(block159DHLPT.xy)) * scale;
47
+ const float4 blockHLPT = static_cast<float4>(as_type<half4>(block159DHLPT.zw)) * scale;
48
+ const float4 block37BF = static_cast<float4>(as_type<half4>(block37BFJNRV.xy)) * scale;
49
+ const float4 blockJNRV = static_cast<float4>(as_type<half4>(block37BFJNRV.zw)) * scale;
50
+
51
+ output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };
52
+ output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };
53
+ output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };
54
+ output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };
55
+ output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };
56
+ output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };
57
+ output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };
58
+ output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };
59
+
60
+ blocks += threadgroup_size;
61
+ scales += threadgroup_size;
62
+ output += 8 * threadgroup_size;
63
+ }
64
+ }
gptoss_kernels/source/embeddings.metal ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/kernel-args.h>
2
+
3
+ #pragma METAL fp math_mode(safe)
4
+ #pragma METAL fp contract(off)
5
+
6
+
7
+ kernel void gptoss_bf16_f32_embeddings(
8
+ constant gptoss_embeddings_args& args [[ buffer(0) ]],
9
+ const device uint* tokens [[ buffer(1) ]],
10
+ const device bfloat4* weights [[ buffer(2) ]],
11
+ device float4* output [[ buffer(3) ]],
12
+ const device gptoss_control* control [[ buffer(4) ]],
13
+ uint gid [[threadgroup_position_in_grid]],
14
+ uint tid [[thread_position_in_threadgroup]],
15
+ uint threadgroup_size [[ threads_per_threadgroup ]])
16
+ {
17
+ if (control->abort != 0) {
18
+ return;
19
+ }
20
+
21
+ const uint t = tokens[gid];
22
+
23
+ weights += t * args.num_vecs;
24
+ output += gid * args.num_vecs;
25
+ for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
26
+ const bfloat4 w = weights[i];
27
+ output[i] = static_cast<float4>(w);
28
+ }
29
+ }
gptoss_kernels/source/expert_routing_metadata.metal ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/kernel-args.h>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_stdlib>
5
+
6
+ constant uint kMaxExperts = 128;
7
+
8
+ kernel void gptoss_f32_expert_routing_metadata(
9
+ constant gptoss_expert_routing_metadata_args& args [[ buffer(0) ]],
10
+ const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(1) ]],
11
+ device uint* __restrict__ expert_offsets [[ buffer(2) ]],
12
+ device uint* __restrict__ intra_expert_offsets [[ buffer(3) ]],
13
+ uint tg_size [[threads_per_threadgroup]],
14
+ uint tid [[thread_position_in_threadgroup]])
15
+ {
16
+ assert(args.num_experts <= kMaxExperts);
17
+ // Create threadgroup mem and initialize it to 0.
18
+ threadgroup metal::atomic_uint tg_counts[kMaxExperts];
19
+ for (uint e = tid; e < args.num_experts; e += tg_size) {
20
+ metal::atomic_store_explicit(&tg_counts[e], 0u, metal::memory_order_relaxed);
21
+ }
22
+
23
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
24
+
25
+ for (uint i = tid; i < args.tokens; i += tg_size) {
26
+ const uint e = expert_predictions[i].expert_id;
27
+ const uint r = metal::atomic_fetch_add_explicit(&tg_counts[e], 1u, metal::memory_order_relaxed);
28
+ intra_expert_offsets[i] = r;
29
+ }
30
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
31
+
32
+ if (tid == 0) {
33
+ uint total = 0;
34
+ for (uint e = 0; e < args.num_experts; ++e) {
35
+ const uint bin = metal::atomic_load_explicit(&tg_counts[e], metal::memory_order_relaxed);
36
+ expert_offsets[e] = total;
37
+ total += bin;
38
+ }
39
+ expert_offsets[args.num_experts] = total;
40
+ }
41
+ }
gptoss_kernels/source/gather_and_accumulate.metal ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/kernel-args.h>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_stdlib>
5
+
6
+ // TODO(ibrahim): This is not optimal as each thread only gathers and accumulates a single float4. To amortize the
7
+ // cost of reading the expert, offset and scales for a token, we should let each thread gather and accumulate several
8
+ // float4s.
9
+ kernel void gptoss_f32_gather_and_accumulate_e4(
10
+ constant gptoss_gather_args& args [[ buffer(0) ]],
11
+ const device float* in [[ buffer(1) ]],
12
+ const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
13
+ const device uint* expert_offsets [[ buffer(3) ]],
14
+ const device uint* intra_expert_offsets [[ buffer(4) ]],
15
+ device float* out [[ buffer(5) ]],
16
+ uint3 gid [[thread_position_in_grid]])
17
+ {
18
+ const uint T = args.tokens;
19
+ const uint k = args.active_experts_per_token;
20
+ const uint D = args.token_stride;
21
+
22
+ assert((D & 3u) == 0);
23
+ assert(k == 4);
24
+
25
+ const uint row = gid.y;
26
+ if (row >= T) {
27
+ return;
28
+ }
29
+
30
+ const uint col_vec4 = gid.x;
31
+ const uint col = col_vec4 * 4u;
32
+ if (col >= D) {
33
+ return;
34
+ }
35
+
36
+ device float4* dst4 = reinterpret_cast<device float4*>(out + row * D + col);
37
+
38
+ const uint base = row * k;
39
+ const gptoss_expert_prediction expert0 = expert_predictions[base];
40
+ const gptoss_expert_prediction expert1 = expert_predictions[base + 1];
41
+ const gptoss_expert_prediction expert2 = expert_predictions[base + 2];
42
+ const gptoss_expert_prediction expert3 = expert_predictions[base + 3];
43
+ const uint expert0_id = expert0.expert_id;
44
+ const uint expert1_id = expert1.expert_id;
45
+ const uint expert2_id = expert2.expert_id;
46
+ const uint expert3_id = expert3.expert_id;
47
+ const float scale0 = expert0.score;
48
+ const float scale1 = expert1.score;
49
+ const float scale2 = expert2.score;
50
+ const float scale3 = expert3.score;
51
+ const uint4 current_intra_expert_offsets =
52
+ *reinterpret_cast<const device uint4*>(&intra_expert_offsets[base]);
53
+ // Get the row indices for the current expert ids
54
+ const uint r0 = expert_offsets[expert0_id] + current_intra_expert_offsets.x;
55
+ const uint r1 = expert_offsets[expert1_id] + current_intra_expert_offsets.y;
56
+ const uint r2 = expert_offsets[expert2_id] + current_intra_expert_offsets.z;
57
+ const uint r3 = expert_offsets[expert3_id] + current_intra_expert_offsets.w;
58
+
59
+ const device float4* src0 =
60
+ reinterpret_cast<const device float4*>(in + r0 * D + col);
61
+ const device float4* src1 =
62
+ reinterpret_cast<const device float4*>(in + r1 * D + col);
63
+ const device float4* src2 =
64
+ reinterpret_cast<const device float4*>(in + r2 * D + col);
65
+ const device float4* src3 =
66
+ reinterpret_cast<const device float4*>(in + r3 * D + col);
67
+
68
+ float4 acc = *dst4;
69
+ acc = metal::fma(*src0, scale0, acc);
70
+ acc = metal::fma(*src1, scale1, acc);
71
+ acc = metal::fma(*src2, scale2, acc);
72
+ acc = metal::fma(*src3, scale3, acc);
73
+ *dst4 = acc;
74
+ }
gptoss_kernels/source/generate.c ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h>
2
+ #include <inttypes.h>
3
+ #include <math.h>
4
+ #include <signal.h>
5
+ #include <stdatomic.h>
6
+ #include <stdbool.h>
7
+ #include <stdio.h>
8
+ #include <stdint.h>
9
+ #include <stdlib.h>
10
+ #include <string.h>
11
+
12
+ #include <mach/mach_time.h>
13
+
14
+ #include <gpt-oss.h>
15
+
16
+ #include "internal/model.h"
17
+
18
+ struct {
19
+ atomic_uint_least64_t inference_bytes;
20
+ atomic_size_t num_prefill_tokens;
21
+ atomic_uint_least64_t prefill_microseconds;
22
+ atomic_size_t num_generated_tokens;
23
+ atomic_uint_least64_t generation_microseconds;
24
+ } globals = {
25
+ .inference_bytes = 0,
26
+ .num_prefill_tokens = 0,
27
+ .prefill_microseconds = 0,
28
+ .num_generated_tokens = 0,
29
+ .generation_microseconds = 0,
30
+ };
31
+
32
+ struct options {
33
+ const char* model;
34
+ const char* prompt;
35
+ size_t context_length;
36
+ size_t max_tokens;
37
+ float temperature;
38
+ bool verbose;
39
+ };
40
+
41
+ static inline double mach_timestamp_diff_to_seconds(uint64_t start_timestamp, uint64_t end_timestamp) {
42
+ static mach_timebase_info_data_t timebase_info = {0};
43
+ if (timebase_info.denom == 0) {
44
+ mach_timebase_info(&timebase_info);
45
+ }
46
+ const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
47
+ return ((double) elapsed_mach_time * (double) timebase_info.numer) / ((double) timebase_info.denom * 1.0e+9);
48
+ }
49
+
50
+ static inline uint64_t mach_timestamp_diff_to_microseconds(uint64_t start_timestamp, uint64_t end_timestamp) {
51
+ static mach_timebase_info_data_t timebase_info = {0};
52
+ if (timebase_info.denom == 0) {
53
+ mach_timebase_info(&timebase_info);
54
+ }
55
+ const uint64_t elapsed_mach_time = end_timestamp - start_timestamp;
56
+ const uint64_t denominator = timebase_info.denom * UINT64_C(1000);
57
+ return (elapsed_mach_time * timebase_info.numer + denominator / 2) / denominator;
58
+ }
59
+
60
+ static void print_usage(const char* program_name) {
61
+ printf("Usage: %s <model-path> [-p <prompt>] [-n <tokens>]\n", program_name);
62
+ }
63
+
64
+ struct options parse_options(int argc, char** argv) {
65
+ struct options options = (struct options) {
66
+ .model = NULL,
67
+ .prompt = NULL,
68
+ .context_length = 0,
69
+ .max_tokens = 0,
70
+ .temperature = 0.0f,
71
+ .verbose = false,
72
+ };
73
+ if (argc < 2) {
74
+ fprintf(stderr, "Error: missing required command-line argument\n");
75
+ print_usage(argv[0]);
76
+ exit(EXIT_FAILURE);
77
+ }
78
+ for (int i = 1; i < argc; i++) {
79
+ if (strcmp(argv[i], "--help") == 0) {
80
+ print_usage(argv[0]);
81
+ exit(EXIT_SUCCESS);
82
+ } else if (strcmp(argv[i], "-p") == 0 || strcmp(argv[i], "--prompt") == 0) {
83
+ if (i + 1 >= argc) {
84
+ fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
85
+ print_usage(argv[0]);
86
+ exit(EXIT_FAILURE);
87
+ }
88
+ options.prompt = argv[++i];
89
+ } else if (strcmp(argv[i], "--context-length") == 0) {
90
+ if (i + 1 >= argc) {
91
+ fprintf(stderr, "Error: missing argument for --context-length\n");
92
+ print_usage(argv[0]);
93
+ exit(EXIT_FAILURE);
94
+ }
95
+ char* context_length_start = argv[++i];
96
+ char* context_length_end = context_length_start;
97
+ options.context_length = strtoul(context_length_start, &context_length_end, 10);
98
+ if (context_length_end == context_length_start || *context_length_end != 0) {
99
+ fprintf(stderr, "Error: failed to parse context length value \"%s\"\n", context_length_start);
100
+ exit(EXIT_FAILURE);
101
+ }
102
+ } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--max-tokens") == 0) {
103
+ if (i + 1 >= argc) {
104
+ fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
105
+ print_usage(argv[0]);
106
+ exit(EXIT_FAILURE);
107
+ }
108
+ char* max_tokens_start = argv[++i];
109
+ char* max_tokens_end = max_tokens_start;
110
+ options.max_tokens = strtoul(max_tokens_start, &max_tokens_end, 10);
111
+ if (max_tokens_end == max_tokens_start || *max_tokens_end != 0) {
112
+ fprintf(stderr, "Error: failed to max tokens value \"%s\"\n", max_tokens_start);
113
+ exit(EXIT_FAILURE);
114
+ }
115
+ if (options.max_tokens == 0) {
116
+ fprintf(stderr, "Error: invalid max tokens value %zu\n", options.max_tokens);
117
+ exit(EXIT_FAILURE);
118
+ }
119
+ } else if (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--temperature") == 0) {
120
+ if (i + 1 >= argc) {
121
+ fprintf(stderr, "Error: missing argument for %s\n", argv[i]);
122
+ print_usage(argv[0]);
123
+ exit(EXIT_FAILURE);
124
+ }
125
+ char* temperature_start = argv[++i];
126
+ char* temperature_end = temperature_start;
127
+ options.temperature = strtof(temperature_start, &temperature_end);
128
+ if (temperature_end == temperature_start || *temperature_end != 0) {
129
+ fprintf(stderr, "Error: failed to parse temperature value \"%s\"\n", temperature_start);
130
+ exit(EXIT_FAILURE);
131
+ }
132
+ if (signbit(options.temperature) != 0 || !(options.temperature <= 2.0f)) {
133
+ fprintf(stderr, "Error: invalid temperature value %f\n", options.temperature);
134
+ exit(EXIT_FAILURE);
135
+ }
136
+ } else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0) {
137
+ options.verbose = true;
138
+ } else {
139
+ if (options.model == NULL) {
140
+ options.model = argv[i];
141
+ } else {
142
+ fprintf(stderr, "Error: unexpected command-line argument %s\n", argv[i]);
143
+ print_usage(argv[0]);
144
+ exit(EXIT_FAILURE);
145
+ }
146
+ }
147
+ }
148
+ if (options.model == NULL) {
149
+ fprintf(stderr, "Error: missing required model argument\n");
150
+ print_usage(argv[0]);
151
+ exit(EXIT_FAILURE);
152
+ }
153
+ if (options.prompt == NULL) {
154
+ fprintf(stderr, "Error: missing required prompt argument\n");
155
+ print_usage(argv[0]);
156
+ exit(EXIT_FAILURE);
157
+ }
158
+ return options;
159
+ }
160
+
161
+
162
+ static void print_profile() {
163
+ const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
164
+ const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
165
+ const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
166
+ const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
167
+ const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
168
+ if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
169
+ printf("\n");
170
+ }
171
+ if (num_prefill_tokens != 0) {
172
+ printf("Prefill speed (%zu tokens): %.1f tokens/second\n",
173
+ num_prefill_tokens,
174
+ (double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
175
+ }
176
+ if (num_generated_tokens != 0) {
177
+ printf("Generation speed (%zu tokens): %.1f tokens/second\n",
178
+ num_generated_tokens,
179
+ (double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
180
+ }
181
+ }
182
+
183
+ static void ctrl_c_handler(int signum) {
184
+ print_profile();
185
+ exit(EXIT_SUCCESS);
186
+ }
187
+
188
+ int main(int argc, char *argv[]) {
189
+ enum gptoss_status status;
190
+ gptoss_model_t model = NULL;
191
+ gptoss_tokenizer_t tokenizer = NULL;
192
+ gptoss_context_t context = NULL;
193
+
194
+ struct sigaction act;
195
+ act.sa_handler = ctrl_c_handler;
196
+ sigaction(SIGINT, &act, NULL);
197
+
198
+ setvbuf(stdout, NULL, _IONBF, 0);
199
+
200
+ struct options options = parse_options(argc, argv);
201
+
202
+ const uint64_t load_start_time = mach_continuous_time();
203
+ status = gptoss_model_create_from_file(options.model, &model);
204
+ if (status != gptoss_status_success) {
205
+ fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
206
+ goto error;
207
+ }
208
+ size_t max_model_context_length = 0;
209
+ status = gptoss_model_get_max_context_length(model, &max_model_context_length);
210
+ if (status != gptoss_status_success) {
211
+ fprintf(stderr, "Error: failed to query maximum context length\n");
212
+ goto error;
213
+ }
214
+ assert(max_model_context_length != 0);
215
+ if (options.context_length == 0) {
216
+ options.context_length = max_model_context_length;
217
+ } else if (options.context_length > max_model_context_length) {
218
+ fprintf(stderr, "Error: context length %zu exceeds maximum context length %zu supported by the model\n", options.context_length, max_model_context_length);
219
+ goto error;
220
+ }
221
+
222
+ status = gptoss_model_get_tokenizer(model, &tokenizer);
223
+ if (status != gptoss_status_success) {
224
+ fprintf(stderr, "Error: failed to retrieve Tokenizer\n");
225
+ goto error;
226
+ }
227
+
228
+ uint32_t return_token_id = UINT32_MAX;
229
+ status = gptoss_tokenizer_get_special_token_id(tokenizer, gptoss_special_token_return, &return_token_id);
230
+ if (status != gptoss_status_success) {
231
+ fprintf(stderr, "Error: failed to query end-of-text token ID\n");
232
+ goto error;
233
+ }
234
+
235
+ status = gptoss_context_create(model, options.context_length, /*max_batch_tokens=*/0, &context);
236
+ if (status != gptoss_status_success) {
237
+ fprintf(stderr, "Error: failed to create Context object\n");
238
+ goto error;
239
+ }
240
+ if (options.verbose) {
241
+ printf("Model weights size: %.2lf MB\n", (double) model->weights_size * 0x1.0p-20);
242
+ printf("Model allocation size: %.2lf MB\n", (double) model->allocation_size * 0x1.0p-20);
243
+ printf("Context allocation size: %.2lf MB\n", (double) context->allocation_size * 0x1.0p-20);
244
+ printf(" Including KV cache: %.2lf MB\n", (double) context->kvcache_size * 0x1.0p-20);
245
+ }
246
+
247
+ const uint64_t load_end_time = mach_continuous_time();
248
+ const double load_elapsed_seconds = mach_timestamp_diff_to_seconds(load_start_time, load_end_time);
249
+ if (options.verbose) {
250
+ printf("Loaded model in %.3f seconds\n", load_elapsed_seconds);
251
+ }
252
+
253
+ const uint64_t prefill_start_time = mach_continuous_time();
254
+ size_t num_prefill_tokens = 0;
255
+ status = gptoss_context_append_chars(context, options.prompt, strlen(options.prompt), &num_prefill_tokens);
256
+ if (status != gptoss_status_success) {
257
+ fprintf(stderr, "Error: failed to tokenize prompt \"%s\"\n", options.prompt);
258
+ goto error;
259
+ }
260
+ atomic_store(&globals.num_prefill_tokens, num_prefill_tokens);
261
+ status = gptoss_context_process(context);
262
+ if (status != gptoss_status_success) {
263
+ fprintf(stderr, "Error: failed to process Context object\n");
264
+ goto error;
265
+ }
266
+ const uint64_t prefill_end_time = mach_continuous_time();
267
+
268
+ while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
269
+
270
+ uint32_t predicted_token = UINT32_MAX;
271
+ size_t num_predicted_tokens = 0;
272
+ const uint64_t inference_start_timestamp = mach_continuous_time();
273
+ status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
274
+ if (status != gptoss_status_success) {
275
+ fprintf(stderr, "Error: failed to sample from the Context object\n");
276
+ goto error;
277
+ }
278
+ const uint64_t inference_end_timestamp = mach_continuous_time();
279
+
280
+ if (predicted_token == return_token_id) {
281
+ // Yield token -> stop generation
282
+ break;
283
+ }
284
+
285
+ // Unembedding: detokenize
286
+ size_t token_size = 0;
287
+ const void* token_ptr = NULL;
288
+ status = gptoss_tokenizer_decode(tokenizer, predicted_token, &token_ptr, &token_size);
289
+ if (status != gptoss_status_success) {
290
+ fprintf(stderr, "Error: failed to detokenize predicted token %" PRIu32 "\n", predicted_token);
291
+ goto error;
292
+ }
293
+ const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
294
+ if (previous_num_generated_tokens == 0) {
295
+ atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
296
+ } else {
297
+ atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
298
+ }
299
+ printf("%.*s", (int) token_size, (const char*) token_ptr);
300
+
301
+ status = gptoss_context_append_tokens(context, 1, &predicted_token);
302
+ if (status != gptoss_status_success) {
303
+ fprintf(stderr, "Error: failed to append predicted token %" PRIu32 " to context\n", predicted_token);
304
+ goto error;
305
+ }
306
+ }
307
+
308
+ print_profile();
309
+
310
+ return EXIT_SUCCESS;
311
+
312
+ error:
313
+ gptoss_context_release(context);
314
+ gptoss_tokenizer_release(tokenizer);
315
+ gptoss_model_release(model);
316
+ return EXIT_FAILURE;
317
+ }
gptoss_kernels/source/include/internal/datatype.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdint.h>
4
+
5
+ #include <internal/macros.h>
6
+
7
+
8
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
9
+ GPTOSS_ALIGN(2) uint16_t bits;
10
+ } gptoss_bfloat16;
11
+ static_assert(sizeof(gptoss_bfloat16) == 2, "bfloat16 size is not 2 bytes");
12
+
13
+
14
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
15
+ GPTOSS_ALIGN(2) uint16_t bits;
16
+ } gptoss_float16;
17
+ static_assert(sizeof(gptoss_float16) == 2, "float16 size is not 2 bytes");
18
+
19
+
20
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
21
+ GPTOSS_ALIGN(1) uint8_t bits;
22
+ } gptoss_float8ue8m0;
23
+ static_assert(sizeof(gptoss_float8ue8m0) == 1, "gptoss_float8ue8m0 size is not 1 bytes");
24
+
25
+
26
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
27
+ GPTOSS_ALIGN(1) uint8_t bits;
28
+ } gptoss_float8e5m2;
29
+ static_assert(sizeof(gptoss_float8e5m2) == 1, "float8e5m2 size is not 1 bytes");
30
+
31
+
32
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
33
+ GPTOSS_ALIGN(1) uint8_t bits;
34
+ } gptoss_float8e4m3;
35
+ static_assert(sizeof(gptoss_float8e4m3) == 1, "gptoss_float8e4m3 size is not 1 bytes");
36
+
37
+
38
+ typedef struct GPTOSS_DENSELY_PACKED_STRUCTURE {
39
+ GPTOSS_ALIGN(1) uint8_t bits;
40
+ } gptoss_float4e2m1x2;
41
+ static_assert(sizeof(gptoss_float4e2m1x2) == 1, "gptoss_float4e2m1x2 size is not 1 bytes");
gptoss_kernels/source/include/internal/datatype.hpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <bit>
4
+
5
+ #include <internal/datatype.h>
6
+
7
+
8
+ namespace gptoss {
9
+
10
+ template <typename WideT, typename NarrowT>
11
+ WideT upcast(NarrowT);
12
+
13
+ template <>
14
+ inline float upcast<float>(gptoss_bfloat16 bf16_value) {
15
+ const uint32_t bits = static_cast<uint32_t>(bf16_value.bits) << 16;
16
+ return std::bit_cast<float>(bits);
17
+ }
18
+
19
+ template <>
20
+ inline float upcast<float>(gptoss_float16 fp16_value) {
21
+ return static_cast<float>(std::bit_cast<_Float16>(fp16_value.bits));
22
+ }
23
+
24
+ template <>
25
+ inline float upcast<float>(gptoss_float8e4m3 fp8_value) {
26
+ static constexpr uint16_t fp8e4m3_to_fp32[256] = {
27
+ 0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60,
28
+ 0x3C80, 0x3C90, 0x3CA0, 0x3CB0, 0x3CC0, 0x3CD0, 0x3CE0, 0x3CF0,
29
+ 0x3D00, 0x3D10, 0x3D20, 0x3D30, 0x3D40, 0x3D50, 0x3D60, 0x3D70,
30
+ 0x3D80, 0x3D90, 0x3DA0, 0x3DB0, 0x3DC0, 0x3DD0, 0x3DE0, 0x3DF0,
31
+ 0x3E00, 0x3E10, 0x3E20, 0x3E30, 0x3E40, 0x3E50, 0x3E60, 0x3E70,
32
+ 0x3E80, 0x3E90, 0x3EA0, 0x3EB0, 0x3EC0, 0x3ED0, 0x3EE0, 0x3EF0,
33
+ 0x3F00, 0x3F10, 0x3F20, 0x3F30, 0x3F40, 0x3F50, 0x3F60, 0x3F70,
34
+ 0x3F80, 0x3F90, 0x3FA0, 0x3FB0, 0x3FC0, 0x3FD0, 0x3FE0, 0x3FF0,
35
+ 0x4000, 0x4010, 0x4020, 0x4030, 0x4040, 0x4050, 0x4060, 0x4070,
36
+ 0x4080, 0x4090, 0x40A0, 0x40B0, 0x40C0, 0x40D0, 0x40E0, 0x40F0,
37
+ 0x4100, 0x4110, 0x4120, 0x4130, 0x4140, 0x4150, 0x4160, 0x4170,
38
+ 0x4180, 0x4190, 0x41A0, 0x41B0, 0x41C0, 0x41D0, 0x41E0, 0x41F0,
39
+ 0x4200, 0x4210, 0x4220, 0x4230, 0x4240, 0x4250, 0x4260, 0x4270,
40
+ 0x4280, 0x4290, 0x42A0, 0x42B0, 0x42C0, 0x42D0, 0x42E0, 0x42F0,
41
+ 0x4300, 0x4310, 0x4320, 0x4330, 0x4340, 0x4350, 0x4360, 0x4370,
42
+ 0x4380, 0x4390, 0x43A0, 0x43B0, 0x43C0, 0x43D0, 0x43E0, 0x7FF0,
43
+ 0x8000, 0xBB00, 0xBB80, 0xBBC0, 0xBC00, 0xBC20, 0xBC40, 0xBC60,
44
+ 0xBC80, 0xBC90, 0xBCA0, 0xBCB0, 0xBCC0, 0xBCD0, 0xBCE0, 0xBCF0,
45
+ 0xBD00, 0xBD10, 0xBD20, 0xBD30, 0xBD40, 0xBD50, 0xBD60, 0xBD70,
46
+ 0xBD80, 0xBD90, 0xBDA0, 0xBDB0, 0xBDC0, 0xBDD0, 0xBDE0, 0xBDF0,
47
+ 0xBE00, 0xBE10, 0xBE20, 0xBE30, 0xBE40, 0xBE50, 0xBE60, 0xBE70,
48
+ 0xBE80, 0xBE90, 0xBEA0, 0xBEB0, 0xBEC0, 0xBED0, 0xBEE0, 0xBEF0,
49
+ 0xBF00, 0xBF10, 0xBF20, 0xBF30, 0xBF40, 0xBF50, 0xBF60, 0xBF70,
50
+ 0xBF80, 0xBF90, 0xBFA0, 0xBFB0, 0xBFC0, 0xBFD0, 0xBFE0, 0xBFF0,
51
+ 0xC000, 0xC010, 0xC020, 0xC030, 0xC040, 0xC050, 0xC060, 0xC070,
52
+ 0xC080, 0xC090, 0xC0A0, 0xC0B0, 0xC0C0, 0xC0D0, 0xC0E0, 0xC0F0,
53
+ 0xC100, 0xC110, 0xC120, 0xC130, 0xC140, 0xC150, 0xC160, 0xC170,
54
+ 0xC180, 0xC190, 0xC1A0, 0xC1B0, 0xC1C0, 0xC1D0, 0xC1E0, 0xC1F0,
55
+ 0xC200, 0xC210, 0xC220, 0xC230, 0xC240, 0xC250, 0xC260, 0xC270,
56
+ 0xC280, 0xC290, 0xC2A0, 0xC2B0, 0xC2C0, 0xC2D0, 0xC2E0, 0xC2F0,
57
+ 0xC300, 0xC310, 0xC320, 0xC330, 0xC340, 0xC350, 0xC360, 0xC370,
58
+ 0xC380, 0xC390, 0xC3A0, 0xC3B0, 0xC3C0, 0xC3D0, 0xC3E0, 0xFFF0,
59
+ };
60
+ const gptoss_bfloat16 bf16_value{.bits = fp8e4m3_to_fp32[fp8_value.bits]};
61
+ return upcast<float>(bf16_value);
62
+ }
63
+
64
+ template <>
65
+ inline double upcast<double>(float fp32_value) {
66
+ return static_cast<double>(fp32_value);
67
+ }
68
+
69
+ template <>
70
+ inline double upcast<double>(gptoss_bfloat16 bf16_value) {
71
+ const float fp32_value = upcast<float>(bf16_value);
72
+ return upcast<double>(fp32_value);
73
+ }
74
+
75
+ template <>
76
+ inline double upcast<double>(gptoss_float16 fp16_value) {
77
+ const float fp32_value = upcast<float>(fp16_value);
78
+ return upcast<double>(fp32_value);
79
+ }
80
+
81
+ template <>
82
+ inline double upcast<double>(gptoss_float8e4m3 fp8_value) {
83
+ const float fp32_value = upcast<float>(fp8_value);
84
+ return upcast<double>(fp32_value);
85
+ }
86
+
87
+ } // namespace gptoss
gptoss_kernels/source/include/internal/kernel-args.h ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if !defined(__METAL_VERSION__)
4
+ #include <stdint.h>
5
+ #endif
6
+
7
+ // TODO(ibahmed): specalize using metal function constants.
8
+ #define QKV_Bm 64
9
+ #define QKV_Bn 64
10
+ #define QKV_Bk 32
11
+ #define QKV_Sg_Bm 32
12
+ #define QKV_Sg_Bn 32
13
+
14
+ #define ATTN_OUTPUT_Bm 32
15
+ #define ATTN_OUTPUT_Bn 64
16
+ #define ATTN_OUTPUT_Bk 64
17
+ #define ATTN_OUTPUT_Sg_Bm 32
18
+ #define ATTN_OUTPUT_Sg_Bn 16
19
+
20
+ #define MLP_GATE_Bm 64
21
+ #define MLP_GATE_Bn 16
22
+ #define MLP_GATE_Bk 64
23
+ #define MLP_GATE_Sg_Bm 16
24
+ #define MLP_GATE_Sg_Bn 16
25
+
26
+ #define MOE_DENSE_MATMUL_SWIGLU_Bm 32
27
+ #define MOE_DENSE_MATMUL_SWIGLU_Bn 64
28
+ #define MOE_DENSE_MATMUL_SWIGLU_Bk 16
29
+ #define MOE_DENSE_MATMUL_SWIGLU_Sg_Bm 32
30
+ #define MOE_DENSE_MATMUL_SWIGLU_Sg_Bn 16
31
+
32
+ #define MOE_DENSE_MATMUL_Bm 32
33
+ #define MOE_DENSE_MATMUL_Bn 64
34
+ #define MOE_DENSE_MATMUL_Bk 16
35
+ #define MOE_DENSE_MATMUL_Sg_Bm 32
36
+ #define MOE_DENSE_MATMUL_Sg_Bn 16
37
+
38
+ struct gptoss_expert_prediction {
39
+ uint32_t expert_id;
40
+ float score;
41
+ };
42
+
43
+ struct gptoss_control {
44
+ uint32_t abort;
45
+ };
46
+
47
+ struct gptoss_topk_args {
48
+ uint32_t num_vecs_per_token;
49
+ };
50
+
51
+ struct gptoss_sdpa_args {
52
+ uint32_t qkv_dim;
53
+ uint32_t num_kv_tokens;
54
+ uint32_t kv_stride;
55
+ uint32_t window;
56
+ };
57
+
58
+ struct gptoss_u32_fill_random_args {
59
+ uint64_t num_vecs_per_threadgroup;
60
+ uint64_t num_vecs;
61
+ uint64_t offset;
62
+ uint64_t seed;
63
+ };
64
+
65
+ struct gptoss_f32_fill_random_args {
66
+ uint64_t num_vecs_per_threadgroup;
67
+ uint64_t num_vecs;
68
+ uint64_t offset;
69
+ uint64_t seed;
70
+ float scale;
71
+ float bias;
72
+ };
73
+
74
+ struct gptoss_accumulate_args {
75
+ uint32_t num_vecs_per_expert;
76
+ uint32_t num_vecs_per_threadgroup;
77
+ uint32_t num_vecs;
78
+ };
79
+
80
+ struct gptoss_convert_args {
81
+ uint64_t num_vecs_per_threadgroup;
82
+ uint64_t num_vecs;
83
+ };
84
+
85
+ struct gptoss_embeddings_args {
86
+ uint32_t num_vecs;
87
+ };
88
+
89
+ struct gptoss_rmsnorm_args {
90
+ uint32_t num_vecs;
91
+ float num_channels;
92
+ float epsilon;
93
+ };
94
+
95
+ struct gptoss_matmul_args {
96
+ uint32_t num_column_vecs;
97
+ uint32_t num_rows;
98
+ uint32_t add;
99
+ };
100
+
101
+ struct gptoss_dense_matmul_args {
102
+ uint32_t m;
103
+ uint32_t n;
104
+ uint32_t k;
105
+ };
106
+
107
+ struct gptoss_scatter_args {
108
+ uint32_t tokens;
109
+ uint32_t active_experts_per_token;
110
+ uint32_t token_stride;
111
+ };
112
+
113
+ struct gptoss_moe_dense_matmul_swiglu_args {
114
+ uint32_t k;
115
+ uint32_t n;
116
+ uint32_t weight_blocks_expert_stride_bytes;
117
+ uint32_t weight_scales_expert_stride_bytes;
118
+ uint32_t bias_expert_stride_bytes;
119
+ float swiglu_min;
120
+ float swiglu_max;
121
+ };
122
+ struct gptoss_moe_dense_matmul_args {
123
+ uint32_t k;
124
+ uint32_t n;
125
+ uint32_t weight_blocks_expert_stride_bytes;
126
+ uint32_t weight_scales_expert_stride_bytes;
127
+ uint32_t bias_expert_stride_bytes;
128
+ };
129
+
130
+ struct gptoss_expert_routing_metadata_args {
131
+ uint32_t tokens;
132
+ uint32_t num_experts;
133
+ };
134
+
135
+ struct gptoss_gather_args {
136
+ uint32_t tokens;
137
+ uint32_t active_experts_per_token;
138
+ uint32_t token_stride;
139
+ };
140
+
141
+ struct gptoss_unembedding_args {
142
+ uint32_t num_column_vecs;
143
+ uint32_t num_rows_per_threadgroup;
144
+ uint32_t num_rows;
145
+ };
146
+
147
+ struct gptoss_moe_matmul_swiglu_args {
148
+ uint32_t num_column_vecs;
149
+ uint32_t num_rows;
150
+ uint32_t num_active_experts;
151
+ uint32_t weight_expert_stride; // in bytes
152
+ uint32_t output_expert_stride; // in elements
153
+ float swiglu_min;
154
+ float swiglu_max;
155
+ };
156
+
157
+ struct gptoss_moe_matmul_args {
158
+ uint32_t num_column_vecs;
159
+ uint32_t num_rows;
160
+ uint32_t num_active_experts;
161
+ uint32_t input_expert_stride; // in blocks of 32 elements
162
+ uint32_t weight_expert_stride; // in bytes
163
+ uint32_t output_expert_stride; // in elements
164
+ };
165
+
166
+ struct gptoss_rope_args {
167
+ uint32_t token_stride;
168
+ uint32_t token_offset;
169
+ float freq_scale;
170
+ float interpolation_scale;
171
+ float yarn_offset;
172
+ float yarn_scale;
173
+ float yarn_multiplier;
174
+ };
175
+
176
+ struct gptoss_qkv_args {
177
+ uint32_t num_column_vecs;
178
+ uint32_t num_rows;
179
+ uint32_t token_offset;
180
+ float freq_scale;
181
+ float interpolation_scale;
182
+ float yarn_offset;
183
+ float yarn_scale;
184
+ float yarn_multiplier;
185
+ uint32_t max_tokens;
186
+ };
187
+
188
+ struct gptoss_softmax_args {
189
+ uint32_t num_vecs;
190
+ uint32_t num_vecs_per_threadgroup;
191
+ uint32_t max_threadgroups;
192
+ float temperature;
193
+ };
194
+
195
+ struct gptoss_sample_args {
196
+ uint64_t rng_seed;
197
+ uint32_t rng_offset;
198
+ uint32_t num_blocks;
199
+ uint32_t num_dims;
200
+ uint32_t num_dims_per_block;
201
+ };
gptoss_kernels/source/include/internal/log.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdarg.h>
4
+
5
+
6
+ void gptoss_format_log(const char* format, va_list args);
7
+
8
+ __attribute__((__format__(__printf__, 1, 2)))
9
+ inline static void gptoss_log(const char* format, ...) {
10
+ va_list args;
11
+ va_start(args, format);
12
+ gptoss_format_log(format, args);
13
+ va_end(args);
14
+ }
15
+
16
+ #define GPTOSS_LOG_ERROR(message, ...) \
17
+ gptoss_log("Error: " message "\n", ##__VA_ARGS__)
18
+
19
+ #define GPTOSS_LOG_WARNING(message, ...) \
20
+ gptoss_log("Warning: " message "\n", ##__VA_ARGS__)
gptoss_kernels/source/include/internal/macros.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /***** Architecture detection macros *****/
4
+
5
+ #ifdef GPTOSS_ARCH_X86_64
6
+ #if GPTOSS_ARCH_X86_64 != 0 && GPTOSS_ARCH_X86_64 != 1
7
+ #error "Invalid GPTOSS_ARCH_X86_64 value: must be either 0 or 1"
8
+ #endif
9
+ #else
10
+ #if defined(__x86_64__) || defined(_M_X64) && !defined(_M_ARM64EC)
11
+ #define GPTOSS_ARCH_X86_64 1
12
+ #else
13
+ #define GPTOSS_ARCH_X86_64 0
14
+ #endif
15
+ #endif
16
+
17
+ #ifdef GPTOSS_ARCH_ARM64
18
+ #if GPTOSS_ARCH_ARM64 != 0 && GPTOSS_ARCH_ARM64 != 1
19
+ #error "Invalid GPTOSS_ARCH_ARM64 value: must be either 0 or 1"
20
+ #endif
21
+ #else
22
+ #if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
23
+ #define GPTOSS_ARCH_ARM64 1
24
+ #else
25
+ #define GPTOSS_ARCH_ARM64 0
26
+ #endif
27
+ #endif
28
+
29
+ #if GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 == 0
30
+ #error "Unsupported architecture: neither x86-64 nor ARM64 detected"
31
+ #elif GPTOSS_ARCH_X86_64 + GPTOSS_ARCH_ARM64 != 1
32
+ #error "Inconsistent architecture detection: both x86-64 and ARM64 detection macros are specified"
33
+ #endif
34
+
35
+ /***** Compiler portability macros *****/
36
+
37
+ #ifndef GPTOSS_LIKELY
38
+ #if defined(__GNUC__)
39
+ #define GPTOSS_LIKELY(condition) (__builtin_expect(!!(condition), 1))
40
+ #else
41
+ #define GPTOSS_LIKELY(condition) (!!(condition))
42
+ #endif
43
+ #endif
44
+
45
+ #ifndef GPTOSS_UNLIKELY
46
+ #if defined(__GNUC__)
47
+ #define GPTOSS_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
48
+ #else
49
+ #define GPTOSS_UNLIKELY(condition) (!!(condition))
50
+ #endif
51
+ #endif
52
+
53
+ #ifndef GPTOSS_UNPREDICTABLE
54
+ #if defined(__has_builtin)
55
+ #if __has_builtin(__builtin_unpredictable)
56
+ #define GPTOSS_UNPREDICTABLE(condition) (__builtin_unpredictable(!!(condition)))
57
+ #endif
58
+ #endif
59
+ #endif
60
+ #ifndef GPTOSS_UNPREDICTABLE
61
+ #if defined(__GNUC__) && (__GNUC__ >= 9) && !defined(__INTEL_COMPILER)
62
+ #define GPTOSS_UNPREDICTABLE(condition) (__builtin_expect_with_probability(!!(condition), 0, 0.5))
63
+ #else
64
+ #define GPTOSS_UNPREDICTABLE(condition) (!!(condition))
65
+ #endif
66
+ #endif
67
+
68
+ // Disable padding for structure members.
69
+ #ifndef GPTOSS_DENSELY_PACKED_STRUCTURE
70
+ #if defined(__GNUC__)
71
+ #define GPTOSS_DENSELY_PACKED_STRUCTURE __attribute__((__packed__))
72
+ #else
73
+ #error "Compiler-specific implementation of GPTOSS_DENSELY_PACKED_STRUCTURE required"
74
+ #endif
75
+ #endif
76
+
77
+ #ifndef GPTOSS_ALIGN
78
+ #if defined(__GNUC__)
79
+ #define GPTOSS_ALIGN(alignment) __attribute__((__aligned__(alignment)))
80
+ #elif defined(_MSC_VER)
81
+ #define GPTOSS_ALIGN(alignment) __declspec(align(alignment))
82
+ #else
83
+ #error "Compiler-specific implementation of GPTOSS_ALIGN required"
84
+ #endif
85
+ #endif
86
+
87
+ #ifndef GPTOSS_FORCE_INLINE
88
+ #if defined(__GNUC__)
89
+ #define GPTOSS_FORCE_INLINE inline __attribute__((__always_inline__))
90
+ #elif defined(_MSC_VER)
91
+ #define GPTOSS_FORCE_INLINE __forceinline
92
+ #else
93
+ #define GPTOSS_FORCE_INLINE inline
94
+ #endif
95
+ #endif
96
+
97
+ /***** Symbol visibility macros *****/
98
+
99
+ #ifndef GPTOSS_INTERNAL_SYMBOL
100
+ #if defined(__ELF__)
101
+ #define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("internal")))
102
+ #elif defined(__MACH__)
103
+ #define GPTOSS_INTERNAL_SYMBOL __attribute__((__visibility__("hidden")))
104
+ #else
105
+ #define GPTOSS_INTERNAL_SYMBOL
106
+ #endif
107
+ #endif
gptoss_kernels/source/include/internal/math.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <assert.h>
4
+ #include <stddef.h>
5
+ #include <stdint.h>
6
+
7
+ inline static size_t math_ceil_div(size_t numer, size_t denom) {
8
+ return (numer + denom - 1) / denom;
9
+ }
10
+
11
+ inline static size_t math_max(size_t a, size_t b) {
12
+ return a >= b ? a : b;
13
+ }
14
+
15
+ inline static size_t math_min(size_t a, size_t b) {
16
+ return a < b ? a : b;
17
+ }
18
+
19
+ inline static size_t math_sub_sat(size_t a, size_t b) {
20
+ return a > b ? a - b : 0;
21
+ }
22
+
23
+ static size_t math_round_down_po2(size_t number, size_t multiple) {
24
+ assert(multiple != 0);
25
+ assert((multiple & (multiple - 1)) == 0);
26
+
27
+ return number & -multiple;
28
+ }
29
+
30
+ static size_t math_round_up_po2(size_t number, size_t multiple) {
31
+ assert(multiple != 0);
32
+ assert((multiple & (multiple - 1)) == 0);
33
+
34
+ const size_t multiple_mask = multiple - 1;
35
+ if ((number & multiple_mask) != 0) {
36
+ number |= multiple_mask;
37
+ number += 1;
38
+ }
39
+ return number;
40
+ }
gptoss_kernels/source/include/internal/metal-kernels.h ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stddef.h>
4
+ #include <stdint.h>
5
+
6
+ #include <internal/metal.h>
7
+
8
+ #ifdef __cplusplus
9
+ extern "C" {
10
+ #endif
11
+
12
+ #include <stddef.h>
13
+ #include <stdint.h>
14
+
15
+ #include <internal/kernel-args.h>
16
+ #include <internal/math.h>
17
+ #include <internal/metal.h>
18
+
19
+
20
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
21
+ const struct gptoss_metal_command_buffer* command_buffer,
22
+ const struct gptoss_metal_function* u32_fill_random_fn,
23
+ size_t threadgroup_size,
24
+ size_t max_threadgroups,
25
+ const struct gptoss_metal_buffer* output_buffer,
26
+ size_t output_offset,
27
+ uint64_t num_elements,
28
+ uint64_t rng_seed,
29
+ uint64_t rng_offset);
30
+
31
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
32
+ const struct gptoss_metal_command_buffer* command_buffer,
33
+ const struct gptoss_metal_function* f32_fill_random_fn,
34
+ size_t threadgroup_size,
35
+ size_t max_threadgroups,
36
+ const struct gptoss_metal_buffer* output_buffer,
37
+ size_t output_offset,
38
+ uint64_t num_elements,
39
+ uint64_t rng_seed,
40
+ uint64_t rng_offset,
41
+ float rng_min,
42
+ float rng_max);
43
+
44
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
45
+ const struct gptoss_metal_command_buffer* command_buffer,
46
+ const struct gptoss_metal_function* bf16_fill_random_fn,
47
+ size_t threadgroup_size,
48
+ size_t max_threadgroups,
49
+ const struct gptoss_metal_buffer* output_buffer,
50
+ size_t output_offset,
51
+ uint64_t num_elements,
52
+ uint64_t rng_seed,
53
+ uint64_t rng_offset,
54
+ float rng_min,
55
+ float rng_max);
56
+
57
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
58
+ const struct gptoss_metal_command_buffer* command_buffer,
59
+ const struct gptoss_metal_function* mf4_f32_convert_fn,
60
+ size_t threadgroup_size,
61
+ size_t max_threadgroups,
62
+ const struct gptoss_metal_buffer* block_buffer,
63
+ const struct gptoss_metal_buffer* scale_buffer,
64
+ const struct gptoss_metal_buffer* output_buffer,
65
+ uint64_t num_elements);
66
+
67
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
68
+ const struct gptoss_metal_command_buffer* command_buffer,
69
+ const struct gptoss_metal_function* bf16_f32_embeddings_fn,
70
+ size_t threadgroup_size,
71
+ const struct gptoss_metal_buffer* token_buffer,
72
+ size_t token_offset,
73
+ const struct gptoss_metal_buffer* weight_buffer,
74
+ size_t weight_offset,
75
+ const struct gptoss_metal_buffer* output_buffer,
76
+ size_t output_offset,
77
+ const struct gptoss_metal_buffer* control_buffer,
78
+ size_t control_offset,
79
+ uint32_t num_tokens,
80
+ uint32_t num_channels);
81
+
82
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
83
+ const struct gptoss_metal_command_buffer* command_buffer,
84
+ const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
85
+ const struct gptoss_metal_buffer* input_buffer,
86
+ size_t input_offset,
87
+ const struct gptoss_metal_buffer* weight_buffer,
88
+ size_t weight_offset,
89
+ const struct gptoss_metal_buffer* output_buffer,
90
+ size_t output_offset,
91
+ const struct gptoss_metal_buffer* control_buffer,
92
+ size_t control_offset,
93
+ uint32_t num_tokens,
94
+ uint32_t num_channels,
95
+ float epsilon);
96
+
97
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
98
+ const struct gptoss_metal_command_buffer* command_buffer,
99
+ const struct gptoss_metal_function* f32_bf16w_matmul_fn,
100
+ size_t threadgroup_size,
101
+ const struct gptoss_metal_buffer* input_buffer,
102
+ size_t input_offset,
103
+ const struct gptoss_metal_buffer* weight_buffer,
104
+ size_t weight_offset,
105
+ const struct gptoss_metal_buffer* bias_buffer,
106
+ size_t bias_offset,
107
+ const struct gptoss_metal_buffer* output_buffer,
108
+ size_t output_offset,
109
+ const struct gptoss_metal_buffer* control_buffer,
110
+ size_t control_offset,
111
+ uint32_t num_tokens,
112
+ uint32_t num_cols,
113
+ uint32_t num_rows);
114
+
115
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
116
+ const struct gptoss_metal_command_buffer* command_buffer,
117
+ const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
118
+ size_t threadgroup_size,
119
+ const struct gptoss_metal_buffer* input_buffer,
120
+ size_t input_offset,
121
+ const struct gptoss_metal_buffer* weight_buffer,
122
+ size_t weight_offset,
123
+ const struct gptoss_metal_buffer* bias_buffer,
124
+ size_t bias_offset,
125
+ const struct gptoss_metal_buffer* output_buffer,
126
+ size_t output_offset,
127
+ const struct gptoss_metal_buffer* kv_buffer,
128
+ size_t kv_offset,
129
+ const struct gptoss_metal_buffer* control_buffer,
130
+ size_t control_offset,
131
+ uint32_t num_tokens,
132
+ uint32_t num_cols,
133
+ uint32_t num_q_heads,
134
+ uint32_t num_kv_heads,
135
+ uint32_t attn_head_dim,
136
+ uint32_t token_offset,
137
+ uint32_t max_tokens,
138
+ float rope_base,
139
+ float interpolation_scale,
140
+ float yarn_offset,
141
+ float yarn_scale,
142
+ float yarn_multiplier);
143
+
144
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
145
+ const struct gptoss_metal_command_buffer* command_buffer,
146
+ const struct gptoss_metal_function* f32_bf16w_matmul_fn,
147
+ size_t threadgroup_size,
148
+ const struct gptoss_metal_buffer* input_buffer,
149
+ size_t input_offset,
150
+ const struct gptoss_metal_buffer* weight_buffer,
151
+ size_t weight_offset,
152
+ const struct gptoss_metal_buffer* bias_buffer,
153
+ size_t bias_offset,
154
+ const struct gptoss_metal_buffer* output_buffer,
155
+ size_t output_offset,
156
+ const struct gptoss_metal_buffer* control_buffer,
157
+ size_t control_offset,
158
+ uint32_t num_tokens,
159
+ uint32_t num_cols,
160
+ uint32_t num_rows);
161
+
162
+ enum gptoss_status
163
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
164
+ const struct gptoss_metal_command_buffer* command_buffer,
165
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
166
+ const struct gptoss_metal_buffer* input_buffer,
167
+ size_t input_offset,
168
+ const struct gptoss_metal_buffer* weight_buffer,
169
+ size_t weight_offset,
170
+ const struct gptoss_metal_buffer* bias_buffer,
171
+ size_t bias_offset,
172
+ const struct gptoss_metal_buffer* output_buffer,
173
+ size_t output_offset,
174
+ const struct gptoss_metal_buffer* control_buffer,
175
+ size_t control_offset,
176
+ uint32_t num_tokens,
177
+ uint32_t num_cols,
178
+ uint32_t num_rows);
179
+
180
+ enum gptoss_status
181
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
182
+ const struct gptoss_metal_command_buffer* command_buffer,
183
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
184
+ const struct gptoss_metal_buffer* input_buffer,
185
+ size_t input_offset,
186
+ const struct gptoss_metal_buffer* weight_buffer,
187
+ size_t weight_offset,
188
+ const struct gptoss_metal_buffer* bias_buffer,
189
+ size_t bias_offset,
190
+ const struct gptoss_metal_buffer* output_buffer,
191
+ size_t output_offset,
192
+ const struct gptoss_metal_buffer* control_buffer,
193
+ size_t control_offset,
194
+ uint32_t num_tokens,
195
+ uint32_t num_cols,
196
+ uint32_t num_rows);
197
+
198
+ enum gptoss_status
199
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
200
+ const struct gptoss_metal_command_buffer* command_buffer,
201
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
202
+ const struct gptoss_metal_buffer* input_buffer,
203
+ size_t input_offset,
204
+ const struct gptoss_metal_buffer* weight_buffer,
205
+ size_t weight_offset,
206
+ const struct gptoss_metal_buffer* bias_buffer,
207
+ size_t bias_offset,
208
+ const struct gptoss_metal_buffer* output_buffer,
209
+ size_t output_offset,
210
+ const struct gptoss_metal_buffer* control_buffer,
211
+ size_t control_offset,
212
+ uint32_t num_tokens,
213
+ uint32_t num_cols,
214
+ uint32_t num_rows);
215
+
216
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
217
+ const struct gptoss_metal_command_buffer* command_buffer,
218
+ const struct gptoss_metal_function* f32_bf16w_matmul_fn,
219
+ size_t threadgroup_size,
220
+ size_t max_threadgroups,
221
+ const struct gptoss_metal_buffer* input_buffer,
222
+ size_t input_offset,
223
+ const struct gptoss_metal_buffer* weight_buffer,
224
+ size_t weight_offset,
225
+ const struct gptoss_metal_buffer* output_buffer,
226
+ size_t output_offset,
227
+ const struct gptoss_metal_buffer* argmax_buffer,
228
+ size_t argmax_offset,
229
+ const struct gptoss_metal_buffer* control_buffer,
230
+ size_t control_offset,
231
+ uint32_t num_tokens,
232
+ uint32_t num_cols,
233
+ uint32_t num_rows);
234
+
235
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
236
+ const struct gptoss_metal_command_buffer* command_buffer,
237
+ const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
238
+ size_t threadgroup_size,
239
+ const struct gptoss_metal_buffer* input_buffer,
240
+ size_t input_offset,
241
+ const struct gptoss_metal_buffer* expert_buffer,
242
+ size_t expert_offset,
243
+ const struct gptoss_metal_buffer* weight_block_buffer,
244
+ size_t weight_block_offset,
245
+ const struct gptoss_metal_buffer* weight_scale_buffer,
246
+ size_t weight_scale_offset,
247
+ const struct gptoss_metal_buffer* bias_buffer,
248
+ size_t bias_offset,
249
+ const struct gptoss_metal_buffer* output_buffer,
250
+ size_t output_offset,
251
+ const struct gptoss_metal_buffer* control_buffer,
252
+ size_t control_offset,
253
+ float swiglu_limit,
254
+ uint32_t expert_stride,
255
+ uint32_t num_tokens,
256
+ uint32_t num_active_experts,
257
+ uint32_t num_cols,
258
+ uint32_t num_rows);
259
+
260
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
261
+ const struct gptoss_metal_command_buffer* command_buffer,
262
+ const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
263
+ size_t threadgroup_size,
264
+ const struct gptoss_metal_buffer* input_buffer,
265
+ size_t input_offset,
266
+ const struct gptoss_metal_buffer* expert_buffer,
267
+ size_t expert_offset,
268
+ const struct gptoss_metal_buffer* weight_block_buffer,
269
+ size_t weight_block_offset,
270
+ const struct gptoss_metal_buffer* weight_scale_buffer,
271
+ size_t weight_scale_offset,
272
+ const struct gptoss_metal_buffer* bias_buffer,
273
+ size_t bias_offset,
274
+ const struct gptoss_metal_buffer* output_buffer,
275
+ size_t output_offset,
276
+ const struct gptoss_metal_buffer* control_buffer,
277
+ size_t control_offset,
278
+ uint32_t expert_stride,
279
+ uint32_t num_tokens,
280
+ uint32_t num_active_experts,
281
+ uint32_t num_cols,
282
+ uint32_t num_rows);
283
+
284
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
285
+ const struct gptoss_metal_command_buffer* command_buffer,
286
+ const struct gptoss_metal_function* f32_rope_fn,
287
+ size_t threadgroup_size,
288
+ const struct gptoss_metal_buffer* activations_buffer,
289
+ size_t activations_offset,
290
+ const struct gptoss_metal_buffer* control_buffer,
291
+ size_t control_offset,
292
+ float rope_base,
293
+ float interpolation_scale,
294
+ float yarn_offset,
295
+ float yarn_scale,
296
+ float yarn_multiplier,
297
+ uint32_t num_tokens,
298
+ uint32_t num_q_heads,
299
+ uint32_t num_kv_heads,
300
+ uint32_t attn_head_dim,
301
+ uint32_t token_offset);
302
+
303
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
304
+ const struct gptoss_metal_command_buffer* command_buffer,
305
+ const struct gptoss_metal_function* f32_accumulate_fn,
306
+ size_t threadgroup_size,
307
+ size_t max_threadgroups,
308
+ const struct gptoss_metal_buffer* input_buffer,
309
+ size_t input_offset,
310
+ const struct gptoss_metal_buffer* expert_buffer,
311
+ size_t expert_offset,
312
+ const struct gptoss_metal_buffer* output_buffer,
313
+ size_t output_offset,
314
+ const struct gptoss_metal_buffer* control_buffer,
315
+ size_t control_offset,
316
+ uint32_t num_channels,
317
+ uint32_t num_tokens,
318
+ uint32_t num_experts);
319
+
320
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
321
+ const struct gptoss_metal_command_buffer* command_buffer,
322
+ const struct gptoss_metal_function* expert_routing_metadata_fn,
323
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
324
+ size_t expert_predictions_offset,
325
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
326
+ size_t expert_offsets_offset,
327
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
328
+ size_t intra_expert_offsets_offset,
329
+ uint32_t num_tokens,
330
+ uint32_t num_experts);
331
+
332
+
333
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(
334
+ const struct gptoss_metal_command_buffer* command_buffer,
335
+ const struct gptoss_metal_function* f32_scatter_fn,
336
+ const struct gptoss_metal_buffer* input_buffer,
337
+ size_t input_offset,
338
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
339
+ size_t expert_predictions_offset,
340
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
341
+ size_t expert_offsets_offset,
342
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
343
+ size_t intra_expert_offsets_offset,
344
+ const struct gptoss_metal_buffer* output_buffer,
345
+ size_t output_offset,
346
+ uint32_t num_channels,
347
+ uint32_t num_tokens,
348
+ uint32_t num_active_experts);
349
+
350
+ enum gptoss_status
351
+ gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
352
+ const struct gptoss_metal_command_buffer* command_buffer,
353
+ const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,
354
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
355
+ size_t expert_offsets_offset,
356
+ const struct gptoss_metal_buffer* input_buffer,
357
+ size_t input_offset,
358
+ const struct gptoss_metal_buffer* weight_block_buffer,
359
+ size_t weight_block_offset,
360
+ const struct gptoss_metal_buffer* weight_scale_buffer,
361
+ size_t weight_scale_offset,
362
+ const struct gptoss_metal_buffer* bias_buffer,
363
+ size_t bias_offset,
364
+ const struct gptoss_metal_buffer* output_buffer,
365
+ size_t output_offset,
366
+ float swiglu_limit,
367
+ uint32_t expert_stride_bytes,
368
+ uint32_t num_tokens,
369
+ uint32_t num_experts,
370
+ uint32_t num_cols,
371
+ uint32_t num_rows);
372
+
373
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
374
+ const struct gptoss_metal_command_buffer* command_buffer,
375
+ const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,
376
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
377
+ size_t expert_offsets_offset,
378
+ const struct gptoss_metal_buffer* input_buffer,
379
+ size_t input_offset,
380
+ const struct gptoss_metal_buffer* weight_block_buffer,
381
+ size_t weight_block_offset,
382
+ const struct gptoss_metal_buffer* weight_scale_buffer,
383
+ size_t weight_scale_offset,
384
+ const struct gptoss_metal_buffer* bias_buffer,
385
+ size_t bias_offset,
386
+ const struct gptoss_metal_buffer* output_buffer,
387
+ size_t output_offset,
388
+ uint32_t expert_stride_bytes,
389
+ uint32_t num_tokens,
390
+ uint32_t num_experts,
391
+ uint32_t num_cols,
392
+ uint32_t num_rows);
393
+
394
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
395
+ const struct gptoss_metal_command_buffer* command_buffer,
396
+ const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,
397
+ const struct gptoss_metal_buffer* input_buffer,
398
+ size_t input_offset,
399
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
400
+ size_t expert_predictions_offset,
401
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
402
+ size_t expert_offsets_offset,
403
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
404
+ size_t intra_expert_offsets_offset,
405
+ const struct gptoss_metal_buffer* output_buffer,
406
+ size_t output_offset,
407
+ uint32_t num_channels,
408
+ uint32_t num_tokens,
409
+ uint32_t num_active_experts);
410
+
411
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
412
+ const struct gptoss_metal_command_buffer* command_buffer,
413
+ const struct gptoss_metal_function* f32_topk_fn,
414
+ const struct gptoss_metal_buffer* input_buffer,
415
+ size_t input_offset,
416
+ const struct gptoss_metal_buffer* output_buffer,
417
+ size_t output_offset,
418
+ const struct gptoss_metal_buffer* control_buffer,
419
+ size_t control_offset,
420
+ uint32_t num_tokens,
421
+ uint32_t num_experts,
422
+ uint32_t num_active_experts);
423
+
424
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
425
+ const struct gptoss_metal_command_buffer* command_buffer,
426
+ const struct gptoss_metal_function* f32_sdpa_fn,
427
+ const struct gptoss_metal_buffer* q_buffer,
428
+ size_t q_offset,
429
+ const struct gptoss_metal_buffer* kv_buffer,
430
+ size_t kv_offset,
431
+ const struct gptoss_metal_buffer* s_buffer,
432
+ size_t s_offset,
433
+ const struct gptoss_metal_buffer* output_buffer,
434
+ size_t output_offset,
435
+ const struct gptoss_metal_buffer* control_buffer,
436
+ size_t control_offset,
437
+ uint32_t window,
438
+ uint32_t kv_stride,
439
+ uint32_t num_q_tokens,
440
+ uint32_t num_kv_tokens,
441
+ uint32_t num_q_heads,
442
+ uint32_t num_kv_heads,
443
+ uint32_t head_dim);
444
+
445
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
446
+ const struct gptoss_metal_command_buffer* command_buffer,
447
+ const struct gptoss_metal_function* f32_softmax_fn,
448
+ size_t threadgroup_size,
449
+ size_t max_threadgroups,
450
+ const struct gptoss_metal_buffer* score_buffer,
451
+ size_t score_offset,
452
+ const struct gptoss_metal_buffer* argmax_buffer,
453
+ size_t argmax_offset,
454
+ const struct gptoss_metal_buffer* prob_buffer,
455
+ size_t prob_offset,
456
+ const struct gptoss_metal_buffer* sum_buffer,
457
+ size_t sum_offset,
458
+ const struct gptoss_metal_buffer* control_buffer,
459
+ size_t control_offset,
460
+ uint32_t num_channels,
461
+ uint32_t num_tokens,
462
+ float temperature,
463
+ uint32_t* num_threadgroups_out,
464
+ uint32_t* num_channels_per_threadgroup_out);
465
+
466
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
467
+ const struct gptoss_metal_command_buffer* command_buffer,
468
+ const struct gptoss_metal_function* f32_sample_fn,
469
+ size_t min_threadgroup_size,
470
+ const struct gptoss_metal_buffer* prob_buffer,
471
+ size_t prob_offset,
472
+ const struct gptoss_metal_buffer* sum_buffer,
473
+ size_t sum_offset,
474
+ const struct gptoss_metal_buffer* token_buffer,
475
+ size_t token_offset,
476
+ const struct gptoss_metal_buffer* control_buffer,
477
+ size_t control_offset,
478
+ uint64_t rng_seed,
479
+ uint32_t rng_offset,
480
+ uint32_t num_blocks,
481
+ uint32_t num_channels,
482
+ uint32_t num_channels_per_block);
483
+
484
+ #ifdef __cplusplus
485
+ } // extern "C"
486
+ #endif
gptoss_kernels/source/include/internal/metal.h ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stddef.h>
4
+
5
+ #include <gpt-oss/types.h>
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ struct gptoss_metal_device {
12
+ void* object; // id<MTLDevice>
13
+ size_t num_cores;
14
+ size_t max_buffer_size;
15
+ size_t max_threadgroup_memory;
16
+ size_t max_threadgroup_threads_x;
17
+ size_t max_threadgroup_threads_y;
18
+ size_t max_threadgroup_threads_z;
19
+ };
20
+
21
+ enum gptoss_status gptoss_metal_device_create_system_default(
22
+ struct gptoss_metal_device* device_out);
23
+
24
+ enum gptoss_status gptoss_metal_device_release(
25
+ struct gptoss_metal_device* device);
26
+
27
+
28
+ struct gptoss_metal_library {
29
+ void* object; // id<MTLLibrary>
30
+ };
31
+
32
+ enum gptoss_status gptoss_metal_library_create_default(
33
+ const struct gptoss_metal_device* device,
34
+ struct gptoss_metal_library* library_out);
35
+
36
+ enum gptoss_status gptoss_metal_library_release(
37
+ struct gptoss_metal_library* library);
38
+
39
+ struct gptoss_metal_function {
40
+ void* function_object; // id<MTLFunction>
41
+ void* pipeline_state_object; // id<MTLComputePipelineState>
42
+ size_t max_threadgroup_threads;
43
+ size_t simdgroup_threads;
44
+ size_t static_threadgroup_memory;
45
+ };
46
+
47
+ enum gptoss_status gptoss_metal_function_create(
48
+ const struct gptoss_metal_library* library,
49
+ const char* name,
50
+ struct gptoss_metal_function* function_out);
51
+
52
+ enum gptoss_status gptoss_metal_function_release(
53
+ struct gptoss_metal_function* function);
54
+
55
+ struct gptoss_metal_buffer {
56
+ void* object; // id<MTLBuffer>
57
+ size_t size;
58
+ void* ptr;
59
+ };
60
+
61
+ enum gptoss_status gptoss_metal_buffer_create(
62
+ const struct gptoss_metal_device* device,
63
+ size_t size,
64
+ const void* data,
65
+ struct gptoss_metal_buffer* buffer_out);
66
+
67
+ enum gptoss_status gptoss_metal_buffer_wrap(
68
+ const struct gptoss_metal_device* device,
69
+ size_t size,
70
+ const void* data,
71
+ struct gptoss_metal_buffer* buffer_out);
72
+
73
+ enum gptoss_status gptoss_metal_buffer_release(
74
+ struct gptoss_metal_buffer* buffer);
75
+
76
+ struct gptoss_metal_command_queue {
77
+ void* object; // id<MTLCommandQueue>
78
+ };
79
+
80
+ enum gptoss_status gptoss_metal_command_queue_create(
81
+ const struct gptoss_metal_device* device,
82
+ struct gptoss_metal_command_queue* command_queue_out);
83
+
84
+ enum gptoss_status gptoss_metal_command_queue_release(
85
+ struct gptoss_metal_command_queue* command_queue);
86
+
87
+ struct gptoss_metal_command_buffer {
88
+ void* object; // id<MTLCommandBuffer>
89
+ };
90
+
91
+ enum gptoss_status gptoss_metal_command_buffer_create(
92
+ const struct gptoss_metal_command_queue* command_queue,
93
+ struct gptoss_metal_command_buffer* command_buffer_out);
94
+
95
+ enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
96
+ const struct gptoss_metal_command_buffer* command_buffer,
97
+ const struct gptoss_metal_buffer* buffer,
98
+ size_t offset,
99
+ size_t size,
100
+ uint8_t fill_value);
101
+
102
+ enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
103
+ const struct gptoss_metal_command_buffer* command_buffer,
104
+ const struct gptoss_metal_buffer* input_buffer,
105
+ size_t input_offset,
106
+ const struct gptoss_metal_buffer* output_buffer,
107
+ size_t output_offset,
108
+ size_t size);
109
+
110
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
111
+ const struct gptoss_metal_command_buffer* command_buffer,
112
+ const struct gptoss_metal_function* function,
113
+ size_t threadgroup_size_x,
114
+ size_t threadgroup_size_y,
115
+ size_t threadgroup_size_z,
116
+ size_t num_threadgroups_x,
117
+ size_t num_threadgroups_y,
118
+ size_t num_threadgroups_z,
119
+ size_t params_size,
120
+ const void* params,
121
+ size_t num_device_buffers,
122
+ const struct gptoss_metal_buffer** device_buffers,
123
+ const size_t* device_buffer_offsets,
124
+ size_t threadgroup_buffer_size);
125
+
126
+ enum gptoss_status gptoss_metal_command_buffer_commit(
127
+ const struct gptoss_metal_command_buffer* command_buffer);
128
+
129
+ enum gptoss_status gptoss_metal_command_buffer_wait_completion(
130
+ const struct gptoss_metal_command_buffer* command_buffer,
131
+ double* elapsed_seconds);
132
+
133
+ enum gptoss_status gptoss_metal_command_buffer_release(
134
+ struct gptoss_metal_command_buffer* command_buffer);
135
+
136
+ #ifdef __cplusplus
137
+ } // extern "C"
138
+ #endif
gptoss_kernels/source/include/internal/metal.hpp ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <initializer_list>
5
+ #include <cstring>
6
+ #include <stdexcept>
7
+ #include <vector>
8
+
9
+ #include <gpt-oss/types.h>
10
+ #include <internal/metal.h>
11
+ #include <internal/metal-kernels.h>
12
+
13
+
14
+ namespace gptoss {
15
+
16
+ inline void Check(gptoss_status s, const char* what) {
17
+ if (s != gptoss_status_success) {
18
+ throw std::runtime_error(what);
19
+ }
20
+ }
21
+
22
+ inline std::size_t round_up(std::size_t p, std::size_t q) {
23
+ const std::size_t r = p % q;
24
+ if (r == 0) {
25
+ return p;
26
+ } else {
27
+ return p - r + q;
28
+ }
29
+ }
30
+
31
+ namespace metal {
32
+
33
+ class Device {
34
+ public:
35
+ inline Device() {
36
+ Check(gptoss_metal_device_create_system_default(&device_), "create Device");
37
+ }
38
+
39
+ inline ~Device() {
40
+ gptoss_metal_device_release(&device_);
41
+ }
42
+
43
+ Device(const Device&) = delete;
44
+ Device& operator=(const Device&) = delete;
45
+
46
+ inline Device(Device&& other) noexcept {
47
+ device_ = other.device_;
48
+ std::memset(&other.device_, 0, sizeof(other.device_));
49
+ }
50
+
51
+ inline Device& operator=(Device&& other) noexcept {
52
+ if (this != &other) {
53
+ gptoss_metal_device_release(&device_);
54
+ device_ = other.device_;
55
+ std::memset(&other.device_, 0, sizeof(other.device_));
56
+ }
57
+ return *this;
58
+ }
59
+
60
+ inline const gptoss_metal_device* handle() const noexcept { return &device_; }
61
+
62
+ inline size_t max_buffer_size() const noexcept { return device_.max_buffer_size; }
63
+ inline size_t max_threadgroup_memory() const noexcept { return device_.max_threadgroup_memory; }
64
+ inline size_t max_threadgroup_threads_x() const noexcept { return device_.max_threadgroup_threads_x; }
65
+ inline size_t max_threadgroup_threads_y() const noexcept { return device_.max_threadgroup_threads_y; }
66
+ inline size_t max_threadgroup_threads_z() const noexcept { return device_.max_threadgroup_threads_z; }
67
+
68
+ private:
69
+ gptoss_metal_device device_{};
70
+ };
71
+
72
+ class Library {
73
+ public:
74
+ inline explicit Library(const Device& dev) {
75
+ Check(gptoss_metal_library_create_default(dev.handle(), &library_),
76
+ "gptoss_metal_library_create_default");
77
+ }
78
+
79
+ inline ~Library() {
80
+ gptoss_metal_library_release(&library_);
81
+ }
82
+
83
+ Library(const Library&) = delete;
84
+ Library& operator=(const Library&) = delete;
85
+
86
+ inline Library(Library&& other) noexcept {
87
+ library_ = other.library_;
88
+ std::memset(&other.library_, 0, sizeof(other.library_));
89
+ }
90
+
91
+ inline Library& operator=(Library&& other) noexcept {
92
+ if (this != &other) {
93
+ gptoss_metal_library_release(&library_);
94
+ library_ = other.library_;
95
+ std::memset(&other.library_, 0, sizeof(other.library_));
96
+ }
97
+ return *this;
98
+ }
99
+
100
+ inline const gptoss_metal_library* handle() const noexcept {
101
+ return &library_;
102
+ }
103
+
104
+ private:
105
+ gptoss_metal_library library_{};
106
+ };
107
+
108
+ class Function {
109
+ public:
110
+ inline Function(const Library& library, const char* name) {
111
+ Check(gptoss_metal_function_create(library.handle(), name, &function_),
112
+ "gptoss_metal_function_create");
113
+ }
114
+
115
+ inline ~Function() {
116
+ gptoss_metal_function_release(&function_);
117
+ }
118
+
119
+ Function(const Function&) = delete;
120
+ Function& operator=(const Function&) = delete;
121
+
122
+ inline Function(Function&& other) noexcept {
123
+ function_ = other.function_;
124
+ std::memset(&other.function_, 0, sizeof(other.function_));
125
+ }
126
+
127
+ inline Function& operator=(Function&& other) noexcept {
128
+ if (this != &other) {
129
+ gptoss_metal_function_release(&function_);
130
+ function_ = other.function_;
131
+ std::memset(&other.function_, 0, sizeof(other.function_));
132
+ }
133
+ return *this;
134
+ }
135
+
136
+ inline const gptoss_metal_function* handle() const noexcept { return &function_; }
137
+
138
+ inline size_t max_threadgroup_threads() const noexcept { return function_.max_threadgroup_threads; }
139
+ inline size_t simdgroup_threads() const noexcept { return function_.simdgroup_threads; }
140
+ inline size_t static_threadgroup_memory() const noexcept { return function_.static_threadgroup_memory; }
141
+
142
+ private:
143
+ gptoss_metal_function function_{};
144
+ };
145
+
146
+ class Buffer {
147
+ public:
148
+ inline Buffer(const Device& dev, size_t size, const void* data = nullptr) {
149
+ Check(gptoss_metal_buffer_create(dev.handle(), size, data, &buffer_), "create buffer");
150
+ }
151
+
152
+ inline ~Buffer() {
153
+ gptoss_metal_buffer_release(&buffer_);
154
+ }
155
+
156
+ Buffer(const Buffer&) = delete;
157
+ Buffer& operator=(const Buffer&) = delete;
158
+
159
+ inline Buffer(Buffer&& other) noexcept {
160
+ buffer_ = other.buffer_;
161
+ std::memset(&other.buffer_, 0, sizeof(other.buffer_));
162
+ }
163
+
164
+ inline Buffer& operator=(Buffer&& other) noexcept {
165
+ if (this != &other) {
166
+ gptoss_metal_buffer_release(&buffer_);
167
+ buffer_ = other.buffer_;
168
+ std::memset(&other.buffer_, 0, sizeof(other.buffer_));
169
+ }
170
+ return *this;
171
+ }
172
+
173
+ inline size_t size() const noexcept { return buffer_.size; }
174
+ inline void* ptr() const noexcept { return buffer_.ptr; }
175
+
176
+ inline const gptoss_metal_buffer* handle() const noexcept { return &buffer_; }
177
+
178
+ private:
179
+ gptoss_metal_buffer buffer_{};
180
+ };
181
+
182
+ class CommandQueue {
183
+ public:
184
+ inline explicit CommandQueue(const Device& dev) {
185
+ Check(gptoss_metal_command_queue_create(dev.handle(), &command_queue_),
186
+ "gptoss_metal_command_queue_create");
187
+ }
188
+
189
+ inline ~CommandQueue() {
190
+ gptoss_metal_command_queue_release(&command_queue_);
191
+ }
192
+
193
+ CommandQueue(const CommandQueue&) = delete;
194
+ CommandQueue& operator=(const CommandQueue&) = delete;
195
+
196
+ inline CommandQueue(CommandQueue&& other) noexcept {
197
+ command_queue_ = other.command_queue_;
198
+ std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
199
+ }
200
+
201
+ inline CommandQueue& operator=(CommandQueue&& other) noexcept {
202
+ if (this != &other) {
203
+ gptoss_metal_command_queue_release(&command_queue_);
204
+ command_queue_ = other.command_queue_;
205
+ std::memset(&other.command_queue_, 0, sizeof(other.command_queue_));
206
+ }
207
+ return *this;
208
+ }
209
+
210
+ inline const gptoss_metal_command_queue* handle() const noexcept {
211
+ return &command_queue_;
212
+ }
213
+
214
+ private:
215
+ gptoss_metal_command_queue command_queue_{};
216
+ };
217
+
218
+ class CommandBuffer {
219
+ public:
220
+ inline explicit CommandBuffer(const CommandQueue& command_queue) {
221
+ Check(gptoss_metal_command_buffer_create(command_queue.handle(), &command_buffer_),
222
+ "gptoss_metal_command_buffer_create");
223
+ }
224
+ inline ~CommandBuffer() {
225
+ gptoss_metal_command_buffer_release(&command_buffer_);
226
+ }
227
+
228
+ CommandBuffer(const CommandBuffer&) = delete;
229
+ CommandBuffer& operator=(const CommandBuffer&) = delete;
230
+
231
+ inline CommandBuffer(CommandBuffer&& other) noexcept {
232
+ command_buffer_ = other.command_buffer_;
233
+ std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
234
+ }
235
+
236
+ inline CommandBuffer& operator=(CommandBuffer&& other) noexcept {
237
+ if (this != &other) {
238
+ gptoss_metal_command_buffer_release(&command_buffer_);
239
+ command_buffer_ = other.command_buffer_;
240
+ std::memset(&other.command_buffer_, 0, sizeof(other.command_buffer_));
241
+ }
242
+ return *this;
243
+ }
244
+
245
+ inline void encode_launch_kernel(const Function& function,
246
+ const std::array<size_t, 3>& threadgroup_size,
247
+ const std::array<size_t, 3>& num_threadgroups,
248
+ size_t params_size, const void* params,
249
+ std::initializer_list<const Buffer*> device_buffers = {},
250
+ size_t threadgroup_buffer_size = 0)
251
+ {
252
+ std::vector<const gptoss_metal_buffer*> buffer_handles(device_buffers.size());
253
+ std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),
254
+ [](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
255
+ Check(gptoss_metal_command_buffer_encode_launch_kernel(
256
+ &command_buffer_, function.handle(),
257
+ threadgroup_size[0], threadgroup_size[1], threadgroup_size[2],
258
+ num_threadgroups[0], num_threadgroups[1], num_threadgroups[2],
259
+ params_size, params,
260
+ buffer_handles.size(),
261
+ buffer_handles.data(),
262
+ /*buffer_offsets=*/nullptr,
263
+ threadgroup_buffer_size),
264
+ "gptoss_metal_command_buffer_encode_launch_kernel");
265
+ }
266
+
267
+ inline void encode_launch_f32_fill_random(const Function& f32_fill_random_fn,
268
+ size_t threadgroup_size,
269
+ size_t num_threadgroups,
270
+ const Buffer& output_buffer,
271
+ size_t output_offset,
272
+ size_t num_channels,
273
+ uint64_t rng_seed,
274
+ uint64_t rng_offset,
275
+ float rng_min,
276
+ float rng_max)
277
+ {
278
+ Check(gptoss_metal_command_buffer_encode_launch_f32_fill_random(
279
+ &command_buffer_, f32_fill_random_fn.handle(),
280
+ threadgroup_size, num_threadgroups,
281
+ output_buffer.handle(), output_offset,
282
+ num_channels,
283
+ rng_seed, rng_offset, rng_min, rng_max),
284
+ "gptoss_metal_command_buffer_encode_launch_f32_fill_random");
285
+ }
286
+
287
+ inline void encode_launch_bf16_fill_random(const Function& bf16_fill_random_fn,
288
+ size_t threadgroup_size,
289
+ size_t num_threadgroups,
290
+ const Buffer& output_buffer,
291
+ size_t output_offset,
292
+ size_t num_channels,
293
+ uint64_t rng_seed,
294
+ uint64_t rng_offset,
295
+ float rng_min,
296
+ float rng_max)
297
+ {
298
+ Check(gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
299
+ &command_buffer_, bf16_fill_random_fn.handle(),
300
+ threadgroup_size, num_threadgroups,
301
+ output_buffer.handle(), output_offset,
302
+ num_channels,
303
+ rng_seed, rng_offset, rng_min, rng_max),
304
+ "gptoss_metal_command_buffer_encode_launch_bf16_fill_random");
305
+ }
306
+
307
+ inline void encode_launch_u32_fill_random(const Function& u32_fill_random_fn,
308
+ size_t threadgroup_size,
309
+ size_t num_threadgroups,
310
+ const Buffer& output_buffer,
311
+ size_t output_offset,
312
+ size_t num_channels,
313
+ uint64_t rng_seed,
314
+ uint64_t rng_offset)
315
+ {
316
+ Check(gptoss_metal_command_buffer_encode_launch_u32_fill_random(
317
+ &command_buffer_, u32_fill_random_fn.handle(),
318
+ threadgroup_size, num_threadgroups,
319
+ output_buffer.handle(), output_offset,
320
+ num_channels,
321
+ rng_seed, rng_offset),
322
+ "gptoss_metal_command_buffer_encode_launch_u32_fill_random");
323
+ }
324
+
325
+ inline void commit() {
326
+ Check(gptoss_metal_command_buffer_commit(&command_buffer_), "commit");
327
+ }
328
+
329
+ inline double wait_completion() {
330
+ double secs = 0.0;
331
+ Check(gptoss_metal_command_buffer_wait_completion(&command_buffer_, &secs), "wait completion");
332
+ return secs;
333
+ }
334
+
335
+ inline const gptoss_metal_command_buffer* handle() const noexcept { return &command_buffer_; }
336
+
337
+ private:
338
+ gptoss_metal_command_buffer command_buffer_{};
339
+ };
340
+
341
+ } // namespace metal
342
+ } // namespace gptoss
gptoss_kernels/source/include/internal/model.h ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifndef __cplusplus
4
+ #include <stdatomic.h>
5
+ #endif
6
+ #include <stdbool.h>
7
+ #include <stddef.h>
8
+ #include <stdint.h>
9
+
10
+ #include "internal/metal.h"
11
+
12
+
13
+ struct gptoss_tokenizer {
14
+ #ifndef __cplusplus
15
+ atomic_uint_least64_t ref_count;
16
+ #else
17
+ uint_least64_t ref_count;
18
+ #endif
19
+
20
+ void* mapping_ptr;
21
+ size_t mapping_size;
22
+
23
+ const char* regex_ptr;
24
+ const char* tokens_ptr;
25
+
26
+ uint32_t num_text_tokens;
27
+ uint32_t num_special_tokens;
28
+
29
+ uint32_t special_token_id[gptoss_special_token_max - 1];
30
+ };
31
+
32
+ struct gptoss_model {
33
+ #ifndef __cplusplus
34
+ atomic_uint_least64_t ref_count;
35
+ #else
36
+ uint_least64_t ref_count;
37
+ #endif
38
+
39
+ struct gptoss_tokenizer* tokenizer;
40
+
41
+ void* mapping_ptr;
42
+ size_t mapping_size;
43
+
44
+ uint32_t context_length;
45
+ uint32_t num_blocks;
46
+ uint32_t num_experts;
47
+ uint32_t num_active_experts;
48
+ uint32_t embedding_dim;
49
+ uint32_t mlp_dim;
50
+ float swiglu_limit;
51
+ uint32_t head_dim;
52
+ uint32_t num_heads;
53
+ uint32_t num_kv_heads;
54
+ uint32_t attention_window;
55
+ float rope_theta;
56
+ float interpolation_scale;
57
+ float yarn_offset;
58
+ float yarn_scale;
59
+ float yarn_multiplier;
60
+ float rmsnorm_epsilon;
61
+
62
+ uint32_t vocabulary_size;
63
+
64
+ bool lock_memory;
65
+
66
+ size_t weights_size;
67
+ size_t allocation_size;
68
+
69
+ // Metal objects
70
+ struct gptoss_metal_device device;
71
+ size_t max_threadgroups;
72
+ struct gptoss_metal_command_queue command_queue;
73
+ struct gptoss_metal_library library;
74
+ struct gptoss_metal_function bf16_f32_embeddings_fn;
75
+ struct gptoss_metal_function f32_bf16w_rmsnorm_fn;
76
+ struct gptoss_metal_function f32_bf16w_matmul_fn;
77
+ struct gptoss_metal_function f32_bf16w_matmul_qkv_fn;
78
+ struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn;
79
+ struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn;
80
+ struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn;
81
+ struct gptoss_metal_function f32_bf16w_unembedding_fn;
82
+ struct gptoss_metal_function f32_rope_fn;
83
+ struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn;
84
+ struct gptoss_metal_function f32_mf4w_moe_matmul_fn;
85
+ struct gptoss_metal_function f32_accumulate_e4_fn;
86
+ struct gptoss_metal_function f32_scatter_e4_fn;
87
+ struct gptoss_metal_function f32_mf4w_moe_dense_matmul_swiglu_fn;
88
+ struct gptoss_metal_function f32_mf4w_moe_dense_matmul_fn;
89
+ struct gptoss_metal_function f32_gather_and_accumulate_e4_fn;
90
+ struct gptoss_metal_function f32_expert_routing_metadata_fn;
91
+ struct gptoss_metal_function f32_topk_softmax_e32_k4_fn;
92
+ struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;
93
+ struct gptoss_metal_function f32_sdpa_q8_d64_fn;
94
+ struct gptoss_metal_function f32_softmax_fn;
95
+ struct gptoss_metal_function f32_sample_fn;
96
+
97
+ size_t per_block_shared_weights_size;
98
+ size_t per_expert_block_weight_size;
99
+
100
+ size_t embeddings_threadgroup_size;
101
+ size_t attn_qkv_threadgroup_size;
102
+ size_t attn_out_threadgroup_size;
103
+ size_t mlp_gate_threadgroup_size;
104
+ size_t mlp_swiglu_threadgroup_size;
105
+ size_t mlp_out_threadgroup_size;
106
+ size_t mlp_acc_threadgroup_size;
107
+ size_t unembedding_threadgroup_size;
108
+
109
+ size_t attn_rmsnorm_gain_offset;
110
+ size_t attn_qkv_weight_offset;
111
+ size_t attn_qkv_bias_offset;
112
+ size_t attn_sdpa_sink_offset;
113
+ size_t attn_out_weight_offset;
114
+ size_t attn_out_bias_offset;
115
+ size_t mlp_rmsnorm_gain_offset;
116
+ size_t mlp_gate_weight_offset;
117
+ size_t mlp_gate_bias_offset;
118
+ size_t mlp_swiglu_scale_offset;
119
+ size_t mlp_swiglu_bias_offset;
120
+ size_t mlp_out_block_offset;
121
+ size_t mlp_out_scale_offset;
122
+ size_t mlp_out_bias_offset;
123
+ size_t rmsnorm_weight_offset;
124
+ size_t unembedding_weight_offset;
125
+
126
+ // Buffer with non-MoE weights. Includes MoE gates, embeddings/unembeddings.
127
+ struct gptoss_metal_buffer shared_weight_buffer;
128
+ // num_blocks per-block buffers with MoE weights to follow.
129
+ struct gptoss_metal_buffer block_weight_buffers[];
130
+ };
131
+
132
+ #define GPTOSS_DEFAULT_BATCH_SIZE 128
133
+
134
+ struct gptoss_context {
135
+ #ifndef __cplusplus
136
+ atomic_uint_least64_t ref_count;
137
+ #else
138
+ uint_least64_t ref_count;
139
+ #endif
140
+
141
+ struct gptoss_model* model;
142
+ // Number of tokens processed in the context.
143
+ size_t num_tokens;
144
+ // Number of tokens in the KV cache.
145
+ size_t num_kv_tokens;
146
+ // Length of the context.
147
+ size_t max_tokens;
148
+ // Maximum number of tokens that can be processed in a single batch.
149
+ // Activation buffers are allocated with this size.
150
+ size_t max_batch_tokens;
151
+
152
+
153
+ size_t kvcache_size;
154
+ size_t allocation_size;
155
+
156
+ // Activation buffers.
157
+ // TODO: merge into a single buffer.
158
+ struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
159
+ struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
160
+ struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
161
+ struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
162
+ struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
163
+ struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
164
+ struct gptoss_metal_buffer expert_offset_buffer; // MoE expert histograms cumsum
165
+ struct gptoss_metal_buffer token_to_expert_routing_buffer; // MoE token to expert routing
166
+ struct gptoss_metal_buffer swiglu_input_buffer; // MLP+SwiGLU input for prefill.
167
+ struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
168
+ struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
169
+
170
+ // Input/output buffers.
171
+ struct gptoss_metal_buffer control_buffer;
172
+ struct gptoss_metal_buffer token_buffer; // uint32 token IDs
173
+ struct gptoss_metal_buffer score_buffer; // unembedding outputs
174
+ struct gptoss_metal_buffer prob_buffer;
175
+ struct gptoss_metal_buffer sum_buffer;
176
+ struct gptoss_metal_buffer argmax_buffer;
177
+ struct gptoss_metal_buffer kvcache_buffer;
178
+ };
gptoss_kernels/source/include/internal/rng.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdint.h>
4
+
5
+ inline static uint32_t rng_squares32(uint64_t offset, uint64_t seed) {
6
+ const uint64_t y = offset * seed;
7
+ const uint64_t z = y + seed;
8
+
9
+ /* Round 1 */
10
+ uint64_t x = y * y + y;
11
+ x = (x >> 32) | (x << 32);
12
+
13
+ /* Round 2 */
14
+ x = x * x + z;
15
+ x = (x >> 32) | (x << 32);
16
+
17
+ /* Round 3 */
18
+ x = x * x + y;
19
+ x = (x >> 32) | (x << 32);
20
+
21
+ /* Round 4 */
22
+ x = x * x + z;
23
+ return (uint32_t) (x >> 32);
24
+ }
gptoss_kernels/source/include/internal/rng.hpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cstdint>
4
+
5
+ namespace gptoss {
6
+
7
+ namespace rng {
8
+
9
+ inline static std::uint32_t squares32(std::uint64_t offset, std::uint64_t seed) {
10
+ const std::uint64_t y = offset * seed;
11
+ const std::uint64_t z = y + seed;
12
+
13
+ /* Round 1 */
14
+ std::uint64_t x = y * y + y;
15
+ x = (x >> 32) | (x << 32);
16
+
17
+ /* Round 2 */
18
+ x = x * x + z;
19
+ x = (x >> 32) | (x << 32);
20
+
21
+ /* Round 3 */
22
+ x = x * x + y;
23
+ x = (x >> 32) | (x << 32);
24
+
25
+ /* Round 4 */
26
+ x = x * x + z;
27
+ return static_cast<uint32_t>(x >> 32);
28
+ }
29
+
30
+ } // namespace rng
31
+
32
+ } // namespace gptoss
gptoss_kernels/source/include/internal/storage.h ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdbool.h>
4
+ #include <stdint.h>
5
+
6
+ struct gptoss_file_header {
7
+ char magic[12];
8
+ uint32_t zero;
9
+ };
10
+
11
+ struct gptoss_gptoss_model_header {
12
+ uint32_t context_length;
13
+ uint32_t num_blocks;
14
+ uint32_t num_experts;
15
+ uint32_t num_active_experts;
16
+ uint32_t embedding_dim;
17
+ uint32_t mlp_dim;
18
+ float swiglu_limit;
19
+ uint32_t head_dim;
20
+ uint32_t num_heads;
21
+ uint32_t num_kv_heads;
22
+ uint32_t attention_window;
23
+ float rope_theta;
24
+ float interpolation_scale;
25
+ float yarn_offset;
26
+ float yarn_scale;
27
+ float yarn_multiplier;
28
+ float rmsnorm_epsilon;
29
+ };
30
+
31
+ struct gptoss_tiktoken_tokenizer_header {
32
+ uint32_t num_special_tokens;
33
+ uint32_t num_text_tokens;
34
+ uint32_t regex_size;
35
+ uint32_t tokens_size;
36
+ };
gptoss_kernels/source/include/internal/uuid.h ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <stdbool.h>
4
+ #include <stdint.h>
5
+ #include <string.h>
6
+
7
+ #include "internal/macros.h"
8
+
9
+
10
+ struct GPTOSS_DENSELY_PACKED_STRUCTURE gptoss_uuid {
11
+ uint8_t bytes[16];
12
+ };
13
+ static_assert(sizeof(struct gptoss_uuid) == 16, "UUID size is not 16 bytes");
14
+
15
+
16
+ #define UUID_FORMAT "%02X%02X%02X%02X-%02X%02X-%02X%02X-%02X%02X-%02X%02X%02X%02X%02X%02X"
17
+ #define UUID_ARGS(uuid) (uuid).bytes[0], (uuid).bytes[1], (uuid).bytes[2], (uuid).bytes[3], \
18
+ (uuid).bytes[4], (uuid).bytes[5], (uuid).bytes[6], (uuid).bytes[7], (uuid).bytes[8], (uuid).bytes[9], \
19
+ (uuid).bytes[10], (uuid).bytes[11], (uuid).bytes[12], (uuid).bytes[13], (uuid).bytes[14], (uuid).bytes[15]
20
+
21
+ static inline bool gptoss_is_gptoss_model_uuid(const struct gptoss_uuid* uuid) {
22
+ return memcmp(
23
+ &(struct gptoss_uuid) {0xDF, 0x52, 0xDC, 0x86, 0x17, 0x89, 0x4E, 0xD0, 0xA2, 0x95, 0x66, 0xF1, 0x05, 0x08, 0x14, 0x5B},
24
+ uuid,
25
+ sizeof(struct gptoss_uuid)) == 0;
26
+ }
27
+
28
+ static inline bool gptoss_is_applegpu_layout_uuid(const struct gptoss_uuid* uuid) {
29
+ return memcmp(
30
+ &(struct gptoss_uuid) {0x22, 0x91, 0x77, 0xA8, 0x57, 0x75, 0x42, 0x68, 0xBF, 0xD8, 0xD5, 0x88, 0xB3, 0x51, 0xC5, 0x6D},
31
+ uuid,
32
+ sizeof(struct gptoss_uuid)) == 0;
33
+ }
34
+
35
+ static inline bool gptoss_is_tiktoken_tokenizer_uuid(const struct gptoss_uuid* uuid) {
36
+ return memcmp(
37
+ &(struct gptoss_uuid) {0x74, 0x01, 0xAD, 0xED, 0x2A, 0x95, 0x40, 0xCB, 0xB7, 0x82, 0x9C, 0xCE, 0xBA, 0xAF, 0xE7, 0x2B},
38
+ uuid,
39
+ sizeof(struct gptoss_uuid)) == 0;
40
+ }
41
+
42
+ static inline enum gptoss_special_token gptoss_special_token_decode_uuid(const struct gptoss_uuid* uuid) {
43
+ if (memcmp(
44
+ &(struct gptoss_uuid) {0x55, 0xA7, 0x7C, 0x2F, 0x8A, 0x01, 0x4C, 0x54, 0x8A, 0xC2, 0x31, 0x3B, 0xFC, 0x7E, 0x20, 0x8D},
45
+ uuid,
46
+ sizeof(struct gptoss_uuid)) == 0)
47
+ {
48
+ return gptoss_special_token_start;
49
+ } else if (memcmp(
50
+ &(struct gptoss_uuid) {0x16, 0xE4, 0x04, 0x31, 0xF4, 0x7F, 0x4B, 0x22, 0xB5, 0x9B, 0x8B, 0x27, 0x8F, 0xC3, 0x0A, 0x54},
51
+ uuid,
52
+ sizeof(struct gptoss_uuid)) == 0)
53
+ {
54
+ return gptoss_special_token_message;
55
+ } else if (memcmp(
56
+ &(struct gptoss_uuid) {0xFC, 0xAC, 0x2F, 0x6D, 0x47, 0x05, 0x4F, 0x6B, 0xB2, 0x28, 0x64, 0x2A, 0xCC, 0xAC, 0x72, 0x38},
57
+ uuid,
58
+ sizeof(struct gptoss_uuid)) == 0)
59
+ {
60
+ return gptoss_special_token_end;
61
+ } else if (memcmp(
62
+ &(struct gptoss_uuid) {0xF7, 0x99, 0xFF, 0x69, 0x19, 0x92, 0x43, 0xC4, 0xA3, 0xD8, 0xD8, 0x31, 0xF4, 0x75, 0xDC, 0x75},
63
+ uuid,
64
+ sizeof(struct gptoss_uuid)) == 0)
65
+ {
66
+ return gptoss_special_token_return;
67
+ } else if (memcmp(
68
+ &(struct gptoss_uuid) {0xE1, 0x5B, 0xA7, 0x02, 0x28, 0xC4, 0x42, 0x92, 0xAB, 0x8F, 0xFF, 0xA4, 0x34, 0x70, 0x91, 0x28},
69
+ uuid,
70
+ sizeof(struct gptoss_uuid)) == 0)
71
+ {
72
+ return gptoss_special_token_refusal;
73
+ } else if (memcmp(
74
+ &(struct gptoss_uuid) {0xC0, 0xBB, 0x14, 0xC7, 0x60, 0x22, 0x49, 0xDA, 0xAD, 0x08, 0x79, 0x2D, 0x67, 0xE8, 0xB4, 0x70},
75
+ uuid,
76
+ sizeof(struct gptoss_uuid)) == 0)
77
+ {
78
+ return gptoss_special_token_constrain;
79
+ } else if (memcmp(
80
+ &(struct gptoss_uuid) {0xFD, 0x3D, 0xDA, 0x11, 0xC8, 0xAB, 0x40, 0x33, 0x87, 0x6E, 0xD9, 0x3D, 0xEB, 0x17, 0x2C, 0x93},
81
+ uuid,
82
+ sizeof(struct gptoss_uuid)) == 0)
83
+ {
84
+ return gptoss_special_token_channel;
85
+ } else if (memcmp(
86
+ &(struct gptoss_uuid) {0x12, 0x20, 0xF7, 0x96, 0xE3, 0x88, 0x4D, 0xE5, 0xB4, 0x87, 0xFE, 0x2E, 0xB5, 0xFE, 0x03, 0xC0},
87
+ uuid,
88
+ sizeof(struct gptoss_uuid)) == 0)
89
+ {
90
+ return gptoss_special_token_call;
91
+ } else if (memcmp(
92
+ &(struct gptoss_uuid) {0x07, 0xD7, 0xDA, 0x55, 0xB3, 0x46, 0x4C, 0xFF, 0x8B, 0x37, 0x7C, 0xEF, 0xAC, 0xF8, 0xA3, 0xE8},
93
+ uuid,
94
+ sizeof(struct gptoss_uuid)) == 0)
95
+ {
96
+ return gptoss_special_token_untrusted;
97
+ } else if (memcmp(
98
+ &(struct gptoss_uuid) {0xF2, 0x65, 0xBD, 0x9C, 0xC7, 0x17, 0x46, 0x9E, 0xA4, 0x47, 0x92, 0x06, 0x87, 0xD6, 0x5D, 0x90},
99
+ uuid,
100
+ sizeof(struct gptoss_uuid)) == 0)
101
+ {
102
+ return gptoss_special_token_end_untrusted;
103
+ } else if (memcmp(
104
+ &(struct gptoss_uuid) {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
105
+ uuid,
106
+ sizeof(struct gptoss_uuid)) == 0)
107
+ {
108
+ // Suppress warning
109
+ return gptoss_special_token_invalid;
110
+ } else {
111
+ GPTOSS_LOG_WARNING("unsupported special token " UUID_FORMAT, UUID_ARGS(*uuid));
112
+ return gptoss_special_token_invalid;
113
+ }
114
+ }
gptoss_kernels/source/log.c ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h> // assert
2
+ #include <stdarg.h> // va_list, va_copy, va_end
3
+ #include <stdio.h> // vsnprintf
4
+ #include <stdlib.h> // malloc, free
5
+
6
+ #include <unistd.h> // STDERR_FILENO
7
+
8
+
9
+
10
+ #define GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE 16384
11
+
12
+ void gptoss_format_log(const char* format, va_list args) {
13
+ char stack_buffer[GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE];
14
+ char* heap_buffer = NULL;
15
+
16
+ va_list args_copy;
17
+ va_copy(args_copy, args);
18
+
19
+ const int vsnprintf_result = vsnprintf(stack_buffer, GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE, format, args);
20
+ assert(vsnprintf_result >= 0);
21
+
22
+ // At least a partially formatted buffer is ready.
23
+ char* message_buffer = &stack_buffer[0];
24
+ size_t message_size = (size_t) vsnprintf_result;
25
+ if (message_size > GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE) {
26
+ heap_buffer = malloc(message_size);
27
+ if (heap_buffer == NULL) {
28
+ // Fall back to the truncated message in the on-stack buffer.
29
+ message_size = GPTOSS_ON_STACK_FORMAT_BUFFER_SIZE;
30
+ } else {
31
+ // Use the full message in the in-heap buffer.
32
+ vsnprintf(heap_buffer, message_size, format, args_copy);
33
+ message_buffer = heap_buffer;
34
+ }
35
+ }
36
+
37
+ ssize_t bytes_written;
38
+ do {
39
+ bytes_written = write(STDERR_FILENO, message_buffer, message_size);
40
+ if (bytes_written > 0) {
41
+ assert((size_t) bytes_written <= message_size);
42
+ message_buffer += bytes_written;
43
+ message_size -= bytes_written;
44
+ }
45
+ } while (bytes_written >= 0 && message_size != 0);
46
+
47
+ cleanup:
48
+ free(heap_buffer);
49
+ va_end(args_copy);
50
+ }
gptoss_kernels/source/matmul.metal ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_atomic>
2
+ #include <metal_compute>
3
+ #include <metal_integer>
4
+ #include <metal_math>
5
+ #include <metal_simdgroup>
6
+ #include <metal_stdlib>
7
+
8
+ #include <internal/kernel-args.h>
9
+
10
+ #pragma METAL fp math_mode(safe)
11
+ #pragma METAL fp contract(off)
12
+
13
+
14
+ // Each simdgroup reduces all channels of the input and computes a single channel of the output
15
+ // + Efficient synchronization
16
+ // + Sequential memory access within a warp
17
+ // Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
18
+ // + Reuse input vector from threadgroup memory
19
+ // + Avoid synchronization across warps when doing reduction
20
+
21
+ kernel void gptoss_f32_bf16w_matmul(
22
+ constant gptoss_matmul_args& args [[ buffer(0) ]],
23
+ const device float4* input [[ buffer(1) ]],
24
+ const device bfloat4* weight [[ buffer(2) ]],
25
+ const device bfloat* bias [[ buffer(3) ]],
26
+ device float* output [[ buffer(4) ]],
27
+ const device gptoss_control* control [[ buffer(5) ]],
28
+ uint2 gid [[threadgroup_position_in_grid]],
29
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
30
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
31
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
32
+ {
33
+ const uint simdgroup_size = 32;
34
+ if (control->abort != 0) {
35
+ return;
36
+ }
37
+
38
+ const uint num_column_vecs = args.num_column_vecs;
39
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
40
+
41
+ input += gid.y * num_column_vecs + simdgroup_tid;
42
+ weight += num_column_vecs * row + simdgroup_tid;
43
+ bias += row;
44
+ output += gid.y * args.num_rows + row;
45
+
46
+ uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
47
+
48
+ float4 sum4 = 0.0f;
49
+ do {
50
+ const bfloat4 w = *weight;
51
+ const float4 i = *input;
52
+ sum4 = metal::fma(static_cast<float4>(w), i, sum4);
53
+
54
+ weight += simdgroup_size;
55
+ input += simdgroup_size;
56
+ } while (--num_iter != 0);
57
+ const float2 sum2 = sum4.xy + sum4.zw;
58
+ float sum = sum2.x + sum2.y;
59
+ sum = metal::simd_sum(sum);
60
+ if (metal::simd_is_first()) {
61
+ sum += static_cast<float>(*bias);
62
+ if (args.add) {
63
+ *output += sum;
64
+ } else {
65
+ *output = sum;
66
+ }
67
+ }
68
+ }
69
+
70
+ kernel void gptoss_f32_bf16w_matmul_qkv(
71
+ constant gptoss_qkv_args& args [[ buffer(0) ]],
72
+ const device float4* input [[ buffer(1) ]],
73
+ const device bfloat4* weight [[ buffer(2) ]],
74
+ const device bfloat* bias [[ buffer(3) ]],
75
+ device float* q [[ buffer(4) ]],
76
+ device float* kv [[ buffer(5) ]],
77
+ const device gptoss_control* control [[ buffer(6) ]],
78
+ threadgroup void* scratch [[ threadgroup(0) ]],
79
+ uint2 gid [[threadgroup_position_in_grid]],
80
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
81
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
82
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
83
+ {
84
+ const uint simdgroup_size = 32;
85
+ const uint head_dim = 64;
86
+ const uint num_q_heads = 64;
87
+ const uint num_kv_heads = 8;
88
+ if (control->abort != 0) {
89
+ return;
90
+ }
91
+
92
+ const uint num_column_vecs = args.num_column_vecs;
93
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
94
+
95
+ input += gid.y * num_column_vecs + simdgroup_tid;
96
+ weight += num_column_vecs * row + simdgroup_tid;
97
+ bias += row;
98
+ q += gid.y * args.num_rows;
99
+
100
+ uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
101
+
102
+ float4 sum4 = 0.0f;
103
+ do {
104
+ const bfloat4 w = *weight;
105
+ const float4 i = *input;
106
+ sum4 = metal::fma(static_cast<float4>(w), i, sum4);
107
+
108
+ weight += simdgroup_size;
109
+ input += simdgroup_size;
110
+ } while (--num_iter != 0);
111
+ const float2 sum2 = sum4.xy + sum4.zw;
112
+ float sum = sum2.x + sum2.y;
113
+ sum = metal::simd_sum(sum);
114
+ if (metal::simd_is_first()) {
115
+ sum += static_cast<float>(*bias);
116
+ static_cast<threadgroup float*>(scratch)[simdgroup_idx] = sum;
117
+ }
118
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
119
+ if (simdgroup_idx == 0) {
120
+ const uint num_half_simdgroups = num_simdgroups / 2;
121
+ if (simdgroup_tid < num_half_simdgroups) {
122
+ float2 vals = static_cast<const threadgroup float2*>(scratch)[simdgroup_tid];
123
+ const uint idx = gid.x * num_half_simdgroups + simdgroup_tid;
124
+ const uint head_idx = idx / (head_dim / 2);
125
+ const uint token_idx = args.token_offset + gid.y;
126
+ const uint dim_idx = idx % (head_dim / 2);
127
+ if (head_idx < num_q_heads + num_kv_heads) {
128
+ const float dim_idx_val = static_cast<float>(dim_idx);
129
+ const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale);
130
+ const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
131
+ const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset));
132
+ const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
133
+
134
+ const float phi = static_cast<float>(token_idx) * inv_freq;
135
+ const float yarn_multiplier = args.yarn_multiplier;
136
+ float cosphi;
137
+ const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
138
+ cosphi *= yarn_multiplier;
139
+
140
+ vals = (float2) {
141
+ vals.x * cosphi - vals.y * sinphi,
142
+ vals.x * sinphi + vals.y * cosphi,
143
+ };
144
+ }
145
+ if (head_idx < num_q_heads) {
146
+ reinterpret_cast<device float2*>(q)[idx] = vals;
147
+ } else if (head_idx < num_q_heads + num_kv_heads) {
148
+ const uint h = head_idx - num_q_heads;
149
+ reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals;
150
+ } else {
151
+ const uint h = head_idx - num_q_heads - num_kv_heads;
152
+ reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals;
153
+ }
154
+ }
155
+ }
156
+ }
157
+
158
+ kernel void gptoss_f32_bf16w_unembedding(
159
+ constant gptoss_unembedding_args& args [[ buffer(0) ]],
160
+ const device float4* input [[ buffer(1) ]],
161
+ const device bfloat4* weight [[ buffer(2) ]],
162
+ device float* output [[ buffer(3) ]],
163
+ device metal::atomic_ulong* argmax [[ buffer(4) ]],
164
+ const device gptoss_control* control [[ buffer(5) ]],
165
+ uint2 gid [[threadgroup_position_in_grid]],
166
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
167
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
168
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
169
+ {
170
+ const uint simdgroup_size = 32;
171
+ threadgroup uint2 threadgroup_buffer[32];
172
+ if (control->abort != 0) {
173
+ return;
174
+ }
175
+
176
+ const uint num_column_vecs = args.num_column_vecs;
177
+ const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;
178
+ const uint row_end = metal::min(gid.x * args.num_rows_per_threadgroup + args.num_rows_per_threadgroup, args.num_rows);
179
+ const uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
180
+
181
+ input += gid.y * num_column_vecs + simdgroup_tid;
182
+ weight += num_column_vecs * row_start + simdgroup_tid;
183
+ output += gid.y * args.num_rows + row_start;
184
+
185
+ uint2 row_sum{0xFFFFFFFFul, 0xFFFFFFFFul};
186
+ for (uint row = row_start; row < row_end; row += num_simdgroups) {
187
+ uint n = num_iter;
188
+
189
+ float4 sum4 = 0.0f;
190
+ do {
191
+ const bfloat4 w = *weight;
192
+ const float4 i = *input;
193
+
194
+ sum4 = metal::fma(static_cast<float4>(w), i, sum4);
195
+
196
+ weight += simdgroup_size;
197
+ input += simdgroup_size;
198
+ } while (--n != 0);
199
+ input -= num_iter * simdgroup_size;
200
+ weight -= num_iter * simdgroup_size;
201
+
202
+ const float2 sum2 = sum4.xy + sum4.zw;
203
+ float sum = sum2.x + sum2.y;
204
+ sum = metal::simd_sum(sum);
205
+ uint sum_bits = as_type<uint>(sum);
206
+ if (static_cast<int>(sum_bits) >= 0) {
207
+ sum_bits ^= 0x7FFFFFFFu;
208
+ }
209
+ row_sum = as_type<uint2>(metal::min(as_type<ulong>(row_sum), as_type<ulong>(uint2{row, sum_bits})));
210
+ if (metal::simd_is_first()) {
211
+ *output = sum;
212
+ }
213
+
214
+ weight += num_column_vecs * num_simdgroups;
215
+ output += num_simdgroups;
216
+ }
217
+ if (metal::simd_is_first()) {
218
+ threadgroup_buffer[simdgroup_idx] = row_sum;
219
+ }
220
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
221
+ if (simdgroup_idx == 0) {
222
+ // Min-Reduce threadgroup_buffer
223
+ if (simdgroup_tid < num_simdgroups) {
224
+ row_sum = threadgroup_buffer[simdgroup_tid];
225
+ }
226
+ const uint sum_bits = row_sum.y;
227
+ const uint sum_bits_min = metal::simd_min(sum_bits);
228
+ const uint row_min = metal::simd_min(sum_bits == sum_bits_min ? row_sum.x : 0xFFFFFFFFu);
229
+ if (metal::simd_is_first()) {
230
+ const uint2 threadgroup_output{row_min, sum_bits_min};
231
+ atomic_min_explicit(&argmax[gid.y], as_type<ulong>(threadgroup_output), metal::memory_order_relaxed);
232
+ }
233
+ }
234
+ }
235
+
236
+ // Current constraints for the dense matmul kernel:
237
+ // 1- All B* and Sg_* are a multiple of 8.
238
+ // 2- Bm is divisible by Sg_n and Bn is divisible by Sg_n.
239
+ // 3- M, N and K are all divisible by 8..
240
+ template <uint Bm, uint Bn, uint Bk, uint Sg_Bm, uint Sg_Bn, uint add = 0>
241
+ inline void _gptoss_f32_bf16w_dense_matmul_impl(
242
+ constant gptoss_dense_matmul_args& args, const device float* lhs,
243
+ const device bfloat* rhs, const device bfloat* __restrict__ bias,
244
+ device float* out, const device gptoss_control* control, threadgroup float* scratch, threadgroup float* bias_tile,
245
+ uint sg_id, uint sg_count_per_tg, uint3 gid, uint3 tg_id, uint3 local_tid,
246
+ uint3 threadgroup_size) {
247
+
248
+ if (control->abort != 0) {
249
+ return;
250
+ }
251
+
252
+ // The kernel assumes that M, K, and N are divisible by 8.
253
+ const uint M = args.m;
254
+ const uint K = args.k;
255
+ const uint N = args.n;
256
+ static_assert((Bm % 8u) == 0u, "Bm must be a multiple of 8");
257
+ static_assert((Bn % 8u) == 0u, "Bn must be a multiple of 8");
258
+ static_assert((Bk % 8u) == 0u, "Bk must be a multiple of 8");
259
+ static_assert((Sg_Bm % 8u) == 0u, "Bk must be a multiple of 8");
260
+ static_assert((Sg_Bn % 8u) == 0u, "Bk must be a multiple of 8");
261
+ static_assert((Bn % Sg_Bn) == 0u, "Bn must be a multiple of Sg_Bn");
262
+ static_assert((Bm % Sg_Bm) == 0u, "Bm must be a multiple of Sg_Bm");
263
+
264
+ // Get row and col tg.
265
+ const uint row_tg = tg_id.y;
266
+ const uint col_tg = tg_id.x;
267
+ // Get row and col local tid.
268
+ const uint row_tg_offset = row_tg * Bm;
269
+ const uint col_tg_offset = col_tg * Bn;
270
+
271
+ const uint sg_col_count = Bn / Sg_Bn;
272
+ const uint row_sg = sg_id / sg_col_count;
273
+ const uint col_sg = sg_id % sg_col_count;
274
+
275
+ const uint row_sg_offset = row_sg * Sg_Bm;
276
+ const uint col_sg_offset = col_sg * Sg_Bn;
277
+ constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
278
+ // Create an array of simdgroup_float8x8 to hold temp results.
279
+ metal::simdgroup_float8x8 OutTiles[temp_result_size];
280
+ #pragma clang loop unroll(full)
281
+ for (uint i = 0; i < temp_result_size; i++) {
282
+ OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(
283
+ static_cast<float>(0.0));
284
+ }
285
+
286
+ for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
287
+ #pragma clang loop unroll(full)
288
+ for (uint k = 0; k < Bk; k += 8) {
289
+ #pragma clang loop unroll(full)
290
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
291
+ // const uint m_subtile = row_sg_offset + m_subtile_;
292
+ // const uint row_index_in_out_tile = (m_subtile - row_sg_offset) / 8;
293
+ const uint row_index_in_out_tile = m_subtile_ / 8;
294
+ metal::simdgroup_float8x8 LHStile;
295
+ const uint k_id = k + k_offset;
296
+ const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_;
297
+ metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset));
298
+ metal::simdgroup_bfloat8x8 RHStile;
299
+ #pragma clang loop unroll(full)
300
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
301
+ const uint col_index_in_out_tile = n_subtile_ / 8;
302
+ const uint current_index_out_tile =
303
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
304
+ const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_;
305
+ simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true);
306
+ // If rhs was not transposed, use the following instead:
307
+ // simdgroup_load(RHStile, rhs, N, ulong2(col_offset, k_id));
308
+ simdgroup_multiply_accumulate(OutTiles[current_index_out_tile],
309
+ LHStile, RHStile,
310
+ OutTiles[current_index_out_tile]);
311
+ }
312
+ }
313
+ }
314
+ }
315
+ // Epilogue.
316
+ #pragma clang loop unroll(full)
317
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
318
+ const uint col_index_in_out_tile = n_subtile_ / 8;
319
+ const uint local_col_offset = col_sg_offset + n_subtile_;
320
+ #pragma clang loop unroll(full)
321
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
322
+ const uint row_index_in_out_tile = m_subtile_ / 8;
323
+ const uint local_row_offset = row_sg_offset + m_subtile_;
324
+ const uint current_index_out_tile =
325
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
326
+ simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
327
+ ulong2(local_col_offset, local_row_offset));
328
+ }
329
+ }
330
+ // TODO(ibahmed): vectorize these loads an maybe unroll the loop.
331
+ const uint thread_count_per_tg =
332
+ threadgroup_size.x * threadgroup_size.y * threadgroup_size.z;
333
+ for (uint c_local = local_tid.x; c_local < Bn;
334
+ c_local += thread_count_per_tg) {
335
+ const uint c_global = col_tg_offset + c_local;
336
+ bias_tile[c_local] =
337
+ (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
338
+ }
339
+
340
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
341
+
342
+ // TODO(ibahmed): vectorize these stores and maybe unroll the loop.
343
+ for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {
344
+ const uint r = idx / Bn;
345
+ const uint c = idx % Bn;
346
+
347
+ const uint out_row = row_tg_offset + r;
348
+ const uint out_col = col_tg_offset + c;
349
+
350
+ if (out_row < M && out_col < N) {
351
+ float acc = scratch[idx] + bias_tile[c];
352
+ if (add) {
353
+ acc += out[out_row * N + out_col];
354
+ }
355
+ out[out_row * N + out_col] = acc;
356
+ }
357
+ }
358
+ }
359
+
360
+ kernel void gptoss_f32_bf16w_dense_matmul_qkv(
361
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
362
+ const device float* lhs [[buffer(1)]],
363
+ const device bfloat* rhs [[buffer(2)]],
364
+ const device bfloat* __restrict__ bias [[buffer(3)]],
365
+ device float* out [[buffer(4)]],
366
+ const device gptoss_control* control [[buffer(5)]],
367
+ uint sg_id [[simdgroup_index_in_threadgroup]],
368
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
369
+ uint3 gid [[thread_position_in_grid]],
370
+ uint3 tg_id [[threadgroup_position_in_grid]],
371
+ uint3 local_tid [[thread_position_in_threadgroup]],
372
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
373
+ threadgroup float scratch[QKV_Bm * QKV_Bn];
374
+ threadgroup float bias_tile[QKV_Bn];
375
+ _gptoss_f32_bf16w_dense_matmul_impl<QKV_Bm, QKV_Bn, QKV_Bk, QKV_Sg_Bm,
376
+ QKV_Sg_Bn>(
377
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
378
+ gid, tg_id, local_tid, threadgroup_size);
379
+ }
380
+
381
+ kernel void gptoss_f32_bf16w_dense_matmul_attn_output(
382
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
383
+ const device float* lhs [[buffer(1)]],
384
+ const device bfloat* rhs [[buffer(2)]],
385
+ const device bfloat* __restrict__ bias [[buffer(3)]],
386
+ device float* out [[buffer(4)]],
387
+ const device gptoss_control* control [[buffer(5)]],
388
+ uint sg_id [[simdgroup_index_in_threadgroup]],
389
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
390
+ uint3 gid [[thread_position_in_grid]],
391
+ uint3 tg_id [[threadgroup_position_in_grid]],
392
+ uint3 local_tid [[thread_position_in_threadgroup]],
393
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
394
+ threadgroup float scratch[ATTN_OUTPUT_Bm * ATTN_OUTPUT_Bn];
395
+ threadgroup float bias_tile[ATTN_OUTPUT_Bn];
396
+ _gptoss_f32_bf16w_dense_matmul_impl<ATTN_OUTPUT_Bm, ATTN_OUTPUT_Bn,
397
+ ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm,
398
+ ATTN_OUTPUT_Sg_Bn, /*add=*/1>(
399
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
400
+ gid, tg_id, local_tid, threadgroup_size);
401
+ }
402
+
403
+ kernel void gptoss_f32_bf16w_dense_matmul_mlp_gate(
404
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
405
+ const device float* lhs [[buffer(1)]],
406
+ const device bfloat* rhs [[buffer(2)]],
407
+ const device bfloat* __restrict__ bias [[buffer(3)]],
408
+ device float* out [[buffer(4)]],
409
+ const device gptoss_control* control [[buffer(5)]],
410
+ uint sg_id [[simdgroup_index_in_threadgroup]],
411
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
412
+ uint3 gid [[thread_position_in_grid]],
413
+ uint3 tg_id [[threadgroup_position_in_grid]],
414
+ uint3 local_tid [[thread_position_in_threadgroup]],
415
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
416
+ threadgroup float scratch[MLP_GATE_Bm * MLP_GATE_Bn];
417
+ threadgroup float bias_tile[MLP_GATE_Bn];
418
+ _gptoss_f32_bf16w_dense_matmul_impl<MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk,
419
+ MLP_GATE_Sg_Bm, MLP_GATE_Sg_Bn>(
420
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
421
+ gid, tg_id, local_tid, threadgroup_size);
422
+ }
gptoss_kernels/source/metal-kernels.c ADDED
@@ -0,0 +1,1518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <inttypes.h>
2
+ #include <stddef.h>
3
+ #include <stdint.h>
4
+ #include <math.h>
5
+
6
+ #include <internal/kernel-args.h>
7
+ #include <internal/log.h>
8
+ #include <internal/math.h>
9
+ #include <internal/metal.h>
10
+ #include <internal/metal-kernels.h>
11
+
12
+
13
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
14
+ const struct gptoss_metal_command_buffer* command_buffer,
15
+ const struct gptoss_metal_function* u32_fill_random_fn,
16
+ size_t threadgroup_size,
17
+ size_t max_threadgroups,
18
+ const struct gptoss_metal_buffer* output_buffer,
19
+ size_t output_offset,
20
+ uint64_t num_elements,
21
+ uint64_t rng_seed,
22
+ uint64_t rng_offset)
23
+ {
24
+ if (command_buffer->object == NULL || u32_fill_random_fn->pipeline_state_object == NULL) {
25
+ return gptoss_status_invalid_state;
26
+ }
27
+
28
+ if (threadgroup_size == 0) {
29
+ threadgroup_size = u32_fill_random_fn->max_threadgroup_threads;
30
+ } else if (threadgroup_size > u32_fill_random_fn->max_threadgroup_threads) {
31
+ return gptoss_status_invalid_argument;
32
+ }
33
+
34
+ const size_t num_vecs = num_elements;
35
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
36
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
37
+ const struct gptoss_u32_fill_random_args args = {
38
+ .num_vecs = num_vecs,
39
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
40
+ .seed = rng_seed,
41
+ .offset = rng_offset,
42
+ };
43
+
44
+ return gptoss_metal_command_buffer_encode_launch_kernel(
45
+ command_buffer, u32_fill_random_fn,
46
+ threadgroup_size, 1, 1,
47
+ num_threadgroups, 1, 1,
48
+ sizeof(args), &args,
49
+ 1, &output_buffer, &output_offset,
50
+ /*threadgroup_buffer_size=*/0);
51
+ }
52
+
53
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
54
+ const struct gptoss_metal_command_buffer* command_buffer,
55
+ const struct gptoss_metal_function* f32_fill_random_fn,
56
+ size_t threadgroup_size,
57
+ size_t max_threadgroups,
58
+ const struct gptoss_metal_buffer* output_buffer,
59
+ size_t output_offset,
60
+ uint64_t num_elements,
61
+ uint64_t rng_seed,
62
+ uint64_t rng_offset,
63
+ float rng_min,
64
+ float rng_max)
65
+ {
66
+ if (command_buffer->object == NULL || f32_fill_random_fn->pipeline_state_object == NULL) {
67
+ return gptoss_status_invalid_state;
68
+ }
69
+
70
+ if (threadgroup_size == 0) {
71
+ threadgroup_size = f32_fill_random_fn->max_threadgroup_threads;
72
+ } else if (threadgroup_size > f32_fill_random_fn->max_threadgroup_threads) {
73
+ return gptoss_status_invalid_argument;
74
+ }
75
+
76
+ if (rng_min >= rng_max) {
77
+ return gptoss_status_invalid_argument;
78
+ }
79
+
80
+ const size_t num_vecs = num_elements;
81
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
82
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
83
+ const struct gptoss_f32_fill_random_args args = {
84
+ .num_vecs = num_vecs,
85
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
86
+ .seed = rng_seed,
87
+ .offset = rng_offset,
88
+ .scale = (rng_max - rng_min) * 0x1.0p-32f,
89
+ .bias = (rng_min + rng_max) * 0.5f,
90
+ };
91
+
92
+ return gptoss_metal_command_buffer_encode_launch_kernel(
93
+ command_buffer, f32_fill_random_fn,
94
+ threadgroup_size, 1, 1,
95
+ num_threadgroups, 1, 1,
96
+ sizeof(args), &args,
97
+ 1, &output_buffer, &output_offset,
98
+ /*threadgroup_buffer_size=*/0);
99
+ }
100
+
101
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
102
+ const struct gptoss_metal_command_buffer* command_buffer,
103
+ const struct gptoss_metal_function* bf16_fill_random_fn,
104
+ size_t threadgroup_size,
105
+ size_t max_threadgroups,
106
+ const struct gptoss_metal_buffer* output_buffer,
107
+ size_t output_offset,
108
+ uint64_t num_elements,
109
+ uint64_t rng_seed,
110
+ uint64_t rng_offset,
111
+ float rng_min,
112
+ float rng_max)
113
+ {
114
+ if (command_buffer->object == NULL || bf16_fill_random_fn->pipeline_state_object == NULL) {
115
+ return gptoss_status_invalid_state;
116
+ }
117
+
118
+ if (threadgroup_size == 0) {
119
+ threadgroup_size = bf16_fill_random_fn->max_threadgroup_threads;
120
+ } else if (threadgroup_size > bf16_fill_random_fn->max_threadgroup_threads) {
121
+ return gptoss_status_invalid_argument;
122
+ }
123
+
124
+ if (rng_min >= rng_max) {
125
+ return gptoss_status_invalid_argument;
126
+ }
127
+
128
+ const size_t num_vecs = num_elements;
129
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
130
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
131
+ const struct gptoss_f32_fill_random_args args = {
132
+ .num_vecs = num_vecs,
133
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
134
+ .seed = rng_seed,
135
+ .offset = rng_offset,
136
+ .scale = (rng_max - rng_min) * 0x1.0p-32f,
137
+ .bias = (rng_min + rng_max) * 0.5f,
138
+ };
139
+
140
+ return gptoss_metal_command_buffer_encode_launch_kernel(
141
+ command_buffer, bf16_fill_random_fn,
142
+ threadgroup_size, 1, 1,
143
+ num_threadgroups, 1, 1,
144
+ sizeof(args), &args,
145
+ 1, &output_buffer, &output_offset,
146
+ /*threadgroup_buffer_size=*/0);
147
+ }
148
+
149
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
150
+ const struct gptoss_metal_command_buffer* command_buffer,
151
+ const struct gptoss_metal_function* mf4_f32_convert_fn,
152
+ size_t threadgroup_size,
153
+ size_t max_threadgroups,
154
+ const struct gptoss_metal_buffer* block_buffer,
155
+ const struct gptoss_metal_buffer* scale_buffer,
156
+ const struct gptoss_metal_buffer* output_buffer,
157
+ uint64_t num_elements)
158
+ {
159
+ if (command_buffer->object == NULL || mf4_f32_convert_fn->pipeline_state_object == NULL) {
160
+ return gptoss_status_invalid_state;
161
+ }
162
+
163
+ if (num_elements % 32 != 0) {
164
+ return gptoss_status_invalid_argument;
165
+ }
166
+
167
+ if (threadgroup_size == 0) {
168
+ threadgroup_size = mf4_f32_convert_fn->max_threadgroup_threads;
169
+ } else if (threadgroup_size > mf4_f32_convert_fn->max_threadgroup_threads) {
170
+ return gptoss_status_invalid_argument;
171
+ }
172
+
173
+ const size_t num_vecs = num_elements / 32;
174
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
175
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
176
+ const struct gptoss_convert_args args = {
177
+ .num_vecs = num_vecs,
178
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
179
+ };
180
+
181
+ return gptoss_metal_command_buffer_encode_launch_kernel(
182
+ command_buffer, mf4_f32_convert_fn,
183
+ threadgroup_size, 1, 1,
184
+ num_threadgroups, 1, 1,
185
+ sizeof(args), &args,
186
+ 3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL,
187
+ /*threadgroup_buffer_size=*/0);
188
+ }
189
+
190
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
191
+ const struct gptoss_metal_command_buffer* command_buffer,
192
+ const struct gptoss_metal_function* bf16_f32_embeddings_fn,
193
+ size_t threadgroup_size,
194
+ const struct gptoss_metal_buffer* token_buffer,
195
+ size_t token_offset,
196
+ const struct gptoss_metal_buffer* weight_buffer,
197
+ size_t weight_offset,
198
+ const struct gptoss_metal_buffer* output_buffer,
199
+ size_t output_offset,
200
+ const struct gptoss_metal_buffer* control_buffer,
201
+ size_t control_offset,
202
+ uint32_t num_tokens,
203
+ uint32_t num_channels)
204
+ {
205
+ if (command_buffer->object == NULL || bf16_f32_embeddings_fn->pipeline_state_object == NULL) {
206
+ return gptoss_status_invalid_state;
207
+ }
208
+
209
+ if (num_channels % 4 != 0) {
210
+ return gptoss_status_invalid_argument;
211
+ }
212
+
213
+ if (threadgroup_size == 0) {
214
+ threadgroup_size = bf16_f32_embeddings_fn->max_threadgroup_threads;
215
+ } else if (threadgroup_size > bf16_f32_embeddings_fn->max_threadgroup_threads) {
216
+ return gptoss_status_invalid_argument;
217
+ }
218
+
219
+ const uint32_t num_vecs = num_channels / 4;
220
+ const struct gptoss_embeddings_args args = {
221
+ .num_vecs = num_vecs,
222
+ };
223
+
224
+ return gptoss_metal_command_buffer_encode_launch_kernel(
225
+ command_buffer, bf16_f32_embeddings_fn,
226
+ threadgroup_size, 1, 1,
227
+ num_tokens, 1, 1,
228
+ sizeof(args), &args,
229
+ 4,
230
+ (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer},
231
+ (const size_t[]) {token_offset, weight_offset, output_offset, control_offset},
232
+ /*threadgroup_buffer_size=*/0);
233
+ }
234
+
235
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
236
+ const struct gptoss_metal_command_buffer* command_buffer,
237
+ const struct gptoss_metal_function* f32_bf16w_rmsnorm_fn,
238
+ const struct gptoss_metal_buffer* input_buffer,
239
+ size_t input_offset,
240
+ const struct gptoss_metal_buffer* weight_buffer,
241
+ size_t weight_offset,
242
+ const struct gptoss_metal_buffer* output_buffer,
243
+ size_t output_offset,
244
+ const struct gptoss_metal_buffer* control_buffer,
245
+ size_t control_offset,
246
+ uint32_t num_tokens,
247
+ uint32_t num_channels,
248
+ float epsilon)
249
+ {
250
+ if (command_buffer->object == NULL || f32_bf16w_rmsnorm_fn->pipeline_state_object == NULL) {
251
+ return gptoss_status_invalid_state;
252
+ }
253
+
254
+ if (num_channels % 4 != 0) {
255
+ return gptoss_status_invalid_argument;
256
+ }
257
+
258
+ if (f32_bf16w_rmsnorm_fn->max_threadgroup_threads < 1024) {
259
+ return gptoss_status_unsupported_system;
260
+ }
261
+
262
+ if (f32_bf16w_rmsnorm_fn->simdgroup_threads != 32) {
263
+ return gptoss_status_unsupported_system;
264
+ }
265
+
266
+ const uint32_t num_vecs = num_channels / 4;
267
+ const struct gptoss_rmsnorm_args args = {
268
+ .num_vecs = num_vecs,
269
+ .num_channels = (float) num_channels,
270
+ .epsilon = epsilon,
271
+ };
272
+
273
+ return gptoss_metal_command_buffer_encode_launch_kernel(
274
+ command_buffer, f32_bf16w_rmsnorm_fn,
275
+ /*threadgroup_size=*/1024, 1, 1,
276
+ num_tokens, 1, 1,
277
+ sizeof(args), &args,
278
+ 4,
279
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer},
280
+ (const size_t[]) {input_offset, weight_offset, output_offset, control_offset},
281
+ /*threadgroup_buffer_size=*/0);
282
+ }
283
+
284
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
285
+ const struct gptoss_metal_command_buffer* command_buffer,
286
+ const struct gptoss_metal_function* f32_bf16w_matmul_fn,
287
+ size_t threadgroup_size,
288
+ const struct gptoss_metal_buffer* input_buffer,
289
+ size_t input_offset,
290
+ const struct gptoss_metal_buffer* weight_buffer,
291
+ size_t weight_offset,
292
+ const struct gptoss_metal_buffer* bias_buffer,
293
+ size_t bias_offset,
294
+ const struct gptoss_metal_buffer* output_buffer,
295
+ size_t output_offset,
296
+ const struct gptoss_metal_buffer* control_buffer,
297
+ size_t control_offset,
298
+ uint32_t num_tokens,
299
+ uint32_t num_cols,
300
+ uint32_t num_rows)
301
+ {
302
+ if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
303
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: invalid command buffer or pipeline state object");
304
+ return gptoss_status_invalid_state;
305
+ }
306
+
307
+ if (threadgroup_size == 0) {
308
+ threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
309
+ } else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
310
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
311
+ threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
312
+ return gptoss_status_invalid_argument;
313
+ }
314
+
315
+ if (num_cols % 4 != 0) {
316
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
317
+ num_cols);
318
+ return gptoss_status_invalid_argument;
319
+ }
320
+ const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
321
+ if (num_rows % num_simdgroups != 0) {
322
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
323
+ num_rows, num_simdgroups);
324
+ return gptoss_status_invalid_argument;
325
+ }
326
+
327
+ const struct gptoss_matmul_args args = {
328
+ .num_column_vecs = num_cols / 4,
329
+ .num_rows = num_rows,
330
+ .add = 0,
331
+ };
332
+
333
+ return gptoss_metal_command_buffer_encode_launch_kernel(
334
+ command_buffer, f32_bf16w_matmul_fn,
335
+ threadgroup_size, 1, 1,
336
+ num_rows / num_simdgroups, num_tokens, 1,
337
+ sizeof(args), &args,
338
+ 5,
339
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
340
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
341
+ /*threadgroup_buffer_size=*/0);
342
+ }
343
+
344
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
345
+ const struct gptoss_metal_command_buffer* command_buffer,
346
+ const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
347
+ size_t threadgroup_size,
348
+ const struct gptoss_metal_buffer* input_buffer,
349
+ size_t input_offset,
350
+ const struct gptoss_metal_buffer* weight_buffer,
351
+ size_t weight_offset,
352
+ const struct gptoss_metal_buffer* bias_buffer,
353
+ size_t bias_offset,
354
+ const struct gptoss_metal_buffer* output_buffer,
355
+ size_t output_offset,
356
+ const struct gptoss_metal_buffer* kv_buffer,
357
+ size_t kv_offset,
358
+ const struct gptoss_metal_buffer* control_buffer,
359
+ size_t control_offset,
360
+ uint32_t num_tokens,
361
+ uint32_t num_cols,
362
+ uint32_t num_q_heads,
363
+ uint32_t num_kv_heads,
364
+ uint32_t attn_head_dim,
365
+ uint32_t token_offset,
366
+ uint32_t max_tokens,
367
+ float rope_base,
368
+ float interpolation_scale,
369
+ float yarn_offset,
370
+ float yarn_scale,
371
+ float yarn_multiplier)
372
+ {
373
+ if (command_buffer->object == NULL || f32_bf16w_matmul_qkv_fn->pipeline_state_object == NULL) {
374
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: invalid command buffer or pipeline state object");
375
+ return gptoss_status_invalid_state;
376
+ }
377
+
378
+ if (threadgroup_size == 0) {
379
+ threadgroup_size = f32_bf16w_matmul_qkv_fn->simdgroup_threads;
380
+ } else if (threadgroup_size > f32_bf16w_matmul_qkv_fn->max_threadgroup_threads) {
381
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
382
+ threadgroup_size, f32_bf16w_matmul_qkv_fn->max_threadgroup_threads);
383
+ return gptoss_status_invalid_argument;
384
+ }
385
+
386
+ if (num_cols % 4 != 0) {
387
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
388
+ num_cols);
389
+ return gptoss_status_invalid_argument;
390
+ }
391
+
392
+ if (num_q_heads != 64) {
393
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of Q heads (%" PRIu32 ") must be 64",
394
+ num_q_heads);
395
+ return gptoss_status_invalid_argument;
396
+ }
397
+ if (num_kv_heads != 8) {
398
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of KV heads (%" PRIu32 ") must be 8",
399
+ num_kv_heads);
400
+ return gptoss_status_invalid_argument;
401
+ }
402
+ if (attn_head_dim != 64) {
403
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: attention head dimension (%" PRIu32 ") must be 64",
404
+ attn_head_dim);
405
+ return gptoss_status_invalid_argument;
406
+ }
407
+
408
+ const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_qkv_fn->simdgroup_threads;
409
+ const uint32_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim;
410
+ if (num_rows % num_simdgroups != 0) {
411
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
412
+ num_rows, num_simdgroups);
413
+ return gptoss_status_invalid_argument;
414
+ }
415
+
416
+ const struct gptoss_qkv_args args = {
417
+ .num_column_vecs = num_cols / 4,
418
+ .num_rows = num_rows,
419
+ .token_offset = token_offset,
420
+ .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
421
+ .interpolation_scale = interpolation_scale,
422
+ .yarn_offset = yarn_offset,
423
+ .yarn_scale = yarn_scale,
424
+ .yarn_multiplier = yarn_multiplier,
425
+ .max_tokens = max_tokens,
426
+ };
427
+
428
+ return gptoss_metal_command_buffer_encode_launch_kernel(
429
+ command_buffer, f32_bf16w_matmul_qkv_fn,
430
+ threadgroup_size, 1, 1,
431
+ num_rows / num_simdgroups, num_tokens, 1,
432
+ sizeof(args), &args,
433
+ 6,
434
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer},
435
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset},
436
+ /*threadgroup_buffer_size=*/num_simdgroups * sizeof(float));
437
+ }
438
+
439
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
440
+ const struct gptoss_metal_command_buffer* command_buffer,
441
+ const struct gptoss_metal_function* f32_bf16w_matmul_fn,
442
+ size_t threadgroup_size,
443
+ const struct gptoss_metal_buffer* input_buffer,
444
+ size_t input_offset,
445
+ const struct gptoss_metal_buffer* weight_buffer,
446
+ size_t weight_offset,
447
+ const struct gptoss_metal_buffer* bias_buffer,
448
+ size_t bias_offset,
449
+ const struct gptoss_metal_buffer* output_buffer,
450
+ size_t output_offset,
451
+ const struct gptoss_metal_buffer* control_buffer,
452
+ size_t control_offset,
453
+ uint32_t num_tokens,
454
+ uint32_t num_cols,
455
+ uint32_t num_rows)
456
+ {
457
+ if (command_buffer->object == NULL || f32_bf16w_matmul_fn->pipeline_state_object == NULL) {
458
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: invalid command buffer or pipeline state object");
459
+ return gptoss_status_invalid_state;
460
+ }
461
+
462
+ if (threadgroup_size == 0) {
463
+ threadgroup_size = f32_bf16w_matmul_fn->simdgroup_threads;
464
+ } else if (threadgroup_size > f32_bf16w_matmul_fn->max_threadgroup_threads) {
465
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
466
+ threadgroup_size, f32_bf16w_matmul_fn->max_threadgroup_threads);
467
+ return gptoss_status_invalid_argument;
468
+ }
469
+
470
+ if (num_cols % 4 != 0) {
471
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
472
+ num_cols);
473
+ return gptoss_status_invalid_argument;
474
+ }
475
+ const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_fn->simdgroup_threads;
476
+ if (num_rows % num_simdgroups != 0) {
477
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
478
+ num_rows, num_simdgroups);
479
+ return gptoss_status_invalid_argument;
480
+ }
481
+
482
+ const struct gptoss_matmul_args args = {
483
+ .num_column_vecs = num_cols / 4,
484
+ .num_rows = num_rows,
485
+ .add = 1,
486
+ };
487
+
488
+ return gptoss_metal_command_buffer_encode_launch_kernel(
489
+ command_buffer, f32_bf16w_matmul_fn,
490
+ threadgroup_size, 1, 1,
491
+ num_rows / num_simdgroups, num_tokens, 1,
492
+ sizeof(args), &args,
493
+ 5,
494
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
495
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
496
+ /*threadgroup_buffer_size=*/0);
497
+ }
498
+
499
+ enum gptoss_status _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
500
+ const struct gptoss_metal_command_buffer* command_buffer,
501
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
502
+ const struct gptoss_metal_buffer* input_buffer,
503
+ size_t input_offset,
504
+ const struct gptoss_metal_buffer* weight_buffer,
505
+ size_t weight_offset,
506
+ const struct gptoss_metal_buffer* bias_buffer,
507
+ size_t bias_offset,
508
+ const struct gptoss_metal_buffer* output_buffer,
509
+ size_t output_offset,
510
+ const struct gptoss_metal_buffer* control_buffer,
511
+ size_t control_offset,
512
+ uint32_t num_tokens,
513
+ uint32_t num_cols,
514
+ uint32_t num_rows,
515
+ uint32_t Bm,
516
+ uint32_t Bn,
517
+ uint32_t Bk,
518
+ uint32_t Sg_Bm,
519
+ uint32_t Sg_Bn)
520
+ {
521
+
522
+ if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) {
523
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object");
524
+ return gptoss_status_invalid_state;
525
+ }
526
+
527
+ if (num_cols % 8 != 0) {
528
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 8",
529
+ num_cols);
530
+ return gptoss_status_invalid_argument;
531
+ }
532
+ if (num_rows % 8 != 0) {
533
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by 8",
534
+ num_rows);
535
+ return gptoss_status_invalid_argument;
536
+ }
537
+ if (num_tokens % 8 != 0) {
538
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of tokens (%" PRIu32 ") is not divisible by 8",
539
+ num_tokens);
540
+ return gptoss_status_invalid_argument;
541
+ }
542
+
543
+ const struct gptoss_dense_matmul_args args = {
544
+ .m = num_tokens,
545
+ .n = num_rows,
546
+ .k = num_cols,
547
+ };
548
+ const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads;
549
+ const uint32_t m = args.m;
550
+ const uint32_t n = args.n;
551
+ const uint32_t k = args.k;
552
+ if (Bm % Sg_Bm != 0) {
553
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
554
+ Bm, Sg_Bm);
555
+ return gptoss_status_invalid_argument;
556
+ }
557
+ if (Bn % Sg_Bn != 0) {
558
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
559
+ Bn, Sg_Bn);
560
+ return gptoss_status_invalid_argument;
561
+ }
562
+ const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
563
+ const size_t threadgroup_size_y = 1;
564
+ const size_t threadgroup_size_z = 1;
565
+ const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
566
+ if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) {
567
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
568
+ total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads);
569
+ return gptoss_status_invalid_argument;
570
+ }
571
+ if (m % Bm != 0) {
572
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: m (%" PRIu32 ") is not divisible by Bm (%" PRIu32 ")",
573
+ m, Bm);
574
+ return gptoss_status_invalid_argument;
575
+ }
576
+ if (n % Bn != 0) {
577
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
578
+ n, Bn);
579
+ return gptoss_status_invalid_argument;
580
+ }
581
+ if (k % Bk != 0) {
582
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
583
+ k, Bk);
584
+ return gptoss_status_invalid_argument;
585
+ }
586
+ const size_t grid_x = n / Bn;
587
+ const size_t grid_y = m / Bm;
588
+ const size_t grid_z = 1;
589
+
590
+ return gptoss_metal_command_buffer_encode_launch_kernel(
591
+ command_buffer, f32_bf16w_dense_matmul_fn,
592
+ threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
593
+ grid_x, grid_y, grid_z,
594
+ sizeof(args), &args,
595
+ 5,
596
+ (const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
597
+ (const size_t[]){input_offset, weight_offset, bias_offset, output_offset, control_offset},
598
+ /*threadgroup_buffer_size=*/0);
599
+ return gptoss_status_success;
600
+ }
601
+
602
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
603
+ const struct gptoss_metal_command_buffer* command_buffer,
604
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
605
+ const struct gptoss_metal_buffer* input_buffer,
606
+ size_t input_offset,
607
+ const struct gptoss_metal_buffer* weight_buffer,
608
+ size_t weight_offset,
609
+ const struct gptoss_metal_buffer* bias_buffer,
610
+ size_t bias_offset,
611
+ const struct gptoss_metal_buffer* output_buffer,
612
+ size_t output_offset,
613
+ const struct gptoss_metal_buffer* control_buffer,
614
+ size_t control_offset,
615
+ uint32_t num_tokens,
616
+ uint32_t num_cols,
617
+ uint32_t num_rows)
618
+ {
619
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
620
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
621
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
622
+ output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, QKV_Bm, QKV_Bn, QKV_Bk,
623
+ QKV_Sg_Bm, QKV_Sg_Bn);
624
+ }
625
+
626
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
627
+ const struct gptoss_metal_command_buffer* command_buffer,
628
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
629
+ const struct gptoss_metal_buffer* input_buffer,
630
+ size_t input_offset,
631
+ const struct gptoss_metal_buffer* weight_buffer,
632
+ size_t weight_offset,
633
+ const struct gptoss_metal_buffer* bias_buffer,
634
+ size_t bias_offset,
635
+ const struct gptoss_metal_buffer* output_buffer,
636
+ size_t output_offset,
637
+ const struct gptoss_metal_buffer* control_buffer,
638
+ size_t control_offset,
639
+ uint32_t num_tokens,
640
+ uint32_t num_cols,
641
+ uint32_t num_rows)
642
+ {
643
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
644
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
645
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
646
+ output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, ATTN_OUTPUT_Bm,
647
+ ATTN_OUTPUT_Bn, ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm, ATTN_OUTPUT_Sg_Bn);
648
+ }
649
+
650
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
651
+ const struct gptoss_metal_command_buffer* command_buffer,
652
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
653
+ const struct gptoss_metal_buffer* input_buffer,
654
+ size_t input_offset,
655
+ const struct gptoss_metal_buffer* weight_buffer,
656
+ size_t weight_offset,
657
+ const struct gptoss_metal_buffer* bias_buffer,
658
+ size_t bias_offset,
659
+ const struct gptoss_metal_buffer* output_buffer,
660
+ size_t output_offset,
661
+ const struct gptoss_metal_buffer* control_buffer,
662
+ size_t control_offset,
663
+ uint32_t num_tokens,
664
+ uint32_t num_cols,
665
+ uint32_t num_rows)
666
+ {
667
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
668
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
669
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
670
+ output_offset, control_buffer, control_offset, num_tokens, num_cols,
671
+ num_rows, MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk, MLP_GATE_Sg_Bm,
672
+ MLP_GATE_Sg_Bn);
673
+ }
674
+
675
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
676
+ const struct gptoss_metal_command_buffer* command_buffer,
677
+ const struct gptoss_metal_function* f32_bf16w_unembedding_fn,
678
+ size_t threadgroup_size,
679
+ size_t max_threadgroups,
680
+ const struct gptoss_metal_buffer* input_buffer,
681
+ size_t input_offset,
682
+ const struct gptoss_metal_buffer* weight_buffer,
683
+ size_t weight_offset,
684
+ const struct gptoss_metal_buffer* output_buffer,
685
+ size_t output_offset,
686
+ const struct gptoss_metal_buffer* argmax_buffer,
687
+ size_t argmax_offset,
688
+ const struct gptoss_metal_buffer* control_buffer,
689
+ size_t control_offset,
690
+ uint32_t num_tokens,
691
+ uint32_t num_cols,
692
+ uint32_t num_rows)
693
+ {
694
+ if (command_buffer->object == NULL || f32_bf16w_unembedding_fn->pipeline_state_object == NULL) {
695
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: invalid command buffer or pipeline state object");
696
+ return gptoss_status_invalid_state;
697
+ }
698
+
699
+ if (threadgroup_size == 0) {
700
+ threadgroup_size = f32_bf16w_unembedding_fn->simdgroup_threads;
701
+ } else if (threadgroup_size > f32_bf16w_unembedding_fn->max_threadgroup_threads) {
702
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
703
+ threadgroup_size, f32_bf16w_unembedding_fn->max_threadgroup_threads);
704
+ return gptoss_status_invalid_argument;
705
+ }
706
+
707
+ if (num_cols % 4 != 0) {
708
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
709
+ num_cols);
710
+ return gptoss_status_invalid_argument;
711
+ }
712
+
713
+ const size_t num_simdgroups = threadgroup_size / f32_bf16w_unembedding_fn->simdgroup_threads;
714
+ const size_t num_rows_per_threadgroup = math_ceil_div(num_rows, max_threadgroups * num_simdgroups) * num_simdgroups;
715
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_rows, num_rows_per_threadgroup));
716
+ const struct gptoss_unembedding_args args = {
717
+ .num_column_vecs = num_cols / 4,
718
+ .num_rows_per_threadgroup = num_rows_per_threadgroup,
719
+ .num_rows = num_rows,
720
+ };
721
+
722
+ return gptoss_metal_command_buffer_encode_launch_kernel(
723
+ command_buffer, f32_bf16w_unembedding_fn,
724
+ threadgroup_size, 1, 1,
725
+ num_threadgroups, num_tokens, 1,
726
+ sizeof(args), &args,
727
+ 5,
728
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer},
729
+ (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset},
730
+ /*threadgroup_buffer_size=*/0);
731
+ }
732
+
733
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
734
+ const struct gptoss_metal_command_buffer* command_buffer,
735
+ const struct gptoss_metal_function* f32_mf4w_moe_matmul_swiglu_fn,
736
+ size_t threadgroup_size,
737
+ const struct gptoss_metal_buffer* input_buffer,
738
+ size_t input_offset,
739
+ const struct gptoss_metal_buffer* expert_buffer,
740
+ size_t expert_offset,
741
+ const struct gptoss_metal_buffer* weight_block_buffer,
742
+ size_t weight_block_offset,
743
+ const struct gptoss_metal_buffer* weight_scale_buffer,
744
+ size_t weight_scale_offset,
745
+ const struct gptoss_metal_buffer* bias_buffer,
746
+ size_t bias_offset,
747
+ const struct gptoss_metal_buffer* output_buffer,
748
+ size_t output_offset,
749
+ const struct gptoss_metal_buffer* control_buffer,
750
+ size_t control_offset,
751
+ float swiglu_limit,
752
+ uint32_t expert_stride,
753
+ uint32_t num_tokens,
754
+ uint32_t num_active_experts,
755
+ uint32_t num_cols,
756
+ uint32_t num_rows)
757
+ {
758
+ if (command_buffer->object == NULL || f32_mf4w_moe_matmul_swiglu_fn->pipeline_state_object == NULL) {
759
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: invalid command buffer or pipeline state object");
760
+ return gptoss_status_invalid_state;
761
+ }
762
+
763
+ if (threadgroup_size == 0) {
764
+ threadgroup_size = 2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
765
+ } else if (threadgroup_size > f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads) {
766
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
767
+ threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->max_threadgroup_threads);
768
+ return gptoss_status_invalid_argument;
769
+ } else if (threadgroup_size % (2 * f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads)) {
770
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu) multiplied by 2X",
771
+ threadgroup_size, f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads);
772
+ return gptoss_status_invalid_argument;
773
+ }
774
+
775
+ if (num_cols % 32 != 0) {
776
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
777
+ num_cols);
778
+ return gptoss_status_invalid_argument;
779
+ }
780
+ const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_swiglu_fn->simdgroup_threads;
781
+ if ((2 * num_rows) % num_simdgroups != 0) {
782
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch: "
783
+ "the number of rows (%" PRIu32 ") multiplied by 2X is not divisible by the number of simdgroups (%zu)",
784
+ num_rows, num_simdgroups);
785
+ return gptoss_status_invalid_argument;
786
+ }
787
+
788
+ const struct gptoss_moe_matmul_swiglu_args args = {
789
+ .num_column_vecs = num_cols / 32,
790
+ .num_rows = num_rows,
791
+ .num_active_experts = num_active_experts,
792
+ .weight_expert_stride = expert_stride,
793
+ .output_expert_stride = num_rows * num_tokens,
794
+ .swiglu_min = -swiglu_limit,
795
+ .swiglu_max = swiglu_limit,
796
+ };
797
+
798
+ return gptoss_metal_command_buffer_encode_launch_kernel(
799
+ command_buffer, f32_mf4w_moe_matmul_swiglu_fn,
800
+ threadgroup_size, 1, 1,
801
+ (2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,
802
+ sizeof(args), &args,
803
+ 7,
804
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
805
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
806
+ /*threadgroup_buffer_size=*/0);
807
+ }
808
+
809
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
810
+ const struct gptoss_metal_command_buffer* command_buffer,
811
+ const struct gptoss_metal_function* f32_mf4w_moe_matmul_fn,
812
+ size_t threadgroup_size,
813
+ const struct gptoss_metal_buffer* input_buffer,
814
+ size_t input_offset,
815
+ const struct gptoss_metal_buffer* expert_buffer,
816
+ size_t expert_offset,
817
+ const struct gptoss_metal_buffer* weight_block_buffer,
818
+ size_t weight_block_offset,
819
+ const struct gptoss_metal_buffer* weight_scale_buffer,
820
+ size_t weight_scale_offset,
821
+ const struct gptoss_metal_buffer* bias_buffer,
822
+ size_t bias_offset,
823
+ const struct gptoss_metal_buffer* output_buffer,
824
+ size_t output_offset,
825
+ const struct gptoss_metal_buffer* control_buffer,
826
+ size_t control_offset,
827
+ uint32_t expert_stride,
828
+ uint32_t num_tokens,
829
+ uint32_t num_active_experts,
830
+ uint32_t num_cols,
831
+ uint32_t num_rows)
832
+ {
833
+ if (command_buffer->object == NULL || f32_mf4w_moe_matmul_fn->pipeline_state_object == NULL) {
834
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: invalid command buffer or pipeline state object");
835
+ return gptoss_status_invalid_state;
836
+ }
837
+
838
+ if (threadgroup_size == 0) {
839
+ threadgroup_size = f32_mf4w_moe_matmul_fn->simdgroup_threads;
840
+ } else if (threadgroup_size > f32_mf4w_moe_matmul_fn->max_threadgroup_threads) {
841
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
842
+ threadgroup_size, f32_mf4w_moe_matmul_fn->max_threadgroup_threads);
843
+ return gptoss_status_invalid_argument;
844
+ } else if (threadgroup_size % f32_mf4w_moe_matmul_fn->simdgroup_threads) {
845
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: threadgroup size (%zu) is not divisible by simdgroup size (%zu)",
846
+ threadgroup_size, f32_mf4w_moe_matmul_fn->simdgroup_threads);
847
+ return gptoss_status_invalid_argument;
848
+ }
849
+
850
+ if (num_cols % 32 != 0) {
851
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
852
+ num_cols);
853
+ return gptoss_status_invalid_argument;
854
+ }
855
+ const size_t num_simdgroups = threadgroup_size / f32_mf4w_moe_matmul_fn->simdgroup_threads;
856
+ if (num_rows % num_simdgroups != 0) {
857
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch: "
858
+ "the number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
859
+ num_rows, num_simdgroups);
860
+ return gptoss_status_invalid_argument;
861
+ }
862
+
863
+ const struct gptoss_moe_matmul_args args = {
864
+ .num_column_vecs = num_cols / 32,
865
+ .num_rows = num_rows,
866
+ .num_active_experts = num_active_experts,
867
+ .input_expert_stride = num_tokens * (num_cols / 32),
868
+ .weight_expert_stride = expert_stride,
869
+ .output_expert_stride = num_rows * num_tokens,
870
+ };
871
+
872
+ return gptoss_metal_command_buffer_encode_launch_kernel(
873
+ command_buffer, f32_mf4w_moe_matmul_fn,
874
+ threadgroup_size, 1, 1,
875
+ num_rows / num_simdgroups, num_tokens, num_active_experts,
876
+ sizeof(args), &args,
877
+ 7,
878
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
879
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
880
+ /*threadgroup_buffer_size=*/0);
881
+ }
882
+
883
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
884
+ const struct gptoss_metal_command_buffer* command_buffer,
885
+ const struct gptoss_metal_function* f32_rope_fn,
886
+ size_t threadgroup_size,
887
+ const struct gptoss_metal_buffer* activations_buffer,
888
+ size_t activations_offset,
889
+ const struct gptoss_metal_buffer* control_buffer,
890
+ size_t control_offset,
891
+ float rope_base,
892
+ float interpolation_scale,
893
+ float yarn_offset,
894
+ float yarn_scale,
895
+ float yarn_multiplier,
896
+ uint32_t num_tokens,
897
+ uint32_t num_q_heads,
898
+ uint32_t num_kv_heads,
899
+ uint32_t attn_head_dim,
900
+ uint32_t token_offset)
901
+ {
902
+ if (command_buffer->object == NULL || f32_rope_fn->pipeline_state_object == NULL) {
903
+ return gptoss_status_invalid_state;
904
+ }
905
+
906
+ if (threadgroup_size == 0) {
907
+ threadgroup_size = f32_rope_fn->max_threadgroup_threads;
908
+ } else if (threadgroup_size > f32_rope_fn->max_threadgroup_threads) {
909
+ return gptoss_status_invalid_argument;
910
+ }
911
+
912
+ const size_t num_simdgroups = threadgroup_size / f32_rope_fn->simdgroup_threads;
913
+ const uint32_t num_qk_heads = num_q_heads + num_kv_heads;
914
+ if (num_qk_heads % num_simdgroups != 0) {
915
+ return gptoss_status_invalid_argument;
916
+ }
917
+
918
+ const struct gptoss_rope_args args = {
919
+ .token_stride = (num_q_heads + 2 * num_kv_heads) * (attn_head_dim / 2),
920
+ .token_offset = token_offset,
921
+ .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
922
+ .interpolation_scale = interpolation_scale,
923
+ .yarn_offset = yarn_offset,
924
+ .yarn_scale = yarn_scale,
925
+ .yarn_multiplier = yarn_multiplier,
926
+ };
927
+
928
+ return gptoss_metal_command_buffer_encode_launch_kernel(
929
+ command_buffer, f32_rope_fn,
930
+ threadgroup_size, 1, 1,
931
+ num_qk_heads / num_simdgroups, num_tokens, 1,
932
+ sizeof(args), &args,
933
+ 2,
934
+ (const struct gptoss_metal_buffer *[]) {activations_buffer, control_buffer},
935
+ (const size_t[]) {activations_offset, control_offset},
936
+ /*threadgroup_buffer_size=*/0);
937
+ }
938
+
939
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
940
+ const struct gptoss_metal_command_buffer* command_buffer,
941
+ const struct gptoss_metal_function* expert_routing_metadata_fn,
942
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
943
+ size_t expert_predictions_offset,
944
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
945
+ size_t expert_offsets_offset,
946
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
947
+ size_t intra_expert_offsets_offset,
948
+ uint32_t num_tokens,
949
+ uint32_t num_experts)
950
+ {
951
+ if (command_buffer->object == NULL || expert_routing_metadata_fn->pipeline_state_object == NULL) {
952
+ return gptoss_status_invalid_state;
953
+ }
954
+
955
+ const struct gptoss_expert_routing_metadata_args args = {
956
+ .tokens = num_tokens,
957
+ .num_experts = num_experts,
958
+ };
959
+ const uint32_t threadgroup_size = 256;
960
+ return gptoss_metal_command_buffer_encode_launch_kernel(
961
+ command_buffer, expert_routing_metadata_fn,
962
+ threadgroup_size, 1, 1,
963
+ /*num_threadgroups_x=*/1, /*num_threadgroups_y=*/1, /*num_threadgroups_z=*/1,
964
+ sizeof(args), &args,
965
+ 3,
966
+ (const struct gptoss_metal_buffer *[]) {expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer},
967
+ (const size_t[]) {expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset},
968
+ /*threadgroup_buffer_size=*/0);
969
+ }
970
+
971
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_scatter(
972
+ const struct gptoss_metal_command_buffer* command_buffer,
973
+ const struct gptoss_metal_function* f32_scatter_fn,
974
+ const struct gptoss_metal_buffer* input_buffer,
975
+ size_t input_offset,
976
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
977
+ size_t expert_predictions_offset,
978
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
979
+ size_t expert_offsets_offset,
980
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
981
+ size_t intra_expert_offsets_offset,
982
+ const struct gptoss_metal_buffer* output_buffer,
983
+ size_t output_offset,
984
+ uint32_t num_channels,
985
+ uint32_t num_tokens,
986
+ uint32_t num_active_experts)
987
+ {
988
+ if (command_buffer->object == NULL || f32_scatter_fn->pipeline_state_object == NULL) {
989
+ return gptoss_status_invalid_state;
990
+ }
991
+
992
+ if (num_channels % 4 != 0) {
993
+ return gptoss_status_invalid_argument;
994
+ }
995
+
996
+ const size_t num_vecs = num_channels / 4;
997
+ const size_t tgx = math_min(num_vecs, 64);
998
+ const size_t tgy = 1;
999
+ const size_t tgz = 1;
1000
+ const size_t grid_x = math_ceil_div(num_vecs, tgx);
1001
+ const size_t grid_y = num_tokens;
1002
+ const size_t grid_z = 1;
1003
+ const size_t total_threadgroup_size = tgx * tgy * tgz;
1004
+ if (total_threadgroup_size > f32_scatter_fn->max_threadgroup_threads) {
1005
+ return gptoss_status_invalid_argument;
1006
+ }
1007
+ const struct gptoss_scatter_args args = {
1008
+ .tokens = num_tokens,
1009
+ .active_experts_per_token = num_active_experts,
1010
+ .token_stride = num_channels,
1011
+ };
1012
+
1013
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1014
+ command_buffer, f32_scatter_fn,
1015
+ tgx, tgy, tgz,
1016
+ grid_x, grid_y, grid_z,
1017
+ sizeof(args), &args,
1018
+ 5,
1019
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},
1020
+ (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},
1021
+ /*threadgroup_buffer_size=*/0);
1022
+ }
1023
+
1024
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_gather_and_accumulate_e4(
1025
+ const struct gptoss_metal_command_buffer* command_buffer,
1026
+ const struct gptoss_metal_function* f32_gather_and_accumulate_e4_fn,
1027
+ const struct gptoss_metal_buffer* input_buffer,
1028
+ size_t input_offset,
1029
+ const struct gptoss_metal_buffer* expert_predictions_buffer,
1030
+ size_t expert_predictions_offset,
1031
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
1032
+ size_t expert_offsets_offset,
1033
+ const struct gptoss_metal_buffer* intra_expert_offsets_buffer,
1034
+ size_t intra_expert_offsets_offset,
1035
+ const struct gptoss_metal_buffer* output_buffer,
1036
+ size_t output_offset,
1037
+ uint32_t num_channels,
1038
+ uint32_t num_tokens,
1039
+ uint32_t num_active_experts)
1040
+ {
1041
+ if (command_buffer->object == NULL || f32_gather_and_accumulate_e4_fn->pipeline_state_object == NULL) {
1042
+ return gptoss_status_invalid_state;
1043
+ }
1044
+
1045
+ if (num_channels % 4 != 0) {
1046
+ return gptoss_status_invalid_argument;
1047
+ }
1048
+
1049
+ const size_t num_vecs = num_channels / 4;
1050
+ const size_t tgx = math_min(num_vecs, 64);
1051
+ const size_t tgy = 1;
1052
+ const size_t tgz = 1;
1053
+ const size_t grid_x = math_ceil_div(num_vecs, tgx);
1054
+ const size_t grid_y = num_tokens;
1055
+ const size_t grid_z = 1;
1056
+ const size_t total_threadgroup_size = tgx * tgy * tgz;
1057
+ if (total_threadgroup_size > f32_gather_and_accumulate_e4_fn->max_threadgroup_threads) {
1058
+ return gptoss_status_invalid_argument;
1059
+ }
1060
+ const struct gptoss_gather_args args = {
1061
+ .tokens = num_tokens,
1062
+ .active_experts_per_token = num_active_experts,
1063
+ .token_stride = num_channels,
1064
+ };
1065
+
1066
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1067
+ command_buffer, f32_gather_and_accumulate_e4_fn,
1068
+ tgx, tgy, tgz,
1069
+ grid_x, grid_y, grid_z,
1070
+ sizeof(args), &args,
1071
+ 5,
1072
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_predictions_buffer, expert_offsets_buffer, intra_expert_offsets_buffer, output_buffer},
1073
+ (const size_t[]) {input_offset, expert_predictions_offset, expert_offsets_offset, intra_expert_offsets_offset, output_offset},
1074
+ /*threadgroup_buffer_size=*/0);
1075
+ }
1076
+
1077
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul_swiglu(
1078
+ const struct gptoss_metal_command_buffer* command_buffer,
1079
+ const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_swiglu_fn,
1080
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
1081
+ size_t expert_offsets_offset,
1082
+ const struct gptoss_metal_buffer* input_buffer,
1083
+ size_t input_offset,
1084
+ const struct gptoss_metal_buffer* weight_block_buffer,
1085
+ size_t weight_block_offset,
1086
+ const struct gptoss_metal_buffer* weight_scale_buffer,
1087
+ size_t weight_scale_offset,
1088
+ const struct gptoss_metal_buffer* bias_buffer,
1089
+ size_t bias_offset,
1090
+ const struct gptoss_metal_buffer* output_buffer,
1091
+ size_t output_offset,
1092
+ float swiglu_limit,
1093
+ uint32_t expert_stride_bytes,
1094
+ uint32_t num_tokens,
1095
+ uint32_t num_experts,
1096
+ uint32_t num_cols,
1097
+ uint32_t num_rows)
1098
+ {
1099
+ if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_swiglu_fn->pipeline_state_object == NULL) {
1100
+ return gptoss_status_invalid_state;
1101
+ }
1102
+
1103
+ if (num_cols % 32 != 0) {
1104
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
1105
+ num_cols);
1106
+ return gptoss_status_invalid_argument;
1107
+ }
1108
+
1109
+ const struct gptoss_moe_dense_matmul_swiglu_args args = {
1110
+ .n = num_rows,
1111
+ .k = num_cols,
1112
+ .weight_blocks_expert_stride_bytes = expert_stride_bytes,
1113
+ .weight_scales_expert_stride_bytes = expert_stride_bytes,
1114
+ .bias_expert_stride_bytes = expert_stride_bytes,
1115
+ .swiglu_min = -swiglu_limit,
1116
+ .swiglu_max = swiglu_limit,
1117
+ };
1118
+ const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_swiglu_fn->simdgroup_threads;
1119
+ const uint32_t m = num_tokens;
1120
+ const uint32_t n = args.n;
1121
+ const uint32_t k = args.k;
1122
+ const uint32_t Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;
1123
+ const uint32_t Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;
1124
+ const uint32_t Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;
1125
+ const uint32_t Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;
1126
+ const uint32_t Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;
1127
+ if (Bm % Sg_Bm != 0) {
1128
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
1129
+ Bm, Sg_Bm);
1130
+ return gptoss_status_invalid_argument;
1131
+ }
1132
+ if (Bn % Sg_Bn != 0) {
1133
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
1134
+ Bn, Sg_Bn);
1135
+ return gptoss_status_invalid_argument;
1136
+ }
1137
+
1138
+ const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
1139
+ const size_t threadgroup_size_y = 1;
1140
+ const size_t threadgroup_size_z = 1;
1141
+ const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
1142
+ if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads) {
1143
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
1144
+ total_threadgroup_size, f32_mf4w_moe_dense_matmul_swiglu_fn->max_threadgroup_threads);
1145
+ return gptoss_status_invalid_argument;
1146
+ }
1147
+ if (n % Bn != 0) {
1148
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
1149
+ n, Bn);
1150
+ return gptoss_status_invalid_argument;
1151
+ }
1152
+ if (k % Bk != 0) {
1153
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul_swiglu kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
1154
+ k, Bk);
1155
+ return gptoss_status_invalid_argument;
1156
+ }
1157
+ const size_t grid_x = n / Bn;
1158
+ const size_t grid_y = math_ceil_div(m, Bm);
1159
+ const size_t grid_z = num_experts;
1160
+
1161
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1162
+ command_buffer, f32_mf4w_moe_dense_matmul_swiglu_fn,
1163
+ threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
1164
+ grid_x, grid_y, grid_z,
1165
+ sizeof(args), &args,
1166
+ 6,
1167
+ (const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
1168
+ (const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
1169
+ /*threadgroup_buffer_size=*/0);
1170
+
1171
+ }
1172
+
1173
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_dense_matmul(
1174
+ const struct gptoss_metal_command_buffer* command_buffer,
1175
+ const struct gptoss_metal_function* f32_mf4w_moe_dense_matmul_fn,
1176
+ const struct gptoss_metal_buffer* expert_offsets_buffer,
1177
+ size_t expert_offsets_offset,
1178
+ const struct gptoss_metal_buffer* input_buffer,
1179
+ size_t input_offset,
1180
+ const struct gptoss_metal_buffer* weight_block_buffer,
1181
+ size_t weight_block_offset,
1182
+ const struct gptoss_metal_buffer* weight_scale_buffer,
1183
+ size_t weight_scale_offset,
1184
+ const struct gptoss_metal_buffer* bias_buffer,
1185
+ size_t bias_offset,
1186
+ const struct gptoss_metal_buffer* output_buffer,
1187
+ size_t output_offset,
1188
+ uint32_t expert_stride_bytes,
1189
+ uint32_t num_tokens,
1190
+ uint32_t num_experts,
1191
+ uint32_t num_cols,
1192
+ uint32_t num_rows)
1193
+ {
1194
+ if (command_buffer->object == NULL || f32_mf4w_moe_dense_matmul_fn->pipeline_state_object == NULL) {
1195
+ return gptoss_status_invalid_state;
1196
+ }
1197
+
1198
+ if (num_cols % 32 != 0) {
1199
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 32",
1200
+ num_cols);
1201
+ return gptoss_status_invalid_argument;
1202
+ }
1203
+ const struct gptoss_moe_dense_matmul_args args = {
1204
+ .k = num_cols,
1205
+ .n = num_rows,
1206
+ .weight_blocks_expert_stride_bytes = expert_stride_bytes,
1207
+ .weight_scales_expert_stride_bytes = expert_stride_bytes,
1208
+ .bias_expert_stride_bytes = expert_stride_bytes,
1209
+ };
1210
+
1211
+ const size_t threads_per_simdgroup = f32_mf4w_moe_dense_matmul_fn->simdgroup_threads;
1212
+ const uint32_t m = num_tokens;
1213
+ const uint32_t n = args.n;
1214
+ const uint32_t k = args.k;
1215
+ const uint32_t Bm = MOE_DENSE_MATMUL_Bm;
1216
+ const uint32_t Bn = MOE_DENSE_MATMUL_Bn;
1217
+ const uint32_t Bk = MOE_DENSE_MATMUL_Bk;
1218
+ const uint32_t Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;
1219
+ const uint32_t Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;
1220
+ if (Bm % Sg_Bm != 0) {
1221
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
1222
+ Bm, Sg_Bm);
1223
+ return gptoss_status_invalid_argument;
1224
+ }
1225
+ if (Bn % Sg_Bn != 0) {
1226
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
1227
+ Bn, Sg_Bn);
1228
+ return gptoss_status_invalid_argument;
1229
+ }
1230
+
1231
+ const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
1232
+ const size_t threadgroup_size_y = 1;
1233
+ const size_t threadgroup_size_z = 1;
1234
+ const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
1235
+ if (total_threadgroup_size > f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads) {
1236
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
1237
+ total_threadgroup_size, f32_mf4w_moe_dense_matmul_fn->max_threadgroup_threads);
1238
+ return gptoss_status_invalid_argument;
1239
+ }
1240
+ if (n % Bn != 0) {
1241
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
1242
+ n, Bn);
1243
+ return gptoss_status_invalid_argument;
1244
+ }
1245
+ if (k % Bk != 0) {
1246
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
1247
+ k, Bk);
1248
+ return gptoss_status_invalid_argument;
1249
+ }
1250
+
1251
+ const size_t grid_y = math_ceil_div(m, Bm);
1252
+ const size_t grid_x = n / Bn;
1253
+ const size_t grid_z = num_experts;
1254
+
1255
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1256
+ command_buffer, f32_mf4w_moe_dense_matmul_fn,
1257
+ threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
1258
+ grid_x, grid_y, grid_z,
1259
+ sizeof(args), &args,
1260
+ 6,
1261
+ (const struct gptoss_metal_buffer *[]) {expert_offsets_buffer, input_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
1262
+ (const size_t[]) {expert_offsets_offset, input_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
1263
+ /*threadgroup_buffer_size=*/0);
1264
+ }
1265
+
1266
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
1267
+ const struct gptoss_metal_command_buffer* command_buffer,
1268
+ const struct gptoss_metal_function* f32_accumulate_fn,
1269
+ size_t threadgroup_size,
1270
+ size_t max_threadgroups,
1271
+ const struct gptoss_metal_buffer* input_buffer,
1272
+ size_t input_offset,
1273
+ const struct gptoss_metal_buffer* expert_buffer,
1274
+ size_t expert_offset,
1275
+ const struct gptoss_metal_buffer* output_buffer,
1276
+ size_t output_offset,
1277
+ const struct gptoss_metal_buffer* control_buffer,
1278
+ size_t control_offset,
1279
+ uint32_t num_channels,
1280
+ uint32_t num_tokens,
1281
+ uint32_t num_experts)
1282
+ {
1283
+ if (command_buffer->object == NULL || f32_accumulate_fn->pipeline_state_object == NULL) {
1284
+ return gptoss_status_invalid_state;
1285
+ }
1286
+
1287
+ if (num_channels% 4 != 0) {
1288
+ return gptoss_status_invalid_argument;
1289
+ }
1290
+
1291
+ if (threadgroup_size == 0) {
1292
+ threadgroup_size = f32_accumulate_fn->max_threadgroup_threads;
1293
+ } else if (threadgroup_size > f32_accumulate_fn->max_threadgroup_threads) {
1294
+ return gptoss_status_invalid_argument;
1295
+ }
1296
+
1297
+ const size_t num_vecs = num_channels / 4;
1298
+ const size_t num_vecs_per_expert = num_vecs * num_tokens;
1299
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
1300
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
1301
+ const struct gptoss_accumulate_args args = {
1302
+ .num_vecs_per_expert = num_vecs_per_expert,
1303
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
1304
+ .num_vecs = num_vecs,
1305
+ };
1306
+
1307
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1308
+ command_buffer, f32_accumulate_fn,
1309
+ threadgroup_size, 1, 1,
1310
+ num_threadgroups, num_tokens, 1,
1311
+ sizeof(args), &args,
1312
+ 4,
1313
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer},
1314
+ (const size_t[]) {input_offset, expert_offset, output_offset, control_offset},
1315
+ /*threadgroup_buffer_size=*/0);
1316
+ }
1317
+
1318
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
1319
+ const struct gptoss_metal_command_buffer* command_buffer,
1320
+ const struct gptoss_metal_function* f32_topk_fn,
1321
+ const struct gptoss_metal_buffer* input_buffer,
1322
+ size_t input_offset,
1323
+ const struct gptoss_metal_buffer* output_buffer,
1324
+ size_t output_offset,
1325
+ const struct gptoss_metal_buffer* control_buffer,
1326
+ size_t control_offset,
1327
+ uint32_t num_tokens,
1328
+ uint32_t num_experts,
1329
+ uint32_t num_active_experts)
1330
+ {
1331
+ if (command_buffer->object == NULL || f32_topk_fn->pipeline_state_object == NULL) {
1332
+ return gptoss_status_invalid_state;
1333
+ }
1334
+
1335
+ if (num_experts != 32 && num_experts != 128) {
1336
+ return gptoss_status_invalid_argument;
1337
+ }
1338
+
1339
+ if (num_active_experts != 4) {
1340
+ return gptoss_status_invalid_argument;
1341
+ }
1342
+
1343
+ const struct gptoss_topk_args args = { 0 };
1344
+
1345
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1346
+ command_buffer, f32_topk_fn,
1347
+ /*threadgroup_size=*/32, 1, 1,
1348
+ num_tokens, 1, 1,
1349
+ sizeof(args), &args,
1350
+ 3,
1351
+ (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer},
1352
+ (const size_t[]) {input_offset, output_offset, control_offset},
1353
+ /*threadgroup_buffer_size=*/0);
1354
+ }
1355
+
1356
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
1357
+ const struct gptoss_metal_command_buffer* command_buffer,
1358
+ const struct gptoss_metal_function* f32_sdpa_fn,
1359
+ const struct gptoss_metal_buffer* q_buffer,
1360
+ size_t q_offset,
1361
+ const struct gptoss_metal_buffer* kv_buffer,
1362
+ size_t kv_offset,
1363
+ const struct gptoss_metal_buffer* s_buffer,
1364
+ size_t s_offset,
1365
+ const struct gptoss_metal_buffer* output_buffer,
1366
+ size_t output_offset,
1367
+ const struct gptoss_metal_buffer* control_buffer,
1368
+ size_t control_offset,
1369
+ uint32_t window,
1370
+ uint32_t kv_stride,
1371
+ uint32_t num_q_tokens,
1372
+ uint32_t num_kv_tokens,
1373
+ uint32_t num_q_heads,
1374
+ uint32_t num_kv_heads,
1375
+ uint32_t head_dim)
1376
+ {
1377
+ if (command_buffer->object == NULL || f32_sdpa_fn->pipeline_state_object == NULL) {
1378
+ return gptoss_status_invalid_state;
1379
+ }
1380
+
1381
+ if (num_q_heads != num_kv_heads * 8) {
1382
+ GPTOSS_LOG_ERROR("number of Q heads (%" PRIu32 ") must be 8 times the number of KV heads (%" PRIu32 ")",
1383
+ num_q_heads, num_kv_heads);
1384
+ return gptoss_status_invalid_argument;
1385
+ }
1386
+
1387
+ if (head_dim != 64) {
1388
+ GPTOSS_LOG_ERROR("attention head dimension (%" PRIu32 ") must be 64", head_dim);
1389
+ return gptoss_status_invalid_argument;
1390
+ }
1391
+
1392
+ const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window);
1393
+ const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads,
1394
+ max_context_tokens * f32_sdpa_fn->simdgroup_threads);
1395
+ const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads);
1396
+
1397
+ const struct gptoss_sdpa_args args = {
1398
+ .qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),
1399
+ .num_kv_tokens = num_kv_tokens,
1400
+ .kv_stride = kv_stride,
1401
+ .window = window,
1402
+ };
1403
+
1404
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1405
+ command_buffer, f32_sdpa_fn,
1406
+ threadgroup_size, 1, 1,
1407
+ num_q_tokens, num_kv_heads, 1,
1408
+ sizeof(args), &args,
1409
+ 5,
1410
+ (const struct gptoss_metal_buffer *[]) {q_buffer, kv_buffer, s_buffer, output_buffer, control_buffer},
1411
+ (const size_t[]) {q_offset, kv_offset, s_offset, output_offset, control_offset},
1412
+ /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));
1413
+ }
1414
+
1415
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
1416
+ const struct gptoss_metal_command_buffer* command_buffer,
1417
+ const struct gptoss_metal_function* f32_softmax_fn,
1418
+ size_t threadgroup_size,
1419
+ size_t max_threadgroups,
1420
+ const struct gptoss_metal_buffer* score_buffer,
1421
+ size_t score_offset,
1422
+ const struct gptoss_metal_buffer* argmax_buffer,
1423
+ size_t argmax_offset,
1424
+ const struct gptoss_metal_buffer* prob_buffer,
1425
+ size_t prob_offset,
1426
+ const struct gptoss_metal_buffer* sum_buffer,
1427
+ size_t sum_offset,
1428
+ const struct gptoss_metal_buffer* control_buffer,
1429
+ size_t control_offset,
1430
+ uint32_t num_channels,
1431
+ uint32_t num_tokens,
1432
+ float temperature,
1433
+ uint32_t* num_threadgroups_out,
1434
+ uint32_t* num_channels_per_threadgroup_out)
1435
+ {
1436
+ *num_threadgroups_out = 0;
1437
+ *num_channels_per_threadgroup_out = 0;
1438
+ if (command_buffer->object == NULL || f32_softmax_fn->pipeline_state_object == NULL) {
1439
+ return gptoss_status_invalid_state;
1440
+ }
1441
+
1442
+ const size_t num_vecs = num_channels;
1443
+ const size_t num_vecs_per_threadgroup = math_ceil_div(num_vecs, max_threadgroups * threadgroup_size) * threadgroup_size;
1444
+ const size_t num_threadgroups = math_min(max_threadgroups, math_ceil_div(num_vecs, num_vecs_per_threadgroup));
1445
+ const struct gptoss_softmax_args args = {
1446
+ .num_vecs = num_vecs,
1447
+ .num_vecs_per_threadgroup = num_vecs_per_threadgroup,
1448
+ .max_threadgroups = max_threadgroups,
1449
+ .temperature = temperature,
1450
+ };
1451
+
1452
+ *num_threadgroups_out = num_threadgroups;
1453
+ *num_channels_per_threadgroup_out = num_vecs_per_threadgroup;
1454
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1455
+ command_buffer, f32_softmax_fn,
1456
+ threadgroup_size, 1, 1,
1457
+ num_threadgroups, num_tokens, 1,
1458
+ sizeof(args), &args,
1459
+ 5,
1460
+ (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer},
1461
+ (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset},
1462
+ /*threadgroup_buffer_size=*/0);
1463
+ }
1464
+
1465
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
1466
+ const struct gptoss_metal_command_buffer* command_buffer,
1467
+ const struct gptoss_metal_function* f32_sample_fn,
1468
+ size_t min_threadgroup_size,
1469
+ const struct gptoss_metal_buffer* prob_buffer,
1470
+ size_t prob_offset,
1471
+ const struct gptoss_metal_buffer* sum_buffer,
1472
+ size_t sum_offset,
1473
+ const struct gptoss_metal_buffer* token_buffer,
1474
+ size_t token_offset,
1475
+ const struct gptoss_metal_buffer* control_buffer,
1476
+ size_t control_offset,
1477
+ uint64_t rng_seed,
1478
+ uint32_t rng_offset,
1479
+ uint32_t num_blocks,
1480
+ uint32_t num_channels,
1481
+ uint32_t num_channels_per_block)
1482
+ {
1483
+ if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {
1484
+ return gptoss_status_invalid_state;
1485
+ }
1486
+
1487
+ if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) {
1488
+ return gptoss_status_invalid_argument;
1489
+ }
1490
+
1491
+ if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) {
1492
+ return gptoss_status_invalid_argument;
1493
+ }
1494
+
1495
+ if (num_blocks > f32_sample_fn->max_threadgroup_threads) {
1496
+ return gptoss_status_invalid_argument;
1497
+ }
1498
+
1499
+ const struct gptoss_sample_args args = {
1500
+ .rng_seed = rng_seed,
1501
+ .rng_offset = rng_offset,
1502
+ .num_blocks = num_blocks,
1503
+ .num_dims = num_channels,
1504
+ .num_dims_per_block = num_channels_per_block,
1505
+ };
1506
+
1507
+ const size_t threadgroup_size = math_max(min_threadgroup_size,
1508
+ math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads));
1509
+ return gptoss_metal_command_buffer_encode_launch_kernel(
1510
+ command_buffer, f32_sample_fn,
1511
+ threadgroup_size, 1, 1,
1512
+ 1, 1, 1,
1513
+ sizeof(args), &args,
1514
+ 4,
1515
+ (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer},
1516
+ (const size_t[]) {prob_offset, sum_offset, token_offset, control_offset},
1517
+ /*threadgroup_buffer_size=*/0);
1518
+ }
gptoss_kernels/source/metal.m ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import <Foundation/Foundation.h>
2
+ #import <Metal/Metal.h>
3
+
4
+ #include <dispatch/dispatch.h>
5
+ #include <mach-o/getsect.h>
6
+
7
+ #include <gpt-oss/types.h>
8
+
9
+ #include <internal/log.h>
10
+ #include <internal/metal.h>
11
+
12
+
13
+ static size_t gptoss_metal_device_get_core_count(id<MTLDevice> device) {
14
+ if (!device) {
15
+ return 0;
16
+ }
17
+
18
+ const uint64_t target_registry_id = [device registryID];
19
+
20
+ io_iterator_t it = IO_OBJECT_NULL;
21
+ const kern_return_t kr = IOServiceGetMatchingServices(
22
+ kIOMainPortDefault,
23
+ IOServiceMatching("IOAccelerator"),
24
+ &it
25
+ );
26
+ if (kr != KERN_SUCCESS) {
27
+ GPTOSS_LOG_ERROR("failed to find IOAccelerator objects: error %d", kr);
28
+ return 0;
29
+ }
30
+
31
+ size_t result = 0;
32
+ for (io_object_t obj = IOIteratorNext(it); obj != IO_OBJECT_NULL; obj = IOIteratorNext(it)) {
33
+ uint64_t registry_id = 0;
34
+ if (IORegistryEntryGetRegistryEntryID(obj, &registry_id) == KERN_SUCCESS &&
35
+ registry_id == target_registry_id)
36
+ {
37
+ // Read "gpu-core-count" from this accelerator node
38
+ const CFTypeRef value = IORegistryEntryCreateCFProperty(
39
+ obj, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
40
+ if (value != NULL) {
41
+ if (CFGetTypeID(value) == CFNumberGetTypeID()) {
42
+ int32_t n = -1;
43
+ if (CFNumberGetValue((CFNumberRef) value, kCFNumberSInt32Type, &n) && n > 0) {
44
+ result = (size_t) n;
45
+ }
46
+ }
47
+ CFRelease(value);
48
+ }
49
+ IOObjectRelease(obj);
50
+ break;
51
+ }
52
+ IOObjectRelease(obj);
53
+ }
54
+
55
+ IOObjectRelease(it);
56
+ return result;
57
+ }
58
+
59
+ enum gptoss_status gptoss_metal_device_create_system_default(
60
+ struct gptoss_metal_device* device_out)
61
+ {
62
+ id<MTLDevice> device_obj = MTLCreateSystemDefaultDevice();
63
+ if (device_obj == nil) {
64
+ GPTOSS_LOG_ERROR("failed to create Metal device");
65
+ return gptoss_status_unsupported_system;
66
+ }
67
+
68
+ device_out->object = (void*) device_obj;
69
+ device_out->num_cores = gptoss_metal_device_get_core_count(device_obj);
70
+ device_out->max_buffer_size = (size_t) [device_obj maxBufferLength];
71
+ device_out->max_threadgroup_memory = (size_t) [device_obj maxThreadgroupMemoryLength];
72
+ const MTLSize max_threadgroup_threads = [device_obj maxThreadsPerThreadgroup];
73
+ device_out->max_threadgroup_threads_x = (size_t) max_threadgroup_threads.width;
74
+ device_out->max_threadgroup_threads_y = (size_t) max_threadgroup_threads.height;
75
+ device_out->max_threadgroup_threads_z = (size_t) max_threadgroup_threads.depth;
76
+ return gptoss_status_success;
77
+ }
78
+
79
+ enum gptoss_status gptoss_metal_device_release(
80
+ struct gptoss_metal_device* device)
81
+ {
82
+ if (device->object != NULL) {
83
+ id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
84
+ [device_obj release];
85
+ }
86
+ memset(device, 0, sizeof(struct gptoss_metal_device));
87
+ return gptoss_status_success;
88
+ }
89
+
90
+ extern const struct mach_header_64 __dso_handle;
91
+
92
+ enum gptoss_status gptoss_metal_library_create_default(
93
+ const struct gptoss_metal_device* device,
94
+ struct gptoss_metal_library* library_out)
95
+ {
96
+ enum gptoss_status status = gptoss_status_success;
97
+ id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
98
+ id<MTLLibrary> library_obj = nil;
99
+ NSAutoreleasePool* autorelease_pool = nil;
100
+ dispatch_data_t library_blob = NULL;
101
+
102
+ unsigned long library_size = 0;
103
+ uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
104
+ if (library_data != NULL) {
105
+ library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
106
+
107
+ autorelease_pool = [[NSAutoreleasePool alloc] init];
108
+ NSError* error_obj = nil;
109
+ library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
110
+ if (library_obj == nil) {
111
+ GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
112
+ status = gptoss_status_unsupported_system;
113
+ goto cleanup;
114
+ }
115
+ } else {
116
+ // Fall-back to loading from the bundle
117
+ library_obj = [device_obj newDefaultLibrary];
118
+ if (library_obj == nil) {
119
+ GPTOSS_LOG_ERROR("failed to create Metal default library");
120
+ status = gptoss_status_unsupported_system;
121
+ goto cleanup;
122
+ }
123
+ }
124
+
125
+ *library_out = (struct gptoss_metal_library) {
126
+ .object = (void*) library_obj,
127
+ };
128
+
129
+ cleanup:
130
+ if (library_blob != NULL) {
131
+ dispatch_release(library_blob);
132
+ }
133
+ if (autorelease_pool != nil) {
134
+ [autorelease_pool drain];
135
+ }
136
+ return status;
137
+ }
138
+
139
+ enum gptoss_status gptoss_metal_library_release(
140
+ struct gptoss_metal_library* library)
141
+ {
142
+ if (library->object != NULL) {
143
+ id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
144
+ [library_obj release];
145
+ }
146
+ memset(library, 0, sizeof(struct gptoss_metal_library));
147
+ return gptoss_status_success;
148
+ }
149
+
150
+ enum gptoss_status gptoss_metal_function_create(
151
+ const struct gptoss_metal_library* library,
152
+ const char* name,
153
+ struct gptoss_metal_function* function_out)
154
+ {
155
+ __block NSString* error_string_obj = nil;
156
+ id<MTLFunction> function_obj = nil;
157
+ MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
158
+ __block id<MTLComputePipelineState> pipeline_state_obj = nil;
159
+ dispatch_semaphore_t pipeline_build_semaphore = NULL;
160
+ enum gptoss_status status = gptoss_status_success;
161
+
162
+ NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
163
+ id<MTLLibrary> library_obj = (id<MTLLibrary>) library->object;
164
+ NSString* name_obj = [NSString stringWithUTF8String:name];
165
+ function_obj = [library_obj newFunctionWithName:name_obj];
166
+ if (function_obj == nil) {
167
+ GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
168
+ status = gptoss_status_unsupported_system;
169
+ goto cleanup;
170
+ }
171
+ id<MTLDevice> device_obj = [library_obj device];
172
+ pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
173
+ [pipeline_descriptor_obj setComputeFunction:function_obj];
174
+ [pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];
175
+
176
+ pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
177
+ [device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
178
+ options:MTLPipelineOptionNone
179
+ completionHandler:^(id<MTLComputePipelineState> _Nullable new_state,
180
+ MTLComputePipelineReflection* _Nullable reflection,
181
+ NSError* _Nullable error_obj) {
182
+ if (new_state != nil) {
183
+ pipeline_state_obj = [new_state retain];
184
+ }
185
+ if (error_obj != nil) {
186
+ error_string_obj = [[error_obj localizedDescription] copy];
187
+ }
188
+ dispatch_semaphore_signal(pipeline_build_semaphore);
189
+ }];
190
+ dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);
191
+
192
+ if (pipeline_state_obj == nil) {
193
+ const char* error_string = "unknown error";
194
+ if (error_string_obj != nil) {
195
+ error_string = [error_string_obj UTF8String];
196
+ }
197
+ GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
198
+ name, error_string);
199
+ status = gptoss_status_unsupported_system;
200
+ goto cleanup;
201
+ }
202
+
203
+ // Commit
204
+ function_out->function_object = function_obj;
205
+ function_out->pipeline_state_object = pipeline_state_obj;
206
+ function_out->max_threadgroup_threads = (size_t) [pipeline_state_obj maxTotalThreadsPerThreadgroup];
207
+ function_out->simdgroup_threads = (size_t) [pipeline_state_obj threadExecutionWidth];
208
+ function_out->static_threadgroup_memory = (size_t) [pipeline_state_obj staticThreadgroupMemoryLength];
209
+
210
+ function_obj = nil;
211
+ pipeline_state_obj = nil;
212
+
213
+ cleanup:
214
+ if (function_obj != nil) {
215
+ [function_obj release];
216
+ }
217
+ if (pipeline_descriptor_obj != nil) {
218
+ [pipeline_descriptor_obj release];
219
+ }
220
+ if (error_string_obj != nil) {
221
+ [error_string_obj release];
222
+ }
223
+ if (pipeline_build_semaphore != NULL) {
224
+ dispatch_release(pipeline_build_semaphore);
225
+ }
226
+ if (autorelease_pool != nil) {
227
+ [autorelease_pool drain];
228
+ }
229
+ return status;
230
+ }
231
+
232
+ enum gptoss_status gptoss_metal_function_release(
233
+ struct gptoss_metal_function* function)
234
+ {
235
+ if (function->pipeline_state_object != NULL) {
236
+ id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
237
+ [pipeline_state_obj release];
238
+ }
239
+ if (function->function_object != NULL) {
240
+ id<MTLFunction> function_obj = (id<MTLFunction>) function->function_object;
241
+ [function_obj release];
242
+ }
243
+ memset(function, 0, sizeof(struct gptoss_metal_function));
244
+ return gptoss_status_success;
245
+ }
246
+
247
+ enum gptoss_status gptoss_metal_buffer_create(
248
+ const struct gptoss_metal_device* device,
249
+ size_t size,
250
+ const void* data,
251
+ struct gptoss_metal_buffer* buffer_out)
252
+ {
253
+ id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
254
+ id<MTLBuffer> buffer_obj = nil;
255
+ if (data != NULL) {
256
+ buffer_obj = [device_obj newBufferWithBytes:data length:size options:MTLResourceStorageModeShared];
257
+ } else {
258
+ buffer_obj = [device_obj newBufferWithLength:size options:MTLResourceStorageModeShared];
259
+ }
260
+ if (buffer_obj == nil) {
261
+ GPTOSS_LOG_ERROR("failed to create Metal buffer of size %zu", size);
262
+ return gptoss_status_unsupported_system;
263
+ }
264
+ buffer_out->object = (void*) buffer_obj;
265
+ buffer_out->size = size;
266
+ buffer_out->ptr = [buffer_obj contents];
267
+ return gptoss_status_success;
268
+ }
269
+
270
+ enum gptoss_status gptoss_metal_buffer_wrap(
271
+ const struct gptoss_metal_device* device,
272
+ size_t size,
273
+ const void* data,
274
+ struct gptoss_metal_buffer* buffer_out)
275
+ {
276
+ id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
277
+ id<MTLBuffer> buffer_obj = [device_obj newBufferWithBytesNoCopy:(void*) data length:size options:MTLResourceStorageModeShared deallocator:nil];
278
+ if (buffer_obj == nil) {
279
+ GPTOSS_LOG_ERROR("failed to wrap Metal buffer of size %zu", size);
280
+ return gptoss_status_unsupported_system;
281
+ }
282
+ buffer_out->object = (void*) buffer_obj;
283
+ buffer_out->size = size;
284
+ buffer_out->ptr = (void*) data;
285
+ return gptoss_status_success;
286
+ }
287
+
288
+ enum gptoss_status gptoss_metal_buffer_release(
289
+ struct gptoss_metal_buffer* buffer)
290
+ {
291
+ if (buffer->object != NULL) {
292
+ id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
293
+ [buffer_obj release];
294
+ }
295
+ memset(buffer, 0, sizeof(struct gptoss_metal_buffer));
296
+ return gptoss_status_success;
297
+ }
298
+
299
+ enum gptoss_status gptoss_metal_command_queue_create(
300
+ const struct gptoss_metal_device* device,
301
+ struct gptoss_metal_command_queue* command_queue_out)
302
+ {
303
+ id<MTLDevice> device_obj = (id<MTLDevice>) device->object;
304
+ id<MTLCommandQueue> command_queue_obj = [device_obj newCommandQueue];
305
+ if (command_queue_obj == nil) {
306
+ GPTOSS_LOG_ERROR("failed to create Metal command queue");
307
+ return gptoss_status_unsupported_system;
308
+ }
309
+ command_queue_out->object = (void*) command_queue_obj;
310
+ return gptoss_status_success;
311
+ }
312
+
313
+ enum gptoss_status gptoss_metal_command_queue_release(
314
+ struct gptoss_metal_command_queue* command_queue)
315
+ {
316
+ if (command_queue->object != NULL) {
317
+ id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
318
+ [command_queue_obj release];
319
+ }
320
+ memset(command_queue, 0, sizeof(struct gptoss_metal_command_queue));
321
+ return gptoss_status_success;
322
+ }
323
+
324
+ enum gptoss_status gptoss_metal_command_buffer_create(
325
+ const struct gptoss_metal_command_queue* command_queue,
326
+ struct gptoss_metal_command_buffer* command_buffer_out)
327
+ {
328
+ id<MTLCommandQueue> command_queue_obj = (id<MTLCommandQueue>) command_queue->object;
329
+ id<MTLCommandBuffer> command_buffer_obj = [command_queue_obj commandBuffer];
330
+ if (command_buffer_obj == nil) {
331
+ GPTOSS_LOG_ERROR("failed to create Metal command buffer");
332
+ return gptoss_status_unsupported_system;
333
+ }
334
+ [command_buffer_obj retain];
335
+ command_buffer_out->object = (void*) command_buffer_obj;
336
+ return gptoss_status_success;
337
+ }
338
+
339
+ enum gptoss_status gptoss_metal_command_buffer_encode_fill_buffer(
340
+ const struct gptoss_metal_command_buffer* command_buffer,
341
+ const struct gptoss_metal_buffer* buffer,
342
+ size_t offset,
343
+ size_t size,
344
+ uint8_t fill_value)
345
+ {
346
+ if (command_buffer->object == NULL) {
347
+ return gptoss_status_invalid_state;
348
+ }
349
+ if (buffer->object == NULL) {
350
+ return gptoss_status_invalid_argument;
351
+ }
352
+
353
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
354
+ id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffer->object;
355
+
356
+ id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
357
+
358
+ const NSRange range = NSMakeRange((NSUInteger) offset, (NSUInteger) size);
359
+ [command_encoder_obj fillBuffer:buffer_obj range:range value:fill_value];
360
+ [command_encoder_obj endEncoding];
361
+
362
+ return gptoss_status_success;
363
+ }
364
+
365
+ enum gptoss_status gptoss_metal_command_buffer_encode_copy_buffer(
366
+ const struct gptoss_metal_command_buffer* command_buffer,
367
+ const struct gptoss_metal_buffer* input_buffer,
368
+ size_t input_offset,
369
+ const struct gptoss_metal_buffer* output_buffer,
370
+ size_t output_offset,
371
+ size_t size)
372
+ {
373
+ if (command_buffer->object == NULL) {
374
+ return gptoss_status_invalid_state;
375
+ }
376
+ if (input_buffer->object == NULL) {
377
+ return gptoss_status_invalid_argument;
378
+ }
379
+ if (output_buffer->object == NULL) {
380
+ return gptoss_status_invalid_argument;
381
+ }
382
+
383
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
384
+ id<MTLBuffer> input_buffer_obj = (id<MTLBuffer>) input_buffer->object;
385
+ id<MTLBuffer> output_buffer_obj = (id<MTLBuffer>) output_buffer->object;
386
+
387
+ id<MTLBlitCommandEncoder> command_encoder_obj = [command_buffer_obj blitCommandEncoder];
388
+
389
+ [command_encoder_obj copyFromBuffer:input_buffer_obj sourceOffset:(NSUInteger) input_offset
390
+ toBuffer:output_buffer_obj destinationOffset:(NSUInteger) output_offset
391
+ size:(NSUInteger) size];
392
+ [command_encoder_obj endEncoding];
393
+
394
+ return gptoss_status_success;
395
+ }
396
+
397
+ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
398
+ const struct gptoss_metal_command_buffer* command_buffer,
399
+ const struct gptoss_metal_function* function,
400
+ size_t threadgroup_size_x,
401
+ size_t threadgroup_size_y,
402
+ size_t threadgroup_size_z,
403
+ size_t num_threadgroups_x,
404
+ size_t num_threadgroups_y,
405
+ size_t num_threadgroups_z,
406
+ size_t params_size,
407
+ const void* params,
408
+ size_t num_device_buffers,
409
+ const struct gptoss_metal_buffer** device_buffers,
410
+ const size_t* device_buffer_offsets,
411
+ size_t threadgroup_buffer_size)
412
+ {
413
+ if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {
414
+ return gptoss_status_invalid_state;
415
+ }
416
+
417
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
418
+ id<MTLComputePipelineState> pipeline_state_obj = (id<MTLComputePipelineState>) function->pipeline_state_object;
419
+
420
+ id<MTLComputeCommandEncoder> command_encoder_obj = [command_buffer_obj computeCommandEncoder];
421
+
422
+ // Set kernel arguments
423
+ [command_encoder_obj setComputePipelineState:pipeline_state_obj];
424
+ [command_encoder_obj setBytes:params length:params_size atIndex:0];
425
+ for (size_t i = 0; i < num_device_buffers; ++i) {
426
+ id<MTLBuffer> buffer_obj = (id<MTLBuffer>) device_buffers[i]->object;
427
+ const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i];
428
+ [command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];
429
+ }
430
+ if (threadgroup_buffer_size != 0) {
431
+ [command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0];
432
+ }
433
+
434
+ // Dispatch kernel
435
+ const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);
436
+ const MTLSize num_threadgroups = MTLSizeMake(num_threadgroups_x, num_threadgroups_y, num_threadgroups_z);
437
+ [command_encoder_obj dispatchThreadgroups:num_threadgroups threadsPerThreadgroup:threadgroup_size];
438
+ [command_encoder_obj endEncoding];
439
+
440
+ return gptoss_status_success;
441
+ }
442
+
443
+ enum gptoss_status gptoss_metal_command_buffer_commit(
444
+ const struct gptoss_metal_command_buffer* command_buffer)
445
+ {
446
+ if (command_buffer->object == NULL) {
447
+ return gptoss_status_invalid_state;
448
+ }
449
+
450
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
451
+ [command_buffer_obj commit];
452
+ return gptoss_status_success;
453
+ }
454
+
455
+ enum gptoss_status gptoss_metal_command_buffer_wait_completion(
456
+ const struct gptoss_metal_command_buffer* command_buffer,
457
+ double* elapsed_seconds)
458
+ {
459
+ if (command_buffer->object == NULL) {
460
+ return gptoss_status_invalid_state;
461
+ }
462
+
463
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
464
+ [command_buffer_obj waitUntilCompleted];
465
+ if (elapsed_seconds != NULL) {
466
+ const CFTimeInterval start_time = [command_buffer_obj GPUStartTime];
467
+ const CFTimeInterval end_time = [command_buffer_obj GPUEndTime];
468
+ *elapsed_seconds = (double) end_time - (double) start_time;
469
+ }
470
+ return gptoss_status_success;
471
+ }
472
+
473
+ enum gptoss_status gptoss_metal_command_buffer_release(
474
+ struct gptoss_metal_command_buffer* command_buffer)
475
+ {
476
+ if (command_buffer->object != NULL) {
477
+ id<MTLCommandBuffer> command_buffer_obj = (id<MTLCommandBuffer>) command_buffer->object;
478
+ [command_buffer_obj release];
479
+ }
480
+ memset(command_buffer, 0, sizeof(struct gptoss_metal_command_buffer));
481
+ return gptoss_status_success;
482
+ }
gptoss_kernels/source/model.c ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h>
2
+ #include <inttypes.h>
3
+ #include <stdatomic.h>
4
+ #include <stdint.h>
5
+ #include <stdlib.h>
6
+ #include <string.h>
7
+
8
+ #include <errno.h> // errno, EISDIR, ENOENT, ENOTDIR
9
+ #include <fcntl.h> // open
10
+ #include <mach/vm_page_size.h> // vm_page_size
11
+ #include <sys/mman.h> // mmap, PROT_READ, MAP_PRIVATE
12
+ #include <sys/stat.h> // fstat, stat
13
+ #include <sys/types.h> // off_t, ssize_t
14
+ #include <unistd.h> // close
15
+
16
+ #include <gpt-oss.h>
17
+
18
+ #include "internal/datatype.h"
19
+ #include "internal/kernel-args.h" // gptoss_expert_prediction
20
+ #include "internal/log.h"
21
+ #include "internal/uuid.h"
22
+ #include "internal/storage.h"
23
+ #include "internal/math.h"
24
+ #include "internal/model.h"
25
+
26
+
27
+ static size_t round_up_to_page_size(size_t bytes) {
28
+ const size_t page_size_mask = (size_t) vm_page_size - 1;
29
+ if ((bytes & page_size_mask) != 0) {
30
+ bytes |= page_size_mask;
31
+ bytes += 1;
32
+ }
33
+ return bytes;
34
+ }
35
+
36
+ static size_t round_down_to_page_size(size_t bytes) {
37
+ const size_t page_size_mask = (size_t) vm_page_size - 1;
38
+ return bytes & ~page_size_mask;
39
+ }
40
+
41
+ static enum gptoss_status read_fd(int fd, void* data, size_t size, const char* path) {
42
+ assert(fd != -1);
43
+ assert(data != NULL);
44
+ assert(size != 0);
45
+
46
+ size_t bytes_to_read = size;
47
+ char* current_byte = (char*) data;
48
+ do {
49
+ const ssize_t read_result = read(fd, current_byte, bytes_to_read);
50
+ if (read_result < 0) {
51
+ GPTOSS_LOG_ERROR("reading %zu bytes from file %s failed with error %d",
52
+ size, path, errno);
53
+ return gptoss_status_io_error;
54
+ }
55
+ current_byte += (size_t) read_result;
56
+ bytes_to_read -= (size_t) read_result;
57
+ } while (bytes_to_read != 0);
58
+ return gptoss_status_success;
59
+ }
60
+
61
+ static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
62
+ // radvisory.ra_count is int, so we can't prefetch 2GB+ at once
63
+ const size_t prefetch_max = round_down_to_page_size((size_t) INT_MAX);
64
+ do {
65
+ const size_t prefetch_size = math_min(size, prefetch_max);
66
+ const struct radvisory ra = {
67
+ .ra_offset = offset,
68
+ .ra_count = (int) prefetch_size,
69
+ };
70
+ if (fcntl(fd, F_RDADVISE, &ra) == -1) {
71
+ GPTOSS_LOG_WARNING("fcntl(%s, F_RDADVISE, .ra_offset=%zu, .ra_count=%d) failed with error %d\n",
72
+ path, (size_t) ra.ra_offset, ra.ra_count, errno);
73
+ return;
74
+ }
75
+ offset += prefetch_size;
76
+ size -= prefetch_size;
77
+ } while (size != 0);
78
+ }
79
+
80
+ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
81
+ const char* path,
82
+ gptoss_model_t* model_out)
83
+ {
84
+ *model_out = NULL;
85
+
86
+ enum gptoss_status status = gptoss_status_success;
87
+ struct gptoss_model* model = NULL;
88
+ struct gptoss_tokenizer* tokenizer = NULL;
89
+ int fd = -1;
90
+ size_t file_offset = 0;
91
+
92
+ fd = open(path, O_RDONLY);
93
+ if (fd == -1) {
94
+ GPTOSS_LOG_ERROR("open(%s) failed with error %d", path, errno);
95
+ switch (errno) {
96
+ case EISDIR:
97
+ case ENOENT:
98
+ case ENOTDIR:
99
+ status = gptoss_status_invalid_argument;
100
+ break;
101
+ default:
102
+ status = gptoss_status_io_error;
103
+ break;
104
+ }
105
+ goto cleanup;
106
+ }
107
+
108
+ struct gptoss_file_header file_header;
109
+ status = read_fd(fd, &file_header, sizeof(file_header), path);
110
+ if (status != gptoss_status_success) {
111
+ goto cleanup;
112
+ }
113
+ file_offset += sizeof(file_header);
114
+
115
+ if (file_header.magic[0] != 'G' ||
116
+ file_header.magic[1] != 'P' ||
117
+ file_header.magic[2] != 'T' ||
118
+ file_header.magic[3] != '-' ||
119
+ file_header.magic[4] != 'O' ||
120
+ file_header.magic[5] != 'S' ||
121
+ file_header.magic[6] != 'S' ||
122
+ file_header.magic[7] != ' ' ||
123
+ file_header.magic[8] != 'v' ||
124
+ file_header.magic[9] != '1' ||
125
+ file_header.magic[10] != '.' ||
126
+ file_header.magic[11] != '0' ||
127
+ file_header.zero != 0)
128
+ {
129
+ GPTOSS_LOG_ERROR("invalid magic in file %s", path);
130
+ status = gptoss_status_invalid_argument;
131
+ goto cleanup;
132
+ }
133
+
134
+ struct gptoss_uuid model_uuid;
135
+ status = read_fd(fd, &model_uuid, sizeof(model_uuid), path);
136
+ if (status != gptoss_status_success) {
137
+ goto cleanup;
138
+ }
139
+ file_offset += sizeof(model_uuid);
140
+
141
+ if (!gptoss_is_gptoss_model_uuid(&model_uuid)) {
142
+ GPTOSS_LOG_ERROR("unsupported model UUID " UUID_FORMAT, UUID_ARGS(model_uuid));
143
+ status = gptoss_status_invalid_argument;
144
+ goto cleanup;
145
+ }
146
+
147
+ struct gptoss_gptoss_model_header model_header;
148
+ status = read_fd(fd, &model_header, sizeof(model_header), path);
149
+ if (status != gptoss_status_success) {
150
+ goto cleanup;
151
+ }
152
+ file_offset += sizeof(model_header);
153
+
154
+ struct gptoss_uuid layout_uuid;
155
+ status = read_fd(fd, &layout_uuid, sizeof(layout_uuid), path);
156
+ if (status != gptoss_status_success) {
157
+ goto cleanup;
158
+ }
159
+ file_offset += sizeof(layout_uuid);
160
+
161
+ if (!gptoss_is_applegpu_layout_uuid(&layout_uuid)) {
162
+ GPTOSS_LOG_ERROR("unsupported layout UUID " UUID_FORMAT, UUID_ARGS(layout_uuid));
163
+ status = gptoss_status_invalid_argument;
164
+ goto cleanup;
165
+ }
166
+
167
+ const size_t model_size = sizeof(struct gptoss_model) + model_header.num_blocks * sizeof(struct gptoss_metal_buffer);
168
+ model = malloc(model_size);
169
+ if (model == NULL) {
170
+ GPTOSS_LOG_ERROR("failed to allocate %zu bytes for model descriptor", model_size);
171
+ status = gptoss_status_insufficient_memory;
172
+ goto cleanup;
173
+ }
174
+ memset(model, 0, model_size);
175
+
176
+ atomic_store_explicit(&model->ref_count, 1, memory_order_relaxed);
177
+ model->context_length = model_header.context_length;
178
+ model->num_blocks = model_header.num_blocks;
179
+ model->num_experts = model_header.num_experts;
180
+ model->num_active_experts = model_header.num_active_experts;
181
+ model->embedding_dim = model_header.embedding_dim;
182
+ model->mlp_dim = model_header.mlp_dim;
183
+ model->swiglu_limit = model_header.swiglu_limit;
184
+ model->head_dim = model_header.head_dim;
185
+ model->num_heads = model_header.num_heads;
186
+ model->num_kv_heads = model_header.num_kv_heads;
187
+ model->attention_window = model_header.attention_window;
188
+ model->rope_theta = model_header.rope_theta;
189
+ model->interpolation_scale = model_header.interpolation_scale;
190
+ model->yarn_offset = model_header.yarn_offset;
191
+ model->yarn_scale = model_header.yarn_scale;
192
+ model->yarn_multiplier = model_header.yarn_multiplier;
193
+ model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
194
+
195
+ struct gptoss_uuid tokenizer_uuid;
196
+ status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
197
+ if (status != gptoss_status_success) {
198
+ goto cleanup;
199
+ }
200
+ file_offset += sizeof(tokenizer_uuid);
201
+
202
+ if (!gptoss_is_tiktoken_tokenizer_uuid(&tokenizer_uuid)) {
203
+ GPTOSS_LOG_ERROR("unsupported tokenizer UUID " UUID_FORMAT, UUID_ARGS(tokenizer_uuid));
204
+ status = gptoss_status_invalid_argument;
205
+ goto cleanup;
206
+ }
207
+
208
+ struct gptoss_tiktoken_tokenizer_header tokenizer_header;
209
+ status = read_fd(fd, &tokenizer_header, sizeof(tokenizer_header), path);
210
+ if (status != gptoss_status_success) {
211
+ goto cleanup;
212
+ }
213
+ file_offset += sizeof(tokenizer_header);
214
+
215
+ tokenizer = malloc(sizeof(struct gptoss_tokenizer));
216
+ if (tokenizer == NULL) {
217
+ GPTOSS_LOG_ERROR("failed to allocate %zu bytes for tokenizer descriptor", sizeof(struct gptoss_tokenizer));
218
+ status = gptoss_status_insufficient_memory;
219
+ goto cleanup;
220
+ }
221
+ memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
222
+ // Initialize all special token IDs to UINT32_MAX (0xFF in all bytes)
223
+ memset(tokenizer->special_token_id, 0xFF, sizeof(tokenizer->special_token_id));
224
+
225
+ atomic_store_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
226
+ tokenizer->num_special_tokens = tokenizer_header.num_special_tokens;
227
+ tokenizer->num_text_tokens = tokenizer_header.num_text_tokens;
228
+ model->vocabulary_size = tokenizer_header.num_special_tokens + tokenizer_header.num_text_tokens;
229
+ for (uint32_t t = 0; t < tokenizer_header.num_special_tokens; t++) {
230
+ struct gptoss_uuid token_uuid;
231
+ status = read_fd(fd, &token_uuid, sizeof(token_uuid), path);
232
+ if (status != gptoss_status_success) {
233
+ goto cleanup;
234
+ }
235
+ file_offset += sizeof(token_uuid);
236
+
237
+ const enum gptoss_special_token token = gptoss_special_token_decode_uuid(&token_uuid);
238
+ if (token != gptoss_special_token_invalid) {
239
+ tokenizer->special_token_id[token - 1] = tokenizer_header.num_text_tokens + t;
240
+ }
241
+ }
242
+
243
+ const size_t tokenizer_start_offset = file_offset;
244
+ const size_t tokenizer_end_offset = tokenizer_start_offset + tokenizer_header.regex_size + tokenizer_header.tokens_size;
245
+ const size_t tokenizer_mapping_start = round_down_to_page_size(tokenizer_start_offset);
246
+ const size_t tokenizer_mapping_size = round_up_to_page_size(tokenizer_end_offset) - tokenizer_mapping_start;
247
+ void* tokenizer_mapping_ptr = mmap(NULL, tokenizer_mapping_size, PROT_READ, MAP_PRIVATE, fd, tokenizer_mapping_start);
248
+ if (tokenizer_mapping_ptr == (void*) -1) {
249
+ GPTOSS_LOG_ERROR("failed to mmap(%s) tokenizer at offset %zu size %zu",
250
+ path, tokenizer_mapping_start, tokenizer_mapping_size);
251
+ status = gptoss_status_io_error;
252
+ goto cleanup;
253
+ }
254
+ tokenizer->mapping_ptr = tokenizer_mapping_ptr;
255
+ tokenizer->mapping_size = tokenizer_mapping_size;
256
+ tokenizer->regex_ptr = (const char*) tokenizer_mapping_ptr + (tokenizer_start_offset - tokenizer_mapping_start);
257
+ tokenizer->tokens_ptr = tokenizer->regex_ptr + tokenizer_header.regex_size;
258
+
259
+ if (madvise(tokenizer_mapping_ptr, tokenizer_mapping_size, MADV_RANDOM | MADV_WILLNEED) != 0) {
260
+ GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, tokenizer_mapping_size, errno);
261
+ }
262
+
263
+ prefetch_fd(fd, tokenizer_mapping_start, tokenizer_mapping_size, path);
264
+
265
+ struct stat model_stat = {0};
266
+ int stat_result = fstat(fd, &model_stat);
267
+ if (stat_result != 0) {
268
+ GPTOSS_LOG_ERROR("stat(%s) failed with error %d", path, errno);
269
+ status = gptoss_status_io_error;
270
+ goto cleanup;
271
+ }
272
+
273
+ const size_t model_mapping_start = round_up_to_page_size(tokenizer_end_offset);
274
+ const size_t model_mapping_size = round_up_to_page_size((size_t) model_stat.st_size) - model_mapping_start;
275
+ void* model_mapping_ptr = mmap(NULL, model_mapping_size, PROT_READ, MAP_PRIVATE, fd, model_mapping_start);
276
+ if (model_mapping_ptr == (void*) -1) {
277
+ GPTOSS_LOG_ERROR("failed to mmap(%s) model weights at offset %zu size %zu",
278
+ path, model_mapping_start, model_mapping_size);
279
+ status = gptoss_status_io_error;
280
+ goto cleanup;
281
+ }
282
+ model->mapping_ptr = model_mapping_ptr;
283
+ model->mapping_size = model_mapping_size;
284
+
285
+ if (madvise(model_mapping_ptr, model_mapping_size, MADV_SEQUENTIAL | MADV_WILLNEED) != 0) {
286
+ GPTOSS_LOG_WARNING("madvise(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
287
+ }
288
+
289
+ prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
290
+
291
+ if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
292
+ GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
293
+ } else {
294
+ model->lock_memory = true;
295
+ }
296
+
297
+ // Initialize Metal
298
+ status = gptoss_metal_device_create_system_default(&model->device);
299
+ if (status != gptoss_status_success) {
300
+ goto cleanup;
301
+ }
302
+ model->max_threadgroups = model->device.num_cores * 3;
303
+ status = gptoss_metal_command_queue_create(&model->device, &model->command_queue);
304
+ if (status != gptoss_status_success) {
305
+ goto cleanup;
306
+ }
307
+
308
+ // Metal kernels
309
+ status = gptoss_metal_library_create_default(&model->device, &model->library);
310
+ if (status != gptoss_status_success) {
311
+ goto cleanup;
312
+ }
313
+ status = gptoss_metal_function_create(&model->library, "gptoss_bf16_f32_embeddings", &model->bf16_f32_embeddings_fn);
314
+ if (status != gptoss_status_success) {
315
+ goto cleanup;
316
+ }
317
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_rmsnorm", &model->f32_bf16w_rmsnorm_fn);
318
+ if (status != gptoss_status_success) {
319
+ goto cleanup;
320
+ }
321
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul", &model->f32_bf16w_matmul_fn);
322
+ if (status != gptoss_status_success) {
323
+ goto cleanup;
324
+ }
325
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn);
326
+ if (status != gptoss_status_success) {
327
+ goto cleanup;
328
+ }
329
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn);
330
+ if (status != gptoss_status_success) {
331
+ goto cleanup;
332
+ }
333
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn);
334
+ if (status != gptoss_status_success) {
335
+ goto cleanup;
336
+ }
337
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn);
338
+ if (status != gptoss_status_success) {
339
+ goto cleanup;
340
+ }
341
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
342
+ if (status != gptoss_status_success) {
343
+ goto cleanup;
344
+ }
345
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_rope", &model->f32_rope_fn);
346
+ if (status != gptoss_status_success) {
347
+ goto cleanup;
348
+ }
349
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_expert_routing_metadata", &model->f32_expert_routing_metadata_fn);
350
+ if (status != gptoss_status_success) {
351
+ goto cleanup;
352
+ }
353
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_scatter_e4", &model->f32_scatter_e4_fn);
354
+ if (status != gptoss_status_success) {
355
+ goto cleanup;
356
+ }
357
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul_swiglu", &model->f32_mf4w_moe_dense_matmul_swiglu_fn);
358
+ if (status != gptoss_status_success) {
359
+ goto cleanup;
360
+ }
361
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_dense_matmul", &model->f32_mf4w_moe_dense_matmul_fn);
362
+ if (status != gptoss_status_success) {
363
+ goto cleanup;
364
+ }
365
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_gather_and_accumulate_e4", &model->f32_gather_and_accumulate_e4_fn);
366
+ if (status != gptoss_status_success) {
367
+ goto cleanup;
368
+ }
369
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul_swiglu", &model->f32_mf4w_moe_matmul_swiglu_fn);
370
+ if (status != gptoss_status_success) {
371
+ goto cleanup;
372
+ }
373
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_mf4w_moe_matmul", &model->f32_mf4w_moe_matmul_fn);
374
+ if (status != gptoss_status_success) {
375
+ goto cleanup;
376
+ }
377
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_accumulate_e4", &model->f32_accumulate_e4_fn);
378
+ if (status != gptoss_status_success) {
379
+ goto cleanup;
380
+ }
381
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e32_k4", &model->f32_topk_softmax_e32_k4_fn);
382
+ if (status != gptoss_status_success) {
383
+ goto cleanup;
384
+ }
385
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_topk_softmax_e128_k4", &model->f32_topk_softmax_e128_k4_fn);
386
+ if (status != gptoss_status_success) {
387
+ goto cleanup;
388
+ }
389
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_softmax", &model->f32_softmax_fn);
390
+ if (status != gptoss_status_success) {
391
+ goto cleanup;
392
+ }
393
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
394
+ if (status != gptoss_status_success) {
395
+ goto cleanup;
396
+ }
397
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
398
+ if (status != gptoss_status_success) {
399
+ goto cleanup;
400
+ }
401
+
402
+ // Kernel launch parameters
403
+ model->embeddings_threadgroup_size = 512;
404
+ model->attn_qkv_threadgroup_size = 1024;
405
+ model->attn_out_threadgroup_size = 768;
406
+ model->mlp_gate_threadgroup_size = 256;
407
+ model->mlp_swiglu_threadgroup_size = 192;
408
+ model->mlp_out_threadgroup_size = 192;
409
+ model->mlp_acc_threadgroup_size = 768;
410
+ model->unembedding_threadgroup_size = 416;
411
+
412
+ // Weight buffers
413
+ const char* current_ptr = (const char*) model->mapping_ptr;
414
+
415
+ const size_t embedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
416
+ model->attn_rmsnorm_gain_offset = embedding_weight_size;
417
+ const size_t rmsnorm_weight_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
418
+ model->attn_qkv_weight_offset = model->attn_rmsnorm_gain_offset + rmsnorm_weight_size;
419
+ const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
420
+ const size_t attn_qkv_weight_size = math_round_up_po2(attn_qkv_dim * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
421
+ model->attn_qkv_bias_offset = model->attn_qkv_weight_offset + attn_qkv_weight_size;
422
+ const size_t attn_qkv_bias_size = math_round_up_po2(attn_qkv_dim * sizeof(gptoss_bfloat16), 16);
423
+ model->attn_sdpa_sink_offset = model->attn_qkv_bias_offset + attn_qkv_bias_size;
424
+ const size_t attn_sink_weight_size = math_round_up_po2(model->num_heads * sizeof(gptoss_bfloat16), 16);
425
+ model->attn_out_weight_offset = model->attn_sdpa_sink_offset + attn_sink_weight_size;
426
+ const size_t attn_out_weight_size = math_round_up_po2(model->embedding_dim * model->num_heads * model->head_dim * sizeof(gptoss_bfloat16), 16);
427
+ model->attn_out_bias_offset = model->attn_out_weight_offset + attn_out_weight_size;
428
+ const size_t attn_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
429
+ model->mlp_rmsnorm_gain_offset = model->attn_out_bias_offset + attn_out_bias_size;
430
+ model->mlp_gate_weight_offset = model->mlp_rmsnorm_gain_offset + rmsnorm_weight_size;
431
+ const size_t mlp_gate_weight_size = math_round_up_po2(model->num_experts * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
432
+ model->mlp_gate_bias_offset = model->mlp_gate_weight_offset + mlp_gate_weight_size;
433
+ const size_t mlp_gate_bias_size = math_round_up_po2(model->num_experts * sizeof(gptoss_bfloat16), 16);
434
+ const size_t per_block_shared_weights_size =
435
+ rmsnorm_weight_size + attn_qkv_weight_size + attn_qkv_bias_size + attn_sink_weight_size + attn_out_weight_size + attn_out_bias_size +
436
+ rmsnorm_weight_size + mlp_gate_weight_size + mlp_gate_bias_size;
437
+ model->rmsnorm_weight_offset = embedding_weight_size + model->num_blocks * per_block_shared_weights_size;
438
+ model->unembedding_weight_offset = model->rmsnorm_weight_offset + rmsnorm_weight_size;
439
+ const size_t unembedding_weight_size = math_round_up_po2(model->vocabulary_size * model->embedding_dim * sizeof(gptoss_bfloat16), 16);
440
+
441
+ model->per_block_shared_weights_size = per_block_shared_weights_size;
442
+ const size_t shared_weights_size =
443
+ round_up_to_page_size(embedding_weight_size + rmsnorm_weight_size + unembedding_weight_size + model->num_blocks * per_block_shared_weights_size);
444
+
445
+ status = gptoss_metal_buffer_wrap(&model->device, shared_weights_size, current_ptr, &model->shared_weight_buffer);
446
+ if (status != gptoss_status_success) {
447
+ GPTOSS_LOG_ERROR("failed to map expert-shared weight of size %zu onto a Metal buffer", shared_weights_size);
448
+ goto cleanup;
449
+ }
450
+ current_ptr += shared_weights_size;
451
+ model->weights_size += shared_weights_size;
452
+
453
+ const size_t mlp_swiglu_weight_block_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 2, 16);
454
+ model->mlp_swiglu_scale_offset = mlp_swiglu_weight_block_size;
455
+ const size_t mlp_swiglu_weight_scale_size = math_round_up_po2(2 * model->mlp_dim * model->embedding_dim / 32, 16);
456
+ model->mlp_swiglu_bias_offset = model->mlp_swiglu_scale_offset + mlp_swiglu_weight_scale_size;
457
+ const size_t mlp_swiglu_bias_size = math_round_up_po2(2 * model->mlp_dim * sizeof(gptoss_bfloat16), 16);
458
+ model->mlp_out_block_offset = model->mlp_swiglu_bias_offset + mlp_swiglu_bias_size;
459
+ const size_t mlp_out_weight_block_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 2, 16);
460
+ model->mlp_out_scale_offset = model->mlp_out_block_offset + mlp_out_weight_block_size;
461
+ const size_t mlp_out_weight_scale_size = math_round_up_po2(model->embedding_dim * model->mlp_dim / 32, 16);
462
+ model->mlp_out_bias_offset = model->mlp_out_scale_offset + mlp_out_weight_scale_size;
463
+ const size_t mlp_out_bias_size = math_round_up_po2(model->embedding_dim * sizeof(gptoss_bfloat16), 16);
464
+ model->per_expert_block_weight_size =
465
+ mlp_swiglu_weight_block_size + mlp_swiglu_weight_scale_size + mlp_swiglu_bias_size + mlp_out_weight_block_size + mlp_out_weight_scale_size + mlp_out_bias_size;
466
+ const size_t moe_block_weight_size = round_up_to_page_size(model->num_experts * model->per_expert_block_weight_size);
467
+ for (uint32_t n = 0; n < model->num_blocks; n++) {
468
+ status = gptoss_metal_buffer_wrap(&model->device, moe_block_weight_size, current_ptr, &model->block_weight_buffers[n]);
469
+ if (status != gptoss_status_success) {
470
+ GPTOSS_LOG_ERROR("failed to map block #%" PRIu32 " MoE weight of size %zu onto a Metal buffer",
471
+ n, moe_block_weight_size);
472
+ goto cleanup;
473
+ }
474
+ current_ptr += moe_block_weight_size;
475
+ model->weights_size += moe_block_weight_size;
476
+ }
477
+
478
+ // Commit tokenizer
479
+ model->tokenizer = tokenizer;
480
+ tokenizer = NULL;
481
+
482
+ // Commit model
483
+ *model_out = model;
484
+ model = NULL;
485
+
486
+ cleanup:
487
+ if (fd != -1) {
488
+ close(fd);
489
+ fd = -1;
490
+ }
491
+ gptoss_model_release(model); // does nothing if model is NULL
492
+ gptoss_tokenizer_release(tokenizer); // does nothing if tokenizer is NULL
493
+ return status;
494
+ }
495
+
496
+ enum gptoss_status GPTOSS_ABI gptoss_model_get_tokenizer(
497
+ gptoss_model_t model,
498
+ gptoss_tokenizer_t* tokenizer_out)
499
+ {
500
+ gptoss_tokenizer_t tokenizer = model->tokenizer;
501
+ atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
502
+ *tokenizer_out = tokenizer;
503
+ return gptoss_status_success;
504
+ }
505
+
506
+ enum gptoss_status GPTOSS_ABI gptoss_model_get_max_context_length(
507
+ gptoss_model_t model,
508
+ size_t* max_context_length_out)
509
+ {
510
+ *max_context_length_out = model->context_length;
511
+ return gptoss_status_success;
512
+ }
513
+
514
+ enum gptoss_status GPTOSS_ABI gptoss_model_retain(
515
+ gptoss_model_t model)
516
+ {
517
+ atomic_fetch_add_explicit(&model->ref_count, 1, memory_order_relaxed);
518
+ return gptoss_status_success;
519
+ }
520
+
521
+ enum gptoss_status GPTOSS_ABI gptoss_model_release(
522
+ gptoss_model_t model)
523
+ {
524
+ if (model != NULL) {
525
+ if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
526
+ gptoss_tokenizer_release(model->tokenizer);
527
+
528
+ // Weight buffers
529
+ gptoss_metal_buffer_release(&model->shared_weight_buffer);
530
+ for (uint32_t n = 0; n < model->num_blocks; n++) {
531
+ gptoss_metal_buffer_release(&model->block_weight_buffers[n]);
532
+ }
533
+
534
+ // Metal kernels
535
+ gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
536
+ gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
537
+ gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
538
+ gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);
539
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);
540
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);
541
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);
542
+ gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
543
+ gptoss_metal_function_release(&model->f32_rope_fn);
544
+ gptoss_metal_function_release(&model->f32_expert_routing_metadata_fn);
545
+ gptoss_metal_function_release(&model->f32_scatter_e4_fn);
546
+ gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_swiglu_fn);
547
+ gptoss_metal_function_release(&model->f32_mf4w_moe_dense_matmul_fn);
548
+ gptoss_metal_function_release(&model->f32_gather_and_accumulate_e4_fn);
549
+ gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
550
+ gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_fn);
551
+ gptoss_metal_function_release(&model->f32_accumulate_e4_fn);
552
+ gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
553
+ gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
554
+ gptoss_metal_function_release(&model->f32_softmax_fn);
555
+ gptoss_metal_function_release(&model->f32_sample_fn);
556
+ gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
557
+ gptoss_metal_library_release(&model->library);
558
+
559
+ gptoss_metal_command_queue_release(&model->command_queue);
560
+ gptoss_metal_device_release(&model->device);
561
+ // Weight buffers
562
+
563
+ if (model->mapping_ptr != NULL && model->mapping_size != 0) {
564
+ if (model->lock_memory) {
565
+ if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
566
+ GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
567
+ }
568
+ }
569
+
570
+ if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
571
+ GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
572
+ }
573
+ }
574
+
575
+ const size_t model_size = sizeof(struct gptoss_model) + model->num_blocks * sizeof(struct gptoss_metal_buffer);
576
+ memset(model, 0, model_size);
577
+ free(model);
578
+ }
579
+ }
580
+ return gptoss_status_success;
581
+ }
gptoss_kernels/source/moematmul.metal ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/kernel-args.h>
2
+ #include <metal_common>
3
+ #include <metal_compute>
4
+ #include <metal_math>
5
+ #include <metal_simdgroup>
6
+ #include <metal_stdlib>
7
+
8
+ #pragma METAL fp math_mode(safe)
9
+ #pragma METAL fp contract(off)
10
+ #define ceil_div(a, b) (((a) + (b) - 1) / (b))
11
+
12
+ // Each simdgroup reduces all channels of the input and computes a single channel of the output
13
+ // + Efficient synchronization
14
+ // + Sequential memory access within a warp
15
+ // Each threadgroup computes (simdgroups_per_threadgroup) consecutive output channels
16
+ // + Reuse input vector from threadgroup memory
17
+ // + Avoid synchronization across warps when doing reduction
18
+
19
+ kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
20
+ constant gptoss_moe_matmul_swiglu_args& args [[ buffer(0) ]],
21
+ const device float4* input [[ buffer(1) ]],
22
+ const device gptoss_expert_prediction* expert [[ buffer(2) ]],
23
+ const device uint4* weight_blocks [[ buffer(3) ]],
24
+ const device uchar* weight_scales [[ buffer(4) ]],
25
+ const device bfloat* bias [[ buffer(5) ]],
26
+ device float* output [[ buffer(6) ]],
27
+ const device gptoss_control* control [[ buffer(7) ]],
28
+ uint3 gid [[threadgroup_position_in_grid]],
29
+ uint tid [[thread_index_in_threadgroup]],
30
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
31
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
32
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
33
+ {
34
+ const uint simdgroup_size = 32;
35
+ threadgroup float threadgroup_buffer[32];
36
+ if (control->abort != 0) {
37
+ return;
38
+ }
39
+
40
+ const uint num_column_vecs = args.num_column_vecs;
41
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
42
+ const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
43
+
44
+ input += 8 * (gid.y * num_column_vecs + simdgroup_tid);
45
+ weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
46
+ weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
47
+ bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
48
+ output += gid.y * args.num_rows + gid.x * (num_simdgroups / 2) + gid.z * args.output_expert_stride;
49
+
50
+ uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
51
+
52
+ float4 sum4 = 0.0f;
53
+ do {
54
+ const uint4 wblock = *weight_blocks;
55
+ const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
56
+ uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
57
+ uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
58
+ wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
59
+ wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
60
+ wblock02468ACEGIKMOQSU += 0x70707070u;
61
+ wblock13579BDFHJLNPRTV += 0x70707070u;
62
+ wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
63
+ wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
64
+ const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
65
+ const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
66
+ const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
67
+ const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
68
+ const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
69
+ const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
70
+ const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
71
+ const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
72
+ const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
73
+ const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
74
+ const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
75
+ const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
76
+
77
+ const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
78
+ const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
79
+ const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
80
+ const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
81
+ const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
82
+ const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
83
+ const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
84
+ const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
85
+
86
+ const float4 i0123 = input[0];
87
+ const float4 i4567 = input[1];
88
+ const float4 i89AB = input[2];
89
+ const float4 iCDEF = input[3];
90
+ const float4 iGHIJ = input[4];
91
+ const float4 iKLMN = input[5];
92
+ const float4 iOPQR = input[6];
93
+ const float4 iSTUV = input[7];
94
+
95
+ float4 psum0 = i0123 * w0123;
96
+ float4 psum1 = i4567 * w4567;
97
+ psum0 = metal::fma(i89AB, w89AB, psum0);
98
+ psum1 = metal::fma(iCDEF, wCDEF, psum1);
99
+ psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
100
+ psum1 = metal::fma(iKLMN, wKLMN, psum1);
101
+ psum0 = metal::fma(iOPQR, wOPQR, psum0);
102
+ psum1 = metal::fma(iSTUV, wSTUV, psum1);
103
+ sum4 = metal::fma(psum0, wscale, sum4);
104
+ sum4 = metal::fma(psum1, wscale, sum4);
105
+
106
+ weight_blocks += simdgroup_size;
107
+ weight_scales += simdgroup_size;
108
+ input += 8 * simdgroup_size;
109
+ } while (--num_iter != 0);
110
+ const float2 sum2 = sum4.xy + sum4.zw;
111
+ float sum = sum2.x + sum2.y;
112
+ sum = metal::simd_sum(sum);
113
+ if (metal::simd_is_first()) {
114
+ sum += static_cast<float>(*bias);
115
+ threadgroup_buffer[simdgroup_idx] = sum;
116
+ }
117
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
118
+ if (tid * 2 < num_simdgroups) {
119
+ const float2 x = reinterpret_cast<const threadgroup float2*>(threadgroup_buffer)[tid];
120
+ const float swish_x = metal::min(x.x, args.swiglu_max);
121
+ const float linear_x = metal::clamp(x.y, args.swiglu_min, args.swiglu_max);
122
+ const float alpha = 1.702f;
123
+ const float swish_y = swish_x / (1.0f + metal::precise::exp(-alpha * swish_x));
124
+ const float swiglu_y = metal::fma(swish_y, linear_x, swish_y);
125
+ output[tid] = swiglu_y;
126
+ }
127
+ }
128
+
129
+ kernel void gptoss_f32_mf4w_moe_matmul(
130
+ constant gptoss_moe_matmul_args& args [[ buffer(0) ]],
131
+ const device float4* input [[ buffer(1) ]],
132
+ const device gptoss_expert_prediction* expert [[ buffer(2) ]],
133
+ const device uint4* weight_blocks [[ buffer(3) ]],
134
+ const device uchar* weight_scales [[ buffer(4) ]],
135
+ const device bfloat* bias [[ buffer(5) ]],
136
+ device float* output [[ buffer(6) ]],
137
+ const device gptoss_control* control [[ buffer(7) ]],
138
+ uint3 gid [[threadgroup_position_in_grid]],
139
+ uint tid [[thread_index_in_threadgroup]],
140
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
141
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
142
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
143
+ {
144
+ const uint simdgroup_size = 32;
145
+ if (control->abort != 0) {
146
+ return;
147
+ }
148
+
149
+ const uint num_column_vecs = args.num_column_vecs;
150
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
151
+ const uint expert_id = expert[gid.y * args.num_active_experts + gid.z].expert_id;
152
+
153
+ input += 8 * (gid.y * num_column_vecs + simdgroup_tid + gid.z * args.input_expert_stride);
154
+ weight_blocks = (const device uint4*) ((uintptr_t) (weight_blocks + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
155
+ weight_scales = (const device uchar*) ((uintptr_t) (weight_scales + num_column_vecs * row + simdgroup_tid) + expert_id * args.weight_expert_stride);
156
+ bias = (const device bfloat*) ((uintptr_t) (bias + row) + expert_id * args.weight_expert_stride);
157
+ output += gid.y * args.num_rows + row + gid.z * args.output_expert_stride;
158
+
159
+ uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
160
+
161
+ float4 sum4 = 0.0f;
162
+ do {
163
+ const uint4 wblock = *weight_blocks;
164
+ const float wscale = as_type<float>(static_cast<uint>(*weight_scales) << 23);
165
+ uint4 wblock02468ACEGIKMOQSU = wblock + wblock;
166
+ uint4 wblock13579BDFHJLNPRTV = wblock >> 3;
167
+ wblock02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
168
+ wblock13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
169
+ wblock02468ACEGIKMOQSU += 0x70707070u;
170
+ wblock13579BDFHJLNPRTV += 0x70707070u;
171
+ wblock02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
172
+ wblock13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
173
+ const uint4 wblock26AEIMQU = wblock02468ACEGIKMOQSU & 0xFF00FF00u;
174
+ const uint4 wblock048CGKOS = (wblock02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
175
+ const uint4 wblock37BFJNRV = wblock13579BDFHJLNPRTV & 0xFF00FF00u;
176
+ const uint4 wblock159DHLPT = (wblock13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
177
+ const float4 w048C = static_cast<float4>(as_type<half4>(wblock048CGKOS.xy));
178
+ const float4 wGKOS = static_cast<float4>(as_type<half4>(wblock048CGKOS.zw));
179
+ const float4 w26AE = static_cast<float4>(as_type<half4>(wblock26AEIMQU.xy));
180
+ const float4 wIMQU = static_cast<float4>(as_type<half4>(wblock26AEIMQU.zw));
181
+ const float4 w159D = static_cast<float4>(as_type<half4>(wblock159DHLPT.xy));
182
+ const float4 wHLPT = static_cast<float4>(as_type<half4>(wblock159DHLPT.zw));
183
+ const float4 w37BF = static_cast<float4>(as_type<half4>(wblock37BFJNRV.xy));
184
+ const float4 wJNRV = static_cast<float4>(as_type<half4>(wblock37BFJNRV.zw));
185
+
186
+ const float4 w0123 = (float4) { w048C.x, w159D.x, w26AE.x, w37BF.x };
187
+ const float4 w4567 = (float4) { w048C.y, w159D.y, w26AE.y, w37BF.y };
188
+ const float4 w89AB = (float4) { w048C.z, w159D.z, w26AE.z, w37BF.z };
189
+ const float4 wCDEF = (float4) { w048C.w, w159D.w, w26AE.w, w37BF.w };
190
+ const float4 wGHIJ = (float4) { wGKOS.x, wHLPT.x, wIMQU.x, wJNRV.x };
191
+ const float4 wKLMN = (float4) { wGKOS.y, wHLPT.y, wIMQU.y, wJNRV.y };
192
+ const float4 wOPQR = (float4) { wGKOS.z, wHLPT.z, wIMQU.z, wJNRV.z };
193
+ const float4 wSTUV = (float4) { wGKOS.w, wHLPT.w, wIMQU.w, wJNRV.w };
194
+
195
+ const float4 i0123 = input[0];
196
+ const float4 i4567 = input[1];
197
+ const float4 i89AB = input[2];
198
+ const float4 iCDEF = input[3];
199
+ const float4 iGHIJ = input[4];
200
+ const float4 iKLMN = input[5];
201
+ const float4 iOPQR = input[6];
202
+ const float4 iSTUV = input[7];
203
+
204
+ float4 psum0 = i0123 * w0123;
205
+ float4 psum1 = i4567 * w4567;
206
+ psum0 = metal::fma(i89AB, w89AB, psum0);
207
+ psum1 = metal::fma(iCDEF, wCDEF, psum1);
208
+ psum0 = metal::fma(iGHIJ, wGHIJ, psum0);
209
+ psum1 = metal::fma(iKLMN, wKLMN, psum1);
210
+ psum0 = metal::fma(iOPQR, wOPQR, psum0);
211
+ psum1 = metal::fma(iSTUV, wSTUV, psum1);
212
+ sum4 = metal::fma(psum0, wscale, sum4);
213
+ sum4 = metal::fma(psum1, wscale, sum4);
214
+
215
+ weight_blocks += simdgroup_size;
216
+ weight_scales += simdgroup_size;
217
+ input += 8 * simdgroup_size;
218
+ } while (--num_iter != 0);
219
+ const float2 sum2 = sum4.xy + sum4.zw;
220
+ float sum = sum2.x + sum2.y;
221
+ sum = metal::simd_sum(sum);
222
+ if (metal::simd_is_first()) {
223
+ sum += static_cast<float>(*bias);
224
+ *output = sum;
225
+ }
226
+ }
227
+
228
+ kernel void gptoss_f32_mf4w_moe_dense_matmul_swiglu(
229
+ constant gptoss_moe_dense_matmul_swiglu_args& params [[ buffer(0) ]],
230
+ const device uint* __restrict__ expert_offsets [[ buffer(1) ]],
231
+ const device float* lhs [[ buffer(2) ]],
232
+ const device uint* weight_blocks [[ buffer(3) ]],
233
+ const device uchar* weight_scales [[ buffer(4) ]],
234
+ const device bfloat* __restrict__ bias [[ buffer(5) ]],
235
+ device float* out [[ buffer(6) ]],
236
+ uint sg_id [[simdgroup_index_in_threadgroup]],
237
+ uint3 threads_per_tg [[threads_per_threadgroup]],
238
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
239
+ uint3 gid [[thread_position_in_grid]],
240
+ uint3 tg_id [[threadgroup_position_in_grid]],
241
+ uint3 local_tid [[thread_position_in_threadgroup]])
242
+ {
243
+ constexpr uint Bm = MOE_DENSE_MATMUL_SWIGLU_Bm;
244
+ constexpr uint Bn = MOE_DENSE_MATMUL_SWIGLU_Bn;
245
+ constexpr uint Bk = MOE_DENSE_MATMUL_SWIGLU_Bk;
246
+ constexpr uint Sg_Bm = MOE_DENSE_MATMUL_SWIGLU_Sg_Bm;
247
+ constexpr uint Sg_Bn = MOE_DENSE_MATMUL_SWIGLU_Sg_Bn;
248
+
249
+ // Assumptions about shapes.
250
+ assert(Bm % 8 == 0);
251
+ assert(Bn % 8 == 0);
252
+ assert(Bk % 8 == 0);
253
+ assert(Sg_Bm % 8 == 0);
254
+ assert(Sg_Bn % 8 == 0);
255
+ assert(Bm % Sg_Bm == 0);
256
+ assert(Bn % Sg_Bn == 0);
257
+
258
+ const uint K = params.k;
259
+ const uint N = params.n;
260
+ const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];
261
+ assert((K % 32) == 0);
262
+ assert((K % 8) == 0);
263
+ assert(N % Bn == 0);
264
+ assert(K % Bk == 0);
265
+ // Get row and col tg.
266
+ const uint row_tg = tg_id.y;
267
+ const uint col_tg = tg_id.x;
268
+ // Get row and col local tid.
269
+ const uint row_tg_offset = row_tg * Bm;
270
+ const uint col_tg_offset = col_tg * Bn;
271
+ if (row_tg_offset >= M || col_tg_offset >= N) {
272
+ return;
273
+ }
274
+ // Move lhs and output according to the passed offset.
275
+ const uint expert_offset = expert_offsets[tg_id.z];
276
+ lhs += expert_offset * K;
277
+ const uint N_output = N / 2;
278
+ out += expert_offset * N_output;
279
+
280
+ const uint S = params.weight_blocks_expert_stride_bytes;
281
+ const uint S_scales = params.weight_scales_expert_stride_bytes;
282
+ const uint S_bias = params.bias_expert_stride_bytes;
283
+
284
+ const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);
285
+ const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);
286
+ const device char* bi0 = reinterpret_cast<const device char*>(bias);
287
+
288
+ weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);
289
+ weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);
290
+ bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);
291
+
292
+ const uint sg_col_count = Bn / Sg_Bn;
293
+ const uint row_sg = sg_id / sg_col_count;
294
+ const uint col_sg = sg_id % sg_col_count;
295
+
296
+ const uint row_sg_offset = row_sg * Sg_Bm;
297
+ const uint col_sg_offset = col_sg * Sg_Bn;
298
+ // Declare threadgroup blocks.
299
+ threadgroup float lhs_block[Bm * Bk];
300
+ // rhs_block will hold the scaled fp32 weights.
301
+ threadgroup float rhs_block[Bn * Bk];
302
+
303
+ constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
304
+ // Create an array of simdgroup_float8x8 to hold temp results.
305
+ metal::simdgroup_float8x8 OutTiles[temp_result_size];
306
+ for (uint i = 0; i < temp_result_size; i++) {
307
+ OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);
308
+ }
309
+ // Linear thread id within TG (we launch 1-D TGs)
310
+ const uint lin_tid = local_tid.x;
311
+ const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;
312
+
313
+ // Iterate over all Bk blocks.
314
+ for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
315
+ constexpr uint lhs_row_stride = Bk;
316
+ constexpr uint lhs_vec_cols = Bk / 4;
317
+ constexpr uint lhs_vec_total = Bm * lhs_vec_cols;
318
+
319
+ const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);
320
+
321
+ // #pragma clang loop unroll(full)
322
+ for (uint t = 0; t < LHS_ITERS; ++t) {
323
+ const uint i = t * thread_count_per_tg + lin_tid;
324
+ if (i < lhs_vec_total) {
325
+ const uint r = i / lhs_vec_cols;
326
+ const uint c4 = i % lhs_vec_cols;
327
+
328
+ const uint gr = row_tg_offset + r;
329
+ const uint gc4 = (k_offset / 4) + c4;
330
+
331
+ threadgroup float4* dst4 =
332
+ reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));
333
+ if (gr < M) {
334
+ const device float4* src4 =
335
+ reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));
336
+
337
+ *dst4 = *src4;
338
+ } else {
339
+ *dst4 = float4(0.0);
340
+ }
341
+ }
342
+ }
343
+
344
+ // Load weights with vector loads.
345
+ constexpr uint rhs_row_stride = Bk;
346
+ constexpr uint weights_per_elem = 8;
347
+ constexpr uint rhs_loads_per_col = Bk / weights_per_elem;
348
+ constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;
349
+ const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);
350
+ // #pragma clang loop unroll(full)
351
+ for (uint t = 0; t < RHS_ITERS; ++t) {
352
+ const uint i = t * thread_count_per_tg + lin_tid;
353
+ if (i < rhs_loads_total) {
354
+ const uint r = i / rhs_loads_per_col;
355
+ const uint c = i % rhs_loads_per_col;
356
+
357
+ const uint gr = col_tg_offset + r;
358
+ const uint gc = (k_offset / weights_per_elem) + c;
359
+ const uint gc_scale = (k_offset / 32) + (c >> 2);
360
+
361
+ const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];
362
+ const float scale =
363
+ as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);
364
+ uint wblock0246 = (wblock + wblock);
365
+ uint wblock1357 = (wblock >> 3);
366
+ wblock0246 &= 0x1E1E1E1Eu;
367
+ wblock1357 &= 0x1E1E1E1Eu;
368
+
369
+ wblock0246 += 0x70707070u;
370
+ wblock1357 += 0x70707070u;
371
+ wblock0246 &= 0x8E8E8E8Eu;
372
+ wblock1357 &= 0x8E8E8E8Eu;
373
+
374
+ uint wblock26 = (wblock0246) & 0xFF00FF00u;
375
+ uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;
376
+ uint wblock37 = (wblock1357) & 0xFF00FF00u;
377
+ uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;
378
+
379
+ half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));
380
+ half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));
381
+
382
+ // Convert to float scalars and apply scale
383
+ const float w0 = float(wblock0426.x) * scale;
384
+ const float w1 = float(wblock1537.x) * scale;
385
+ const float w2 = float(wblock0426.z) * scale;
386
+ const float w3 = float(wblock1537.z) * scale;
387
+ const float w4 = float(wblock0426.y) * scale;
388
+ const float w5 = float(wblock1537.y) * scale;
389
+ const float w6 = float(wblock0426.w) * scale;
390
+ const float w7 = float(wblock1537.w) * scale;
391
+ const uint rhs_offset = r * rhs_row_stride + c * 8;
392
+ rhs_block[rhs_offset] = w0;
393
+ rhs_block[rhs_offset + 1] = w1;
394
+ rhs_block[rhs_offset + 2] = w2;
395
+ rhs_block[rhs_offset + 3] = w3;
396
+ rhs_block[rhs_offset + 4] = w4;
397
+ rhs_block[rhs_offset + 5] = w5;
398
+ rhs_block[rhs_offset + 6] = w6;
399
+ rhs_block[rhs_offset + 7] = w7;
400
+ }
401
+ }
402
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
403
+ #pragma clang loop unroll(full)
404
+ for (uint k = 0; k < Bk; k += 8) {
405
+ #pragma clang loop unroll(full)
406
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
407
+ const uint row_index_in_out_tile = m_subtile_ / 8;
408
+ metal::simdgroup_float8x8 lhs_frag;
409
+
410
+ simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));
411
+ #pragma clang loop unroll(full)
412
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
413
+ const uint col_index_in_out_tile = n_subtile_ / 8;
414
+ const uint current_index_out_tile =
415
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
416
+ metal::simdgroup_float8x8 rhs_frag;
417
+ simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);
418
+
419
+ simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,
420
+ OutTiles[current_index_out_tile]);
421
+ }
422
+ }
423
+ }
424
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
425
+ }
426
+
427
+ // Epilogue.
428
+ threadgroup float scratch[Bm * Bn];
429
+ #pragma clang loop unroll(full)
430
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
431
+ const uint col_index_in_out_tile = n_subtile_ / 8;
432
+ const uint local_col_offset = col_sg_offset + n_subtile_;
433
+ #pragma clang loop unroll(full)
434
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
435
+ const uint row_index_in_out_tile = m_subtile_ / 8;
436
+ const uint local_row_offset = row_sg_offset + m_subtile_;
437
+ const uint current_index_out_tile =
438
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
439
+ simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
440
+ ulong2(local_col_offset, local_row_offset));
441
+ }
442
+ }
443
+ threadgroup float bias_tile[Bn];
444
+ // TODO(ibahmed): vectorize these loads an maybe unroll the loop.
445
+ for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {
446
+ const uint c_global = col_tg_offset + c_local;
447
+ bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
448
+ }
449
+
450
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
451
+ const float alpha = 1.702f;
452
+ // TODO(ibahmed): vectorize these stores and maybe unroll the loop.
453
+ for (uint idx = local_tid.x; idx < Bm * Bn / 2; idx += thread_count_per_tg) {
454
+ const uint idx_swish = idx * 2;
455
+ const uint r = idx_swish / Bn;
456
+ const uint c_swish = idx_swish % Bn;
457
+
458
+ const uint out_row = row_tg_offset + r;
459
+ const uint out_col = (col_tg_offset / 2) + (c_swish / 2);
460
+
461
+ if (out_row < M && out_col < N_output) {
462
+ float acc_swish = scratch[idx_swish] + bias_tile[c_swish];
463
+ float acc_linear = scratch[idx_swish + 1] + bias_tile[c_swish + 1];
464
+ const float swish = metal::min(acc_swish, params.swiglu_max);
465
+ const float linear = metal::clamp(acc_linear, params.swiglu_min, params.swiglu_max);
466
+ const float swish_y = swish / (1.0f + metal::precise::exp(-alpha * swish));
467
+ const float swiglu_y = metal::fma(swish_y, linear, swish_y);
468
+ out[out_row * N_output + out_col] = swiglu_y;
469
+ }
470
+ }
471
+ }
472
+
473
+ kernel void gptoss_f32_mf4w_moe_dense_matmul(
474
+ constant gptoss_moe_dense_matmul_args& params [[ buffer(0) ]],
475
+ const device uint* __restrict__ expert_offsets [[ buffer(1) ]],
476
+ const device float* lhs [[ buffer(2) ]],
477
+ const device uint* weight_blocks [[ buffer(3) ]],
478
+ const device uchar* weight_scales [[ buffer(4) ]],
479
+ const device bfloat* __restrict__ bias [[ buffer(5) ]],
480
+ device float* out [[ buffer(6) ]],
481
+ uint sg_id [[simdgroup_index_in_threadgroup]],
482
+ uint3 threads_per_tg [[threads_per_threadgroup]],
483
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
484
+ uint3 gid [[thread_position_in_grid]],
485
+ uint3 tg_id [[threadgroup_position_in_grid]],
486
+ uint3 local_tid [[thread_position_in_threadgroup]])
487
+ {
488
+ const uint Bm = MOE_DENSE_MATMUL_Bm;
489
+ const uint Bn = MOE_DENSE_MATMUL_Bn;
490
+ const uint Bk = MOE_DENSE_MATMUL_Bk;
491
+ const uint Sg_Bm = MOE_DENSE_MATMUL_Sg_Bm;
492
+ const uint Sg_Bn = MOE_DENSE_MATMUL_Sg_Bn;
493
+ assert(Bm % 8 == 0);
494
+ assert(Bn % 8 == 0);
495
+ assert(Bk % 8 == 0);
496
+ assert(Sg_Bm % 8 == 0);
497
+ assert(Sg_Bn % 8 == 0);
498
+ assert(Bm % Sg_Bm == 0);
499
+ assert(Bn % Sg_Bn == 0);
500
+
501
+ const uint K = params.k;
502
+ const uint N = params.n;
503
+ const uint M = expert_offsets[tg_id.z + 1] - expert_offsets[tg_id.z];
504
+ assert((K % 32) == 0);
505
+ assert((K % 8) == 0);
506
+ assert(N % Bn == 0);
507
+ assert(K % Bk == 0);
508
+ // Get row and col tg.
509
+ const uint row_tg = tg_id.y;
510
+ const uint col_tg = tg_id.x;
511
+ // Get row and col local tid.
512
+ const uint row_tg_offset = row_tg * Bm;
513
+ const uint col_tg_offset = col_tg * Bn;
514
+ if (row_tg_offset >= M || col_tg_offset >= N) {
515
+ return;
516
+ }
517
+ // Move lhs and output according to the passed offset.
518
+ const uint expert_offset = expert_offsets[tg_id.z];
519
+ lhs += expert_offset * K;
520
+ out += expert_offset * N;
521
+
522
+ const uint S = params.weight_blocks_expert_stride_bytes;
523
+ const uint S_scales = params.weight_scales_expert_stride_bytes;
524
+ const uint S_bias = params.bias_expert_stride_bytes;
525
+
526
+ const device char* wb0 = reinterpret_cast<const device char*>(weight_blocks);
527
+ const device char* sc0 = reinterpret_cast<const device char*>(weight_scales);
528
+ const device char* bi0 = reinterpret_cast<const device char*>(bias);
529
+
530
+ weight_blocks = reinterpret_cast<const device uint*>(wb0 + tg_id.z * S);
531
+ weight_scales = reinterpret_cast<const device uchar*>(sc0 + tg_id.z * S_scales);
532
+ bias = reinterpret_cast<const device bfloat*>(bi0 + tg_id.z * S_bias);
533
+
534
+ const uint sg_col_count = Bn / Sg_Bn;
535
+ const uint row_sg = sg_id / sg_col_count;
536
+ const uint col_sg = sg_id % sg_col_count;
537
+
538
+ const uint row_sg_offset = row_sg * Sg_Bm;
539
+ const uint col_sg_offset = col_sg * Sg_Bn;
540
+ // Declare threadgroup blocks.
541
+ threadgroup float lhs_block[Bm * Bk];
542
+ // rhs_block will hold the scaled fp32 weights.
543
+ threadgroup float rhs_block[Bn * Bk];
544
+
545
+ constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
546
+ // Create an array of simdgroup_float8x8 to hold temp results.
547
+ metal::simdgroup_float8x8 OutTiles[temp_result_size];
548
+ for (uint i = 0; i < temp_result_size; i++) {
549
+ OutTiles[i] = metal::make_filled_simdgroup_matrix<float, 8, 8>(0.0);
550
+ }
551
+ // Linear thread id within TG (we launch 1-D TGs)
552
+ const uint lin_tid = local_tid.x;
553
+
554
+ const uint thread_count_per_tg = threads_per_tg.x * threads_per_tg.y * threads_per_tg.z;
555
+ // Iterate over all Bk blocks.
556
+ for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
557
+ constexpr uint lhs_row_stride = Bk;
558
+ constexpr uint lhs_vec_cols = Bk / 4;
559
+ constexpr uint lhs_vec_total = Bm * lhs_vec_cols;
560
+
561
+ const uint LHS_ITERS = ceil_div(lhs_vec_total, thread_count_per_tg);
562
+
563
+ for (uint t = 0; t < LHS_ITERS; ++t) {
564
+ const uint i = t * thread_count_per_tg + lin_tid;
565
+ if (i < lhs_vec_total) {
566
+ const uint r = i / lhs_vec_cols;
567
+ const uint c4 = i % lhs_vec_cols;
568
+
569
+ const uint gr = row_tg_offset + r;
570
+ const uint gc4 = (k_offset / 4) + c4;
571
+
572
+ threadgroup float4* dst4 =
573
+ reinterpret_cast<threadgroup float4*>(lhs_block + r * lhs_row_stride + (c4 << 2));
574
+ if (gr < M) {
575
+ const device float4* src4 =
576
+ reinterpret_cast<const device float4*>(lhs + gr * K + (gc4 << 2));
577
+
578
+ *dst4 = *src4;
579
+ } else {
580
+ *dst4 = float4(0.0);
581
+ }
582
+ }
583
+ }
584
+
585
+ // Load weights with vector loads.
586
+ constexpr uint rhs_row_stride = Bk;
587
+ constexpr uint weights_per_elem = 8;
588
+ constexpr uint rhs_loads_per_col = Bk / weights_per_elem;
589
+ constexpr uint rhs_loads_total = Bn * rhs_loads_per_col;
590
+ const uint RHS_ITERS = ceil_div(rhs_loads_total, thread_count_per_tg);
591
+ // #pragma clang loop unroll(full)
592
+ for (uint t = 0; t < RHS_ITERS; ++t) {
593
+ const uint i = t * thread_count_per_tg + lin_tid;
594
+ if (i < rhs_loads_total) {
595
+ const uint r = i / rhs_loads_per_col;
596
+ const uint c = i % rhs_loads_per_col;
597
+
598
+ const uint gr = col_tg_offset + r;
599
+ const uint gc = (k_offset / weights_per_elem) + c;
600
+ const uint gc_scale = (k_offset / 32) + (c >> 2);
601
+
602
+ const uint wblock = weight_blocks[gr * (K / weights_per_elem) + gc];
603
+ const float scale =
604
+ as_type<float>(static_cast<uint>(weight_scales[gr * (K / 32) + gc_scale]) << 23);
605
+
606
+ uint wblock0246 = (wblock + wblock);
607
+ uint wblock1357 = (wblock >> 3);
608
+ wblock0246 &= 0x1E1E1E1Eu;
609
+ wblock1357 &= 0x1E1E1E1Eu;
610
+
611
+ wblock0246 += 0x70707070u;
612
+ wblock1357 += 0x70707070u;
613
+ wblock0246 &= 0x8E8E8E8Eu;
614
+ wblock1357 &= 0x8E8E8E8Eu;
615
+
616
+ uint wblock26 = (wblock0246) & 0xFF00FF00u;
617
+ uint wblock04 = ((wblock0246 << 8)) & 0xFF00FF00u;
618
+ uint wblock37 = (wblock1357) & 0xFF00FF00u;
619
+ uint wblock15 = ((wblock1357 << 8)) & 0xFF00FF00u;
620
+
621
+ half4 wblock0426 = as_type<half4>(uint2(wblock04, wblock26));
622
+ half4 wblock1537 = as_type<half4>(uint2(wblock15, wblock37));
623
+
624
+ const float w0 = float(wblock0426.x) * scale;
625
+ const float w1 = float(wblock1537.x) * scale;
626
+ const float w2 = float(wblock0426.z) * scale;
627
+ const float w3 = float(wblock1537.z) * scale;
628
+ const float w4 = float(wblock0426.y) * scale;
629
+ const float w5 = float(wblock1537.y) * scale;
630
+ const float w6 = float(wblock0426.w) * scale;
631
+ const float w7 = float(wblock1537.w) * scale;
632
+ const uint rhs_offset = r * rhs_row_stride + c * 8;
633
+ rhs_block[rhs_offset] = w0;
634
+ rhs_block[rhs_offset + 1] = w1;
635
+ rhs_block[rhs_offset + 2] = w2;
636
+ rhs_block[rhs_offset + 3] = w3;
637
+ rhs_block[rhs_offset + 4] = w4;
638
+ rhs_block[rhs_offset + 5] = w5;
639
+ rhs_block[rhs_offset + 6] = w6;
640
+ rhs_block[rhs_offset + 7] = w7;
641
+ }
642
+ }
643
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
644
+ #pragma clang loop unroll(full)
645
+ for (uint k = 0; k < Bk; k += 8) {
646
+ #pragma clang loop unroll(full)
647
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
648
+ const uint row_index_in_out_tile = m_subtile_ / 8;
649
+ metal::simdgroup_float8x8 lhs_frag;
650
+
651
+ simdgroup_load(lhs_frag, lhs_block, Bk, ulong2(k, m_subtile_ + row_sg_offset));
652
+ #pragma clang loop unroll(full)
653
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
654
+ const uint col_index_in_out_tile = n_subtile_ / 8;
655
+ const uint current_index_out_tile =
656
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
657
+ metal::simdgroup_float8x8 rhs_frag;
658
+ simdgroup_load(rhs_frag, rhs_block, Bk, ulong2(k, n_subtile_ + col_sg_offset), true);
659
+ simdgroup_multiply_accumulate(OutTiles[current_index_out_tile], lhs_frag, rhs_frag,
660
+ OutTiles[current_index_out_tile]);
661
+ }
662
+ }
663
+ }
664
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
665
+ }
666
+
667
+ // Epilogue.
668
+ threadgroup float scratch[Bm * Bn];
669
+ #pragma clang loop unroll(full)
670
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
671
+ const uint col_index_in_out_tile = n_subtile_ / 8;
672
+ const uint local_col_offset = col_sg_offset + n_subtile_;
673
+ #pragma clang loop unroll(full)
674
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
675
+ const uint row_index_in_out_tile = m_subtile_ / 8;
676
+ const uint local_row_offset = row_sg_offset + m_subtile_;
677
+ const uint current_index_out_tile =
678
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
679
+ simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
680
+ ulong2(local_col_offset, local_row_offset));
681
+ }
682
+ }
683
+ threadgroup float bias_tile[Bn];
684
+ for (uint c_local = local_tid.x; c_local < Bn; c_local += thread_count_per_tg) {
685
+ const uint c_global = col_tg_offset + c_local;
686
+ bias_tile[c_local] = (c_global < N) ? static_cast<float>(bias[c_global]) : 0.0f;
687
+ }
688
+
689
+ threadgroup_barrier(metal::mem_flags::mem_threadgroup);
690
+ for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {
691
+ const uint r = idx / Bn;
692
+ const uint c = idx % Bn;
693
+
694
+ const uint out_row = row_tg_offset + r;
695
+ const uint out_col = col_tg_offset + c;
696
+
697
+ if (out_row < M && out_col < N) {
698
+ float acc = scratch[idx] + bias_tile[c];
699
+ out[out_row * N + out_col] = acc;
700
+ }
701
+ }
702
+ }
gptoss_kernels/source/random.metal ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_integer>
2
+ #include <metal_math>
3
+
4
+ #include <internal/kernel-args.h>
5
+
6
+ #pragma METAL fp math_mode(safe)
7
+ #pragma METAL fp contract(off)
8
+
9
+
10
+ inline static uint rng_squares32(ulong offset, ulong seed) {
11
+ const ulong y = offset * seed;
12
+ const ulong z = y + seed;
13
+
14
+ /* Round 1 */
15
+ ulong x = y * y + y;
16
+ x = metal::rotate(x, 32ul);
17
+
18
+ /* Round 2 */
19
+ x = x * x + z;
20
+ x = metal::rotate(x, 32ul);
21
+
22
+ /* Round 3 */
23
+ x = x * x + y;
24
+ x = metal::rotate(x, 32ul);
25
+
26
+ /* Round 4 */
27
+ x = x * x + z;
28
+ return as_type<uint2>(x).y;
29
+ }
30
+
31
+ kernel void gptoss_u32_fill_random(
32
+ constant gptoss_u32_fill_random_args& args [[ buffer(0) ]],
33
+ device uint* output [[ buffer(1) ]],
34
+ uint gid [[threadgroup_position_in_grid]],
35
+ uint tid [[thread_position_in_threadgroup]],
36
+ uint threadgroup_size [[ threads_per_threadgroup ]])
37
+ {
38
+ const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
39
+ const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
40
+ const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
41
+ const ulong thread_start = threadgroup_start + tid;
42
+ uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
43
+
44
+ output += thread_start;
45
+ ulong offset = args.offset + thread_start;
46
+ for (; num_iter != 0; num_iter--) {
47
+ *output = rng_squares32(offset, args.seed);
48
+ output += threadgroup_size;
49
+ offset += threadgroup_size;
50
+ }
51
+ }
52
+
53
+ kernel void gptoss_f32_fill_random(
54
+ constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
55
+ device float* output [[ buffer(1) ]],
56
+ uint gid [[threadgroup_position_in_grid]],
57
+ uint tid [[thread_position_in_threadgroup]],
58
+ uint threadgroup_size [[ threads_per_threadgroup ]])
59
+ {
60
+ const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
61
+ const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
62
+ const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
63
+ const ulong thread_start = threadgroup_start + tid;
64
+ uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
65
+
66
+ output += thread_start;
67
+ ulong offset = args.offset + thread_start;
68
+ for (; num_iter != 0; num_iter--) {
69
+ const uint word = rng_squares32(offset, args.seed);
70
+ *output = metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias);
71
+ output += threadgroup_size;
72
+ offset += threadgroup_size;
73
+ }
74
+ }
75
+
76
+ kernel void gptoss_bf16_fill_random(
77
+ constant gptoss_f32_fill_random_args& args [[ buffer(0) ]],
78
+ device bfloat* output [[ buffer(1) ]],
79
+ uint gid [[threadgroup_position_in_grid]],
80
+ uint tid [[thread_position_in_threadgroup]],
81
+ uint threadgroup_size [[ threads_per_threadgroup ]])
82
+ {
83
+ const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
84
+ const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
85
+ const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
86
+ const ulong thread_start = threadgroup_start + tid;
87
+ uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
88
+
89
+ output += thread_start;
90
+ ulong offset = args.offset + thread_start;
91
+ for (; num_iter != 0; num_iter--) {
92
+ const uint word = rng_squares32(offset, args.seed);
93
+ *output = static_cast<bfloat>(metal::fma(static_cast<float>(as_type<int>(word)), args.scale, args.bias));
94
+ output += threadgroup_size;
95
+ offset += threadgroup_size;
96
+ }
97
+ }
gptoss_kernels/source/rmsnorm.metal ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_compute>
2
+ #include <metal_math>
3
+ #include <metal_simdgroup>
4
+
5
+ #include <internal/kernel-args.h>
6
+
7
+ #pragma METAL fp math_mode(safe)
8
+ #pragma METAL fp contract(off)
9
+
10
+
11
+ [[max_total_threads_per_threadgroup(1024)]]
12
+ kernel void gptoss_f32_bf16w_rmsnorm(
13
+ constant gptoss_rmsnorm_args& args [[ buffer(0) ]],
14
+ const device float4* input [[ buffer(1) ]],
15
+ const device bfloat4* weights [[ buffer(2) ]],
16
+ device float4* output [[ buffer(3) ]],
17
+ const device gptoss_control* control [[ buffer(4) ]],
18
+ uint gid [[threadgroup_position_in_grid]],
19
+ uint tid [[thread_position_in_threadgroup]],
20
+ uint threadgroup_size [[ threads_per_threadgroup ]])
21
+ {
22
+ const uint simdgroup_size = 32;
23
+ threadgroup float threadgroup_buffer[32];
24
+ if (control->abort != 0) {
25
+ return;
26
+ }
27
+
28
+ input += gid * args.num_vecs;
29
+ output += gid * args.num_vecs;
30
+
31
+ float4 sumsq4 = 0.0f;
32
+ for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
33
+ const float4 val = input[i];
34
+ sumsq4 = metal::fma(val, val, sumsq4);
35
+ }
36
+
37
+ // Tree-reduce sumsq within thread, then all-reduce within threadgroup.
38
+ const float2 sumsq2 = sumsq4.xy + sumsq4.zw;
39
+ float sumsq = sumsq2.x + sumsq2.y;
40
+ // Warning: this all-reduce works only for simdgroup of 32 threads and threadgroup of 32*32=1024 threads.
41
+ sumsq = metal::simd_sum(sumsq);
42
+ if (metal::simd_is_first()) {
43
+ const uint simdgroup_idx = tid / simdgroup_size;
44
+ threadgroup_buffer[simdgroup_idx] = sumsq;
45
+ }
46
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
47
+ const uint simdgroup_tid = tid % simdgroup_size;
48
+ sumsq = threadgroup_buffer[simdgroup_tid];
49
+ sumsq = metal::simd_sum(sumsq);
50
+
51
+ const float avgsq = sumsq / args.num_channels;
52
+ const float scale = metal::precise::rsqrt(avgsq + args.epsilon);
53
+ for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
54
+ const float4 val = input[i] * scale;
55
+ const float4 weight_val = static_cast<float4>(weights[i]);
56
+ output[i] = val * weight_val;
57
+ }
58
+ }
gptoss_kernels/source/rope.metal ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_common>
2
+ #include <metal_math>
3
+
4
+ #include <internal/kernel-args.h>
5
+
6
+ #pragma METAL fp math_mode(safe)
7
+ #pragma METAL fp contract(off)
8
+
9
+
10
+ // Each thread handles 2 head elements.
11
+ // Each simdgroup handles one head (64 head elements).
12
+
13
+ kernel void gptoss_f32_rope(
14
+ constant gptoss_rope_args& args [[ buffer(0) ]],
15
+ device float2* activations [[ buffer(1) ]],
16
+ const device gptoss_control* control [[ buffer(2) ]],
17
+ uint2 gid [[thread_position_in_grid]])
18
+ {
19
+ const uint num_head_dims = 64;
20
+ if (control->abort != 0) {
21
+ return;
22
+ }
23
+
24
+ const float dim_idx = static_cast<float>(gid.x % (num_head_dims / 2));
25
+ const uint token_idx = args.token_offset + gid.y;
26
+ activations += gid.y * args.token_stride + gid.x;
27
+
28
+ const float2 input_vals = *activations;
29
+ const float inv_extrapolation_freq = metal::precise::exp(dim_idx * args.freq_scale);
30
+ const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
31
+ const float alpha = metal::saturate(metal::fma(dim_idx, args.yarn_scale, args.yarn_offset));
32
+ const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
33
+
34
+ const float phi = static_cast<float>(token_idx) * inv_freq;
35
+ const float yarn_multiplier = args.yarn_multiplier;
36
+ float cosphi;
37
+ const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
38
+ cosphi *= yarn_multiplier;
39
+
40
+ const float output_re = input_vals.x * cosphi - input_vals.y * sinphi;
41
+ const float output_im = input_vals.x * sinphi + input_vals.y * cosphi;
42
+ *activations = (float2) { output_re, output_im };
43
+ }
gptoss_kernels/source/sample.metal ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_compute>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_simdgroup>
5
+
6
+ #include <internal/kernel-args.h>
7
+
8
+ #pragma METAL fp math_mode(safe)
9
+ #pragma METAL fp contract(off)
10
+
11
+
12
+ inline static uint rng_squares32(ulong offset, ulong seed) {
13
+ const ulong y = offset * seed;
14
+ const ulong z = y + seed;
15
+
16
+ /* Round 1 */
17
+ ulong x = y * y + y;
18
+ x = metal::rotate(x, 32ul);
19
+
20
+ /* Round 2 */
21
+ x = x * x + z;
22
+ x = metal::rotate(x, 32ul);
23
+
24
+ /* Round 3 */
25
+ x = x * x + y;
26
+ x = metal::rotate(x, 32ul);
27
+
28
+ /* Round 4 */
29
+ x = x * x + z;
30
+ return as_type<uint2>(x).y;
31
+ }
32
+
33
+ kernel void gptoss_f32_softmax(
34
+ constant gptoss_softmax_args& args [[ buffer(0) ]],
35
+ const device float* score [[ buffer(1) ]],
36
+ const device uint2* argmax [[ buffer(2) ]],
37
+ device float* prob [[ buffer(3) ]],
38
+ device float* sum [[ buffer(4) ]],
39
+ const device gptoss_control* control [[ buffer(5) ]],
40
+ uint tidx [[thread_index_in_threadgroup]],
41
+ uint2 gid [[threadgroup_position_in_grid]],
42
+ uint2 threadgroup_size [[threads_per_threadgroup]],
43
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
44
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
45
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
46
+ {
47
+ threadgroup float threadgroup_sumexp[32];
48
+ if (control->abort != 0) {
49
+ return;
50
+ }
51
+
52
+ score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
53
+ prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
54
+ sum += gid.y * args.max_threadgroups;
55
+
56
+ uint max_bits = argmax[gid.y].y;
57
+ if (static_cast<int>(max_bits) >= 0) {
58
+ max_bits ^= 0x7FFFFFFFu;
59
+ }
60
+ const float max_val = as_type<float>(max_bits);
61
+ float sum_exp = 0.0f;
62
+ const uint num_vecs_per_threadgroup = metal::min(args.num_vecs - gid.x * args.num_vecs_per_threadgroup, args.num_vecs_per_threadgroup);
63
+ for (uint i = tidx; i < num_vecs_per_threadgroup; i += threadgroup_size.x) {
64
+ const float score_val = score[i];
65
+ const float prob_val = metal::precise::exp((score_val - max_val) * args.temperature);
66
+ prob[i] = prob_val;
67
+ sum_exp += prob_val;
68
+ }
69
+ sum_exp = metal::simd_sum(sum_exp);
70
+ if (metal::simd_is_first()) {
71
+ threadgroup_sumexp[simdgroup_idx] = sum_exp;
72
+ }
73
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
74
+ if (simdgroup_idx == 0) {
75
+ // Sum-Reduce threadgroup_sumexp
76
+ sum_exp = 0.0f;
77
+ if (simdgroup_tid < num_simdgroups) {
78
+ sum_exp = threadgroup_sumexp[simdgroup_tid];
79
+ }
80
+ sum_exp = metal::simd_sum(sum_exp);
81
+ if (metal::simd_is_first()) {
82
+ sum[gid.x] = sum_exp;
83
+ }
84
+ }
85
+ }
86
+
87
+ [[max_total_threads_per_threadgroup(1024)]]
88
+ kernel void gptoss_f32_sample(
89
+ constant gptoss_sample_args& args [[ buffer(0) ]],
90
+ device const float* prob [[ buffer(1) ]],
91
+ device const float* sum [[ buffer(2) ]],
92
+ device uint* prediction [[ buffer(3) ]],
93
+ device gptoss_control* control [[ buffer(4) ]],
94
+ uint tid [[thread_position_in_threadgroup]],
95
+ uint threadgroup_size [[threads_per_threadgroup]],
96
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
97
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
98
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
99
+ {
100
+ threadgroup float threadgroup_sum_buffer[32];
101
+ threadgroup uint threadgroup_idx_buffer[32];
102
+ threadgroup float threadgroup_cumsum_buffer[32];
103
+ if (control->abort != 0) {
104
+ return;
105
+ }
106
+
107
+ const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed);
108
+ float sample_cdf = static_cast<float>(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;
109
+
110
+ float cumsum = 0.0f;
111
+ if (tid < args.num_blocks) {
112
+ cumsum = sum[tid];
113
+ }
114
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
115
+ if (simdgroup_tid == 31) {
116
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
117
+ }
118
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
119
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
120
+ if (simdgroup_tid < num_simdgroups) {
121
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
122
+ if (simdgroup_tid < simdgroup_idx) {
123
+ threadgroup_cumsum = threadgroup_sum;
124
+ }
125
+ }
126
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
127
+ cumsum += metal::simd_sum(threadgroup_cumsum);
128
+
129
+ sample_cdf *= threadgroup_sum;
130
+ sample_cdf = metal::max(sample_cdf, 0x1.0p-149f);
131
+
132
+ // Find the block: the smallest tid where sample_cdf >= s
133
+ uint block_idx = args.num_blocks;
134
+ float block_sum = cumsum;
135
+ if (tid >= args.num_blocks - 1) {
136
+ block_idx = args.num_blocks - 1;
137
+ block_sum = 0.0f;
138
+ } else if (cumsum >= sample_cdf) {
139
+ block_idx = tid;
140
+ block_sum = 0.0f;
141
+ }
142
+ block_idx = metal::simd_min(block_idx);
143
+ block_sum = metal::simd_max(block_sum);
144
+ if (simdgroup_tid == 0) {
145
+ threadgroup_idx_buffer[simdgroup_idx] = block_idx;
146
+ threadgroup_cumsum_buffer[simdgroup_idx] = block_sum;
147
+ }
148
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
149
+ if (simdgroup_tid < num_simdgroups) {
150
+ block_idx = threadgroup_idx_buffer[simdgroup_tid];
151
+ block_sum = threadgroup_cumsum_buffer[simdgroup_tid];
152
+ }
153
+ block_idx = metal::simd_min(block_idx);
154
+ block_sum = metal::simd_max(block_sum);
155
+
156
+ const uint block_start = args.num_dims_per_block * block_idx;
157
+ const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims);
158
+ uint offset = block_start + tid;
159
+ float accumulated_sum = block_sum;
160
+ uint sample_idx;
161
+
162
+ // This loop must be threadgroup-uniform.
163
+ do {
164
+ // Find the token: the smallest tid where sample_cdf >= s
165
+ float cumsum = 0.0f;
166
+ if (offset < block_end) {
167
+ cumsum = prob[offset];
168
+ }
169
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
170
+ if (simdgroup_tid == 31) {
171
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
172
+ }
173
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
174
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
175
+ if (simdgroup_tid < num_simdgroups) {
176
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
177
+ if (simdgroup_tid < simdgroup_idx) {
178
+ threadgroup_cumsum = threadgroup_sum;
179
+ }
180
+ }
181
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
182
+ cumsum += metal::simd_sum(threadgroup_cumsum);
183
+ cumsum += accumulated_sum;
184
+
185
+ sample_idx = block_end;
186
+ if (offset >= block_end) {
187
+ // Trigger loop exit, with the last token in the block being sampled if no other candidate was found.
188
+ sample_idx = block_end - 1;
189
+ } else if (cumsum >= sample_cdf) {
190
+ sample_idx = offset;
191
+ }
192
+ sample_idx = metal::simd_min(sample_idx);
193
+ if (simdgroup_tid == 0) {
194
+ threadgroup_idx_buffer[simdgroup_idx] = sample_idx;
195
+ }
196
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
197
+ if (simdgroup_tid < num_simdgroups) {
198
+ sample_idx = threadgroup_idx_buffer[simdgroup_tid];
199
+ }
200
+ sample_idx = metal::simd_min(sample_idx);
201
+
202
+ offset += threadgroup_size;
203
+ accumulated_sum += threadgroup_sum;
204
+ } while (sample_idx == block_end);
205
+
206
+ if (tid == 0) {
207
+ *prediction = sample_idx;
208
+ }
209
+ }
gptoss_kernels/source/scatter.metal ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <internal/kernel-args.h>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_stdlib>
5
+
6
+ // TODO(ibrahim): This is not optimal as each thread only scatters a single float4. To amortize the
7
+ // cost of reading the expert id and offset for a token, we should let each thread scatter several
8
+ // float4s.
9
+ kernel void gptoss_f32_scatter_e4(
10
+ constant gptoss_scatter_args& args [[ buffer(0) ]],
11
+ const device float* in [[ buffer(1) ]],
12
+ const device gptoss_expert_prediction* __restrict__ expert_predictions [[ buffer(2) ]],
13
+ const device uint* __restrict__ expert_offsets [[ buffer(3) ]],
14
+ const device uint* __restrict__ intra_expert_offsets [[ buffer(4) ]],
15
+ device float* out [[ buffer(5) ]],
16
+ uint3 gid [[thread_position_in_grid]])
17
+ {
18
+ const uint total_tokens = args.tokens;
19
+ const uint active_experts_per_token = args.active_experts_per_token;
20
+ const uint embedding_dim = args.token_stride;
21
+ assert(embedding_dim % 4 == 0);
22
+ // Hard coded to top4 for now.
23
+ assert(active_experts_per_token == 4);
24
+ const uint row_in = gid.y;
25
+ if (row_in >= total_tokens) {
26
+ return;
27
+ }
28
+ // Consecutive threads in a tg read consecutive columns of the input.
29
+ const uint col_in_vec4 = gid.x;
30
+ const uint col_in = col_in_vec4 * 4;
31
+ if (col_in >= embedding_dim) {
32
+ return;
33
+ }
34
+ // Pointer to the piece of the input that we will copy to the top4 experts.
35
+ const device float4* src4 =
36
+ reinterpret_cast<const device float4*>(in + row_in * embedding_dim + col_in);
37
+
38
+ // Get the 4 destinations -- 4 experts.
39
+ const uint base = row_in * active_experts_per_token;
40
+ const uint expert0_id = expert_predictions[base].expert_id;
41
+ const uint expert1_id = expert_predictions[base + 1].expert_id;
42
+ const uint expert2_id = expert_predictions[base + 2].expert_id;
43
+ const uint expert3_id = expert_predictions[base + 3].expert_id;
44
+ const uint expert0_offset = expert_offsets[expert0_id];
45
+ const uint expert1_offset = expert_offsets[expert1_id];
46
+ const uint expert2_offset = expert_offsets[expert2_id];
47
+ const uint expert3_offset = expert_offsets[expert3_id];
48
+ const uint expert0_intra_expert_offset = intra_expert_offsets[base];
49
+ const uint expert1_intra_expert_offset = intra_expert_offsets[base + 1];
50
+ const uint expert2_intra_expert_offset = intra_expert_offsets[base + 2];
51
+ const uint expert3_intra_expert_offset = intra_expert_offsets[base + 3];
52
+ device float4* dst4_0 = reinterpret_cast<device float4*>(
53
+ out + (expert0_offset + expert0_intra_expert_offset) * embedding_dim + col_in);
54
+ device float4* dst4_1 = reinterpret_cast<device float4*>(
55
+ out + (expert1_offset + expert1_intra_expert_offset) * embedding_dim + col_in);
56
+ device float4* dst4_2 = reinterpret_cast<device float4*>(
57
+ out + (expert2_offset + expert2_intra_expert_offset) * embedding_dim + col_in);
58
+ device float4* dst4_3 = reinterpret_cast<device float4*>(
59
+ out + (expert3_offset + expert3_intra_expert_offset) * embedding_dim + col_in);
60
+ const float4 data = *src4;
61
+ *dst4_0 = data;
62
+ *dst4_1 = data;
63
+ *dst4_2 = data;
64
+ *dst4_3 = data;
65
+ }
gptoss_kernels/source/sdpa.metal ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_geometric>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_compute>
5
+ #include <metal_simdgroup>
6
+
7
+ #include <internal/kernel-args.h>
8
+
9
+ #pragma METAL fp math_mode(safe)
10
+ #pragma METAL fp contract(off)
11
+
12
+ // Each threadgroup handles 8 Q heads / 1 KV head for 1 token
13
+
14
+ kernel void gptoss_f32_sdpa_q8_d64(
15
+ constant gptoss_sdpa_args& args [[ buffer(0) ]],
16
+ const device float* q [[ buffer(1) ]],
17
+ const device float* kv [[ buffer(2) ]],
18
+ const device bfloat* s [[ buffer(3) ]],
19
+ device float* output [[ buffer(4) ]],
20
+ const device gptoss_control* control [[ buffer(6) ]],
21
+ threadgroup void* threadgroup_buffer [[ threadgroup(0) ]],
22
+ uint2 gid [[threadgroup_position_in_grid]],
23
+ uint2 tid [[thread_position_in_threadgroup]],
24
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
25
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
26
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
27
+ {
28
+ const uint simdgroup_size = 32;
29
+ if (control->abort != 0) {
30
+ return;
31
+ }
32
+
33
+ const uint num_q_heads = 64;
34
+ const uint head_dim = 64;
35
+ const uint qmul = 8;
36
+
37
+ const uint token_stride = 2 * head_dim;
38
+
39
+ const uint qt = gid.x; // Q token index
40
+ const uint h = gid.y; // KV head index
41
+
42
+ q += qt * args.qkv_dim + h * (qmul * head_dim);
43
+ kv += h * args.kv_stride;
44
+ output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim);
45
+
46
+ float m0 = static_cast<float>(s[h * qmul + 0]);
47
+ float m1 = static_cast<float>(s[h * qmul + 1]);
48
+ float m2 = static_cast<float>(s[h * qmul + 2]);
49
+ float m3 = static_cast<float>(s[h * qmul + 3]);
50
+ float m4 = static_cast<float>(s[h * qmul + 4]);
51
+ float m5 = static_cast<float>(s[h * qmul + 5]);
52
+ float m6 = static_cast<float>(s[h * qmul + 6]);
53
+ float m7 = static_cast<float>(s[h * qmul + 7]);
54
+
55
+ float l0 = simdgroup_idx == 0 ? 1.0f : 0.0f;
56
+ float l1 = simdgroup_idx == 0 ? 1.0f : 0.0f;
57
+ float l2 = simdgroup_idx == 0 ? 1.0f : 0.0f;
58
+ float l3 = simdgroup_idx == 0 ? 1.0f : 0.0f;
59
+ float l4 = simdgroup_idx == 0 ? 1.0f : 0.0f;
60
+ float l5 = simdgroup_idx == 0 ? 1.0f : 0.0f;
61
+ float l6 = simdgroup_idx == 0 ? 1.0f : 0.0f;
62
+ float l7 = simdgroup_idx == 0 ? 1.0f : 0.0f;
63
+
64
+ float2 out0 = 0.0f;
65
+ float2 out1 = 0.0f;
66
+ float2 out2 = 0.0f;
67
+ float2 out3 = 0.0f;
68
+ float2 out4 = 0.0f;
69
+ float2 out5 = 0.0f;
70
+ float2 out6 = 0.0f;
71
+ float2 out7 = 0.0f;
72
+
73
+ float2 q0 = reinterpret_cast<const device float2*>(q + 0 * head_dim)[simdgroup_tid];
74
+ float2 q1 = reinterpret_cast<const device float2*>(q + 1 * head_dim)[simdgroup_tid];
75
+ float2 q2 = reinterpret_cast<const device float2*>(q + 2 * head_dim)[simdgroup_tid];
76
+ float2 q3 = reinterpret_cast<const device float2*>(q + 3 * head_dim)[simdgroup_tid];
77
+ float2 q4 = reinterpret_cast<const device float2*>(q + 4 * head_dim)[simdgroup_tid];
78
+ float2 q5 = reinterpret_cast<const device float2*>(q + 5 * head_dim)[simdgroup_tid];
79
+ float2 q6 = reinterpret_cast<const device float2*>(q + 6 * head_dim)[simdgroup_tid];
80
+ float2 q7 = reinterpret_cast<const device float2*>(q + 7 * head_dim)[simdgroup_tid];
81
+
82
+ const uint kt_end = qt + args.num_kv_tokens + 1;
83
+ const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx;
84
+ kv += token_stride * kt_start;
85
+ for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) {
86
+ const float2 kval = reinterpret_cast<const device float2*>(kv)[simdgroup_tid];
87
+
88
+ float qk0 = metal::dot(q0, kval);
89
+ float qk1 = metal::dot(q1, kval);
90
+ float qk2 = metal::dot(q2, kval);
91
+ float qk3 = metal::dot(q3, kval);
92
+ float qk4 = metal::dot(q4, kval);
93
+ float qk5 = metal::dot(q5, kval);
94
+ float qk6 = metal::dot(q6, kval);
95
+ float qk7 = metal::dot(q7, kval);
96
+
97
+ qk0 = metal::simd_sum(qk0);
98
+ qk1 = metal::simd_sum(qk1);
99
+ qk2 = metal::simd_sum(qk2);
100
+ qk3 = metal::simd_sum(qk3);
101
+ qk4 = metal::simd_sum(qk4);
102
+ qk5 = metal::simd_sum(qk5);
103
+ qk6 = metal::simd_sum(qk6);
104
+ qk7 = metal::simd_sum(qk7);
105
+
106
+ const float new_m0 = metal::max(m0, qk0);
107
+ const float new_m1 = metal::max(m1, qk1);
108
+ const float new_m2 = metal::max(m2, qk2);
109
+ const float new_m3 = metal::max(m3, qk3);
110
+ const float new_m4 = metal::max(m4, qk4);
111
+ const float new_m5 = metal::max(m5, qk5);
112
+ const float new_m6 = metal::max(m6, qk6);
113
+ const float new_m7 = metal::max(m7, qk7);
114
+
115
+ const float alpha0 = metal::fast::exp(m0 - new_m0);
116
+ const float alpha1 = metal::fast::exp(m1 - new_m1);
117
+ const float alpha2 = metal::fast::exp(m2 - new_m2);
118
+ const float alpha3 = metal::fast::exp(m3 - new_m3);
119
+ const float alpha4 = metal::fast::exp(m4 - new_m4);
120
+ const float alpha5 = metal::fast::exp(m5 - new_m5);
121
+ const float alpha6 = metal::fast::exp(m6 - new_m6);
122
+ const float alpha7 = metal::fast::exp(m7 - new_m7);
123
+
124
+ qk0 = metal::fast::exp(qk0 - new_m0);
125
+ qk1 = metal::fast::exp(qk1 - new_m1);
126
+ qk2 = metal::fast::exp(qk2 - new_m2);
127
+ qk3 = metal::fast::exp(qk3 - new_m3);
128
+ qk4 = metal::fast::exp(qk4 - new_m4);
129
+ qk5 = metal::fast::exp(qk5 - new_m5);
130
+ qk6 = metal::fast::exp(qk6 - new_m6);
131
+ qk7 = metal::fast::exp(qk7 - new_m7);
132
+
133
+ l0 = metal::fma(l0, alpha0, qk0);
134
+ l1 = metal::fma(l1, alpha1, qk1);
135
+ l2 = metal::fma(l2, alpha2, qk2);
136
+ l3 = metal::fma(l3, alpha3, qk3);
137
+ l4 = metal::fma(l4, alpha4, qk4);
138
+ l5 = metal::fma(l5, alpha5, qk5);
139
+ l6 = metal::fma(l6, alpha6, qk6);
140
+ l7 = metal::fma(l7, alpha7, qk7);
141
+
142
+ m0 = new_m0;
143
+ m1 = new_m1;
144
+ m2 = new_m2;
145
+ m3 = new_m3;
146
+ m4 = new_m4;
147
+ m5 = new_m5;
148
+ m6 = new_m6;
149
+ m7 = new_m7;
150
+
151
+ const float2 vval = reinterpret_cast<const device float2*>(kv + head_dim)[simdgroup_tid];
152
+ kv += token_stride * num_simdgroups;
153
+ out0 = metal::fma(vval, qk0, out0 * alpha0);
154
+ out1 = metal::fma(vval, qk1, out1 * alpha1);
155
+ out2 = metal::fma(vval, qk2, out2 * alpha2);
156
+ out3 = metal::fma(vval, qk3, out3 * alpha3);
157
+ out4 = metal::fma(vval, qk4, out4 * alpha4);
158
+ out5 = metal::fma(vval, qk5, out5 * alpha5);
159
+ out6 = metal::fma(vval, qk6, out6 * alpha6);
160
+ out7 = metal::fma(vval, qk7, out7 * alpha7);
161
+ }
162
+ if (num_simdgroups > 1) {
163
+ if (metal::simd_is_first()) {
164
+ static_cast<threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_idx] = m0;
165
+ static_cast<threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_idx] = m1;
166
+ static_cast<threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_idx] = m2;
167
+ static_cast<threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_idx] = m3;
168
+ static_cast<threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_idx] = m4;
169
+ static_cast<threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_idx] = m5;
170
+ static_cast<threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_idx] = m6;
171
+ static_cast<threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_idx] = m7;
172
+
173
+ static_cast<threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_idx] = l0;
174
+ static_cast<threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_idx] = l1;
175
+ static_cast<threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_idx] = l2;
176
+ static_cast<threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_idx] = l3;
177
+ static_cast<threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_idx] = l4;
178
+ static_cast<threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_idx] = l5;
179
+ static_cast<threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_idx] = l6;
180
+ static_cast<threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_idx] = l7;
181
+ }
182
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
183
+ // Note: simdgroup refers not to the thread's current simdgroup, but to one with simdgroup_idx == thread's simdgroup_tid.
184
+ float simdgroup_m0 = m0;
185
+ float simdgroup_m1 = m1;
186
+ float simdgroup_m2 = m2;
187
+ float simdgroup_m3 = m3;
188
+ float simdgroup_m4 = m4;
189
+ float simdgroup_m5 = m5;
190
+ float simdgroup_m6 = m6;
191
+ float simdgroup_m7 = m7;
192
+ if (simdgroup_tid < num_simdgroups) {
193
+ simdgroup_m0 = static_cast<const threadgroup float*>(threadgroup_buffer)[0 * num_simdgroups + simdgroup_tid];
194
+ simdgroup_m1 = static_cast<const threadgroup float*>(threadgroup_buffer)[1 * num_simdgroups + simdgroup_tid];
195
+ simdgroup_m2 = static_cast<const threadgroup float*>(threadgroup_buffer)[2 * num_simdgroups + simdgroup_tid];
196
+ simdgroup_m3 = static_cast<const threadgroup float*>(threadgroup_buffer)[3 * num_simdgroups + simdgroup_tid];
197
+ simdgroup_m4 = static_cast<const threadgroup float*>(threadgroup_buffer)[4 * num_simdgroups + simdgroup_tid];
198
+ simdgroup_m5 = static_cast<const threadgroup float*>(threadgroup_buffer)[5 * num_simdgroups + simdgroup_tid];
199
+ simdgroup_m6 = static_cast<const threadgroup float*>(threadgroup_buffer)[6 * num_simdgroups + simdgroup_tid];
200
+ simdgroup_m7 = static_cast<const threadgroup float*>(threadgroup_buffer)[7 * num_simdgroups + simdgroup_tid];
201
+ }
202
+
203
+ const float threadgroup_m0 = metal::simd_max(simdgroup_m0);
204
+ const float threadgroup_m1 = metal::simd_max(simdgroup_m1);
205
+ const float threadgroup_m2 = metal::simd_max(simdgroup_m2);
206
+ const float threadgroup_m3 = metal::simd_max(simdgroup_m3);
207
+ const float threadgroup_m4 = metal::simd_max(simdgroup_m4);
208
+ const float threadgroup_m5 = metal::simd_max(simdgroup_m5);
209
+ const float threadgroup_m6 = metal::simd_max(simdgroup_m6);
210
+ const float threadgroup_m7 = metal::simd_max(simdgroup_m7);
211
+
212
+ out0 *= metal::fast::exp(m0 - threadgroup_m0);
213
+ out1 *= metal::fast::exp(m1 - threadgroup_m1);
214
+ out2 *= metal::fast::exp(m2 - threadgroup_m2);
215
+ out3 *= metal::fast::exp(m3 - threadgroup_m3);
216
+ out4 *= metal::fast::exp(m4 - threadgroup_m4);
217
+ out5 *= metal::fast::exp(m5 - threadgroup_m5);
218
+ out6 *= metal::fast::exp(m6 - threadgroup_m6);
219
+ out7 *= metal::fast::exp(m7 - threadgroup_m7);
220
+
221
+ if (simdgroup_idx == 0) {
222
+ l0 = 0.0f;
223
+ l1 = 0.0f;
224
+ l2 = 0.0f;
225
+ l3 = 0.0f;
226
+ l4 = 0.0f;
227
+ l5 = 0.0f;
228
+ l6 = 0.0f;
229
+ l7 = 0.0f;
230
+ if (simdgroup_tid < num_simdgroups) {
231
+ l0 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_tid];
232
+ l1 = static_cast<const threadgroup float*>(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_tid];
233
+ l2 = static_cast<const threadgroup float*>(threadgroup_buffer)[10 * num_simdgroups + simdgroup_tid];
234
+ l3 = static_cast<const threadgroup float*>(threadgroup_buffer)[11 * num_simdgroups + simdgroup_tid];
235
+ l4 = static_cast<const threadgroup float*>(threadgroup_buffer)[12 * num_simdgroups + simdgroup_tid];
236
+ l5 = static_cast<const threadgroup float*>(threadgroup_buffer)[13 * num_simdgroups + simdgroup_tid];
237
+ l6 = static_cast<const threadgroup float*>(threadgroup_buffer)[14 * num_simdgroups + simdgroup_tid];
238
+ l7 = static_cast<const threadgroup float*>(threadgroup_buffer)[15 * num_simdgroups + simdgroup_tid];
239
+ }
240
+
241
+ l0 = metal::simd_sum(l0 * metal::fast::exp(simdgroup_m0 - threadgroup_m0));
242
+ l1 = metal::simd_sum(l1 * metal::fast::exp(simdgroup_m1 - threadgroup_m1));
243
+ l2 = metal::simd_sum(l2 * metal::fast::exp(simdgroup_m2 - threadgroup_m2));
244
+ l3 = metal::simd_sum(l3 * metal::fast::exp(simdgroup_m3 - threadgroup_m3));
245
+ l4 = metal::simd_sum(l4 * metal::fast::exp(simdgroup_m4 - threadgroup_m4));
246
+ l5 = metal::simd_sum(l5 * metal::fast::exp(simdgroup_m5 - threadgroup_m5));
247
+ l6 = metal::simd_sum(l6 * metal::fast::exp(simdgroup_m6 - threadgroup_m6));
248
+ l7 = metal::simd_sum(l7 * metal::fast::exp(simdgroup_m7 - threadgroup_m7));
249
+ }
250
+
251
+ uint num_threads = num_simdgroups * simdgroup_size;
252
+ do {
253
+ const uint num_smem_threads = (num_threads / 2) & -simdgroup_size;
254
+ const uint num_half_threads = num_threads - num_smem_threads;
255
+
256
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
257
+ const uint smem_tid = tid.x - num_half_threads;
258
+ if (smem_tid < num_smem_threads) {
259
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + smem_tid] = out0;
260
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + smem_tid] = out1;
261
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + smem_tid] = out2;
262
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + smem_tid] = out3;
263
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + smem_tid] = out4;
264
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + smem_tid] = out5;
265
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + smem_tid] = out6;
266
+ static_cast<threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + smem_tid] = out7;
267
+ }
268
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
269
+ if (tid.x < num_smem_threads) {
270
+ out0 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 0 + tid.x];
271
+ out1 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 1 + tid.x];
272
+ out2 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 2 + tid.x];
273
+ out3 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 3 + tid.x];
274
+ out4 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 4 + tid.x];
275
+ out5 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 5 + tid.x];
276
+ out6 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 6 + tid.x];
277
+ out7 += static_cast<const threadgroup float2*>(threadgroup_buffer)[num_smem_threads * 7 + tid.x];
278
+ }
279
+
280
+ num_threads = num_half_threads;
281
+ } while (num_threads > simdgroup_size);
282
+ }
283
+ if (simdgroup_idx == 0) {
284
+ reinterpret_cast<device float2*>(output + 0 * head_dim)[simdgroup_tid] = out0 / l0;
285
+ reinterpret_cast<device float2*>(output + 1 * head_dim)[simdgroup_tid] = out1 / l1;
286
+ reinterpret_cast<device float2*>(output + 2 * head_dim)[simdgroup_tid] = out2 / l2;
287
+ reinterpret_cast<device float2*>(output + 3 * head_dim)[simdgroup_tid] = out3 / l3;
288
+ reinterpret_cast<device float2*>(output + 4 * head_dim)[simdgroup_tid] = out4 / l4;
289
+ reinterpret_cast<device float2*>(output + 5 * head_dim)[simdgroup_tid] = out5 / l5;
290
+ reinterpret_cast<device float2*>(output + 6 * head_dim)[simdgroup_tid] = out6 / l6;
291
+ reinterpret_cast<device float2*>(output + 7 * head_dim)[simdgroup_tid] = out7 / l7;
292
+ }
293
+ }
gptoss_kernels/source/tokenizer.c ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h>
2
+ #include <stdatomic.h>
3
+ #include <stddef.h>
4
+ #include <stdint.h>
5
+ #include <stdlib.h>
6
+ #include <string.h>
7
+
8
+ #include <errno.h>
9
+ #include <sys/mman.h>
10
+
11
+ #include <gpt-oss.h>
12
+
13
+ #include "internal/log.h"
14
+ #include "internal/model.h"
15
+
16
+
17
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_special_token_id(
18
+ gptoss_tokenizer_t tokenizer,
19
+ enum gptoss_special_token token_type,
20
+ uint32_t* token_id_out)
21
+ {
22
+ uint32_t token_id = UINT32_MAX;
23
+ if (token_type != gptoss_special_token_invalid && token_type < gptoss_special_token_max)
24
+ {
25
+ token_id = tokenizer->special_token_id[(uint32_t) token_type - 1];
26
+ }
27
+ if (token_id == UINT32_MAX) {
28
+ return gptoss_status_invalid_argument;
29
+ }
30
+
31
+ *token_id_out = token_id;
32
+ return gptoss_status_success;
33
+ }
34
+
35
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_text_tokens(
36
+ gptoss_tokenizer_t tokenizer,
37
+ uint32_t* num_text_tokens_out)
38
+ {
39
+ *num_text_tokens_out = tokenizer->num_text_tokens;
40
+ return gptoss_status_success;
41
+ }
42
+
43
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_special_tokens(
44
+ gptoss_tokenizer_t tokenizer,
45
+ uint32_t* num_special_tokens_out)
46
+ {
47
+ *num_special_tokens_out = tokenizer->num_special_tokens;
48
+ return gptoss_status_success;
49
+ }
50
+
51
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_get_num_tokens(
52
+ gptoss_tokenizer_t tokenizer,
53
+ uint32_t* num_tokens_out)
54
+ {
55
+ *num_tokens_out = tokenizer->num_text_tokens + tokenizer->num_special_tokens;
56
+ return gptoss_status_success;
57
+ }
58
+
59
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_decode(
60
+ gptoss_tokenizer_t tokenizer,
61
+ uint32_t token_id,
62
+ const void** token_ptr_out,
63
+ size_t* token_size_out)
64
+ {
65
+ if (token_id >= tokenizer->num_text_tokens) {
66
+ return gptoss_status_invalid_argument;
67
+ }
68
+
69
+ const char* token_ptr = (const char*) tokenizer->tokens_ptr;
70
+ for (uint32_t t = 0; t < token_id; t++) {
71
+ // Reading unaligned uint16_t
72
+ uint16_t token_length;
73
+ memcpy(&token_length, token_ptr, sizeof(token_length));
74
+
75
+ token_ptr += (size_t) token_length + sizeof(uint16_t);
76
+ }
77
+
78
+ *token_ptr_out = (const void*) (token_ptr + sizeof(uint16_t));
79
+ *token_size_out = (size_t) *token_ptr;
80
+ return gptoss_status_success;
81
+ }
82
+
83
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_retain(
84
+ gptoss_tokenizer_t tokenizer)
85
+ {
86
+ atomic_fetch_add_explicit(&tokenizer->ref_count, 1, memory_order_relaxed);
87
+ return gptoss_status_success;
88
+ }
89
+
90
+ enum gptoss_status GPTOSS_ABI gptoss_tokenizer_release(
91
+ gptoss_tokenizer_t tokenizer)
92
+ {
93
+ if (tokenizer != NULL) {
94
+ if (atomic_fetch_sub_explicit(&tokenizer->ref_count, 1, memory_order_acquire) == 1) {
95
+ if (tokenizer->mapping_ptr != NULL && tokenizer->mapping_size != 0) {
96
+ if (munmap(tokenizer->mapping_ptr, tokenizer->mapping_size) != 0) {
97
+ GPTOSS_LOG_WARNING("munmap for tokenizer mapping failed with error %d", errno);
98
+ }
99
+ }
100
+
101
+ memset(tokenizer, 0, sizeof(struct gptoss_tokenizer));
102
+ free(tokenizer);
103
+ }
104
+ }
105
+ return gptoss_status_success;
106
+ }
gptoss_kernels/source/topk.metal ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_compute>
2
+ #include <metal_integer>
3
+ #include <metal_math>
4
+ #include <metal_simdgroup>
5
+
6
+ #include <internal/kernel-args.h>
7
+
8
+ #pragma METAL fp math_mode(safe)
9
+ #pragma METAL fp contract(off)
10
+
11
+
12
+ [[max_total_threads_per_threadgroup(32)]]
13
+ kernel void gptoss_f32_topk_softmax_e128_k4(
14
+ constant gptoss_topk_args& args [[ buffer(0) ]],
15
+ const device float4* input [[ buffer(1) ]],
16
+ device gptoss_expert_prediction* output [[ buffer(2) ]],
17
+ const device gptoss_control* control [[ buffer(3) ]],
18
+ uint gid [[threadgroup_position_in_grid]],
19
+ uint tid [[thread_position_in_threadgroup]])
20
+ {
21
+ const uint num_experts = 128;
22
+ const uint num_active_experts = 4;
23
+ if (control->abort != 0) {
24
+ return;
25
+ }
26
+
27
+ input += gid * (num_experts / 4);
28
+ output += gid * num_active_experts;
29
+
30
+ uint4 idx = tid * 4 + (uint4) {0, 1, 2, 3};
31
+ float4 val = input[tid];
32
+
33
+ const float topval0 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
34
+ uint idx0 = 0xFFFFFFFFu;
35
+ if (val.w == topval0) {
36
+ idx0 = idx.w;
37
+ }
38
+ if (val.z == topval0) {
39
+ idx0 = idx.z;
40
+ }
41
+ if (val.y == topval0) {
42
+ idx0 = idx.y;
43
+ }
44
+ if (val.x == topval0) {
45
+ idx0 = idx.x;
46
+ }
47
+ const uint topidx0 = metal::simd_min(idx0);
48
+ const bool4 is_topidx0 = idx == topidx0;
49
+ val = metal::select(val, -INFINITY, is_topidx0);
50
+ idx = metal::select(idx, 0xFFFFFFFFu, is_topidx0);
51
+
52
+ const float topval1 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
53
+ uint idx1 = 0xFFFFFFFFu;
54
+ if (val.w == topval1) {
55
+ idx1 = idx.w;
56
+ }
57
+ if (val.z == topval1) {
58
+ idx1 = idx.z;
59
+ }
60
+ if (val.y == topval1) {
61
+ idx1 = idx.y;
62
+ }
63
+ if (val.x == topval1) {
64
+ idx1 = idx.x;
65
+ }
66
+ const uint topidx1 = metal::simd_min(idx1);
67
+ const bool4 is_topidx1 = idx == topidx1;
68
+ val = metal::select(val, -INFINITY, is_topidx1);
69
+ idx = metal::select(idx, 0xFFFFFFFFu, is_topidx1);
70
+
71
+ const float topval2 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
72
+ uint idx2 = 0xFFFFFFFFu;
73
+ if (val.w == topval2) {
74
+ idx2 = idx.w;
75
+ }
76
+ if (val.z == topval2) {
77
+ idx2 = idx.z;
78
+ }
79
+ if (val.y == topval2) {
80
+ idx2 = idx.y;
81
+ }
82
+ if (val.x == topval2) {
83
+ idx2 = idx.x;
84
+ }
85
+ const uint topidx2 = metal::simd_min(idx2);
86
+ const bool4 is_topidx2 = idx == topidx2;
87
+ val = metal::select(val, -INFINITY, is_topidx2);
88
+ idx = metal::select(idx, 0xFFFFFFFFu, is_topidx2);
89
+
90
+ const float topval3 = metal::simd_max(metal::max3(metal::max(val.x, val.y), val.z, val.w));
91
+ uint idx3 = 0xFFFFFFFFu;
92
+ if (val.w == topval3) {
93
+ idx3 = idx.w;
94
+ }
95
+ if (val.z == topval3) {
96
+ idx3 = idx.z;
97
+ }
98
+ if (val.y == topval3) {
99
+ idx3 = idx.y;
100
+ }
101
+ if (val.x == topval3) {
102
+ idx3 = idx.x;
103
+ }
104
+ const uint topidx3 = metal::simd_min(idx3);
105
+
106
+ if (metal::simd_is_first()) {
107
+ const float topexp0 = 1.0f;
108
+ const float topexp1 = metal::precise::exp(topval1 - topval0);
109
+ const float topexp2 = metal::precise::exp(topval2 - topval0);
110
+ const float topexp3 = metal::precise::exp(topval3 - topval0);
111
+
112
+ const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
113
+ const float scale = 1.0 / sum;
114
+
115
+ output[0] = (gptoss_expert_prediction) {
116
+ .expert_id = topidx0,
117
+ .score = topexp0 * scale,
118
+ };
119
+ output[1] = (gptoss_expert_prediction) {
120
+ .expert_id = topidx1,
121
+ .score = topexp1 * scale,
122
+ };
123
+ output[2] = (gptoss_expert_prediction) {
124
+ .expert_id = topidx2,
125
+ .score = topexp2 * scale,
126
+ };
127
+ output[3] = (gptoss_expert_prediction) {
128
+ .expert_id = topidx3,
129
+ .score = topexp3 * scale,
130
+ };
131
+ }
132
+ }
133
+
134
+ [[max_total_threads_per_threadgroup(32)]]
135
+ kernel void gptoss_f32_topk_softmax_e32_k4(
136
+ constant gptoss_topk_args& args [[ buffer(0) ]],
137
+ const device float* input [[ buffer(1) ]],
138
+ device gptoss_expert_prediction* output [[ buffer(2) ]],
139
+ const device gptoss_control* control [[ buffer(3) ]],
140
+ uint gid [[threadgroup_position_in_grid]],
141
+ uint tid [[thread_position_in_threadgroup]])
142
+ {
143
+ const uint num_experts = 32;
144
+ const uint num_active_experts = 4;
145
+ if (control->abort != 0) {
146
+ return;
147
+ }
148
+
149
+ input += gid * num_experts;
150
+ output += gid * num_active_experts;
151
+
152
+ float val = input[tid];
153
+ uint idx = tid;
154
+
155
+ const float topval0 = metal::simd_max(val);
156
+ const uint topidx0 = metal::simd_min(val == topval0 ? idx : 0xFFFFFFFFu);
157
+ if (idx == topidx0) {
158
+ val = -INFINITY;
159
+ idx = 0xFFFFFFFFu;
160
+ }
161
+
162
+ const float topval1 = metal::simd_max(val);
163
+ const uint topidx1 = metal::simd_min(val == topval1 ? idx : 0xFFFFFFFFu);
164
+ if (idx == topidx1) {
165
+ val = -INFINITY;
166
+ idx = 0xFFFFFFFFu;
167
+ }
168
+
169
+ const float topval2 = metal::simd_max(val);
170
+ const uint topidx2 = metal::simd_min(val == topval2 ? idx : 0xFFFFFFFFu);
171
+ if (idx == topidx2) {
172
+ val = -INFINITY;
173
+ idx = 0xFFFFFFFFu;
174
+ }
175
+
176
+ const float topval3 = metal::simd_max(val);
177
+ const uint topidx3 = metal::simd_min(val == topval3 ? idx : 0xFFFFFFFFu);
178
+
179
+ if (metal::simd_is_first()) {
180
+ const float topexp0 = 1.0f;
181
+ const float topexp1 = metal::precise::exp(topval1 - topval0);
182
+ const float topexp2 = metal::precise::exp(topval2 - topval0);
183
+ const float topexp3 = metal::precise::exp(topval3 - topval0);
184
+
185
+ const float sum = (topexp0 + topexp1) + (topexp2 + topexp3);
186
+ const float scale = 1.0 / sum;
187
+
188
+ output[0] = (gptoss_expert_prediction) {
189
+ .expert_id = topidx0,
190
+ .score = topexp0 * scale,
191
+ };
192
+ output[1] = (gptoss_expert_prediction) {
193
+ .expert_id = topidx1,
194
+ .score = topexp1 * scale,
195
+ };
196
+ output[2] = (gptoss_expert_prediction) {
197
+ .expert_id = topidx2,
198
+ .score = topexp2 * scale,
199
+ };
200
+ output[3] = (gptoss_expert_prediction) {
201
+ .expert_id = topidx3,
202
+ .score = topexp3 * scale,
203
+ };
204
+ }
205
+ }
gptoss_kernels/test/bf16-f32-embeddings.cc ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <gtest/gtest.h>
2
+
3
+ #include <cstddef>
4
+
5
+ #include "embeddings-kernel-tester.hpp"
6
+
7
+
8
+ using gptoss::EmbeddingsKernelTester;
9
+
10
+ constexpr std::size_t kThreadgroupSize = 64;
11
+
12
+
13
+ TEST(BF16_F32_EMBEDDINGS, single_token_single_tile) {
14
+ EmbeddingsKernelTester()
15
+ .num_channels(kThreadgroupSize)
16
+ .threadgroup_size(kThreadgroupSize)
17
+ .TestBF16_F32();
18
+ }
19
+
20
+ TEST(BF16_F32_EMBEDDINGS, single_token_multi_tile) {
21
+ EmbeddingsKernelTester()
22
+ .num_channels(kThreadgroupSize * 4 + 16)
23
+ .threadgroup_size(kThreadgroupSize)
24
+ .TestBF16_F32();
25
+ }
26
+
27
+ TEST(BF16_F32_EMBEDDINGS, multiple_tokens) {
28
+ EmbeddingsKernelTester()
29
+ .num_channels(kThreadgroupSize * 4 + 16)
30
+ .num_tokens(3)
31
+ .threadgroup_size(kThreadgroupSize)
32
+ .TestBF16_F32();
33
+ }
gptoss_kernels/test/embeddings-kernel-tester.hpp ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <gtest/gtest.h>
4
+
5
+ #include <cstddef>
6
+ #include <cstdint>
7
+
8
+ #include <internal/datatype.hpp>
9
+ #include <internal/metal.hpp>
10
+ #include <internal/metal-kernels.h>
11
+
12
+
13
+ namespace gptoss {
14
+
15
+ class EmbeddingsKernelTester {
16
+ public:
17
+ EmbeddingsKernelTester() { }
18
+
19
+ EmbeddingsKernelTester(const EmbeddingsKernelTester&) = delete;
20
+ EmbeddingsKernelTester(EmbeddingsKernelTester&&) = delete;
21
+ EmbeddingsKernelTester& operator=(const EmbeddingsKernelTester&) = delete;
22
+ EmbeddingsKernelTester& operator=(EmbeddingsKernelTester&&) = delete;
23
+
24
+ [[nodiscard]]
25
+ EmbeddingsKernelTester& num_channels(std::uint32_t num_channels) {
26
+ num_channels_ = num_channels;
27
+ return *this;
28
+ }
29
+
30
+ std::uint32_t num_channels() const {
31
+ return num_channels_;
32
+ }
33
+
34
+ [[nodiscard]]
35
+ EmbeddingsKernelTester& num_tokens(std::uint32_t num_tokens) {
36
+ num_tokens_ = num_tokens;
37
+ return *this;
38
+ }
39
+
40
+ std::uint32_t num_tokens() const {
41
+ return num_tokens_;
42
+ }
43
+
44
+ std::uint32_t vocabulary_size() const {
45
+ return num_tokens() + 1;
46
+ }
47
+
48
+ [[nodiscard]]
49
+ EmbeddingsKernelTester& threadgroup_size(std::size_t threadgroup_size) {
50
+ threadgroup_size_ = threadgroup_size;
51
+ return *this;
52
+ }
53
+
54
+ std::size_t threadgroup_size() const {
55
+ return threadgroup_size_;
56
+ }
57
+
58
+ void Validate() const {
59
+ ASSERT_NE(num_channels(), 0);
60
+ ASSERT_NE(num_tokens(), 0);
61
+ ASSERT_NE(threadgroup_size(), 0);
62
+ ASSERT_EQ(threadgroup_size() % 32, 0);
63
+ }
64
+
65
+ void TestBF16_F32() const {
66
+ Validate();
67
+
68
+ metal::CommandBuffer command_buffer{command_queue_};
69
+ metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};
70
+ metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};
71
+ metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};
72
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
73
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
74
+
75
+ std::uint32_t* token_ptr = static_cast<std::uint32_t*>(token_buffer.ptr());
76
+ for (std::uint32_t t = 0; t < num_tokens(); t++) {
77
+ token_ptr[t] = t + 1;
78
+ }
79
+
80
+ Check(gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
81
+ command_buffer.handle(),
82
+ bf16_f32_embeddings_fn.handle(),
83
+ threadgroup_size(),
84
+ token_buffer.handle(),
85
+ /*token_offset=*/0,
86
+ weight_buffer.handle(),
87
+ /*weight_offset=*/0,
88
+ output_buffer.handle(),
89
+ /*output_offset=*/0,
90
+ control_buffer.handle(),
91
+ /*control_offset=*/0,
92
+ num_tokens(),
93
+ num_channels()),
94
+ "gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings");
95
+
96
+ command_buffer.commit();
97
+ command_buffer.wait_completion();
98
+
99
+ const gptoss_bfloat16* weight_ptr = static_cast<const gptoss_bfloat16*>(weight_buffer.ptr());
100
+ const float* output_ptr = static_cast<const float*>(output_buffer.ptr());
101
+ for (std::uint32_t t = 0; t < num_tokens(); t++) {
102
+ const std::uint32_t token = token_ptr[t];
103
+ for (std::uint32_t i = 0; i < num_channels(); i++) {
104
+ const gptoss_bfloat16 input_val = weight_ptr[token * num_channels() + i];
105
+ const float ref_output = upcast<float>(input_val);
106
+ const float output = output_ptr[t * num_channels() + i];
107
+ ASSERT_EQ(output, ref_output)
108
+ << "at token " << t << ", position " << i << " / " << num_channels() << ", input " << std::uint32_t(input_val.bits);
109
+ }
110
+ }
111
+ }
112
+
113
+ private:
114
+ metal::Device device_{};
115
+ metal::CommandQueue command_queue_{device_};
116
+ metal::Library library_{device_};
117
+ metal::Function bf16_f32_embeddings_fn{library_, "gptoss_bf16_f32_embeddings"};
118
+ std::uint32_t num_tokens_{1};
119
+ std::uint32_t num_channels_{1};
120
+ std::size_t threadgroup_size_{32};
121
+ };
122
+
123
+ } // namespace gptoss
gptoss_kernels/test/f32-bf16w-matmul.cc ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <gtest/gtest.h>
2
+
3
+ #include <cstddef>
4
+ #include <cstdint>
5
+
6
+ #include "matmul-kernel-tester.hpp"
7
+
8
+
9
+ using gptoss::MatMulKernelTester;
10
+
11
+ constexpr size_t kSimdgroupSize = 32; // fixed in the kernel
12
+
13
+ TEST(F32_BF16W_MATMUL, single_simdgroup_single_iteration) {
14
+ MatMulKernelTester()
15
+ .num_rows(1)
16
+ .num_cols(kSimdgroupSize * 4)
17
+ .threadgroup_size(kSimdgroupSize)
18
+ .TestF32_BF16W();
19
+ }
20
+
21
+ TEST(F32_BF16W_MATMUL, single_simdgroup_multiple_iteration) {
22
+ MatMulKernelTester()
23
+ .num_rows(1)
24
+ .num_cols((2 * kSimdgroupSize + 1) * 4)
25
+ .threadgroup_size(kSimdgroupSize)
26
+ .TestF32_BF16W();
27
+ }
28
+
29
+ TEST(F32_BF16W_MATMUL, single_threadgroup) {
30
+ constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
31
+
32
+ MatMulKernelTester()
33
+ .num_rows(threadgroup_size / kSimdgroupSize)
34
+ .num_cols((2 * kSimdgroupSize + 1) * 4)
35
+ .threadgroup_size(threadgroup_size)
36
+ .TestF32_BF16W();
37
+ }
38
+
39
+ TEST(F32_BF16W_MATMUL, multiple_threadgroups) {
40
+ constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
41
+ constexpr std::uint32_t num_threadgroups = 3;
42
+
43
+ MatMulKernelTester()
44
+ .num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
45
+ .num_cols((2 * kSimdgroupSize + 1) * 4)
46
+ .threadgroup_size(threadgroup_size)
47
+ .TestF32_BF16W();
48
+ }
49
+
50
+ TEST(F32_BF16W_MATMUL, multiple_tokens) {
51
+ constexpr std::size_t threadgroup_size = 2 * kSimdgroupSize;
52
+ constexpr std::uint32_t num_threadgroups = 3;
53
+
54
+ MatMulKernelTester()
55
+ .num_rows(num_threadgroups * threadgroup_size / kSimdgroupSize)
56
+ .num_cols((2 * kSimdgroupSize + 1) * 4)
57
+ .num_tokens(2)
58
+ .threadgroup_size(threadgroup_size)
59
+ .TestF32_BF16W();
60
+ }
61
+
62
+ TEST(F32_BF16W_DENSE_MATMUL_QKV, seq_len_1024) {
63
+ MatMulKernelTester()
64
+ .num_tokens(1024)
65
+ .num_rows(5120)
66
+ .num_cols(2880)
67
+ .TestF32_BF16W(
68
+ MatMulKernelTester::MatMulKernelType::PREFILL_QKV_OPTIMIZED);
69
+ }
70
+
71
+ TEST(F32_BF16W_DENSE_MATMUL_ATTN_OUTPUT, seq_len_1024) {
72
+ MatMulKernelTester()
73
+ .num_tokens(1024)
74
+ .num_rows(2880)
75
+ .num_cols(4096)
76
+ .TestF32_BF16W(
77
+ MatMulKernelTester::MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED);
78
+ }
79
+
80
+ TEST(F32_BF16W_DENSE_MATMUL_MLP_GATE, seq_len_1024) {
81
+ MatMulKernelTester()
82
+ .num_tokens(1024)
83
+ .num_rows(128)
84
+ .num_cols(2880)
85
+ .TestF32_BF16W(
86
+ MatMulKernelTester::MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED);
87
+ }
gptoss_kernels/test/f32-bf16w-rmsnorm.cc ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <gtest/gtest.h>
2
+
3
+ #include <cstdint>
4
+
5
+ #include "rmsnorm-kernel-tester.hpp"
6
+
7
+
8
+ using gptoss::RMSNormKernelTester;
9
+
10
+ constexpr std::uint32_t kThreadgroupSize = 1024; // fixed in the kernel
11
+ constexpr std::uint32_t kVectorSize = 4; // fixed in the kernel
12
+
13
+ TEST(F32_BF16W_RMSNORM, single_iteration) {
14
+ RMSNormKernelTester()
15
+ .num_channels(kThreadgroupSize)
16
+ .TestF32_BF16W();
17
+ }
18
+
19
+ TEST(F32_BF16W_RMSNORM, multiple_iterations) {
20
+ RMSNormKernelTester()
21
+ .num_channels(kThreadgroupSize * 2)
22
+ .TestF32_BF16W();
23
+ }
24
+
25
+ TEST(F32_BF16W_RMSNORM, partial_iteration) {
26
+ RMSNormKernelTester()
27
+ .num_channels(kThreadgroupSize * 2 + kVectorSize)
28
+ .TestF32_BF16W();
29
+ }
30
+
31
+ TEST(F32_BF16W_RMSNORM, multiple_tokens) {
32
+ RMSNormKernelTester()
33
+ .num_tokens(3)
34
+ .num_channels(kThreadgroupSize * 2 + kVectorSize)
35
+ .TestF32_BF16W();
36
+ }