Eric Buehler
commited on
Commit
·
e7707ac
1
Parent(s):
bc6a74d
Better testing
Browse files- tests/test_flash_attention.py +135 -798
tests/test_flash_attention.py
CHANGED
|
@@ -11,76 +11,73 @@ def create_cu_seqlens(seq_lengths):
|
|
| 11 |
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
torch.manual_seed(42)
|
| 19 |
-
|
| 20 |
-
# Single sequence
|
| 21 |
-
seq_len = 32
|
| 22 |
-
num_heads = 4
|
| 23 |
-
|
| 24 |
-
# Create cumulative sequence lengths
|
| 25 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 26 |
-
|
| 27 |
-
# Create input tensors in Flash Attention format
|
| 28 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 29 |
-
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 30 |
-
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 31 |
-
|
| 32 |
-
# Scale factor
|
| 33 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 34 |
-
|
| 35 |
-
# Call Flash Attention
|
| 36 |
-
out = torch.empty_like(query)
|
| 37 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 38 |
-
out=out,
|
| 39 |
-
query=query,
|
| 40 |
-
key=key,
|
| 41 |
-
value=value,
|
| 42 |
-
cu_seqlens_q=cu_seqlens,
|
| 43 |
-
cu_seqlens_k=cu_seqlens,
|
| 44 |
-
max_seqlen_q=seq_len,
|
| 45 |
-
max_seqlen_k=seq_len,
|
| 46 |
-
do_causal=False,
|
| 47 |
-
scale=scale,
|
| 48 |
-
softcapping=1.0,
|
| 49 |
-
)
|
| 50 |
|
| 51 |
-
# Compute ground truth
|
| 52 |
-
# Flash Attention computes attention separately for each head
|
| 53 |
-
expected = torch.zeros_like(out)
|
| 54 |
for h in range(num_heads):
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
attn_weights = torch.softmax(scores, dim=-1)
|
| 61 |
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if dtype == torch.bfloat16:
|
| 65 |
-
|
| 66 |
-
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
|
| 67 |
elif dtype == torch.float16:
|
| 68 |
-
|
| 69 |
else:
|
| 70 |
-
|
| 71 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 72 |
|
| 73 |
|
| 74 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 75 |
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
torch.manual_seed(42)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
num_heads = 4
|
| 85 |
|
| 86 |
# Create cumulative sequence lengths
|
|
@@ -111,128 +108,48 @@ def test_flash_attention_variable_lengths(dtype, head_dim):
|
|
| 111 |
cu_seqlens_k=cu_seqlens_k,
|
| 112 |
max_seqlen_q=max_seqlen_q,
|
| 113 |
max_seqlen_k=max_seqlen_k,
|
| 114 |
-
do_causal=
|
| 115 |
scale=scale,
|
| 116 |
softcapping=1.0,
|
| 117 |
)
|
| 118 |
|
| 119 |
# Compute ground truth for each sequence
|
| 120 |
expected = torch.zeros_like(out)
|
|
|
|
|
|
|
| 121 |
for i in range(batch_size):
|
| 122 |
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
|
| 123 |
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# Compute attention for each head separately
|
| 130 |
-
for h in range(num_heads):
|
| 131 |
-
q_h = q_i[:, h, :] # [seq_len, head_dim]
|
| 132 |
-
k_h = k_i[:, h, :]
|
| 133 |
-
v_h = v_i[:, h, :]
|
| 134 |
-
|
| 135 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 136 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 137 |
-
expected[q_start:q_end, h, :] = torch.matmul(attn_weights, v_h)
|
| 138 |
-
|
| 139 |
-
# Check results (higher tolerance for bfloat16 and float16)
|
| 140 |
-
if dtype == torch.bfloat16:
|
| 141 |
-
# Higher tolerance for bfloat16 with variable length sequences
|
| 142 |
-
rtol, atol = 2e-2, 2e-2
|
| 143 |
-
elif dtype == torch.float16:
|
| 144 |
-
rtol, atol = 2e-3, 2e-3
|
| 145 |
-
else:
|
| 146 |
-
rtol, atol = 1e-3, 1e-3
|
| 147 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 151 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 152 |
-
def test_flash_attention_causal(dtype, head_dim):
|
| 153 |
-
"""Test Flash Attention with causal masking."""
|
| 154 |
-
torch.manual_seed(42)
|
| 155 |
-
|
| 156 |
-
# Test dimensions
|
| 157 |
-
seq_lengths = [16, 24]
|
| 158 |
-
batch_size = len(seq_lengths)
|
| 159 |
-
num_heads = 4
|
| 160 |
-
|
| 161 |
-
# Create cumulative sequence lengths
|
| 162 |
-
cu_seqlens = create_cu_seqlens(seq_lengths)
|
| 163 |
-
total_tokens = sum(seq_lengths)
|
| 164 |
-
max_seqlen = max(seq_lengths)
|
| 165 |
-
|
| 166 |
-
# Create input tensors
|
| 167 |
-
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 168 |
-
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 169 |
-
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 170 |
-
|
| 171 |
-
# Scale factor
|
| 172 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 173 |
-
|
| 174 |
-
# Call Flash Attention with causal mask
|
| 175 |
-
out = torch.empty_like(query)
|
| 176 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 177 |
-
out=out,
|
| 178 |
-
query=query,
|
| 179 |
-
key=key,
|
| 180 |
-
value=value,
|
| 181 |
-
cu_seqlens_q=cu_seqlens,
|
| 182 |
-
cu_seqlens_k=cu_seqlens,
|
| 183 |
-
max_seqlen_q=max_seqlen,
|
| 184 |
-
max_seqlen_k=max_seqlen,
|
| 185 |
-
do_causal=True,
|
| 186 |
-
scale=scale,
|
| 187 |
-
softcapping=1.0,
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
# Compute ground truth with causal mask
|
| 191 |
-
expected = torch.zeros_like(out)
|
| 192 |
-
for i in range(batch_size):
|
| 193 |
-
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
|
| 194 |
-
seq_len = end - start
|
| 195 |
-
|
| 196 |
-
q_i = query[start:end]
|
| 197 |
-
k_i = key[start:end]
|
| 198 |
-
v_i = value[start:end]
|
| 199 |
-
|
| 200 |
-
# Compute attention for each head separately
|
| 201 |
-
for h in range(num_heads):
|
| 202 |
-
q_h = q_i[:, h, :] # [seq_len, head_dim]
|
| 203 |
-
k_h = k_i[:, h, :]
|
| 204 |
-
v_h = v_i[:, h, :]
|
| 205 |
-
|
| 206 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 207 |
-
|
| 208 |
-
# Apply causal mask
|
| 209 |
-
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
|
| 210 |
-
scores.masked_fill_(causal_mask, float("-inf"))
|
| 211 |
|
| 212 |
-
|
| 213 |
-
expected[
|
| 214 |
|
| 215 |
-
# Check results
|
| 216 |
-
|
| 217 |
-
# Higher tolerance for head_dim=128 with bfloat16
|
| 218 |
-
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
|
| 219 |
-
elif dtype == torch.float16:
|
| 220 |
-
rtol, atol = 2e-3, 2e-3
|
| 221 |
-
else:
|
| 222 |
-
rtol, atol = 1e-3, 1e-3
|
| 223 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 224 |
|
| 225 |
|
| 226 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 227 |
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
torch.manual_seed(42)
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
num_heads = 8
|
| 235 |
-
num_kv_heads = 2 # GQA with 4:1 ratio
|
| 236 |
|
| 237 |
# Create cumulative sequence lengths
|
| 238 |
cu_seqlens = create_cu_seqlens([seq_len])
|
|
@@ -262,81 +179,28 @@ def test_flash_attention_gqa(dtype, head_dim):
|
|
| 262 |
)
|
| 263 |
|
| 264 |
# Compute ground truth with GQA
|
| 265 |
-
|
| 266 |
-
expected = torch.zeros_like(query)
|
| 267 |
-
gqa_factor = num_heads // num_kv_heads
|
| 268 |
-
|
| 269 |
-
for h in range(num_heads):
|
| 270 |
-
kv_h = h // gqa_factor
|
| 271 |
-
q_h = query[:, h, :] # [seq_len, head_dim]
|
| 272 |
-
k_h = key[:, kv_h, :]
|
| 273 |
-
v_h = value[:, kv_h, :]
|
| 274 |
-
|
| 275 |
-
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
|
| 276 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 277 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 278 |
|
| 279 |
-
# Check results
|
| 280 |
-
|
| 281 |
-
# Higher tolerance for head_dim=128 with bfloat16
|
| 282 |
-
rtol, atol = (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
|
| 283 |
-
elif dtype == torch.float16:
|
| 284 |
-
rtol, atol = 2e-3, 2e-3
|
| 285 |
-
else:
|
| 286 |
-
rtol, atol = 1e-3, 1e-3
|
| 287 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 288 |
|
| 289 |
|
| 290 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 291 |
-
def test_flash_attention_head_dimensions(head_dim):
|
| 292 |
-
"""Test Flash Attention with different supported head dimensions."""
|
| 293 |
-
torch.manual_seed(42)
|
| 294 |
-
|
| 295 |
-
# Test dimensions
|
| 296 |
-
seq_len = 16
|
| 297 |
-
num_heads = 4
|
| 298 |
-
|
| 299 |
-
# Create cumulative sequence lengths
|
| 300 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 301 |
-
|
| 302 |
-
# Create input tensors
|
| 303 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
|
| 304 |
-
key = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
|
| 305 |
-
value = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
|
| 306 |
-
|
| 307 |
-
# Scale factor
|
| 308 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 309 |
-
|
| 310 |
-
# Call Flash Attention
|
| 311 |
-
out = torch.empty_like(query)
|
| 312 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 313 |
-
out=out,
|
| 314 |
-
query=query,
|
| 315 |
-
key=key,
|
| 316 |
-
value=value,
|
| 317 |
-
cu_seqlens_q=cu_seqlens,
|
| 318 |
-
cu_seqlens_k=cu_seqlens,
|
| 319 |
-
max_seqlen_q=seq_len,
|
| 320 |
-
max_seqlen_k=seq_len,
|
| 321 |
-
do_causal=False,
|
| 322 |
-
scale=scale,
|
| 323 |
-
softcapping=1.0,
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
# Basic check that output is not zeros
|
| 327 |
-
assert out.abs().max().item() > 0
|
| 328 |
-
|
| 329 |
-
|
| 330 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
torch.manual_seed(42)
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
batch_size = len(seq_lengths)
|
| 338 |
-
num_heads = 8
|
| 339 |
-
head_dim = 128
|
| 340 |
|
| 341 |
# Create cumulative sequence lengths
|
| 342 |
cu_seqlens = create_cu_seqlens(seq_lengths)
|
|
@@ -351,7 +215,7 @@ def test_flash_attention_large_head_dim(dtype):
|
|
| 351 |
# Scale factor
|
| 352 |
scale = 1.0 / (head_dim ** 0.5)
|
| 353 |
|
| 354 |
-
# Call Flash Attention
|
| 355 |
out = torch.empty_like(query)
|
| 356 |
metal_flash_sdpa.flash_attention_varlen(
|
| 357 |
out=out,
|
|
@@ -364,159 +228,87 @@ def test_flash_attention_large_head_dim(dtype):
|
|
| 364 |
max_seqlen_k=max_seqlen,
|
| 365 |
do_causal=False,
|
| 366 |
scale=scale,
|
| 367 |
-
softcapping=
|
| 368 |
)
|
| 369 |
|
| 370 |
-
# Compute ground truth
|
| 371 |
-
expected = torch.zeros_like(
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
# Compute attention for each head separately
|
| 380 |
-
for h in range(num_heads):
|
| 381 |
-
q_h = q_i[:, h, :] # [seq_len, head_dim]
|
| 382 |
-
k_h = k_i[:, h, :]
|
| 383 |
-
v_h = v_i[:, h, :]
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
| 388 |
|
| 389 |
-
# Check results (higher tolerance for
|
| 390 |
if dtype == torch.bfloat16:
|
| 391 |
-
|
| 392 |
-
rtol, atol = 2e-2, 2e-2
|
| 393 |
elif dtype == torch.float16:
|
| 394 |
-
rtol, atol = 2e-
|
| 395 |
else:
|
| 396 |
-
rtol, atol = 1e-
|
| 397 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 398 |
|
| 399 |
|
| 400 |
-
@pytest.mark.parametrize("
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
torch.manual_seed(42)
|
| 404 |
|
| 405 |
-
|
| 406 |
-
seq_len = 48
|
| 407 |
num_heads = 4
|
| 408 |
-
head_dim = 128
|
| 409 |
|
| 410 |
# Create cumulative sequence lengths
|
| 411 |
-
|
|
|
|
| 412 |
|
| 413 |
# Create input tensors
|
| 414 |
-
query = torch.randn(
|
| 415 |
-
key = torch.randn(
|
| 416 |
-
value = torch.randn(
|
| 417 |
|
| 418 |
# Scale factor
|
| 419 |
scale = 1.0 / (head_dim ** 0.5)
|
| 420 |
|
| 421 |
-
# Call Flash Attention
|
| 422 |
out = torch.empty_like(query)
|
| 423 |
metal_flash_sdpa.flash_attention_varlen(
|
| 424 |
out=out,
|
| 425 |
query=query,
|
| 426 |
key=key,
|
| 427 |
value=value,
|
| 428 |
-
cu_seqlens_q=
|
| 429 |
-
cu_seqlens_k=
|
| 430 |
-
max_seqlen_q=
|
| 431 |
-
max_seqlen_k=
|
| 432 |
-
do_causal=
|
| 433 |
scale=scale,
|
| 434 |
softcapping=1.0,
|
| 435 |
)
|
| 436 |
|
| 437 |
-
# Compute ground truth
|
| 438 |
-
expected =
|
| 439 |
-
|
| 440 |
-
for h in range(num_heads):
|
| 441 |
-
q_h = query[:, h, :] # [seq_len, head_dim]
|
| 442 |
-
k_h = key[:, h, :]
|
| 443 |
-
v_h = value[:, h, :]
|
| 444 |
-
|
| 445 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 446 |
-
|
| 447 |
-
# Apply causal mask
|
| 448 |
-
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
|
| 449 |
-
scores.masked_fill_(causal_mask, float("-inf"))
|
| 450 |
-
|
| 451 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 452 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 453 |
|
| 454 |
-
# Check results (higher tolerance for
|
| 455 |
if dtype == torch.bfloat16:
|
| 456 |
-
|
| 457 |
-
rtol, atol = 2e-2, 2e-2
|
| 458 |
elif dtype == torch.float16:
|
| 459 |
-
rtol, atol =
|
| 460 |
else:
|
| 461 |
-
rtol, atol =
|
| 462 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 463 |
|
| 464 |
|
| 465 |
-
def test_flash_attention_large_head_dim_gqa():
|
| 466 |
-
"""Test Flash Attention with head_dim=128 and GQA."""
|
| 467 |
-
torch.manual_seed(42)
|
| 468 |
-
|
| 469 |
-
# Test dimensions
|
| 470 |
-
seq_len = 32
|
| 471 |
-
num_heads = 16
|
| 472 |
-
num_kv_heads = 4 # GQA with 4:1 ratio
|
| 473 |
-
head_dim = 128
|
| 474 |
-
|
| 475 |
-
# Create cumulative sequence lengths
|
| 476 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 477 |
-
|
| 478 |
-
# Create input tensors
|
| 479 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float32, device="mps")
|
| 480 |
-
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
|
| 481 |
-
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float32, device="mps")
|
| 482 |
-
|
| 483 |
-
# Scale factor
|
| 484 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 485 |
-
|
| 486 |
-
# Call Flash Attention
|
| 487 |
-
out = torch.empty_like(query)
|
| 488 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 489 |
-
out=out,
|
| 490 |
-
query=query,
|
| 491 |
-
key=key,
|
| 492 |
-
value=value,
|
| 493 |
-
cu_seqlens_q=cu_seqlens,
|
| 494 |
-
cu_seqlens_k=cu_seqlens,
|
| 495 |
-
max_seqlen_q=seq_len,
|
| 496 |
-
max_seqlen_k=seq_len,
|
| 497 |
-
do_causal=False,
|
| 498 |
-
scale=scale,
|
| 499 |
-
softcapping=1.0,
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
# Compute ground truth with GQA
|
| 503 |
-
expected = torch.zeros_like(query)
|
| 504 |
-
gqa_factor = num_heads // num_kv_heads
|
| 505 |
-
|
| 506 |
-
for h in range(num_heads):
|
| 507 |
-
kv_h = h // gqa_factor
|
| 508 |
-
q_h = query[:, h, :] # [seq_len, head_dim]
|
| 509 |
-
k_h = key[:, kv_h, :]
|
| 510 |
-
v_h = value[:, kv_h, :]
|
| 511 |
-
|
| 512 |
-
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) * scale
|
| 513 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 514 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 515 |
-
|
| 516 |
-
# Check results
|
| 517 |
-
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
|
| 518 |
-
|
| 519 |
-
|
| 520 |
def test_flash_attention_edge_cases():
|
| 521 |
"""Test Flash Attention edge cases."""
|
| 522 |
torch.manual_seed(42)
|
|
@@ -596,36 +388,13 @@ def test_flash_attention_unsupported_cases():
|
|
| 596 |
softcapping=1.0,
|
| 597 |
)
|
| 598 |
|
| 599 |
-
# Test 2:
|
|
|
|
| 600 |
query = torch.randn(16, 4, 64, device="mps")
|
| 601 |
key = torch.randn(16, 4, 64, device="mps")
|
| 602 |
value = torch.randn(16, 4, 64, device="mps")
|
| 603 |
-
mask = torch.randn(1, 1, 16, 16, device="mps")
|
| 604 |
-
cu_seqlens = create_cu_seqlens([16])
|
| 605 |
-
out = torch.empty_like(query)
|
| 606 |
-
|
| 607 |
-
# The function signature no longer accepts mask parameter
|
| 608 |
-
with pytest.raises(TypeError):
|
| 609 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 610 |
-
out=out,
|
| 611 |
-
query=query,
|
| 612 |
-
key=key,
|
| 613 |
-
value=value,
|
| 614 |
-
cu_seqlens_q=cu_seqlens,
|
| 615 |
-
cu_seqlens_k=cu_seqlens,
|
| 616 |
-
max_seqlen_q=16,
|
| 617 |
-
max_seqlen_k=16,
|
| 618 |
-
mask=mask, # This parameter doesn't exist anymore
|
| 619 |
-
do_causal=False,
|
| 620 |
-
scale=0.125,
|
| 621 |
-
softcapping=1.0,
|
| 622 |
-
)
|
| 623 |
-
|
| 624 |
-
# Test 3: Wrong dtype for cu_seqlens (should be int32)
|
| 625 |
-
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
|
| 626 |
|
| 627 |
# This will silently fail (output will be unchanged)
|
| 628 |
-
# We can detect this by initializing output to a known value
|
| 629 |
out = torch.full_like(query, -999.0)
|
| 630 |
metal_flash_sdpa.flash_attention_varlen(
|
| 631 |
out=out,
|
|
@@ -645,300 +414,6 @@ def test_flash_attention_unsupported_cases():
|
|
| 645 |
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
|
| 646 |
|
| 647 |
|
| 648 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 649 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 650 |
-
def test_flash_attention_small_sequences(dtype, head_dim):
|
| 651 |
-
"""Test Flash Attention with small sequence lengths (2-8)."""
|
| 652 |
-
torch.manual_seed(42)
|
| 653 |
-
|
| 654 |
-
# Test different small sequence lengths
|
| 655 |
-
for seq_len in [2, 4, 6, 8]:
|
| 656 |
-
num_heads = 4
|
| 657 |
-
|
| 658 |
-
# Create cumulative sequence lengths
|
| 659 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 660 |
-
|
| 661 |
-
# Create input tensors
|
| 662 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 663 |
-
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 664 |
-
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 665 |
-
|
| 666 |
-
# Scale factor
|
| 667 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 668 |
-
|
| 669 |
-
# Call Flash Attention
|
| 670 |
-
out = torch.empty_like(query)
|
| 671 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 672 |
-
out=out,
|
| 673 |
-
query=query,
|
| 674 |
-
key=key,
|
| 675 |
-
value=value,
|
| 676 |
-
cu_seqlens_q=cu_seqlens,
|
| 677 |
-
cu_seqlens_k=cu_seqlens,
|
| 678 |
-
max_seqlen_q=seq_len,
|
| 679 |
-
max_seqlen_k=seq_len,
|
| 680 |
-
do_causal=False,
|
| 681 |
-
scale=scale,
|
| 682 |
-
softcapping=1.0,
|
| 683 |
-
)
|
| 684 |
-
|
| 685 |
-
# Compute ground truth
|
| 686 |
-
expected = torch.zeros_like(out)
|
| 687 |
-
for h in range(num_heads):
|
| 688 |
-
q_h = query[:, h, :] # [seq_len, head_dim]
|
| 689 |
-
k_h = key[:, h, :]
|
| 690 |
-
v_h = value[:, h, :]
|
| 691 |
-
|
| 692 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 693 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 694 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 695 |
-
|
| 696 |
-
# Check results (higher tolerance for bfloat16)
|
| 697 |
-
if dtype == torch.bfloat16:
|
| 698 |
-
rtol, atol = 2e-2, 2e-2
|
| 699 |
-
elif dtype == torch.float16:
|
| 700 |
-
rtol, atol = 2e-3, 2e-3
|
| 701 |
-
else:
|
| 702 |
-
rtol, atol = 1e-3, 1e-3
|
| 703 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 707 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 708 |
-
def test_flash_attention_cross_attention(dtype, head_dim):
|
| 709 |
-
"""Test Flash Attention with different q_seq and k_seq (cross-attention)."""
|
| 710 |
-
torch.manual_seed(42)
|
| 711 |
-
|
| 712 |
-
# Test various q_seq, k_seq combinations
|
| 713 |
-
test_cases = [
|
| 714 |
-
(16, 32), # q_seq < k_seq
|
| 715 |
-
(32, 16), # q_seq > k_seq
|
| 716 |
-
(8, 128), # large difference
|
| 717 |
-
(1, 64), # single query token
|
| 718 |
-
]
|
| 719 |
-
|
| 720 |
-
for q_seq, k_seq in test_cases:
|
| 721 |
-
num_heads = 4
|
| 722 |
-
|
| 723 |
-
# Create cumulative sequence lengths
|
| 724 |
-
cu_seqlens_q = create_cu_seqlens([q_seq])
|
| 725 |
-
cu_seqlens_k = create_cu_seqlens([k_seq])
|
| 726 |
-
|
| 727 |
-
# Create input tensors
|
| 728 |
-
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 729 |
-
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 730 |
-
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 731 |
-
|
| 732 |
-
# Scale factor
|
| 733 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 734 |
-
|
| 735 |
-
# Call Flash Attention
|
| 736 |
-
out = torch.empty_like(query)
|
| 737 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 738 |
-
out=out,
|
| 739 |
-
query=query,
|
| 740 |
-
key=key,
|
| 741 |
-
value=value,
|
| 742 |
-
cu_seqlens_q=cu_seqlens_q,
|
| 743 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 744 |
-
max_seqlen_q=q_seq,
|
| 745 |
-
max_seqlen_k=k_seq,
|
| 746 |
-
do_causal=False,
|
| 747 |
-
scale=scale,
|
| 748 |
-
softcapping=1.0,
|
| 749 |
-
)
|
| 750 |
-
|
| 751 |
-
# Compute ground truth
|
| 752 |
-
expected = torch.zeros_like(out)
|
| 753 |
-
for h in range(num_heads):
|
| 754 |
-
q_h = query[:, h, :] # [q_seq, head_dim]
|
| 755 |
-
k_h = key[:, h, :] # [k_seq, head_dim]
|
| 756 |
-
v_h = value[:, h, :] # [k_seq, head_dim]
|
| 757 |
-
|
| 758 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 759 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 760 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 761 |
-
|
| 762 |
-
# Check results (higher tolerance for bfloat16)
|
| 763 |
-
if dtype == torch.bfloat16:
|
| 764 |
-
rtol, atol = 2e-2, 2e-2
|
| 765 |
-
elif dtype == torch.float16:
|
| 766 |
-
rtol, atol = 2e-3, 2e-3
|
| 767 |
-
else:
|
| 768 |
-
rtol, atol = 1e-3, 1e-3
|
| 769 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 773 |
-
def test_flash_attention_large_sequences(dtype):
|
| 774 |
-
"""Test Flash Attention with large k_seq (>= 1024)."""
|
| 775 |
-
torch.manual_seed(42)
|
| 776 |
-
|
| 777 |
-
# Test dimensions - large k_seq to test 2-pass algorithms
|
| 778 |
-
q_seq = 32
|
| 779 |
-
k_seq = 2048 # Large k_seq
|
| 780 |
-
num_heads = 4
|
| 781 |
-
head_dim = 64 # Use smaller head_dim to avoid memory issues
|
| 782 |
-
|
| 783 |
-
# Create cumulative sequence lengths
|
| 784 |
-
cu_seqlens_q = create_cu_seqlens([q_seq])
|
| 785 |
-
cu_seqlens_k = create_cu_seqlens([k_seq])
|
| 786 |
-
|
| 787 |
-
# Create input tensors
|
| 788 |
-
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 789 |
-
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 790 |
-
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 791 |
-
|
| 792 |
-
# Scale factor
|
| 793 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 794 |
-
|
| 795 |
-
# Call Flash Attention
|
| 796 |
-
out = torch.empty_like(query)
|
| 797 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 798 |
-
out=out,
|
| 799 |
-
query=query,
|
| 800 |
-
key=key,
|
| 801 |
-
value=value,
|
| 802 |
-
cu_seqlens_q=cu_seqlens_q,
|
| 803 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 804 |
-
max_seqlen_q=q_seq,
|
| 805 |
-
max_seqlen_k=k_seq,
|
| 806 |
-
do_causal=False,
|
| 807 |
-
scale=scale,
|
| 808 |
-
softcapping=1.0,
|
| 809 |
-
)
|
| 810 |
-
|
| 811 |
-
# Compute ground truth
|
| 812 |
-
expected = torch.zeros_like(out)
|
| 813 |
-
for h in range(num_heads):
|
| 814 |
-
q_h = query[:, h, :] # [q_seq, head_dim]
|
| 815 |
-
k_h = key[:, h, :] # [k_seq, head_dim]
|
| 816 |
-
v_h = value[:, h, :] # [k_seq, head_dim]
|
| 817 |
-
|
| 818 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 819 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 820 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 821 |
-
|
| 822 |
-
# Check results (higher tolerance for large sequences)
|
| 823 |
-
if dtype == torch.bfloat16:
|
| 824 |
-
rtol, atol = 3e-2, 3e-2
|
| 825 |
-
elif dtype == torch.float16:
|
| 826 |
-
rtol, atol = 5e-3, 5e-3
|
| 827 |
-
else:
|
| 828 |
-
rtol, atol = 2e-3, 2e-3
|
| 829 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
@pytest.mark.parametrize("gqa_ratio", [2, 4, 8])
|
| 833 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128])
|
| 834 |
-
def test_flash_attention_gqa_ratios(gqa_ratio, head_dim):
|
| 835 |
-
"""Test Flash Attention with different GQA ratios."""
|
| 836 |
-
torch.manual_seed(42)
|
| 837 |
-
|
| 838 |
-
# Test dimensions
|
| 839 |
-
seq_len = 32
|
| 840 |
-
num_heads = 16
|
| 841 |
-
num_kv_heads = num_heads // gqa_ratio
|
| 842 |
-
dtype = torch.float32
|
| 843 |
-
|
| 844 |
-
# Create cumulative sequence lengths
|
| 845 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 846 |
-
|
| 847 |
-
# Create input tensors
|
| 848 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 849 |
-
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
|
| 850 |
-
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
|
| 851 |
-
|
| 852 |
-
# Scale factor
|
| 853 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 854 |
-
|
| 855 |
-
# Call Flash Attention
|
| 856 |
-
out = torch.empty_like(query)
|
| 857 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 858 |
-
out=out,
|
| 859 |
-
query=query,
|
| 860 |
-
key=key,
|
| 861 |
-
value=value,
|
| 862 |
-
cu_seqlens_q=cu_seqlens,
|
| 863 |
-
cu_seqlens_k=cu_seqlens,
|
| 864 |
-
max_seqlen_q=seq_len,
|
| 865 |
-
max_seqlen_k=seq_len,
|
| 866 |
-
do_causal=False,
|
| 867 |
-
scale=scale,
|
| 868 |
-
softcapping=1.0,
|
| 869 |
-
)
|
| 870 |
-
|
| 871 |
-
# Compute ground truth with GQA
|
| 872 |
-
expected = torch.zeros_like(query)
|
| 873 |
-
gqa_factor = num_heads // num_kv_heads
|
| 874 |
-
|
| 875 |
-
for h in range(num_heads):
|
| 876 |
-
kv_h = h // gqa_factor
|
| 877 |
-
q_h = query[:, h, :] # [seq_len, head_dim]
|
| 878 |
-
k_h = key[:, kv_h, :]
|
| 879 |
-
v_h = value[:, kv_h, :]
|
| 880 |
-
|
| 881 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 882 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 883 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 884 |
-
|
| 885 |
-
# Check results
|
| 886 |
-
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
def test_flash_attention_single_query_token():
|
| 890 |
-
"""Test Flash Attention with single query token (q_seq = 1)."""
|
| 891 |
-
torch.manual_seed(42)
|
| 892 |
-
|
| 893 |
-
# Test dimensions - single query token
|
| 894 |
-
q_seq = 1
|
| 895 |
-
k_seq = 64
|
| 896 |
-
num_heads = 8
|
| 897 |
-
head_dim = 64
|
| 898 |
-
dtype = torch.float32
|
| 899 |
-
|
| 900 |
-
# Create cumulative sequence lengths
|
| 901 |
-
cu_seqlens_q = create_cu_seqlens([q_seq])
|
| 902 |
-
cu_seqlens_k = create_cu_seqlens([k_seq])
|
| 903 |
-
|
| 904 |
-
# Create input tensors
|
| 905 |
-
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 906 |
-
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 907 |
-
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 908 |
-
|
| 909 |
-
# Scale factor
|
| 910 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 911 |
-
|
| 912 |
-
# Call Flash Attention
|
| 913 |
-
out = torch.empty_like(query)
|
| 914 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 915 |
-
out=out,
|
| 916 |
-
query=query,
|
| 917 |
-
key=key,
|
| 918 |
-
value=value,
|
| 919 |
-
cu_seqlens_q=cu_seqlens_q,
|
| 920 |
-
cu_seqlens_k=cu_seqlens_k,
|
| 921 |
-
max_seqlen_q=q_seq,
|
| 922 |
-
max_seqlen_k=k_seq,
|
| 923 |
-
do_causal=False,
|
| 924 |
-
scale=scale,
|
| 925 |
-
softcapping=1.0,
|
| 926 |
-
)
|
| 927 |
-
|
| 928 |
-
# With single token, output should be weighted average of values
|
| 929 |
-
expected = torch.zeros_like(out)
|
| 930 |
-
for h in range(num_heads):
|
| 931 |
-
q_h = query[:, h, :] # [1, head_dim]
|
| 932 |
-
k_h = key[:, h, :] # [k_seq, head_dim]
|
| 933 |
-
v_h = value[:, h, :] # [k_seq, head_dim]
|
| 934 |
-
|
| 935 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 936 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 937 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 938 |
-
|
| 939 |
-
torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3)
|
| 940 |
-
|
| 941 |
-
|
| 942 |
def test_flash_attn_varlen_func():
|
| 943 |
"""Test the flash_attn_varlen_func compatibility function."""
|
| 944 |
torch.manual_seed(42)
|
|
@@ -992,141 +467,3 @@ def test_flash_attn_varlen_func():
|
|
| 992 |
|
| 993 |
assert out_causal.shape == q.shape
|
| 994 |
assert out_causal.abs().max().item() > 0
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 998 |
-
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 999 |
-
def test_flash_attention_softcapping(dtype, head_dim):
|
| 1000 |
-
"""Test Flash Attention with softcapping."""
|
| 1001 |
-
torch.manual_seed(42)
|
| 1002 |
-
|
| 1003 |
-
# Test dimensions
|
| 1004 |
-
seq_lengths = [32, 24]
|
| 1005 |
-
num_heads = 4
|
| 1006 |
-
softcapping = 50.0
|
| 1007 |
-
|
| 1008 |
-
# Create cumulative sequence lengths
|
| 1009 |
-
cu_seqlens = create_cu_seqlens(seq_lengths)
|
| 1010 |
-
total_tokens = sum(seq_lengths)
|
| 1011 |
-
max_seqlen = max(seq_lengths)
|
| 1012 |
-
|
| 1013 |
-
# Create input tensors
|
| 1014 |
-
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1015 |
-
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1016 |
-
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1017 |
-
|
| 1018 |
-
# Scale factor
|
| 1019 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 1020 |
-
|
| 1021 |
-
# Call Flash Attention with softcapping
|
| 1022 |
-
out = torch.empty_like(query)
|
| 1023 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 1024 |
-
out=out,
|
| 1025 |
-
query=query,
|
| 1026 |
-
key=key,
|
| 1027 |
-
value=value,
|
| 1028 |
-
cu_seqlens_q=cu_seqlens,
|
| 1029 |
-
cu_seqlens_k=cu_seqlens,
|
| 1030 |
-
max_seqlen_q=max_seqlen,
|
| 1031 |
-
max_seqlen_k=max_seqlen,
|
| 1032 |
-
do_causal=False,
|
| 1033 |
-
scale=scale,
|
| 1034 |
-
softcapping=softcapping,
|
| 1035 |
-
)
|
| 1036 |
-
|
| 1037 |
-
# Compute ground truth with softcapping
|
| 1038 |
-
# The kernel applies: softmax(tanh(qk^T*scale/cap)*cap)v
|
| 1039 |
-
expected = torch.zeros_like(query)
|
| 1040 |
-
|
| 1041 |
-
for i, (start, end) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
|
| 1042 |
-
q_seq = query[start:end]
|
| 1043 |
-
k_seq = key[start:end]
|
| 1044 |
-
v_seq = value[start:end]
|
| 1045 |
-
|
| 1046 |
-
for h in range(num_heads):
|
| 1047 |
-
q_h = q_seq[:, h, :]
|
| 1048 |
-
k_h = k_seq[:, h, :]
|
| 1049 |
-
v_h = v_seq[:, h, :]
|
| 1050 |
-
|
| 1051 |
-
# Apply softcapping formula
|
| 1052 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * (scale / softcapping)
|
| 1053 |
-
scores = torch.tanh(scores) * softcapping
|
| 1054 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 1055 |
-
expected[start:end, h, :] = torch.matmul(attn_weights, v_h)
|
| 1056 |
-
|
| 1057 |
-
# Check results (higher tolerance for bfloat16 and softcapping)
|
| 1058 |
-
if dtype == torch.bfloat16:
|
| 1059 |
-
rtol, atol = 3e-2, 3e-2
|
| 1060 |
-
elif dtype == torch.float16:
|
| 1061 |
-
rtol, atol = 2e-2, 2e-2
|
| 1062 |
-
else:
|
| 1063 |
-
rtol, atol = 1e-2, 1e-2
|
| 1064 |
-
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 1068 |
-
def test_flash_attention_softcapping_edge_cases(dtype):
|
| 1069 |
-
"""Test Flash Attention softcapping with edge cases."""
|
| 1070 |
-
torch.manual_seed(42)
|
| 1071 |
-
|
| 1072 |
-
# Test with softcapping = 1.0 (no softcapping)
|
| 1073 |
-
seq_len = 16
|
| 1074 |
-
num_heads = 2
|
| 1075 |
-
head_dim = 64
|
| 1076 |
-
|
| 1077 |
-
cu_seqlens = create_cu_seqlens([seq_len])
|
| 1078 |
-
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1079 |
-
key = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1080 |
-
value = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
|
| 1081 |
-
|
| 1082 |
-
scale = 1.0 / (head_dim ** 0.5)
|
| 1083 |
-
|
| 1084 |
-
# With softcapping = 1.0 (no effect)
|
| 1085 |
-
out_no_cap = torch.empty_like(query)
|
| 1086 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 1087 |
-
out=out_no_cap,
|
| 1088 |
-
query=query,
|
| 1089 |
-
key=key,
|
| 1090 |
-
value=value,
|
| 1091 |
-
cu_seqlens_q=cu_seqlens,
|
| 1092 |
-
cu_seqlens_k=cu_seqlens,
|
| 1093 |
-
max_seqlen_q=seq_len,
|
| 1094 |
-
max_seqlen_k=seq_len,
|
| 1095 |
-
do_causal=False,
|
| 1096 |
-
scale=scale,
|
| 1097 |
-
softcapping=1.0,
|
| 1098 |
-
)
|
| 1099 |
-
|
| 1100 |
-
# Regular computation without softcapping
|
| 1101 |
-
expected = torch.zeros_like(query)
|
| 1102 |
-
for h in range(num_heads):
|
| 1103 |
-
q_h = query[:, h, :]
|
| 1104 |
-
k_h = key[:, h, :]
|
| 1105 |
-
v_h = value[:, h, :]
|
| 1106 |
-
|
| 1107 |
-
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 1108 |
-
attn_weights = torch.softmax(scores, dim=-1)
|
| 1109 |
-
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 1110 |
-
|
| 1111 |
-
# Should be identical when softcapping = 1.0
|
| 1112 |
-
rtol, atol = (2e-2, 2e-2) if dtype != torch.float32 else (1e-3, 1e-3)
|
| 1113 |
-
torch.testing.assert_close(out_no_cap, expected, rtol=rtol, atol=atol)
|
| 1114 |
-
|
| 1115 |
-
# Test with very large softcapping value
|
| 1116 |
-
out_large_cap = torch.empty_like(query)
|
| 1117 |
-
metal_flash_sdpa.flash_attention_varlen(
|
| 1118 |
-
out=out_large_cap,
|
| 1119 |
-
query=query,
|
| 1120 |
-
key=key,
|
| 1121 |
-
value=value,
|
| 1122 |
-
cu_seqlens_q=cu_seqlens,
|
| 1123 |
-
cu_seqlens_k=cu_seqlens,
|
| 1124 |
-
max_seqlen_q=seq_len,
|
| 1125 |
-
max_seqlen_k=seq_len,
|
| 1126 |
-
do_causal=False,
|
| 1127 |
-
scale=scale,
|
| 1128 |
-
softcapping=1000.0,
|
| 1129 |
-
)
|
| 1130 |
-
|
| 1131 |
-
# With very large softcapping, should be close to no softcapping
|
| 1132 |
-
torch.testing.assert_close(out_large_cap, expected, rtol=rtol, atol=atol)
|
|
|
|
| 11 |
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
|
| 12 |
|
| 13 |
|
| 14 |
+
def compute_attention_reference(query, key, value, scale, causal=False, softcapping=1.0, gqa_ratio=1):
|
| 15 |
+
"""Compute reference attention output for validation."""
|
| 16 |
+
num_heads = query.shape[1]
|
| 17 |
+
expected = torch.zeros_like(query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
for h in range(num_heads):
|
| 20 |
+
kv_h = h // gqa_ratio if gqa_ratio > 1 else h
|
| 21 |
+
q_h = query[:, h, :]
|
| 22 |
+
k_h = key[:, kv_h, :]
|
| 23 |
+
v_h = value[:, kv_h, :]
|
| 24 |
|
| 25 |
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
|
| 26 |
+
|
| 27 |
+
# Apply softcapping if not 1.0
|
| 28 |
+
if softcapping != 1.0:
|
| 29 |
+
scores = scores / softcapping
|
| 30 |
+
scores = torch.tanh(scores) * softcapping
|
| 31 |
+
|
| 32 |
+
# Apply causal mask if needed
|
| 33 |
+
if causal:
|
| 34 |
+
seq_len = query.shape[0]
|
| 35 |
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
|
| 36 |
+
scores.masked_fill_(causal_mask, float("-inf"))
|
| 37 |
+
|
| 38 |
attn_weights = torch.softmax(scores, dim=-1)
|
| 39 |
expected[:, h, :] = torch.matmul(attn_weights, v_h)
|
| 40 |
|
| 41 |
+
return expected
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_tolerance(dtype, head_dim):
|
| 45 |
+
"""Get appropriate tolerance based on dtype and head dimension."""
|
| 46 |
if dtype == torch.bfloat16:
|
| 47 |
+
return (2e-2, 2e-2) if head_dim >= 96 else (1e-2, 1e-2)
|
|
|
|
| 48 |
elif dtype == torch.float16:
|
| 49 |
+
return (2e-3, 2e-3)
|
| 50 |
else:
|
| 51 |
+
return (1e-3, 1e-3)
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 55 |
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 56 |
+
@pytest.mark.parametrize("seq_config", [
|
| 57 |
+
# (seq_lengths_q, seq_lengths_k, description)
|
| 58 |
+
([32], [32], "single_sequence"),
|
| 59 |
+
([8, 16, 12], [10, 20, 15], "variable_lengths"),
|
| 60 |
+
([16, 24], [16, 24], "multiple_sequences"),
|
| 61 |
+
([2], [2], "small_sequence_2"),
|
| 62 |
+
([4], [4], "small_sequence_4"),
|
| 63 |
+
([8], [8], "small_sequence_8"),
|
| 64 |
+
([16], [32], "cross_attention_q_lt_k"),
|
| 65 |
+
([32], [16], "cross_attention_q_gt_k"),
|
| 66 |
+
([8], [128], "cross_attention_large_diff"),
|
| 67 |
+
([1], [64], "single_query_token"),
|
| 68 |
+
])
|
| 69 |
+
@pytest.mark.parametrize("causal", [False, True])
|
| 70 |
+
def test_flash_attention_comprehensive(dtype, head_dim, seq_config, causal):
|
| 71 |
+
"""Comprehensive test for Flash Attention with various configurations."""
|
| 72 |
torch.manual_seed(42)
|
| 73 |
|
| 74 |
+
seq_lengths_q, seq_lengths_k, _ = seq_config
|
| 75 |
+
|
| 76 |
+
# Skip causal tests for cross-attention cases
|
| 77 |
+
if causal and seq_lengths_q != seq_lengths_k:
|
| 78 |
+
pytest.skip("Causal attention only valid when q_seq == k_seq")
|
| 79 |
+
|
| 80 |
+
# Test parameters
|
| 81 |
num_heads = 4
|
| 82 |
|
| 83 |
# Create cumulative sequence lengths
|
|
|
|
| 108 |
cu_seqlens_k=cu_seqlens_k,
|
| 109 |
max_seqlen_q=max_seqlen_q,
|
| 110 |
max_seqlen_k=max_seqlen_k,
|
| 111 |
+
do_causal=causal,
|
| 112 |
scale=scale,
|
| 113 |
softcapping=1.0,
|
| 114 |
)
|
| 115 |
|
| 116 |
# Compute ground truth for each sequence
|
| 117 |
expected = torch.zeros_like(out)
|
| 118 |
+
batch_size = len(seq_lengths_q)
|
| 119 |
+
|
| 120 |
for i in range(batch_size):
|
| 121 |
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
|
| 122 |
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
|
| 123 |
|
| 124 |
+
if q_end > q_start and k_end > k_start: # Skip empty sequences
|
| 125 |
+
q_i = query[q_start:q_end]
|
| 126 |
+
k_i = key[k_start:k_end]
|
| 127 |
+
v_i = value[k_start:k_end]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
expected_i = compute_attention_reference(q_i, k_i, v_i, scale, causal=causal)
|
| 130 |
+
expected[q_start:q_end] = expected_i
|
| 131 |
|
| 132 |
+
# Check results
|
| 133 |
+
rtol, atol = get_tolerance(dtype, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 135 |
|
| 136 |
|
| 137 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 138 |
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
|
| 139 |
+
@pytest.mark.parametrize("gqa_config", [
|
| 140 |
+
# (num_heads, num_kv_heads, seq_len)
|
| 141 |
+
(8, 2, 32), # 4:1 ratio
|
| 142 |
+
(16, 4, 32), # 4:1 ratio
|
| 143 |
+
(16, 8, 32), # 2:1 ratio
|
| 144 |
+
(16, 2, 32), # 8:1 ratio
|
| 145 |
+
(16, 4, 128), # 4:1 ratio with larger sequence
|
| 146 |
+
])
|
| 147 |
+
def test_flash_attention_gqa(dtype, head_dim, gqa_config):
|
| 148 |
+
"""Test Flash Attention with Grouped Query Attention configurations."""
|
| 149 |
torch.manual_seed(42)
|
| 150 |
|
| 151 |
+
num_heads, num_kv_heads, seq_len = gqa_config
|
| 152 |
+
gqa_ratio = num_heads // num_kv_heads
|
|
|
|
|
|
|
| 153 |
|
| 154 |
# Create cumulative sequence lengths
|
| 155 |
cu_seqlens = create_cu_seqlens([seq_len])
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
# Compute ground truth with GQA
|
| 182 |
+
expected = compute_attention_reference(query, key, value, scale, gqa_ratio=gqa_ratio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
# Check results
|
| 185 |
+
rtol, atol = get_tolerance(dtype, head_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 187 |
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
| 190 |
+
@pytest.mark.parametrize("softcapping_config", [
|
| 191 |
+
# (softcapping_value, seq_lengths, head_dim)
|
| 192 |
+
(1.0, [32], 64), # No softcapping
|
| 193 |
+
(50.0, [32, 24], 64), # Regular softcapping
|
| 194 |
+
(10.0, [16], 128), # Strong softcapping
|
| 195 |
+
(1000.0, [16], 64), # Very weak softcapping
|
| 196 |
+
(30.0, [48], 96), # Medium softcapping
|
| 197 |
+
])
|
| 198 |
+
def test_flash_attention_softcapping(dtype, softcapping_config):
|
| 199 |
+
"""Test Flash Attention with various softcapping values."""
|
| 200 |
torch.manual_seed(42)
|
| 201 |
|
| 202 |
+
softcapping, seq_lengths, head_dim = softcapping_config
|
| 203 |
+
num_heads = 4
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
# Create cumulative sequence lengths
|
| 206 |
cu_seqlens = create_cu_seqlens(seq_lengths)
|
|
|
|
| 215 |
# Scale factor
|
| 216 |
scale = 1.0 / (head_dim ** 0.5)
|
| 217 |
|
| 218 |
+
# Call Flash Attention with softcapping
|
| 219 |
out = torch.empty_like(query)
|
| 220 |
metal_flash_sdpa.flash_attention_varlen(
|
| 221 |
out=out,
|
|
|
|
| 228 |
max_seqlen_k=max_seqlen,
|
| 229 |
do_causal=False,
|
| 230 |
scale=scale,
|
| 231 |
+
softcapping=softcapping,
|
| 232 |
)
|
| 233 |
|
| 234 |
+
# Compute ground truth with softcapping
|
| 235 |
+
expected = torch.zeros_like(query)
|
| 236 |
+
|
| 237 |
+
for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
| 238 |
+
if end > start:
|
| 239 |
+
q_seq = query[start:end]
|
| 240 |
+
k_seq = key[start:end]
|
| 241 |
+
v_seq = value[start:end]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
expected_seq = compute_attention_reference(
|
| 244 |
+
q_seq, k_seq, v_seq, scale, softcapping=softcapping
|
| 245 |
+
)
|
| 246 |
+
expected[start:end] = expected_seq
|
| 247 |
|
| 248 |
+
# Check results (higher tolerance for softcapping)
|
| 249 |
if dtype == torch.bfloat16:
|
| 250 |
+
rtol, atol = 3e-2, 3e-2
|
|
|
|
| 251 |
elif dtype == torch.float16:
|
| 252 |
+
rtol, atol = 2e-2, 2e-2
|
| 253 |
else:
|
| 254 |
+
rtol, atol = 1e-2, 1e-2
|
| 255 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 256 |
|
| 257 |
|
| 258 |
+
@pytest.mark.parametrize("large_seq_config", [
|
| 259 |
+
# (q_seq, k_seq, head_dim, dtype)
|
| 260 |
+
(32, 2048, 64, torch.float32),
|
| 261 |
+
(16, 1024, 96, torch.float16),
|
| 262 |
+
(64, 1536, 64, torch.bfloat16),
|
| 263 |
+
])
|
| 264 |
+
def test_flash_attention_large_sequences(large_seq_config):
|
| 265 |
+
"""Test Flash Attention with large k sequences (>= 1024)."""
|
| 266 |
torch.manual_seed(42)
|
| 267 |
|
| 268 |
+
q_seq, k_seq, head_dim, dtype = large_seq_config
|
|
|
|
| 269 |
num_heads = 4
|
|
|
|
| 270 |
|
| 271 |
# Create cumulative sequence lengths
|
| 272 |
+
cu_seqlens_q = create_cu_seqlens([q_seq])
|
| 273 |
+
cu_seqlens_k = create_cu_seqlens([k_seq])
|
| 274 |
|
| 275 |
# Create input tensors
|
| 276 |
+
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 277 |
+
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 278 |
+
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
|
| 279 |
|
| 280 |
# Scale factor
|
| 281 |
scale = 1.0 / (head_dim ** 0.5)
|
| 282 |
|
| 283 |
+
# Call Flash Attention
|
| 284 |
out = torch.empty_like(query)
|
| 285 |
metal_flash_sdpa.flash_attention_varlen(
|
| 286 |
out=out,
|
| 287 |
query=query,
|
| 288 |
key=key,
|
| 289 |
value=value,
|
| 290 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 291 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 292 |
+
max_seqlen_q=q_seq,
|
| 293 |
+
max_seqlen_k=k_seq,
|
| 294 |
+
do_causal=False,
|
| 295 |
scale=scale,
|
| 296 |
softcapping=1.0,
|
| 297 |
)
|
| 298 |
|
| 299 |
+
# Compute ground truth
|
| 300 |
+
expected = compute_attention_reference(query, key, value, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
# Check results (higher tolerance for large sequences)
|
| 303 |
if dtype == torch.bfloat16:
|
| 304 |
+
rtol, atol = 3e-2, 3e-2
|
|
|
|
| 305 |
elif dtype == torch.float16:
|
| 306 |
+
rtol, atol = 5e-3, 5e-3
|
| 307 |
else:
|
| 308 |
+
rtol, atol = 2e-3, 2e-3
|
| 309 |
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
|
| 310 |
|
| 311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
def test_flash_attention_edge_cases():
|
| 313 |
"""Test Flash Attention edge cases."""
|
| 314 |
torch.manual_seed(42)
|
|
|
|
| 388 |
softcapping=1.0,
|
| 389 |
)
|
| 390 |
|
| 391 |
+
# Test 2: Wrong dtype for cu_seqlens (should be int32)
|
| 392 |
+
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
|
| 393 |
query = torch.randn(16, 4, 64, device="mps")
|
| 394 |
key = torch.randn(16, 4, 64, device="mps")
|
| 395 |
value = torch.randn(16, 4, 64, device="mps")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
|
| 397 |
# This will silently fail (output will be unchanged)
|
|
|
|
| 398 |
out = torch.full_like(query, -999.0)
|
| 399 |
metal_flash_sdpa.flash_attention_varlen(
|
| 400 |
out=out,
|
|
|
|
| 414 |
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
|
| 415 |
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
def test_flash_attn_varlen_func():
|
| 418 |
"""Test the flash_attn_varlen_func compatibility function."""
|
| 419 |
torch.manual_seed(42)
|
|
|
|
| 467 |
|
| 468 |
assert out_causal.shape == q.shape
|
| 469 |
assert out_causal.abs().max().item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|