Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,22 @@ class GemmUniversal<
static constexpr uint32_t NumMmaThreads = size(TiledMma{});
static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
#ifndef CUTLASS_W4A8_MIN_BLOCKS_PER_SM
#define CUTLASS_W4A8_MIN_BLOCKS_PER_SM 1
#endif
static constexpr uint32_t MinBlocksPerMultiprocessor = CUTLASS_W4A8_MIN_BLOCKS_PER_SM;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;

/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
#ifndef CUTLASS_W4A8_LOAD_REG_REQUIREMENT
#define CUTLASS_W4A8_LOAD_REG_REQUIREMENT 40
#endif
#ifndef CUTLASS_W4A8_MMA_REG_REQUIREMENT
#define CUTLASS_W4A8_MMA_REG_REQUIREMENT 232
#endif
static constexpr uint32_t LoadRegisterRequirement = CUTLASS_W4A8_LOAD_REG_REQUIREMENT;
static constexpr uint32_t MmaRegisterRequirement = CUTLASS_W4A8_MMA_REG_REQUIREMENT;

// 1 stage ordered sequence between mainloop and epilogue producer load threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>;
Expand Down
2 changes: 2 additions & 0 deletions microbenchmarks/w4a8/baseline/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
build/
profiles/
39 changes: 39 additions & 0 deletions microbenchmarks/w4a8/baseline/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Standalone build for the W4A8 grouped GEMM tile/schedule sweep (Hopper SM90a).
cmake_minimum_required(VERSION 3.19)
project(w4a8_microbench_sweep LANGUAGES CXX CUDA)

find_package(CUDAToolkit REQUIRED)

# Locate CUTLASS three levels up by default: <CUTLASS_DIR>/microbenchmarks/w4a8/baseline -> CUTLASS_DIR.
get_filename_component(_default_cutlass_dir "${CMAKE_CURRENT_SOURCE_DIR}/../../.." ABSOLUTE)
set(CUTLASS_DIR "${_default_cutlass_dir}" CACHE PATH "CUTLASS tree")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

# sm_90a: arch-conditional MMA and full Hopper ISA (H100/H200).
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 90a)
endif()

set(CUTLASS_INCLUDES
"${CUTLASS_DIR}/include"
"${CUTLASS_DIR}/tools/util/include"
"${CUTLASS_DIR}/examples/common"
"${CUTLASS_DIR}/examples/55_hopper_mixed_dtype_gemm"
)

add_executable(w4a8_baseline_sweep w4a8_baseline_sweep.cu)
target_include_directories(w4a8_baseline_sweep PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/.." # for #include "common/..."
${CUTLASS_INCLUDES}
)
target_compile_definitions(w4a8_baseline_sweep PRIVATE INT4FP8_GROUPED_SUPPORTED)
target_compile_options(w4a8_baseline_sweep PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>
$<$<COMPILE_LANGUAGE:CUDA>:-ftemplate-backtrace-limit=0>
)
target_link_libraries(w4a8_baseline_sweep PRIVATE CUDA::cudart)
122 changes: 122 additions & 0 deletions microbenchmarks/w4a8/baseline/profile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env bash
# Profile a chosen variant (default C1) of the W4A8 grouped GEMM with Nsight Compute.
# Builds the sweep binary if needed, then runs ncu on --only=<VARIANT> with a curated section set,
# saving the report to profiles/<variant>[suffix].ncu-rep, a details dump, and a raw-metrics CSV.
#
# Modes:
# default : SpeedOfLight + memory/scheduler/warp/instruction/launch/occupancy sections.
# PMSAMPLE=1 : adds SourceCounters + PmSampling + PmSampling_WarpStates so the SASS view
# in ncu-ui can break down stalls per source/SASS line.
# Writes to profiles/<variant>_pmsample.ncu-rep so the default report is
# preserved untouched.
#
# Requires CUDA 12.3+ and a Hopper (SM90) GPU. Pinned to GPU 4 by default.
# Needs access to NVIDIA GPU perf counters - either run as root (sudo -E ...) or set
# NVreg_RestrictProfilingToAdminUsers=0 on the nvidia kernel module.
set -euo pipefail

ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_DIR="${BUILD_DIR:-${ROOT}/build}"
PROFILES_DIR="${ROOT}/profiles"
CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}"
VARIANT="${VARIANT:-C1}"
PMSAMPLE="${PMSAMPLE:-0}"

if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then
export PATH="${CUDA_HOME}/bin:${PATH}"
elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then
export CUDA_HOME="/usr/local/cuda"
export PATH="${CUDA_HOME}/bin:${PATH}"
fi
export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}"

if ! command -v nvcc >/dev/null 2>&1; then
echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2
exit 1
fi
if ! command -v ncu >/dev/null 2>&1; then
echo "ncu (Nsight Compute) not found. Add CUDA toolkit bin/ to PATH." >&2
exit 1
fi

cmake -S "${ROOT}" -B "${BUILD_DIR}" \
-DCMAKE_BUILD_TYPE=Release \
-DCUTLASS_DIR="${CUTLASS_DIR}" \
-DCMAKE_CUDA_ARCHITECTURES=90a \
-DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}" >/dev/null

cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_baseline_sweep

mkdir -p "${PROFILES_DIR}"
REPORT_SUFFIX=""
NCU_LAUNCH_COUNT=3
NCU_BENCH_ITERATIONS=4

# ncu kernel filter: the cooperative TMA-WS grouped-gemm entry. The warmup launch in run_one()
# is skipped via --launch-skip 1; we then capture timed launches with rich sections.
NCU_SECTIONS=(
--section SpeedOfLight
--section SpeedOfLight_RooflineChart
--section LaunchStats
--section Occupancy
--section MemoryWorkloadAnalysis
--section MemoryWorkloadAnalysis_Chart
--section MemoryWorkloadAnalysis_Tables
--section SchedulerStats
--section WarpStateStats
--section ComputeWorkloadAnalysis
--section InstructionStats
)

NCU_EXTRA_ARGS=()
if [[ "${PMSAMPLE}" == "1" ]]; then
# PC sampling: attribute warp stalls to specific SASS / source lines so we can tell whether
# the ALU stalls cluster on PRMT / LOP3 (op-bound), on LDSM (smem-bound), on WGMMA (tensor),
# or on long_scoreboard (waiting for global / L2 fills). Re-encodes report with --import-source
# so ncu-ui can show the kernel's source alongside SASS.
NCU_SECTIONS+=(
--section SourceCounters
--section PmSampling
--section PmSampling_WarpStates
)
NCU_EXTRA_ARGS+=(--import-source yes)
REPORT_SUFFIX="_pmsample"
# PC sampling needs more steady-state samples; widen the capture window.
NCU_LAUNCH_COUNT=2
NCU_BENCH_ITERATIONS=8
fi
REPORT_BASE="${PROFILES_DIR}/${VARIANT,,}${REPORT_SUFFIX}"

export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-4}"

ncu \
--target-processes all \
--kernel-name "regex:.*device_kernel.*" \
--launch-skip 1 \
--launch-count "${NCU_LAUNCH_COUNT}" \
"${NCU_SECTIONS[@]}" \
"${NCU_EXTRA_ARGS[@]}" \
--export "${REPORT_BASE}" --force-overwrite \
"${BUILD_DIR}/w4a8_baseline_sweep" \
--only="${VARIANT}" --m=1 --n=1536 --k=4096 --groups=128 --c=128 \
--warmup=1 --iterations="${NCU_BENCH_ITERATIONS}"

# Text summary for human reading: page=details, full per-metric breakdown.
ncu --import "${REPORT_BASE}.ncu-rep" --page details --print-units base \
> "${REPORT_BASE}_details.txt" 2>&1 || true

# CSV with all metrics, one row per kernel launch.
ncu --import "${REPORT_BASE}.ncu-rep" --csv --page raw \
> "${REPORT_BASE}_metrics.csv" 2>&1 || true

echo
echo "ncu report : ${REPORT_BASE}.ncu-rep"
echo "ncu details : ${REPORT_BASE}_details.txt"
echo "ncu raw CSV : ${REPORT_BASE}_metrics.csv"
if [[ "${PMSAMPLE}" == "1" ]]; then
echo
echo "PC sampling enabled. Open the report in Nsight Compute UI and switch to"
echo "the Source page, then sort SASS by smsp__warp_issue_stalled_long_scoreboard"
echo "(or stall_short_scoreboard / stall_mio_throttle) to see which SASS lines"
echo "are absorbing the stalls."
fi
43 changes: 43 additions & 0 deletions microbenchmarks/w4a8/baseline/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env bash
# Build and run the W4A8 grouped GEMM tile/schedule sweep.
# Requires CUDA 12.3+ nvcc and a Hopper (SM90) GPU.
set -euo pipefail

ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
BUILD_DIR="${BUILD_DIR:-${ROOT}/build}"
CUTLASS_DIR="${CUTLASS_DIR:-$(cd "${ROOT}/../../.." && pwd)}"

if [[ -n "${CUDA_HOME:-}" && -x "${CUDA_HOME}/bin/nvcc" ]]; then
export PATH="${CUDA_HOME}/bin:${PATH}"
elif [[ -x "/usr/local/cuda/bin/nvcc" ]]; then
export CUDA_HOME="/usr/local/cuda"
export PATH="${CUDA_HOME}/bin:${PATH}"
fi
export CMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER:-$(command -v nvcc)}"

if ! command -v nvcc >/dev/null 2>&1; then
echo "nvcc not found. Set CUDA_HOME or add CUDA toolkit to PATH." >&2
exit 1
fi

cmake -S "${ROOT}" -B "${BUILD_DIR}" \
-DCMAKE_BUILD_TYPE=Release \
-DCUTLASS_DIR="${CUTLASS_DIR}" \
-DCMAKE_CUDA_ARCHITECTURES=90a \
-DCMAKE_CUDA_COMPILER="${CMAKE_CUDA_COMPILER}"

cmake --build "${BUILD_DIR}" -j "${JOBS:-$(nproc 2>/dev/null || echo 64)}" --target w4a8_baseline_sweep

# Default decode shape: 128 experts, M=1, N=1536, K=4096, c=128.
exec env CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" \
"${BUILD_DIR}/w4a8_baseline_sweep" \
--groups=128 \
--m=1 \
--n=1536 \
--k=4096 \
--c=128 \
--alpha=1 \
--beta=0 \
--warmup=20 \
--iterations=200 \
"$@"
Loading