|
|
import copy |
|
|
import logging |
|
|
import time |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from optimizer.muon import Muon, get_default_muon_param_groups |
|
|
from torch.distributed.tensor import DTensor, Replicate |
|
|
from torch.profiler import ProfilerActivity, profile |
|
|
|
|
|
from .utils import (ParallelDims, assert_params_equal, parallelize_motif, |
|
|
parallelize_qk_logits) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
def apply_muon_step( |
|
|
model: torch.nn.Module, |
|
|
parallel_dims: ParallelDims | None, |
|
|
grads: list[torch.Tensor], |
|
|
warmup_step: int, |
|
|
chunk_size: int, |
|
|
qk_logits: dict[int, torch.Tensor] | None = None, |
|
|
use_distributed_muon: bool = False, |
|
|
measure_perf: bool = False, |
|
|
do_profile: bool = False, |
|
|
) -> tuple[torch.nn.Module, tuple[float, float] | None]: |
|
|
""" apply single Muon step with optional QK clipping """ |
|
|
|
|
|
|
|
|
assert len(grads) == len(list(model.parameters())) |
|
|
for grad, param in zip(grads, model.parameters()): |
|
|
grad = grad.to(param.device) |
|
|
if isinstance(param.data, DTensor): |
|
|
unsharded_grad = DTensor.from_local( |
|
|
grad, |
|
|
device_mesh=param.data.device_mesh, |
|
|
placements=[Replicate()] * param.data.device_mesh.ndim, |
|
|
) |
|
|
sharded_grad = unsharded_grad.redistribute( |
|
|
device_mesh=param.data.device_mesh, |
|
|
placements=param.data.placements) |
|
|
param.grad = sharded_grad |
|
|
else: |
|
|
param.grad = grad |
|
|
|
|
|
|
|
|
params = get_default_muon_param_groups(model) |
|
|
clip_config = dict({ |
|
|
"q_indices": |
|
|
list(range(model.config.num_attention_heads)), |
|
|
"k_indices": |
|
|
list(range(model.config.num_attention_heads)), |
|
|
"head_dim": |
|
|
model.config.hidden_size // model.config.num_attention_heads, |
|
|
"threshold": |
|
|
0.5 |
|
|
}) |
|
|
optim = Muon( |
|
|
params=params, |
|
|
clip_config=clip_config if qk_logits is not None else None, |
|
|
none_grad=False, |
|
|
warmup_step=warmup_step, |
|
|
chunk_size=chunk_size, |
|
|
use_distributed_muon=use_distributed_muon, |
|
|
) |
|
|
|
|
|
optim.step(qk_logits=qk_logits) |
|
|
|
|
|
timing_result: tuple[float, float] | None = None |
|
|
|
|
|
if measure_perf: |
|
|
|
|
|
optim.step(qk_logits=qk_logits) |
|
|
|
|
|
start = torch.cuda.Event(enable_timing=True) |
|
|
end = torch.cuda.Event(enable_timing=True) |
|
|
|
|
|
start.record() |
|
|
num_iters = 20 |
|
|
current_mem = torch.cuda.memory_allocated() |
|
|
|
|
|
if do_profile: |
|
|
context = profile( |
|
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], |
|
|
record_shapes=True) |
|
|
else: |
|
|
context = nullcontext() |
|
|
|
|
|
with context as prof: |
|
|
for _i in range(num_iters): |
|
|
optim.step(qk_logits=qk_logits) |
|
|
|
|
|
end.record() |
|
|
end.synchronize() |
|
|
|
|
|
if prof is not None and dist.get_rank() == 0: |
|
|
date = time.strftime("%Y%m%d_%H%M%S", time.localtime()) |
|
|
profile_name = "trace" |
|
|
profile_name += f"_{date}" |
|
|
profile_name += f"_{parallel_dims}" |
|
|
profile_name += f"_{chunk_size}" |
|
|
profile_name += f"_{warmup_step}" |
|
|
profile_name += f"_{qk_logits is not None}" |
|
|
profile_name += f"_{use_distributed_muon}" |
|
|
|
|
|
prof.export_chrome_trace(f"{profile_name}.json") |
|
|
|
|
|
peak_memory = torch.cuda.max_memory_allocated() - current_mem |
|
|
|
|
|
elapsed_time_ms = start.elapsed_time(end) / num_iters |
|
|
|
|
|
timing_result = (elapsed_time_ms, peak_memory) |
|
|
|
|
|
return model, timing_result |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session") |
|
|
def sequential_muon_result( |
|
|
skip_verify, |
|
|
inputs |
|
|
) -> dict[bool, torch.nn.Module]: |
|
|
"""Run Muon optimizer to sequential model for baseline results.""" |
|
|
if skip_verify: |
|
|
logger.info("Skipping verification tests as per user request") |
|
|
return None |
|
|
|
|
|
model, grads, qk_logits = inputs |
|
|
|
|
|
result = apply_muon_step( |
|
|
model=copy.deepcopy(model).cuda(), |
|
|
parallel_dims=None, |
|
|
grads=grads, |
|
|
warmup_step=-1, |
|
|
chunk_size=-1, |
|
|
qk_logits=None, |
|
|
)[0].cpu() |
|
|
|
|
|
result_qk_clip = apply_muon_step( |
|
|
model=copy.deepcopy(model).cuda(), |
|
|
parallel_dims=None, |
|
|
grads=grads, |
|
|
warmup_step=-1, |
|
|
chunk_size=-1, |
|
|
qk_logits=qk_logits, |
|
|
)[0].cpu() |
|
|
|
|
|
return { |
|
|
False: result, |
|
|
True: result_qk_clip, |
|
|
} |
|
|
|
|
|
|
|
|
OVERLAP_STEPS = [5] |
|
|
CHUNK_SIZES = [8] |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("parallel_dims", [ |
|
|
pytest.param(ParallelDims(8, 1, 1), id="base"), |
|
|
pytest.param(ParallelDims(1, 8, 1), id="fsdp"), |
|
|
pytest.param(ParallelDims(2, 4, 1), id="hsdp"), |
|
|
pytest.param(ParallelDims(1, 1, 8), id="tp"), |
|
|
pytest.param(ParallelDims(2, 2, 2), id="hsdp+tp"), |
|
|
pytest.param(ParallelDims(1, 2, 4), id="fsdp+tp"), |
|
|
]) |
|
|
@pytest.mark.parametrize("apply_qk_clip", [False, True]) |
|
|
@pytest.mark.parametrize("use_distributed_muon", [False]) |
|
|
@pytest.mark.parametrize("warmup_step", OVERLAP_STEPS) |
|
|
@pytest.mark.parametrize("chunk_size", CHUNK_SIZES) |
|
|
def test_parallel_muon( |
|
|
request, |
|
|
sequential_muon_result: dict[bool, torch.nn.Module], |
|
|
parallel_dims: ParallelDims, |
|
|
apply_qk_clip: bool, |
|
|
use_distributed_muon: bool, |
|
|
warmup_step: int, |
|
|
chunk_size: int, |
|
|
inputs: tuple[torch.nn.Module, list[torch.Tensor], |
|
|
dict[int, torch.Tensor]], |
|
|
measure_perf, |
|
|
do_profile, |
|
|
) -> None: |
|
|
if use_distributed_muon and chunk_size != CHUNK_SIZES[0]: |
|
|
pytest.skip("Distributed Muon does not effected by chunk size") |
|
|
if use_distributed_muon and warmup_step != OVERLAP_STEPS[0]: |
|
|
pytest.skip("Distributed Muon does not effected by warmup step") |
|
|
|
|
|
model, grads, qk_logits = inputs |
|
|
|
|
|
if not apply_qk_clip: |
|
|
qk_logits = None |
|
|
|
|
|
|
|
|
model = copy.deepcopy(model).cuda() |
|
|
|
|
|
parallelized_model = parallelize_motif(model, parallel_dims) |
|
|
|
|
|
if qk_logits is not None: |
|
|
|
|
|
qk_logits = copy.deepcopy(qk_logits) |
|
|
qk_logits = parallelize_qk_logits(qk_logits, parallel_dims) |
|
|
|
|
|
parallelized_model, timing_result = apply_muon_step( |
|
|
model=parallelized_model, |
|
|
parallel_dims=parallel_dims, |
|
|
grads=grads, |
|
|
warmup_step=warmup_step, |
|
|
chunk_size=chunk_size, |
|
|
qk_logits=qk_logits, |
|
|
use_distributed_muon=use_distributed_muon, |
|
|
measure_perf=measure_perf, |
|
|
do_profile=do_profile, |
|
|
) |
|
|
|
|
|
if measure_perf: |
|
|
assert timing_result is not None |
|
|
avg_time_ms, peak_memory = timing_result |
|
|
logger.info( |
|
|
f"\nParallel dims: {parallel_dims}, " |
|
|
f"\nUse distributed Muon: {use_distributed_muon}, " |
|
|
f"\nApply QK clip: {apply_qk_clip} => " |
|
|
f"\nChunk Size, Warmup Step, Avg Time (ms), Peak Memory (MB):" |
|
|
f"\n{chunk_size}, {warmup_step}, {avg_time_ms:.2f}, {peak_memory / (1024**2):.2f}," |
|
|
) |
|
|
|
|
|
if sequential_muon_result is None: |
|
|
logger.info("Skipping correctness check as sequential result is None") |
|
|
elif measure_perf: |
|
|
logger.info("Skipping correctness check as timing is enabled") |
|
|
else: |
|
|
assert_params_equal(parallelized_model, |
|
|
sequential_muon_result[apply_qk_clip]) |
|
|
|