[examples][CuTeDSL] add MoE dispatch+combine example with NVSHMEM#3221
Open
shubaoyu2 wants to merge 1 commit into
Open
[examples][CuTeDSL] add MoE dispatch+combine example with NVSHMEM#3221shubaoyu2 wants to merge 1 commit into
shubaoyu2 wants to merge 1 commit into
Conversation
Ports the Triton-distributed (Bytedance) DeepEP MoE all-to-all
dispatch+combine kernel pair to CuTeDSL on top of NVSHMEM. Runs
single-node multi-GPU and multi-node:
- dispatch: each rank routes its tokens to the topk experts via
`nvshmem.core.device.cute.put_warp`. Two modes:
* WITH_SCATTER_INDICES=False: kernel atomically allocates the
destination slot via `cute.arch.atomic_add`.
* WITH_SCATTER_INDICES=True: host pre-computes scatter indices.
- combine: per-peer route mirroring DeepEP's `expert_node_idx ==
node_id` predicate. Same-node experts go through a direct LDG
via `nvshmem_ptr` at SYS-scope VOLATILE; cross-node experts use
`nvshmem.core.device.cute.get_warp` into a symmetric staging
buffer.
Bootstrap: supports both launchers — torchrun (RANK / LOCAL_RANK)
and srun (SLURM_PROCID / SLURM_LOCALID, with `tcp://master:port`
init). UID broadcast over the bootstrapped torch.distributed PG,
then `nvshmem.core.init(initializer_method="uid")`.
Validators: `check_dispatch` and `check_combine` allgather the
per-rank state and verify correctness from rank 0; a route-class
coverage gate asserts the kernel's per-peer paths are all
exercised on multi-node runs. Optional `--save-baseline-to <path>`
writes a per-rank .npz that a subsequent run can compare against
for bit-identity.
Variable names follow Triton-distributed DeepEP convention
(`expert_node_idx`, `node_id`, `local_world_size`,
`expert_per_rank`) to ease cross-referencing with the public
reference at
https://github.com/ByteDance-Seed/Triton-distributed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ports the MoE all-to-all dispatch+combine kernel pair to CuTeDSL on top of NVSHMEM. Runs single-node multi-GPU and multi-node:
Single-node (1 host, 4 GPUs):
torchrun --nproc-per-node 4 dispatch_and_combine.py
--num_tokens 64 --hidden 256 --num_experts 32 --topk 2
Multi-node (2 hosts × 4 GPUs each, via SLURM + torchrun):
srun -p -N 2 --ntasks-per-node 1 --gres=gpu:4 \
torchrun --nnodes 2 --nproc_per_node 4 \
--rdzv_endpoint=$MASTER_ADDR:29500 dispatch_and_combine.py \
--num_tokens 64 --hidden 256 --num_experts 64 --topk 2