Kernels
optimizer / test /test_muon.py
wyldecat's picture
Support param group with various placements (#13)
e2b41e5 unverified
import copy
import logging
import time
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
from optimizer.muon import Muon, get_default_muon_param_groups
from torch.distributed.tensor import DTensor, Replicate
from torch.profiler import ProfilerActivity, profile
from .utils import (ParallelDims, assert_params_equal, parallelize_motif,
parallelize_qk_logits)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def apply_muon_step(
model: torch.nn.Module,
parallel_dims: ParallelDims | None,
grads: list[torch.Tensor],
warmup_step: int,
chunk_size: int,
qk_logits: dict[int, torch.Tensor] | None = None,
use_distributed_muon: bool = False,
measure_perf: bool = False,
do_profile: bool = False,
) -> tuple[torch.nn.Module, tuple[float, float] | None]:
""" apply single Muon step with optional QK clipping """
# 1. Apply gradients to model parameters
assert len(grads) == len(list(model.parameters()))
for grad, param in zip(grads, model.parameters()):
grad = grad.to(param.device)
if isinstance(param.data, DTensor):
unsharded_grad = DTensor.from_local(
grad,
device_mesh=param.data.device_mesh,
placements=[Replicate()] * param.data.device_mesh.ndim,
)
sharded_grad = unsharded_grad.redistribute(
device_mesh=param.data.device_mesh,
placements=param.data.placements)
param.grad = sharded_grad
else:
param.grad = grad
# 2. Setup Muon optimizer
params = get_default_muon_param_groups(model)
clip_config = dict({
"q_indices":
list(range(model.config.num_attention_heads)),
"k_indices":
list(range(model.config.num_attention_heads)),
"head_dim":
model.config.hidden_size // model.config.num_attention_heads,
"threshold":
0.5
})
optim = Muon(
params=params,
clip_config=clip_config if qk_logits is not None else None,
none_grad=False,
warmup_step=warmup_step,
chunk_size=chunk_size,
use_distributed_muon=use_distributed_muon,
)
optim.step(qk_logits=qk_logits)
timing_result: tuple[float, float] | None = None
if measure_perf:
# extra warm up
optim.step(qk_logits=qk_logits)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
num_iters = 20
current_mem = torch.cuda.memory_allocated()
if do_profile:
context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True)
else:
context = nullcontext()
with context as prof:
for _i in range(num_iters):
optim.step(qk_logits=qk_logits)
end.record()
end.synchronize()
if prof is not None and dist.get_rank() == 0:
date = time.strftime("%Y%m%d_%H%M%S", time.localtime())
profile_name = "trace"
profile_name += f"_{date}"
profile_name += f"_{parallel_dims}"
profile_name += f"_{chunk_size}"
profile_name += f"_{warmup_step}"
profile_name += f"_{qk_logits is not None}"
profile_name += f"_{use_distributed_muon}"
prof.export_chrome_trace(f"{profile_name}.json")
peak_memory = torch.cuda.max_memory_allocated() - current_mem
elapsed_time_ms = start.elapsed_time(end) / num_iters
timing_result = (elapsed_time_ms, peak_memory)
return model, timing_result
@pytest.fixture(scope="session")
def sequential_muon_result(
skip_verify, # from conftest.py
inputs # from conftest.py
) -> dict[bool, torch.nn.Module]:
"""Run Muon optimizer to sequential model for baseline results."""
if skip_verify:
logger.info("Skipping verification tests as per user request")
return None
model, grads, qk_logits = inputs
result = apply_muon_step(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
qk_logits=None,
)[0].cpu()
result_qk_clip = apply_muon_step(
model=copy.deepcopy(model).cuda(),
parallel_dims=None,
grads=grads,
warmup_step=-1,
chunk_size=-1,
qk_logits=qk_logits,
)[0].cpu()
return {
False: result,
True: result_qk_clip,
}
OVERLAP_STEPS = [5]
CHUNK_SIZES = [8]
@pytest.mark.parametrize("parallel_dims", [
pytest.param(ParallelDims(8, 1, 1), id="base"),
pytest.param(ParallelDims(1, 8, 1), id="fsdp"),
pytest.param(ParallelDims(2, 4, 1), id="hsdp"),
pytest.param(ParallelDims(1, 1, 8), id="tp"),
pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"),
pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"),
])
@pytest.mark.parametrize("apply_qk_clip", [False, True])
@pytest.mark.parametrize("use_distributed_muon", [False])
@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS)
@pytest.mark.parametrize("chunk_size", CHUNK_SIZES)
def test_parallel_muon(
request,
sequential_muon_result: dict[bool, torch.nn.Module],
parallel_dims: ParallelDims,
apply_qk_clip: bool,
use_distributed_muon: bool,
warmup_step: int,
chunk_size: int,
inputs: tuple[torch.nn.Module, list[torch.Tensor],
dict[int, torch.Tensor]], # from conftest.py
measure_perf, # from conftest.py
do_profile, # from conftest.py
) -> None:
if use_distributed_muon and chunk_size != CHUNK_SIZES[0]:
pytest.skip("Distributed Muon does not effected by chunk size")
if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]:
pytest.skip("Distributed Muon does not effected by warmup step")
model, grads, qk_logits = inputs
if not apply_qk_clip:
qk_logits = None
# Deepcopy the model to avoid in-place modification
model = copy.deepcopy(model).cuda()
parallelized_model = parallelize_motif(model, parallel_dims)
if qk_logits is not None:
# Deepcopy the qk logits to avoid in-place modification
qk_logits = copy.deepcopy(qk_logits)
qk_logits = parallelize_qk_logits(qk_logits, parallel_dims)
parallelized_model, timing_result = apply_muon_step(
model=parallelized_model,
parallel_dims=parallel_dims,
grads=grads,
warmup_step=warmup_step,
chunk_size=chunk_size,
qk_logits=qk_logits,
use_distributed_muon=use_distributed_muon,
measure_perf=measure_perf,
do_profile=do_profile,
)
if measure_perf:
assert timing_result is not None
avg_time_ms, peak_memory = timing_result
logger.info(
f"\nParallel dims: {parallel_dims}, "
f"\nUse distributed Muon: {use_distributed_muon}, "
f"\nApply QK clip: {apply_qk_clip} => "
f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):"
f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f},"
)
if sequential_muon_result is None:
logger.info("Skipping correctness check as sequential result is None")
elif measure_perf:
logger.info("Skipping correctness check as timing is enabled")
else:
assert_params_equal(parallelized_model,
sequential_muon_result[apply_qk_clip])