Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import ctypes
import functools
import math
import os
import sys
from typing import Optional, Tuple, Union

import cuda.bindings.driver as cuda
Expand All @@ -44,10 +46,11 @@
from cutlass.cute.runtime import make_ptr

# Support both direct execution and module import
try:
from .reduce import row_reduce
except ImportError:
from reduce import row_reduce
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../../.."))

from blackwell.kernel.reduce.reduce import row_reduce

"""
RMSNorm: Root Mean Square Layer Normalization for Hopper & Blackwell (SM90+)
Expand Down
6 changes: 6 additions & 0 deletions test/examples/CuTeDSL/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@


project_root = Path(__file__).resolve().parent.parent.parent.parent

cute_example_path = project_root / "examples" / "python" / "CuTeDSL" / "cute"
example_path = project_root / "examples" / "python" / "CuTeDSL"
utils_path = project_root / "test" / "utils"

Expand All @@ -50,9 +52,11 @@
# Importing cutlass here, while sys.path is still clean, avoids that race.
import cutlass # noqa: E402 (intentional early import)

sys.path.append(str(cute_example_path))
sys.path.append(str(example_path))
sys.path.append(str(utils_path))


# The helper class to prevent modification of sys.path from test files
# Only allow modification of sys.path from pytest monkeypatch API calls
class ImmutableSysPath(list):
Expand All @@ -70,6 +74,7 @@ class ImmutableSysPath(list):
}

for mtd in mutating_methods:

def mutating_method(self, *args, mtd=mtd, **kwargs):
frame = sys._getframe().f_back
if (
Expand Down Expand Up @@ -98,6 +103,7 @@ def __init__(self, initial=None):

pytest_plugins = ["test_sharding"]


def pytest_addoption(parser):
parser.addoption(
"--sample-interval",
Expand Down
1 change: 1 addition & 0 deletions test/examples/CuTeDSL/hopper/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


def pytest_configure(config):
config.default_SMs[__file__] = "90a"
config.addinivalue_line(
Expand Down
171 changes: 135 additions & 36 deletions test/examples/CuTeDSL/hopper/test_dense_gemm_fp8_2xacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

import pytest
import cutlass
from hopper.dense_gemm_fp8_2xacc import run
from hopper.kernel.dense_gemm.dense_gemm_fp8_2xacc import run

# ---------------------------------------------------------------------------
# Type aliases
Expand Down Expand Up @@ -169,8 +169,8 @@ def _run_benchmark(
[
pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"),
pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"),
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
],
)
def test_l0_tile_shapes(tile_shape_mn, mnkl):
Expand All @@ -195,8 +195,11 @@ def test_l0_tile_shapes(tile_shape_mn, mnkl):
)
def test_l0_cluster_shapes(cluster_shape_mn):
"""All valid cluster shapes compile (tile 128x128, 2048^3)."""
_run_compile(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128),
cluster_shape_mn=cluster_shape_mn)
_run_compile(
mnkl=(2048, 2048, 2048, 1),
tile_shape_mn=(128, 128),
cluster_shape_mn=cluster_shape_mn,
)


# ---------------------------------------------------------------------------
Expand All @@ -208,8 +211,8 @@ def test_l0_cluster_shapes(cluster_shape_mn):
@pytest.mark.parametrize(
"c_dtype",
[
pytest.param(F16, id="Float16"),
pytest.param(F32, id="Float32"),
pytest.param(F16, id="Float16"),
pytest.param(F32, id="Float32"),
pytest.param(F8E4, id="Float8E4M3FN"),
],
)
Expand All @@ -227,8 +230,8 @@ def test_l0_output_dtypes(c_dtype):
@pytest.mark.parametrize(
"mma_promotion_interval",
[
pytest.param(4, id="interval4"),
pytest.param(8, id="interval8"),
pytest.param(4, id="interval4"),
pytest.param(8, id="interval8"),
pytest.param(16, id="interval16"),
],
)
Expand All @@ -249,8 +252,8 @@ def test_l0_mma_promotion_intervals(mma_promotion_interval):
[
pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"),
pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"),
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"),
pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"),
],
)
def test_l1_tile_shapes(tile_shape_mn, mnkl):
Expand All @@ -276,8 +279,11 @@ def test_l1_tile_shapes(tile_shape_mn, mnkl):
)
def test_l1_cluster_shapes(cluster_shape_mn):
"""All cluster shapes (including A/B multicast paths) produce correct results."""
_run_correctness(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128),
cluster_shape_mn=cluster_shape_mn)
_run_correctness(
mnkl=(2048, 2048, 2048, 1),
tile_shape_mn=(128, 128),
cluster_shape_mn=cluster_shape_mn,
)


# ---------------------------------------------------------------------------
Expand All @@ -290,9 +296,9 @@ def test_l1_cluster_shapes(cluster_shape_mn):
@pytest.mark.parametrize(
"c_dtype, tolerance",
[
pytest.param(F16, 0.1, id="Float16"),
pytest.param(F32, 0.1, id="Float32"),
pytest.param(F8E4, 0.5, id="Float8E4M3FN"),
pytest.param(F16, 0.1, id="Float16"),
pytest.param(F32, 0.1, id="Float32"),
pytest.param(F8E4, 0.5, id="Float8E4M3FN"),
],
)
def test_l1_output_dtypes(c_dtype, tolerance):
Expand All @@ -310,8 +316,8 @@ def test_l1_output_dtypes(c_dtype, tolerance):
@pytest.mark.parametrize(
"mma_promotion_interval",
[
pytest.param(4, id="interval4"),
pytest.param(8, id="interval8"),
pytest.param(4, id="interval4"),
pytest.param(8, id="interval8"),
pytest.param(16, id="interval16"),
],
)
Expand All @@ -330,9 +336,9 @@ def test_l1_mma_promotion_intervals(mma_promotion_interval):
@pytest.mark.parametrize(
"scale_a_val, scale_b_val",
[
pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"),
pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"),
pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"),
pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"),
pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"),
pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"),
],
)
def test_l1_scale_factors(scale_a_val, scale_b_val):
Expand All @@ -351,7 +357,7 @@ def test_l1_scale_factors(scale_a_val, scale_b_val):
"mnkl",
[
pytest.param((1024, 1024, 1024, 2), id="L2"),
pytest.param((512, 512, 512, 4), id="L4"),
pytest.param((512, 512, 512, 4), id="L4"),
],
)
def test_l1_batched(mnkl):
Expand All @@ -370,25 +376,118 @@ def test_l1_batched(mnkl):
"mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label",
[
# Square 4096^3 — tile / cluster sweep
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 1), 4, "4096^3 tile=128x128 cluster=1x1", id="4096-128x128-1x1"),
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 tile=128x128 cluster=1x2", id="4096-128x128-1x2"),
pytest.param((4096, 4096, 4096, 1), (128, 128), (2, 2), 4, "4096^3 tile=128x128 cluster=2x2", id="4096-128x128-2x2"),
pytest.param((4096, 4096, 4096, 1), (128, 256), (1, 2), 4, "4096^3 tile=128x256 cluster=1x2", id="4096-128x256-1x2"),
pytest.param((4096, 4096, 4096, 1), (128, 256), (2, 2), 4, "4096^3 tile=128x256 cluster=2x2", id="4096-128x256-2x2"),
pytest.param((4096, 4096, 4096, 1), (128, 64), (1, 2), 4, "4096^3 tile=128x64 cluster=1x2", id="4096-128x64-1x2"),
pytest.param((4096, 4096, 4096, 1), (64, 64), (1, 2), 4, "4096^3 tile=64x64 cluster=1x2", id="4096-64x64-1x2"),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 1),
4,
"4096^3 tile=128x128 cluster=1x1",
id="4096-128x128-1x1",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 2),
4,
"4096^3 tile=128x128 cluster=1x2",
id="4096-128x128-1x2",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(2, 2),
4,
"4096^3 tile=128x128 cluster=2x2",
id="4096-128x128-2x2",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 256),
(1, 2),
4,
"4096^3 tile=128x256 cluster=1x2",
id="4096-128x256-1x2",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 256),
(2, 2),
4,
"4096^3 tile=128x256 cluster=2x2",
id="4096-128x256-2x2",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 64),
(1, 2),
4,
"4096^3 tile=128x64 cluster=1x2",
id="4096-128x64-1x2",
),
pytest.param(
(4096, 4096, 4096, 1),
(64, 64),
(1, 2),
4,
"4096^3 tile=64x64 cluster=1x2",
id="4096-64x64-1x2",
),
# LLM-like: 8192x8192x4096
pytest.param((8192, 8192, 4096, 1), (128, 128), (1, 2), 4, "8192x8192x4096 tile=128x128 cluster=1x2", id="llm-128x128-1x2"),
pytest.param((8192, 8192, 4096, 1), (128, 256), (2, 2), 4, "8192x8192x4096 tile=128x256 cluster=2x2", id="llm-128x256-2x2"),
pytest.param(
(8192, 8192, 4096, 1),
(128, 128),
(1, 2),
4,
"8192x8192x4096 tile=128x128 cluster=1x2",
id="llm-128x128-1x2",
),
pytest.param(
(8192, 8192, 4096, 1),
(128, 256),
(2, 2),
4,
"8192x8192x4096 tile=128x256 cluster=2x2",
id="llm-128x256-2x2",
),
# mma_promotion_interval sweep (shows precision/performance trade-off)
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 interval=4", id="4096-interval4"),
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 8, "4096^3 interval=8", id="4096-interval8"),
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 16, "4096^3 interval=16", id="4096-interval16"),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 2),
4,
"4096^3 interval=4",
id="4096-interval4",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 2),
8,
"4096^3 interval=8",
id="4096-interval8",
),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 2),
16,
"4096^3 interval=16",
id="4096-interval16",
),
# FP8 output
pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 out=FP8E4M3", id="4096-fp8-out"),
pytest.param(
(4096, 4096, 4096, 1),
(128, 128),
(1, 2),
4,
"4096^3 out=FP8E4M3",
id="4096-fp8-out",
),
],
)
def test_bench(mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys):
def test_bench(
mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys
):
"""
Performance benchmark — run with: pytest -m bench -s

Expand Down
Loading