|
|
import torch |
|
|
import pytest |
|
|
import metal_flash_sdpa |
|
|
|
|
|
|
|
|
def create_cu_seqlens(seq_lengths): |
|
|
"""Create cumulative sequence lengths tensor.""" |
|
|
cu_seqlens = [0] |
|
|
for length in seq_lengths: |
|
|
cu_seqlens.append(cu_seqlens[-1] + length) |
|
|
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps") |
|
|
|
|
|
|
|
|
def compute_attention_reference(query, key, value, scale, causal=False, softcapping=1.0, gqa_ratio=1): |
|
|
"""Compute reference attention output for validation.""" |
|
|
num_heads = query.shape[1] |
|
|
expected = torch.zeros_like(query) |
|
|
|
|
|
for h in range(num_heads): |
|
|
kv_h = h // gqa_ratio if gqa_ratio > 1 else h |
|
|
q_h = query[:, h, :] |
|
|
k_h = key[:, kv_h, :] |
|
|
v_h = value[:, kv_h, :] |
|
|
|
|
|
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale |
|
|
|
|
|
|
|
|
if softcapping != 1.0: |
|
|
scores = scores / softcapping |
|
|
scores = torch.tanh(scores) * softcapping |
|
|
|
|
|
|
|
|
if causal: |
|
|
seq_len = query.shape[0] |
|
|
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() |
|
|
scores.masked_fill_(causal_mask, float("-inf")) |
|
|
|
|
|
attn_weights = torch.softmax(scores, dim=-1) |
|
|
expected[:, h, :] = torch.matmul(attn_weights, v_h) |
|
|
|
|
|
return expected |
|
|
|
|
|
|
|
|
def get_tolerance(dtype, head_dim): |
|
|
"""Get appropriate tolerance based on dtype and head dimension.""" |
|
|
if dtype == torch.bfloat16: |
|
|
return (2e-2, 2e-2) if head_dim >= 96 else (1.6e-2, 1.6e-2) |
|
|
elif dtype == torch.float16: |
|
|
return (2e-3, 2e-3) |
|
|
else: |
|
|
return (1e-3, 1e-3) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
@pytest.mark.parametrize("seq_config", [ |
|
|
|
|
|
([32], [32], "single_sequence"), |
|
|
([8, 16, 12], [10, 20, 15], "variable_lengths"), |
|
|
([16, 24], [16, 24], "multiple_sequences"), |
|
|
([2], [2], "small_sequence_2"), |
|
|
([4], [4], "small_sequence_4"), |
|
|
([8], [8], "small_sequence_8"), |
|
|
([16], [32], "cross_attention_q_lt_k"), |
|
|
([32], [16], "cross_attention_q_gt_k"), |
|
|
([8], [128], "cross_attention_large_diff"), |
|
|
([1], [64], "single_query_token"), |
|
|
]) |
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
def test_flash_attention_comprehensive(dtype, head_dim, seq_config, causal): |
|
|
"""Comprehensive test for Flash Attention with various configurations.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
seq_lengths_q, seq_lengths_k, _ = seq_config |
|
|
|
|
|
|
|
|
if causal and seq_lengths_q != seq_lengths_k: |
|
|
pytest.skip("Causal attention only valid when q_seq == k_seq") |
|
|
|
|
|
|
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens(seq_lengths_q) |
|
|
cu_seqlens_k = create_cu_seqlens(seq_lengths_k) |
|
|
|
|
|
total_q = sum(seq_lengths_q) |
|
|
total_k = sum(seq_lengths_k) |
|
|
max_seqlen_q = max(seq_lengths_q) |
|
|
max_seqlen_k = max(seq_lengths_k) |
|
|
|
|
|
|
|
|
query = torch.randn(total_q, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=max_seqlen_q, |
|
|
max_seqlen_k=max_seqlen_k, |
|
|
do_causal=causal, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(out) |
|
|
batch_size = len(seq_lengths_q) |
|
|
|
|
|
for i in range(batch_size): |
|
|
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item() |
|
|
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item() |
|
|
|
|
|
if q_end > q_start and k_end > k_start: |
|
|
q_i = query[q_start:q_end] |
|
|
k_i = key[k_start:k_end] |
|
|
v_i = value[k_start:k_end] |
|
|
|
|
|
expected_i = compute_attention_reference(q_i, k_i, v_i, scale, causal=causal) |
|
|
expected[q_start:q_end] = expected_i |
|
|
|
|
|
|
|
|
rtol, atol = get_tolerance(dtype, head_dim) |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256]) |
|
|
@pytest.mark.parametrize("gqa_config", [ |
|
|
|
|
|
(8, 2, 32), |
|
|
(16, 4, 32), |
|
|
(16, 8, 32), |
|
|
(16, 2, 32), |
|
|
(16, 4, 128), |
|
|
]) |
|
|
def test_flash_attention_gqa(dtype, head_dim, gqa_config): |
|
|
"""Test Flash Attention with Grouped Query Attention configurations.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
num_heads, num_kv_heads, seq_len = gqa_config |
|
|
gqa_ratio = num_heads // num_kv_heads |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens([seq_len]) |
|
|
|
|
|
|
|
|
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seq_len, |
|
|
max_seqlen_k=seq_len, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = compute_attention_reference(query, key, value, scale, gqa_ratio=gqa_ratio) |
|
|
|
|
|
|
|
|
rtol, atol = get_tolerance(dtype, head_dim) |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
@pytest.mark.parametrize("softcapping_config", [ |
|
|
|
|
|
(1.0, [32], 64), |
|
|
(50.0, [32, 24], 64), |
|
|
(10.0, [16], 128), |
|
|
(1000.0, [16], 64), |
|
|
(30.0, [48], 96), |
|
|
]) |
|
|
def test_flash_attention_softcapping(dtype, softcapping_config): |
|
|
"""Test Flash Attention with various softcapping values.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
softcapping, seq_lengths, head_dim = softcapping_config |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=softcapping, |
|
|
) |
|
|
|
|
|
|
|
|
expected = torch.zeros_like(query) |
|
|
|
|
|
for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]): |
|
|
if end > start: |
|
|
q_seq = query[start:end] |
|
|
k_seq = key[start:end] |
|
|
v_seq = value[start:end] |
|
|
|
|
|
expected_seq = compute_attention_reference( |
|
|
q_seq, k_seq, v_seq, scale, softcapping=softcapping |
|
|
) |
|
|
expected[start:end] = expected_seq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
if softcapping < 50: |
|
|
rtol, atol = 1.5e-1, 1.5e-1 |
|
|
else: |
|
|
rtol, atol = 3e-2, 3e-2 |
|
|
elif dtype == torch.float16: |
|
|
if softcapping < 50: |
|
|
rtol, atol = 1e-1, 1e-1 |
|
|
else: |
|
|
rtol, atol = 2e-2, 2e-2 |
|
|
else: |
|
|
if softcapping < 50: |
|
|
rtol, atol = 1.5e-1, 1.5e-1 |
|
|
else: |
|
|
rtol, atol = 1e-2, 1e-2 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("large_seq_config", [ |
|
|
|
|
|
(32, 2048, 64, torch.float32), |
|
|
(16, 1024, 96, torch.float16), |
|
|
(64, 1536, 64, torch.bfloat16), |
|
|
]) |
|
|
def test_flash_attention_large_sequences(large_seq_config): |
|
|
"""Test Flash Attention with large k sequences (>= 1024).""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
q_seq, k_seq, head_dim, dtype = large_seq_config |
|
|
num_heads = 4 |
|
|
|
|
|
|
|
|
cu_seqlens_q = create_cu_seqlens([q_seq]) |
|
|
cu_seqlens_k = create_cu_seqlens([k_seq]) |
|
|
|
|
|
|
|
|
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps") |
|
|
|
|
|
|
|
|
scale = 1.0 / (head_dim ** 0.5) |
|
|
|
|
|
|
|
|
out = torch.empty_like(query) |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=q_seq, |
|
|
max_seqlen_k=k_seq, |
|
|
do_causal=False, |
|
|
scale=scale, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
expected = compute_attention_reference(query, key, value, scale) |
|
|
|
|
|
|
|
|
if dtype == torch.bfloat16: |
|
|
rtol, atol = 3e-2, 3e-2 |
|
|
elif dtype == torch.float16: |
|
|
rtol, atol = 5e-3, 5e-3 |
|
|
else: |
|
|
rtol, atol = 2e-3, 2e-3 |
|
|
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol) |
|
|
|
|
|
|
|
|
def test_flash_attention_edge_cases(): |
|
|
"""Test Flash Attention edge cases.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
query = torch.randn(1, 1, 64, device="mps") |
|
|
key = torch.randn(1, 1, 64, device="mps") |
|
|
value = torch.randn(1, 1, 64, device="mps") |
|
|
cu_seqlens = create_cu_seqlens([1]) |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=1, |
|
|
max_seqlen_k=1, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
torch.testing.assert_close(out, value, rtol=1e-5, atol=1e-5) |
|
|
|
|
|
|
|
|
seq_lengths = [8, 0, 12] |
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
|
|
|
query = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
key = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
value = torch.randn(total_tokens, 4, 64, device="mps") |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
|
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max(seq_lengths) if seq_lengths else 0, |
|
|
max_seqlen_k=max(seq_lengths) if seq_lengths else 0, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
def test_flash_attention_unsupported_cases(): |
|
|
"""Test that unsupported cases raise appropriate errors.""" |
|
|
|
|
|
|
|
|
query = torch.randn(16, 4, 48, device="mps") |
|
|
key = torch.randn(16, 4, 48, device="mps") |
|
|
value = torch.randn(16, 4, 48, device="mps") |
|
|
cu_seqlens = create_cu_seqlens([16]) |
|
|
out = torch.empty_like(query) |
|
|
|
|
|
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"): |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=16, |
|
|
max_seqlen_k=16, |
|
|
do_causal=False, |
|
|
scale=0.144, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps") |
|
|
query = torch.randn(16, 4, 64, device="mps") |
|
|
key = torch.randn(16, 4, 64, device="mps") |
|
|
value = torch.randn(16, 4, 64, device="mps") |
|
|
|
|
|
|
|
|
out = torch.full_like(query, -999.0) |
|
|
metal_flash_sdpa.flash_attention_varlen( |
|
|
out=out, |
|
|
query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
cu_seqlens_q=cu_seqlens_wrong, |
|
|
cu_seqlens_k=cu_seqlens_wrong, |
|
|
max_seqlen_q=16, |
|
|
max_seqlen_k=16, |
|
|
do_causal=False, |
|
|
scale=0.125, |
|
|
softcapping=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run" |
|
|
|
|
|
|
|
|
def test_flash_attn_varlen_func(): |
|
|
"""Test the flash_attn_varlen_func compatibility function.""" |
|
|
torch.manual_seed(42) |
|
|
|
|
|
|
|
|
seq_lengths = [8, 12] |
|
|
num_heads = 4 |
|
|
head_dim = 64 |
|
|
|
|
|
|
|
|
cu_seqlens = create_cu_seqlens(seq_lengths) |
|
|
total_tokens = sum(seq_lengths) |
|
|
max_seqlen = max(seq_lengths) |
|
|
|
|
|
|
|
|
q = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
k = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
v = torch.randn(total_tokens, num_heads, head_dim, device="mps") |
|
|
|
|
|
|
|
|
out = metal_flash_sdpa.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=None, |
|
|
causal=False, |
|
|
) |
|
|
|
|
|
|
|
|
assert out.shape == q.shape |
|
|
assert out.abs().max().item() > 0 |
|
|
|
|
|
|
|
|
out_causal = metal_flash_sdpa.flash_attn_varlen_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=0.125, |
|
|
causal=True, |
|
|
) |
|
|
|
|
|
assert out_causal.shape == q.shape |
|
|
assert out_causal.abs().max().item() > 0 |
|
|
|