diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 3d73c62f5a8a..d2a6c089e8e7 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -215,7 +215,7 @@ def top1gating(logits: Tensor, if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) # Communicate across expert processes to pick the maximum capacity. - if ep_group is not None: + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. @@ -335,7 +335,7 @@ def top2gating(logits: Tensor, else: # Do not drop tokens - set capacity according to current expert assignments new_capacity = torch.max(exp_counts) - if ep_group is not None: + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. @@ -421,7 +421,7 @@ def topkgating( else: # Do not drop tokens - set capacity according to current expert assignments new_capacity = torch.max(exp_counts) - if ep_group is not None: + if ep_group is not None and dist.get_world_size(group=ep_group) > 1: dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. @@ -628,7 +628,10 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: # an allgather to ensure correctness, dispatched_input = drop_tokens(dispatched_input, dim=1) - dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + if self.ep_size > 1: + dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + else: + dispatched_input = dispatched_input.contiguous() if self.wall_clock_breakdown: self.timers(FIRST_ALLTOALL_TIMER).stop() @@ -654,7 +657,10 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers(SECOND_ALLTOALL_TIMER).start() - expert_output = _AllToAll.apply(self.ep_group, expert_output) + if self.ep_size > 1: + expert_output = _AllToAll.apply(self.ep_group, expert_output) + else: + expert_output = expert_output.contiguous() if self.wall_clock_breakdown: self.timers(SECOND_ALLTOALL_TIMER).stop() diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index fc7faceaed12..90c728ef62b7 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -11,8 +11,10 @@ from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader import deepspeed.comm as dist +import deepspeed.moe.sharded_moe as sharded_moe from deepspeed import get_accelerator -from deepspeed.moe.sharded_moe import top1gating, topkgating +from deepspeed.moe.layer import MoE +from deepspeed.moe.sharded_moe import top1gating, top2gating, topkgating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.utils.torch import required_torch_version @@ -209,6 +211,93 @@ def test(self): use_tutel=False) +class TestMoESingleton(DistributedTest): + world_size = 2 + + @pytest.mark.parametrize("ep_size, expected_calls", [(1, 0), (2, 2)], ids=["single", "multi"]) + def test_all_to_all(self, monkeypatch, ep_size, expected_calls): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + config_dict = {"train_micro_batch_size_per_gpu": 1, "steps_per_print": 1} + hidden_dim = 8 + expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim)) + model = MoE(hidden_size=hidden_dim, expert=expert, num_experts=2, ep_size=ep_size, k=1, min_capacity=0) + optimizer = torch.optim.AdamW(params=model.parameters()) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + optimizer=optimizer, + dist_init_required=False) + + all_to_all_calls = [] + + def counted_all_to_all(group, input): + all_to_all_calls.append((group, input.shape)) + return input + + monkeypatch.setattr(sharded_moe._AllToAll, "apply", counted_all_to_all) + + x = torch.randn(1, 4, hidden_dim, device=model.device, requires_grad=True) + output, l_aux, _ = model(x) + assert len(all_to_all_calls) == expected_calls + + loss = output.float().sum() + l_aux.float() + model.backward(loss) + assert len(all_to_all_calls) == expected_calls + assert x.grad is not None + assert any(param.grad is not None for param in model.module.parameters()) + + @pytest.mark.parametrize("gate_fn, capacity_args", [(top1gating, (1, 0)), (top2gating, (1, 0)), + (topkgating, (3, 1, 0))], + ids=["top1", "top2", "topk"]) + @pytest.mark.parametrize("ep_world_size, expected_calls", [(1, 0), (2, 1)], ids=["single", "multi"]) + def test_capacity(self, monkeypatch, gate_fn, capacity_args, ep_world_size, expected_calls): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + ep_group = None + if ep_world_size == 1: + for rank in range(dist.get_world_size()): + group = dist.new_group([rank]) + if rank == dist.get_rank(): + ep_group = group + else: + ep_group = dist.new_group(list(range(dist.get_world_size()))) + + all_reduce_calls = [] + original_all_reduce = sharded_moe.dist.all_reduce + + def counted_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None): + all_reduce_calls.append((tensor, op, group)) + return original_all_reduce(tensor, op=op, group=group) + + monkeypatch.setattr(sharded_moe.dist, "all_reduce", counted_all_reduce) + + device = get_accelerator().current_device_name() + logits = torch.randn(8, 4, device=device) + gate_fn(logits, *capacity_args, drop_tokens=False, ep_group=ep_group) + + assert len(all_reduce_calls) == expected_calls + if all_reduce_calls: + _, op, group = all_reduce_calls[0] + assert op == dist.ReduceOp.MAX + assert group is ep_group + + def test_no_ep_group(self, monkeypatch): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + def fail_collective(*args, **kwargs): + raise AssertionError("ep_group=None should not enter expert-parallel collective code") + + monkeypatch.setattr(sharded_moe.dist, "get_world_size", fail_collective) + monkeypatch.setattr(sharded_moe.dist, "all_reduce", fail_collective) + + device = get_accelerator().current_device_name() + logits = torch.randn(8, 4, device=device) + top2gating(logits, 1, 0, drop_tokens=False, ep_group=None, top2_2nd_expert_sampling=False) + + class TestTopkGate(DistributedTest): def test(self):