Skip to content

[examples][CuTeDSL] add MoE dispatch+combine example with NVSHMEM#3221

Open
shubaoyu2 wants to merge 1 commit into
NVIDIA:mainfrom
shubaoyu2:combine-and-dispatch-example-with-nvshmem
Open

[examples][CuTeDSL] add MoE dispatch+combine example with NVSHMEM#3221
shubaoyu2 wants to merge 1 commit into
NVIDIA:mainfrom
shubaoyu2:combine-and-dispatch-example-with-nvshmem

Conversation

@shubaoyu2
Copy link
Copy Markdown
Contributor

Ports the MoE all-to-all dispatch+combine kernel pair to CuTeDSL on top of NVSHMEM. Runs single-node multi-GPU and multi-node:

Single-node (1 host, 4 GPUs):
torchrun --nproc-per-node 4 dispatch_and_combine.py
--num_tokens 64 --hidden 256 --num_experts 32 --topk 2

Multi-node (2 hosts × 4 GPUs each, via SLURM + torchrun):
srun -p -N 2 --ntasks-per-node 1 --gres=gpu:4 \
torchrun --nnodes 2 --nproc_per_node 4 \
--rdzv_endpoint=$MASTER_ADDR:29500 dispatch_and_combine.py \
--num_tokens 64 --hidden 256 --num_experts 64 --topk 2

Ports the Triton-distributed (Bytedance) DeepEP MoE all-to-all
dispatch+combine kernel pair to CuTeDSL on top of NVSHMEM. Runs
single-node multi-GPU and multi-node:

  - dispatch: each rank routes its tokens to the topk experts via
    `nvshmem.core.device.cute.put_warp`. Two modes:
      * WITH_SCATTER_INDICES=False: kernel atomically allocates the
        destination slot via `cute.arch.atomic_add`.
      * WITH_SCATTER_INDICES=True: host pre-computes scatter indices.
  - combine: per-peer route mirroring DeepEP's `expert_node_idx ==
    node_id` predicate. Same-node experts go through a direct LDG
    via `nvshmem_ptr` at SYS-scope VOLATILE; cross-node experts use
    `nvshmem.core.device.cute.get_warp` into a symmetric staging
    buffer.

Bootstrap: supports both launchers — torchrun (RANK / LOCAL_RANK)
and srun (SLURM_PROCID / SLURM_LOCALID, with `tcp://master:port`
init). UID broadcast over the bootstrapped torch.distributed PG,
then `nvshmem.core.init(initializer_method="uid")`.

Validators: `check_dispatch` and `check_combine` allgather the
per-rank state and verify correctness from rank 0; a route-class
coverage gate asserts the kernel's per-peer paths are all
exercised on multi-node runs. Optional `--save-baseline-to <path>`
writes a per-rank .npz that a subsequent run can compare against
for bit-identity.

Variable names follow Triton-distributed DeepEP convention
(`expert_node_idx`, `node_id`, `local_world_size`,
`expert_per_rank`) to ease cross-referencing with the public
reference at
https://github.com/ByteDance-Seed/Triton-distributed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants