Kernels
metal-flash-sdpa / tests /test_flash_attention.py
Eric Buehler
Fix bug in kernel
0e83189
import torch
import pytest
import metal_flash_sdpa
def create_cu_seqlens(seq_lengths):
"""Create cumulative sequence lengths tensor."""
cu_seqlens = [0]
for length in seq_lengths:
cu_seqlens.append(cu_seqlens[-1] + length)
return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps")
def compute_attention_reference(query, key, value, scale, causal=False, softcapping=1.0, gqa_ratio=1):
"""Compute reference attention output for validation."""
num_heads = query.shape[1]
expected = torch.zeros_like(query)
for h in range(num_heads):
kv_h = h // gqa_ratio if gqa_ratio > 1 else h
q_h = query[:, h, :]
k_h = key[:, kv_h, :]
v_h = value[:, kv_h, :]
scores = torch.matmul(q_h, k_h.transpose(-1, -2)) * scale
# Apply softcapping if not 1.0
if softcapping != 1.0:
scores = scores / softcapping
scores = torch.tanh(scores) * softcapping
# Apply causal mask if needed
if causal:
seq_len = query.shape[0]
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool()
scores.masked_fill_(causal_mask, float("-inf"))
attn_weights = torch.softmax(scores, dim=-1)
expected[:, h, :] = torch.matmul(attn_weights, v_h)
return expected
def get_tolerance(dtype, head_dim):
"""Get appropriate tolerance based on dtype and head dimension."""
if dtype == torch.bfloat16:
return (2e-2, 2e-2) if head_dim >= 96 else (1.6e-2, 1.6e-2)
elif dtype == torch.float16:
return (2e-3, 2e-3)
else:
return (1e-3, 1e-3)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
@pytest.mark.parametrize("seq_config", [
# (seq_lengths_q, seq_lengths_k, description)
([32], [32], "single_sequence"),
([8, 16, 12], [10, 20, 15], "variable_lengths"),
([16, 24], [16, 24], "multiple_sequences"),
([2], [2], "small_sequence_2"),
([4], [4], "small_sequence_4"),
([8], [8], "small_sequence_8"),
([16], [32], "cross_attention_q_lt_k"),
([32], [16], "cross_attention_q_gt_k"),
([8], [128], "cross_attention_large_diff"),
([1], [64], "single_query_token"),
])
@pytest.mark.parametrize("causal", [False, True])
def test_flash_attention_comprehensive(dtype, head_dim, seq_config, causal):
"""Comprehensive test for Flash Attention with various configurations."""
torch.manual_seed(42)
seq_lengths_q, seq_lengths_k, _ = seq_config
# Skip causal tests for cross-attention cases
if causal and seq_lengths_q != seq_lengths_k:
pytest.skip("Causal attention only valid when q_seq == k_seq")
# Test parameters
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens(seq_lengths_q)
cu_seqlens_k = create_cu_seqlens(seq_lengths_k)
total_q = sum(seq_lengths_q)
total_k = sum(seq_lengths_k)
max_seqlen_q = max(seq_lengths_q)
max_seqlen_k = max(seq_lengths_k)
# Create input tensors
query = torch.randn(total_q, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_k, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
do_causal=causal,
scale=scale,
softcapping=1.0,
)
# Compute ground truth for each sequence
expected = torch.zeros_like(out)
batch_size = len(seq_lengths_q)
for i in range(batch_size):
q_start, q_end = cu_seqlens_q[i].item(), cu_seqlens_q[i+1].item()
k_start, k_end = cu_seqlens_k[i].item(), cu_seqlens_k[i+1].item()
if q_end > q_start and k_end > k_start: # Skip empty sequences
q_i = query[q_start:q_end]
k_i = key[k_start:k_end]
v_i = value[k_start:k_end]
expected_i = compute_attention_reference(q_i, k_i, v_i, scale, causal=causal)
expected[q_start:q_end] = expected_i
# Check results
rtol, atol = get_tolerance(dtype, head_dim)
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("head_dim", [32, 64, 72, 80, 96, 128, 256])
@pytest.mark.parametrize("gqa_config", [
# (num_heads, num_kv_heads, seq_len)
(8, 2, 32), # 4:1 ratio
(16, 4, 32), # 4:1 ratio
(16, 8, 32), # 2:1 ratio
(16, 2, 32), # 8:1 ratio
(16, 4, 128), # 4:1 ratio with larger sequence
])
def test_flash_attention_gqa(dtype, head_dim, gqa_config):
"""Test Flash Attention with Grouped Query Attention configurations."""
torch.manual_seed(42)
num_heads, num_kv_heads, seq_len = gqa_config
gqa_ratio = num_heads // num_kv_heads
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens([seq_len])
# Create input tensors
query = torch.randn(seq_len, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(seq_len, num_kv_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth with GQA
expected = compute_attention_reference(query, key, value, scale, gqa_ratio=gqa_ratio)
# Check results
rtol, atol = get_tolerance(dtype, head_dim)
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("softcapping_config", [
# (softcapping_value, seq_lengths, head_dim)
(1.0, [32], 64), # No softcapping
(50.0, [32, 24], 64), # Regular softcapping
(10.0, [16], 128), # Strong softcapping
(1000.0, [16], 64), # Very weak softcapping
(30.0, [48], 96), # Medium softcapping
])
def test_flash_attention_softcapping(dtype, softcapping_config):
"""Test Flash Attention with various softcapping values."""
torch.manual_seed(42)
softcapping, seq_lengths, head_dim = softcapping_config
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention with softcapping
out = torch.empty_like(query)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
do_causal=False,
scale=scale,
softcapping=softcapping,
)
# Compute ground truth with softcapping
expected = torch.zeros_like(query)
for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]):
if end > start:
q_seq = query[start:end]
k_seq = key[start:end]
v_seq = value[start:end]
expected_seq = compute_attention_reference(
q_seq, k_seq, v_seq, scale, softcapping=softcapping
)
expected[start:end] = expected_seq
# Check results (higher tolerance for softcapping)
# Note: Softcapping with strong values (< 50) has higher error due to
# the interaction between tanh transformation and exp2-based softmax
if dtype == torch.bfloat16:
if softcapping < 50:
rtol, atol = 1.5e-1, 1.5e-1 # Higher tolerance for strong softcapping
else:
rtol, atol = 3e-2, 3e-2
elif dtype == torch.float16:
if softcapping < 50:
rtol, atol = 1e-1, 1e-1
else:
rtol, atol = 2e-2, 2e-2
else:
if softcapping < 50:
rtol, atol = 1.5e-1, 1.5e-1 # Higher tolerance for strong softcapping with float32
else:
rtol, atol = 1e-2, 1e-2
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
@pytest.mark.parametrize("large_seq_config", [
# (q_seq, k_seq, head_dim, dtype)
(32, 2048, 64, torch.float32),
(16, 1024, 96, torch.float16),
(64, 1536, 64, torch.bfloat16),
])
def test_flash_attention_large_sequences(large_seq_config):
"""Test Flash Attention with large k sequences (>= 1024)."""
torch.manual_seed(42)
q_seq, k_seq, head_dim, dtype = large_seq_config
num_heads = 4
# Create cumulative sequence lengths
cu_seqlens_q = create_cu_seqlens([q_seq])
cu_seqlens_k = create_cu_seqlens([k_seq])
# Create input tensors
query = torch.randn(q_seq, num_heads, head_dim, dtype=dtype, device="mps")
key = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
value = torch.randn(k_seq, num_heads, head_dim, dtype=dtype, device="mps")
# Scale factor
scale = 1.0 / (head_dim ** 0.5)
# Call Flash Attention
out = torch.empty_like(query)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_seq,
max_seqlen_k=k_seq,
do_causal=False,
scale=scale,
softcapping=1.0,
)
# Compute ground truth
expected = compute_attention_reference(query, key, value, scale)
# Check results (higher tolerance for large sequences)
if dtype == torch.bfloat16:
rtol, atol = 3e-2, 3e-2
elif dtype == torch.float16:
rtol, atol = 5e-3, 5e-3
else:
rtol, atol = 2e-3, 2e-3
torch.testing.assert_close(out, expected, rtol=rtol, atol=atol)
def test_flash_attention_edge_cases():
"""Test Flash Attention edge cases."""
torch.manual_seed(42)
# Test 1: Single token sequence
query = torch.randn(1, 1, 64, device="mps")
key = torch.randn(1, 1, 64, device="mps")
value = torch.randn(1, 1, 64, device="mps")
cu_seqlens = create_cu_seqlens([1])
out = torch.empty_like(query)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=1,
max_seqlen_k=1,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
# With single token, output should equal value
torch.testing.assert_close(out, value, rtol=1e-5, atol=1e-5)
# Test 2: Empty sequence in batch
seq_lengths = [8, 0, 12] # Middle sequence is empty
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
query = torch.randn(total_tokens, 4, 64, device="mps")
key = torch.randn(total_tokens, 4, 64, device="mps")
value = torch.randn(total_tokens, 4, 64, device="mps")
out = torch.empty_like(query)
# This should handle empty sequences gracefully
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max(seq_lengths) if seq_lengths else 0,
max_seqlen_k=max(seq_lengths) if seq_lengths else 0,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
def test_flash_attention_unsupported_cases():
"""Test that unsupported cases raise appropriate errors."""
# Test 1: Unsupported head dimension
query = torch.randn(16, 4, 48, device="mps") # head_dim = 48 (not supported)
key = torch.randn(16, 4, 48, device="mps")
value = torch.randn(16, 4, 48, device="mps")
cu_seqlens = create_cu_seqlens([16])
out = torch.empty_like(query)
with pytest.raises(RuntimeError, match="Head dimension .* is not supported"):
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=16,
max_seqlen_k=16,
do_causal=False,
scale=0.144,
softcapping=1.0,
)
# Test 2: Wrong dtype for cu_seqlens (should be int32)
cu_seqlens_wrong = torch.tensor([0, 16], dtype=torch.int64, device="mps")
query = torch.randn(16, 4, 64, device="mps")
key = torch.randn(16, 4, 64, device="mps")
value = torch.randn(16, 4, 64, device="mps")
# This will silently fail (output will be unchanged)
out = torch.full_like(query, -999.0)
metal_flash_sdpa.flash_attention_varlen(
out=out,
query=query,
key=key,
value=value,
cu_seqlens_q=cu_seqlens_wrong,
cu_seqlens_k=cu_seqlens_wrong,
max_seqlen_q=16,
max_seqlen_k=16,
do_causal=False,
scale=0.125,
softcapping=1.0,
)
# Check that output wasn't modified (kernel didn't run)
assert (out == -999.0).all(), "cu_seqlens with wrong dtype should cause kernel to not run"
def test_flash_attn_varlen_func():
"""Test the flash_attn_varlen_func compatibility function."""
torch.manual_seed(42)
# Test dimensions
seq_lengths = [8, 12]
num_heads = 4
head_dim = 64
# Create cumulative sequence lengths
cu_seqlens = create_cu_seqlens(seq_lengths)
total_tokens = sum(seq_lengths)
max_seqlen = max(seq_lengths)
# Create input tensors
q = torch.randn(total_tokens, num_heads, head_dim, device="mps")
k = torch.randn(total_tokens, num_heads, head_dim, device="mps")
v = torch.randn(total_tokens, num_heads, head_dim, device="mps")
# Call the compatibility function
out = metal_flash_sdpa.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
softmax_scale=None, # Will use 1/sqrt(head_dim)
causal=False,
)
# Check that output has correct shape and is not zeros
assert out.shape == q.shape
assert out.abs().max().item() > 0
# Test with causal
out_causal = metal_flash_sdpa.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
softmax_scale=0.125,
causal=True,
)
assert out_causal.shape == q.shape
assert out_causal.abs().max().item() > 0