Kernels
EricB HF Staff commited on
Commit
e1b9104
·
1 Parent(s): c0b4ecd

Add benchmark

Browse files
Files changed (2) hide show
  1. benchmark.py +304 -0
  2. benchmark_flash_sdpa.py +0 -301
benchmark.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Benchmark causal mask performance scaling with sequence length"""
3
+
4
+ import torch
5
+ import time
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from typing import List
9
+ import kernels
10
+
11
+ metal_flash_sdpa = kernels.get_kernel("kernels-community/metal-flash-sdpa")
12
+
13
+
14
+ def create_cu_seqlens(seq_lengths: List[int]) -> torch.Tensor:
15
+ """Create cumulative sequence lengths tensor."""
16
+ cu_seqlens = [0]
17
+ for length in seq_lengths:
18
+ cu_seqlens.append(cu_seqlens[-1] + length)
19
+ return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
20
+
21
+
22
+ def benchmark_flash_sdpa_causal(
23
+ batch_size: int,
24
+ num_heads: int,
25
+ seq_len: int,
26
+ head_dim: int,
27
+ dtype: torch.dtype,
28
+ num_iterations: int = 20,
29
+ ) -> float:
30
+ """Benchmark Flash SDPA with causal mask"""
31
+
32
+ seq_lengths = [seq_len] * batch_size
33
+ cu_seqlens = create_cu_seqlens(seq_lengths)
34
+ total_tokens = sum(seq_lengths)
35
+
36
+ # Create input tensors
37
+ query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
38
+ key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
39
+ value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
40
+ out = torch.empty_like(query)
41
+
42
+ scale = 1.0 / (head_dim**0.5)
43
+
44
+ # Warmup
45
+ for _ in range(5):
46
+ metal_flash_sdpa.flash_attention_varlen(
47
+ out=out,
48
+ query=query,
49
+ key=key,
50
+ value=value,
51
+ cu_seqlens_q=cu_seqlens,
52
+ cu_seqlens_k=cu_seqlens,
53
+ max_seqlen_q=seq_len,
54
+ max_seqlen_k=seq_len,
55
+ do_causal=True,
56
+ scale=scale,
57
+ softcapping=1.0,
58
+ )
59
+ torch.mps.synchronize()
60
+
61
+ # Benchmark
62
+ start_time = time.perf_counter()
63
+ for _ in range(num_iterations):
64
+ metal_flash_sdpa.flash_attention_varlen(
65
+ out=out,
66
+ query=query,
67
+ key=key,
68
+ value=value,
69
+ cu_seqlens_q=cu_seqlens,
70
+ cu_seqlens_k=cu_seqlens,
71
+ max_seqlen_q=seq_len,
72
+ max_seqlen_k=seq_len,
73
+ do_causal=True,
74
+ scale=scale,
75
+ softcapping=1.0,
76
+ )
77
+ torch.mps.synchronize()
78
+ end_time = time.perf_counter()
79
+
80
+ return (end_time - start_time) * 1000 / num_iterations
81
+
82
+
83
+ def benchmark_naive_sdpa_causal(
84
+ batch_size: int,
85
+ num_heads: int,
86
+ seq_len: int,
87
+ head_dim: int,
88
+ dtype: torch.dtype,
89
+ num_iterations: int = 20,
90
+ ) -> float:
91
+ """Benchmark naive SDPA with causal mask"""
92
+
93
+ # Create input tensors
94
+ query = torch.randn(
95
+ batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps"
96
+ )
97
+ key = torch.randn(
98
+ batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps"
99
+ )
100
+ value = torch.randn(
101
+ batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps"
102
+ )
103
+
104
+ scale = 1.0 / (head_dim**0.5)
105
+
106
+ # Precompute causal mask
107
+ mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
108
+
109
+ # Warmup
110
+ for _ in range(5):
111
+ scores = torch.matmul(query, key.transpose(-2, -1)) * scale
112
+ scores = scores.masked_fill(mask, float("-inf"))
113
+ attn_weights = torch.softmax(scores, dim=-1)
114
+ out = torch.matmul(attn_weights, value)
115
+ torch.mps.synchronize()
116
+
117
+ # Benchmark
118
+ start_time = time.perf_counter()
119
+ for _ in range(num_iterations):
120
+ scores = torch.matmul(query, key.transpose(-2, -1)) * scale
121
+ scores = scores.masked_fill(mask, float("-inf"))
122
+ attn_weights = torch.softmax(scores, dim=-1)
123
+ out = torch.matmul(attn_weights, value)
124
+ torch.mps.synchronize()
125
+ end_time = time.perf_counter()
126
+
127
+ return (end_time - start_time) * 1000 / num_iterations
128
+
129
+
130
+ def run_scaling_benchmark():
131
+ """Run causal mask scaling benchmark"""
132
+
133
+ print("=" * 80)
134
+ print("Causal Mask Performance Scaling Benchmark")
135
+ print("Batch Size: 4, Head Dimension: 64")
136
+ print("=" * 80)
137
+
138
+ # Configuration
139
+ batch_size = 4
140
+ num_heads = 16
141
+ head_dim = 64
142
+ dtype = torch.float16
143
+
144
+ # Sequence lengths from 512 to 4096
145
+ seq_lengths = [512, 768, 1024, 1536, 2048, 3072, 4096]
146
+
147
+ flash_times = []
148
+ naive_times = []
149
+ speedups = []
150
+
151
+ print(f"{'Seq Len':<8} {'Flash (ms)':<12} {'Naive (ms)':<12} {'Speedup':<10}")
152
+ print("-" * 50)
153
+
154
+ for seq_len in seq_lengths:
155
+ # Benchmark Flash SDPA
156
+ flash_time = benchmark_flash_sdpa_causal(
157
+ batch_size, num_heads, seq_len, head_dim, dtype
158
+ )
159
+ flash_times.append(flash_time)
160
+
161
+ # Benchmark Naive SDPA
162
+ naive_time = benchmark_naive_sdpa_causal(
163
+ batch_size, num_heads, seq_len, head_dim, dtype
164
+ )
165
+ naive_times.append(naive_time)
166
+
167
+ speedup = naive_time / flash_time
168
+ speedups.append(speedup)
169
+
170
+ print(f"{seq_len:<8} {flash_time:<12.2f} {naive_time:<12.2f} {speedup:<10.2f}x")
171
+
172
+ return seq_lengths, flash_times, naive_times, speedups
173
+
174
+
175
+ def create_line_plot(seq_lengths, flash_times, naive_times, speedups):
176
+ """Create line graph visualization"""
177
+
178
+ # Create figure with single plot
179
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
180
+ fig.suptitle(
181
+ "Causal Mask Performance Scaling\n(Batch Size: 4, Head Dimension: 64)",
182
+ fontsize=16,
183
+ )
184
+
185
+ # Plot execution times
186
+ ax.plot(
187
+ seq_lengths,
188
+ flash_times,
189
+ marker="o",
190
+ linewidth=3,
191
+ markersize=10,
192
+ label="Flash SDPA",
193
+ color="blue",
194
+ )
195
+ ax.plot(
196
+ seq_lengths,
197
+ naive_times,
198
+ marker="s",
199
+ linewidth=3,
200
+ markersize=10,
201
+ label="Naive SDPA",
202
+ color="red",
203
+ )
204
+
205
+ ax.set_xlabel("Sequence Length", fontsize=14)
206
+ ax.set_ylabel("Time (ms)", fontsize=14)
207
+ ax.set_title("Execution Time vs Sequence Length", fontsize=16)
208
+ ax.grid(True, alpha=0.3)
209
+ ax.legend(fontsize=12)
210
+
211
+ # Add value annotations for all points
212
+ for i, (seq_len, flash_time, naive_time) in enumerate(
213
+ zip(seq_lengths, flash_times, naive_times)
214
+ ):
215
+ ax.annotate(
216
+ f"{flash_time:.1f}ms",
217
+ xy=(seq_len, flash_time),
218
+ xytext=(5, 5),
219
+ textcoords="offset points",
220
+ fontsize=10,
221
+ color="blue",
222
+ )
223
+ ax.annotate(
224
+ f"{naive_time:.1f}ms",
225
+ xy=(seq_len, naive_time),
226
+ xytext=(5, 5),
227
+ textcoords="offset points",
228
+ fontsize=10,
229
+ color="red",
230
+ )
231
+
232
+ # Set axis limits to better show the data
233
+ ax.set_xlim(seq_lengths[0] - 100, seq_lengths[-1] + 100)
234
+ ax.set_ylim(0, max(naive_times) * 1.1)
235
+
236
+ plt.tight_layout()
237
+ plt.savefig("benchmark.png", dpi=300, bbox_inches="tight")
238
+ plt.show()
239
+
240
+
241
+ def print_analysis(seq_lengths, flash_times, naive_times, speedups):
242
+ """Print detailed analysis of the results"""
243
+
244
+ print("\n" + "=" * 80)
245
+ print("DETAILED ANALYSIS")
246
+ print("=" * 80)
247
+
248
+ # Performance scaling analysis
249
+ print("\n1. Performance Scaling:")
250
+ print(
251
+ f" • Flash SDPA: {flash_times[0]:.2f}ms → {flash_times[-1]:.2f}ms ({flash_times[-1] / flash_times[0]:.1f}x increase)"
252
+ )
253
+ print(
254
+ f" • Naive SDPA: {naive_times[0]:.2f}ms → {naive_times[-1]:.2f}ms ({naive_times[-1] / naive_times[0]:.1f}x increase)"
255
+ )
256
+
257
+ # Speedup analysis
258
+ print("\n2. Speedup Analysis:")
259
+ print(f" • Average Speedup: {np.mean(speedups):.2f}x")
260
+ print(
261
+ f" • Max Speedup: {np.max(speedups):.2f}x (at seq_len={seq_lengths[np.argmax(speedups)]})"
262
+ )
263
+ print(
264
+ f" • Min Speedup: {np.min(speedups):.2f}x (at seq_len={seq_lengths[np.argmin(speedups)]})"
265
+ )
266
+
267
+ # Efficiency analysis
268
+ print("\n3. Efficiency Analysis:")
269
+ speedup_improvement = speedups[-1] / speedups[0]
270
+ print(f" • Speedup improvement from 512→4096: {speedup_improvement:.2f}x")
271
+
272
+ if speedup_improvement > 1.1:
273
+ print(" • Flash SDPA becomes MORE efficient at longer sequences")
274
+ elif speedup_improvement < 0.9:
275
+ print(" • Flash SDPA becomes LESS efficient at longer sequences")
276
+ else:
277
+ print(" • Flash SDPA maintains consistent efficiency across sequence lengths")
278
+
279
+ # Memory complexity analysis
280
+ print("\n4. Theoretical Complexity:")
281
+ print(f" • Sequence length increased by: {seq_lengths[-1] / seq_lengths[0]:.1f}x")
282
+ print(
283
+ f" • Theoretical O(n²) complexity increase: {(seq_lengths[-1] / seq_lengths[0]) ** 2:.1f}x"
284
+ )
285
+ print(f" • Actual Flash SDPA increase: {flash_times[-1] / flash_times[0]:.1f}x")
286
+ efficiency_ratio = (flash_times[-1] / flash_times[0]) / (
287
+ (seq_lengths[-1] / seq_lengths[0]) ** 2
288
+ )
289
+ print(f" • Flash SDPA efficiency ratio: {efficiency_ratio:.3f} (lower is better)")
290
+
291
+
292
+ def main():
293
+ # Run the scaling benchmark
294
+ seq_lengths, flash_times, naive_times, speedups = run_scaling_benchmark()
295
+
296
+ # Create line plot visualization
297
+ create_line_plot(seq_lengths, flash_times, naive_times, speedups)
298
+
299
+ # Print detailed analysis
300
+ print_analysis(seq_lengths, flash_times, naive_times, speedups)
301
+
302
+
303
+ if __name__ == "__main__":
304
+ main()
benchmark_flash_sdpa.py DELETED
@@ -1,301 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Benchmark script for metal-sdpa-flash (Flash SDPA)"""
3
-
4
- import torch
5
- import time
6
- import metal_flash_sdpa
7
- from typing import List, Tuple
8
- import numpy as np
9
-
10
-
11
- def create_cu_seqlens(seq_lengths: List[int]) -> torch.Tensor:
12
- """Create cumulative sequence lengths tensor."""
13
- cu_seqlens = [0]
14
- for length in seq_lengths:
15
- cu_seqlens.append(cu_seqlens[-1] + length)
16
- return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
17
-
18
-
19
- def warmup(func, *args, num_warmup=10):
20
- """Warmup the GPU by running the function multiple times"""
21
- for _ in range(num_warmup):
22
- func(*args)
23
- torch.mps.synchronize()
24
-
25
-
26
- def benchmark_flash_sdpa(
27
- batch_size: int,
28
- num_heads: int,
29
- seq_len: int,
30
- head_dim: int,
31
- dtype: torch.dtype,
32
- causal: bool = False,
33
- num_iterations: int = 100,
34
- ) -> float:
35
- """Benchmark Flash SDPA with given parameters"""
36
-
37
- # Create sequence lengths (all equal for fair comparison)
38
- seq_lengths = [seq_len] * batch_size
39
- cu_seqlens = create_cu_seqlens(seq_lengths)
40
- total_tokens = sum(seq_lengths)
41
-
42
- # Create input tensors in Flash format (total_tokens, num_heads, head_dim)
43
- query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
44
- key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
45
- value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
46
- out = torch.empty_like(query)
47
-
48
- scale = 1.0 / (head_dim ** 0.5)
49
-
50
- # Define the function to benchmark
51
- def run_flash_sdpa():
52
- metal_flash_sdpa.flash_attention_varlen(
53
- out=out,
54
- query=query,
55
- key=key,
56
- value=value,
57
- cu_seqlens_q=cu_seqlens,
58
- cu_seqlens_k=cu_seqlens,
59
- max_seqlen_q=seq_len,
60
- max_seqlen_k=seq_len,
61
- mask=None,
62
- do_causal=causal,
63
- scale=scale,
64
- softcapping=1.0,
65
- )
66
-
67
- # Warmup
68
- warmup(run_flash_sdpa, num_warmup=10)
69
-
70
- # Benchmark
71
- torch.mps.synchronize()
72
- start_time = time.perf_counter()
73
-
74
- for _ in range(num_iterations):
75
- run_flash_sdpa()
76
-
77
- torch.mps.synchronize()
78
- end_time = time.perf_counter()
79
-
80
- avg_time_ms = (end_time - start_time) * 1000 / num_iterations
81
- return avg_time_ms
82
-
83
-
84
- def benchmark_flash_gqa(
85
- batch_size: int,
86
- num_heads_q: int,
87
- num_heads_kv: int,
88
- seq_len: int,
89
- head_dim: int,
90
- dtype: torch.dtype,
91
- causal: bool = False,
92
- num_iterations: int = 100,
93
- ) -> float:
94
- """Benchmark Flash Attention with Grouped Query Attention"""
95
-
96
- # Create sequence lengths
97
- seq_lengths = [seq_len] * batch_size
98
- cu_seqlens = create_cu_seqlens(seq_lengths)
99
- total_tokens = sum(seq_lengths)
100
-
101
- # Create input tensors with different head counts
102
- query = torch.randn(total_tokens, num_heads_q, head_dim, dtype=dtype, device="mps")
103
- key = torch.randn(total_tokens, num_heads_kv, head_dim, dtype=dtype, device="mps")
104
- value = torch.randn(total_tokens, num_heads_kv, head_dim, dtype=dtype, device="mps")
105
- out = torch.empty_like(query)
106
-
107
- scale = 1.0 / (head_dim ** 0.5)
108
-
109
- # Define the function to benchmark
110
- def run_flash_gqa():
111
- metal_flash_sdpa.flash_attention_varlen(
112
- out=out,
113
- query=query,
114
- key=key,
115
- value=value,
116
- cu_seqlens_q=cu_seqlens,
117
- cu_seqlens_k=cu_seqlens,
118
- max_seqlen_q=seq_len,
119
- max_seqlen_k=seq_len,
120
- mask=None,
121
- do_causal=causal,
122
- scale=scale,
123
- softcapping=1.0,
124
- )
125
-
126
- # Warmup
127
- warmup(run_flash_gqa, num_warmup=10)
128
-
129
- # Benchmark
130
- torch.mps.synchronize()
131
- start_time = time.perf_counter()
132
-
133
- for _ in range(num_iterations):
134
- run_flash_gqa()
135
-
136
- torch.mps.synchronize()
137
- end_time = time.perf_counter()
138
-
139
- avg_time_ms = (end_time - start_time) * 1000 / num_iterations
140
- return avg_time_ms
141
-
142
-
143
- def benchmark_variable_length(
144
- seq_lengths: List[int],
145
- num_heads: int,
146
- head_dim: int,
147
- dtype: torch.dtype,
148
- causal: bool = False,
149
- num_iterations: int = 100,
150
- ) -> float:
151
- """Benchmark Flash Attention with variable sequence lengths"""
152
-
153
- cu_seqlens = create_cu_seqlens(seq_lengths)
154
- total_tokens = sum(seq_lengths)
155
- max_seqlen = max(seq_lengths)
156
-
157
- # Create input tensors
158
- query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
159
- key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
160
- value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
161
- out = torch.empty_like(query)
162
-
163
- scale = 1.0 / (head_dim ** 0.5)
164
-
165
- # Define the function to benchmark
166
- def run_varlen():
167
- metal_flash_sdpa.flash_attention_varlen(
168
- out=out,
169
- query=query,
170
- key=key,
171
- value=value,
172
- cu_seqlens_q=cu_seqlens,
173
- cu_seqlens_k=cu_seqlens,
174
- max_seqlen_q=max_seqlen,
175
- max_seqlen_k=max_seqlen,
176
- mask=None,
177
- do_causal=causal,
178
- scale=scale,
179
- softcapping=1.0,
180
- )
181
-
182
- # Warmup
183
- warmup(run_varlen, num_warmup=10)
184
-
185
- # Benchmark
186
- torch.mps.synchronize()
187
- start_time = time.perf_counter()
188
-
189
- for _ in range(num_iterations):
190
- run_varlen()
191
-
192
- torch.mps.synchronize()
193
- end_time = time.perf_counter()
194
-
195
- avg_time_ms = (end_time - start_time) * 1000 / num_iterations
196
- return avg_time_ms
197
-
198
-
199
- def main():
200
- print("=" * 80)
201
- print("Metal Flash SDPA Benchmark")
202
- print("=" * 80)
203
-
204
- # Test configurations (matching the plain SDPA benchmark)
205
- configs = [
206
- # (batch_size, num_heads, seq_len, head_dim, dtype, causal, name)
207
- (1, 32, 512, 64, torch.float32, False, "Small seq, float32"),
208
- (1, 32, 512, 64, torch.float16, False, "Small seq, float16"),
209
- (1, 32, 512, 64, torch.bfloat16, False, "Small seq, bfloat16"),
210
-
211
- (4, 32, 2048, 64, torch.float16, False, "Medium seq, float16"),
212
- (4, 32, 2048, 64, torch.float16, True, "Medium seq, float16, causal"),
213
-
214
- (2, 32, 4096, 64, torch.float16, False, "Large seq, float16"),
215
- (2, 32, 4096, 64, torch.float16, True, "Large seq, float16, causal"),
216
-
217
- # Different head dimensions
218
- (2, 32, 2048, 32, torch.float16, False, "head_dim=32"),
219
- (2, 32, 2048, 64, torch.float16, False, "head_dim=64"),
220
- (2, 32, 2048, 128, torch.float16, False, "head_dim=128"),
221
-
222
- # Vector kernel cases (q_seq=1) - Flash doesn't have a special vector kernel
223
- # but we benchmark these cases for fair comparison with plain SDPA
224
- (16, 32, 1, 64, torch.float16, False, "Vector kernel (q_seq=1)"),
225
- (16, 32, 1, 128, torch.float16, False, "Vector kernel (q_seq=1, head_dim=128)"),
226
- ]
227
-
228
- print("\nFlash Attention Benchmarks:")
229
- print("-" * 80)
230
- print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}")
231
- print("-" * 80)
232
-
233
- for batch_size, num_heads, seq_len, head_dim, dtype, causal, name in configs:
234
- time_ms = benchmark_flash_sdpa(
235
- batch_size, num_heads, seq_len, head_dim, dtype, causal
236
- )
237
-
238
- # Calculate FLOPS (approximate)
239
- # Attention: 2 * batch * heads * seq_len^2 * head_dim
240
- flops = 2 * batch_size * num_heads * seq_len * seq_len * head_dim
241
- tflops = (flops / 1e12) / (time_ms / 1000)
242
-
243
- print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}")
244
-
245
- # GQA benchmarks
246
- print("\n\nGrouped Query Attention (GQA) Benchmarks:")
247
- print("-" * 80)
248
- print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}")
249
- print("-" * 80)
250
-
251
- gqa_configs = [
252
- # (batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal, name)
253
- (2, 32, 8, 2048, 64, torch.float16, False, "GQA 4:1 ratio"),
254
- (2, 32, 4, 2048, 64, torch.float16, False, "GQA 8:1 ratio"),
255
- (2, 32, 1, 2048, 64, torch.float16, False, "MQA (32:1 ratio)"),
256
- (2, 32, 8, 2048, 128, torch.float16, False, "GQA 4:1, head_dim=128"),
257
- ]
258
-
259
- for batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal, name in gqa_configs:
260
- time_ms = benchmark_flash_gqa(
261
- batch_size, num_heads_q, num_heads_kv, seq_len, head_dim, dtype, causal
262
- )
263
-
264
- # Calculate FLOPS for GQA
265
- flops = 2 * batch_size * num_heads_q * seq_len * seq_len * head_dim
266
- tflops = (flops / 1e12) / (time_ms / 1000)
267
-
268
- print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}")
269
-
270
- # Variable length sequences (unique to Flash Attention)
271
- print("\n\nVariable Length Sequence Benchmarks:")
272
- print("-" * 80)
273
- print(f"{'Config':<40} {'Time (ms)':<15} {'TFLOPS':<15}")
274
- print("-" * 80)
275
-
276
- varlen_configs = [
277
- # (seq_lengths, num_heads, head_dim, dtype, causal, name)
278
- ([512, 1024, 2048, 4096], 32, 64, torch.float16, False, "Variable [512-4096]"),
279
- ([128, 256, 512, 1024, 2048], 32, 64, torch.float16, False, "Variable [128-2048]"),
280
- ([2048, 2048, 2048, 2048], 32, 64, torch.float16, False, "Fixed 4x2048 (baseline)"),
281
- ]
282
-
283
- for seq_lengths, num_heads, head_dim, dtype, causal, name in varlen_configs:
284
- time_ms = benchmark_variable_length(
285
- seq_lengths, num_heads, head_dim, dtype, causal
286
- )
287
-
288
- # Calculate FLOPS for variable length
289
- total_flops = 0
290
- for seq_len in seq_lengths:
291
- total_flops += 2 * num_heads * seq_len * seq_len * head_dim
292
- tflops = (total_flops / 1e12) / (time_ms / 1000)
293
-
294
- print(f"{name:<40} {time_ms:<15.3f} {tflops:<15.2f}")
295
-
296
- print("\n" + "=" * 80)
297
- print("Benchmark completed!")
298
-
299
-
300
- if __name__ == "__main__":
301
- main()