diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 3a5149d6ed..1bf38f1fc5 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -164,13 +164,22 @@ class GemmUniversal< static constexpr uint32_t NumMmaThreads = size(TiledMma{}); static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; +#ifndef CUTLASS_W4A8_MIN_BLOCKS_PER_SM +#define CUTLASS_W4A8_MIN_BLOCKS_PER_SM 1 +#endif + static constexpr uint32_t MinBlocksPerMultiprocessor = CUTLASS_W4A8_MIN_BLOCKS_PER_SM; static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v; /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; +#ifndef CUTLASS_W4A8_LOAD_REG_REQUIREMENT +#define CUTLASS_W4A8_LOAD_REG_REQUIREMENT 40 +#endif +#ifndef CUTLASS_W4A8_MMA_REG_REQUIREMENT +#define CUTLASS_W4A8_MMA_REG_REQUIREMENT 232 +#endif + static constexpr uint32_t LoadRegisterRequirement = CUTLASS_W4A8_LOAD_REG_REQUIREMENT; + static constexpr uint32_t MmaRegisterRequirement = CUTLASS_W4A8_MMA_REG_REQUIREMENT; // 1 stage ordered sequence between mainloop and epilogue producer load threads using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; diff --git a/microbenchmarks/w4a8/baseline/.gitignore b/microbenchmarks/w4a8/baseline/.gitignore new file mode 100644 index 0000000000..787dd46710 --- /dev/null +++ b/microbenchmarks/w4a8/baseline/.gitignore @@ -0,0 +1,2 @@ +build/ +profiles/ diff --git a/microbenchmarks/w4a8/baseline/CMakeLists.txt b/microbenchmarks/w4a8/baseline/CMakeLists.txt new file mode 100644 index 0000000000..468d8eb1de --- /dev/null +++ b/microbenchmarks/w4a8/baseline/CMakeLists.txt @@ -0,0 +1,39 @@ +# Standalone build for the W4A8 grouped GEMM tile/schedule sweep (Hopper SM90a). +cmake_minimum_required(VERSION 3.19) +project(w4a8_microbench_sweep LANGUAGES CXX CUDA) + +find_package(CUDAToolkit REQUIRED) + +# Locate CUTLASS three levels up by default: /microbenchmarks/w4a8/baseline -> CUTLASS_DIR. +get_filename_component(_default_cutlass_dir "${CMAKE_CURRENT_SOURCE_DIR}/../../.." ABSOLUTE) +set(CUTLASS_DIR "${_default_cutlass_dir}" CACHE PATH "CUTLASS tree") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +# sm_90a: arch-conditional MMA and full Hopper ISA (H100/H200). +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 90a) +endif() + +set(CUTLASS_INCLUDES + "${CUTLASS_DIR}/include" + "${CUTLASS_DIR}/tools/util/include" + "${CUTLASS_DIR}/examples/common" + "${CUTLASS_DIR}/examples/55_hopper_mixed_dtype_gemm" +) + +add_executable(w4a8_baseline_sweep w4a8_baseline_sweep.cu) +target_include_directories(w4a8_baseline_sweep PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/.." # for #include "common/..." + ${CUTLASS_INCLUDES} +) +target_compile_definitions(w4a8_baseline_sweep PRIVATE INT4FP8_GROUPED_SUPPORTED) +target_compile_options(w4a8_baseline_sweep PRIVATE + $<$:--expt-relaxed-constexpr> + $<$:-ftemplate-backtrace-limit=0> +) +target_link_libraries(w4a8_baseline_sweep PRIVATE CUDA::cudart) diff --git a/microbenchmarks/w4a8/baseline/profile.sh b/microbenchmarks/w4a8/baseline/profile.sh new file mode 100755 index 0000000000..44eea03bbe --- /dev/null +++ b/microbenchmarks/w4a8/baseline/profile.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash +# Profile a chosen variant (default C1) of the W4A8 grouped GEMM with Nsight Compute. +# Builds the sweep binary if needed, then runs ncu on --only= with a curated section set, +# saving the report to profiles/[suffix].ncu-rep, a details dump, and a raw-metrics CSV. +# +# Modes: +# default : SpeedOfLight + memory/scheduler/warp/instruction/launch/occupancy sections. +# PMSAMPLE=1 : adds SourceCounters + PmSampling + PmSampling_WarpStates so the SASS view +# in ncu-ui can break down stalls per source/SASS line. +# Writes to profiles/_pmsample.ncu-rep so the default report is +# preserved untouched. +# +# Requires CUDA 12.3+ and a Hopper (SM90) GPU. Pinned to GPU 4 by default. +# Needs access to NVIDIA GPU perf counters - either run as root (sudo -E ...) or set +# NVreg_RestrictProfilingToAdminUsers=0 on the nvidia kernel module. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${BUILD_DIR:-${ROOT}/build}" +PROFILES_DIR="${ROOT}/profiles" +CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}" +VARIANT="${VARIANT:-C1}" +PMSAMPLE="${PMSAMPLE:-0}" + +if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then + export PATH="${CUDA_HOME}/bin:${PATH}" +elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then + export CUDA_HOME="/usr/local/cuda" + export PATH="${CUDA_HOME}/bin:${PATH}" +fi +export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}" + +if ! command -v nvcc >/dev/null 2>&1; then + echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2 + exit 1 +fi +if ! command -v ncu >/dev/null 2>&1; then + echo "ncu (Nsight Compute) not found. Add CUDA toolkit bin/ to PATH." >&2 + exit 1 +fi + +cmake -S "${ROOT}" -B "${BUILD_DIR}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUTLASS_DIR="${CUTLASS_DIR}" \ + -DCMAKE_CUDA_ARCHITECTURES=90a \ + -DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}" >/dev/null + +cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_baseline_sweep + +mkdir -p "${PROFILES_DIR}" +REPORT_SUFFIX="" +NCU_LAUNCH_COUNT=3 +NCU_BENCH_ITERATIONS=4 + +# ncu kernel filter: the cooperative TMA-WS grouped-gemm entry. The warmup launch in run_one() +# is skipped via --launch-skip 1; we then capture timed launches with rich sections. +NCU_SECTIONS=( + --section SpeedOfLight + --section SpeedOfLight_RooflineChart + --section LaunchStats + --section Occupancy + --section MemoryWorkloadAnalysis + --section MemoryWorkloadAnalysis_Chart + --section MemoryWorkloadAnalysis_Tables + --section SchedulerStats + --section WarpStateStats + --section ComputeWorkloadAnalysis + --section InstructionStats +) + +NCU_EXTRA_ARGS=() +if [[ "${PMSAMPLE}" == "1" ]]; then + # PC sampling: attribute warp stalls to specific SASS / source lines so we can tell whether + # the ALU stalls cluster on PRMT / LOP3 (op-bound), on LDSM (smem-bound), on WGMMA (tensor), + # or on long_scoreboard (waiting for global / L2 fills). Re-encodes report with --import-source + # so ncu-ui can show the kernel's source alongside SASS. + NCU_SECTIONS+=( + --section SourceCounters + --section PmSampling + --section PmSampling_WarpStates + ) + NCU_EXTRA_ARGS+=(--import-source yes) + REPORT_SUFFIX="_pmsample" + # PC sampling needs more steady-state samples; widen the capture window. + NCU_LAUNCH_COUNT=2 + NCU_BENCH_ITERATIONS=8 +fi +REPORT_BASE="${PROFILES_DIR}/${VARIANT,,}${REPORT_SUFFIX}" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-4}" + +ncu \ + --target-processes all \ + --kernel-name "regex:.*device_kernel.*" \ + --launch-skip 1 \ + --launch-count "${NCU_LAUNCH_COUNT}" \ + "${NCU_SECTIONS[@]}" \ + "${NCU_EXTRA_ARGS[@]}" \ + --export "${REPORT_BASE}" --force-overwrite \ + "${BUILD_DIR}/w4a8_baseline_sweep" \ + --only="${VARIANT}" --m=1 --n=1536 --k=4096 --groups=128 --c=128 \ + --warmup=1 --iterations="${NCU_BENCH_ITERATIONS}" + +# Text summary for human reading: page=details, full per-metric breakdown. +ncu --import "${REPORT_BASE}.ncu-rep" --page details --print-units base \ + > "${REPORT_BASE}_details.txt" 2>&1 || true + +# CSV with all metrics, one row per kernel launch. +ncu --import "${REPORT_BASE}.ncu-rep" --csv --page raw \ + > "${REPORT_BASE}_metrics.csv" 2>&1 || true + +echo +echo "ncu report : ${REPORT_BASE}.ncu-rep" +echo "ncu details : ${REPORT_BASE}_details.txt" +echo "ncu raw CSV : ${REPORT_BASE}_metrics.csv" +if [[ "${PMSAMPLE}" == "1" ]]; then + echo + echo "PC sampling enabled. Open the report in Nsight Compute UI and switch to" + echo "the Source page, then sort SASS by smsp__warp_issue_stalled_long_scoreboard" + echo "(or stall_short_scoreboard / stall_mio_throttle) to see which SASS lines" + echo "are absorbing the stalls." +fi diff --git a/microbenchmarks/w4a8/baseline/run.sh b/microbenchmarks/w4a8/baseline/run.sh new file mode 100755 index 0000000000..f60772fd97 --- /dev/null +++ b/microbenchmarks/w4a8/baseline/run.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Build and run the W4A8 grouped GEMM tile/schedule sweep. +# Requires CUDA 12.3+ nvcc and a Hopper (SM90) GPU. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${BUILD_DIR:-${ROOT}/build}" +CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}" + +if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then + export PATH="${CUDA_HOME}/bin:${PATH}" +elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then + export CUDA_HOME="/usr/local/cuda" + export PATH="${CUDA_HOME}/bin:${PATH}" +fi +export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}" + +if ! command -v nvcc >/dev/null 2>&1; then + echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2 + exit 1 +fi + +cmake -S "${ROOT}" -B "${BUILD_DIR}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUTLASS_DIR="${CUTLASS_DIR}" \ + -DCMAKE_CUDA_ARCHITECTURES=90a \ + -DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}" + +cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_baseline_sweep + +# Default decode shape: 128 experts, M=1, N=1536, K=4096, c=128. +exec env CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" \ + "${BUILD_DIR}/w4a8_baseline_sweep" \ + --groups=128 \ + --m=1 \ + --n=1536 \ + --k=4096 \ + --c=128 \ + --alpha=1 \ + --beta=0 \ + --warmup=20 \ + --iterations=200 \ + "$@" diff --git a/microbenchmarks/w4a8/baseline/w4a8_baseline_sweep.cu b/microbenchmarks/w4a8/baseline/w4a8_baseline_sweep.cu new file mode 100644 index 0000000000..74e4265871 --- /dev/null +++ b/microbenchmarks/w4a8/baseline/w4a8_baseline_sweep.cu @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * W4A8 grouped GEMM tile/schedule sweep for the decode regime (small M, many experts). + * + * For every (TileShape, ClusterShape, KernelSchedule) candidate we instantiate the same + * Int4Fp8GemmGivenSchedule<...> from common/sm90_int4_fp8_grouped_baseline.cuh, run it + * through profile_variant(), and print a single ranked table. + * + * The sweep itself is the baseline experiment; per the project convention it does + * not perform an output-equality check (every other experiment does). + **************************************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/w4a8_bench_common.hpp" +#include "common/w4a8_kernel_common.cuh" +#include "common/sm90_int4_fp8_grouped_baseline.cuh" +#include "common/w4a8_grouped_setup.cuh" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +#include "helper.h" + +#if !defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) +int main() { + std::cerr << "This benchmark requires CUDA 12.3+ (modifiable TMA) and CUTLASS SM90 support.\n"; + return 0; +} +#else + +template +static SweepResult run_one(const std::string &name, W4A8BenchOptions const &opt, + W4A8SharedInputs &shared) { + using Gemm = typename Int4Fp8GemmGivenSchedule::GemmScaleOnly; + SweepResult r; + r.name = name; + + W4A8GemmContext ctx; + ctx.shared = &shared; + ctx.allocate(opt); + + Gemm gemm; + auto arguments = ctx.make_arguments(opt); + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + auto status_ci = gemm.can_implement(arguments); + if (status_ci != cutlass::Status::kSuccess) { + r.ok = false; + r.status = std::string("SKIP can_implement: ") + cutlass::cutlassGetStatusString(status_ci); + return r; + } + auto status_init = gemm.initialize(arguments, workspace.get()); + if (status_init != cutlass::Status::kSuccess) { + r.ok = false; + r.status = std::string("SKIP initialize: ") + cutlass::cutlassGetStatusString(status_init); + return r; + } + + // Warm-up launch outside the timer to surface any runtime error. + auto status_run = gemm.run(); + if (status_run != cutlass::Status::kSuccess) { + r.ok = false; + r.status = std::string("SKIP run: ") + cutlass::cutlassGetStatusString(status_run); + return r; + } + W4A8_CUDA_SYNC(); + + r = profile_variant(name, opt, [&]() { CUTLASS_CHECK(gemm.run()); }); + return r; +} + +// ============================================================================ +// Candidate tile/cluster/schedule configurations. +// ============================================================================ + +namespace sched { + using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using Ping = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +} + +namespace cfg { + template + using Tile = cute::Shape, cute::Int, cute::Int>; + + using C0_Tile = Tile<128, 128>; + using C1_Tile = Tile<128, 16>; + using C2_Tile = Tile<128, 16>; + using C3_Tile = Tile<128, 32>; + using C4_Tile = Tile< 64, 16>; + using C5_Tile = Tile< 64, 32>; + using C6_Tile = Tile< 64, 64>; + using C7_Tile = Tile<128, 16>; + + using Cluster_1_1_1 = cute::Shape; + using Cluster_2_1_1 = cute::Shape; + using Cluster_1_2_1 = cute::Shape; + + using C0 = GemmConfig; + using C1 = GemmConfig; + using C2 = GemmConfig; + using C3 = GemmConfig; + using C4 = GemmConfig; + using C5 = GemmConfig; + using C6 = GemmConfig; + using C7 = GemmConfig; +} + +struct Variant { + const char *id; + const char *name; + SweepResult (*run)(W4A8BenchOptions const &, W4A8SharedInputs &); +}; + +template +static SweepResult run_named(W4A8BenchOptions const &opt, W4A8SharedInputs &shared, const char *name) { + return run_one(std::string(name), opt, shared); +} + +#define VARIANT(ID, NAME, CFG, SCH) \ + Variant { \ + ID, NAME, \ + [](W4A8BenchOptions const &o, W4A8SharedInputs &s) { \ + return run_named(o, s, NAME); \ + } \ + } + +static const Variant kVariants[] = { + VARIANT("C0", "C0 coop 128x128x128 cluster1x1x1", cfg::C0, sched::Coop), + VARIANT("C1", "C1 coop 128x 16x128 cluster1x1x1", cfg::C1, sched::Coop), + VARIANT("C2", "C2 coop 128x 16x128 cluster2x1x1", cfg::C2, sched::Coop), + VARIANT("C3", "C3 coop 128x 32x128 cluster1x1x1", cfg::C3, sched::Coop), + VARIANT("C4", "C4 ping 64x 16x128 cluster1x1x1", cfg::C4, sched::Ping), + VARIANT("C5", "C5 ping 64x 32x128 cluster1x1x1", cfg::C5, sched::Ping), + VARIANT("C6", "C6 ping 64x 64x128 cluster1x1x1", cfg::C6, sched::Ping), + VARIANT("C7", "C7 coop 128x 16x128 cluster1x2x1", cfg::C7, sched::Coop), +}; + +int main(int argc, char const **argv) { + if (!cuda_toolkit_at_least_12_3()) { + std::cerr << "CUDA 12.3+ required.\n"; + return 0; + } + if (!device_is_hopper_sm90()) { + std::cerr << "Hopper (SM90) GPU required.\n"; + return 1; + } + + W4A8BenchOptions opt; + opt.parse(argc, argv); + + std::cout << "W4A8 grouped GEMM sweep\n"; + std::cout << " groups : " << opt.groups << "\n"; + std::cout << " per-group MNK: " << opt.m << " x " << opt.n << " x " << opt.k << "\n"; + std::cout << " scale chunk c: " << opt.c << "\n"; + std::cout << " warmup/iters : " << opt.warmup << " / " << opt.iterations << "\n"; + std::cout << " total math : " << opt.total_gemm_gflops() << " GFLOPs/iter\n"; + std::cout << " weight bytes : " << (opt.total_b_bytes() / (1024.0 * 1024.0)) << " MiB/iter\n"; + std::cout << " HBM3e peak : " << H200_HBM_PEAK_GIB_S << " GiB/s\n"; + if (!opt.only.empty()) { + std::cout << " filter --only: " << opt.only << "\n"; + } + std::cout << "\n"; + + // Allocate the inputs once and reuse them across every variant. + W4A8SharedInputs shared; + std::vector problem_host; + shared.allocate_and_init(opt, problem_host); + + std::vector results; + for (auto const &v : kVariants) { + if (!opt.only.empty() && opt.only != v.id) continue; + results.push_back(v.run(opt, shared)); + } + + if (results.empty()) { + std::cerr << "No variants matched --only='" << opt.only << "'.\n"; + return 2; + } + + // Print as-run order. + std::cout << "\n=== Sweep results (as-run order) ===\n"; + std::cout << std::left << std::setw(38) << "config" + << std::right << std::setw(11) << "time(ms)" + << std::setw(12) << "GFLOP/s" + << std::setw(13) << "GiB/s(B)" + << std::setw(9) << "%HBM" + << " status\n"; + for (auto const &r : results) { + std::cout << std::left << std::setw(38) << r.name; + if (r.ok) { + std::cout << std::right << std::fixed << std::setprecision(3) << std::setw(11) << r.avg_ms + << std::setprecision(1) << std::setw(12) << r.gflops + << std::setprecision(1) << std::setw(13) << r.bw_gib_s + << std::setprecision(1) << std::setw(9) << r.bw_pct_peak; + } else { + std::cout << std::right << std::setw(11) << "-" + << std::setw(12) << "-" + << std::setw(13) << "-" + << std::setw(9) << "-"; + } + std::cout << " " << r.status << "\n"; + } + + // Sort OK rows fastest first; failed rows pushed to the end. + std::vector sorted = results; + std::sort(sorted.begin(), sorted.end(), [](SweepResult const &a, SweepResult const &b) { + if (a.ok != b.ok) return a.ok && !b.ok; + if (!a.ok) return false; + return a.avg_ms < b.avg_ms; + }); + + std::cout << "\n=== Ranked (best -> worst by avg ms) ===\n"; + std::cout << std::left << std::setw(38) << "config" + << std::right << std::setw(11) << "time(ms)" + << std::setw(12) << "GFLOP/s" + << std::setw(13) << "GiB/s(B)" + << std::setw(9) << "%HBM" + << " status\n"; + for (auto const &r : sorted) { + std::cout << std::left << std::setw(38) << r.name; + if (r.ok) { + std::cout << std::right << std::fixed << std::setprecision(3) << std::setw(11) << r.avg_ms + << std::setprecision(1) << std::setw(12) << r.gflops + << std::setprecision(1) << std::setw(13) << r.bw_gib_s + << std::setprecision(1) << std::setw(9) << r.bw_pct_peak; + } else { + std::cout << std::right << std::setw(11) << "-" + << std::setw(12) << "-" + << std::setw(13) << "-" + << std::setw(9) << "-"; + } + std::cout << " " << r.status << "\n"; + } + + // Highlight winner and speedup vs baseline (C0). + SweepResult const *baseline = nullptr; + SweepResult const *winner = nullptr; + for (auto const &r : results) { + if (r.name.rfind("C0 ", 0) == 0 && r.ok) baseline = &r; + } + for (auto const &r : sorted) { + if (r.ok) { winner = &r; break; } + } + if (winner && baseline && winner != baseline) { + double speedup = baseline->avg_ms / winner->avg_ms; + std::cout << "\nBest config: " << winner->name + << " (" << std::setprecision(2) << speedup << "x speedup vs C0 baseline)\n"; + } else if (winner) { + std::cout << "\nBest config: " << winner->name << "\n"; + } + return 0; +} + +#endif diff --git a/microbenchmarks/w4a8/common/sm90_int4_fp8_grouped_baseline.cuh b/microbenchmarks/w4a8/common/sm90_int4_fp8_grouped_baseline.cuh new file mode 100644 index 0000000000..ec37f0ef6c --- /dev/null +++ b/microbenchmarks/w4a8/common/sm90_int4_fp8_grouped_baseline.cuh @@ -0,0 +1,77 @@ +#pragma once + +// Slim baseline schedule: only the Int4Fp8GemmGivenSchedule<...> template and +// canonical configs. All shared types come from w4a8_kernel_common.cuh. +// +// This header is the canonical reference used by every non-baseline experiment +// for the equality check + fair-baseline timing. + +#include "common/w4a8_kernel_common.cuh" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#ifdef INT4FP8_GROUPED_SUPPORTED + +template +struct Int4Fp8GemmGivenSchedule { + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = typename TConfig::TileShape; + using ClusterShape = typename TConfig::ClusterShape; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using KernelSchedule = KernelScheduleTag; + using EpilogueSchedule = typename PtrArrayEpilogueScheduleFor::type; + + using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); + using LayoutQ_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout>, StrideQ>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentD, + ElementD, LayoutD*, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutQ_Reordered*, AlignmentB, + ElementF, LayoutF_Transpose*, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnly, + CollectiveEpilogue>; + + using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelScaleOnly::InternalStrideC; + using StrideD = typename GemmKernelScaleOnly::InternalStrideD; + using StrideC_ref = cutlass::detail::TagToStrideC_t; + using StrideD_ref = cutlass::detail::TagToStrideC_t; + using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; + using StrideS_ref = cutlass::detail::TagToStrideB_t; +}; + +// Default 128x128x128 cooperative (matches the original CUTLASS example). +using DefaultBaselineConfig = GemmConfig>, Shape<_1, _1, _1>>; +using Int4Fp8DefaultSchedule = Int4Fp8GemmGivenSchedule; + +// C1 winner from the baseline sweep: 128x16x128 cooperative. This is the +// canonical reference schedule for every non-baseline experiment that targets +// the M=1 decode regime. +using C1BaselineConfig = GemmConfig>, Shape<_1, _1, _1>>; +using Int4Fp8C1BaselineSchedule = Int4Fp8GemmGivenSchedule; + +#endif // INT4FP8_GROUPED_SUPPORTED diff --git a/microbenchmarks/w4a8/common/w4a8_bench_common.hpp b/microbenchmarks/w4a8/common/w4a8_bench_common.hpp new file mode 100644 index 0000000000..873c47a6bc --- /dev/null +++ b/microbenchmarks/w4a8/common/w4a8_bench_common.hpp @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +inline void w4a8_cuda_check(cudaError_t e, int line) { + if (e != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(e) << " at line " << line << std::endl; + std::exit(EXIT_FAILURE); + } +} +#define W4A8_CUDA_SYNC() w4a8_cuda_check(cudaDeviceSynchronize(), __LINE__) + +/// Shared CLI and timing for all W4A8 grouped microbenchmarks. +struct W4A8BenchOptions { + int groups = 128; + int m = 1; + int n = 1536; + int k = 4096; + int c = 128; + float alpha = 1.f; + float beta = 0.f; + int warmup = 20; + int iterations = 200; + std::string only; + + void parse(int argc, char const **argv) { + cutlass::CommandLine cmd(argc, argv); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("c", c); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("only", only); + } + + /// Total GEMM FLOPs across the grouped batch in GFLOP units (2*M*N*K per group). + double total_gemm_gflops() const { + uint64_t fmas = static_cast(groups) * static_cast(m) + * static_cast(n) * static_cast(k); + return (2.0 * static_cast(fmas)) / 1e9; + } + + /// Total INT4 weight bytes read across the grouped batch (groups * K * N * 4 bits). + double total_b_bytes() const { + return 0.5 * static_cast(groups) * static_cast(k) * static_cast(n); + } +}; + +inline bool environment_supports_modifiable_tma_w4a8() { +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + return true; +#else + return false; +#endif +} + +inline bool cuda_toolkit_at_least_12_3() { + return !(__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)); +} + +inline bool device_is_hopper_sm90(int device = 0) { + cudaDeviceProp props{}; + if (cudaGetDeviceProperties(&props, device) != cudaSuccess) { + return false; + } + return props.major == 9 && props.minor == 0; +} + +/// Result row from a single timed sweep variant. +struct SweepResult { + std::string name; + bool ok = false; + std::string status; + float avg_ms = 0.f; + double gflops = 0.0; + double bw_gib_s = 0.0; + double bw_pct_peak = 0.0; +}; + +/// Reference HBM3e peak BW for H200, GiB/s (4.8 TB/s = 4800 GB/s ~ 4470 GiB/s). +inline constexpr double H200_HBM_PEAK_GIB_S = 4470.0; + +/// Time a callable with CUDA events, computing avg_ms and derived throughput metrics. +template +SweepResult profile_variant(const std::string &name, W4A8BenchOptions const &opt, Fn &&fn) { + SweepResult r; + r.name = name; + if (opt.iterations <= 0) { + r.status = "skip: iterations<=0"; + return r; + } + cudaEvent_t start{}, stop{}; + cudaEventCreate(&start); + cudaEventCreate(&stop); + std::vector times; + times.reserve(opt.iterations); + + for (int it = 0; it < opt.warmup + opt.iterations; ++it) { + cudaEventRecord(start); + fn(); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float ms = 0.f; + cudaEventElapsedTime(&ms, start, stop); + if (it >= opt.warmup) { + times.push_back(ms); + } + } + cudaEventDestroy(start); + cudaEventDestroy(stop); + + r.avg_ms = std::accumulate(times.begin(), times.end(), 0.f) / + static_cast(times.size() > 0 ? times.size() : 1); + double t_s = static_cast(r.avg_ms) / 1000.0; + r.gflops = opt.total_gemm_gflops() / t_s; + r.bw_gib_s = (opt.total_b_bytes() / (1024.0 * 1024.0 * 1024.0)) / t_s; + r.bw_pct_peak = 100.0 * r.bw_gib_s / H200_HBM_PEAK_GIB_S; + r.ok = true; + r.status = "OK"; + return r; +} diff --git a/microbenchmarks/w4a8/common/w4a8_correctness.hpp b/microbenchmarks/w4a8/common/w4a8_correctness.hpp new file mode 100644 index 0000000000..bb94609310 --- /dev/null +++ b/microbenchmarks/w4a8/common/w4a8_correctness.hpp @@ -0,0 +1,96 @@ +#pragma once + +// bf16 output equality check shared by every non-baseline experiment. +// +// compare_block_D(ref_ctx, test_ctx, atol, rtol) +// - copies both block_D allocations to host +// - tracks max abs diff and max rel diff over all elements +// - declares passed = true iff every element satisfies +// |ref - test| <= atol OR |ref - test|/|ref| <= rtol +// - always prints the two diff numbers so the caller sees the actual gap +// +// Tolerances default to (atol=1e-2, rtol=5e-3), tuned for bf16 outputs of the +// W4A8 grouped GEMM at the target problem shapes. + +#include +#include +#include +#include +#include +#include + +#include "common/w4a8_grouped_setup.cuh" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#ifdef INT4FP8_GROUPED_SUPPORTED + +struct EqResult { + double max_abs_diff = 0.0; + double max_rel_diff = 0.0; + size_t total_elements = 0; + size_t mismatched = 0; + bool passed = true; + double atol = 0.0; + double rtol = 0.0; +}; + +template +inline EqResult compare_block_D(W4A8GemmContext const &ref, + W4A8GemmContext const &test, + double atol = 1e-2, + double rtol = 5e-3) { + using ElementRef = typename W4A8GemmContext::ElementOut; + using ElementTest = typename W4A8GemmContext::ElementOut; + + EqResult r; + r.atol = atol; + r.rtol = rtol; + + if (ref.block_D.size() != test.block_D.size()) { + std::cerr << "compare_block_D: size mismatch ref=" << ref.block_D.size() + << " test=" << test.block_D.size() << "\n"; + r.passed = false; + return r; + } + + size_t n = ref.block_D.size(); + r.total_elements = n; + + std::vector h_ref(n); + std::vector h_test(n); + ref.block_D.copy_to_host(h_ref.data()); + test.block_D.copy_to_host(h_test.data()); + + for (size_t i = 0; i < n; ++i) { + double a = static_cast(static_cast(h_ref[i])); + double b = static_cast(static_cast(h_test[i])); + double abs_diff = std::abs(a - b); + double denom = std::max(std::abs(a), 1e-12); + double rel_diff = abs_diff / denom; + + if (abs_diff > r.max_abs_diff) r.max_abs_diff = abs_diff; + if (rel_diff > r.max_rel_diff) r.max_rel_diff = rel_diff; + + if (abs_diff > atol && rel_diff > rtol) { + r.mismatched += 1; + } + } + + r.passed = (r.mismatched == 0); + return r; +} + +inline void print_eq_result(EqResult const &eq) { + std::cout << "Equality check (" + << "atol=" << std::scientific << std::setprecision(1) << eq.atol + << ", rtol=" << eq.rtol << "): " + << (eq.passed ? "PASS" : "FAIL") + << " max_abs_diff=" << std::scientific << std::setprecision(3) << eq.max_abs_diff + << " max_rel_diff=" << eq.max_rel_diff + << " mismatched=" << eq.mismatched << "/" << eq.total_elements + << std::defaultfloat << "\n"; +} + +#endif // INT4FP8_GROUPED_SUPPORTED diff --git a/microbenchmarks/w4a8/common/w4a8_grouped_setup.cuh b/microbenchmarks/w4a8/common/w4a8_grouped_setup.cuh new file mode 100644 index 0000000000..324034100d --- /dev/null +++ b/microbenchmarks/w4a8/common/w4a8_grouped_setup.cuh @@ -0,0 +1,290 @@ +#pragma once + +// Host-side data layout for the W4A8 grouped Int4xFP8 GEMM microbenchmarks. +// +// Two pieces: +// +// W4A8SharedInputs<> Owns A/B/scale/problem-size device buffers and the +// reordered B layout. NOT parameterized on the output +// buffer, so a single instance can feed multiple Gemm +// variants that share the W4A8 epilogue contract +// (same ColumnMajor LayoutC/D, same bf16 ElementD). +// Stride types are taken from the canonical baseline +// C1 schedule; every variant we ship uses an identical +// InternalStrideC/D. +// +// W4A8GemmContext Owns block_D + ptr_D for one Gemm variant; constructs +// Gemm::Arguments wired to a W4A8SharedInputs<>. +// +// This split lets a non-baseline experiment instantiate two contexts (one for +// the C1 reference, one for the experimental schedule) over a single shared +// input set, run both, and compare D buffers element-wise. + +#include +#include + +#include "common/w4a8_bench_common.hpp" +#include "common/w4a8_kernel_common.cuh" +#include "common/sm90_int4_fp8_grouped_baseline.cuh" + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "helper.h" +#include "mixed_dtype_utils.hpp" + +#ifdef INT4FP8_GROUPED_SUPPORTED + +using namespace cute; + +// Canonical Gemm whose stride types and reordered-B layout we reuse for every +// W4A8 variant. All schedules in this microbenchmark share these types. +using W4A8CanonicalSchedule = Int4Fp8C1BaselineSchedule; +using W4A8CanonicalGemm = typename W4A8CanonicalSchedule::GemmScaleOnly; +using W4A8CanonicalKernel = typename W4A8CanonicalGemm::GemmKernel; + +using W4A8StrideA = StrideF; +using W4A8StrideB = StrideQ; +using W4A8StrideC = typename W4A8CanonicalKernel::InternalStrideC; +using W4A8StrideD = typename W4A8CanonicalKernel::InternalStrideD; +using W4A8StrideS = typename W4A8CanonicalKernel::CollectiveMainloop::StrideScale; +using W4A8LayoutB_Reordered = typename W4A8CanonicalSchedule::LayoutQ_Reordered; + +/// Device buffers + per-group strides that are independent of the output Gemm. +struct W4A8SharedInputs { + std::vector offset_A; + std::vector offset_B; + std::vector offset_D; + std::vector offset_scale; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_C_host; + std::vector stride_D_host; + std::vector stride_S_host; + + cutlass::DeviceAllocation problem_sizes_device; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_B_modified; + cutlass::DeviceAllocation block_scale; + cutlass::DeviceAllocation> block_scale_packed; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation *> ptr_scale_packed; + cutlass::DeviceAllocation ptr_C; + + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation layout_B_reordered; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + cutlass::DeviceAllocation stride_S; + + // Sizes (in elements) of the per-group output, useful for callers that want + // to allocate their own block_D buffers. + int64_t total_D_elements = 0; + int64_t total_A_elements = 0; + int64_t total_B_elements = 0; + int64_t total_scale_elements = 0; + + void allocate_and_init(W4A8BenchOptions const &opt, + std::vector &problem_host) { + using US = typename ProblemShape::UnderlyingProblemShape; + offset_A.clear(); + offset_B.clear(); + offset_D.clear(); + offset_scale.clear(); + stride_A_host.clear(); + stride_B_host.clear(); + stride_C_host.clear(); + stride_D_host.clear(); + stride_S_host.clear(); + + int64_t total_A = 0; + int64_t total_B = 0; + int64_t total_D = 0; + int64_t total_scale = 0; + + problem_host.resize(opt.groups); + for (int i = 0; i < opt.groups; ++i) { + int M = opt.m; + int N = opt.n; + int K = opt.k; + problem_host[i] = US{M, N, K}; + + int const scale_k = cutlass::ceil_div(K, opt.c); + + offset_A.push_back(total_A); + offset_B.push_back(total_B * cutlass::sizeof_bits::value / 8); + offset_D.push_back(total_D); + offset_scale.push_back(total_scale); + + int64_t el_A = static_cast(M) * K; + int64_t el_B = static_cast(K) * N; + int64_t el_D = static_cast(M) * N; + int64_t el_scale = static_cast(scale_k) * N; + + total_A += el_A; + total_B += el_B; + total_D += el_D; + total_scale += el_scale; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(W4A8StrideA{}, cute::make_shape(M, K, 1))); + stride_B_host.push_back(cutlass::make_cute_packed_stride(W4A8StrideB{}, cute::make_shape(N, K, 1))); + stride_C_host.push_back(cutlass::make_cute_packed_stride(W4A8StrideC{}, cute::make_shape(N, M, 1))); + stride_D_host.push_back(cutlass::make_cute_packed_stride(W4A8StrideD{}, cute::make_shape(N, M, 1))); + stride_S_host.push_back(cutlass::make_cute_packed_stride(W4A8StrideS{}, cute::make_shape(N, scale_k, 1))); + } + + total_A_elements = total_A; + total_B_elements = total_B; + total_D_elements = total_D; + total_scale_elements = total_scale; + + block_A.reset(total_A); + block_B.reset(total_B); + block_B_modified.reset(total_B); + block_scale.reset(total_scale); + block_scale_packed.reset(total_scale); + + std::vector pA(opt.groups); + std::vector pB(opt.groups); + std::vector *> pS(opt.groups); + std::vector pC(opt.groups, nullptr); + + for (int i = 0; i < opt.groups; ++i) { + pA[i] = block_A.get() + offset_A[i]; + pB[i] = block_B_modified.get() + offset_B[i]; + pS[i] = block_scale_packed.get() + offset_scale[i]; + } + + ptr_A.reset(opt.groups); + ptr_A.copy_from_host(pA.data()); + ptr_B.reset(opt.groups); + ptr_B.copy_from_host(pB.data()); + ptr_scale_packed.reset(opt.groups); + ptr_scale_packed.copy_from_host(pS.data()); + ptr_C.reset(opt.groups); + ptr_C.copy_from_host(pC.data()); + + stride_A.reset(opt.groups); + stride_A.copy_from_host(stride_A_host.data()); + stride_B.reset(opt.groups); + stride_B.copy_from_host(stride_B_host.data()); + stride_C.reset(opt.groups); + stride_C.copy_from_host(stride_C_host.data()); + stride_D.reset(opt.groups); + stride_D.copy_from_host(stride_D_host.data()); + stride_S.reset(opt.groups); + stride_S.copy_from_host(stride_S_host.data()); + + uint64_t seed = 2026; + initialize_tensor(block_A, seed + 3); + initialize_tensor(block_B, seed + 2); + cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); + + MixedDtypeOptions mixed_opts; + mixed_opts.mode = 1; + initialize_scale(block_scale, mixed_opts, seed + 1); + cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size()); + + using LayoutAtomQuant = typename W4A8CanonicalSchedule::LayoutAtomQuant; + std::vector layout_B_host(opt.groups); + for (int i = 0; i < opt.groups; ++i) { + int N = opt.n; + int K = opt.k; + auto shape_B = cute::make_shape(N, K, Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B_host[i]); + layout_B_host[i] = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + cutlass::reorder_tensor(block_B_modified.get() + offset_B[i], layout_B, layout_B_host[i]); + } + layout_B_reordered.reset(opt.groups); + layout_B_reordered.copy_from_host(layout_B_host.data()); + + // The CUTLASS grouped GEMM expects (N, M, K) per group. + for (int i = 0; i < opt.groups; ++i) { + auto [M, N, K] = problem_host[i]; + problem_host[i] = make_tuple(N, M, K); + } + problem_sizes_device.reset(opt.groups); + problem_sizes_device.copy_from_host(problem_host.data()); + } +}; + +/// Per-Gemm output buffer + Arguments builder. Borrows the W4A8SharedInputs +/// for inputs and strides; allocates its own block_D / ptr_D. +template +struct W4A8GemmContext { + using ElementOut = typename Gemm::EpilogueOutputOp::ElementOutput; + + W4A8SharedInputs *shared = nullptr; + + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation ptr_D; + + void allocate(W4A8BenchOptions const &opt) { + if (shared == nullptr) { + std::cerr << "W4A8GemmContext::allocate: shared inputs not set\n"; + std::exit(EXIT_FAILURE); + } + block_D.reset(shared->total_D_elements); + + std::vector pD(opt.groups); + for (int i = 0; i < opt.groups; ++i) { + pD[i] = block_D.get() + shared->offset_D[i]; + } + ptr_D.reset(opt.groups); + ptr_D.copy_from_host(pD.data()); + } + + void zero_outputs() { + if (block_D.size() == 0) return; + w4a8_cuda_check( + cudaMemsetAsync(block_D.get(), 0, + block_D.size() * sizeof(ElementOut), + /*stream=*/0), + __LINE__); + } + + typename Gemm::Arguments make_arguments(W4A8BenchOptions const &opt) const { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + // The cooperative grouped-GEMM scheduler launches a persistent grid of + // exactly `sm_count` CTAs. With the device-queried SM count (132 on H200) + // the runtime block scheduler places exactly 1 CTA per SM, even when + // resource limits would allow 2. Experiments that want to exercise 2 + // CTAs/SM must inflate `sm_count` so the grid is large enough for the + // runtime to co-locate two CTAs on each SM. +#ifdef CUTLASS_W4A8_SM_COUNT_OVERRIDE + hw_info.sm_count = CUTLASS_W4A8_SM_COUNT_OVERRIDE; +#else + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); +#endif + + decltype(std::declval().epilogue.thread) fusion_args{}; + fusion_args.alpha = opt.alpha; + fusion_args.beta = opt.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {opt.groups, shared->problem_sizes_device.get(), nullptr}, + {shared->ptr_B.get(), shared->layout_B_reordered.get(), shared->ptr_A.get(), + shared->stride_A.get(), shared->ptr_scale_packed.get(), + shared->stride_S.get(), opt.c}, + {fusion_args, shared->ptr_C.get(), shared->stride_C.get(), ptr_D.get(), + shared->stride_D.get()}, + hw_info}; + } +}; + +#endif // INT4FP8_GROUPED_SUPPORTED diff --git a/microbenchmarks/w4a8/common/w4a8_kernel_common.cuh b/microbenchmarks/w4a8/common/w4a8_kernel_common.cuh new file mode 100644 index 0000000000..6548fc448a --- /dev/null +++ b/microbenchmarks/w4a8/common/w4a8_kernel_common.cuh @@ -0,0 +1,91 @@ +#pragma once + +// Type aliases shared across every W4A8 grouped-GEMM microbenchmark variant. +// Including this header is enough to use ProblemShape / MmaType / QuantType / +// GemmConfig / PtrArrayEpilogueScheduleFor without dragging in a specific +// schedule template. +// +// Each kernel-config header (sm90_int4_fp8_grouped_baseline.cuh, +// sm90_int4_fp8_grouped_2cta.cuh, ...) #includes this file and then defines +// only the schedule struct it owns. This split lets a single .cu translation +// unit instantiate two different schedule templates (e.g. baseline C1 + the +// experimental decode_2cta) without ODR/redefinition errors. + +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "mixed_dtype_utils.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/integral_constant.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/layout.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" + +using namespace cute; + +#ifdef INT4FP8_GROUPED_SUPPORTED + +// minimal alignemnt for hopper GroupedGemm https://github.com/NVIDIA/cutlass/issues/2042 +static const constexpr size_t ALIGNMENT = 64; +static const constexpr size_t MAX_PROBLEM_COUNT = 512; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementF = MmaType; +using LayoutF = cutlass::layout::RowMajor; +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + +// B matrix configuration +using ElementQ = QuantType; +using LayoutQ = cutlass::layout::ColumnMajor; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + +// Manual swap-and-transpose, so keep transposed input layouts for the kernel. +using LayoutF_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutQ_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using StrideF = cute::remove_pointer_t>; +using StrideQ = cute::remove_pointer_t>; + +using ElementScale = cutlass::float_e4m3_t; +using LayoutScale = cutlass::layout::ColumnMajor; + +// C/D matrix configuration +using ElementC = void; +using ElementD = cutlass::bfloat16_t; + +using LayoutC = cutlass::layout::ColumnMajor; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// (Threadblock tile, cluster shape) pair used by every Int4Fp8 schedule template. +template +struct GemmConfig { + using TileShape = ThreadBlockShape; + using ClusterShape = TClusterShape; +}; + +// Map a mainloop KernelSchedule to its matching epilogue schedule for ptr-array +// (grouped) GEMM. Two schedules supported: cooperative + pingpong. +template +struct PtrArrayEpilogueScheduleFor { + using type = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; +}; +template <> +struct PtrArrayEpilogueScheduleFor { + using type = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +}; + +#endif // INT4FP8_GROUPED_SUPPORTED diff --git a/microbenchmarks/w4a8/decode_2cta/.gitignore b/microbenchmarks/w4a8/decode_2cta/.gitignore new file mode 100644 index 0000000000..787dd46710 --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/.gitignore @@ -0,0 +1,2 @@ +build/ +profiles/ diff --git a/microbenchmarks/w4a8/decode_2cta/CMakeLists.txt b/microbenchmarks/w4a8/decode_2cta/CMakeLists.txt new file mode 100644 index 0000000000..98ce18b1af --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/CMakeLists.txt @@ -0,0 +1,90 @@ +# Standalone build for the W4A8 grouped GEMM "decode 2 CTAs/SM" experiment (Hopper SM90a). +# +# Three translation units: +# w4a8_2cta_ref.cu -- reference C1 schedule, NO macro overrides ("stock C1"). +# w4a8_2cta_test.cu -- experimental schedule, WITH the three CUTLASS_W4A8_* +# macros that force 2 CTAs/SM and reduced register +# requirements. +# w4a8_2cta_main.cu -- driver: parses CLI, owns shared inputs, runs both +# runners, executes the bf16 equality check. +# +# Per-source COMPILE_DEFINITIONS are used to keep the macro overrides scoped +# to the test TU only, so the reference C1 timing is the genuine CUTLASS +# default (__launch_bounds__(384, 1), setmaxnreg.dec/inc(40, 232)). +cmake_minimum_required(VERSION 3.19) +project(w4a8_microbench_decode_2cta LANGUAGES CXX CUDA) + +find_package(CUDAToolkit REQUIRED) + +# Locate CUTLASS three levels up by default: /microbenchmarks/w4a8/decode_2cta -> CUTLASS_DIR. +get_filename_component(_default_cutlass_dir "${CMAKE_CURRENT_SOURCE_DIR}/../../.." ABSOLUTE) +set(CUTLASS_DIR "${_default_cutlass_dir}" CACHE PATH "CUTLASS tree") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +# sm_90a: arch-conditional MMA and full Hopper ISA (H100/H200). +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 90a) +endif() + +set(CUTLASS_INCLUDES + "${CUTLASS_DIR}/include" + "${CUTLASS_DIR}/tools/util/include" + "${CUTLASS_DIR}/examples/common" + "${CUTLASS_DIR}/examples/55_hopper_mixed_dtype_gemm" +) + +add_executable(w4a8_2cta + w4a8_2cta_main.cu + w4a8_2cta_ref.cu + w4a8_2cta_test.cu +) + +target_include_directories(w4a8_2cta PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/.." # for #include "common/..." + ${CUTLASS_INCLUDES} +) + +# The CUTLASS_W4A8_* macros are kernel-class-scoped (templated on KernelSchedule +# and read in the GemmUniversal kernel header). Defining them at the target level +# would leak into the ref TU as well and contaminate its "stock C1" timing. +# Apply them ONLY to the test source. +# +# Why (40, 88) and not the originally planned (40, 108): +# * PTX requires the operand to be a multiple of 8. +# * Empirically (40, 104) deadlocks the kernel under 2 CTAs/SM: the producer +# setmaxnreg.dec releases too few regs for two consumer warpgroups (per CTA) +# to satisfy their setmaxnreg.inc when 2 CTAs/SM are co-resident, and the +# HW reg-pool arbitration stalls forever. +# * (40, 88) ptxas-allocates REG=80 with STACK=0 (no spills, no deadlock). +# Register accounting at (40, 88) with MinBlocks=2: +# per-CTA = 128 * 40 + 256 * 88 = 27 648 regs +# per-SM = 2 * 27 648 = 55 296 regs (slack = 10 240 below 65 536) +set_source_files_properties(w4a8_2cta_test.cu PROPERTIES + COMPILE_DEFINITIONS + "INT4FP8_GROUPED_SUPPORTED;CUTLASS_W4A8_MIN_BLOCKS_PER_SM=2;CUTLASS_W4A8_MMA_REG_REQUIREMENT=88;CUTLASS_W4A8_LOAD_REG_REQUIREMENT=40;CUTLASS_W4A8_SM_COUNT_OVERRIDE=264" +) +# CUTLASS_W4A8_SM_COUNT_OVERRIDE=264: +# The cooperative grouped-GEMM scheduler launches exactly `sm_count` CTAs +# (persistent grid). At sm_count=132 (the actual H200 SM count) the runtime +# places 1 CTA per SM regardless of resource budget, so the 2-CTA launch +# bounds + Stages=9 fix is wasted (Block Limit SMEM = 2 but Achieved +# Occupancy stays at 17%). Setting sm_count=264 grows the persistent grid to +# 2x SM count, letting the runtime block scheduler co-locate two CTAs per +# SM. The ref TU keeps the device-queried 132 (no macro), so its launch +# characteristics match stock CUTLASS exactly. + +# Ref + driver TUs: only INT4FP8_GROUPED_SUPPORTED, no kernel-constant overrides. +set_source_files_properties(w4a8_2cta_ref.cu w4a8_2cta_main.cu PROPERTIES + COMPILE_DEFINITIONS "INT4FP8_GROUPED_SUPPORTED" +) + +target_compile_options(w4a8_2cta PRIVATE + $<$:--expt-relaxed-constexpr> + $<$:-ftemplate-backtrace-limit=0> +) +target_link_libraries(w4a8_2cta PRIVATE CUDA::cudart) diff --git a/microbenchmarks/w4a8/decode_2cta/profile.sh b/microbenchmarks/w4a8/decode_2cta/profile.sh new file mode 100755 index 0000000000..29d0bd759e --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/profile.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +# Profile the decode_2cta experiment kernel with Nsight Compute. +# Builds the experiment binary if needed, runs ncu against w4a8_2cta filtered to +# the StageCount=10 (test) kernel only, and saves the report + a details dump +# + a raw-metrics CSV. +# +# Usage: +# CUDA_VISIBLE_DEVICES=4 ./profile.sh # default report +# CUDA_VISIBLE_DEVICES=4 PMSAMPLE=1 ./profile.sh # adds PC sampling +# +# Modes: +# default : SpeedOfLight + memory/scheduler/warp/instruction/launch/occupancy sections. +# PMSAMPLE=1 : adds SourceCounters + PmSampling + PmSampling_WarpStates so the SASS view +# in ncu-ui can break down stalls per source/SASS line. +# Writes to profiles/decode_2cta_pmsample.ncu-rep so the default report is +# preserved untouched. +# +# Note: the binary launches BOTH the reference C1 kernel (StageCount=19, stock +# CUTLASS launch bounds) and the experimental decode_2cta kernel (StageCount=9, +# 2 CTAs/SM launch bounds). The ncu --kernel-name regex below filters down to the +# StageCount=9 instance only (mangled name fragment "MainloopSm90Array...Li9E"), +# so the captured launches are unambiguously the experiment kernel. +# +# Requires CUDA 12.3+ and a Hopper (SM90) GPU. Pinned to GPU 4 by default. +# Needs access to NVIDIA GPU perf counters - either run as root (sudo -E ...) or set +# NVreg_RestrictProfilingToAdminUsers=0 on the nvidia kernel module. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${BUILD_DIR:-${ROOT}/build}" +PROFILES_DIR="${ROOT}/profiles" +CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}" +PMSAMPLE="${PMSAMPLE:-0}" + +# Filter to the StageCount=9 (test) kernel by default. The cooperative +# CollectiveMma carries StageCount as the first non-type template arg, mangled +# as Li9E for the test kernel and Li19E for the ref. Override via env if the +# StageCount parameter changes. +NCU_KERNEL_REGEX="${NCU_KERNEL_REGEX:-regex:.*MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInputILi9E.*}" + +if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then + export PATH="${CUDA_HOME}/bin:${PATH}" +elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then + export CUDA_HOME="/usr/local/cuda" + export PATH="${CUDA_HOME}/bin:${PATH}" +fi +export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}" + +if ! command -v nvcc >/dev/null 2>&1; then + echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2 + exit 1 +fi +if ! command -v ncu >/dev/null 2>&1; then + echo "ncu (Nsight Compute) not found. Add CUDA toolkit bin/ to PATH." >&2 + exit 1 +fi + +cmake -S "${ROOT}" -B "${BUILD_DIR}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUTLASS_DIR="${CUTLASS_DIR}" \ + -DCMAKE_CUDA_ARCHITECTURES=90a \ + -DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}" >/dev/null + +cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_2cta + +mkdir -p "${PROFILES_DIR}" +REPORT_SUFFIX="" +NCU_LAUNCH_COUNT=3 +NCU_BENCH_ITERATIONS=4 + +NCU_SECTIONS=( + --section SpeedOfLight + --section SpeedOfLight_RooflineChart + --section LaunchStats + --section Occupancy + --section MemoryWorkloadAnalysis + --section MemoryWorkloadAnalysis_Chart + --section MemoryWorkloadAnalysis_Tables + --section SchedulerStats + --section WarpStateStats + --section ComputeWorkloadAnalysis + --section InstructionStats +) + +NCU_EXTRA_ARGS=() +if [[ "${PMSAMPLE}" == "1" ]]; then + NCU_SECTIONS+=( + --section SourceCounters + --section PmSampling + --section PmSampling_WarpStates + ) + NCU_EXTRA_ARGS+=(--import-source yes) + REPORT_SUFFIX="_pmsample" + NCU_LAUNCH_COUNT=2 + NCU_BENCH_ITERATIONS=8 +fi +REPORT_BASE="${PROFILES_DIR}/decode_2cta${REPORT_SUFFIX}" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-4}" + +# --launch-skip 1: drop the runner's first warmup launch (issued before the +# profile_variant timing window). +ncu \ + --target-processes all \ + --kernel-name-base mangled \ + --kernel-name "${NCU_KERNEL_REGEX}" \ + --launch-skip 1 \ + --launch-count "${NCU_LAUNCH_COUNT}" \ + "${NCU_SECTIONS[@]}" \ + "${NCU_EXTRA_ARGS[@]}" \ + --export "${REPORT_BASE}" --force-overwrite \ + "${BUILD_DIR}/w4a8_2cta" \ + --m=1 --n=1536 --k=4096 --groups=128 --c=128 \ + --warmup=1 --iterations="${NCU_BENCH_ITERATIONS}" + +ncu --import "${REPORT_BASE}.ncu-rep" --page details --print-units base \ + > "${REPORT_BASE}_details.txt" 2>&1 || true + +ncu --import "${REPORT_BASE}.ncu-rep" --csv --page raw \ + > "${REPORT_BASE}_metrics.csv" 2>&1 || true + +echo +echo "ncu report : ${REPORT_BASE}.ncu-rep" +echo "ncu details : ${REPORT_BASE}_details.txt" +echo "ncu raw CSV : ${REPORT_BASE}_metrics.csv" +echo "kernel filter: ${NCU_KERNEL_REGEX}" +if [[ "${PMSAMPLE}" == "1" ]]; then + echo + echo "PC sampling enabled. Open the report in Nsight Compute UI and switch to" + echo "the Source page, then sort SASS by smsp__warp_issue_stalled_long_scoreboard" + echo "(or stall_short_scoreboard / stall_mio_throttle) to see which SASS lines" + echo "are absorbing the stalls." +fi diff --git a/microbenchmarks/w4a8/decode_2cta/run.sh b/microbenchmarks/w4a8/decode_2cta/run.sh new file mode 100755 index 0000000000..e57e19a37b --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/run.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +# Build and run the W4A8 grouped GEMM "decode 2 CTAs/SM" experiment. +# Requires CUDA 12.3+ nvcc and a Hopper (SM90) GPU. +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BUILD_DIR="${BUILD_DIR:-${ROOT}/build}" +CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}" + +if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then + export PATH="${CUDA_HOME}/bin:${PATH}" +elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then + export CUDA_HOME="/usr/local/cuda" + export PATH="${CUDA_HOME}/bin:${PATH}" +fi +export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}" + +if ! command -v nvcc >/dev/null 2>&1; then + echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2 + exit 1 +fi + +cmake -S "${ROOT}" -B "${BUILD_DIR}" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCUTLASS_DIR="${CUTLASS_DIR}" \ + -DCMAKE_CUDA_ARCHITECTURES=90a \ + -DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}" + +cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_2cta + +# Default decode shape: 128 experts, M=1, N=1536, K=4096, c=128. GPU4 by default +# (override with CUDA_VISIBLE_DEVICES=...) per the established convention while +# other GPUs on this host are in use. +exec env CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-4}" \ + "${BUILD_DIR}/w4a8_2cta" \ + --groups=128 \ + --m=1 \ + --n=1536 \ + --k=4096 \ + --c=128 \ + --alpha=1 \ + --beta=0 \ + --warmup=20 \ + --iterations=200 \ + "$@" diff --git a/microbenchmarks/w4a8/decode_2cta/sm90_int4_fp8_grouped_2cta.cuh b/microbenchmarks/w4a8/decode_2cta/sm90_int4_fp8_grouped_2cta.cuh new file mode 100644 index 0000000000..7420ce3b5d --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/sm90_int4_fp8_grouped_2cta.cuh @@ -0,0 +1,69 @@ +#pragma once + +// Decode 2 CTAs/SM variant of the W4A8 grouped GEMM schedule. +// +// Identical to the baseline Int4Fp8GemmGivenSchedule except the mainloop pipeline +// depth is set explicitly via cutlass::gemm::collective::StageCount +// instead of StageCountAutoCarveout. Capping the depth shrinks the per-CTA SMEM +// footprint, which together with the macro-overridden MmaRegisterRequirement and +// MinBlocksPerMultiprocessor (defined in this experiment's CMakeLists.txt) is +// intended to enable 2 CTAs per SM for the M=1 decode regime. + +#include "common/w4a8_kernel_common.cuh" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#ifdef INT4FP8_GROUPED_SUPPORTED + +template +struct Int4Fp8GemmGivenScheduleStaged { + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = typename TConfig::TileShape; + using ClusterShape = typename TConfig::ClusterShape; + using KernelSchedule = KernelScheduleTag; + using EpilogueSchedule = typename PtrArrayEpilogueScheduleFor::type; + + using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); + using LayoutQ_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout>, StrideQ>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentD, + ElementD, LayoutD*, AlignmentD, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutQ_Reordered*, AlignmentB, + ElementF, LayoutF_Transpose*, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCount, + KernelSchedule>::CollectiveOp; + + using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnly, + CollectiveEpilogue>; + + using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelScaleOnly::InternalStrideC; + using StrideD = typename GemmKernelScaleOnly::InternalStrideD; + using StrideC_ref = cutlass::detail::TagToStrideC_t; + using StrideD_ref = cutlass::detail::TagToStrideC_t; + using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; + using StrideS_ref = cutlass::detail::TagToStrideB_t; +}; + +#endif // INT4FP8_GROUPED_SUPPORTED diff --git a/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_main.cu b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_main.cu new file mode 100644 index 0000000000..2a32404caa --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_main.cu @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Driver for the decode_2cta experiment. + * + * Owns CLI parsing + W4A8SharedInputs (input data is allocated and initialized + * once). Calls the reference and test runners, runs the bf16 output equality + * check on the captured host D buffers, and prints a 2-row comparison plus the + * measured speedup. + **************************************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/w4a8_bench_common.hpp" +#include "common/w4a8_kernel_common.cuh" +#include "common/sm90_int4_fp8_grouped_baseline.cuh" +#include "common/w4a8_grouped_setup.cuh" + +#include "w4a8_2cta_runners.h" + +#if !defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) +int main() { + std::cerr << "This benchmark requires CUDA 12.3+ (modifiable TMA) and CUTLASS SM90 support.\n"; + return 0; +} +#else + +namespace { + +void print_result_row(SweepResult const &r) { + std::cout << std::left << std::setw(54) << r.name; + if (r.ok) { + std::cout << std::right << std::fixed << std::setprecision(3) << std::setw(11) << r.avg_ms + << std::setprecision(1) << std::setw(12) << r.gflops + << std::setprecision(1) << std::setw(13) << r.bw_gib_s + << std::setprecision(1) << std::setw(9) << r.bw_pct_peak; + } else { + std::cout << std::right << std::setw(11) << "-" + << std::setw(12) << "-" + << std::setw(13) << "-" + << std::setw(9) << "-"; + } + std::cout << " " << r.status << "\n"; +} + +struct EqResultHost { + double max_abs_diff = 0.0; + double max_rel_diff = 0.0; + size_t total_elements = 0; + size_t mismatched = 0; + bool passed = true; + double atol = 0.0; + double rtol = 0.0; +}; + +EqResultHost compare_host_bf16(std::vector const &ref, + std::vector const &test, + double atol, double rtol) { + EqResultHost r; + r.atol = atol; + r.rtol = rtol; + if (ref.size() != test.size()) { + std::cerr << "Equality check: size mismatch ref=" << ref.size() + << " test=" << test.size() << "\n"; + r.passed = false; + return r; + } + r.total_elements = ref.size(); + for (size_t i = 0; i < ref.size(); ++i) { + double a = static_cast(static_cast(ref[i])); + double b = static_cast(static_cast(test[i])); + double abs_diff = std::abs(a - b); + double denom = std::max(std::abs(a), 1e-12); + double rel_diff = abs_diff / denom; + if (abs_diff > r.max_abs_diff) r.max_abs_diff = abs_diff; + if (rel_diff > r.max_rel_diff) r.max_rel_diff = rel_diff; + if (abs_diff > atol && rel_diff > rtol) { + r.mismatched += 1; + } + } + r.passed = (r.mismatched == 0); + return r; +} + +} // namespace + +int main(int argc, char const **argv) { + if (!cuda_toolkit_at_least_12_3()) { + std::cerr << "CUDA 12.3+ required.\n"; + return 0; + } + if (!device_is_hopper_sm90()) { + std::cerr << "Hopper (SM90) GPU required.\n"; + return 1; + } + + W4A8BenchOptions opt; + opt.parse(argc, argv); + + std::cout << "W4A8 grouped GEMM (decode 2 CTAs/SM)\n"; + std::cout << " groups : " << opt.groups << "\n"; + std::cout << " per-group MNK: " << opt.m << " x " << opt.n << " x " << opt.k << "\n"; + std::cout << " scale chunk c: " << opt.c << "\n"; + std::cout << " warmup/iters : " << opt.warmup << " / " << opt.iterations << "\n"; + std::cout << " total math : " << opt.total_gemm_gflops() << " GFLOPs/iter\n"; + std::cout << " weight bytes : " << (opt.total_b_bytes() / (1024.0 * 1024.0)) << " MiB/iter\n"; + std::cout << " HBM3e peak : " << H200_HBM_PEAK_GIB_S << " GiB/s\n"; + std::cout << "\n"; + + W4A8SharedInputs shared; + std::vector problem_host; + shared.allocate_and_init(opt, problem_host); + + // Reference first; test second. Both runners do their own warmup + timing + // window inside the same binary, on the same GPU, with the same warmup/iters. + W4A8RunResult ref = run_ref(shared, opt); + W4A8RunResult test = run_test(shared, opt); + + std::cout << "=== Results ===\n"; + std::cout << std::left << std::setw(54) << "config" + << std::right << std::setw(11) << "time(ms)" + << std::setw(12) << "GFLOP/s" + << std::setw(13) << "GiB/s(B)" + << std::setw(9) << "%HBM" + << " status\n"; + print_result_row(ref.result); + print_result_row(test.result); + + std::cout << "\n"; + if (ref.result.ok && test.result.ok) { + EqResultHost eq = compare_host_bf16(ref.host_D, test.host_D, + /*atol=*/1e-2, /*rtol=*/5e-3); + std::cout << "Equality check (atol=" << std::scientific << std::setprecision(1) << eq.atol + << ", rtol=" << eq.rtol << "): " + << (eq.passed ? "PASS" : "FAIL") + << " max_abs_diff=" << std::scientific << std::setprecision(3) << eq.max_abs_diff + << " max_rel_diff=" << eq.max_rel_diff + << " mismatched=" << eq.mismatched << "/" << eq.total_elements + << std::defaultfloat << "\n"; + + double speedup = static_cast(ref.result.avg_ms) / + static_cast(test.result.avg_ms); + std::cout << "Speedup vs measured baseline: " + << std::fixed << std::setprecision(2) << speedup << "x\n"; + return eq.passed ? 0 : 3; + } else { + std::cerr << "One or both runs failed; skipping equality check.\n"; + return 2; + } +} + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_ref.cu b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_ref.cu new file mode 100644 index 0000000000..04b6bc255f --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_ref.cu @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Reference TU for the decode_2cta experiment. + * + * Compiled WITHOUT the CUTLASS_W4A8_* macro overrides, so the cooperative grouped + * GEMM kernel keeps its CUTLASS-default __launch_bounds__(384, 1) and + * setmaxnreg.dec/inc(40, 232). This produces the genuine "stock C1" timing that + * the experimental decode_2cta variant is measured against. + **************************************************************************************************/ + +#include +#include +#include +#include +#include + +#include "common/w4a8_bench_common.hpp" +#include "common/w4a8_kernel_common.cuh" +#include "common/sm90_int4_fp8_grouped_baseline.cuh" +#include "common/w4a8_grouped_setup.cuh" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +#include "helper.h" + +#include "w4a8_2cta_runners.h" + +#if !defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +W4A8RunResult run_ref(W4A8SharedInputs & /*shared*/, W4A8BenchOptions const & /*opt*/) { + W4A8RunResult r; + r.result.name = "baseline C1 (StageCountAuto, MinBlocks=1)"; + r.result.ok = false; + r.result.status = "SKIP: requires CUDA 12.3+ and CUTLASS SM90"; + return r; +} + +#else + +using RefGemm = typename Int4Fp8C1BaselineSchedule::GemmScaleOnly; + +W4A8RunResult run_ref(W4A8SharedInputs &shared, W4A8BenchOptions const &opt) { + W4A8RunResult r; + r.result.name = "baseline C1 (StageCountAuto, MinBlocks=1)"; + + W4A8GemmContext ctx; + ctx.shared = &shared; + ctx.allocate(opt); + + RefGemm gemm; + auto arguments = ctx.make_arguments(opt); + size_t workspace_size = RefGemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + auto status_ci = gemm.can_implement(arguments); + if (status_ci != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP can_implement: ") + cutlass::cutlassGetStatusString(status_ci); + return r; + } + auto status_init = gemm.initialize(arguments, workspace.get()); + if (status_init != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP initialize: ") + cutlass::cutlassGetStatusString(status_init); + return r; + } + + auto status_run = gemm.run(); + if (status_run != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP run: ") + cutlass::cutlassGetStatusString(status_run); + return r; + } + W4A8_CUDA_SYNC(); + + r.result = profile_variant(r.result.name, opt, [&]() { CUTLASS_CHECK(gemm.run()); }); + + // Deterministic re-run for the equality check: zero D, run once, sync, copy out. + ctx.zero_outputs(); + W4A8_CUDA_SYNC(); + CUTLASS_CHECK(gemm.run()); + W4A8_CUDA_SYNC(); + + r.host_D.resize(ctx.block_D.size()); + ctx.block_D.copy_to_host(r.host_D.data()); + return r; +} + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_runners.h b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_runners.h new file mode 100644 index 0000000000..a5c9e934d8 --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_runners.h @@ -0,0 +1,43 @@ +#pragma once + +// Cross-TU interface for the decode_2cta experiment driver. +// +// The experiment ships three translation units. Two run the GEMM kernel: +// +// w4a8_2cta_ref.cu compiled WITHOUT the CUTLASS_W4A8_* macro overrides; +// instantiates and runs the canonical baseline C1 schedule. +// Time + output captured here is what every speedup number +// is measured against. +// +// w4a8_2cta_test.cu compiled WITH the macro overrides +// (CUTLASS_W4A8_MIN_BLOCKS_PER_SM=2, +// CUTLASS_W4A8_MMA_REG_REQUIREMENT=88, +// CUTLASS_W4A8_LOAD_REG_REQUIREMENT=40); +// instantiates and runs the experimental decode_2cta +// schedule (StageCount=9 + 2 CTAs/SM via launch bounds). +// +// w4a8_2cta_main.cu drives both: parses CLI, owns W4A8SharedInputs, calls +// run_ref / run_test, runs the equality check and prints +// the comparison. +// +// Each runner returns its measured timing AND copies its bf16 block_D out to a +// caller-owned host vector so the main TU can compare without any cross-TU +// W4A8GemmContext type leakage. + +#include +#include + +#include "cutlass/numeric_types.h" + +#include "common/w4a8_bench_common.hpp" + +// Forward declaration; full definition lives in common/w4a8_grouped_setup.cuh. +struct W4A8SharedInputs; + +struct W4A8RunResult { + SweepResult result; + std::vector host_D; // populated by a deterministic re-run. +}; + +W4A8RunResult run_ref(W4A8SharedInputs &shared, W4A8BenchOptions const &opt); +W4A8RunResult run_test(W4A8SharedInputs &shared, W4A8BenchOptions const &opt); diff --git a/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_test.cu b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_test.cu new file mode 100644 index 0000000000..83a9f0472c --- /dev/null +++ b/microbenchmarks/w4a8/decode_2cta/w4a8_2cta_test.cu @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Test TU for the decode_2cta experiment. + * + * Compiled WITH the three CUTLASS_W4A8_* macro overrides, so the cooperative + * grouped GEMM kernel uses __launch_bounds__(384, 2) and setmaxnreg.dec/inc(40, 88). + * The instantiated schedule is Int4Fp8GemmGivenScheduleStaged. + **************************************************************************************************/ + +#include +#include +#include +#include +#include + +#include "common/w4a8_bench_common.hpp" +#include "common/w4a8_kernel_common.cuh" +#include "common/sm90_int4_fp8_grouped_baseline.cuh" +#include "common/w4a8_grouped_setup.cuh" + +#include "sm90_int4_fp8_grouped_2cta.cuh" + +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +#include "helper.h" + +#include "w4a8_2cta_runners.h" + +#if !defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +W4A8RunResult run_test(W4A8SharedInputs & /*shared*/, W4A8BenchOptions const & /*opt*/) { + W4A8RunResult r; + r.result.name = "decode_2cta C1 (Stages=9, MinBlocks=2, sm_count=264)"; + r.result.ok = false; + r.result.status = "SKIP: requires CUDA 12.3+ and CUTLASS SM90"; + return r; +} + +#else + +namespace { +using TestConfig = GemmConfig>, + cute::Shape>; +// kStages=9: chosen so per-block dynamic SMEM drops below the H200 232 KiB / 2 +// budget, letting the runtime block scheduler actually pack 2 CTAs/SM. At +// kStages=10 the SMEM was 121 KiB/block (2x = 242 KiB > 228 KiB usable) and +// `Block Limit Shared Mem` stayed at 1, defeating the entire experiment. +using TestSchedule = Int4Fp8GemmGivenScheduleStaged< + TestConfig, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative, + /*kStages=*/9>; +using TestGemm = typename TestSchedule::GemmScaleOnly; +} + +W4A8RunResult run_test(W4A8SharedInputs &shared, W4A8BenchOptions const &opt) { + W4A8RunResult r; + r.result.name = "decode_2cta C1 (Stages=9, MinBlocks=2, sm_count=264)"; + + W4A8GemmContext ctx; + ctx.shared = &shared; + ctx.allocate(opt); + + TestGemm gemm; + auto arguments = ctx.make_arguments(opt); + size_t workspace_size = TestGemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + auto status_ci = gemm.can_implement(arguments); + if (status_ci != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP can_implement: ") + cutlass::cutlassGetStatusString(status_ci); + return r; + } + auto status_init = gemm.initialize(arguments, workspace.get()); + if (status_init != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP initialize: ") + cutlass::cutlassGetStatusString(status_init); + return r; + } + + auto status_run = gemm.run(); + if (status_run != cutlass::Status::kSuccess) { + r.result.ok = false; + r.result.status = std::string("SKIP run: ") + cutlass::cutlassGetStatusString(status_run); + return r; + } + W4A8_CUDA_SYNC(); + + r.result = profile_variant(r.result.name, opt, [&]() { CUTLASS_CHECK(gemm.run()); }); + + ctx.zero_outputs(); + W4A8_CUDA_SYNC(); + CUTLASS_CHECK(gemm.run()); + W4A8_CUDA_SYNC(); + + r.host_D.resize(ctx.block_D.size()); + ctx.block_D.copy_to_host(r.host_D.data()); + return r; +} + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED