diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py index 27165604c3..e771bb4d10 100644 --- a/examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py @@ -30,6 +30,8 @@ import ctypes import functools import math +import os +import sys from typing import Optional, Tuple, Union import cuda.bindings.driver as cuda @@ -44,10 +46,11 @@ from cutlass.cute.runtime import make_ptr # Support both direct execution and module import -try: - from .reduce import row_reduce -except ImportError: - from reduce import row_reduce +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(current_dir, "../../..")) + +from blackwell.kernel.reduce.reduce import row_reduce """ RMSNorm: Root Mean Square Layer Normalization for Hopper & Blackwell (SM90+) diff --git a/test/examples/CuTeDSL/conftest.py b/test/examples/CuTeDSL/conftest.py index 506ba1476f..28d75cc80c 100644 --- a/test/examples/CuTeDSL/conftest.py +++ b/test/examples/CuTeDSL/conftest.py @@ -39,6 +39,8 @@ project_root = Path(__file__).resolve().parent.parent.parent.parent + +cute_example_path = project_root / "examples" / "python" / "CuTeDSL" / "cute" example_path = project_root / "examples" / "python" / "CuTeDSL" utils_path = project_root / "test" / "utils" @@ -50,9 +52,11 @@ # Importing cutlass here, while sys.path is still clean, avoids that race. import cutlass # noqa: E402 (intentional early import) +sys.path.append(str(cute_example_path)) sys.path.append(str(example_path)) sys.path.append(str(utils_path)) + # The helper class to prevent modification of sys.path from test files # Only allow modification of sys.path from pytest monkeypatch API calls class ImmutableSysPath(list): @@ -70,6 +74,7 @@ class ImmutableSysPath(list): } for mtd in mutating_methods: + def mutating_method(self, *args, mtd=mtd, **kwargs): frame = sys._getframe().f_back if ( @@ -98,6 +103,7 @@ def __init__(self, initial=None): pytest_plugins = ["test_sharding"] + def pytest_addoption(parser): parser.addoption( "--sample-interval", diff --git a/test/examples/CuTeDSL/hopper/conftest.py b/test/examples/CuTeDSL/hopper/conftest.py index 0b648de36a..8dbc39c529 100644 --- a/test/examples/CuTeDSL/hopper/conftest.py +++ b/test/examples/CuTeDSL/hopper/conftest.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + def pytest_configure(config): config.default_SMs[__file__] = "90a" config.addinivalue_line( diff --git a/test/examples/CuTeDSL/hopper/test_dense_gemm_fp8_2xacc.py b/test/examples/CuTeDSL/hopper/test_dense_gemm_fp8_2xacc.py index 12d824c227..96e7c32534 100644 --- a/test/examples/CuTeDSL/hopper/test_dense_gemm_fp8_2xacc.py +++ b/test/examples/CuTeDSL/hopper/test_dense_gemm_fp8_2xacc.py @@ -55,7 +55,7 @@ import pytest import cutlass -from hopper.dense_gemm_fp8_2xacc import run +from hopper.kernel.dense_gemm.dense_gemm_fp8_2xacc import run # --------------------------------------------------------------------------- # Type aliases @@ -169,8 +169,8 @@ def _run_benchmark( [ pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"), pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"), - pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"), - pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"), + pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"), + pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"), ], ) def test_l0_tile_shapes(tile_shape_mn, mnkl): @@ -195,8 +195,11 @@ def test_l0_tile_shapes(tile_shape_mn, mnkl): ) def test_l0_cluster_shapes(cluster_shape_mn): """All valid cluster shapes compile (tile 128x128, 2048^3).""" - _run_compile(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128), - cluster_shape_mn=cluster_shape_mn) + _run_compile( + mnkl=(2048, 2048, 2048, 1), + tile_shape_mn=(128, 128), + cluster_shape_mn=cluster_shape_mn, + ) # --------------------------------------------------------------------------- @@ -208,8 +211,8 @@ def test_l0_cluster_shapes(cluster_shape_mn): @pytest.mark.parametrize( "c_dtype", [ - pytest.param(F16, id="Float16"), - pytest.param(F32, id="Float32"), + pytest.param(F16, id="Float16"), + pytest.param(F32, id="Float32"), pytest.param(F8E4, id="Float8E4M3FN"), ], ) @@ -227,8 +230,8 @@ def test_l0_output_dtypes(c_dtype): @pytest.mark.parametrize( "mma_promotion_interval", [ - pytest.param(4, id="interval4"), - pytest.param(8, id="interval8"), + pytest.param(4, id="interval4"), + pytest.param(8, id="interval8"), pytest.param(16, id="interval16"), ], ) @@ -249,8 +252,8 @@ def test_l0_mma_promotion_intervals(mma_promotion_interval): [ pytest.param((128, 256), (2048, 2048, 2048, 1), id="tile128x256"), pytest.param((128, 128), (2048, 2048, 2048, 1), id="tile128x128"), - pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"), - pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"), + pytest.param((128, 64), (2048, 2048, 2048, 1), id="tile128x64"), + pytest.param((64, 64), (2048, 2048, 2048, 1), id="tile64x64"), ], ) def test_l1_tile_shapes(tile_shape_mn, mnkl): @@ -276,8 +279,11 @@ def test_l1_tile_shapes(tile_shape_mn, mnkl): ) def test_l1_cluster_shapes(cluster_shape_mn): """All cluster shapes (including A/B multicast paths) produce correct results.""" - _run_correctness(mnkl=(2048, 2048, 2048, 1), tile_shape_mn=(128, 128), - cluster_shape_mn=cluster_shape_mn) + _run_correctness( + mnkl=(2048, 2048, 2048, 1), + tile_shape_mn=(128, 128), + cluster_shape_mn=cluster_shape_mn, + ) # --------------------------------------------------------------------------- @@ -290,9 +296,9 @@ def test_l1_cluster_shapes(cluster_shape_mn): @pytest.mark.parametrize( "c_dtype, tolerance", [ - pytest.param(F16, 0.1, id="Float16"), - pytest.param(F32, 0.1, id="Float32"), - pytest.param(F8E4, 0.5, id="Float8E4M3FN"), + pytest.param(F16, 0.1, id="Float16"), + pytest.param(F32, 0.1, id="Float32"), + pytest.param(F8E4, 0.5, id="Float8E4M3FN"), ], ) def test_l1_output_dtypes(c_dtype, tolerance): @@ -310,8 +316,8 @@ def test_l1_output_dtypes(c_dtype, tolerance): @pytest.mark.parametrize( "mma_promotion_interval", [ - pytest.param(4, id="interval4"), - pytest.param(8, id="interval8"), + pytest.param(4, id="interval4"), + pytest.param(8, id="interval8"), pytest.param(16, id="interval16"), ], ) @@ -330,9 +336,9 @@ def test_l1_mma_promotion_intervals(mma_promotion_interval): @pytest.mark.parametrize( "scale_a_val, scale_b_val", [ - pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"), - pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"), - pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"), + pytest.param(0.5, 2.0, id="scale_a0.5_b2.0"), + pytest.param(0.25, 4.0, id="scale_a0.25_b4.0"), + pytest.param(2.0, 0.5, id="scale_a2.0_b0.5"), ], ) def test_l1_scale_factors(scale_a_val, scale_b_val): @@ -351,7 +357,7 @@ def test_l1_scale_factors(scale_a_val, scale_b_val): "mnkl", [ pytest.param((1024, 1024, 1024, 2), id="L2"), - pytest.param((512, 512, 512, 4), id="L4"), + pytest.param((512, 512, 512, 4), id="L4"), ], ) def test_l1_batched(mnkl): @@ -370,25 +376,118 @@ def test_l1_batched(mnkl): "mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label", [ # Square 4096^3 — tile / cluster sweep - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 1), 4, "4096^3 tile=128x128 cluster=1x1", id="4096-128x128-1x1"), - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 tile=128x128 cluster=1x2", id="4096-128x128-1x2"), - pytest.param((4096, 4096, 4096, 1), (128, 128), (2, 2), 4, "4096^3 tile=128x128 cluster=2x2", id="4096-128x128-2x2"), - pytest.param((4096, 4096, 4096, 1), (128, 256), (1, 2), 4, "4096^3 tile=128x256 cluster=1x2", id="4096-128x256-1x2"), - pytest.param((4096, 4096, 4096, 1), (128, 256), (2, 2), 4, "4096^3 tile=128x256 cluster=2x2", id="4096-128x256-2x2"), - pytest.param((4096, 4096, 4096, 1), (128, 64), (1, 2), 4, "4096^3 tile=128x64 cluster=1x2", id="4096-128x64-1x2"), - pytest.param((4096, 4096, 4096, 1), (64, 64), (1, 2), 4, "4096^3 tile=64x64 cluster=1x2", id="4096-64x64-1x2"), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 1), + 4, + "4096^3 tile=128x128 cluster=1x1", + id="4096-128x128-1x1", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 2), + 4, + "4096^3 tile=128x128 cluster=1x2", + id="4096-128x128-1x2", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (2, 2), + 4, + "4096^3 tile=128x128 cluster=2x2", + id="4096-128x128-2x2", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 256), + (1, 2), + 4, + "4096^3 tile=128x256 cluster=1x2", + id="4096-128x256-1x2", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 256), + (2, 2), + 4, + "4096^3 tile=128x256 cluster=2x2", + id="4096-128x256-2x2", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 64), + (1, 2), + 4, + "4096^3 tile=128x64 cluster=1x2", + id="4096-128x64-1x2", + ), + pytest.param( + (4096, 4096, 4096, 1), + (64, 64), + (1, 2), + 4, + "4096^3 tile=64x64 cluster=1x2", + id="4096-64x64-1x2", + ), # LLM-like: 8192x8192x4096 - pytest.param((8192, 8192, 4096, 1), (128, 128), (1, 2), 4, "8192x8192x4096 tile=128x128 cluster=1x2", id="llm-128x128-1x2"), - pytest.param((8192, 8192, 4096, 1), (128, 256), (2, 2), 4, "8192x8192x4096 tile=128x256 cluster=2x2", id="llm-128x256-2x2"), + pytest.param( + (8192, 8192, 4096, 1), + (128, 128), + (1, 2), + 4, + "8192x8192x4096 tile=128x128 cluster=1x2", + id="llm-128x128-1x2", + ), + pytest.param( + (8192, 8192, 4096, 1), + (128, 256), + (2, 2), + 4, + "8192x8192x4096 tile=128x256 cluster=2x2", + id="llm-128x256-2x2", + ), # mma_promotion_interval sweep (shows precision/performance trade-off) - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 interval=4", id="4096-interval4"), - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 8, "4096^3 interval=8", id="4096-interval8"), - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 16, "4096^3 interval=16", id="4096-interval16"), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 2), + 4, + "4096^3 interval=4", + id="4096-interval4", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 2), + 8, + "4096^3 interval=8", + id="4096-interval8", + ), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 2), + 16, + "4096^3 interval=16", + id="4096-interval16", + ), # FP8 output - pytest.param((4096, 4096, 4096, 1), (128, 128), (1, 2), 4, "4096^3 out=FP8E4M3", id="4096-fp8-out"), + pytest.param( + (4096, 4096, 4096, 1), + (128, 128), + (1, 2), + 4, + "4096^3 out=FP8E4M3", + id="4096-fp8-out", + ), ], ) -def test_bench(mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys): +def test_bench( + mnkl, tile_shape_mn, cluster_shape_mn, mma_promotion_interval, label, capsys +): """ Performance benchmark — run with: pytest -m bench -s diff --git a/test/examples/CuTeDSL/hopper/test_grouped_gemm.py b/test/examples/CuTeDSL/hopper/test_grouped_gemm.py index e38d4084fd..1b15febcde 100644 --- a/test/examples/CuTeDSL/hopper/test_grouped_gemm.py +++ b/test/examples/CuTeDSL/hopper/test_grouped_gemm.py @@ -59,7 +59,7 @@ import pytest import cutlass import cutlass.utils as utils -from hopper.grouped_gemm import run +from hopper.kernel.grouped_gemm.grouped_gemm import run # --------------------------------------------------------------------------- @@ -203,8 +203,8 @@ def _run_case( [ pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"), pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"), - pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"), - pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"), + pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"), + pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"), ], ) def test_l0_tile_shapes(tile_shape_mn, problem_sizes_mnkl, tmap_mode): @@ -222,15 +222,20 @@ def test_l0_tile_shapes(tile_shape_mn, problem_sizes_mnkl, tmap_mode): @pytest.mark.parametrize( "num_groups, problem_sizes_mnkl", [ - pytest.param(2, [(128, 256, 64, 1)] * 2, id="2g-uniform"), - pytest.param(4, [(128, 256, 64, 1), (64, 128, 64, 1), - (256, 128, 64, 1), (192, 256, 64, 1)], id="4g-mixed"), - pytest.param(8, [(128, 256, 64, 1)] * 8, id="8g-uniform"), + pytest.param(2, [(128, 256, 64, 1)] * 2, id="2g-uniform"), + pytest.param( + 4, + [(128, 256, 64, 1), (64, 128, 64, 1), (256, 128, 64, 1), (192, 256, 64, 1)], + id="4g-mixed", + ), + pytest.param(8, [(128, 256, 64, 1)] * 8, id="8g-uniform"), ], ) def test_l0_group_counts(num_groups, problem_sizes_mnkl, tmap_mode): """Various group counts compile for tile (128,256) fp16.""" - _run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode) + _run_compile( + num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode + ) # --------------------------------------------------------------------------- @@ -244,28 +249,43 @@ def test_l0_group_counts(num_groups, problem_sizes_mnkl, tmap_mode): "a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl", [ # fp16 → fp16 output - pytest.param(F16, F16, F16, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp32"), + pytest.param(F16, F16, F16, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp32"), # fp16 → fp32 output - pytest.param(F16, F16, F32, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp32-fp32"), + pytest.param(F16, F16, F32, F32, [(128, 256, 64, 1)], id="fp16-fp16-fp32-fp32"), # fp16 with fp16 accumulator - pytest.param(F16, F16, F16, F16, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp16"), + pytest.param(F16, F16, F16, F16, [(128, 256, 64, 1)], id="fp16-fp16-fp16-fp16"), # fp8 E4M3 → fp16 output (K must be multiple of 16 for fp8 alignment) - pytest.param(F8E4, F8E4, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e4-fp16-fp32"), + pytest.param( + F8E4, F8E4, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e4-fp16-fp32" + ), # fp8 E5M2 → fp16 output - pytest.param(F8E5, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e5-fp8e5-fp16-fp32"), + pytest.param( + F8E5, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e5-fp8e5-fp16-fp32" + ), # mixed fp8: E4M3 × E5M2 - pytest.param(F8E4, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e5-fp16-fp32"), + pytest.param( + F8E4, F8E5, F16, F32, [(128, 256, 128, 1)], id="fp8e4-fp8e5-fp16-fp32" + ), # int8 → int32 output (K must be multiple of 16) - pytest.param(I8, I8, I32, I32, [(128, 256, 128, 1)], id="int8-int8-int32-int32"), + pytest.param( + I8, I8, I32, I32, [(128, 256, 128, 1)], id="int8-int8-int32-int32" + ), # uint8 → int32 output - pytest.param(U8, U8, I32, I32, [(128, 256, 128, 1)], id="uint8-uint8-int32-int32"), + pytest.param( + U8, U8, I32, I32, [(128, 256, 128, 1)], id="uint8-uint8-int32-int32" + ), ], ) def test_l0_dtypes(a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl, tmap_mode): """Data type combinations compile for tile (128,256).""" _run_compile( - 1, problem_sizes_mnkl, (128, 256), - a_dtype=a_dtype, b_dtype=b_dtype, c_dtype=c_dtype, acc_dtype=acc_dtype, + 1, + problem_sizes_mnkl, + (128, 256), + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=c_dtype, + acc_dtype=acc_dtype, tensormap_update_mode=tmap_mode, ) @@ -289,14 +309,22 @@ def test_l0_dtypes(a_dtype, b_dtype, c_dtype, acc_dtype, problem_sizes_mnkl, tma # m-major C output (M must be multiple of 8) pytest.param("k", "k", "m", [(128, 256, 64, 1)], (128, 256), id="akm-bkm-cmaj"), # m-major A + n-major B - pytest.param("m", "n", "n", [(128, 128, 64, 1)], (128, 128), id="amaj-bnmaj-cn"), + pytest.param( + "m", "n", "n", [(128, 128, 64, 1)], (128, 128), id="amaj-bnmaj-cn" + ), ], ) -def test_l0_major_modes(a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn, tmap_mode): +def test_l0_major_modes( + a_major, b_major, c_major, problem_sizes_mnkl, tile_shape_mn, tmap_mode +): """Matrix major mode combinations compile.""" _run_compile( - 1, problem_sizes_mnkl, tile_shape_mn, - a_major=a_major, b_major=b_major, c_major=c_major, + 1, + problem_sizes_mnkl, + tile_shape_mn, + a_major=a_major, + b_major=b_major, + c_major=c_major, tensormap_update_mode=tmap_mode, ) @@ -321,10 +349,14 @@ def test_l0_major_modes(a_major, b_major, c_major, problem_sizes_mnkl, tile_shap pytest.param((2, 2), [(256, 512, 64, 1)], (128, 256), id="cluster2x2"), ], ) -def test_l0_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, tmap_mode): +def test_l0_cluster_shapes( + cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, tmap_mode +): """Cluster shapes including multicast paths compile.""" _run_compile( - 1, problem_sizes_mnkl, tile_shape_mn, + 1, + problem_sizes_mnkl, + tile_shape_mn, cluster_shape_mn=cluster_shape_mn, tensormap_update_mode=tmap_mode, ) @@ -341,18 +373,20 @@ def test_l0_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tile_shape_mn, "num_groups, problem_sizes_mnkl", [ # groups with very different shapes - pytest.param(4, [(64, 64, 64, 1), - (128, 128, 64, 1), - (256, 128, 64, 1), - (128, 256, 64, 1)], id="4g-all-tiles"), + pytest.param( + 4, + [(64, 64, 64, 1), (128, 128, 64, 1), (256, 128, 64, 1), (128, 256, 64, 1)], + id="4g-all-tiles", + ), # tiny vs large - pytest.param(2, [(64, 64, 64, 1), - (512, 512, 64, 1)], id="2g-tiny-large"), + pytest.param(2, [(64, 64, 64, 1), (512, 512, 64, 1)], id="2g-tiny-large"), ], ) def test_l0_mixed_problem_sizes(num_groups, problem_sizes_mnkl, tmap_mode): """Heterogeneous per-group problem sizes compile.""" - _run_compile(num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode) + _run_compile( + num_groups, problem_sizes_mnkl, (128, 256), tensormap_update_mode=tmap_mode + ) # --------------------------------------------------------------------------- @@ -386,13 +420,15 @@ def test_l1_fp16_4g_mixed(tmap_mode): [ pytest.param((128, 256), [(128, 256, 64, 1)], id="tile128x256"), pytest.param((128, 128), [(128, 128, 64, 1)], id="tile128x128"), - pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"), - pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"), + pytest.param((128, 64), [(128, 64, 64, 1)], id="tile128x64"), + pytest.param((64, 64), [(64, 64, 64, 1)], id="tile64x64"), ], ) def test_l1_tile_shapes_fp16(tile_shape_mn, problem_sizes_mnkl, tmap_mode): """All tile shapes produce correct results.""" - _run_correctness(1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode) + _run_correctness( + 1, problem_sizes_mnkl, tile_shape_mn, tensormap_update_mode=tmap_mode + ) # --------------------------------------------------------------------------- @@ -429,8 +465,11 @@ def test_l1_group_count_scaling(num_groups, tmap_mode): def test_l1_fp16_c_fp32(tmap_mode): """fp16 inputs with fp32 output are numerically correct.""" _run_correctness( - 1, [(128, 256, 64, 1)], (128, 256), - c_dtype=F32, acc_dtype=F32, + 1, + [(128, 256, 64, 1)], + (128, 256), + c_dtype=F32, + acc_dtype=F32, tensormap_update_mode=tmap_mode, ) @@ -441,8 +480,13 @@ def test_l1_fp16_c_fp32(tmap_mode): def test_l1_fp8_e4m3(tmap_mode): """fp8 E4M3FN inputs are numerically correct (K=128 for 16B alignment).""" _run_correctness( - 1, [(128, 256, 128, 1)], (128, 256), - a_dtype=F8E4, b_dtype=F8E4, c_dtype=F16, acc_dtype=F32, + 1, + [(128, 256, 128, 1)], + (128, 256), + a_dtype=F8E4, + b_dtype=F8E4, + c_dtype=F16, + acc_dtype=F32, tensormap_update_mode=tmap_mode, tolerance=0.5, ) @@ -454,8 +498,13 @@ def test_l1_fp8_e4m3(tmap_mode): def test_l1_fp8_mixed(tmap_mode): """Mixed fp8 inputs (E4M3 × E5M2) are numerically correct.""" _run_correctness( - 1, [(128, 256, 128, 1)], (128, 256), - a_dtype=F8E4, b_dtype=F8E5, c_dtype=F16, acc_dtype=F32, + 1, + [(128, 256, 128, 1)], + (128, 256), + a_dtype=F8E4, + b_dtype=F8E5, + c_dtype=F16, + acc_dtype=F32, tensormap_update_mode=tmap_mode, tolerance=0.5, ) @@ -467,8 +516,13 @@ def test_l1_fp8_mixed(tmap_mode): def test_l1_int8(tmap_mode): """int8 inputs with int32 accumulator are correct.""" _run_correctness( - 1, [(128, 256, 128, 1)], (128, 256), - a_dtype=I8, b_dtype=I8, c_dtype=I32, acc_dtype=I32, + 1, + [(128, 256, 128, 1)], + (128, 256), + a_dtype=I8, + b_dtype=I8, + c_dtype=I32, + acc_dtype=I32, tensormap_update_mode=tmap_mode, tolerance=0, ) @@ -485,7 +539,9 @@ def test_l1_int8(tmap_mode): def test_l1_c_m_major(tmap_mode): """m-major C output is correct.""" _run_correctness( - 1, [(128, 256, 64, 1)], (128, 256), + 1, + [(128, 256, 64, 1)], + (128, 256), c_major="m", tensormap_update_mode=tmap_mode, ) @@ -498,8 +554,12 @@ def test_l1_c_m_major(tmap_mode): def test_l1_all_non_default_majors(tmap_mode): """m-major A, n-major B, m-major C together are correct.""" _run_correctness( - 1, [(64, 64, 64, 1)], (128, 128), - a_major="m", b_major="n", c_major="m", + 1, + [(64, 64, 64, 1)], + (128, 128), + a_major="m", + b_major="n", + c_major="m", tensormap_update_mode=tmap_mode, ) @@ -521,7 +581,9 @@ def test_l1_all_non_default_majors(tmap_mode): def test_l1_cluster_shapes(cluster_shape_mn, problem_sizes_mnkl, tmap_mode): """Multicast cluster shapes produce correct results.""" _run_correctness( - 1, problem_sizes_mnkl, (128, 256), + 1, + problem_sizes_mnkl, + (128, 256), cluster_shape_mn=cluster_shape_mn, tensormap_update_mode=tmap_mode, ) @@ -540,14 +602,14 @@ def test_l1_8g_mixed_sizes(tmap_mode): _run_correctness( 8, [ - (128, 256, 64, 1), - (64, 128, 64, 1), - (256, 128, 64, 1), + (128, 256, 64, 1), + (64, 128, 64, 1), + (256, 128, 64, 1), (128, 128, 128, 1), - (192, 256, 64, 1), - (64, 64, 64, 1), + (192, 256, 64, 1), + (64, 64, 64, 1), (128, 256, 128, 1), - (256, 256, 64, 1), + (256, 256, 64, 1), ], (128, 256), tensormap_update_mode=tmap_mode, diff --git a/test/examples/CuTeDSL/sm_100a/conftest.py b/test/examples/CuTeDSL/sm_100a/conftest.py index 6a2f300232..f263300a6c 100644 --- a/test/examples/CuTeDSL/sm_100a/conftest.py +++ b/test/examples/CuTeDSL/sm_100a/conftest.py @@ -26,5 +26,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + def pytest_configure(config): - config.default_SMs[__file__] = "100f" \ No newline at end of file + config.default_SMs[__file__] = "100f" diff --git a/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py b/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py index 116c796398..e6f5b8eac0 100644 --- a/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py +++ b/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py @@ -41,29 +41,33 @@ import pytest -from blackwell.dense_blockscaled_gemm_persistent_prefetch import ( +from blackwell.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent_prefetch import ( Sm100BlockScaledPersistentDenseGemmKernel, run, ) import cutlass + pytestmark = [pytest.mark.arch(["100a"])] + @pytest.mark.invalid_case( - lambda: not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( - ab_dtype, - sf_dtype, - sf_vec_size, - c_dtype, - mma_tiler_mn, - cluster_shape_mn, - mnkl[0], - mnkl[1], - mnkl[2], - mnkl[3], - a_major, - b_major, - c_major, + lambda: ( + not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + mnkl[0], + mnkl[1], + mnkl[2], + mnkl[3], + a_major, + b_major, + c_major, + ) ) ) @pytest.mark.parametrize( @@ -110,7 +114,7 @@ "prefetch_dist", [ None, # Default: auto (uses num_ab_stage) - 0, # Disabled + 0, # Disabled ], ) @pytest.mark.parametrize("tolerance", [1e-01]) @@ -145,20 +149,22 @@ def test_dense_blockscaled_gemm_prefetch( @pytest.mark.invalid_case( - lambda: not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( - ab_dtype, - sf_dtype, - sf_vec_size, - c_dtype, - mma_tiler_mn, - cluster_shape_mn, - mnkl[0], - mnkl[1], - mnkl[2], - mnkl[3], - a_major, - b_major, - c_major, + lambda: ( + not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + mnkl[0], + mnkl[1], + mnkl[2], + mnkl[3], + a_major, + b_major, + c_major, + ) ) ) @pytest.mark.parametrize( @@ -190,7 +196,7 @@ def test_dense_blockscaled_gemm_prefetch( "prefetch_dist", [ None, # Default: auto (uses num_ab_stage) - 4, # Explicit distance + 4, # Explicit distance ], ) @pytest.mark.parametrize("tolerance", [1e-01]) @@ -228,15 +234,15 @@ def test_dense_blockscaled_gemm_prefetch_L0( "prefetch_dist", [ None, # Auto: uses num_ab_stage - 0, # Disabled - 2, # Small distance - 4, # Medium distance + 0, # Disabled + 2, # Small distance + 4, # Medium distance ], ) def test_prefetch_dist_configurations(prefetch_dist: Optional[int]): """ Test different prefetch_dist configurations specifically for blockscaled GEMM. - + - None: Auto mode, uses num_ab_stage as prefetch distance - 0: Prefetch disabled - >0: Explicit prefetch distance @@ -451,4 +457,3 @@ def test_invalid_tensor_alignment( cluster_shape_mn, tolerance, ) - diff --git a/test/examples/CuTeDSL/sm_100a/test_dense_gemm_persistent_prefetch.py b/test/examples/CuTeDSL/sm_100a/test_dense_gemm_persistent_prefetch.py index 51cc59d64e..5bc29eaacf 100644 --- a/test/examples/CuTeDSL/sm_100a/test_dense_gemm_persistent_prefetch.py +++ b/test/examples/CuTeDSL/sm_100a/test_dense_gemm_persistent_prefetch.py @@ -40,8 +40,7 @@ import pytest -from blackwell.dense_gemm_persistent_prefetch import ( - PersistentDenseGemmKernel, +from blackwell.kernel.dense_gemm.dense_gemm_persistent_prefetch import ( run, ) @@ -92,8 +91,8 @@ "prefetch_dist", [ None, # Default: auto (uses num_ab_stage) - 0, # Disabled - 2, # Explicit distance + 0, # Disabled + 2, # Explicit distance ], ) @pytest.mark.parametrize("tolerance", [1e-01]) @@ -168,7 +167,7 @@ def test_dense_gemm_prefetch( "prefetch_dist", [ None, # Default: auto (uses num_ab_stage) - 4, # Explicit distance + 4, # Explicit distance ], ) def test_dense_gemm_prefetch_L0( @@ -215,15 +214,15 @@ def test_dense_gemm_prefetch_L0( "prefetch_dist", [ None, # Auto: uses num_ab_stage - 0, # Disabled - 2, # Small distance - 4, # Medium distance + 0, # Disabled + 2, # Small distance + 4, # Medium distance ], ) def test_prefetch_dist_configurations(prefetch_dist: Optional[int]): """ Test different prefetch_dist configurations specifically. - + - None: Auto mode, uses num_ab_stage as prefetch distance - 0: Prefetch disabled - >0: Explicit prefetch distance @@ -259,4 +258,3 @@ def test_prefetch_dist_configurations(prefetch_dist: Optional[int]): ) except testing.CantImplementError: pytest.skip(f"Skip unsupported testcase with prefetch_dist={prefetch_dist}") - diff --git a/test/examples/CuTeDSL/sm_100a/test_rmsnorm.py b/test/examples/CuTeDSL/sm_100a/test_rmsnorm.py index 2b23a4e958..a118566f9e 100644 --- a/test/examples/CuTeDSL/sm_100a/test_rmsnorm.py +++ b/test/examples/CuTeDSL/sm_100a/test_rmsnorm.py @@ -38,11 +38,10 @@ """ import pytest -import torch import cutlass -from blackwell.rmsnorm import ( +from blackwell.kernel.rmsnorm.rmsnorm import ( run, get_sm_version, supports_cluster, @@ -104,7 +103,9 @@ def test_rmsnorm_without_weight(self, N): class TestRMSNormClusterPath: """Test the cluster path for large N (SM90+/SM100 only).""" - @pytest.mark.skipif(not supports_cluster(), reason="Cluster not supported on this GPU") + @pytest.mark.skipif( + not supports_cluster(), reason="Cluster not supported on this GPU" + ) @pytest.mark.parametrize("N", [32768, 65536]) def test_cluster_path_correctness(self, N): """Test cluster path produces correct results.""" @@ -119,6 +120,7 @@ def test_cluster_path_correctness(self, N): benchmark=False, ) + class TestRMSNormLargeN: """Test RMSNorm with large N values.""" @@ -151,7 +153,6 @@ def test_large_batch_dim(self, M): ) - class TestRMSNormEdgeCases: """Test edge cases for RMSNorm.""" @@ -197,4 +198,4 @@ def test_float32_correctness(self, N): tolerance=1e-4, # Tighter tolerance for FP32 skip_ref_check=False, benchmark=False, - ) \ No newline at end of file + ) diff --git a/test/examples/CuTeDSL/sm_100a/test_tutorial_gemm.py b/test/examples/CuTeDSL/sm_100a/test_tutorial_gemm.py index 291d1e593f..512aad09ff 100644 --- a/test/examples/CuTeDSL/sm_100a/test_tutorial_gemm.py +++ b/test/examples/CuTeDSL/sm_100a/test_tutorial_gemm.py @@ -26,14 +26,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from blackwell.tutorial_gemm import fp16_gemm_0 -from blackwell.tutorial_gemm import fp16_gemm_1 -from blackwell.tutorial_gemm import fp16_gemm_2 -from blackwell.tutorial_gemm import fp16_gemm_3 -from blackwell.tutorial_gemm import fp16_gemm_3_1 -from blackwell.tutorial_gemm import fp16_gemm_4 -from blackwell.tutorial_gemm import fp16_gemm_5 -from blackwell.tutorial_gemm import fp16_gemm_6 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_0 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_1 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_2 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_3 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_3_1 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_4 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_5 +from blackwell.tutorial.tutorial_gemm import fp16_gemm_6 import pytest from typing import Tuple @@ -63,7 +63,6 @@ def test_fp16_gemm_1( fp16_gemm_1.run_dense_gemm(mnk, tolerance) - @pytest.mark.parametrize( "mnk", [(512, 512, 256)], diff --git a/test/examples/CuTeDSL/test_math.py b/test/examples/CuTeDSL/test_math.py index 9c63900f02..fc257cd804 100644 --- a/test/examples/CuTeDSL/test_math.py +++ b/test/examples/CuTeDSL/test_math.py @@ -36,8 +36,10 @@ @cute.kernel def _unary_ops_kernel( - absf_inp: cute.Tensor, absf_out: cute.Tensor, - floor_inp: cute.Tensor, floor_out: cute.Tensor, + absf_inp: cute.Tensor, + absf_out: cute.Tensor, + floor_inp: cute.Tensor, + floor_out: cute.Tensor, ): tidx, _, _ = cute.arch.thread_idx() absf_out[tidx] = cute.math.absf(absf_inp[tidx]) @@ -46,8 +48,10 @@ def _unary_ops_kernel( @cute.jit def _unary_ops_host( - absf_inp: cute.Tensor, absf_out: cute.Tensor, - floor_inp: cute.Tensor, floor_out: cute.Tensor, + absf_inp: cute.Tensor, + absf_out: cute.Tensor, + floor_inp: cute.Tensor, + floor_out: cute.Tensor, ): _unary_ops_kernel(absf_inp, absf_out, floor_inp, floor_out).launch( grid=[1, 1, 1], block=[absf_inp.shape[0], 1, 1] @@ -77,7 +81,9 @@ def test_unary_ops(): @cute.kernel def _binary_ops_kernel( - mag_inp: cute.Tensor, sign_inp: cute.Tensor, out: cute.Tensor, + mag_inp: cute.Tensor, + sign_inp: cute.Tensor, + out: cute.Tensor, ): tidx, _, _ = cute.arch.thread_idx() out[tidx] = cute.math.copysign(mag_inp[tidx], sign_inp[tidx]) @@ -85,7 +91,9 @@ def _binary_ops_kernel( @cute.jit def _binary_ops_host( - mag_inp: cute.Tensor, sign_inp: cute.Tensor, out: cute.Tensor, + mag_inp: cute.Tensor, + sign_inp: cute.Tensor, + out: cute.Tensor, ): _binary_ops_kernel(mag_inp, sign_inp, out).launch( grid=[1, 1, 1], block=[mag_inp.shape[0], 1, 1]