POC: Integer cooperative matrix (i32/u32/i8/u8) on Vulkan#9490
POC: Integer cooperative matrix (i32/u32/i8/u8) on Vulkan#9490ruihe774 wants to merge 3 commits intogfx-rs:trunkfrom
Conversation
…, u8)
Adds integer scalar support to cooperative matrices on the Vulkan backend.
Extends naga IR, WGSL frontend, SPIR-V backend, and Vulkan HAL to support
{i32, u32, i8, u8} element types alongside existing f16/f32, including
asymmetric matrix shapes (coop_mat16x32, coop_mat32x16, etc.) needed for
integer multiplies on real hardware.
Relaxes the scalar-equality requirement in CooperativeMultiplyAdd validation to allow widened accumulators: i8/u8 inputs with i32/u32 accumulators, and f16 inputs with f32 accumulators. The SPIR-V backend already emits per-operand CooperativeMatrixOperands flags correctly, so no codegen changes are needed there. The MSL backend wrapper is updated to track input and accumulator scalars separately and emit the right mixed-type simdgroup_multiply_accumulate signature (enabling the f16→f32 combo that Metal already advertises via CooperativeMatrixProperties).
Simplify the CooperativeMatrixOperands construction in the SPIR-V
backend to check `kind != Float` directly instead of two `matches!`
clauses on Sint/Uint, and replace defensive `.scalar().unwrap_or(F32)`
fallbacks with a `scalar_for` closure that pattern-matches
`CooperativeMatrix` and reaches `unreachable!()` otherwise.
Add a dedicated `InvalidCooperativeMixedInputs { a, b }` validation
error so an A/B scalar mismatch in `coopMultiplyAdd` no longer shares
`InvalidCooperativeOperand` with shape mismatches.
Document the remaining design tradeoffs inline: the SHADER_INT8 ↔
COOPERATIVE_MATRIX coupling in the bridge, the global scope of the
i8/u8 scalars unlocked by `enable wgpu_cooperative_matrix;`, the
`coop_mat<C>x<R>` naming convention, the unrealistic accumulator combo
in the int8 test fixture, and the float-only invariant the MSL wrapper
generator relies on its caller to enforce.
|
CC @kvark Also, can you clarify if you used LLMs for the writing of the code or the comment above? Its quite a long description and that may prevent it from getting timely reviews. Finally, if it is a POC we can bring it up at the meeting but it will probably end up getting marked as a draft. |
Yes, I used them. And I don't think the PR is in a status ready for review.
It's actually an implementation of integer types in https://github.com/gpuweb/gpuweb/blob/main/proposals/subgroup-matrix.md. However, as the coop mat implementation of wgpu already diverges with the proposal, I do not expect it can be merged until we refactor the current coop mat interface and implementation |
|
I'm going to assign this to myself so I don't lose it, but I probably won't get to this for a few weeks |
Posting this as a discussion-only POC. The W3C WebGPU cooperative-matrix proposal hasn't settled on integer types or asymmetric shapes yet, so the WGSL surface here is necessarily provisional and likely needs to change before it can land for real. The intent of this PR is to share what works end-to-end on real Vulkan hardware, surface the API questions that fall out, and give #8251's "follow-up" list a concrete starting point. Not expected to be merged as-is.
Connections
wgpu_cooperative_matrixextension).InvalidCooperativeMixedInputsandInvalidCooperativeAccumulatorerror variants, aCooperativeMatrixScalarUnsupported { kind, width }type error, and explicit shape-compatibility checks (a_cols == b_rows, outer dims) that previously fell through to a generic operand error.Description
Extends
EXPERIMENTAL_COOPERATIVE_MATRIXto integer scalar types and the asymmetric matrix shapes real hardware needs for integer GEMM, and adds mixed-precision accumulation. Vulkan-only on the runtime side; Metal continues to reject non-float at the MSL backend (itssimdgroup_matrixis float-only).Naga
Scalar::I8/U8(width=1 byte),CooperativeSize::ThirtyTwo.i8/u8predeclared scalars and the asymmetriccoop_mat8x16,coop_mat16x8,coop_mat8x32,coop_mat32x8,coop_mat16x32,coop_mat32x16types, all gated behind the existingenable wgpu_cooperative_matrix;directive. Naming follows WGSL's existingmat<C>x<R><T>ordering —coop_mat32x16<i8, A>is 32 columns × 16 rows.Capabilities::SHADER_INT8capability gating i8/u8 scalars (currently auto-implied byEXPERIMENTAL_COOPERATIVE_MATRIXin the bridge — see "Open questions" below).(kind, width)pairs with a dedicatedCooperativeMatrixScalarUnsupportederror.coopMultiplyAddnow validates shape compatibility (a_cols == b_rows,a_rows == c_rows,b_cols == c_cols) and the scalar combination: A and B must match (elseInvalidCooperativeMixedInputs), and C may be either the same scalar or a widened accumulator (f16→f32,i8→i32,u8→u32; elseInvalidCooperativeAccumulator { ab, c }).CooperativeMatrixOperands(MATRINENTS_KHR) per-operand for any non-floatNaga
Scalar::I8/U8(width=1 byte),CooperativeSize::ThirtyTwo.i8/u8predeclared scalars and the asymmetriccoop_mat8x16,coop_mat16x8,coop_mat8x32,coop_mat32x8,coop_mat16x32,coop_mat32x16types, all gated behind the existingenable wgpu_cooperative_matrix;directive. Naming follows WGSL's existingmat<C>x<R><T>ordering —coop_mat32x16<i8, A>is 32 columns × 16 rows.Capabilities::SHADER_INT8capabilitntly auto-implied byEXPERIMENTAL_COOPERATIVE_MATRIXin the bridge — see "Open questions" below).(kind, width)pairs with a dedicatedCooperativeMatrixScalarUnsupportederror. -coopMultiplyAddnow validates shape com,a_rows == c_rows,b_cols == c_cols) andthe scalar combination: A and B must match (enputs), and C may be either the same scalar or a widened accumulator (f16→f32,i8→i32,u8→u32; elseInvalidCooperativeAccumulator { ab, c }).CooperativeMatrixOperands(MATRIX_{A,B,C,RESULT}_SIGNED_COMPONENTS_KHR) per-operand for any non-float matmul; omitted entirely for all-float to keele. - RequestsInt8,StorageBuffer8BitAccessitAccess, andSPV_KHR_8bit_storagewhen8-bit scalars are used.Add/Subtracton cooperative matrices now lower toIAdd/ISub(previously hardcoded toFAdd/FSub). - MSL backend:char/ucharmappings fopleteness, but cooperative-matrix load andcoopMultiplyAddreject non-float at the userixdoesn't exist for integer types). Thewrapper generator now tracks input vs accumulator scalars separately sof16→f32` simdgroup multiply-accumulate emits the right typed signature.wgpu-hal/vulkan
VK_KHR_8bit_storage(promoted to Vulkan 1.2): pushesVkPhysicalDevice8BitStorageFeatureswhen supported sets a newprivate_caps.storage_buffer_8bitgeBuffer8BitAccess/UniformAndStorageBuffer8BitAccessSPIR-V capabilities to the naga writer when the device advertises them.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR→wgt::CooperativeScalarTypemapping extended forSINT8/UINT8and for size 32 (M/N/K).wgpu-types
CooperativeScalarType::{I8, U8}added. -EXPERIMENTAL_COOPERATIVE_MATRIXdocs rewralar types and platforms, spells out thecoop_mat<C>x<R>rows/columns convention, and documents thatenable wgpu_cooperative_matrix;makesi8/u8first-class WGSL scalars (usable in storage buffers and struct fields, not just inside cooperative matrices) — required so cooperative-matrix loads can address backi strides.wgpu-naga-bridge
Features::EXPERIMENTAL_COOPERATIVE_MATRIXtoCapabilities::SHADER_INT8, with a TODO noting that this should ideally be its own feature flag (see "Open questions").Open questions
coop_mat<C>x<R>to match WGSL's existingmatCxRconvention, but it inverts the M×N notation everyone uses verbally.SHADER_INT8↔EXPERIMENTAL_COOPERATIVE_MATRIXcoupling. Logically orthogonal; pragmatically tied here because the only consumer right now is cooperative matrix.i8/u8in WGSL. Currently any code underenable wgpu_cooperative_matrix;can use i8/u8 freely (e.g.array<i8>storage buffers). That's required to back CM loads, but it does make i8/u8 WGSL scalars before the W3C proposal has spoken on them.simdgroup_matrixis float-only, so integer CM is rejected at the MSL backend's use site. DX12 isn't wired up.i8 × i8 → i8(no real Vulkan device advertises this; runtime rejects viavkGetPhysicalDeviceCooperativeMatrixPropertiesKHR). Cheap to also reject at validation; left permissive for codegen-test symmetry.Testing
spirv-valon Vulkan 1.2 +SPV_KHR_8bit_storage) and WGSL round-trip golden output:naga/tests/in/wgsl/cooperative-matrix-int.wgsl— i32/u32 cooperative matrices.naga/tests/in/wgsl/cooperative-matrix-int8.wgsl— i8/u8 with asymmetriccoop_mat32x16×coop_mat16x32→coop_mat16x16shapes.naga/tests/in/wgsl/cooperative-matrix-mixed.wgsl— f16→f32, i8→i32, u8→u32 mixed-precision combos.wgsl_errorscases: invalid scalar types, invalid mixed inputs (A ≠ B), invalid widened accumulator (i8 with f32, f16 with i32).CooperativeMatrixPropertiesquery.Squash or Rebase?
Not intended for merging
Checklist
wgpumay be affected behaviorally.CHANGELOG.mdentries for the user-facing effects of this change are present.