| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import os | |
| def test_megablocks_moe_mlp_import(): | |
| from megablocks.layers import MegaBlocksMoeMLP | |
| assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." | |
| def run_distributed_test(rank, world_size): | |
| from megablocks.layers import MegaBlocksMoeMLP | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12355" | |
| os.environ["RANK"] = str(rank) | |
| os.environ["WORLD_SIZE"] = str(world_size) | |
| dist.init_process_group( | |
| backend="gloo", | |
| rank=rank, | |
| world_size=world_size, | |
| ) | |
| expert_parallel_group = torch.distributed.new_group( | |
| range(torch.distributed.get_world_size()) | |
| ) | |
| model = MegaBlocksMoeMLP() | |
| model.expert_parallel_group = expert_parallel_group | |
| class Experts: | |
| def __init__(self): | |
| self.gate_up_proj = None | |
| self.gate_up_proj_bias = None | |
| self.down_proj = None | |
| self.down_proj_bias = None | |
| self.hidden_size = None | |
| model.experts = Experts() | |
| num_experts = 128 | |
| hidden_size = 1152 | |
| intermediate_size = 3072 | |
| ne, hs, isz = num_experts, hidden_size, intermediate_size | |
| experts_per_rank = ne // world_size | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.router = torch.nn.Linear(hs, ne).to(device) | |
| model.router.weight.data.fill_(1) | |
| e = model.experts | |
| e.gate_up_proj = torch.nn.Parameter( | |
| torch.ones(experts_per_rank, hs, isz, device=device) | |
| ) | |
| e.gate_up_proj_bias = torch.nn.Parameter( | |
| torch.zeros(experts_per_rank, isz, device=device) | |
| ) | |
| e.down_proj = torch.nn.Parameter( | |
| torch.ones(experts_per_rank, 1536, hs, device=device) | |
| ) | |
| e.down_proj_bias = torch.nn.Parameter( | |
| torch.zeros(experts_per_rank, hs, device=device) | |
| ) | |
| e.hidden_size = hs | |
| x = torch.randn(1, 1, 1152).to(device) | |
| output, expert_weights_out = model(x) | |
| assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}." | |
| print(f"Rank {rank}: Test passed! Output shape: {output.shape}") | |
| dist.destroy_process_group() | |
| def test_megablocks_moe_mlp_functionality(): | |
| world_size = 2 | |
| mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True) | |
| print("Multi-process test completed successfully!") | |
| if __name__ == "__main__": | |
| test_megablocks_moe_mlp_import() | |
| print("Import test passed!") | |
| test_megablocks_moe_mlp_functionality() | |