Skip to content

POC: Integer cooperative matrix (i32/u32/i8/u8) on Vulkan#9490

Open
ruihe774 wants to merge 3 commits intogfx-rs:trunkfrom
ruihe774:coop_mat_int_trunk
Open

POC: Integer cooperative matrix (i32/u32/i8/u8) on Vulkan#9490
ruihe774 wants to merge 3 commits intogfx-rs:trunkfrom
ruihe774:coop_mat_int_trunk

Conversation

@ruihe774
Copy link
Copy Markdown

@ruihe774 ruihe774 commented May 3, 2026

⚠️ Proof of concept / RFC

Posting this as a discussion-only POC. The W3C WebGPU cooperative-matrix proposal hasn't settled on integer types or asymmetric shapes yet, so the WGSL surface here is necessarily provisional and likely needs to change before it can land for real. The intent of this PR is to share what works end-to-end on real Vulkan hardware, surface the API questions that fall out, and give #8251's "follow-up" list a concrete starting point. Not expected to be merged as-is.

Connections

Description

Extends EXPERIMENTAL_COOPERATIVE_MATRIX to integer scalar types and the asymmetric matrix shapes real hardware needs for integer GEMM, and adds mixed-precision accumulation. Vulkan-only on the runtime side; Metal continues to reject non-float at the MSL backend (its simdgroup_matrix is float-only).

Naga

  • IR: Scalar::I8/U8 (width=1 byte), CooperativeSize::ThirtyTwo.
  • WGSL frontend: i8/u8 predeclared scalars and the asymmetric coop_mat8x16, coop_mat16x8, coop_mat8x32, coop_mat32x8, coop_mat16x32, coop_mat32x16 types, all gated behind the existing enable wgpu_cooperative_matrix; directive. Naming follows WGSL's existing mat<C>x<R><T> ordering — coop_mat32x16<i8, A> is 32 columns × 16 rows.
  • Validation:
    • New Capabilities::SHADER_INT8 capability gating i8/u8 scalars (currently auto-implied by EXPERIMENTAL_COOPERATIVE_MATRIX in the bridge — see "Open questions" below).
    • Cooperative matrix scalar check rewritten to enumerate allowed (kind, width) pairs with a dedicated CooperativeMatrixScalarUnsupported error.
    • coopMultiplyAdd now validates shape compatibility (a_cols == b_rows, a_rows == c_rows, b_cols == c_cols) and the scalar combination: A and B must match (else InvalidCooperativeMixedInputs), and C may be either the same scalar or a widened accumulator (f16→f32, i8→i32, u8→u32; else InvalidCooperativeAccumulator { ab, c }).
  • SPIR-V backend:
    • Emits CooperativeMatrixOperands (MATRINENTS_KHR) per-operand for any non-float

Naga

  • IR: Scalar::I8/U8 (width=1 byte), CooperativeSize::ThirtyTwo.
  • WGSL frontend: i8/u8 predeclared scalars and the asymmetric coop_mat8x16, coop_mat16x8, coop_mat8x32, coop_mat32x8, coop_mat16x32, coop_mat32x16 types, all gated behind the existing enable wgpu_cooperative_matrix; directive. Naming follows WGSL's existing mat<C>x<R><T> ordering — coop_mat32x16<i8, A> is 32 columns × 16 rows.
  • Validation: - New Capabilities::SHADER_INT8 capabilitntly auto-implied byEXPERIMENTAL_COOPERATIVE_MATRIX in the bridge — see "Open questions" below).
    • Cooperative matrix scalar check rewritten to enumerate allowed (kind, width) pairs with a dedicated CooperativeMatrixScalarUnsupported error. - coopMultiplyAdd now validates shape com, a_rows == c_rows, b_cols == c_cols) andthe scalar combination: A and B must match (enputs), and C may be either the same scalar or a widened accumulator (f16→f32, i8→i32, u8→u32; else InvalidCooperativeAccumulator { ab, c }).
  • SPIR-V backend:
    • Emits CooperativeMatrixOperands (MATRIX_{A,B,C,RESULT}_SIGNED_COMPONENTS_KHR) per-operand for any non-float matmul; omitted entirely for all-float to keele. - Requests Int8, StorageBuffer8BitAccessitAccess, and SPV_KHR_8bit_storage when8-bit scalars are used.
    • Integer-typed Add/Subtract on cooperative matrices now lower to IAdd/ISub (previously hardcoded to FAdd/FSub). - MSL backend: char/uchar mappings fopleteness, but cooperative-matrix load andcoopMultiplyAdd reject non-float at the userixdoesn't exist for integer types). Thewrapper generator now tracks input vs accumulator scalars separately sof16→f32` simdgroup multiply-accumulate emits the right typed signature.
  • WGSL writer: round-trips i8/u8 and the new asymmetric shapes.

wgpu-hal/vulkan

  • Negotiates VK_KHR_8bit_storage (promoted to Vulkan 1.2): pushes VkPhysicalDevice8BitStorageFeatures when supported sets a new private_caps.storage_buffer_8bitgeBuffer8BitAccess /UniformAndStorageBuffer8BitAccess SPIR-V capabilities to the naga writer when the device advertises them.
  • vkGetPhysicalDeviceCooperativeMatrixPropertiesKHRwgt::CooperativeScalarType mapping extended for SINT8/UINT8 and for size 32 (M/N/K).

wgpu-types

  • CooperativeScalarType::{I8, U8} added. - EXPERIMENTAL_COOPERATIVE_MATRIX docs rewralar types and platforms, spells out thecoop_mat<C>x<R> rows/columns convention, and documents that enable wgpu_cooperative_matrix; makes i8/u8 first-class WGSL scalars (usable in storage buffers and struct fields, not just inside cooperative matrices) — required so cooperative-matrix loads can address backi strides.

wgpu-naga-bridge

  • Maps Features::EXPERIMENTAL_COOPERATIVE_MATRIX to Capabilities::SHADER_INT8, with a TODO noting that this should ideally be its own feature flag (see "Open questions").

Open questions

  1. Naming for asymmetric shapes. I followed coop_mat<C>x<R> to match WGSL's existing matCxR convention, but it inverts the M×N notation everyone uses verbally.
  2. SHADER_INT8EXPERIMENTAL_COOPERATIVE_MATRIX coupling. Logically orthogonal; pragmatically tied here because the only consumer right now is cooperative matrix.
  3. Scope of i8/u8 in WGSL. Currently any code under enable wgpu_cooperative_matrix; can use i8/u8 freely (e.g. array<i8> storage buffers). That's required to back CM loads, but it does make i8/u8 WGSL scalars before the W3C proposal has spoken on them.
  4. Vulkan-only. Metal's simdgroup_matrix is float-only, so integer CM is rejected at the MSL backend's use site. DX12 isn't wired up.
  5. Same-type i8/u8 accumulator. Validator currently accepts i8 × i8 → i8 (no real Vulkan device advertises this; runtime rejects via vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR). Cheap to also reject at validation; left permissive for codegen-test symmetry.

Testing

  • New WGSL snapshot fixtures with SPV (validated with spirv-val on Vulkan 1.2 + SPV_KHR_8bit_storage) and WGSL round-trip golden output:
    • naga/tests/in/wgsl/cooperative-matrix-int.wgsl — i32/u32 cooperative matrices.
    • naga/tests/in/wgsl/cooperative-matrix-int8.wgsl — i8/u8 with asymmetric coop_mat32x16 × coop_mat16x32coop_mat16x16 shapes.
    • naga/tests/in/wgsl/cooperative-matrix-mixed.wgsl — f16→f32, i8→i32, u8→u32 mixed-precision combos.
  • Three new wgsl_errors cases: invalid scalar types, invalid mixed inputs (A ≠ B), invalid widened accumulator (i8 with f32, f16 with i32).
  • All existing 202 naga tests pass; existing f16/f32 cooperative-matrix golden output is byte-identical.
  • Manually exercised on Vulkan against a discrete NVIDIA device that advertises i8/i8/i32 and f16/f16/f32 configurations; matches expected device combos via the CooperativeMatrixProperties query.

Squash or Rebase?

Not intended for merging

Checklist

  • I self-reviewed and fully understand this PR.
  • WebGPU implementations built with wgpu may be affected behaviorally.
  • Validation and feature gates are in place to confine behavioral changes.
  • Tests demonstrate the validation and altered logic works.
  • CHANGELOG.md entries for the user-facing effects of this change are present.
  • The PR is minimal, and doesn't make sense to land as multiple PRs.
  • Commits are logically scoped and individually reviewable.
  • The PR description has enough context to understand the motivation and solution implemented.

ruihe774 added 3 commits May 3, 2026 15:45
…, u8)

Adds integer scalar support to cooperative matrices on the Vulkan backend.
Extends naga IR, WGSL frontend, SPIR-V backend, and Vulkan HAL to support
{i32, u32, i8, u8} element types alongside existing f16/f32, including
asymmetric matrix shapes (coop_mat16x32, coop_mat32x16, etc.) needed for
integer multiplies on real hardware.
Relaxes the scalar-equality requirement in CooperativeMultiplyAdd
validation to allow widened accumulators: i8/u8 inputs with i32/u32
accumulators, and f16 inputs with f32 accumulators. The SPIR-V backend
already emits per-operand CooperativeMatrixOperands flags correctly, so
no codegen changes are needed there. The MSL backend wrapper is updated
to track input and accumulator scalars separately and emit the right
mixed-type simdgroup_multiply_accumulate signature (enabling the f16→f32
combo that Metal already advertises via CooperativeMatrixProperties).
Simplify the CooperativeMatrixOperands construction in the SPIR-V
backend to check `kind != Float` directly instead of two `matches!`
clauses on Sint/Uint, and replace defensive `.scalar().unwrap_or(F32)`
fallbacks with a `scalar_for` closure that pattern-matches
`CooperativeMatrix` and reaches `unreachable!()` otherwise.

Add a dedicated `InvalidCooperativeMixedInputs { a, b }` validation
error so an A/B scalar mismatch in `coopMultiplyAdd` no longer shares
`InvalidCooperativeOperand` with shape mismatches.

Document the remaining design tradeoffs inline: the SHADER_INT8 ↔
COOPERATIVE_MATRIX coupling in the bridge, the global scope of the
i8/u8 scalars unlocked by `enable wgpu_cooperative_matrix;`, the
`coop_mat<C>x<R>` naming convention, the unrealistic accumulator combo
in the int8 test fixture, and the float-only invariant the MSL wrapper
generator relies on its caller to enforce.
@inner-daemons
Copy link
Copy Markdown
Collaborator

CC @kvark

Also, can you clarify if you used LLMs for the writing of the code or the comment above? Its quite a long description and that may prevent it from getting timely reviews. Finally, if it is a POC we can bring it up at the meeting but it will probably end up getting marked as a draft.

@ruihe774
Copy link
Copy Markdown
Author

ruihe774 commented May 6, 2026

Also, can you clarify if you used LLMs for the writing of the code or the comment above?

Yes, I used them. And I don't think the PR is in a status ready for review.

Finally, if it is a POC we can bring it up at the meeting but it will probably end up getting marked as a draft.

It's actually an implementation of integer types in https://github.com/gpuweb/gpuweb/blob/main/proposals/subgroup-matrix.md. However, as the coop mat implementation of wgpu already diverges with the proposal, I do not expect it can be merged until we refactor the current coop mat interface and implementation

@cwfitzgerald cwfitzgerald self-assigned this May 6, 2026
@cwfitzgerald
Copy link
Copy Markdown
Member

cwfitzgerald commented May 6, 2026

I'm going to assign this to myself so I don't lose it, but I probably won't get to this for a few weeks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants