From d9817ba11ac9f77c69c8fbfe92427d0ea5f897b7 Mon Sep 17 00:00:00 2001 From: irexyc Date: Fri, 27 Mar 2026 02:55:58 +0000 Subject: [PATCH 01/21] integrate deep-ep nccl backend (intranode + low_latency kernels --- lmdeploy/cli/serve.py | 4 +- lmdeploy/messages.py | 1 + lmdeploy/turbomind/deploy/config.py | 1 + lmdeploy/turbomind/deploy/converter.py | 1 + lmdeploy/turbomind/deploy/module.py | 2 +- lmdeploy/turbomind/turbomind.py | 21 +- src/turbomind/comm/device_comm.h | 84 + src/turbomind/comm/nccl/CMakeLists.txt | 20 +- src/turbomind/comm/nccl/deep_ep/config.hpp | 193 +++ src/turbomind/comm/nccl/deep_ep/deep_ep.cpp | 1119 ++++++++++++++ src/turbomind/comm/nccl/deep_ep/deep_ep.hpp | 225 +++ .../comm/nccl/deep_ep/gin_backend.cu | 244 +++ src/turbomind/comm/nccl/deep_ep/gin_backend.h | 82 + .../comm/nccl/deep_ep/kernels/api.cuh | 380 +++++ .../comm/nccl/deep_ep/kernels/buffer.cuh | 134 ++ .../comm/nccl/deep_ep/kernels/configs.cuh | 81 + .../comm/nccl/deep_ep/kernels/exception.cuh | 76 + .../comm/nccl/deep_ep/kernels/internode_ll.cu | 1348 +++++++++++++++++ .../comm/nccl/deep_ep/kernels/intranode.cu | 1110 ++++++++++++++ .../comm/nccl/deep_ep/kernels/launch.cuh | 138 ++ .../comm/nccl/deep_ep/kernels/layout.cu | 153 ++ .../comm/nccl/deep_ep/kernels/runtime.cu | 96 ++ .../comm/nccl/deep_ep/kernels/utils.cuh | 640 ++++++++ src/turbomind/comm/nccl/nccl.cu | 554 +++---- src/turbomind/comm/nccl/nccl_comm.h | 116 ++ src/turbomind/comm/nccl/nccl_ep.cu | 254 ++++ src/turbomind/kernels/gemm/CMakeLists.txt | 1 + src/turbomind/kernels/gemm/moe_ep_utils.cu | 701 +++++++++ src/turbomind/kernels/gemm/moe_ep_utils.h | 63 + .../models/llama/FusedRMSNormLayer.h | 172 +++ .../models/llama/LlamaDecoderLayerWeight.cc | 8 +- .../models/llama/LlamaDecoderLayerWeight.h | 2 + .../models/llama/LlamaDenseWeight.cc | 10 +- src/turbomind/models/llama/LlamaDenseWeight.h | 2 + src/turbomind/models/llama/llama_params.h | 5 + src/turbomind/models/llama/llama_utils.cu | 1 + src/turbomind/models/llama/moe_ffn_layer.cc | 296 +++- src/turbomind/models/llama/moe_ffn_layer.h | 24 + src/turbomind/models/llama/unified_decoder.cc | 240 +-- src/turbomind/models/llama/unified_decoder.h | 15 +- src/turbomind/turbomind.cc | 19 +- 41 files changed, 8189 insertions(+), 447 deletions(-) create mode 100644 src/turbomind/comm/nccl/deep_ep/config.hpp create mode 100644 src/turbomind/comm/nccl/deep_ep/deep_ep.cpp create mode 100644 src/turbomind/comm/nccl/deep_ep/deep_ep.hpp create mode 100644 src/turbomind/comm/nccl/deep_ep/gin_backend.cu create mode 100644 src/turbomind/comm/nccl/deep_ep/gin_backend.h create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/api.cuh create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/buffer.cuh create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/configs.cuh create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/exception.cuh create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/internode_ll.cu create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/intranode.cu create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/launch.cuh create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/layout.cu create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/runtime.cu create mode 100644 src/turbomind/comm/nccl/deep_ep/kernels/utils.cuh create mode 100644 src/turbomind/comm/nccl/nccl_comm.h create mode 100644 src/turbomind/comm/nccl/nccl_ep.cu create mode 100644 src/turbomind/kernels/gemm/moe_ep_utils.cu create mode 100644 src/turbomind/kernels/gemm/moe_ep_utils.h create mode 100644 src/turbomind/models/llama/FusedRMSNormLayer.h diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 155392f4a7..bd47ba33ea 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -121,7 +121,7 @@ def add_parser_api_server(): hf_overrides = ArgumentHelper.hf_overrides(pt_group) disable_metrics = ArgumentHelper.disable_metrics(pt_group) dp = ArgumentHelper.dp(pt_group) - ArgumentHelper.ep(pt_group) + ep = ArgumentHelper.ep(pt_group) ArgumentHelper.enable_microbatch(pt_group) ArgumentHelper.enable_eplb(pt_group) ArgumentHelper.role(pt_group) @@ -148,6 +148,7 @@ def add_parser_api_server(): tb_group._group_actions.append(hf_overrides) tb_group._group_actions.append(disable_metrics) tb_group._group_actions.append(dp) + tb_group._group_actions.append(ep) ArgumentHelper.cp(tb_group) ArgumentHelper.rope_scaling_factor(tb_group) ArgumentHelper.num_tokens_per_iter(tb_group) @@ -255,6 +256,7 @@ def api_server(args): tp=args.tp, dp=args.dp, cp=args.cp, + ep=args.ep, nnodes=args.nnodes, node_rank=args.node_rank, dist_init_addr=args.dist_init_addr, diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d6cd1a3329..540ee5a0a6 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -261,6 +261,7 @@ class TurbomindEngineConfig: tp: int = 1 dp: int = 1 cp: int = 1 + ep: int = 1 device_num: int = None attn_tp_size: int = None attn_cp_size: int = None diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index 8fdb95ac78..949f06ea6e 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -71,6 +71,7 @@ class ModelConfig: attn_tp_size: int = 1 attn_cp_size: int = 1 mlp_tp_size: int = 1 + ep_size: int = 1 model_format: str = 'hf' expert_num: list[int] = field(default_factory=list) expert_router_bias: bool = False diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 05b1ba526f..0021c5caca 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -276,6 +276,7 @@ def get_tm_model(model_path, tm_cfg.model_config.attn_cp_size = engine_config.attn_cp_size if engine_config.mlp_tp_size is not None: tm_cfg.model_config.mlp_tp_size = engine_config.mlp_tp_size + tm_cfg.model_config.ep_size = engine_config.ep output_model = OUTPUT_MODELS.get(output_model_name)(input_model=input_model, cfg=tm_cfg, diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 330fbacc9e..f8e06c56fd 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -140,7 +140,7 @@ class Ffn(Module): def __init__(self, model: BaseOutputModel): self.model = model - self.tp = model.mlp_tp_size + self.tp = model.mlp_tp_size if model.model_config.ep_size == 1 else 1 # inter_sizes in config are padded and may be different from what's # in the weights self.inter_size = model.model_config.inter_size diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f95b2b93ca..d8b2b5a3b2 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -86,7 +86,26 @@ def complete_parallel_config(cfg: TurbomindEngineConfig): def update_parallel_config(cfg: TurbomindEngineConfig): cfg.device_num = len(cfg.devices) * cfg.nnodes if cfg.devices else cfg.device_num - if not complete_parallel_config(cfg): + if not complete_parallel_config(cfg) and cfg.ep > 1: + if cfg.communicator in ['cuda-ipc', 'native']: + assert cfg.nnodes == 1, 'TurboMind does not support multi-node with ep > 1' + total = cfg.dp * cfg.ep + if not cfg.device_num: + count = torch.cuda.device_count() * cfg.nnodes + if total < count: + count = total + cfg.device_num = count + assert total % cfg.device_num == 0 + overlap = total // cfg.device_num + attn_dp_size = overlap + inner_tp_size = cfg.ep // overlap + cfg.outer_dp_size = cfg.dp // overlap + cfg.attn_dp_size = overlap // cfg.nnodes + cfg.attn_tp_size = inner_tp_size // cfg.cp + cfg.attn_cp_size = cfg.cp + cfg.mlp_dp_size = 1 + cfg.mlp_tp_size = cfg.attn_dp_size * cfg.attn_tp_size * cfg.attn_cp_size + elif not complete_parallel_config(cfg): total = cfg.dp * cfg.tp if not cfg.device_num: count = torch.cuda.device_count() * cfg.nnodes diff --git a/src/turbomind/comm/device_comm.h b/src/turbomind/comm/device_comm.h index a6948762df..0c85ac7ebf 100644 --- a/src/turbomind/comm/device_comm.h +++ b/src/turbomind/comm/device_comm.h @@ -9,9 +9,58 @@ #include #include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/buffer.h" +#include "src/turbomind/core/tensor.h" namespace turbomind::comm { +struct EpConfig { + int num_nodes; + int num_experts; + int hidden; + int ll_max_tokens_per_rank; +}; + +enum EpMode +{ + kNull, + kHighThroughput, + kLowLatency, +}; + +struct EpDispatchInput { + EpMode& mode; + core::Tensor& x; + core::Tensor_& topk_weights; + core::Tensor_& topk_idx; +}; + +struct EpDispatchOutput { + core::Tensor out_x; + core::Tensor out_topk_weights; + core::Buffer_& f2n; + core::Buffer_& f2E; + core::Buffer_& en2f; + core::Buffer_& offsets; + + std::vector handle; + + int out_token_num; + int out_expert_token_num; +}; + +struct EpCombineInput { + EpMode& mode; + core::Tensor& x; + std::vector& handle; + std::optional topk_weights; + std::optional topk_idx; +}; + +struct EpCombineOutput { + core::Tensor out_x; +}; + enum QueryAttr { kHasAllGather2D @@ -117,6 +166,41 @@ class DeviceCommImpl { { throw std::runtime_error("not implemented"); } + + virtual void ReduceScatterV(const void* sendbuff, // + void* recvbuff, + const size_t* counts, + DataType type, + int group, + cudaStream_t stream) + { + throw std::runtime_error("not implemented"); + } + + virtual void AllGatherV(const void* sendbuff, // + void* recvbuff, + const size_t* counts, + DataType type, + int group, + cudaStream_t stream) + { + throw std::runtime_error("not implemented"); + } + + virtual void InitializeEp(const EpConfig& config) + { + throw std::runtime_error("ep not implemented"); + } + + virtual void Dispatch(const EpDispatchInput& input, EpDispatchOutput& output, int group) + { + throw std::runtime_error("not implemented"); + } + + virtual void Combine(const EpCombineInput& input, EpCombineOutput& output, int group) + { + throw std::runtime_error("not implemented"); + } }; class DeviceComm { diff --git a/src/turbomind/comm/nccl/CMakeLists.txt b/src/turbomind/comm/nccl/CMakeLists.txt index 373558c84e..2c63c0a122 100644 --- a/src/turbomind/comm/nccl/CMakeLists.txt +++ b/src/turbomind/comm/nccl/CMakeLists.txt @@ -2,8 +2,24 @@ cmake_minimum_required(VERSION 3.11) -add_library(nccl_comm STATIC nccl.cu) -target_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger) +set(DEEP_EP_SOURCE_FILES + deep_ep/deep_ep.cpp + deep_ep/gin_backend.cu + deep_ep/kernels/runtime.cu + deep_ep/kernels/layout.cu + deep_ep/kernels/intranode.cu + deep_ep/kernels/internode_ll.cu +) + +add_library(deepep STATIC ${DEEP_EP_SOURCE_FILES}) +target_link_libraries(deepep PRIVATE ${NCCL_LIBRARIES} CUDA::cudart) +set_property(TARGET deepep PROPERTY CUDA_ARCHITECTURES 90) +target_include_directories(deepep PRIVATE ${NCCL_INCLUDE_DIRS}) +set_property(TARGET deepep PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET deepep PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(nccl_comm STATIC nccl.cu nccl_ep.cu) +target_link_libraries(nccl_comm PRIVATE rms_norm core ${NCCL_LIBRARIES} logger deepep) target_include_directories(nccl_comm PRIVATE ${NCCL_INCLUDE_DIRS}) set_property(TARGET nccl_comm PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/turbomind/comm/nccl/deep_ep/config.hpp b/src/turbomind/comm/nccl/deep_ep/config.hpp new file mode 100644 index 0000000000..0839265799 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/config.hpp @@ -0,0 +1,193 @@ +// clang-format off +#pragma once + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" + +namespace deep_ep { + +template +dtype_t ceil_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; +} + +template +dtype_t align_up(dtype_t a, dtype_t b) { + return ceil_div(a, b) * b; +} + +template +dtype_t align_down(dtype_t a, dtype_t b) { + return a / b * b; +} + +struct Config { + int num_sms; + int num_max_nvl_chunked_send_tokens; + int num_max_nvl_chunked_recv_tokens; + int num_max_rdma_chunked_send_tokens; + int num_max_rdma_chunked_recv_tokens; + + Config(int num_sms, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens) + : num_sms(num_sms), + num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), + num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), + num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), + num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { + EP_HOST_ASSERT(num_sms >= 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); + + // Ceil up RDMA buffer size + this->num_max_rdma_chunked_recv_tokens = align_up(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); + // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); + } + + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } + + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { + // Legacy mode + if (num_ranks <= NUM_MAX_NVL_PEERS) + return 0; + + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(topk_idx_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } +}; + +struct LowLatencyBuffer { + int num_clean_int = 0; + + void* dispatch_rdma_send_buffer = nullptr; + void* dispatch_rdma_recv_data_buffer = nullptr; + int* dispatch_rdma_recv_count_buffer = nullptr; + + void* combine_rdma_send_buffer = nullptr; + void* combine_rdma_recv_data_buffer = nullptr; + int* combine_rdma_recv_flag_buffer = nullptr; + + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; + + std::pair clean_meta() { + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_recv_count_buffer, num_clean_int}; + } +}; + +struct LowLatencyLayout { + void* rdma_buffer = nullptr; + size_t total_bytes = 0; + LowLatencyBuffer buffers[2]; + + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts): rdma_buffer(rdma_buffer) { + const int num_scales = hidden / 128; + + // Dispatch and combine layout: + // - 2 symmetric odd/even send buffer + // - 2 symmetric odd/even receive buffers + // - 2 symmetric odd/even signaling buffers + + // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation + // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); + + // Send buffer + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); + total_bytes += send_buffer_bytes * 2; + + // Symmetric receive buffers + // TODO: optimize memory usages + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); + total_bytes += recv_buffer_bytes * 2; + + // Symmetric signaling buffers + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + size_t signaling_buffer_bytes_aligned = align_up(signaling_buffer_bytes, 128); + total_bytes += signaling_buffer_bytes_aligned * 2; + + // Assign pointers + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, + // so you may see some parameters are duplicated + for (int i = 0; i < 2; ++i) { + buffers[i] = {static_cast(signaling_buffer_bytes / sizeof(int)), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * i), + advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), + num_bytes_per_combine_msg}; + } + } +}; + +inline size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; +} + +} // namespace deep_ep + +// clang-format on diff --git a/src/turbomind/comm/nccl/deep_ep/deep_ep.cpp b/src/turbomind/comm/nccl/deep_ep/deep_ep.cpp new file mode 100644 index 0000000000..ee2b469176 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/deep_ep.cpp @@ -0,0 +1,1119 @@ +#include "src/turbomind/comm/nccl/deep_ep/deep_ep.hpp" + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/allocator.h" +#include "src/turbomind/core/context.h" +#include "src/turbomind/core/data_type.h" +#include "src/turbomind/kernels/core/math.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/string_utils.h" + +#include +#include +#include +#include +#include +#include +#include + +using turbomind::fmtstr; +using turbomind::round_up; + +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) +{ + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) +{ + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) +{ + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) + size = granularity; + return size; +} + +SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric): use_fabric(use_fabric) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) +{ + if (use_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } + else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) +{ + if (use_fabric) { + cu_mem_free(ptr); + } + else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) +{ + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (use_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle( + &mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } + else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) +{ + if (use_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle( + &handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } + else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) +{ + if (use_fabric) { + cu_mem_free(ptr); + } + else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} // namespace shared_memory + +namespace deep_ep { + +Buffer::Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + int64_t num_ll_rdma_bytes, + bool low_latency_mode, + bool enable_shrink, + bool use_fabric, + int qps_per_rank, + HostComm h_comm): + rank(rank), + num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), + low_latency_mode(low_latency_mode), + num_rdma_bytes(num_rdma_bytes), + num_ll_rdma_bytes(num_ll_rdma_bytes), + enable_shrink(enable_shrink), + shared_memory_allocator(use_fabric), + qps_per_rank(qps_per_rank), + h_comm(h_comm) +{ + // move to turbomind.py + setenv("NCCL_GIN_GDAKI_QP_DEPTH", "1024", 0); + + // Common checks + EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 + and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 + and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(0 <= rank and rank < num_ranks + and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + if (num_rdma_bytes > 0) { + EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); + } + + // Get ranks + CUDA_CHECK(cudaGetDevice(&device_id)); + rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + + // Get device info + cudaDeviceProp device_prop = {}; + CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + num_device_sms = device_prop.multiProcessorCount; + + // Number of per-channel bytes cannot be large + EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); + EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); + + auto comm_stream = turbomind::core::Context::stream().handle(); + + // Create 32 MiB workspace + CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + + // MoE counter + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); + *moe_recv_counter = -1; + + // MoE expert-level counter + CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) + moe_recv_expert_counter[i] = -1; + + // MoE RDMA-level counter + if (num_rdma_ranks > 0) { + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); + *moe_recv_rdma_counter = -1; + } + + // NVLink + if (num_nvl_bytes > 0) { + allocate_sync_nvl_buffer(); + } + + // RDMA + if (num_rdma_bytes || num_ll_rdma_bytes) { + allocate_rdma_buffer(); + } + + turbomind::core::Context::stream().Sync(); + h_comm->Sync(); + + // Ready to use + available = true; +} + +void Buffer::allocate_sync_nvl_buffer() +{ + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); + + auto stream = turbomind::core::Context::stream().handle(); + + HostComm h_nvl_comm = h_comm->Split(rdma_rank, 0); + + ipc_comm = CreateDeviceCommunicator("cuda-ipc", h_nvl_comm->n_ranks(), nvl_rank, h_nvl_comm); + + buffer_ptrs[nvl_rank] = + ipc_comm->Allocate(num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + + buffer_ptrs_gpu = + reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); + + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = + reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + + barrier_signal_bytes + buffer_ptr_bytes); + + // No need to synchronize, will do a full device sync during `sync` + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, stream)); + + AllGather(h_nvl_comm, buffer_ptrs, 1); + + for (int i = 0; i < num_nvl_ranks; ++i) { + if (i != nvl_rank) { + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); + } + } + + // Copy all buffer and barrier signal pointers to GPU + CUDA_CHECK(cudaMemcpyAsync( + buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(barrier_signal_ptrs_gpu, + barrier_signal_ptrs, + sizeof(int*) * NUM_MAX_NVL_PEERS, + cudaMemcpyHostToDevice, + stream)); +} + +void Buffer::allocate_rdma_buffer() +{ + TM_CHECK_EQ(comm, nullptr); + if ((not low_latency_mode) and (num_rdma_ranks == 1)) { + return; + } + + std::vector unique_ids; + if (rank == 0) { + unique_ids = deep_ep::internode::get_unique_id(); + } + Broadcast(h_comm, unique_ids, 0); + + comm = std::make_shared(); + comm->init(unique_ids, rank, num_ranks, low_latency_mode, qps_per_rank); + internode::barrier(comm.get()); + + auto stream = turbomind::core::Context::stream().handle(); + + if (num_rdma_bytes) { + // Allocate High-Throughput RDMA buffer + rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get()); + // Register memory with NCCL communicators (sets up windows for RDMA) + internode::register_memory(rdma_buffer_ptr, num_rdma_bytes, comm.get()); + } + + if (num_ll_rdma_bytes) { + // Allocate Low-Latency RDMA buffer + rdma_ll_buffer_ptr = internode::alloc(num_ll_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get()); + // Register memory with NCCL communicators (sets up windows for RDMA) + internode::register_memory(rdma_ll_buffer_ptr, num_ll_rdma_bytes, comm.get()); + + // Clean buffer (mainly for low-latency mode) + CUDA_CHECK(cudaMemsetAsync(rdma_ll_buffer_ptr, 0, num_ll_rdma_bytes, stream)); + + internode_ll::set_p2p_disabled_flag(comm->is_p2p_disabled()); + } + + // Allocate and clean shrink buffer + if (enable_shrink) { + int num_mask_buffer_bytes = num_ranks * sizeof(int); + int num_sync_buffer_bytes = num_ranks * sizeof(int); + mask_buffer_ptr = + reinterpret_cast(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get())); + sync_buffer_ptr = + reinterpret_cast(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES, comm.get())); + CUDA_CHECK(cudaMemsetAsync(mask_buffer_ptr, 0, num_mask_buffer_bytes, stream)); + CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes)); + } + + // Barrier + internode::barrier(comm.get()); +} + +bool Buffer::is_available() const +{ + return available; +} + +bool Buffer::is_internode_available() const +{ + return is_available() and num_ranks > NUM_MAX_NVL_PEERS; +} + +int Buffer::get_num_rdma_ranks() const +{ + return num_rdma_ranks; +} + +int Buffer::get_rdma_rank() const +{ + return rdma_rank; +} + +int Buffer::get_root_rdma_rank(bool global) const +{ + return global ? nvl_rank : 0; +} + +int Buffer::get_local_device_id() const +{ + return device_id; +} + +void Buffer::destroy() +{ + TM_LOG_DEBUG("[NCCLEP][%d] Destroying buffer", rank); + EP_HOST_ASSERT(not destroyed); + + // Synchronize + auto comm_stream = turbomind::core::Context::stream().handle(); + + if (num_nvl_bytes > 0 && ipc_comm) { + turbomind::core::Context::stream().Sync(); + ipc_comm->Free(buffer_ptrs[nvl_rank]); + ipc_comm = {}; + } + + // Free NVSHMEM + if (is_available()) { + turbomind::core::Context::stream().Sync(); + if (num_rdma_bytes > 0) { + internode::free(rdma_buffer_ptr, comm.get()); + } + if (num_ll_rdma_bytes > 0) { + internode::free(rdma_ll_buffer_ptr, comm.get()); + } + if (enable_shrink) { + internode::free(mask_buffer_ptr, comm.get()); + internode::free(sync_buffer_ptr, comm.get()); + } + internode::finalize(comm.get()); + } + + // Free workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); + + destroyed = true; + available = false; +} + +std::tuple, Tensor, Tensor> // +Buffer::get_dispatch_layout(const Tensor& topk_idx, int num_experts) +{ + + auto num_tokens = static_cast(topk_idx.shape(0)); + auto num_topk = static_cast(topk_idx.shape(1)); + auto num_tokens_per_rank = Tensor_{{num_ranks}, turbomind::kDEVICE}; + auto num_tokens_per_rdma_rank = std::optional(); + auto num_tokens_per_expert = Tensor_{{num_experts}, turbomind::kDEVICE}; + auto is_token_in_rank = Tensor_{{num_tokens, num_ranks}, turbomind::kDEVICE}; + if (is_internode_available()) { + num_tokens_per_rdma_rank = Buffer_{num_rdma_ranks, turbomind::kDEVICE}; + } + static_assert(sizeof(topk_idx_t) == sizeof(int64_t), "topk_idx_t must be int64_t"); + + auto stream = turbomind::core::Context::stream().handle(); + layout::get_dispatch_layout(topk_idx.data(), + num_tokens_per_rank.data(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data() : + nullptr, + num_tokens_per_expert.data(), + is_token_in_rank.data_or((bool*)nullptr), // num_tokens may be zero + num_tokens, + num_topk, + num_ranks, + num_experts, + stream); + + return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank}; +} + +std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor> +Buffer::intranode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config) +{ + bool cached_mode = cached_rank_prefix_matrix.has_value(); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); + } + else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + // Type checks + EP_HOST_ASSERT(is_token_in_rank.dtype() == turbomind::kBool); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(cached_channel_prefix_matrix->dtype() == turbomind::kInt32); + } + else { + EP_HOST_ASSERT(num_tokens_per_expert->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(num_tokens_per_rank->dtype() == turbomind::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.shape(1) * byte_size(x.dtype())) % sizeof(int4) == 0); + EP_HOST_ASSERT(is_token_in_rank.ndim() == 2 and is_token_in_rank.is_contiguous()); + EP_HOST_ASSERT(is_token_in_rank.shape(0) == x.shape(0) and is_token_in_rank.shape(1) == num_ranks); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->ndim() == 2 and cached_rank_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rank_prefix_matrix->shape(0) == num_ranks + and cached_rank_prefix_matrix->shape(1) == num_ranks); + EP_HOST_ASSERT(cached_channel_prefix_matrix->ndim() == 2 and cached_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_channel_prefix_matrix->shape(0) == num_ranks + and cached_channel_prefix_matrix->shape(1) == num_channels); + } + else { + EP_HOST_ASSERT(num_tokens_per_expert->ndim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->shape(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + EP_HOST_ASSERT(num_tokens_per_rank->ndim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->shape(0) == num_ranks); + } + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->shape(0)); + auto num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + const topk_idx_t* topk_idx_ptr = nullptr; + const float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->shape(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->ndim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->shape(0) and num_tokens == topk_weights->shape(0)); + EP_HOST_ASSERT(num_topk == topk_weights->shape(1)); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + topk_idx_ptr = topk_idx->data_or((topk_idx_t*)nullptr); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + } + + // FP8 scales checks + const float* x_scales_ptr = nullptr; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(byte_size(x.dtype()) == 1); + EP_HOST_ASSERT(x_scales->dtype() == turbomind::kFloat32 or x_scales->dtype() == turbomind::kInt32); + EP_HOST_ASSERT(x_scales->ndim() == 2); + EP_HOST_ASSERT(x_scales->shape(0) == num_tokens); + num_scales = x_scales->ndim() == 1 ? 1 : static_cast(x_scales->shape(1)); + x_scales_ptr = x_scales->data_or((float*)nullptr); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1; + auto rank_prefix_matrix = Tensor(); + auto channel_prefix_matrix = Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // used to compute offsets in MoeFfnLayer + auto moe_recv_expert_counter_ten = Tensor({num_local_experts}, turbomind::kInt32, turbomind::kDEVICE); + + // Barrier or send sizes + // To clean: channel start/end offset, head and tail + int num_memset_int = num_channels * num_ranks * 4; + if (cached_mode) { + EP_HOST_ASSERT(0); + // num_recv_tokens = cached_num_recv_tokens; + // rank_prefix_matrix = cached_rank_prefix_matrix.value(); + // channel_prefix_matrix = cached_channel_prefix_matrix.value(); + + // // Copy rank prefix matrix and clean flags + // intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), + // num_memset_int, + // buffer_ptrs_gpu, + // barrier_signal_ptrs_gpu, + // rank, + // num_ranks, + // comm_stream); + } + else { + rank_prefix_matrix = Tensor({num_ranks, num_ranks}, turbomind::kInt32, turbomind::kDEVICE); + channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + + // Send sizes + // Meta information: + // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` + // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` + // NOTES: no more token dropping in this version + *moe_recv_counter = -1; + for (int i = 0; i < num_local_experts; ++i) + moe_recv_expert_counter[i] = -1; + EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); + intranode::notify_dispatch(num_tokens_per_rank->data(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_expert->data(), + moe_recv_expert_counter_mapped, + moe_recv_expert_counter_ten.data(), + num_experts, + num_tokens, + is_token_in_rank.data_or((bool*)nullptr), // num_tokens may be zero + channel_prefix_matrix.data(), + rank_prefix_matrix.data(), + num_memset_int, + expert_alignment, + buffer_ptrs_gpu, + barrier_signal_ptrs_gpu, + rank, + turbomind::core::Context::stream().handle(), + num_channels); + + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; + + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } + else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() + - start_time) + .count() + > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + } + + // Allocate new tensors + auto recv_x = Tensor({num_recv_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + auto recv_src_idx = Tensor({num_recv_tokens}, turbomind::kInt32, turbomind::kDEVICE); + auto recv_topk_idx = std::optional(); + auto recv_topk_weights = std::optional(); + auto recv_x_scales = std::optional(); + auto recv_channel_prefix_matrix = Tensor({num_ranks, num_channels}, turbomind::kInt32, turbomind::kDEVICE); + auto send_head = Tensor({num_tokens, num_ranks}, turbomind::kInt32, turbomind::kDEVICE); + + // Assign pointers + topk_idx_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = Tensor({num_recv_tokens, num_topk}, topk_idx->dtype(), topk_idx->device()); + recv_topk_weights = Tensor({num_recv_tokens, num_topk}, topk_weights->dtype(), topk_weights->device()); + recv_topk_idx_ptr = recv_topk_idx->data_or((topk_idx_t*)nullptr); + recv_topk_weights_ptr = recv_topk_weights->data_or((float*)nullptr); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->ndim() == 1 ? + Tensor({num_recv_tokens}, x_scales->dtype(), x_scales->device()) : + Tensor({num_recv_tokens, num_scales}, x_scales->dtype(), x_scales->device()); + recv_x_scales_ptr = recv_x_scales->data_or((float*)nullptr); + } + + // Dispatch + EP_HOST_ASSERT( + num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * byte_size(recv_x.dtype()) + + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) + * num_scales // FP8 scale buffer + <= num_nvl_bytes); + + intranode::dispatch(recv_x.data_or((void*)nullptr), + recv_x_scales_ptr, + recv_src_idx.data_or((int*)nullptr), + recv_topk_idx_ptr, + recv_topk_weights_ptr, + recv_channel_prefix_matrix.data(), + send_head.data_or((int*)nullptr), + x.data_or((void*)nullptr), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + is_token_in_rank.data_or((bool*)nullptr), + channel_prefix_matrix.data(), + num_tokens, + num_worst_tokens, + static_cast(hidden * byte_size(recv_x.dtype()) / sizeof(int4)), + num_topk, + num_experts, + num_scales, + scale_token_stride, + scale_hidden_stride, + buffer_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle(), + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); + + // Return values + return {recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + moe_recv_expert_counter_ten, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + send_head}; +} + +std::tuple> +Buffer::intranode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_idx, + const Tensor& rank_prefix_matrix, + const Tensor& channel_prefix_matrix, + Tensor& send_head, + const Config& config) +{ + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_idx.ndim() == 1 and src_idx.is_contiguous() and src_idx.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(send_head.ndim() == 2 and send_head.is_contiguous() and send_head.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.ndim() == 2 and rank_prefix_matrix.is_contiguous() + and rank_prefix_matrix.dtype() == turbomind::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.ndim() == 2 and channel_prefix_matrix.is_contiguous() + and channel_prefix_matrix.dtype() == turbomind::kInt32); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + auto num_tokens = static_cast(x.shape(0)), hidden = static_cast(x.shape(1)); + auto num_recv_tokens = static_cast(send_head.shape(0)); + EP_HOST_ASSERT(src_idx.shape(0) == num_tokens); + EP_HOST_ASSERT(send_head.shape(1) == num_ranks); + EP_HOST_ASSERT(rank_prefix_matrix.shape(0) == num_ranks and rank_prefix_matrix.shape(1) == num_ranks); + EP_HOST_ASSERT(channel_prefix_matrix.shape(0) == num_ranks and channel_prefix_matrix.shape(1) == num_channels); + EP_HOST_ASSERT((hidden * byte_size(x.dtype())) % sizeof(int4) == 0); + + int num_topk = 0; + auto recv_topk_weights = std::optional(); + const float* topk_weights_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->ndim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->shape(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->dtype() == turbomind::kFloat32); + num_topk = static_cast(topk_weights->shape(1)); + topk_weights_ptr = topk_weights->data_or((float*)nullptr); + recv_topk_weights = Tensor({num_recv_tokens, num_topk}, turbomind::kFloat32, turbomind::kDEVICE); + recv_topk_weights_ptr = recv_topk_weights->data_or((float*)nullptr); + } + + // Launch barrier and reset queue head and tail + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); + intranode::cached_notify_combine(buffer_ptrs_gpu, + send_head.data_or((int*)nullptr), + num_channels, + num_recv_tokens, + num_channels * num_ranks * 2, + barrier_signal_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle()); + + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.ndim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.dtype() == x.dtype()); + EP_HOST_ASSERT(bias.shape(0) == num_recv_tokens and bias.shape(1) == hidden); + bias_ptrs[i] = bias.data_or((void*)nullptr); + } + + // Combine data + auto recv_x = Tensor({num_recv_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * byte_size(x.dtype()) + + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk + * sizeof(float) // Top-k weight buffer + <= num_nvl_bytes); + intranode::combine(CUDA_R_16BF, + recv_x.data_or((void*)nullptr), + recv_topk_weights_ptr, + x.data_or((void*)nullptr), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + src_idx.data_or((int*)nullptr), + rank_prefix_matrix.data(), + channel_prefix_matrix.data(), + send_head.data_or((int*)nullptr), + num_tokens, + num_recv_tokens, + hidden, + num_topk, + buffer_ptrs_gpu, + rank, + num_ranks, + turbomind::core::Context::stream().handle(), + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); + + return {recv_x, std::nullopt}; +} + +std::tuple, + Tensor, + Tensor, + Tensor> +Buffer::low_latency_dispatch(const Tensor& x, + const Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0) +{ + // Tensor checks + // By default using `ptp128c` FP8 cast + EP_HOST_ASSERT(x.ndim() == 2 and x.is_contiguous() and x.dtype() == turbomind::kBfloat16); + EP_HOST_ASSERT(x.shape(1) % sizeof(int4) == 0 and x.shape(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.ndim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.shape(0) == topk_idx.shape(0) and x.shape(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + + // Diagnosis tensors + EP_HOST_ASSERT(not cumulative_local_expert_recv_stats.has_value()); + EP_HOST_ASSERT(not dispatch_wait_recv_cost_stats.has_value()); + // if (cumulative_local_expert_recv_stats.has_value()) { + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dtype() == turbomind::kInt32); + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->ndim() == 1 + // and cumulative_local_expert_recv_stats->is_contiguous()); + // EP_HOST_ASSERT(cumulative_local_expert_recv_stats->shape(0) == num_experts / num_ranks); + // } + // if (dispatch_wait_recv_cost_stats.has_value()) { + // EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dtype() == turbomind::kInt64); + // EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->ndim() == 1 and + // dispatch_wait_recv_cost_stats->is_contiguous()); EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->shape(0) == + // num_ranks); + // } + + auto num_tokens = static_cast(x.shape(0)); + auto hidden = static_cast(x.shape(1)); + auto num_topk = static_cast(topk_idx.shape(1)); + auto num_local_experts = num_experts / num_ranks; + + // Buffer control + LowLatencyLayout layout(rdma_ll_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_ll_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Allocate packed tensors + auto packed_recv_x = Tensor( + {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.dtype(), turbomind::kDEVICE); + auto packed_recv_src_info = Tensor( + {num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, turbomind::kInt32, turbomind::kDEVICE); + auto packed_recv_layout_range = Tensor({num_local_experts, num_ranks}, turbomind::kInt64, turbomind::kDEVICE); + auto packed_recv_count = Tensor({num_local_experts}, turbomind::kInt32, turbomind::kDEVICE); + + // Allocate column-majored scales + auto packed_recv_x_scales = std::optional(); + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 + and "TMA requires the number of tokens to be multiple of 4"); + + if (use_fp8) { + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = + Tensor({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + turbomind::kFloat32, + turbomind::kDEVICE); + } + else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = + Tensor({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + turbomind::kInt32, + turbomind::kDEVICE); + } + packed_recv_x_scales = packed_recv_x_scales->transpose(1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_or((float*)nullptr); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + const int phases = LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE; + auto dev_comm = comm->get_device_communicator(true); + auto nccl_win = comm->get_device_nccl_window(rdma_ll_buffer_ptr); + auto signals_base = comm->get_signals_base(low_latency_buffer_idx, true); + + internode_ll::dispatch( + packed_recv_x.raw_data(), + packed_recv_x_scales_ptr, + packed_recv_src_info.data(), + packed_recv_layout_range.data(), + packed_recv_count.data(), + mask_buffer_ptr, + nullptr, + nullptr, + buffer.dispatch_rdma_recv_data_buffer, + buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_send_buffer, + reinterpret_cast(buffer.dispatch_rdma_recv_data_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.dispatch_rdma_recv_count_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.dispatch_rdma_send_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + x.raw_data(), + topk_idx.data(), + next_clean_meta.first, + next_clean_meta.second, + num_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_fp8, + round_scale, + use_ue8m0, + workspace, + num_device_sms, + nccl_win, + dev_comm, + signals_base, + turbomind::core::Context::stream().handle(), + phases); + + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range}; +} + +std::tuple // +Buffer::low_latency_combine(const Tensor& x, + const Tensor& topk_idx, + const Tensor& topk_weights, + const Tensor& src_info, + const Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + const std::optional& out) +{ + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + EP_HOST_ASSERT(x.ndim() == 3 and x.is_contiguous() and x.dtype() == turbomind::kBfloat16); + EP_HOST_ASSERT(x.shape(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.shape(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.shape(2) % sizeof(int4) == 0 and x.shape(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.ndim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.shape(0) == topk_weights.shape(0) and topk_idx.shape(1) == topk_weights.shape(1)); + EP_HOST_ASSERT(topk_idx.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(topk_weights.ndim() == 2 and topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.shape(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.dtype() == turbomind::kFloat32); + EP_HOST_ASSERT(src_info.ndim() == 2 and src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.dtype() == turbomind::kInt32 and x.shape(0) == src_info.shape(0)); + EP_HOST_ASSERT(layout_range.ndim() == 2 and layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.dtype() == turbomind::kInt64); + EP_HOST_ASSERT(layout_range.shape(0) == num_experts / num_ranks and layout_range.shape(1) == num_ranks); + + EP_HOST_ASSERT(not combine_wait_recv_cost_stats.has_value()); + // if (combine_wait_recv_cost_stats.has_value()) { + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->dtype() == turbomind::kInt64); + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->ndim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); + // EP_HOST_ASSERT(combine_wait_recv_cost_stats->shape(0) == num_ranks); + // } + + auto hidden = static_cast(x.shape(2)); + auto num_topk = static_cast(topk_weights.shape(1)); + auto num_combined_tokens = static_cast(topk_weights.shape(0)); + + // Buffer control + LowLatencyLayout layout(rdma_ll_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_ll_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Allocate output tensor + Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->ndim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->shape(0) == num_combined_tokens and out->shape(1) == hidden); + EP_HOST_ASSERT(out->dtype() == x.dtype()); + combined_x = out.value(); + } + else { + combined_x = Tensor({num_combined_tokens, hidden}, x.dtype(), turbomind::kDEVICE); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + const int phases = LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE; + auto dev_comm = comm->get_device_communicator(true); + auto nccl_win = comm->get_device_nccl_window(rdma_ll_buffer_ptr); + auto signals_base = comm->get_signals_base(low_latency_buffer_idx, true); + + internode_ll::combine( + combined_x.data_or((void*)nullptr), + buffer.combine_rdma_recv_data_buffer, + buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, + reinterpret_cast(buffer.combine_rdma_recv_data_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.combine_rdma_recv_flag_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + reinterpret_cast(buffer.combine_rdma_send_buffer) - reinterpret_cast(rdma_ll_buffer_ptr), + x.raw_data(), + topk_idx.data_or((topk_idx_t*)nullptr), + topk_weights.data_or((float*)nullptr), + src_info.data(), + layout_range.data(), + mask_buffer_ptr, + nullptr, + next_clean_meta.first, + next_clean_meta.second, + num_combined_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_logfmt, + workspace, + num_device_sms, + nccl_win, + dev_comm, + signals_base, + turbomind::core::Context::stream().handle(), + phases, + zero_copy); + + return {combined_x}; +} + +Config Buffer::get_dispatch_config() +{ + static std::unordered_map config_map = { + {2, Config(num_sms, 24, 256, 6, 128)}, + {4, Config(num_sms, 6, 256, 6, 128)}, + {8, Config(num_sms, 6, 256, 6, 128)}, + {16, Config(num_sms, 36, 288, 20, 128)}, + {24, Config(num_sms, 32, 288, 8, 128)}, + {32, Config(num_sms, 32, 288, 8, 128)}, + {48, Config(num_sms, 32, 288, 8, 128)}, + {64, Config(num_sms, 32, 288, 8, 128)}, + {96, Config(num_sms, 20, 480, 12, 128)}, + {128, Config(num_sms, 20, 560, 12, 128)}, + {144, Config(num_sms, 32, 720, 12, 128)}, + {160, Config(num_sms, 28, 720, 12, 128)}, + }; + const auto it = config_map.find(num_ranks); + TM_CHECK(it != config_map.end()); + return it->second; +} + +Config Buffer::get_combine_config() +{ + static std::unordered_map config_map = { + {2, Config(num_sms, 10, 256, 6, 128)}, + {4, Config(num_sms, 9, 256, 6, 128)}, + {8, Config(num_sms, 4, 256, 6, 128)}, + {16, Config(num_sms, 4, 288, 12, 128)}, + {24, Config(num_sms, 1, 288, 8, 128)}, + {32, Config(num_sms, 1, 288, 8, 128)}, + {48, Config(num_sms, 1, 288, 8, 128)}, + {64, Config(num_sms, 1, 288, 8, 128)}, + {96, Config(num_sms, 1, 480, 8, 128)}, + {128, Config(num_sms, 1, 560, 8, 128)}, + {144, Config(num_sms, 2, 720, 8, 128)}, + {160, Config(num_sms, 2, 720, 8, 128)}, + }; + const auto it = config_map.find(num_ranks); + TM_CHECK(it != config_map.end()); + return it->second; +} + +}; // namespace deep_ep diff --git a/src/turbomind/comm/nccl/deep_ep/deep_ep.hpp b/src/turbomind/comm/nccl/deep_ep/deep_ep.hpp new file mode 100644 index 0000000000..2015030336 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/deep_ep.hpp @@ -0,0 +1,225 @@ +#pragma once + +#include "config.hpp" + +#include "gin_backend.h" +#include "kernels/configs.cuh" +#include "kernels/exception.cuh" +#include "src/turbomind/comm/device_comm.h" +#include "src/turbomind/comm/host_comm.h" +#include "src/turbomind/core/tensor.h" + +#include + +#include +#include +#include + +using turbomind::comm::HostComm; +using turbomind::comm::DeviceComm; +using turbomind::core::Tensor; +using turbomind::core::Tensor_; +using turbomind::core::Buffer_; + +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(bool use_fabric); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); + +private: + bool use_fabric; +}; +} // namespace shared_memory + +namespace deep_ep { + +class Buffer { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + +public: + HostComm h_comm; + DeviceComm ipc_comm; + int num_sms{24}; + + std::shared_ptr comm; + + // Low-latency mode buffer + int low_latency_buffer_idx = 0; + bool low_latency_mode = false; + + // NVLink Buffer + int64_t num_nvl_bytes; + void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_ptrs_gpu = nullptr; + + // NVSHMEM Buffer + int64_t num_rdma_bytes; + int64_t num_ll_rdma_bytes; + void* rdma_buffer_ptr = nullptr; + void* rdma_ll_buffer_ptr = nullptr; + + // Shrink mode buffer + bool enable_shrink = false; + int* mask_buffer_ptr = nullptr; + int* sync_buffer_ptr = nullptr; + + // Device info and communication + int device_id; + int num_device_sms; + int rank, rdma_rank, nvl_rank; + int num_ranks, num_rdma_ranks, num_nvl_ranks; + int qps_per_rank; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; + + // After IPC/NVSHMEM synchronization, this flag will be true + bool available = false; + + // After `destroy()` be called, this flag will be true + bool destroyed = false; + + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; + + // Workspace + void* workspace = nullptr; + + // Host-side MoE info + volatile int* moe_recv_counter = nullptr; + int* moe_recv_counter_mapped = nullptr; + + // Host-side expert-level MoE info + volatile int* moe_recv_expert_counter = nullptr; + int* moe_recv_expert_counter_mapped = nullptr; + + // Host-side RDMA-level MoE info + volatile int* moe_recv_rdma_counter = nullptr; + int* moe_recv_rdma_counter_mapped = nullptr; + + shared_memory::SharedMemoryAllocator shared_memory_allocator; + + Buffer(int rank, // + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + int64_t num_ll_rdma_bytes, + bool low_latency_mode, + bool enable_shrink, + bool use_fabric, + int qps_per_rank, + HostComm h_comm); + + Buffer(): shared_memory_allocator{false} {}; + + ~Buffer() = default; + + void allocate_sync_nvl_buffer(); + + void allocate_rdma_buffer(); + + bool is_available() const; + + bool is_internode_available() const; + + int get_num_rdma_ranks() const; + + int get_rdma_rank() const; + + int get_root_rdma_rank(bool global) const; + + int get_local_device_id() const; + + void destroy(); + + std::tuple, Tensor, Tensor> // + get_dispatch_layout(const Tensor& topk_idx, int num_experts); + + std::tuple, + std::optional, + std::optional, + std::vector, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor, + Tensor> + intranode_dispatch(const Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config); + + std::tuple> + intranode_combine(const Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const Tensor& src_idx, + const Tensor& rank_prefix_matrix, + const Tensor& channel_prefix_matrix, + Tensor& send_head, + const Config& config); + + std::tuple, + Tensor, + Tensor, + Tensor> + low_latency_dispatch(const Tensor& x, + const Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0); + + std::tuple // + low_latency_combine(const Tensor& x, + const Tensor& topk_idx, + const Tensor& topk_weights, + const Tensor& src_info, + const Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + const std::optional& out = std::nullopt); + + Config get_dispatch_config(); + + Config get_combine_config(); +}; + +}; // namespace deep_ep diff --git a/src/turbomind/comm/nccl/deep_ep/gin_backend.cu b/src/turbomind/comm/nccl/deep_ep/gin_backend.cu new file mode 100644 index 0000000000..4d7834c5b0 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/gin_backend.cu @@ -0,0 +1,244 @@ +#include "src/turbomind/comm/nccl/deep_ep/gin_backend.h" + +#include "src/turbomind/comm/nccl/deep_ep/kernels/configs.cuh" +#include "src/turbomind/comm/nccl/deep_ep/kernels/exception.cuh" +#include "src/turbomind/core/check.h" +#include "src/turbomind/core/context.h" +#include "src/turbomind/utils/logger.h" + +#include + +namespace deep_ep { +namespace internode { + +NCCLGINBackend::~NCCLGINBackend() +{ + if (initialized_) { + finalize(); + } +} + +int NCCLGINBackend::init( + const std::vector& root_unique_id_val, int rank, int num_ranks, bool low_latency_mode, int qps_per_rank) +{ + if (initialized_) { + return rank_; + } + TM_CHECK_EQ(low_latency_mode, true); // compatible with low latency mode + + // Check if P2P/NVLink is disabled via environment variable + const char* nccl_disable_p2p = std::getenv("NCCL_P2P_DISABLE"); + p2p_disabled_ = (nccl_disable_p2p != nullptr && std::string(nccl_disable_p2p) == "1"); + + // Determine communication topology based on mode + const int gpus_per_server = NUM_MAX_NVL_PEERS; + int comm_rank; // Rank to use for NCCL initialization + int comm_nranks; // Number of ranks in communicator + int color = -1; // Symmetric group ID (only for high throughput mode) + int group_rank = -1; // Rank within symmetric group + + if (low_latency_mode) { + // LOW LATENCY MODE: Connect to all ranks + comm_rank = rank; + comm_nranks = num_ranks; + } + else { + // HIGH THROUGHPUT MODE: Connect only to symmetric RDMA ranks + color = rank % gpus_per_server; + group_rank = rank / gpus_per_server; + comm_nranks = (num_ranks + gpus_per_server - 1) / gpus_per_server; + comm_rank = group_rank; + } + + size_t single_id_size = sizeof(ncclUniqueId); + size_t expected_ids = gpus_per_server; + EP_HOST_ASSERT(root_unique_id_val.size() == expected_ids * single_id_size + && "Number of unique IDs doesn't match NUM_MAX_NVL_PEERS * qps_per_rank"); + + if (rank == 0) { + // Print NCCL version from the actually loaded library + int nccl_version; + NCCL_CHECK(ncclGetVersion(&nccl_version)); + TM_LOG_DEBUG("[NCCLEP] NCCL version: %d.%d.%d (loaded library)", + nccl_version / 10000, + (nccl_version % 10000) / 100, + nccl_version % 100); + } + + // All gpus form a group for low latency compatible, + // otherwise, gpus with the same index across different nodes form a group. + ncclUniqueId id; + const int id_offset = (low_latency_mode) ? 0 : color * single_id_size; + std::memcpy(&id, root_unique_id_val.data() + id_offset, single_id_size); + NCCL_CHECK(ncclCommInitRank(&nccl_comm_, comm_nranks, id, comm_rank)); + + // The assumption is that kDecoupled is false when initializing SymBuffers in internode.cu + // IMPORTANT: Use global num_ranks, not comm_nranks, because kernels use global topology + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + int rdma_channel_head_signals = num_rdma_ranks * DEEP_EP_NCCL_MAX_NUM_CHANNELS; + int rdma_channel_tail_signals = num_rdma_ranks * DEEP_EP_NCCL_MAX_NUM_CHANNELS; + // + num_ht_signals_ = rdma_channel_head_signals + rdma_channel_tail_signals; + num_ll_signals_ = qps_per_rank * comm_nranks * 2; + + // Initialize Device Communicators + auto CreateDevComm = [&](ncclDevComm_t& comm, int signals) { + ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; + reqs.barrierCount = MAX_BARRIER_SESSIONS; + reqs.ginSignalCount = signals + MAX_BARRIER_SESSIONS; + reqs.ginConnectionType = NCCL_GIN_CONNECTION_FULL; + reqs.ginContextCount = qps_per_rank; + NCCL_CHECK(ncclDevCommCreate(nccl_comm_, &reqs, &comm)); + }; + CreateDevComm(dev_ll_comm_, num_ll_signals_); // low latency mode + CreateDevComm(dev_ht_comm_, num_ht_signals_); // high throughput mode + + // Allocate barrier dummy variable + CUDA_CHECK(cudaMalloc(reinterpret_cast(&d_barrier_var_), sizeof(int))); + CUDA_CHECK(cudaMemset(d_barrier_var_, 0, sizeof(int))); + + // Store global rank and num_ranks (for external API) + rank_ = rank; + num_ranks_ = num_ranks; + + // Store communicator-specific ranks for internal use + comm_rank_ = comm_rank; + comm_nranks_ = comm_nranks; + + initialized_ = true; + TM_LOG_DEBUG( + "[NCCLEP] Initialized global rank %d/%d (comm rank %d/%d)", rank_, num_ranks_, comm_rank_, comm_nranks_); + + return rank_; +} + +void NCCLGINBackend::finalize() +{ + TM_LOG_DEBUG("[NCCLEP][%d] Finalizing", rank_); + if (!initialized_) { + return; + } + + // Destroy device communicators + auto DestroyDevComm = [&](ncclDevComm_t& comm, std::string_view key) { + ncclResult_t res = ncclDevCommDestroy(nccl_comm_, &comm); + if (res != ncclSuccess) { + TM_LOG_ERROR("[NCCLEP][%d] Failed to destroy device communication %s: %s", + rank_, + key.data(), + ncclGetErrorString(res)); + } + }; + DestroyDevComm(dev_ll_comm_, "low latency mode"); + DestroyDevComm(dev_ht_comm_, "high throughput mode"); + + for (auto& [ptr, win] : wins_) { + TM_LOG_WARNING("[NCCLEP][%d] Memory %p is not deregistered", rank_, ptr); + } + for (auto& [ptr, size] : buffers_) { + TM_LOG_WARNING("[NCCLEP][%d] Allocation (%p, %lu) is not freed", rank_, ptr, size); + } + + // Free barrier dummy variable + if (d_barrier_var_ != nullptr) { + cudaFree(d_barrier_var_); + d_barrier_var_ = nullptr; + } + // Destroy all communicators + ncclCommFinalize(nccl_comm_); + ncclCommDestroy(nccl_comm_); + + TM_LOG_DEBUG("[NCCLEP][%d] Destroyed NCCL communicator", rank_); + initialized_ = false; +} + +void NCCLGINBackend::barrier() +{ + TM_CHECK_EQ(initialized_, true); + TM_CHECK_NE(d_barrier_var_, nullptr); + + cudaStream_t stream = turbomind::core::Context::stream().handle(); + NCCL_CHECK(ncclGroupStart()); + NCCL_CHECK(ncclAllReduce(d_barrier_var_, d_barrier_var_, 1, ncclInt, ncclSum, nccl_comm_, stream)); + NCCL_CHECK(ncclGroupEnd()); +} + +void* NCCLGINBackend::alloc(size_t size, size_t /*alignment*/) +{ + TM_CHECK_EQ(initialized_, true); + + void* ptr = nullptr; + // NCCL memory is already aligned to page size, so alignment parameter is ignored for now. + NCCL_CHECK(ncclMemAlloc(&ptr, size)); + buffers_.emplace(ptr, size); + return ptr; +} + +void NCCLGINBackend::register_memory(void* ptr, size_t size) +{ + TM_CHECK_EQ(initialized_, true); + TM_CHECK_EQ(buffers_.find(ptr) != buffers_.end(), true); + TM_CHECK_EQ(wins_.find(ptr) == wins_.end(), true); + ncclWindow_t win{}; + NCCL_CHECK(ncclCommWindowRegister(nccl_comm_, ptr, size, &win, 0)); + wins_.emplace(ptr, win); +} + +void NCCLGINBackend::free(void* ptr) +{ + TM_CHECK_EQ(initialized_, true); + auto it = wins_.find(ptr); + TM_CHECK_EQ(it != wins_.end(), true); + NCCL_CHECK(ncclCommWindowDeregister(nccl_comm_, it->second)); + NCCL_CHECK(ncclMemFree(ptr)); + wins_.erase(it); + buffers_.erase(ptr); +} + +int NCCLGINBackend::get_rank() const +{ + TM_CHECK_NE(rank_, -1); + return rank_; +} + +int NCCLGINBackend::get_num_ranks() const +{ + TM_CHECK_NE(num_ranks_, -1); + return num_ranks_; +} + +bool NCCLGINBackend::is_p2p_disabled() const +{ + return p2p_disabled_; +} + +unsigned NCCLGINBackend::get_signals_base(int buffer_idx, bool low_latency_mode) const +{ + if (low_latency_mode) { + EP_HOST_ASSERT(buffer_idx == 0 || buffer_idx == 1); + TM_CHECK_NE(num_ll_signals_, 0); + return buffer_idx * num_ll_signals_ / 2; + } + else { + EP_HOST_ASSERT(buffer_idx == 0); + TM_CHECK_NE(num_ht_signals_, 0); + return 0; + } +} + +ncclWindow_t NCCLGINBackend::get_device_nccl_window(void* ptr) +{ + TM_CHECK_EQ(initialized_, true); + auto it = wins_.find(ptr); + TM_CHECK_EQ(it != wins_.end(), true); + return it->second; +} + +ncclDevComm NCCLGINBackend::get_device_communicator(bool low_latency_mode) const +{ + TM_CHECK_EQ(initialized_, true); + return low_latency_mode ? dev_ll_comm_ : dev_ht_comm_; +} + +} // namespace internode +} // namespace deep_ep diff --git a/src/turbomind/comm/nccl/deep_ep/gin_backend.h b/src/turbomind/comm/nccl/deep_ep/gin_backend.h new file mode 100644 index 0000000000..a10de528c3 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/gin_backend.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include + +#include + +#define DEEP_EP_GIN_MAX_CONTEXTS 32 +#define DEEP_EP_NCCL_GIN_CTXS_PER_COMM 4 +#define DEEP_EP_NCCL_MAX_NUM_CHANNELS 32 // Max number of local experts per GPU + +namespace deep_ep { +namespace internode { + +struct NcclGinMemHandle { + void* ptr = nullptr; +}; + +class NCCLGINBackend { +public: + NCCLGINBackend(): initialized_(false), rank_(-1), num_ranks_(-1) {} + + ~NCCLGINBackend(); + + // Required interface methods + int init(const std::vector& root_unique_id_val, + int rank, + int num_ranks, + bool low_latency_mode, + int qps_per_rank); + + void finalize(); + void barrier(); + + // Memory management interface methods + void* alloc(size_t size, size_t alignment); + void register_memory(void* ptr, size_t size); // NCCL-specific: register allocated memory with communicators + void free(void* ptr); + + int get_rank() const; + int get_num_ranks() const; + + // NCCL-specific methods + bool is_p2p_disabled() const; + + // NCCL specific methods + unsigned get_signals_base(int buffer_idx, bool low_latency_mode) const; + + // Device arrays for kernels + ncclWindow_t get_device_nccl_window(void* ptr); + ncclDevComm get_device_communicator(bool low_latency_mode) const; + +private: + bool initialized_ = false; + bool p2p_disabled_ = false; // True if P2P/NVLink is disabled + int rank_ = -1; // Global rank (for external API) + int num_ranks_ = -1; // Global num_ranks (for external API) + int comm_rank_ = -1; // Rank within NCCL communicator + int comm_nranks_ = -1; // Number of ranks in NCCL communicator + + ncclComm_t nccl_comm_; + + ncclDevComm_t dev_ht_comm_{}; + ncclDevComm_t dev_ll_comm_{}; + + std::unordered_map wins_; + std::unordered_map buffers_; + + // GIN signal management + int num_ht_signals_ = 0; + int num_ll_signals_ = 0; + + // GIN barriers -- assume 32 rdma ranks + const int MAX_BARRIER_SESSIONS = 32; + + // Barrier variable + int* d_barrier_var_ = nullptr; +}; + +} // namespace internode +} // namespace deep_ep diff --git a/src/turbomind/comm/nccl/deep_ep/kernels/api.cuh b/src/turbomind/comm/nccl/deep_ep/kernels/api.cuh new file mode 100644 index 0000000000..fe0d734a61 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/kernels/api.cuh @@ -0,0 +1,380 @@ +// clang-format off +#pragma once + +#include +#include + +#include + +#include "configs.cuh" + +namespace deep_ep { + +// Intranode runtime +namespace intranode { + +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); + +} // namespace intranode + +// Internode runtime +namespace internode { + +class NCCLGINBackend; + +std::vector get_unique_id(); + +int init(const std::vector& root_unique_id_val, + int rank, + int num_ranks, + bool low_latency_mode, + int qps_per_rank, + NCCLGINBackend* comm); + +void* alloc(size_t size, size_t alignment, NCCLGINBackend* comm); + +void register_memory(void* ptr, size_t size, NCCLGINBackend* comm); + +void free(void* ptr, NCCLGINBackend* comm); + +void barrier(NCCLGINBackend* comm); + +void finalize(NCCLGINBackend* comm); + +} // namespace internode + +// Layout kernels +namespace layout { + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream); + +} // namespace layout + +// Intranode kernels +namespace intranode { + +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int* moe_recv_expert_counter_ten, + int num_experts, + int num_tokens, + const bool* is_token_in_rank, + int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, + int num_memset_int, + int expert_alignment, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int num_sms); + +void cached_notify_dispatch(const int* rank_prefix_matrix, + int num_memset_int, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + int num_ranks, + cudaStream_t stream); + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride, + int scale_hidden_stride, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); + +void cached_notify_combine(void** buffer_ptrs, + int* send_head, + int num_channels, + int num_recv_tokens, + int num_memset_int, + int** barrier_signal_ptrs, + int rank, + int num_ranks, + cudaStream_t stream); + +void combine(cudaDataType_t type, + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); + +} // namespace intranode + +// Internode kernels +namespace internode { + +int get_source_meta_bytes(); + +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode); + +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode); + +void combine(cudaDataType_t type, + void* combined_x, + float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + +} // namespace internode + +// Internode low-latency kernels +namespace internode_ll { + +void clean_low_latency_buffer(int* clean_0, + int num_clean_int_0, + int* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer, + int* sync_buffer, + cudaStream_t stream); + +void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases); + +void combine(void* combined_x, + void* rdma_recv_x, + int* rdma_recv_flag, + void* rdma_send_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_flag_offset, + size_t rdma_send_x_offset, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer, + int64_t* combine_wait_recv_cost_stats, + int* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases, + bool zero_copy); + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); + +void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream); + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream); + +void set_p2p_disabled_flag(bool disabled); + +} // namespace internode_ll + +} // namespace deep_ep + +// clang-format on diff --git a/src/turbomind/comm/nccl/deep_ep/kernels/buffer.cuh b/src/turbomind/comm/nccl/deep_ep/kernels/buffer.cuh new file mode 100644 index 0000000000..673fc86ae4 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/kernels/buffer.cuh @@ -0,0 +1,134 @@ +// clang-format off +#pragma once + +#include "configs.cuh" +#include "exception.cuh" + +namespace deep_ep { + +template +struct Buffer { +private: + uint8_t* ptr; + +public: + int64_t total_bytes; + + __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} + + __device__ __forceinline__ Buffer(void*& gbl_ptr, int num_elems, int offset = 0) { + total_bytes = num_elems * sizeof(dtype_t); + ptr = static_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ Buffer advance_also(void*& gbl_ptr) { + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer() { return reinterpret_cast(ptr); } + + __device__ __forceinline__ dtype_t& operator[](int idx) { return buffer()[idx]; } +}; + +template +struct AsymBuffer { +private: + uint8_t* ptrs[kNumRanks]; + int64_t num_bytes; + +public: + int64_t total_bytes; + + __device__ __forceinline__ AsymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + for (int i = 0; i < kNumRanks; ++i) { + ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; + } + } + + __device__ __forceinline__ void advance(int shift) { + #pragma unroll + for (int i = 0; i < kNumRanks; ++i) + ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); + } + + __device__ __forceinline__ AsymBuffer advance_also(void*& gbl_ptr) { + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + return *this; + } + + template + __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { + for (int i = 0; i < kNumAlsoRanks; ++i) + gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[0] + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); + } +}; + +template +struct SymBuffer { +private: + // NOTES: for non-decoupled case, `recv_ptr` is not used + uint8_t* send_ptr; + uint8_t* recv_ptr; + int64_t num_bytes; + +public: + int64_t total_bytes; + + __device__ __forceinline__ SymBuffer(void*& gbl_ptr, int num_elems, int num_ranks, int sm_id = 0, int num_sms = 1) { + num_bytes = num_elems * sizeof(dtype_t); + + int64_t per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = static_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = static_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = static_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); + return reinterpret_cast(recv_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } +}; + +} // namespace deep_ep + +// clang-format on diff --git a/src/turbomind/comm/nccl/deep_ep/kernels/configs.cuh b/src/turbomind/comm/nccl/deep_ep/kernels/configs.cuh new file mode 100644 index 0000000000..9669120dcf --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/kernels/configs.cuh @@ -0,0 +1,81 @@ +#pragma once + +#define NUM_MAX_NVL_PEERS 8 +#define NUM_MAX_RDMA_PEERS 20 +#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) +#define NUM_MAX_LOCAL_EXPERTS 1024 +#define NUM_BUFFER_ALIGNMENT_BYTES 128 + +#define FINISHED_SUM_TAG 1024 +#define NUM_WAIT_NANOSECONDS 500 + +#ifndef ENABLE_FAST_DEBUG +#define NUM_CPU_TIMEOUT_SECS 100 +#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s +#else +#define NUM_CPU_TIMEOUT_SECS 10 +#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s +#endif + +#define LOW_LATENCY_SEND_PHASE 1 +#define LOW_LATENCY_RECV_PHASE 2 + +// Make CLion CUDA indexing work +#ifdef __CLION_IDE__ +#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +#endif + +// Define __CUDACC_RDC__ to ensure proper extern declarations for NVSHMEM device symbols +#ifndef DISABLE_NVSHMEM +#ifndef __CUDACC_RDC__ +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +#endif +#endif + +// Remove Torch restrictions +#ifdef __CUDA_NO_HALF_CONVERSIONS__ +#undef __CUDA_NO_HALF_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_HALF_OPERATORS__ +#undef __CUDA_NO_HALF_OPERATORS__ +#endif +#ifdef __CUDA_NO_HALF2_OPERATORS__ +#undef __CUDA_NO_HALF2_OPERATORS__ +#endif +#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#undef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_BFLOAT162_OPERATORS__ +#undef __CUDA_NO_BFLOAT162_OPERATORS__ +#endif + +#include +#include + +#include + +#ifndef DISABLE_SM90_FEATURES +#include +#else +// Ampere does not support FP8 features +#define __NV_E4M3 0 +#define __NV_E5M2 1 +typedef int __nv_fp8_interpretation_t; +typedef int __nv_fp8x4_e4m3; +typedef uint8_t __nv_fp8_storage_t; +#endif + +namespace deep_ep { + +#ifndef TOPK_IDX_BITS +#define TOPK_IDX_BITS 64 +#endif + +#define INT_BITS_T2(bits) int##bits##_t +#define INT_BITS_T(bits) INT_BITS_T2(bits) +typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t +#undef INT_BITS_T +#undef INT_BITS_T2 + +} // namespace deep_ep diff --git a/src/turbomind/comm/nccl/deep_ep/kernels/exception.cuh b/src/turbomind/comm/nccl/deep_ep/kernels/exception.cuh new file mode 100644 index 0000000000..d6086f4343 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/kernels/exception.cuh @@ -0,0 +1,76 @@ +// clang-format off +#pragma once + +#include +#include + +#include "configs.cuh" + +#ifndef EP_STATIC_ASSERT +#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +class EPException : public std::exception { +private: + std::string message = {}; + +public: + explicit EPException(const char* name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } + + const char* what() const noexcept override { return message.c_str(); } +}; + +#ifndef CUDA_CHECK +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = (cmd); \ + if (e != cudaSuccess) { \ + throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + } \ + } while (0) +#endif + +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ + do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char* error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ + } while (0) +#endif + +#ifndef EP_HOST_ASSERT +#define EP_HOST_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + throw EPException("Assertion", __FILE__, __LINE__, #cond); \ + } \ + } while (0) +#endif + +#ifndef EP_DEVICE_ASSERT +#define EP_DEVICE_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while (0) +#endif + +#ifndef NCCL_CHECK +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t e = (cmd); \ + if (e != ncclSuccess) { \ + throw EPException("NCCL", __FILE__, __LINE__, ncclGetErrorString(e)); \ + } \ + } while (0) +#endif + +// clang-format on diff --git a/src/turbomind/comm/nccl/deep_ep/kernels/internode_ll.cu b/src/turbomind/comm/nccl/deep_ep/kernels/internode_ll.cu new file mode 100644 index 0000000000..7bae1073e9 --- /dev/null +++ b/src/turbomind/comm/nccl/deep_ep/kernels/internode_ll.cu @@ -0,0 +1,1348 @@ +// clang-format off +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" + +#include +#include +#include +#include + +using namespace cooperative_groups; +namespace cg = cooperative_groups; +#define ENABLE_NCCL 1 + +namespace deep_ep { + +namespace internode_ll { + +template +__forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) { + if (mask_buffer_ptr == nullptr) { + return false; + } + if constexpr (use_warp_sync) { + return __shfl_sync(0xffffffff, ld_acquire_global(mask_buffer_ptr + rank), 0) != 0; + } else { + return ld_acquire_global(mask_buffer_ptr + rank) != 0; + } +} + +// Device constant for P2P/NVLink disabled flag +// Set to true to force RDMA path, false to allow P2P when available +// Default is false (P2P enabled), updated from host via CLI option +__device__ __constant__ bool d_p2p_disabled = false; + +// Get peer-to-peer pointer for NCCL +// Returns dst_ptr if NVLink is available, 0 otherwise +// offset parameter allows callers to pass a pre-calculated offset for the destination +__device__ __forceinline__ uint64_t nccl_get_p2p_ptr(const uint64_t& dst_ptr, + const size_t& offset, + const int& rank, + const int& dst_rank, + const ncclWindow_t dev_win, + ncclDevComm dev_comm) +{ + // Local rank, no need for peer mapping + if (rank == dst_rank) + return dst_ptr; + + // If P2P is globally disabled, always use RDMA path + if (d_p2p_disabled) + return 0; + + // P2P/NVLink only works between ranks on the same node (LSA team) + // Use NCCL team APIs to check if dst_rank is in the same LSA team + ncclTeam lsa = ncclTeamLsa(dev_comm); + ncclTeam world = ncclTeamWorld(dev_comm); + if (!ncclTeamRankIsMember(lsa, world, dst_rank)) { + return 0; // Different nodes (not in same LSA team), must use RDMA + + } + + auto const p2p_ptr = reinterpret_cast(ncclGetPeerPointer(dev_win, offset, dst_rank)); + return p2p_ptr ? p2p_ptr : 0; +} + + +template +__global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, /* nccl backend*/ + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* atomic_counter_per_expert, + int* atomic_finish_counter_per_expert, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + int num_warp_groups, + int num_warps_per_group, + bool round_scale, + int phases, + ncclDevComm dev_comm, + const ncclWindow_t nccl_win, + unsigned signals_base +) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_sms = static_cast(gridDim.x); + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / num_warps_per_group; + const auto sub_warp_id = warp_id % num_warps_per_group; + const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + + // May extract UE8M0 from the scales + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; + EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + + // FP8 staffs + constexpr int kNumPerChannels = 128; + const int num_scales = kHidden / kNumPerChannels; + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); + + // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales + // NOTES: currently we have 3 reserved int fields for future use + using vec_t = std::conditional_t; + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + + // Expert counts + constexpr int kNumMaxWarpGroups = 32; + __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_DISPATCH_RECV; + + // There are 2 kinds of warps in this part: + // 1. The first-kind warps for FP8 cast and sending top-k tokens + // 2. The last warp for reading `topk_idx` and count for per-expert information + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerRead) == 0, "Invalid hidden"); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + const auto num_threads = (num_warps - 1) * 32; + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + + // Overlap top-k index read and source token index writes + auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + + // FP8 cast + EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); + #pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if constexpr (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } + asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); + + // Issue IBGDA sends + if (dst_expert_idx >= 0) { + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + + size_t expected_dst_offset = rdma_recv_x_offset + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; + const auto dst_p2p_ptr = + nccl_get_p2p_ptr(dst_ptr, expected_dst_offset, rank, dst_rank, nccl_win, dev_comm); + + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { + size_t expected_src_offset = rdma_x_offset + token_idx * num_bytes_per_msg; + ncclGin net(dev_comm, dst_expert_local_idx); + ncclTeam world = ncclTeamWorld(dev_comm); + net.put(world, + dst_rank, + nccl_win, + expected_dst_offset, + nccl_win, + expected_src_offset, + num_bytes_per_msg, + ncclGin_None{}, // no signal + ncclGin_None{}, // no counter + ncclCoopWarp()); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + } + + // Increase counter after finishing + __syncwarp(); + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + } + } + } else if (warp_id == num_warps - 1) { + EP_DEVICE_ASSERT(num_sms > 1); + if (sm_id == 0) { + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + // Notify before executing `int_p` + __syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + // This SM should be responsible for some destination experts, read `topk_idx` for them + int expert_count[kNumMaxWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * num_warp_groups; + const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); + + // Per lane count + #pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += 32) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= expert_begin_idx and idx < expert_end_idx) + expert_count[idx - expert_begin_idx]++; + } + + // Warp reduce + #pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } + } + } + __syncthreads(); + + // Issue count sends + if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + + // Wait local sends issued and send expert counts + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2) + ; + auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); + + size_t dst_offset = rdma_recv_count_offset + (dst_expert_local_idx * num_ranks + rank) * sizeof(int); + const auto dst_p2p_ptr = nccl_get_p2p_ptr(dst_ptr, dst_offset, rank, dst_rank, nccl_win, dev_comm); + + if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { + if (dst_p2p_ptr == 0) { // if (rank != dst_rank) { + auto signal_id = signals_base + dst_expert_local_idx * num_ranks + rank; + ncclGin net(dev_comm, dst_expert_local_idx); + ncclTeam world = ncclTeamWorld(dev_comm); + // NOTE: net.signal() is semantically cleaner but adds latency to Dispatch-Send + // and Combine-Send compared to net.put() with 0 bytes + // net.signal(world, + // dst_rank, + // ncclGin_SignalAdd{signal_id, (uint64_t)num_tokens_sent + 1}, + // ncclCoopThread(), + // ncclGin_None(), + // cuda::thread_scope_system); + net.put(world, + dst_rank, + nccl_win, + dst_offset, + nccl_win, + 0, + 0, // 0 bytes transfer + ncclGin_SignalAdd{signal_id, (uint64_t)num_tokens_sent + 1}, + ncclGin_None{}, // no counter + ncclCoopThread()); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); + } + } + + // Clean workspace for next use + atomic_counter_per_expert[responsible_expert_idx] = 0; + atomic_finish_counter_per_expert[responsible_expert_idx] = 0; + + // Clean `packed_recv_count` + if (dst_rank == 0) + packed_recv_count[dst_expert_local_idx] = 0; + } + __syncwarp(); + +// Receiving phase +LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible + if (phases & LOW_LATENCY_SEND_PHASE) + cg::this_grid().sync(); + + // Receiving and packing + if (responsible_expert_idx < num_experts) { + const auto src_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto recv_x_int4 = + static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + const auto num_aligned_scales = align_up(num_scales, sizeof(float) / sizeof(scale_t)); + const auto recv_x_scales = static_cast(packed_recv_x_scales) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + + // Shared between sub-warps in warp groups + __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; + + // Wait tokens to arrive + // NOTES: using sub-warp 1 to overlap with sub-warp 0 + int num_recv_tokens = 0, recv_token_begin_idx; + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); + if (sub_warp_id == 1 and lane_id == 0) { + auto start_time = clock64(); + uint64_t wait_recv_cost = 0; + if (not is_rank_masked(mask_buffer_ptr, src_rank)) { + size_t src_offset = rdma_recv_count_offset + (local_expert_idx * num_ranks + src_rank) * sizeof(int); + auto src_p2p_ptr = nccl_get_p2p_ptr(0x01, src_offset, rank, src_rank, nccl_win, dev_comm); + if (src_p2p_ptr == 0) { + ncclGin net(dev_comm, local_expert_idx); + uint64_t cur_value; + do { + cur_value = net.readSignal(signals_base + local_expert_idx * num_ranks + src_rank); + } while (cur_value < 1 // data not arrived + && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout + ); + net.resetSignal(signals_base + local_expert_idx * num_ranks + src_rank); + num_recv_tokens = -(int)cur_value; + } else { + while ((num_recv_tokens = ld_acquire_sys_global((rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == + 0 // data not arrived + && (wait_recv_cost = clock64() - start_time) <= NUM_TIMEOUT_CYCLES // not timeout + ); + } + } + // Do not receive tokens if rank timeout or masked + if (num_recv_tokens == 0) + num_recv_tokens = -1; + // Mask rank if timeout + if (wait_recv_cost > NUM_TIMEOUT_CYCLES) { + printf("Warning: DeepEP timeout for dispatch receive, rank %d, local_expert_idx %d, src_rank %d\n", + rank, + local_expert_idx, + src_rank); + if (mask_buffer_ptr == nullptr) + trap(); + atomicExch(mask_buffer_ptr + src_rank, 1); + } + + num_recv_tokens = -num_recv_tokens - 1; + recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + + // Add stats for diagnosis + if (cumulative_local_expert_recv_stats != nullptr) + atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); + if (dispatch_wait_recv_cost_stats != nullptr) + atomicAdd(reinterpret_cast(dispatch_wait_recv_cost_stats + src_rank), wait_recv_cost); + } + asm volatile("bar.sync %0, %1;" ::"r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + // Copy tokens + EP_DEVICE_ASSERT(num_scales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { + // Copy source info + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + // Copy data + // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + // Copy scales + if constexpr (kUseFP8) { + // Equivalent CuTe layout: + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_idx = recv_token_begin_idx + i; + const auto token_stride = num_elems_per_pack; + const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + if (lane_id < num_scales) { + const auto pack_idx = lane_id / num_elems_per_pack; + const auto elem_idx = lane_id % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + if (lane_id + 32 < num_scales) { + const auto pack_idx = (lane_id + 32) / num_elems_per_pack; + const auto elem_idx = (lane_id + 32) % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + } + } + } +} + +void dispatch(void* packed_recv_x, + void* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int* rdma_recv_count, + void* rdma_x, + size_t rdma_recv_x_offset, + size_t rdma_recv_count_offset, + size_t rdma_x_offset, + const void* x, + const topk_idx_t* topk_idx, + int* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + ncclWindow_t nccl_win, + ncclDevComm dev_comm, + unsigned signals_base, + cudaStream_t stream, + int phases) +{ + constexpr int kNumMaxTopK = 11; + const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warps_per_group = 32 / num_warp_groups; + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); + EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); + + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_sms = ceil_div(num_experts, num_warp_groups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + + // Workspace checks + auto atomic_counter_per_expert = static_cast(workspace); + auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; + EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + + // FP8 checks + if (use_ue8m0) + EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); + +#define DISPATCH_LAUNCH_CASE(hidden) \ + { \ + auto dispatch_func = dispatch