Kernels
Eric Buehler commited on
Commit
e7707ac
·
1 Parent(s): bc6a74d

Better testing

Browse files
Files changed (1) hide show
  1. tests/test_flash_attention.py +135 -798
tests/test_flash_attention.py CHANGED
@@ -11,76 +11,73 @@ def create_cu_seqlens(seq_lengths):
11
  return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
12
 
13
 
14
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
15
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
16
- def test_flash_attention_single_sequence(dtype, head_dim):
17
- """Test Flash Attention with a single sequence."""
18
- torch.manual_seed(42)
19
-
20
- # Single sequence
21
- seq_len = 32
22
- num_heads = 4
23
-
24
- # Create cumulative sequence lengths
25
- cu_seqlens = create_cu_seqlens([seq_len])
26
-
27
- # Create input tensors in Flash Attention format
28
- query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
29
- key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
30
- value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
31
-
32
- # Scale factor
33
- scale = 1.0 / (head_dim ** 0.5)
34
-
35
- # Call Flash Attention
36
- out = torch.empty_like(query)
37
- metal_flash_sdpa.flash_attention_varlen(
38
- out=out,
39
- query=query,
40
- key=key,
41
- value=value,
42
- cu_seqlens_q=cu_seqlens,
43
- cu_seqlens_k=cu_seqlens,
44
- max_seqlen_q=seq_len,
45
- max_seqlen_k=seq_len,
46
- do_causal=False,
47
- scale=scale,
48
- softcapping=1.0,
49
- )
50
 
51
- # Compute ground truth
52
- # Flash Attention computes attention separately for each head
53
- expected = torch.zeros_like(out)
54
  for h in range(num_heads):
55
- q_h = query[:, h, :] # [seq_len, head_dim]
56
- k_h = key[:, h, :]
57
- v_h = value[:, h, :]
 
58
 
59
  scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
 
 
 
 
 
 
 
 
 
 
 
 
60
  attn_weights = torch.softmax(scores, dim=-1)
61
  expected[:, h, :] = torch.matmul(attn_weights, v_h)
62
 
63
- # Check results (higher tolerance for bfloat16 and float16)
 
 
 
 
64
  if dtype == torch.bfloat16:
65
- # Higher tolerance for head_dim=128 with bfloat16
66
- rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
67
  elif dtype == torch.float16:
68
- rtol, atol = 2e-3, 2e-3
69
  else:
70
- rtol, atol = 1e-3, 1e-3
71
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
72
 
73
 
74
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
75
  @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
76
- def test_flash_attention_variable_lengths(dtype, head_dim):
77
- """Test Flash Attention with variable-length sequences."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  torch.manual_seed(42)
79
 
80
- # Variable sequence lengths
81
- seq_lengths_q = [8, 16, 12]
82
- seq_lengths_k = [10, 20, 15]
83
- batch_size = len(seq_lengths_q)
 
 
 
84
  num_heads = 4
85
 
86
  # Create cumulative sequence lengths
@@ -111,128 +108,48 @@ def test_flash_attention_variable_lengths(dtype, head_dim):
111
  cu_seqlens_k=cu_seqlens_k,
112
  max_seqlen_q=max_seqlen_q,
113
  max_seqlen_k=max_seqlen_k,
114
- do_causal=False,
115
  scale=scale,
116
  softcapping=1.0,
117
  )
118
 
119
  # Compute ground truth for each sequence
120
  expected = torch.zeros_like(out)
 
 
121
  for i in range(batch_size):
122
  q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
123
  k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
124
 
125
- q_i = query[q_start:q_end]
126
- k_i = key[k_start:k_end]
127
- v_i = value[k_start:k_end]
128
-
129
- # Compute attention for each head separately
130
- for h in range(num_heads):
131
- q_h = q_i[:, h, :] # [seq_len, head_dim]
132
- k_h = k_i[:, h, :]
133
- v_h = v_i[:, h, :]
134
-
135
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
136
- attn_weights = torch.softmax(scores, dim=-1)
137
- expected[q_start:q_end, h, :] = torch.matmul(attn_weights, v_h)
138
-
139
- # Check results (higher tolerance for bfloat16 and float16)
140
- if dtype == torch.bfloat16:
141
- # Higher tolerance for bfloat16 with variable length sequences
142
- rtol, atol = 2e-2, 2e-2
143
- elif dtype == torch.float16:
144
- rtol, atol = 2e-3, 2e-3
145
- else:
146
- rtol, atol = 1e-3, 1e-3
147
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
148
-
149
-
150
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
151
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
152
- def test_flash_attention_causal(dtype, head_dim):
153
- """Test Flash Attention with causal masking."""
154
- torch.manual_seed(42)
155
-
156
- # Test dimensions
157
- seq_lengths = [16, 24]
158
- batch_size = len(seq_lengths)
159
- num_heads = 4
160
-
161
- # Create cumulative sequence lengths
162
- cu_seqlens = create_cu_seqlens(seq_lengths)
163
- total_tokens = sum(seq_lengths)
164
- max_seqlen = max(seq_lengths)
165
-
166
- # Create input tensors
167
- query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
168
- key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
169
- value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
170
-
171
- # Scale factor
172
- scale = 1.0 / (head_dim ** 0.5)
173
-
174
- # Call Flash Attention with causal mask
175
- out = torch.empty_like(query)
176
- metal_flash_sdpa.flash_attention_varlen(
177
- out=out,
178
- query=query,
179
- key=key,
180
- value=value,
181
- cu_seqlens_q=cu_seqlens,
182
- cu_seqlens_k=cu_seqlens,
183
- max_seqlen_q=max_seqlen,
184
- max_seqlen_k=max_seqlen,
185
- do_causal=True,
186
- scale=scale,
187
- softcapping=1.0,
188
- )
189
-
190
- # Compute ground truth with causal mask
191
- expected = torch.zeros_like(out)
192
- for i in range(batch_size):
193
- start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
194
- seq_len = end - start
195
-
196
- q_i = query[start:end]
197
- k_i = key[start:end]
198
- v_i = value[start:end]
199
-
200
- # Compute attention for each head separately
201
- for h in range(num_heads):
202
- q_h = q_i[:, h, :] # [seq_len, head_dim]
203
- k_h = k_i[:, h, :]
204
- v_h = v_i[:, h, :]
205
-
206
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
207
-
208
- # Apply causal mask
209
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
210
- scores.masked_fill_(causal_mask, float("-inf"))
211
 
212
- attn_weights = torch.softmax(scores, dim=-1)
213
- expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
214
 
215
- # Check results (higher tolerance for bfloat16 and float16)
216
- if dtype == torch.bfloat16:
217
- # Higher tolerance for head_dim=128 with bfloat16
218
- rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
219
- elif dtype == torch.float16:
220
- rtol, atol = 2e-3, 2e-3
221
- else:
222
- rtol, atol = 1e-3, 1e-3
223
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
224
 
225
 
226
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
227
  @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
228
- def test_flash_attention_gqa(dtype, head_dim):
229
- """Test Flash Attention with Grouped Query Attention."""
 
 
 
 
 
 
 
 
230
  torch.manual_seed(42)
231
 
232
- # Test dimensions
233
- seq_len = 32
234
- num_heads = 8
235
- num_kv_heads = 2 # GQA with 4:1 ratio
236
 
237
  # Create cumulative sequence lengths
238
  cu_seqlens = create_cu_seqlens([seq_len])
@@ -262,81 +179,28 @@ def test_flash_attention_gqa(dtype, head_dim):
262
  )
263
 
264
  # Compute ground truth with GQA
265
- # Each query head attends to its corresponding kv head (with repetition)
266
- expected = torch.zeros_like(query)
267
- gqa_factor = num_heads // num_kv_heads
268
-
269
- for h in range(num_heads):
270
- kv_h = h // gqa_factor
271
- q_h = query[:, h, :] # [seq_len, head_dim]
272
- k_h = key[:, kv_h, :]
273
- v_h = value[:, kv_h, :]
274
-
275
- scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
276
- attn_weights = torch.softmax(scores, dim=-1)
277
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
278
 
279
- # Check results (higher tolerance for bfloat16 and float16)
280
- if dtype == torch.bfloat16:
281
- # Higher tolerance for head_dim=128 with bfloat16
282
- rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
283
- elif dtype == torch.float16:
284
- rtol, atol = 2e-3, 2e-3
285
- else:
286
- rtol, atol = 1e-3, 1e-3
287
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
288
 
289
 
290
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
291
- def test_flash_attention_head_dimensions(head_dim):
292
- """Test Flash Attention with different supported head dimensions."""
293
- torch.manual_seed(42)
294
-
295
- # Test dimensions
296
- seq_len = 16
297
- num_heads = 4
298
-
299
- # Create cumulative sequence lengths
300
- cu_seqlens = create_cu_seqlens([seq_len])
301
-
302
- # Create input tensors
303
- query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
304
- key = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
305
- value = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
306
-
307
- # Scale factor
308
- scale = 1.0 / (head_dim ** 0.5)
309
-
310
- # Call Flash Attention
311
- out = torch.empty_like(query)
312
- metal_flash_sdpa.flash_attention_varlen(
313
- out=out,
314
- query=query,
315
- key=key,
316
- value=value,
317
- cu_seqlens_q=cu_seqlens,
318
- cu_seqlens_k=cu_seqlens,
319
- max_seqlen_q=seq_len,
320
- max_seqlen_k=seq_len,
321
- do_causal=False,
322
- scale=scale,
323
- softcapping=1.0,
324
- )
325
-
326
- # Basic check that output is not zeros
327
- assert out.abs().max().item() > 0
328
-
329
-
330
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
331
- def test_flash_attention_large_head_dim(dtype):
332
- """Test Flash Attention with head_dim=128 specifically."""
 
 
 
 
 
 
 
 
333
  torch.manual_seed(42)
334
 
335
- # Test dimensions with head_dim=128
336
- seq_lengths = [32, 64]
337
- batch_size = len(seq_lengths)
338
- num_heads = 8
339
- head_dim = 128
340
 
341
  # Create cumulative sequence lengths
342
  cu_seqlens = create_cu_seqlens(seq_lengths)
@@ -351,7 +215,7 @@ def test_flash_attention_large_head_dim(dtype):
351
  # Scale factor
352
  scale = 1.0 / (head_dim ** 0.5)
353
 
354
- # Call Flash Attention
355
  out = torch.empty_like(query)
356
  metal_flash_sdpa.flash_attention_varlen(
357
  out=out,
@@ -364,159 +228,87 @@ def test_flash_attention_large_head_dim(dtype):
364
  max_seqlen_k=max_seqlen,
365
  do_causal=False,
366
  scale=scale,
367
- softcapping=1.0,
368
  )
369
 
370
- # Compute ground truth
371
- expected = torch.zeros_like(out)
372
- for i in range(batch_size):
373
- start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
374
-
375
- q_i = query[start:end]
376
- k_i = key[start:end]
377
- v_i = value[start:end]
378
-
379
- # Compute attention for each head separately
380
- for h in range(num_heads):
381
- q_h = q_i[:, h, :] # [seq_len, head_dim]
382
- k_h = k_i[:, h, :]
383
- v_h = v_i[:, h, :]
384
 
385
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
386
- attn_weights = torch.softmax(scores, dim=-1)
387
- expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
 
388
 
389
- # Check results (higher tolerance for bfloat16 with head_dim=128)
390
  if dtype == torch.bfloat16:
391
- # bfloat16 with head_dim=128 has known precision issues
392
- rtol, atol = 2e-2, 2e-2
393
  elif dtype == torch.float16:
394
- rtol, atol = 2e-3, 2e-3
395
  else:
396
- rtol, atol = 1e-3, 1e-3
397
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
398
 
399
 
400
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
401
- def test_flash_attention_large_head_dim_causal(dtype):
402
- """Test Flash Attention with head_dim=128 and causal masking."""
 
 
 
 
 
403
  torch.manual_seed(42)
404
 
405
- # Test dimensions
406
- seq_len = 48
407
  num_heads = 4
408
- head_dim = 128
409
 
410
  # Create cumulative sequence lengths
411
- cu_seqlens = create_cu_seqlens([seq_len])
 
412
 
413
  # Create input tensors
414
- query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
415
- key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
416
- value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
417
 
418
  # Scale factor
419
  scale = 1.0 / (head_dim ** 0.5)
420
 
421
- # Call Flash Attention with causal mask
422
  out = torch.empty_like(query)
423
  metal_flash_sdpa.flash_attention_varlen(
424
  out=out,
425
  query=query,
426
  key=key,
427
  value=value,
428
- cu_seqlens_q=cu_seqlens,
429
- cu_seqlens_k=cu_seqlens,
430
- max_seqlen_q=seq_len,
431
- max_seqlen_k=seq_len,
432
- do_causal=True,
433
  scale=scale,
434
  softcapping=1.0,
435
  )
436
 
437
- # Compute ground truth with causal mask
438
- expected = torch.zeros_like(out)
439
-
440
- for h in range(num_heads):
441
- q_h = query[:, h, :] # [seq_len, head_dim]
442
- k_h = key[:, h, :]
443
- v_h = value[:, h, :]
444
-
445
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
446
-
447
- # Apply causal mask
448
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
449
- scores.masked_fill_(causal_mask, float("-inf"))
450
-
451
- attn_weights = torch.softmax(scores, dim=-1)
452
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
453
 
454
- # Check results (higher tolerance for bfloat16 with head_dim=128)
455
  if dtype == torch.bfloat16:
456
- # bfloat16 with head_dim=128 has known precision issues
457
- rtol, atol = 2e-2, 2e-2
458
  elif dtype == torch.float16:
459
- rtol, atol = 2e-3, 2e-3
460
  else:
461
- rtol, atol = 1e-3, 1e-3
462
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
463
 
464
 
465
- def test_flash_attention_large_head_dim_gqa():
466
- """Test Flash Attention with head_dim=128 and GQA."""
467
- torch.manual_seed(42)
468
-
469
- # Test dimensions
470
- seq_len = 32
471
- num_heads = 16
472
- num_kv_heads = 4 # GQA with 4:1 ratio
473
- head_dim = 128
474
-
475
- # Create cumulative sequence lengths
476
- cu_seqlens = create_cu_seqlens([seq_len])
477
-
478
- # Create input tensors
479
- query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
480
- key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
481
- value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
482
-
483
- # Scale factor
484
- scale = 1.0 / (head_dim ** 0.5)
485
-
486
- # Call Flash Attention
487
- out = torch.empty_like(query)
488
- metal_flash_sdpa.flash_attention_varlen(
489
- out=out,
490
- query=query,
491
- key=key,
492
- value=value,
493
- cu_seqlens_q=cu_seqlens,
494
- cu_seqlens_k=cu_seqlens,
495
- max_seqlen_q=seq_len,
496
- max_seqlen_k=seq_len,
497
- do_causal=False,
498
- scale=scale,
499
- softcapping=1.0,
500
- )
501
-
502
- # Compute ground truth with GQA
503
- expected = torch.zeros_like(query)
504
- gqa_factor = num_heads // num_kv_heads
505
-
506
- for h in range(num_heads):
507
- kv_h = h // gqa_factor
508
- q_h = query[:, h, :] # [seq_len, head_dim]
509
- k_h = key[:, kv_h, :]
510
- v_h = value[:, kv_h, :]
511
-
512
- scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
513
- attn_weights = torch.softmax(scores, dim=-1)
514
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
515
-
516
- # Check results
517
- torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
518
-
519
-
520
  def test_flash_attention_edge_cases():
521
  """Test Flash Attention edge cases."""
522
  torch.manual_seed(42)
@@ -596,36 +388,13 @@ def test_flash_attention_unsupported_cases():
596
  softcapping=1.0,
597
  )
598
 
599
- # Test 2: Calling function with wrong number of arguments
 
600
  query = torch.randn(16, 4, 64, device="mps")
601
  key = torch.randn(16, 4, 64, device="mps")
602
  value = torch.randn(16, 4, 64, device="mps")
603
- mask = torch.randn(1, 1, 16, 16, device="mps")
604
- cu_seqlens = create_cu_seqlens([16])
605
- out = torch.empty_like(query)
606
-
607
- # The function signature no longer accepts mask parameter
608
- with pytest.raises(TypeError):
609
- metal_flash_sdpa.flash_attention_varlen(
610
- out=out,
611
- query=query,
612
- key=key,
613
- value=value,
614
- cu_seqlens_q=cu_seqlens,
615
- cu_seqlens_k=cu_seqlens,
616
- max_seqlen_q=16,
617
- max_seqlen_k=16,
618
- mask=mask, # This parameter doesn't exist anymore
619
- do_causal=False,
620
- scale=0.125,
621
- softcapping=1.0,
622
- )
623
-
624
- # Test 3: Wrong dtype for cu_seqlens (should be int32)
625
- cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
626
 
627
  # This will silently fail (output will be unchanged)
628
- # We can detect this by initializing output to a known value
629
  out = torch.full_like(query, -999.0)
630
  metal_flash_sdpa.flash_attention_varlen(
631
  out=out,
@@ -645,300 +414,6 @@ def test_flash_attention_unsupported_cases():
645
  assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
646
 
647
 
648
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
649
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
650
- def test_flash_attention_small_sequences(dtype, head_dim):
651
- """Test Flash Attention with small sequence lengths (2-8)."""
652
- torch.manual_seed(42)
653
-
654
- # Test different small sequence lengths
655
- for seq_len in [2, 4, 6, 8]:
656
- num_heads = 4
657
-
658
- # Create cumulative sequence lengths
659
- cu_seqlens = create_cu_seqlens([seq_len])
660
-
661
- # Create input tensors
662
- query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
663
- key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
664
- value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
665
-
666
- # Scale factor
667
- scale = 1.0 / (head_dim ** 0.5)
668
-
669
- # Call Flash Attention
670
- out = torch.empty_like(query)
671
- metal_flash_sdpa.flash_attention_varlen(
672
- out=out,
673
- query=query,
674
- key=key,
675
- value=value,
676
- cu_seqlens_q=cu_seqlens,
677
- cu_seqlens_k=cu_seqlens,
678
- max_seqlen_q=seq_len,
679
- max_seqlen_k=seq_len,
680
- do_causal=False,
681
- scale=scale,
682
- softcapping=1.0,
683
- )
684
-
685
- # Compute ground truth
686
- expected = torch.zeros_like(out)
687
- for h in range(num_heads):
688
- q_h = query[:, h, :] # [seq_len, head_dim]
689
- k_h = key[:, h, :]
690
- v_h = value[:, h, :]
691
-
692
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
693
- attn_weights = torch.softmax(scores, dim=-1)
694
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
695
-
696
- # Check results (higher tolerance for bfloat16)
697
- if dtype == torch.bfloat16:
698
- rtol, atol = 2e-2, 2e-2
699
- elif dtype == torch.float16:
700
- rtol, atol = 2e-3, 2e-3
701
- else:
702
- rtol, atol = 1e-3, 1e-3
703
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
704
-
705
-
706
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
707
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
708
- def test_flash_attention_cross_attention(dtype, head_dim):
709
- """Test Flash Attention with different q_seq and k_seq (cross-attention)."""
710
- torch.manual_seed(42)
711
-
712
- # Test various q_seq, k_seq combinations
713
- test_cases = [
714
- (16, 32), # q_seq < k_seq
715
- (32, 16), # q_seq > k_seq
716
- (8, 128), # large difference
717
- (1, 64), # single query token
718
- ]
719
-
720
- for q_seq, k_seq in test_cases:
721
- num_heads = 4
722
-
723
- # Create cumulative sequence lengths
724
- cu_seqlens_q = create_cu_seqlens([q_seq])
725
- cu_seqlens_k = create_cu_seqlens([k_seq])
726
-
727
- # Create input tensors
728
- query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
729
- key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
730
- value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
731
-
732
- # Scale factor
733
- scale = 1.0 / (head_dim ** 0.5)
734
-
735
- # Call Flash Attention
736
- out = torch.empty_like(query)
737
- metal_flash_sdpa.flash_attention_varlen(
738
- out=out,
739
- query=query,
740
- key=key,
741
- value=value,
742
- cu_seqlens_q=cu_seqlens_q,
743
- cu_seqlens_k=cu_seqlens_k,
744
- max_seqlen_q=q_seq,
745
- max_seqlen_k=k_seq,
746
- do_causal=False,
747
- scale=scale,
748
- softcapping=1.0,
749
- )
750
-
751
- # Compute ground truth
752
- expected = torch.zeros_like(out)
753
- for h in range(num_heads):
754
- q_h = query[:, h, :] # [q_seq, head_dim]
755
- k_h = key[:, h, :] # [k_seq, head_dim]
756
- v_h = value[:, h, :] # [k_seq, head_dim]
757
-
758
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
759
- attn_weights = torch.softmax(scores, dim=-1)
760
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
761
-
762
- # Check results (higher tolerance for bfloat16)
763
- if dtype == torch.bfloat16:
764
- rtol, atol = 2e-2, 2e-2
765
- elif dtype == torch.float16:
766
- rtol, atol = 2e-3, 2e-3
767
- else:
768
- rtol, atol = 1e-3, 1e-3
769
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
770
-
771
-
772
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
773
- def test_flash_attention_large_sequences(dtype):
774
- """Test Flash Attention with large k_seq (>= 1024)."""
775
- torch.manual_seed(42)
776
-
777
- # Test dimensions - large k_seq to test 2-pass algorithms
778
- q_seq = 32
779
- k_seq = 2048 # Large k_seq
780
- num_heads = 4
781
- head_dim = 64 # Use smaller head_dim to avoid memory issues
782
-
783
- # Create cumulative sequence lengths
784
- cu_seqlens_q = create_cu_seqlens([q_seq])
785
- cu_seqlens_k = create_cu_seqlens([k_seq])
786
-
787
- # Create input tensors
788
- query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
789
- key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
790
- value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
791
-
792
- # Scale factor
793
- scale = 1.0 / (head_dim ** 0.5)
794
-
795
- # Call Flash Attention
796
- out = torch.empty_like(query)
797
- metal_flash_sdpa.flash_attention_varlen(
798
- out=out,
799
- query=query,
800
- key=key,
801
- value=value,
802
- cu_seqlens_q=cu_seqlens_q,
803
- cu_seqlens_k=cu_seqlens_k,
804
- max_seqlen_q=q_seq,
805
- max_seqlen_k=k_seq,
806
- do_causal=False,
807
- scale=scale,
808
- softcapping=1.0,
809
- )
810
-
811
- # Compute ground truth
812
- expected = torch.zeros_like(out)
813
- for h in range(num_heads):
814
- q_h = query[:, h, :] # [q_seq, head_dim]
815
- k_h = key[:, h, :] # [k_seq, head_dim]
816
- v_h = value[:, h, :] # [k_seq, head_dim]
817
-
818
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
819
- attn_weights = torch.softmax(scores, dim=-1)
820
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
821
-
822
- # Check results (higher tolerance for large sequences)
823
- if dtype == torch.bfloat16:
824
- rtol, atol = 3e-2, 3e-2
825
- elif dtype == torch.float16:
826
- rtol, atol = 5e-3, 5e-3
827
- else:
828
- rtol, atol = 2e-3, 2e-3
829
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
830
-
831
-
832
- @pytest.mark.parametrize("gqa_ratio", [2, 4, 8])
833
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128])
834
- def test_flash_attention_gqa_ratios(gqa_ratio, head_dim):
835
- """Test Flash Attention with different GQA ratios."""
836
- torch.manual_seed(42)
837
-
838
- # Test dimensions
839
- seq_len = 32
840
- num_heads = 16
841
- num_kv_heads = num_heads // gqa_ratio
842
- dtype = torch.float32
843
-
844
- # Create cumulative sequence lengths
845
- cu_seqlens = create_cu_seqlens([seq_len])
846
-
847
- # Create input tensors
848
- query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
849
- key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
850
- value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
851
-
852
- # Scale factor
853
- scale = 1.0 / (head_dim ** 0.5)
854
-
855
- # Call Flash Attention
856
- out = torch.empty_like(query)
857
- metal_flash_sdpa.flash_attention_varlen(
858
- out=out,
859
- query=query,
860
- key=key,
861
- value=value,
862
- cu_seqlens_q=cu_seqlens,
863
- cu_seqlens_k=cu_seqlens,
864
- max_seqlen_q=seq_len,
865
- max_seqlen_k=seq_len,
866
- do_causal=False,
867
- scale=scale,
868
- softcapping=1.0,
869
- )
870
-
871
- # Compute ground truth with GQA
872
- expected = torch.zeros_like(query)
873
- gqa_factor = num_heads // num_kv_heads
874
-
875
- for h in range(num_heads):
876
- kv_h = h // gqa_factor
877
- q_h = query[:, h, :] # [seq_len, head_dim]
878
- k_h = key[:, kv_h, :]
879
- v_h = value[:, kv_h, :]
880
-
881
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
882
- attn_weights = torch.softmax(scores, dim=-1)
883
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
884
-
885
- # Check results
886
- torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
887
-
888
-
889
- def test_flash_attention_single_query_token():
890
- """Test Flash Attention with single query token (q_seq = 1)."""
891
- torch.manual_seed(42)
892
-
893
- # Test dimensions - single query token
894
- q_seq = 1
895
- k_seq = 64
896
- num_heads = 8
897
- head_dim = 64
898
- dtype = torch.float32
899
-
900
- # Create cumulative sequence lengths
901
- cu_seqlens_q = create_cu_seqlens([q_seq])
902
- cu_seqlens_k = create_cu_seqlens([k_seq])
903
-
904
- # Create input tensors
905
- query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
906
- key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
907
- value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
908
-
909
- # Scale factor
910
- scale = 1.0 / (head_dim ** 0.5)
911
-
912
- # Call Flash Attention
913
- out = torch.empty_like(query)
914
- metal_flash_sdpa.flash_attention_varlen(
915
- out=out,
916
- query=query,
917
- key=key,
918
- value=value,
919
- cu_seqlens_q=cu_seqlens_q,
920
- cu_seqlens_k=cu_seqlens_k,
921
- max_seqlen_q=q_seq,
922
- max_seqlen_k=k_seq,
923
- do_causal=False,
924
- scale=scale,
925
- softcapping=1.0,
926
- )
927
-
928
- # With single token, output should be weighted average of values
929
- expected = torch.zeros_like(out)
930
- for h in range(num_heads):
931
- q_h = query[:, h, :] # [1, head_dim]
932
- k_h = key[:, h, :] # [k_seq, head_dim]
933
- v_h = value[:, h, :] # [k_seq, head_dim]
934
-
935
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
936
- attn_weights = torch.softmax(scores, dim=-1)
937
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
938
-
939
- torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
940
-
941
-
942
  def test_flash_attn_varlen_func():
943
  """Test the flash_attn_varlen_func compatibility function."""
944
  torch.manual_seed(42)
@@ -992,141 +467,3 @@ def test_flash_attn_varlen_func():
992
 
993
  assert out_causal.shape == q.shape
994
  assert out_causal.abs().max().item() > 0
995
-
996
-
997
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
998
- @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
999
- def test_flash_attention_softcapping(dtype, head_dim):
1000
- """Test Flash Attention with softcapping."""
1001
- torch.manual_seed(42)
1002
-
1003
- # Test dimensions
1004
- seq_lengths = [32, 24]
1005
- num_heads = 4
1006
- softcapping = 50.0
1007
-
1008
- # Create cumulative sequence lengths
1009
- cu_seqlens = create_cu_seqlens(seq_lengths)
1010
- total_tokens = sum(seq_lengths)
1011
- max_seqlen = max(seq_lengths)
1012
-
1013
- # Create input tensors
1014
- query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
1015
- key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
1016
- value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
1017
-
1018
- # Scale factor
1019
- scale = 1.0 / (head_dim ** 0.5)
1020
-
1021
- # Call Flash Attention with softcapping
1022
- out = torch.empty_like(query)
1023
- metal_flash_sdpa.flash_attention_varlen(
1024
- out=out,
1025
- query=query,
1026
- key=key,
1027
- value=value,
1028
- cu_seqlens_q=cu_seqlens,
1029
- cu_seqlens_k=cu_seqlens,
1030
- max_seqlen_q=max_seqlen,
1031
- max_seqlen_k=max_seqlen,
1032
- do_causal=False,
1033
- scale=scale,
1034
- softcapping=softcapping,
1035
- )
1036
-
1037
- # Compute ground truth with softcapping
1038
- # The kernel applies: softmax(tanh(qk^T*scale/cap)*cap)v
1039
- expected = torch.zeros_like(query)
1040
-
1041
- for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
1042
- q_seq = query[start:end]
1043
- k_seq = key[start:end]
1044
- v_seq = value[start:end]
1045
-
1046
- for h in range(num_heads):
1047
- q_h = q_seq[:, h, :]
1048
- k_h = k_seq[:, h, :]
1049
- v_h = v_seq[:, h, :]
1050
-
1051
- # Apply softcapping formula
1052
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * (scale / softcapping)
1053
- scores = torch.tanh(scores) * softcapping
1054
- attn_weights = torch.softmax(scores, dim=-1)
1055
- expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
1056
-
1057
- # Check results (higher tolerance for bfloat16 and softcapping)
1058
- if dtype == torch.bfloat16:
1059
- rtol, atol = 3e-2, 3e-2
1060
- elif dtype == torch.float16:
1061
- rtol, atol = 2e-2, 2e-2
1062
- else:
1063
- rtol, atol = 1e-2, 1e-2
1064
- torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
1065
-
1066
-
1067
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
1068
- def test_flash_attention_softcapping_edge_cases(dtype):
1069
- """Test Flash Attention softcapping with edge cases."""
1070
- torch.manual_seed(42)
1071
-
1072
- # Test with softcapping = 1.0 (no softcapping)
1073
- seq_len = 16
1074
- num_heads = 2
1075
- head_dim = 64
1076
-
1077
- cu_seqlens = create_cu_seqlens([seq_len])
1078
- query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
1079
- key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
1080
- value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
1081
-
1082
- scale = 1.0 / (head_dim ** 0.5)
1083
-
1084
- # With softcapping = 1.0 (no effect)
1085
- out_no_cap = torch.empty_like(query)
1086
- metal_flash_sdpa.flash_attention_varlen(
1087
- out=out_no_cap,
1088
- query=query,
1089
- key=key,
1090
- value=value,
1091
- cu_seqlens_q=cu_seqlens,
1092
- cu_seqlens_k=cu_seqlens,
1093
- max_seqlen_q=seq_len,
1094
- max_seqlen_k=seq_len,
1095
- do_causal=False,
1096
- scale=scale,
1097
- softcapping=1.0,
1098
- )
1099
-
1100
- # Regular computation without softcapping
1101
- expected = torch.zeros_like(query)
1102
- for h in range(num_heads):
1103
- q_h = query[:, h, :]
1104
- k_h = key[:, h, :]
1105
- v_h = value[:, h, :]
1106
-
1107
- scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
1108
- attn_weights = torch.softmax(scores, dim=-1)
1109
- expected[:, h, :] = torch.matmul(attn_weights, v_h)
1110
-
1111
- # Should be identical when softcapping = 1.0
1112
- rtol, atol = (2e-2, 2e-2) if dtype != torch.float32 else (1e-3, 1e-3)
1113
- torch.testing.assert_close(out_no_cap, expected, rtol=rtol, atol=atol)
1114
-
1115
- # Test with very large softcapping value
1116
- out_large_cap = torch.empty_like(query)
1117
- metal_flash_sdpa.flash_attention_varlen(
1118
- out=out_large_cap,
1119
- query=query,
1120
- key=key,
1121
- value=value,
1122
- cu_seqlens_q=cu_seqlens,
1123
- cu_seqlens_k=cu_seqlens,
1124
- max_seqlen_q=seq_len,
1125
- max_seqlen_k=seq_len,
1126
- do_causal=False,
1127
- scale=scale,
1128
- softcapping=1000.0,
1129
- )
1130
-
1131
- # With very large softcapping, should be close to no softcapping
1132
- torch.testing.assert_close(out_large_cap, expected, rtol=rtol, atol=atol)
 
11
  return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
12
 
13
 
14
+ def compute_attention_reference(query, key, value, scale, causal=False, softcapping=1.0, gqa_ratio=1):
15
+ """Compute reference attention output for validation."""
16
+ num_heads = query.shape[1]
17
+ expected = torch.zeros_like(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
19
  for h in range(num_heads):
20
+ kv_h = h // gqa_ratio if gqa_ratio > 1 else h
21
+ q_h = query[:, h, :]
22
+ k_h = key[:, kv_h, :]
23
+ v_h = value[:, kv_h, :]
24
 
25
  scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
26
+
27
+ # Apply softcapping if not 1.0
28
+ if softcapping != 1.0:
29
+ scores = scores / softcapping
30
+ scores = torch.tanh(scores) * softcapping
31
+
32
+ # Apply causal mask if needed
33
+ if causal:
34
+ seq_len = query.shape[0]
35
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
36
+ scores.masked_fill_(causal_mask, float("-inf"))
37
+
38
  attn_weights = torch.softmax(scores, dim=-1)
39
  expected[:, h, :] = torch.matmul(attn_weights, v_h)
40
 
41
+ return expected
42
+
43
+
44
+ def get_tolerance(dtype, head_dim):
45
+ """Get appropriate tolerance based on dtype and head dimension."""
46
  if dtype == torch.bfloat16:
47
+ return (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
 
48
  elif dtype == torch.float16:
49
+ return (2e-3, 2e-3)
50
  else:
51
+ return (1e-3, 1e-3)
 
52
 
53
 
54
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
55
  @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
56
+ @pytest.mark.parametrize("seq_config", [
57
+ # (seq_lengths_q, seq_lengths_k, description)
58
+ ([32], [32], "single_sequence"),
59
+ ([8, 16, 12], [10, 20, 15], "variable_lengths"),
60
+ ([16, 24], [16, 24], "multiple_sequences"),
61
+ ([2], [2], "small_sequence_2"),
62
+ ([4], [4], "small_sequence_4"),
63
+ ([8], [8], "small_sequence_8"),
64
+ ([16], [32], "cross_attention_q_lt_k"),
65
+ ([32], [16], "cross_attention_q_gt_k"),
66
+ ([8], [128], "cross_attention_large_diff"),
67
+ ([1], [64], "single_query_token"),
68
+ ])
69
+ @pytest.mark.parametrize("causal", [False, True])
70
+ def test_flash_attention_comprehensive(dtype, head_dim, seq_config, causal):
71
+ """Comprehensive test for Flash Attention with various configurations."""
72
  torch.manual_seed(42)
73
 
74
+ seq_lengths_q, seq_lengths_k, _ = seq_config
75
+
76
+ # Skip causal tests for cross-attention cases
77
+ if causal and seq_lengths_q != seq_lengths_k:
78
+ pytest.skip("Causal attention only valid when q_seq == k_seq")
79
+
80
+ # Test parameters
81
  num_heads = 4
82
 
83
  # Create cumulative sequence lengths
 
108
  cu_seqlens_k=cu_seqlens_k,
109
  max_seqlen_q=max_seqlen_q,
110
  max_seqlen_k=max_seqlen_k,
111
+ do_causal=causal,
112
  scale=scale,
113
  softcapping=1.0,
114
  )
115
 
116
  # Compute ground truth for each sequence
117
  expected = torch.zeros_like(out)
118
+ batch_size = len(seq_lengths_q)
119
+
120
  for i in range(batch_size):
121
  q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
122
  k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
123
 
124
+ if q_end > q_start and k_end > k_start: # Skip empty sequences
125
+ q_i = query[q_start:q_end]
126
+ k_i = key[k_start:k_end]
127
+ v_i = value[k_start:k_end]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ expected_i = compute_attention_reference(q_i, k_i, v_i, scale, causal=causal)
130
+ expected[q_start:q_end] = expected_i
131
 
132
+ # Check results
133
+ rtol, atol = get_tolerance(dtype, head_dim)
 
 
 
 
 
 
134
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
135
 
136
 
137
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
138
  @pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
139
+ @pytest.mark.parametrize("gqa_config", [
140
+ # (num_heads, num_kv_heads, seq_len)
141
+ (8, 2, 32), # 4:1 ratio
142
+ (16, 4, 32), # 4:1 ratio
143
+ (16, 8, 32), # 2:1 ratio
144
+ (16, 2, 32), # 8:1 ratio
145
+ (16, 4, 128), # 4:1 ratio with larger sequence
146
+ ])
147
+ def test_flash_attention_gqa(dtype, head_dim, gqa_config):
148
+ """Test Flash Attention with Grouped Query Attention configurations."""
149
  torch.manual_seed(42)
150
 
151
+ num_heads, num_kv_heads, seq_len = gqa_config
152
+ gqa_ratio = num_heads // num_kv_heads
 
 
153
 
154
  # Create cumulative sequence lengths
155
  cu_seqlens = create_cu_seqlens([seq_len])
 
179
  )
180
 
181
  # Compute ground truth with GQA
182
+ expected = compute_attention_reference(query, key, value, scale, gqa_ratio=gqa_ratio)
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Check results
185
+ rtol, atol = get_tolerance(dtype, head_dim)
 
 
 
 
 
 
186
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
190
+ @pytest.mark.parametrize("softcapping_config", [
191
+ # (softcapping_value, seq_lengths, head_dim)
192
+ (1.0, [32], 64), # No softcapping
193
+ (50.0, [32, 24], 64), # Regular softcapping
194
+ (10.0, [16], 128), # Strong softcapping
195
+ (1000.0, [16], 64), # Very weak softcapping
196
+ (30.0, [48], 96), # Medium softcapping
197
+ ])
198
+ def test_flash_attention_softcapping(dtype, softcapping_config):
199
+ """Test Flash Attention with various softcapping values."""
200
  torch.manual_seed(42)
201
 
202
+ softcapping, seq_lengths, head_dim = softcapping_config
203
+ num_heads = 4
 
 
 
204
 
205
  # Create cumulative sequence lengths
206
  cu_seqlens = create_cu_seqlens(seq_lengths)
 
215
  # Scale factor
216
  scale = 1.0 / (head_dim ** 0.5)
217
 
218
+ # Call Flash Attention with softcapping
219
  out = torch.empty_like(query)
220
  metal_flash_sdpa.flash_attention_varlen(
221
  out=out,
 
228
  max_seqlen_k=max_seqlen,
229
  do_causal=False,
230
  scale=scale,
231
+ softcapping=softcapping,
232
  )
233
 
234
+ # Compute ground truth with softcapping
235
+ expected = torch.zeros_like(query)
236
+
237
+ for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]):
238
+ if end > start:
239
+ q_seq = query[start:end]
240
+ k_seq = key[start:end]
241
+ v_seq = value[start:end]
 
 
 
 
 
 
242
 
243
+ expected_seq = compute_attention_reference(
244
+ q_seq, k_seq, v_seq, scale, softcapping=softcapping
245
+ )
246
+ expected[start:end] = expected_seq
247
 
248
+ # Check results (higher tolerance for softcapping)
249
  if dtype == torch.bfloat16:
250
+ rtol, atol = 3e-2, 3e-2
 
251
  elif dtype == torch.float16:
252
+ rtol, atol = 2e-2, 2e-2
253
  else:
254
+ rtol, atol = 1e-2, 1e-2
255
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
256
 
257
 
258
+ @pytest.mark.parametrize("large_seq_config", [
259
+ # (q_seq, k_seq, head_dim, dtype)
260
+ (32, 2048, 64, torch.float32),
261
+ (16, 1024, 96, torch.float16),
262
+ (64, 1536, 64, torch.bfloat16),
263
+ ])
264
+ def test_flash_attention_large_sequences(large_seq_config):
265
+ """Test Flash Attention with large k sequences (>= 1024)."""
266
  torch.manual_seed(42)
267
 
268
+ q_seq, k_seq, head_dim, dtype = large_seq_config
 
269
  num_heads = 4
 
270
 
271
  # Create cumulative sequence lengths
272
+ cu_seqlens_q = create_cu_seqlens([q_seq])
273
+ cu_seqlens_k = create_cu_seqlens([k_seq])
274
 
275
  # Create input tensors
276
+ query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
277
+ key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
278
+ value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
279
 
280
  # Scale factor
281
  scale = 1.0 / (head_dim ** 0.5)
282
 
283
+ # Call Flash Attention
284
  out = torch.empty_like(query)
285
  metal_flash_sdpa.flash_attention_varlen(
286
  out=out,
287
  query=query,
288
  key=key,
289
  value=value,
290
+ cu_seqlens_q=cu_seqlens_q,
291
+ cu_seqlens_k=cu_seqlens_k,
292
+ max_seqlen_q=q_seq,
293
+ max_seqlen_k=k_seq,
294
+ do_causal=False,
295
  scale=scale,
296
  softcapping=1.0,
297
  )
298
 
299
+ # Compute ground truth
300
+ expected = compute_attention_reference(query, key, value, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # Check results (higher tolerance for large sequences)
303
  if dtype == torch.bfloat16:
304
+ rtol, atol = 3e-2, 3e-2
 
305
  elif dtype == torch.float16:
306
+ rtol, atol = 5e-3, 5e-3
307
  else:
308
+ rtol, atol = 2e-3, 2e-3
309
  torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
310
 
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  def test_flash_attention_edge_cases():
313
  """Test Flash Attention edge cases."""
314
  torch.manual_seed(42)
 
388
  softcapping=1.0,
389
  )
390
 
391
+ # Test 2: Wrong dtype for cu_seqlens (should be int32)
392
+ cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
393
  query = torch.randn(16, 4, 64, device="mps")
394
  key = torch.randn(16, 4, 64, device="mps")
395
  value = torch.randn(16, 4, 64, device="mps")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  # This will silently fail (output will be unchanged)
 
398
  out = torch.full_like(query, -999.0)
399
  metal_flash_sdpa.flash_attention_varlen(
400
  out=out,
 
414
  assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
415
 
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  def test_flash_attn_varlen_func():
418
  """Test the flash_attn_varlen_func compatibility function."""
419
  torch.manual_seed(42)
 
467
 
468
  assert out_causal.shape == q.shape
469
  assert out_causal.abs().max().item() > 0