From fcdadf83fbe5ed6467fdc530bb22c6e25b15d360 Mon Sep 17 00:00:00 2001 From: Katja Sirazitdinova Date: Wed, 13 May 2026 15:27:45 +0000 Subject: [PATCH] Update CuTe DSL JAX tutorial --- .../dsl_tutorials/jax/cute_dsl_jax.ipynb | 1059 ++++++++++++----- .../dsl_tutorials/jax/cute_dsl_jax_kernels.py | 77 +- 2 files changed, 773 insertions(+), 363 deletions(-) diff --git a/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb index 70ef40b759..8e49f04934 100644 --- a/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb +++ b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "5dd6f35b", "metadata": {}, "source": [ "# Writing High-Performance GPU Kernels with CuTe DSL and JAX\n", @@ -24,6 +25,7 @@ }, { "cell_type": "markdown", + "id": "7f492727", "metadata": {}, "source": [ "## Introduction\n", @@ -37,6 +39,7 @@ }, { "cell_type": "markdown", + "id": "bf2a63ae", "metadata": {}, "source": [ "## The CuTe mental model\n", @@ -56,7 +59,7 @@ "\n", "**Stride** tells CuTe how far you move in memory when you step along each dimension. In a row-major `(4, 8)` matrix, memory is laid out row by row: the first 8 elements belong to row 0, the next 8 to row 1, and so on. Moving one column to the right simply advances to the next element in memory (stride 1). Moving one row down skips over an entire row of 8 elements (stride 8). So the stride is `(8, 1)`.\n", "\n", - "**Layout = (Shape, Stride)** is CuTe's central abstraction. Shape and stride must have the same rank — each logical dimension must have a corresponding stride. \n", + "**Layout = (Shape, Stride)** is CuTe's central abstraction. Shape and stride must have the same rank — each logical dimension must have a corresponding stride.\n", "\n", "Although we think of tensors as multi-dimensional, GPU memory itself is just a long one-dimensional array of elements. Given a coordinate, a layout tells you where that element lives in memory. It does this by combining the coordinate with the stride:\n", "\n", @@ -72,9 +75,9 @@ "\n", "One important thing to note here: in *row-major* layout, elements of each row are stored contiguously in memory (so columns vary fastest), whereas in *column-major* layout, elements of each column are stored contiguously (so rows vary fastest), meaning the logical shape stays the same but the stride — and therefore the memory access pattern — changes.\n", "\n", - "For example, with layout in row-major order `((4, 8), (8, 1))`, coordinate `(2, 5)` maps to offset `2*8 + 5*1 = 21`. \n", + "For example, with layout in row-major order `((4, 8), (8, 1))`, coordinate `(2, 5)` maps to offset `2*8 + 5*1 = 21`.\n", "\n", - "A column-major layout for the same shape would use stride `(1, 4)`, so the same coordinate maps to `2*1 + 5*4 = 22`. \n", + "A column-major layout for the same shape would use stride `(1, 4)`, so the same coordinate maps to `2*1 + 5*4 = 22`.\n", "\n", "The shape stays the same — only the stride changes.\n", "\n", @@ -90,6 +93,7 @@ }, { "cell_type": "markdown", + "id": "e2c456ed", "metadata": {}, "source": [ "## Hardware and software requirements\n", @@ -100,12 +104,12 @@ "| CUDA | 12.9 | 13.1 |\n", "| JAX | 0.8.1+ | Latest |\n", "| CUTLASS | 4.4+ (CuTe DSL) | Latest |\n", - "| Python | 3.10+ | 3.12 |\n", - "\n" + "| Python | 3.10+ | 3.12 |" ] }, { "cell_type": "markdown", + "id": "342b8ce6", "metadata": {}, "source": [ "First, let's check which GPUs are available in this environment. The `nvidia-smi` command shows the GPU model, driver version, CUDA toolkit version, and current memory usage." @@ -114,6 +118,7 @@ { "cell_type": "code", "execution_count": null, + "id": "5e61b94e", "metadata": {}, "outputs": [], "source": [ @@ -122,6 +127,7 @@ }, { "cell_type": "markdown", + "id": "4c8c09ad", "metadata": {}, "source": [ "We programmatically query the GPU's **compute capability** using `nvidia-smi`. This two-digit number (e.g., 9.0 for Hopper) tells us which hardware features are available. CuTe DSL requires SM 8.0 (Ampere) or newer." @@ -130,51 +136,58 @@ { "cell_type": "code", "execution_count": null, + "id": "aa073591", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", + "\n", "def get_compute_capability():\n", - " \"\"\"Query the compute capability of the first visible GPU.\"\"\"\n", - " out = subprocess.check_output(\n", - " [\"nvidia-smi\", \"--query-gpu=compute_cap\", \"--format=csv,noheader\"], text=True\n", - " )\n", - " major, minor = out.strip().split(\"\\n\")[0].split(\".\")\n", - " return int(major), int(minor)\n", + " \"\"\"Query the compute capability of the first visible GPU.\"\"\"\n", + " out = subprocess.check_output(\n", + " [\"nvidia-smi\", \"--query-gpu=compute_cap\", \"--format=csv,noheader\"],\n", + " text=True,\n", + " )\n", + " major, minor = out.strip().split(\"\\n\")[0].split(\".\")\n", + " return int(major), int(minor)\n", + "\n", "\n", "SM_MAJOR, SM_MINOR = get_compute_capability()\n", "print(f\"Detected compute capability: SM {SM_MAJOR}.{SM_MINOR}\")\n", "\n", "if SM_MAJOR < 8:\n", - " print(\"WARNING: CuTe DSL requires SM 8.0+ (Ampere or newer).\")\n", - " print(\"Some examples may not run on this GPU.\")\n", + " print(\"WARNING: CuTe DSL requires SM 8.0+ (Ampere or newer).\")\n", + " print(\"Some examples may not run on this GPU.\")\n", "else:\n", - " print(\"GPU is compatible with CuTe DSL.\")" + " print(\"GPU is compatible with CuTe DSL.\")" ] }, { "cell_type": "markdown", + "id": "fca7c082", "metadata": {}, "source": [ - "## Install CuTe DSL\n", + "## Install CuTe DSL and import dependencies\n", "\n", "The `nvidia-cutlass-dsl` package bundles CuTe DSL together with its JAX integration module (`cutlass.jax`). The `[cu13]` extra pulls in CUDA 13.x compatible runtime libraries. Version 4.4+ is required for the JAX integration.\n", "\n", - "Refer to the [official documentation](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html) for a more comprehansive installation guide." + "Refer to the [official documentation](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html) for a more comprehensive installation guide." ] }, { "cell_type": "code", "execution_count": null, + "id": "f9d3534e", "metadata": {}, "outputs": [], "source": [ - "%pip install \"nvidia-cutlass-dsl[cu13]==4.4.0.dev1\" --quiet" + "%pip install \"nvidia-cutlass-dsl[cu13]\" --quiet" ] }, { "cell_type": "markdown", + "id": "4fa5be3b", "metadata": {}, "source": [ "With CUTLASS installed, we import the libraries we'll use throughout the notebook: `cutlass` for kernel definitions, `jax` and `jnp` for array computation and JIT compilation, and `numpy` for result validation." @@ -183,26 +196,35 @@ { "cell_type": "code", "execution_count": null, + "id": "7a306776", "metadata": {}, "outputs": [], "source": [ "import os\n", - "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # suppress TF/XLA info & warnings\n", + "\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\" # suppress TF/XLA info & warnings\n", "os.environ[\"XLA_FLAGS\"] = \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n", "\n", "import cutlass\n", "from importlib.metadata import version as _pkg_version\n", + "\n", "print(f\"CUTLASS version: {_pkg_version('nvidia-cutlass-dsl')}\")\n", "\n", + "import cutlass.cute as cute\n", + "import cutlass.jax as cjax\n", + "import cuda.bindings.driver as cuda\n", + "\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", + "\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX devices: {jax.devices()}\")" ] }, { "cell_type": "markdown", + "id": "f0d79747", "metadata": {}, "source": [ "## Defining kernels\n", @@ -211,94 +233,145 @@ "\n", "1. **`@cute.kernel`** defines the per-thread program — the sequence of instructions executed by each thread instance.\n", "2. **`@cute.jit`** compiles the kernel and specifies how it runs on the GPU: the grid (how many blocks), the block (how many threads per block), and the CUDA stream (which execution queue to launch into).\n", - " \n", + "\n", "CuTe DSL lowers Python kernels to CUDA/CUTLASS code and compiles them just-in-time using the CUTLASS JIT toolchain.\n", "\n", - "**Note:** CuTe DSL uses `inspect.getsourcelines()` to parse kernel source, so `@cute.kernel` / `@cute.jit` functions must live in `.py` files rather than notebook cells. We have written them to a module [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) — refer to this file for the concept explainations. \n", + "**Note:** CuTe DSL relies on Python source inspection `inspect.getsourcelines()` to parse kernel definitions. In many environments (including this notebook), defining `@cute.kernel` / `@cute.jit` functions directly in notebook cells works correctly. However, this is not consistently reliable across all interactive environments (e.g. plain Python REPL), where source inspection may fail with errors like `OSError: could not get source code`.\n", "\n", - "Here we import the pre-written kernel launch functions from `cute_dsl_jax_kernels.py`." + "We show the executable kernel definitions inline in the notebook. At the same time, for robustness and reproducibility, we keep equivalent definitions in a separate .py module ([cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py)).\n", + "\n", + "Here, we import the pre-written kernel launch functions from [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py)." ] }, { "cell_type": "code", "execution_count": null, + "id": "933e6a80", "metadata": {}, "outputs": [], "source": [ - "import cutlass.jax as cjax\n", - "from cute_dsl_jax_kernels import (\n", - " launch_vector_add, launch_saxpy, launch_gemm,\n", - " launch_relu, launch_fused_bias_relu,\n", - " launch_elementwise_add,\n", - ")\n", - "print(\"Imported: launch_vector_add, launch_saxpy, launch_gemm, launch_relu, launch_fused_bias_relu, launch_elementwise_add\")" + "# Optional, if you execute the equivalent kernel definitions further in the notebook\n", + "\n", + "# from cute_dsl_jax_kernels import (\n", + "# launch_vector_add, launch_saxpy, launch_gemm,\n", + "# launch_relu, launch_fused_bias_relu,\n", + "# launch_elementwise_add,\n", + "# )\n", + "# print(\"Imported: launch_vector_add, launch_saxpy, launch_gemm, launch_relu, launch_fused_bias_relu, launch_elementwise_add\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66eca19c-dff1-4ea5-8f1b-7df3348b5aba", + "metadata": {}, + "outputs": [], + "source": [ + "def split_keys(seed=0):\n", + " key = jax.random.key(seed)\n", + " while True:\n", + " key, subkey = jax.random.split(key)\n", + " yield subkey\n", + "\n", + "keys = iter(split_keys())" ] }, { "cell_type": "markdown", + "id": "75187a02", "metadata": {}, "source": [ "## Basic kernel: vector add\n", "\n", - "We’ll start with the simplest GPU kernel — vector add: `c[i] = a[i] + b[i]`. \n", + "We’ll start with the simplest GPU kernel — vector add: `c[i] = a[i] + b[i]`.\n", "\n", - "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", - "\n", - "### CuTe vector add kernel explained\n", - "\n", - "Each thread identifies itself using `thread_idx()` and `block_idx()`. Thread and block indices are accessed through `cute.arch` (e.g., `thread_idx`, `block_idx`), each returning `(x, y, z)` tuples, because CUDA’s execution and indexing are 3-dimensional by design. Since this kernel is launched in 1D, we only use the `x` component (`tidx` and `bidx`) and ignore the unused `y` and `z` values with `_`.\n", + "Each thread in the kernel below identifies itself using `thread_idx()` and `block_idx()`. Thread and block indices are accessed through `cute.arch` (e.g., `thread_idx`, `block_idx`), each returning `(x, y, z)` tuples, because CUDA’s execution and indexing are 3-dimensional by design. Since this kernel is launched in 1D, we only use the `x` component (`tidx` and `bidx`) and ignore the unused `y` and `z` values with `_`.\n", "\n", "```python\n", "tidx, _, _ = cute.arch.thread_idx()\n", "bidx, _, _ = cute.arch.block_idx()\n", "```\n", "\n", - "Inside the kernel, tensors are typically created in register memory using `cute.make_rmem_tensor`:\n", + "Inside the kernel, tensors are typically created in register memory using `cute.make_rmem_tensor`.\n", "\n", - "```python\n", - "frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)\n", - "frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)\n", - "frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)\n", - "```\n", + "Below, `frgA` and `frgB` hold the input values in registers, while `frgC` is a register fragment that will store the computed result before it is written back to global memory. `mode=[0]` selects the first dimension of the tensor — the \"elements per thread\" axis — so the register fragment is sized to hold exactly the data owned by one thread.\n", "\n", - "Here, `frgA` and `frgB` hold the input values in registers, while `frgC` is a register fragment that will store the computed result before it is written back to global memory. `mode=[0]` selects the first dimension of the tensor — the \"elements per thread\" axis — so the register fragment is sized to hold exactly the data owned by one thread.\n", + "Data movement between global and register memory is explicit: fragments are read using `load()` and written back using `store()`, while `cute.autovec_copy` performs efficient, vectorized transfers between memory spaces. Here, one element of `a` and `b` is loaded into register fragments, the sum in registers is computed and the result is stored back to `c`. The `None` selects the entire first dimension (which has size 1 in this example), preserving the (`elems_per_thread`, `threads_per_block`, `num_blocks`) structure while allowing each thread to access its own slice of the tensor.\n", "\n", - "Data movement between global and register memory is explicit: fragments are read using `load()` and written back using `store()`, while `cute.autovec_copy` performs efficient, vectorized transfers between memory spaces. Here, one element of `a` and `b` is loaded into register fragments, the sum in registers is computed and the result is stored back to `c`:\n", + "> **Concept: Tensor = Pointer + Layout**\n", + ">\n", + "> A CuTe **Tensor** pairs a pointer to GPU memory with a **Layout** that describes how to navigate it. When the kernel receives `a: cute.Tensor`, it gets both the raw data and the index mapping. In this example, our tensors have shape `(1, BLOCK, num_blocks)` — one element per thread, `BLOCK=256` (defined in the example below) threads per block, spread across blocks. The layout maps a `(elems_per_thread, threads_per_block, num_blocks)` coordinate to the flat memory offset where that element lives. The kernel never computes offsets manually — it just indexes the tensor with `a[None, tidx, bidx]` and CuTe's layout handles the rest." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94260fa2", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def vector_add_kernel(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):\n", + " \"\"\"Per-thread kernel: each thread adds one element.\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", "\n", - "```python\n", - "cute.autovec_copy(a[None, tidx, bidx], frgA)\n", - "cute.autovec_copy(b[None, tidx, bidx], frgB)\n", - "frgC.store(frgA.load() + frgB.load())\n", - "cute.autovec_copy(frgC, c[None, tidx, bidx])\n", - "```\n", + " frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type)\n", + " frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type)\n", + " frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type)\n", "\n", - "The `None` selects the entire first dimension (which has size 1 in this example), preserving the (`elems_per_thread`, `threads_per_block`, `num_blocks`) structure while allowing each thread to access its own slice of the tensor.\n", + " cute.autovec_copy(a[None, tidx, bidx], frgA)\n", + " cute.autovec_copy(b[None, tidx, bidx], frgB)\n", + " frgC.store(frgA.load() + frgB.load())\n", + " cute.autovec_copy(frgC, c[None, tidx, bidx])\n", "\n", - "> **Concept: Tensor = Pointer + Layout**\n", - "> \n", - "> A CuTe **Tensor** pairs a pointer to GPU memory with a **Layout** that describes how to navigate it. When the kernel receives `a: cute.Tensor`, it gets both the raw data and the index mapping. In this example, our tensors have shape `(1, BLOCK, num_blocks)` — one element per thread, `BLOCK=256` (defined in the example below) threads per block, spread across blocks. The layout maps a `(elems_per_thread, threads_per_block, num_blocks)` coordinate to the flat memory offset where that element lives. The kernel never computes offsets manually — it just indexes the tensor with `a[None, tidx, bidx]` and CuTe's layout handles the rest.\n", "\n", + "print(\"vector_add_kernel defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "d68a2f51", + "metadata": {}, + "source": [ "The `@cute.kernel` defines one thread’s work. The `@cute.jit` launcher decides how many threads run, and how they’re grouped. It must follow the signature convention: `(stream, *inputs, *outputs, *, **kwargs)` — where `stream` is a CUDA stream managed by XLA, followed by input tensors, then output tensors.\n", "\n", - "```python\n", + "We launch `a.shape[-2]` threads per block and `a.shape[-1]` blocks, directly matching the tensor’s `(1, threads_per_block, num_blocks)` layout so that `threadIdx.x` indexes the thread dimension and `blockIdx.x` indexes the block dimension. We use -2 and -1 because they refer to the last two tensor dimensions (threads per block and number of blocks), making the launch configuration robust even if additional leading dimensions are added.\n", + "\n", + "> **Concept: Layout composition**\n", + ">\n", + "> The vector add kernel expects 3-D tensors with shape `(elems_per_thread, threads_per_block, num_blocks)`, but our data is a flat 1-D array. The JAX wrapper reshapes from 1-D to 3-D before calling the kernel, and back afterward. In CuTe terms, this reshape is a **layout composition** — combining the original 1-D layout with a new layout that splits the single dimension into three. CuTe performs this algebraically: the composed layout maps 3-D coordinates directly to the original flat offsets, with no data movement. Reshaping is free — it's just a change of layout, not a copy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3844feb", + "metadata": {}, + "outputs": [], + "source": [ "@cute.jit\n", "def launch_vector_add(\n", " stream: cuda.CUstream,\n", - " a: cute.Tensor, b: cute.Tensor, c: cute.Tensor,\n", + " a: cute.Tensor,\n", + " b: cute.Tensor,\n", + " c: cute.Tensor,\n", "):\n", - " vector_add_kernel(a, b, c).launch(\n", - " grid=[a.shape[-1], 1, 1],\n", - " block=[a.shape[-2], 1, 1],\n", - " stream=stream,\n", - " )\n", - "```\n", - "\n", - "We launch `a.shape[-2]` threads per block and `a.shape[-1]` blocks, directly matching the tensor’s `(1, threads_per_block, num_blocks)` layout so that `threadIdx.x` indexes the thread dimension and `blockIdx.x` indexes the block dimension. We use -2 and -1 because they refer to the last two tensor dimensions (threads per block and number of blocks), making the launch configuration robust even if additional leading dimensions are added.\n", + " vector_add_kernel(a, b, c).launch(\n", + " grid=[a.shape[-1], 1, 1],\n", + " block=[a.shape[-2], 1, 1],\n", + " stream=stream,\n", + " )\n", "\n", - "> **Concept: Layout composition**\n", - ">\n", - "> The vector add kernel expects 3-D tensors with shape `(elems_per_thread, threads_per_block, num_blocks)`, but our data is a flat 1-D array. The JAX wrapper reshapes from 1-D to 3-D before calling the kernel, and back afterward. In CuTe terms, this reshape is a **layout composition** — combining the original 1-D layout with a new layout that splits the single dimension into three. CuTe performs this algebraically: the composed layout maps 3-D coordinates directly to the original flat offsets, with no data movement. Reshaping is free — it's just a change of layout, not a copy.\n", "\n", + "print(\"launch_vector_add defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e80f7061", + "metadata": {}, + "source": [ "### JAX integration via `cutlass_call`\n", "\n", "The `cutlass.jax.cutlass_call` function wraps a CuTe `@cute.jit` launch function as a JAX custom call, so your CuTe/CUTLASS kernel can be invoked inside `@jax.jit`-compiled code and become part of the XLA computation graph.\n", @@ -330,7 +403,7 @@ ")\n", "result = call(*input_arrays) # Pass JAX arrays here\n", "```\n", - " \n", + "\n", "3. Call the wrapped launcher\n", "\n", "* Inside a `@jax.jit` function this becomes a custom call node in the XLA graph; at runtime XLA provides a CUDA `CUstream` and device pointers, and the CUTLASS kernel is invoked on that stream.\n", @@ -352,37 +425,41 @@ { "cell_type": "code", "execution_count": null, + "id": "c8538e21", "metadata": {}, "outputs": [], "source": [ - "BLOCK = 256 # threads per block for vector add: 256 is a practical default: \n", - " # large enough to expose parallelism, small enough to scale \n", - " # well across different GPUs, and aligned with the hardware’s \n", - " # 32-thread warp execution model.\n", + "BLOCK = 256 # threads per block for vector add: 256 is a practical default:\n", + "# large enough to expose parallelism, small enough to scale\n", + "# well across different GPUs, and aligned with the hardware’s\n", + "# 32-thread warp execution model.\n", + "\n", "\n", "@jax.jit\n", "def jax_vector_add(a, b):\n", - " \"\"\"JAX-compatible vector add using CUTLASS kernel.\"\"\"\n", - " N = a.shape[0]\n", - " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", - " a_pad = jnp.pad(a, (0, padded - N))\n", - " b_pad = jnp.pad(b, (0, padded - N))\n", - " # Reshape to (1, BLOCK, num_blocks) for the CuTe kernel\n", - " a_3d = a_pad.reshape(1, BLOCK, padded // BLOCK)\n", - " b_3d = b_pad.reshape(1, BLOCK, padded // BLOCK)\n", - " call = cjax.cutlass_call(\n", - " launch_vector_add,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(a_3d.shape, a_3d.dtype),\n", - " use_static_tensors=True,\n", - " )\n", - " c_3d = call(a_3d, b_3d)\n", - " return c_3d.reshape(-1)[:N]\n", + " \"\"\"JAX-compatible vector add using CUTLASS kernel.\"\"\"\n", + " N = a.shape[0]\n", + " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", + " a_pad = jnp.pad(a, (0, padded - N))\n", + " b_pad = jnp.pad(b, (0, padded - N))\n", + " # Reshape to (1, BLOCK, num_blocks) for the CuTe kernel\n", + " a_3d = a_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " b_3d = b_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " call = cjax.cutlass_call(\n", + " launch_vector_add,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(a_3d.shape, a_3d.dtype),\n", + " use_static_tensors=True,\n", + " )\n", + " c_3d = call(a_3d, b_3d)\n", + " return c_3d.reshape(-1)[:N]\n", + "\n", "\n", "print(\"jax_vector_add defined.\")" ] }, { "cell_type": "markdown", + "id": "63f6eec4", "metadata": {}, "source": [ "Let's test our CUTLASS vector add by comparing its output against JAX's built-in `+` operator. We generate two random arrays, run both implementations, and verify the results match element-by-element." @@ -391,14 +468,14 @@ { "cell_type": "code", "execution_count": null, + "id": "8568963f", "metadata": {}, "outputs": [], "source": [ - " # Test vector add\n", + "# Test vector add\n", "N = 1024\n", - "key = jax.random.PRNGKey(0)\n", - "a = jax.random.normal(key, (N,), dtype=jnp.float32)\n", - "b = jax.random.normal(jax.random.PRNGKey(1), (N,), dtype=jnp.float32)\n", + "a = jax.random.normal(next(keys), (N,), dtype=jnp.float32)\n", + "b = jax.random.normal(next(keys), (N,), dtype=jnp.float32)\n", "\n", "c = jax_vector_add(a, b)\n", "c_ref = a + b\n", @@ -410,20 +487,16 @@ }, { "cell_type": "markdown", + "id": "a17a7c36", "metadata": {}, "source": [ "## SAXPY: scalar parameters in kernels\n", "\n", "**SAXPY** computes `out[i] = alpha * x[i] + y[i]`. This builds on the vector add pattern and introduces passing a **scalar argument** (`alpha`) alongside tensor arguments.\n", "\n", - "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", - "\n", - "### CuTe SAXPY kernel explained\n", - "\n", "The SAXPY kernel follows the same structure as vector add — identify the thread, load data into registers, compute, write back — with one addition: a scalar `alpha` parameter.\n", "\n", "```python\n", - "@cute.kernel\n", "def saxpy_kernel(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, alpha: float):\n", "```\n", "\n", @@ -435,24 +508,76 @@ "frgO.store(alpha * frgX.load() + frgY.load())\n", "```\n", "\n", - "Each thread loads its element of `x` and `y` into register fragments, multiplies `x` by `alpha`, adds `y`, and writes the result to `out`.\n", + "Each thread loads its element of `x` and `y` into register fragments, multiplies `x` by `alpha`, adds `y`, and writes the result to `out`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da81bdd6", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def saxpy_kernel(\n", + " x: cute.Tensor, y: cute.Tensor, out: cute.Tensor, alpha: float\n", + "):\n", + " \"\"\"SAXPY: out[i] = alpha * x[i] + y[i].\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", "\n", - "The launcher passes `alpha` as a **keyword-only** argument (note the `*` in the signature):\n", + " frgX = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type)\n", + " frgY = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type)\n", + " frgO = cute.make_rmem_tensor(cute.size(out, mode=[0]), out.element_type)\n", "\n", - "```python\n", + " cute.autovec_copy(x[None, tidx, bidx], frgX)\n", + " cute.autovec_copy(y[None, tidx, bidx], frgY)\n", + " frgO.store(alpha * frgX.load() + frgY.load())\n", + " cute.autovec_copy(frgO, out[None, tidx, bidx])\n", + "\n", + "\n", + "print(\"saxpy_kernel defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "dfc084b8", + "metadata": {}, + "source": [ + "The launcher passes `alpha` as a **keyword-only** argument (note the `*` in the signature):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a0eefad", + "metadata": {}, + "outputs": [], + "source": [ "@cute.jit\n", "def launch_saxpy(\n", " stream: cuda.CUstream,\n", - " x: cute.Tensor, y: cute.Tensor, out: cute.Tensor,\n", - " *, alpha: float,\n", + " x: cute.Tensor,\n", + " y: cute.Tensor,\n", + " out: cute.Tensor,\n", + " *,\n", + " alpha: float,\n", "):\n", - " saxpy_kernel(x, y, out, alpha).launch(\n", - " grid=[x.shape[-1], 1, 1],\n", - " block=[x.shape[-2], 1, 1],\n", - " stream=stream,\n", - " )\n", - "```\n", + " saxpy_kernel(x, y, out, alpha).launch(\n", + " grid=[x.shape[-1], 1, 1],\n", + " block=[x.shape[-2], 1, 1],\n", + " stream=stream,\n", + " )\n", "\n", + "\n", + "print(\"launch_saxpy defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "763afd1b", + "metadata": {}, + "source": [ "The keyword-only convention matters for `cutlass_call`: positional arguments correspond to JAX tensors (managed by XLA), while keyword arguments are scalar values passed directly to the kernel. In the JAX wrapper below, `alpha=alpha` routes through `cutlass_call` as a kernel kwarg:\n", "\n", "```python\n", @@ -474,34 +599,40 @@ { "cell_type": "code", "execution_count": null, + "id": "fbbd970b", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "\n", + "BLOCK = 256\n", + "\n", + "\n", "@partial(jax.jit, static_argnums=(2,))\n", "def jax_saxpy(x, y, alpha=2.0):\n", - " \"\"\"JAX-compatible SAXPY using CUTLASS kernel.\"\"\"\n", - " N = x.shape[0]\n", - " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", - " x_pad = jnp.pad(x, (0, padded - N))\n", - " y_pad = jnp.pad(y, (0, padded - N))\n", - " x_3d = x_pad.reshape(1, BLOCK, padded // BLOCK)\n", - " y_3d = y_pad.reshape(1, BLOCK, padded // BLOCK)\n", - " call = cjax.cutlass_call(\n", - " launch_saxpy,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(x_3d.shape, x_3d.dtype),\n", - " use_static_tensors=True,\n", - " alpha=alpha,\n", - " )\n", - " out_3d = call(x_3d, y_3d)\n", - " return out_3d.reshape(-1)[:N]\n", + " \"\"\"JAX-compatible SAXPY using CUTLASS kernel.\"\"\"\n", + " N = x.shape[0]\n", + " padded = ((N + BLOCK - 1) // BLOCK) * BLOCK\n", + " x_pad = jnp.pad(x, (0, padded - N))\n", + " y_pad = jnp.pad(y, (0, padded - N))\n", + " x_3d = x_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " y_3d = y_pad.reshape(1, BLOCK, padded // BLOCK)\n", + " call = cjax.cutlass_call(\n", + " launch_saxpy,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_3d.shape, x_3d.dtype),\n", + " use_static_tensors=True,\n", + " alpha=alpha,\n", + " )\n", + " out_3d = call(x_3d, y_3d)\n", + " return out_3d.reshape(-1)[:N]\n", + "\n", "\n", "print(\"jax_saxpy defined.\")" ] }, { "cell_type": "markdown", + "id": "bbd8d386", "metadata": {}, "source": [ "We test the SAXPY kernel with `alpha=2.5`, comparing against the reference computation `alpha * x + y`. The `assert_allclose` check verifies that results match within floating-point tolerance." @@ -510,78 +641,117 @@ { "cell_type": "code", "execution_count": null, + "id": "6a01dcb1", "metadata": {}, "outputs": [], "source": [ "# Test SAXPY\n", "N = 2048\n", "ALPHA = 2.5\n", - "key = jax.random.PRNGKey(42)\n", - "x = jax.random.normal(key, (N,), dtype=jnp.float32)\n", - "y = jax.random.normal(jax.random.PRNGKey(43), (N,), dtype=jnp.float32)\n", + "x = jax.random.normal(next(keys), (N,), dtype=jnp.float32)\n", + "y = jax.random.normal(next(keys), (N,), dtype=jnp.float32)\n", "\n", "result = jax_saxpy(x, y, alpha=ALPHA)\n", "ref = ALPHA * x + y\n", "\n", - "np.testing.assert_allclose(np.array(result), np.array(ref), rtol=1e-5)\n", + "np.testing.assert_allclose(np.array(result), np.array(ref), rtol=1e-5, atol=1e-6,)\n", "print(f\"SAXPY PASSED (N={N}, alpha={ALPHA})\")\n", "print(f\" Max error: {float(jnp.max(jnp.abs(result - ref))):.2e}\")" ] }, { "cell_type": "markdown", + "id": "f21f6d7f", "metadata": {}, "source": [ "## Deep learning activations: ReLU and fused bias+ReLU\n", "\n", - "**ReLU** (`max(0, x)`) is the most widely used activation function in deep learning. It's elementwise and trivially parallel — a perfect custom kernel for ML workloads.\n", + "### ReLU\n", "\n", - "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementations.\n", - "\n", - "### CuTe ReLU kernel explained\n", + "**ReLU** (`max(0, x)`) is the most widely used activation function in deep learning. It's elementwise and trivially parallel — a perfect custom kernel for ML workloads.\n", "\n", "The ReLU kernel uses a different pattern from vector add and SAXPY. Instead of the 3-D tensor approach with register fragments, it uses **flat 1-D indexing** — simpler and equally efficient for elementwise operations.\n", "\n", "```python\n", - "@cute.kernel\n", - "def relu_kernel(x: cute.Tensor, out: cute.Tensor, N: int):\n", - " tidx, _, _ = cute.arch.thread_idx()\n", - " bidx, _, _ = cute.arch.block_idx()\n", - " bdx, _, _ = cute.arch.block_dim()\n", + "tidx, _, _ = cute.arch.thread_idx()\n", + "bidx, _, _ = cute.arch.block_idx()\n", + "bdx, _, _ = cute.arch.block_dim()\n", "```\n", "\n", "A new call appears here: `cute.arch.block_dim()` returns the number of threads per block (set at launch time). Together with `thread_idx` and `block_idx`, it lets each thread compute its unique **global index**:\n", "\n", "```python\n", - " idx = bidx * bdx + tidx\n", + "idx = bidx * bdx + tidx\n", "```\n", "\n", "For example, if we launch 256 threads per block: thread 3 in block 2 gets `idx = 2 * 256 + 3 = 515`. This is the standard CUDA pattern for mapping threads to 1-D data.\n", "\n", - "Because the data length `N` may not be a multiple of the block size, the last block could contain threads that point past the end of the array. The bounds check prevents out-of-bounds writes:\n", - "\n", - "```python\n", - " if idx < N:\n", - " val = x[idx]\n", - " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", - "```\n", + "Here we index the tensors directly with `x[idx]` — no register fragments or `autovec_copy`. For simple elementwise operations this flat approach is cleaner. `cutlass.max` is CuTe DSL's built-in max function, and `cutlass.Float32(0.0)` creates a typed zero constant that matches the tensor's element type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae15b6fa", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def relu_kernel(x: cute.Tensor, out: cute.Tensor, N: int):\n", + " \"\"\"Per-thread kernel: each thread computes ReLU of one element.\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", + " bdx, _, _ = cute.arch.block_dim()\n", "\n", - "Here we index the tensors directly with `x[idx]` — no register fragments or `autovec_copy`. For simple elementwise operations this flat approach is cleaner. `cutlass.max` is CuTe DSL's built-in max function, and `cutlass.Float32(0.0)` creates a typed zero constant that matches the tensor's element type.\n", + " idx = bidx * bdx + tidx\n", + " if idx < N:\n", + " val = x[idx]\n", + " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", "\n", - "The launcher computes how many blocks are needed to cover `N` elements:\n", "\n", - "```python\n", + "print(\"relu_kernel defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "89e73969", + "metadata": {}, + "source": [ + "The launcher computes how many blocks are needed to cover `N` elements:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa178f26", + "metadata": {}, + "outputs": [], + "source": [ "@cute.jit\n", - "def launch_relu(stream, x, out, *, N):\n", - " BLOCK_SIZE = 256\n", - " grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE\n", - " relu_kernel(x, out, N).launch(\n", - " grid=[grid_size, 1, 1],\n", - " block=[BLOCK_SIZE, 1, 1],\n", - " stream=stream,\n", - " )\n", - "```\n", + "def launch_relu(\n", + " stream: cuda.CUstream,\n", + " x: cute.Tensor,\n", + " out: cute.Tensor,\n", + " *,\n", + " N: int,\n", + "):\n", + " BLOCK_SIZE = 256\n", + " grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE\n", + " relu_kernel(x, out, N).launch(\n", + " grid=[grid_size, 1, 1],\n", + " block=[BLOCK_SIZE, 1, 1],\n", + " stream=stream,\n", + " )\n", "\n", + "\n", + "print(\"launch_relu defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "f04bd7be", + "metadata": {}, + "source": [ "The formula `(N + BLOCK_SIZE - 1) // BLOCK_SIZE` is ceiling division — it ensures we launch enough blocks even when `N` isn't a multiple of 256. The bounds check inside the kernel handles the leftover threads in the last block.\n", "\n", "### JAX wrapper: ReLU\n", @@ -605,27 +775,30 @@ { "cell_type": "code", "execution_count": null, + "id": "ad51598f", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def jax_relu(x):\n", - " \"\"\"JAX-compatible ReLU using CUTLASS kernel.\"\"\"\n", - " N = x.size\n", - " x_flat = x.reshape(-1)\n", - " call = cjax.cutlass_call(\n", - " launch_relu,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", - " N=N,\n", - " )\n", - " out_flat = call(x_flat)\n", - " return out_flat.reshape(x.shape)\n", + " \"\"\"JAX-compatible ReLU using CUTLASS kernel.\"\"\"\n", + " N = x.size\n", + " x_flat = x.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N,\n", + " )\n", + " out_flat = call(x_flat)\n", + " return out_flat.reshape(x.shape)\n", + "\n", "\n", "print(\"jax_relu defined.\")" ] }, { "cell_type": "markdown", + "id": "dd0e6663", "metadata": {}, "source": [ "We verify the ReLU kernel by comparing against `jax.nn.relu`. Positive values should pass through unchanged, and negative values should become zero." @@ -634,13 +807,13 @@ { "cell_type": "code", "execution_count": null, + "id": "689a8bba", "metadata": {}, "outputs": [], "source": [ "# Test ReLU\n", "N = 2048\n", - "key = jax.random.PRNGKey(7)\n", - "x = jax.random.normal(key, (N,), dtype=jnp.float32)\n", + "x = jax.random.normal(next(keys), (N,), dtype=jnp.float32)\n", "\n", "result = jax_relu(x)\n", "ref = jax.nn.relu(x)\n", @@ -654,9 +827,10 @@ }, { "cell_type": "markdown", + "id": "46c9aa23", "metadata": {}, "source": [ - "### CuTe kernel explained: fused bias+ReLU\n", + "### Fused bias+ReLU\n", "\n", "**Fused bias+ReLU** computes `max(0, x + bias)` in a single kernel. This demonstrates **kernel fusion** — combining multiple operations into one GPU pass.\n", "\n", @@ -664,19 +838,84 @@ "- **Without fusion:** `x + bias` writes an intermediate array to global memory, then `max(0, ...)` reads it back. That's two kernel launches and two round-trips to memory.\n", "- **With fusion:** one kernel reads `x` and `bias`, computes the sum and the max, and writes the final result. Half the memory traffic, one launch instead of two.\n", "\n", - "The kernel extends the ReLU pattern with a bias lookup:\n", + "The kernel extends the ReLU pattern with a bias lookup.\n", "\n", - "```python\n", - " idx = bidx * bdx + tidx\n", - " if idx < N:\n", - " col = idx % width\n", - " val = x[idx] + bias[col]\n", - " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", - "```\n", + "The input `x` is a flattened `(batch, width)` matrix. `idx` is the global flat index, and `col = idx % width` recovers which column (feature) this element belongs to, so we can look up the correct bias. This modular indexing pattern is common in fused kernels that combine elementwise and broadcast operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c6b325e", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def fused_bias_relu_kernel(\n", + " x: cute.Tensor,\n", + " bias: cute.Tensor,\n", + " out: cute.Tensor,\n", + " N: int,\n", + " width: int,\n", + "):\n", + " \"\"\"Per-thread: out[i] = max(0, x[i] + bias[i % width]).\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", + " bdx, _, _ = cute.arch.block_dim()\n", "\n", - "The input `x` is a flattened `(batch, width)` matrix. `idx` is the global flat index, and `col = idx % width` recovers which column (feature) this element belongs to, so we can look up the correct bias. This modular indexing pattern is common in fused kernels that combine elementwise and broadcast operations.\n", + " idx = bidx * bdx + tidx\n", + " if idx < N:\n", + " col = idx % width\n", + " val = x[idx] + bias[col]\n", + " out[idx] = cutlass.max(val, cutlass.Float32(0.0))\n", "\n", - "The launcher and JAX wrapper follow the same flat-indexing pattern as ReLU, with `N` (total elements) and `width` (columns) passed as keyword arguments:\n", + "\n", + "print(\"fused_bias_relu_kernel defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "18b3d2f6", + "metadata": {}, + "source": [ + "The launcher and JAX wrapper follow the same flat-indexing pattern as ReLU, with `N` (total elements) and `width` (columns) passed as keyword arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14d67faa", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.jit\n", + "def launch_fused_bias_relu(\n", + " stream: cuda.CUstream,\n", + " x: cute.Tensor,\n", + " bias: cute.Tensor,\n", + " out: cute.Tensor,\n", + " *,\n", + " N: int,\n", + " width: int,\n", + "):\n", + " BLOCK_SIZE = 256\n", + " grid_size = (N + BLOCK_SIZE - 1) // BLOCK_SIZE\n", + " fused_bias_relu_kernel(x, bias, out, N, width).launch(\n", + " grid=[grid_size, 1, 1],\n", + " block=[BLOCK_SIZE, 1, 1],\n", + " stream=stream,\n", + " )\n", + "\n", + "\n", + "print(\"launch_fused_bias_relu defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "0c3ac473", + "metadata": {}, + "source": [ + "Note that `width` is marked as a static argument in the JAX wrapper via `static_argnums=(2,)`. This means JAX recompiles when the feature dimension changes, allowing CUTLASS to generate specialized code for each width.\n", "\n", "```python\n", "call = cjax.cutlass_call(\n", @@ -685,43 +924,46 @@ " N=N, width=width,\n", ")\n", "out_flat = call(x_flat, bias) # two input tensors: x and bias\n", - "```\n", - "\n", - "Note that `width` is marked as a static argument in the JAX wrapper via `static_argnums=(2,)`. This means JAX recompiles when the feature dimension changes, allowing CUTLASS to generate specialized code for each width." + "```" ] }, { "cell_type": "code", "execution_count": null, + "id": "930f0016", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "\n", + "\n", "@partial(jax.jit, static_argnums=(2,))\n", "def jax_fused_bias_relu(x, bias, width):\n", - " \"\"\"JAX-compatible fused Bias+ReLU using CUTLASS kernel.\n", - "\n", - " Args:\n", - " x: Input matrix of shape (batch, width), flattened to 1-D for the kernel.\n", - " bias: Bias vector of shape (width,).\n", - " width: Number of columns (static, passed as constexpr to the kernel).\n", - " \"\"\"\n", - " N = x.size\n", - " x_flat = x.reshape(-1)\n", - " call = cjax.cutlass_call(\n", - " launch_fused_bias_relu,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", - " N=N, width=width,\n", - " )\n", - " out_flat = call(x_flat, bias)\n", - " return out_flat.reshape(x.shape)\n", + " \"\"\"JAX-compatible fused Bias+ReLU using CUTLASS kernel.\n", + "\n", + " Args:\n", + " x: Input matrix of shape (batch, width), flattened to 1-D for the kernel.\n", + " bias: Bias vector of shape (width,).\n", + " width: Number of columns (static, passed as constexpr to the kernel).\n", + " \"\"\"\n", + " N = x.size\n", + " x_flat = x.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_fused_bias_relu,\n", + " output_shape_dtype=jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype),\n", + " N=N,\n", + " width=width,\n", + " )\n", + " out_flat = call(x_flat, bias)\n", + " return out_flat.reshape(x.shape)\n", + "\n", "\n", "print(\"jax_fused_bias_relu defined.\")" ] }, { "cell_type": "markdown", + "id": "e25f1fa5", "metadata": {}, "source": [ "Test the fused kernel against the equivalent two-step JAX computation: add bias, then apply ReLU. The results should match exactly since both paths perform the same arithmetic." @@ -730,14 +972,14 @@ { "cell_type": "code", "execution_count": null, + "id": "06eded2d", "metadata": {}, "outputs": [], "source": [ "# Test Fused Bias+ReLU\n", "BATCH, WIDTH = 64, 512\n", - "key = jax.random.PRNGKey(99)\n", - "x = jax.random.normal(key, (BATCH, WIDTH), dtype=jnp.float32)\n", - "bias = jax.random.normal(jax.random.PRNGKey(100), (WIDTH,), dtype=jnp.float32)\n", + "x = jax.random.normal(next(keys), (BATCH, WIDTH), dtype=jnp.float32)\n", + "bias = jax.random.normal(next(keys), (WIDTH,), dtype=jnp.float32)\n", "\n", "result = jax_fused_bias_relu(x, bias, WIDTH)\n", "ref = jnp.maximum(0, x + bias[None, :])\n", @@ -749,6 +991,7 @@ }, { "cell_type": "markdown", + "id": "56596b5e", "metadata": {}, "source": [ "> **Going further:** For a production-grade generalization of elementwise kernels — with optimized TV (thread-value) layouts, vectorized memory access, and support for arbitrary binary operators including custom ops like `leaky_relu` — see NVIDIA's [elementwise_apply_example.py](https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/jax/elementwise_apply_example.py)." @@ -756,20 +999,17 @@ }, { "cell_type": "markdown", + "id": "e6a4adfc", "metadata": {}, "source": [ "## Advanced: Tiled GEMM\n", "\n", - "This demonstrates a general matrix multiply (GEMM) kernel: `D = A @ B` where A is (M, K), B is (K, N), and D is (M, N). Unlike the previous elementwise kernels, GEMM requires cooperation across data dimensions — each output element is a dot product over K values.\n", - "\n", - "Refer to [cute_dsl_jax_kernels.py](cute_dsl_jax_kernels.py) for the corresponding implementation.\n", + "This demonstrates a general matrix multiply (GEMM) kernel: `D = A @ B` where `A` is `(M, K)`, `B` is `(K, N)`, and `D` is `(M, N)`. Unlike the previous elementwise kernels, GEMM requires cooperation across data dimensions — each output element is a dot product over `K` values.\n", "\n", "> **Concept: Tiling**\n", ">\n", "> Tiling is CuTe's mechanism for partitioning data into sub-problems that map onto the GPU's execution hierarchy. In a GEMM, we divide the output matrix into `BLOCK_M x BLOCK_N` tiles, each assigned to one thread block. Within a tile, individual threads split the work further. CuTe's tiling operations decompose a layout into an \"inner\" part (the tile itself) and an \"outer\" part (which tile we're on). The block index `(bm, bn)` selects the outer coordinate, and thread indices work within the inner tile. This two-level decomposition — partition then index locally — is the fundamental pattern for mapping parallel GPU work to data.\n", "\n", - "### CuTe GEMM kernel explained\n", - "\n", "This is our first kernel using a **2-D grid**. Each block is responsible for a `BLOCK_M x BLOCK_N` tile of the output matrix:\n", "\n", "```python\n", @@ -812,72 +1052,129 @@ " D[m_idx * N + n_idx] = acc\n", "```\n", "\n", - "Here we use **manual row-major indexing**: `A[m_idx * K + k]` computes the offset into the flattened 1-D tensor for element `(m_idx, k)` of a row-major matrix with K columns. Similarly, `B[k * N + n_idx]` indexes element `(k, n_idx)`. Production CUTLASS kernels use multi-dimensional CuTe tensor indexing instead, but explicit indexing makes the memory layout visible for learning.\n", - "\n", - "### Launch configuration\n", - "\n", + "Here we use **manual row-major indexing**: `A[m_idx * K + k]` computes the offset into the flattened 1-D tensor for element `(m_idx, k)` of a row-major matrix with K columns. Similarly, `B[k * N + n_idx]` indexes element `(k, n_idx)`. Production CUTLASS kernels use multi-dimensional CuTe tensor indexing instead, but explicit indexing makes the memory layout visible for learning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d8d792a", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.kernel\n", + "def gemm_kernel(\n", + " A: cute.Tensor,\n", + " B: cute.Tensor,\n", + " D: cute.Tensor,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + " BLOCK_M: int,\n", + " BLOCK_N: int,\n", + "):\n", + " \"\"\"Tiled GEMM: each thread accumulates output elements.\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bm, bn, _ = cute.arch.block_idx()\n", + " bdx, _, _ = cute.arch.block_dim()\n", + "\n", + " for i in cutlass.range(tidx, BLOCK_M * BLOCK_N, bdx):\n", + " row = i // BLOCK_N\n", + " col = i % BLOCK_N\n", + " m_idx = bm * BLOCK_M + row\n", + " n_idx = bn * BLOCK_N + col\n", + " if m_idx < M and n_idx < N:\n", + " acc = cutlass.Float32(0.0)\n", + " for k in cutlass.range(K):\n", + " acc += A[m_idx * K + k] * B[k * N + n_idx]\n", + " D[m_idx * N + n_idx] = acc\n", + "\n", + "\n", + "print(\"gemm_kernel defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "5b85beeb", + "metadata": {}, + "source": [ "The launcher sets up a 2-D grid matching the tile decomposition:\n", "\n", - "```python\n", - "@cute.jit\n", - "def launch_gemm(stream, A, B, D, *, M, N, K):\n", - " BLOCK_M, BLOCK_N = 64, 64\n", - " grid_m = (M + BLOCK_M - 1) // BLOCK_M\n", - " grid_n = (N + BLOCK_N - 1) // BLOCK_N\n", - " gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch(\n", - " grid=[grid_m, grid_n, 1],\n", - " block=[256, 1, 1],\n", - " stream=stream,\n", - " )\n", - "```\n", - "\n", "- `grid=[grid_m, grid_n, 1]` — one block per output tile, arranged in a 2-D grid\n", "- `block=[256, 1, 1]` — 256 threads per block, each handling multiple elements via the stride loop\n", - "- `M, N, K, BLOCK_M, BLOCK_N` are all passed as compile-time constants to the kernel\n", - "\n", - "### JAX wrapper for GEMM\n", + "- `M, N, K, BLOCK_M, BLOCK_N` are all passed as compile-time constants to the kernel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6d73450", + "metadata": {}, + "outputs": [], + "source": [ + "@cute.jit\n", + "def launch_gemm(\n", + " stream: cuda.CUstream,\n", + " A: cute.Tensor,\n", + " B: cute.Tensor,\n", + " D: cute.Tensor,\n", + " *,\n", + " M: int,\n", + " N: int,\n", + " K: int,\n", + "):\n", + " BLOCK_M, BLOCK_N = 64, 64\n", + " grid_m = (M + BLOCK_M - 1) // BLOCK_M\n", + " grid_n = (N + BLOCK_N - 1) // BLOCK_N\n", + " gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch(\n", + " grid=[grid_m, grid_n, 1],\n", + " block=[256, 1, 1],\n", + " stream=stream,\n", + " )\n", "\n", - "The JAX wrapper flattens both input matrices to 1-D (matching the kernel's flat indexing), passes the matrix dimensions as keyword arguments, and reshapes the result:\n", "\n", - "```python\n", - "a_flat = a.reshape(-1)\n", - "b_flat = b.reshape(-1)\n", - "call = cjax.cutlass_call(\n", - " launch_gemm,\n", - " output_shape_dtype=jax.ShapeDtypeStruct((M * N,), a.dtype),\n", - " M=M, N=N, K=K,\n", - ")\n", - "d_flat = call(a_flat, b_flat)\n", - "return d_flat.reshape(M, N)\n", - "```" + "print(\"launch_gemm defined.\")" + ] + }, + { + "cell_type": "markdown", + "id": "50c510bf", + "metadata": {}, + "source": [ + "The JAX wrapper flattens both input matrices to 1-D (matching the kernel's flat indexing), passes the matrix dimensions as keyword arguments, and reshapes the result:" ] }, { "cell_type": "code", "execution_count": null, + "id": "1077f87c", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def jax_cutlass_gemm(a, b):\n", - " \"\"\"JAX wrapper for the CUTLASS GEMM kernel.\"\"\"\n", - " M, K = a.shape\n", - " _, N = b.shape\n", - " a_flat = a.reshape(-1)\n", - " b_flat = b.reshape(-1)\n", - " call = cjax.cutlass_call(\n", - " launch_gemm,\n", - " output_shape_dtype=jax.ShapeDtypeStruct((M * N,), a.dtype),\n", - " M=M, N=N, K=K,\n", - " )\n", - " d_flat = call(a_flat, b_flat)\n", - " return d_flat.reshape(M, N)\n", + " \"\"\"JAX wrapper for the CUTLASS GEMM kernel.\"\"\"\n", + " M, K = a.shape\n", + " _, N = b.shape\n", + " a_flat = a.reshape(-1)\n", + " b_flat = b.reshape(-1)\n", + " call = cjax.cutlass_call(\n", + " launch_gemm,\n", + " output_shape_dtype=jax.ShapeDtypeStruct((M * N,), a.dtype),\n", + " M=M,\n", + " N=N,\n", + " K=K,\n", + " )\n", + " d_flat = call(a_flat, b_flat)\n", + " return d_flat.reshape(M, N)\n", + "\n", "\n", "print(\"jax_cutlass_gemm defined.\")" ] }, { "cell_type": "markdown", + "id": "d01132c8", "metadata": {}, "source": [ "Test the CUTLASS GEMM against JAX's `jnp.matmul`. We use relaxed tolerances (`rtol=1e-2`) because our simple kernel accumulates the K-dimension in a different order than cuBLAS, leading to small floating-point differences that are expected and harmless." @@ -886,25 +1183,26 @@ { "cell_type": "code", "execution_count": null, + "id": "3740a868", "metadata": {}, "outputs": [], "source": [ "# Test GEMM\n", "M, N, K = 256, 256, 128\n", - "key = jax.random.PRNGKey(0)\n", - "A = jax.random.normal(key, (M, K), dtype=jnp.float32)\n", - "B = jax.random.normal(jax.random.PRNGKey(1), (K, N), dtype=jnp.float32)\n", + "A = jax.random.normal(next(keys), (M, K), dtype=jnp.float32)\n", + "B = jax.random.normal(next(keys), (K, N), dtype=jnp.float32)\n", "\n", "D = jax_cutlass_gemm(A, B)\n", - "D_ref = jnp.matmul(A, B)\n", + "D_ref = jnp.matmul(A, B, precision=jax.lax.Precision.HIGHEST)\n", "\n", - "np.testing.assert_allclose(np.array(D), np.array(D_ref), rtol=1e-2, atol=1e-2)\n", + "np.testing.assert_allclose(np.array(D), np.array(D_ref), rtol=1e-2, atol=2e-2)\n", "print(f\"GEMM PASSED (M={M}, N={N}, K={K})\")\n", "print(f\" Max error: {float(jnp.max(jnp.abs(D - D_ref))):.2e}\")" ] }, { "cell_type": "markdown", + "id": "d8fda249", "metadata": {}, "source": [ "## Performance comparison\n", @@ -921,14 +1219,15 @@ { "cell_type": "code", "execution_count": null, + "id": "6a127df5", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "M, N, K = 512, 512, 512\n", - "A = jax.random.normal(jax.random.PRNGKey(0), (M, K), dtype=jnp.float32)\n", - "B = jax.random.normal(jax.random.PRNGKey(1), (K, N), dtype=jnp.float32)\n", + "A = jax.random.normal(next(keys), (M, K), dtype=jnp.float32)\n", + "B = jax.random.normal(next(keys), (K, N), dtype=jnp.float32)\n", "\n", "# Warmup\n", "_ = jax_cutlass_gemm(A, B).block_until_ready()\n", @@ -937,15 +1236,15 @@ "NUM_RUNS = 20\n", "\n", "# Time CUTLASS GEMM\n", - "start = time.perf_counter() \n", + "start = time.perf_counter()\n", "for _ in range(NUM_RUNS):\n", - " _ = jax_cutlass_gemm(A, B).block_until_ready()\n", + " _ = jax_cutlass_gemm(A, B).block_until_ready()\n", "cutlass_time = (time.perf_counter() - start) / NUM_RUNS\n", "\n", "# Time JAX matmul\n", "start = time.perf_counter()\n", "for _ in range(NUM_RUNS):\n", - " _ = jnp.matmul(A, B).block_until_ready()\n", + " _ = jnp.matmul(A, B).block_until_ready()\n", "jax_time = (time.perf_counter() - start) / NUM_RUNS\n", "\n", "print(f\"Matrix size: {M}x{N}x{K}\")\n", @@ -959,6 +1258,7 @@ }, { "cell_type": "markdown", + "id": "ef7dbe90", "metadata": {}, "source": [ "## Multi-GPU: sharding CUTLASS kernels via `jax.shard_map`\n", @@ -972,12 +1272,15 @@ "**1. Create a device mesh.** A mesh maps physical devices to named logical axes:\n", "\n", "```python\n", - "mesh = jax.make_mesh((num_devices,), \"x\")\n", + "mesh = jax.make_mesh((num_devices,), (\"x\",))\n", + "jax.set_mesh(mesh)\n", "```\n", "\n", "This creates a 1-D mesh with `num_devices` devices along an axis called `\"x\"`. For 8 GPUs, the mesh maps device 0 through device 7 to positions 0–7 along the `\"x\"` axis.\n", "\n", - "**2. Define the sharding spec.** `PartitionSpec` tells JAX how to slice each tensor dimension across the mesh:\n", + "**2. Define the sharding spec.**\n", + "\n", + "`PartitionSpec` tells JAX how to slice each tensor dimension across the mesh:\n", "\n", "```python\n", "sharding = P(None, None, \"x\")\n", @@ -990,86 +1293,124 @@ "\n", "So with 8 devices and 128 total blocks, each device gets a tensor of shape `(1, 256, 16)` — its 16 local blocks.\n", "\n", - "**3. Use `shard_map` to run per-device code.** `shard_map` wraps a function so that each device receives only its local shard:\n", + "**3. Create sharded inputs.**\n", + "\n", + "With explicit mesh axes, inputs must already have a layout compatible with the mesh.\n", + "\n", + "We create them directly with the desired sharding:\n", "\n", "```python\n", - "@partial(\n", - " shard_map,\n", - " mesh=mesh,\n", - " in_specs=(sharding, sharding),\n", - " out_specs=sharding,\n", + "a = jax.random.normal(\n", + " jax.random.key(next(keys)),\n", + " shape,\n", + " dtype=jnp.float32,\n", + " out_sharding=sharding,\n", + ")\n", + "b = jax.random.normal(\n", + " jax.random.key(next(keys)),\n", + " shape,\n", + " dtype=jnp.float32,\n", + " out_sharding=sharding,\n", ")\n", - "def _add(a_shard, b_shard):\n", + "```\n", + "\n", + "This produces arrays with sharding P(None, None, \"x\"), matching the computation.\n", + "\n", + "An equivalent alternative is to create unsharded arrays and place them explicitly:\n", + "\n", + "```python\n", + "from jax.sharding import NamedSharding\n", + "\n", + "named_sharding = NamedSharding(mesh, sharding)\n", + "\n", + "a = jax.random.normal(jax.random.key(next(keys)), shape, dtype=jnp.float32)\n", + "b = jax.random.normal(jax.random.key(next(keys)), shape, dtype=jnp.float32)\n", + "\n", + "a = jax.device_put(a, named_sharding)\n", + "b = jax.device_put(b, named_sharding)\n", + "```\n", + "\n", + "**4. Use `jax.shard_map` to run per-device code**\n", + "\n", + "With an explicit mesh set via `jax.set_mesh`, `jax.shard_map` can be written concisely:\n", + "\n", + "```python\n", + "@jax.shard_map(out_specs=sharding)\n", + "def sharded_vector_add(a_shard, b_shard):\n", " call = cjax.cutlass_call(\n", " launch_vector_add,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(a_shard.shape, a_shard.dtype),\n", + " output_shape_dtype=jax.typeof(a_shard),\n", " use_static_tensors=True,\n", " )\n", " return call(a_shard, b_shard)\n", "```\n", "\n", - "Inside `_add`, the code is identical to single-GPU — it sees a regular tensor and calls the same CUTLASS kernel. The kernel has no idea it's running on multiple GPUs. JAX handles splitting inputs before the kernel and reassembling outputs afterward." + "Inside `sharded_vector_add`, the code is identical to single-GPU — it sees a regular tensor and calls the same CUTLASS kernel. The kernel has no idea it's running on multiple GPUs. JAX handles splitting inputs before the kernel and reassembling outputs afterward." ] }, { "cell_type": "code", "execution_count": null, + "id": "6b946c7f", "metadata": {}, "outputs": [], "source": [ - "import warnings\n", - "from functools import partial\n", "from jax.sharding import PartitionSpec as P\n", - "with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\", DeprecationWarning)\n", - " from jax.experimental.shard_map import shard_map\n", "\n", "num_devices = len(jax.devices())\n", "print(f\"Number of devices: {num_devices}\")\n", "\n", - "if num_devices > 1:\n", - " mesh = jax.make_mesh((num_devices,), \"x\")\n", - " # Kernel expects 3-D tensors: (elems_per_thread, threads, blocks)\n", - " # Shard along the blocks axis (last dim)\n", - " sharding = P(None, None, \"x\")\n", - "\n", - " @jax.jit\n", - " def sharded_vector_add(a, b):\n", - " @partial(\n", - " shard_map,\n", - " mesh=mesh,\n", - " in_specs=(sharding, sharding),\n", - " out_specs=sharding,\n", - " )\n", - " def _add(a_shard, b_shard):\n", - " call = cjax.cutlass_call(\n", - " launch_vector_add,\n", - " output_shape_dtype=jax.ShapeDtypeStruct(\n", - " a_shard.shape, a_shard.dtype\n", - " ),\n", - " use_static_tensors=True,\n", - " )\n", - " return call(a_shard, b_shard)\n", - " return _add(a, b)\n", - "\n", - " # Create 3-D tensors: (1, 256, total_blocks) with total_blocks divisible by device count\n", - " blocks_per_device = 16\n", - " total_blocks = blocks_per_device * num_devices\n", - " shape = (1, BLOCK, total_blocks)\n", - " a_m = jax.random.normal(jax.random.PRNGKey(10), shape, dtype=jnp.float32)\n", - " b_m = jax.random.normal(jax.random.PRNGKey(11), shape, dtype=jnp.float32)\n", - "\n", - " c_m = sharded_vector_add(a_m, b_m)\n", - " np.testing.assert_allclose(np.array(c_m), np.array(a_m + b_m), rtol=1e-5)\n", - " N_total = int(np.prod(shape))\n", - " print(f\"Sharded Vector Add PASSED across {num_devices} devices (N={N_total})\")\n", - "else:\n", - " print(\"Only 1 device detected. Skipping multi-GPU example.\")\n", - " print(\"On a multi-GPU system, shard_map distributes CUTLASS kernels across devices.\")" + "BLOCK = 256\n", + "\n", + "mesh = jax.make_mesh((num_devices,), (\"x\",))\n", + "\n", + "# Use `jax.set_mesh` as a context manager so the mesh is scoped to this\n", + "# sharding demo and does not leak into later cells.\n", + "with jax.set_mesh(mesh):\n", + " # Kernel expects 3-D tensors: (elems_per_thread, threads, blocks)\n", + " # Shard along the blocks axis (last dim)\n", + " sharding = P(None, None, \"x\")\n", + "\n", + " @jax.shard_map(out_specs=sharding)\n", + " def sharded_vector_add(a_shard, b_shard):\n", + " call = cjax.cutlass_call(\n", + " launch_vector_add,\n", + " output_shape_dtype=jax.typeof(a_shard),\n", + " use_static_tensors=True,\n", + " )\n", + " return call(a_shard, b_shard)\n", + "\n", + " # Create 3-D tensors: (1, 256, total_blocks) with total_blocks divisible by device count\n", + " blocks_per_device = 16\n", + " total_blocks = blocks_per_device * num_devices\n", + " shape = (1, BLOCK, total_blocks)\n", + "\n", + " a_m = jax.random.normal(\n", + " jax.random.key(next(keys)),\n", + " shape,\n", + " dtype=jnp.float32,\n", + " out_sharding=sharding,\n", + " )\n", + " b_m = jax.random.normal(\n", + " jax.random.key(next(keys)),\n", + " shape,\n", + " dtype=jnp.float32,\n", + " out_sharding=sharding,\n", + " )\n", + "\n", + " print(\"a_m sharding:\", a_m.sharding)\n", + " print(\"b_m sharding:\", b_m.sharding)\n", + "\n", + " c_m = sharded_vector_add(a_m, b_m)\n", + "\n", + " np.testing.assert_allclose(jnp.array(c_m), jnp.array(a_m + b_m), rtol=1e-5)\n", + " n_total = int(np.prod(shape))\n", + " print(f\"Sharded Vector Add PASSED across {num_devices} devices (N={n_total})\")" ] }, { "cell_type": "markdown", + "id": "6e8f2686", "metadata": {}, "source": [ "## Exporting CUTLASS kernels with `jax.export`\n", @@ -1122,34 +1463,81 @@ { "cell_type": "code", "execution_count": null, + "id": "220b2712", "metadata": {}, "outputs": [], "source": [ - "from jax import export\n", "from cutlass.jax import get_export_disabled_safety_checks\n", + "from jax import export\n", + "\n", + "\n", + "# Element-wise Add (2-D, flat indexing)\n", + "@cute.kernel\n", + "def elementwise_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):\n", + " \"\"\"Per-thread kernel: 2-D element-wise add using flat indexing.\"\"\"\n", + " tidx, _, _ = cute.arch.thread_idx()\n", + " bidx, _, _ = cute.arch.block_idx()\n", + " bdim, _, _ = cute.arch.block_dim()\n", + "\n", + " thread_idx = bidx * bdim + tidx\n", + "\n", + " m, n = gA.shape\n", + "\n", + " if thread_idx < m * n:\n", + "\n", + " ni = thread_idx % n\n", + " mi = thread_idx // n\n", + "\n", + " a_val = gA[mi, ni]\n", + " b_val = gB[mi, ni]\n", + " gC[mi, ni] = a_val + b_val\n", + "\n", + "\n", + "@cute.jit\n", + "def launch_elementwise_add(\n", + " stream: cuda.CUstream,\n", + " mA: cute.Tensor,\n", + " mB: cute.Tensor,\n", + " mC: cute.Tensor,\n", + "):\n", + " num_threads_per_block = 256\n", + " m, n = mA.shape\n", + " elementwise_add_kernel(mA, mB, mC).launch(\n", + " grid=((m * n + num_threads_per_block - 1) // num_threads_per_block, 1, 1),\n", + " block=(num_threads_per_block, 1, 1),\n", + " stream=stream,\n", + " )\n", + "\n", "\n", "# Define a function that uses a CUTLASS kernel + JAX ops.\n", "# We use launch_elementwise_add which accepts 2-D tensors directly\n", "# with flat indexing — compatible with jax.export's tracing.\n", "@jax.jit\n", "def f(a, b):\n", - " call = cjax.cutlass_call(launch_elementwise_add, output_shape_dtype=a)\n", - " return jax.nn.sigmoid(call(a, b))\n", + " call = cjax.cutlass_call(launch_elementwise_add, output_shape_dtype=a)\n", + " return jax.nn.sigmoid(call(a, b))\n", + "\n", "\n", "# Reference implementation (pure JAX)\n", "@jax.jit\n", "def ref_f(a, b):\n", - " return jax.nn.sigmoid(a + b)\n", + " return jax.nn.sigmoid(a + b)\n", + "\n", "\n", "# --- Export with concrete shapes ---\n", "M, N = 512, 256\n", "export_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32)\n", "\n", - "print(f\"Exporting with input signature: ({export_shape_dtype}, {export_shape_dtype})\")\n", + "print(\n", + " f\"Exporting with input signature: ({export_shape_dtype},\"\n", + " f\" {export_shape_dtype})\"\n", + ")\n", "\n", "# Export the function — get_export_disabled_safety_checks() tells JAX\n", "# that CUTLASS custom call targets are safe to include\n", - "exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())\n", + "exported = jax.export.export(\n", + " f, disabled_checks=get_export_disabled_safety_checks()\n", + ")\n", "traced = exported(export_shape_dtype, export_shape_dtype)\n", "\n", "# Serialize to a byte blob\n", @@ -1159,9 +1547,8 @@ "# Deserialize and run — this works independently of the original function\n", "rehydrated = export.deserialize(blob)\n", "\n", - "key = jax.random.PRNGKey(1123)\n", - "a = jax.random.normal(key, (M, N), dtype=jnp.float32)\n", - "b = jax.random.normal(jax.random.PRNGKey(456), (M, N), dtype=jnp.float32)\n", + "a = jax.random.normal(next(keys), (M, N), dtype=jnp.float32)\n", + "b = jax.random.normal(next(keys), (M, N), dtype=jnp.float32)\n", "\n", "c = rehydrated.call(a, b)\n", "c_ref = ref_f(a, b)\n", @@ -1173,6 +1560,7 @@ }, { "cell_type": "markdown", + "id": "d3a7d39b", "metadata": {}, "source": [ "### Exporting with symbolic shapes\n", @@ -1185,6 +1573,7 @@ { "cell_type": "code", "execution_count": null, + "id": "6260595b", "metadata": {}, "outputs": [], "source": [ @@ -1192,9 +1581,14 @@ "a_sym, b_sym = export.symbolic_shape(\"a, b\")\n", "symbolic_shape_dtype = jax.ShapeDtypeStruct((a_sym, b_sym), jnp.float32)\n", "\n", - "print(f\"Exporting with symbolic signature: ({symbolic_shape_dtype}, {symbolic_shape_dtype})\")\n", + "print(\n", + " f\"Exporting with symbolic signature: ({symbolic_shape_dtype},\"\n", + " f\" {symbolic_shape_dtype})\"\n", + ")\n", "\n", - "exported_sym = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks())\n", + "exported_sym = jax.export.export(\n", + " f, disabled_checks=get_export_disabled_safety_checks()\n", + ")\n", "traced_sym = exported_sym(symbolic_shape_dtype, symbolic_shape_dtype)\n", "blob_sym = traced_sym.serialize()\n", "print(f\"Serialized computation: {len(blob_sym):,} bytes\")\n", @@ -1205,18 +1599,19 @@ "# The same serialized blob works for any (M, N) where M*N is a\n", "# multiple of the kernel's block size (256).\n", "for shape in [(512, 256), (1024, 512), (2048, 1024)]:\n", - " a = jax.random.normal(jax.random.PRNGKey(42), shape, dtype=jnp.float32)\n", - " b = jax.random.normal(jax.random.PRNGKey(43), shape, dtype=jnp.float32)\n", - " c = rehydrated_sym.call(a, b)\n", - " c_ref = ref_f(a, b)\n", - " np.testing.assert_allclose(np.array(c), np.array(c_ref), rtol=1e-5)\n", - " print(f\" Symbolic export PASSED for shape {shape}\")\n", + " a = jax.random.normal(next(keys), shape, dtype=jnp.float32)\n", + " b = jax.random.normal(next(keys), shape, dtype=jnp.float32)\n", + " c = rehydrated_sym.call(a, b)\n", + " c_ref = ref_f(a, b)\n", + " np.testing.assert_allclose(np.array(c), np.array(c_ref), rtol=1e-5)\n", + " print(f\" Symbolic export PASSED for shape {shape}\")\n", "\n", "print(\"All symbolic shape tests passed.\")" ] }, { "cell_type": "markdown", + "id": "2372a537", "metadata": {}, "source": [ "## Summary\n", @@ -1237,6 +1632,10 @@ } ], "metadata": { + "jupytext": { + "default_lexer": "ipython3", + "formats": "ipynb,md:myst" + }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", @@ -1256,5 +1655,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py index 7089368713..79b27d38f2 100644 --- a/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py +++ b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py @@ -29,7 +29,7 @@ import cutlass import cutlass.cute as cute import cutlass.jax as cjax -import cuda.bindings.driver as cuda +import cuda.bindings.driver as cuda # pyrefly: ignore """ CuTe DSL kernels used by the ``cute_dsl_jax.ipynb`` notebook. @@ -53,7 +53,7 @@ inside ``@jax.jit`` functions. See ``cute_dsl_jax.ipynb`` for usage, validation, and step-by-step explanations. -This module is imported by the notebook and by ``cute_dsl_jax.py``. It can also +This module is imported by the notebook and by ``cute_dsl_jax_kernels.py``. It can also be run directly to validate every kernel: .. code-block:: bash @@ -91,8 +91,8 @@ def launch_vector_add( a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, ): vector_add_kernel(a, b, c).launch( - grid=[a.shape[-1], 1, 1], - block=[a.shape[-2], 1, 1], + grid=[a.shape[-1], 1, 1], # pyrefly: ignore + block=[a.shape[-2], 1, 1], # pyrefly: ignore stream=stream, ) @@ -123,8 +123,8 @@ def launch_saxpy( *, alpha: float, ): saxpy_kernel(x, y, out, alpha).launch( - grid=[x.shape[-1], 1, 1], - block=[x.shape[-2], 1, 1], + grid=[x.shape[-1], 1, 1], # pyrefly: ignore + block=[x.shape[-2], 1, 1], # pyrefly: ignore stream=stream, ) @@ -175,7 +175,7 @@ def fused_bias_relu_kernel( idx = bidx * bdx + tidx if idx < N: col = idx % width - val = x[idx] + bias[col] + val = x[idx] + bias[col] # pyrefly: ignore out[idx] = cutlass.max(val, cutlass.Float32(0.0)) @@ -215,7 +215,7 @@ def gemm_kernel( if m_idx < M and n_idx < N: acc = cutlass.Float32(0.0) for k in cutlass.range(K): - acc += A[m_idx * K + k] * B[k * N + n_idx] + acc += A[m_idx * K + k] * B[k * N + n_idx] # pyrefly: ignore D[m_idx * N + n_idx] = acc @@ -225,35 +225,38 @@ def launch_gemm( A: cute.Tensor, B: cute.Tensor, D: cute.Tensor, *, M: int, N: int, K: int, ): - BLOCK_M, BLOCK_N = 64, 64 - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch( + BLOCK_M, BLOCK_N = 64, 64 + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + gemm_kernel(A, B, D, M, N, K, BLOCK_M, BLOCK_N).launch( grid=[grid_m, grid_n, 1], block=[256, 1, 1], stream=stream, ) - + # ------------------------------------------------------------------ # # Element-wise Add (2-D, flat indexing) # # ------------------------------------------------------------------ # @cute.kernel def elementwise_add_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): - """Per-thread kernel: 2-D element-wise add using flat indexing.""" - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - bdim, _, _ = cute.arch.block_dim() + """Per-thread kernel: 2-D element-wise add using flat indexing.""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + + thread_idx = bidx * bdim + tidx - thread_idx = bidx * bdim + tidx + m, n = gA.shape # pyrefly: ignore + + if thread_idx < m * n: # pyrefly: ignore - m, n = gA.shape ni = thread_idx % n mi = thread_idx // n a_val = gA[mi, ni] b_val = gB[mi, ni] - gC[mi, ni] = a_val + b_val + gC[mi, ni] = a_val + b_val # pyrefly: ignore @cute.jit @@ -262,9 +265,9 @@ def launch_elementwise_add( mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, ): num_threads_per_block = 256 - m, n = mA.shape + m, n = mA.shape # pyrefly: ignore elementwise_add_kernel(mA, mB, mC).launch( - grid=((m * n) // num_threads_per_block, 1, 1), + grid=((m * n + num_threads_per_block - 1) // num_threads_per_block, 1, 1), # pyrefly: ignore block=(num_threads_per_block, 1, 1), stream=stream, ) @@ -281,13 +284,21 @@ def launch_elementwise_add( import jax.numpy as jnp import numpy as np + def split_keys(seed=0): + key = jax.random.key(seed) + while True: + key, subkey = jax.random.split(key) + yield subkey + + keys = iter(split_keys()) + BLOCK = 256 N_BLOCKS = 4 # ── Vector Add ──────────────────────────────────────────────────── # 3-D CuTe layout: (elems_per_thread, threads_per_block, num_blocks) - a = jax.random.normal(jax.random.PRNGKey(0), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) - b = jax.random.normal(jax.random.PRNGKey(1), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + a = jax.random.normal(next(keys), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + b = jax.random.normal(next(keys), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) call = cjax.cutlass_call( launch_vector_add, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype), @@ -298,8 +309,8 @@ def launch_elementwise_add( print('vector_add: PASSED') # ── SAXPY ───────────────────────────────────────────────────────── - x = jax.random.normal(jax.random.PRNGKey(2), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) - y = jax.random.normal(jax.random.PRNGKey(3), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + x = jax.random.normal(next(keys), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) + y = jax.random.normal(next(keys), (1, BLOCK, N_BLOCKS), dtype=jnp.float32) alpha = 2.5 call = cjax.cutlass_call( launch_saxpy, @@ -313,7 +324,7 @@ def launch_elementwise_add( # ── ReLU ────────────────────────────────────────────────────────── N_ELEM = BLOCK * N_BLOCKS - x = jax.random.normal(jax.random.PRNGKey(4), (N_ELEM,), dtype=jnp.float32) + x = jax.random.normal(next(keys), (N_ELEM,), dtype=jnp.float32) call = cjax.cutlass_call( launch_relu, output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype), @@ -325,8 +336,8 @@ def launch_elementwise_add( # ── Fused Bias + ReLU ───────────────────────────────────────────── ROWS, COLS = 16, 64 - x = jax.random.normal(jax.random.PRNGKey(5), (ROWS * COLS,), dtype=jnp.float32) - bias = jax.random.normal(jax.random.PRNGKey(6), (COLS,), dtype=jnp.float32) + x = jax.random.normal(next(keys), (ROWS * COLS,), dtype=jnp.float32) + bias = jax.random.normal(next(keys), (COLS,), dtype=jnp.float32) call = cjax.cutlass_call( launch_fused_bias_relu, output_shape_dtype=jax.ShapeDtypeStruct(x.shape, x.dtype), @@ -339,8 +350,8 @@ def launch_elementwise_add( # ── GEMM ────────────────────────────────────────────────────────── M, N, K = 128, 128, 64 - A = jax.random.normal(jax.random.PRNGKey(7), (M * K,), dtype=jnp.float32) - B = jax.random.normal(jax.random.PRNGKey(8), (K * N,), dtype=jnp.float32) + A = jax.random.normal(next(keys), (M * K,), dtype=jnp.float32) + B = jax.random.normal(next(keys), (K * N,), dtype=jnp.float32) call = cjax.cutlass_call( launch_gemm, output_shape_dtype=jax.ShapeDtypeStruct((M * N,), A.dtype), @@ -353,8 +364,8 @@ def launch_elementwise_add( # ── Elementwise Add (2-D) ───────────────────────────────────────── M, N = 16, 256 - a = jax.random.normal(jax.random.PRNGKey(9), (M, N), dtype=jnp.float32) - b = jax.random.normal(jax.random.PRNGKey(10), (M, N), dtype=jnp.float32) + a = jax.random.normal(next(keys), (M, N), dtype=jnp.float32) + b = jax.random.normal(next(keys), (M, N), dtype=jnp.float32) call = cjax.cutlass_call( launch_elementwise_add, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, a.dtype),