diff --git a/CHANGELOG.md b/CHANGELOG.md index d9cd1cf8902..9fcd9fd6d39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,8 @@ By @beholdnec in [#8505](https://github.com/gfx-rs/wgpu/pull/8505). - Conditional compilation by @jimblandy in [#9390](https://github.com/gfx-rs/wgpu/pull/9390) - Fixed alignment and MatrixStride for mat2x2 in SPIR-V uniform blocks. By @39ali [#9369](https://github.com/gfx-rs/wgpu/pull/9369). - Add `wgpu_hal::vulkan::Buffer::raw_handle()` for retrieving the underlying `vk::Buffer` resource. By @WillowGriffiths in [#9459](https://github.com/gfx-rs/wgpu/pull/9459). +- `EXPERIMENTAL_COOPERATIVE_MATRIX` now supports integer scalar types (`i32`, `u32`, `i8`, `u8`) in addition to the existing `f16`/`f32` support. The `i8`/`u8` types require device support for `shaderInt8` and `storageBuffer8BitAccess` (part of Vulkan 1.2 or `VK_KHR_8bit_storage`). New WGSL predeclared types `coop_mat8x16`, `coop_mat16x8`, `coop_mat8x32`, `coop_mat32x8`, `coop_mat16x32`, and `coop_mat32x16` are available behind `enable wgpu_cooperative_matrix;` for the asymmetric matrix shapes that integer multiplies typically require. By @ruihe774. +- `coopMultiplyAdd` now supports mixed-precision accumulation: `i8`/`u8` input matrices can use `i32`/`u32` accumulators, and `f16` input matrices can use `f32` accumulators (e.g. `coopMultiplyAdd(a_i8, b_i8, c_i32)`). By @ruihe774. #### naga diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c2cea6dd5c3..1c949786e75 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -517,7 +517,11 @@ pub(super) enum WrappedFunction { columns: crate::CooperativeSize, rows: crate::CooperativeSize, intermediate: crate::CooperativeSize, - scalar: crate::Scalar, + /// Scalar type of A and B (the multiplied operands). + input_scalar: crate::Scalar, + /// Scalar type of C and the result (the accumulator). May differ from + /// `input_scalar` for mixed-precision (e.g. f16 inputs, f32 accumulator). + accumulator_scalar: crate::Scalar, }, RayQueryGetIntersection { committed: bool, @@ -554,6 +558,14 @@ impl crate::Scalar { kind: Sk::Float, width: 2, } => "half", + Self { + kind: Sk::Sint, + width: 1, + } => "char", + Self { + kind: Sk::Uint, + width: 1, + } => "uchar", Self { kind: Sk::Sint, width: 2, @@ -2926,6 +2938,22 @@ impl Writer { if context.lang_version < (2, 3) { return Err(Error::UnsupportedCooperativeMatrix); } + // Metal simdgroup_matrix only supports floating-point types. + { + let ptr_ty = context.resolve_type(data.pointer); + let scalar = ptr_ty + .pointer_base_type() + .and_then(|tr| tr.inner_with(&context.module.types).scalar()); + if !matches!( + scalar, + Some(crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }) + ) { + return Err(Error::UnsupportedCooperativeMatrix); + } + } write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?; write!(self.out, "&")?; self.put_access_chain(data.pointer, context.policies.index, context)?; @@ -2937,7 +2965,33 @@ impl Writer { if context.lang_version < (2, 3) { return Err(Error::UnsupportedCooperativeMatrix); } - write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; + // Metal simdgroup_matrix only supports floating-point types. + let (in_name, acc_name) = { + let a_scalar = context.resolve_type(a).scalar(); + let c_scalar = context.resolve_type(c).scalar(); + match (a_scalar, c_scalar) { + ( + Some(crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }), + Some(crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }), + ) => ( + a_scalar.unwrap().to_msl_name(), + c_scalar.unwrap().to_msl_name(), + ), + _ => return Err(Error::UnsupportedCooperativeMatrix), + } + }; + let fn_suffix = if in_name == acc_name { + String::new() + } else { + format!("_{in_name}_{acc_name}") + }; + write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}{fn_suffix}(")?; self.put_expression(a, context, true)?; write!(self.out, ", ")?; self.put_expression(b, context, true)?; @@ -6448,6 +6502,14 @@ template Ok(()) } + /// NOTE: this generator does not itself reject non-float scalar types. Metal's + /// `simdgroup_matrix` only exists for `float`/`half`, so emitting a wrapper for + /// e.g. `simdgroup_char32x16` would produce invalid MSL. The float-only check + /// lives at the use site in `put_expression` for `CooperativeMultiplyAdd`, + /// which fails the whole compilation before the consumer ever reads the + /// emitted output. If we ever start writing wrappers in a context where + /// `put_expression` cannot bail (e.g. emitting a stub library), mirror that + /// check here. fn write_wrapped_cooperative_multiply_add( &mut self, module: &crate::Module, @@ -6455,9 +6517,10 @@ template space: crate::AddressSpace, a: Handle, b: Handle, + c: Handle, ) -> BackendResult { let space_name = space.to_msl_name().unwrap_or_default(); - let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { + let (a_c, a_r, input_scalar) = match *func_ctx.resolve_type(a, &module.types) { crate::TypeInner::CooperativeMatrix { columns, rows, @@ -6470,26 +6533,39 @@ template crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), _ => unreachable!(), }; + let accumulator_scalar = match *func_ctx.resolve_type(c, &module.types) { + crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar, + _ => unreachable!(), + }; let wrapped = WrappedFunction::CooperativeMultiplyAdd { space_name, columns: b_c, rows: a_r, intermediate: a_c, - scalar, + input_scalar, + accumulator_scalar, }; if !self.wrapped_functions.insert(wrapped) { return Ok(()); } - let scalar_name = scalar.to_msl_name(); + let in_name = input_scalar.to_msl_name(); + let acc_name = accumulator_scalar.to_msl_name(); + // When input and accumulator types differ we need a disambiguated name + // so two combos (e.g. f16→f32 and f32→f32) can coexist in one module. + let fn_suffix = if in_name == acc_name { + String::new() + } else { + format!("_{in_name}_{acc_name}") + }; writeln!( self.out, - "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", + "{NAMESPACE}::simdgroup_{acc_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}{fn_suffix}(const {space_name} {NAMESPACE}::simdgroup_{in_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{in_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{acc_name}{}x{}& c) {{", b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, )?; let l1 = back::Level(1); writeln!( self.out, - "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", + "{l1}{NAMESPACE}::simdgroup_{acc_name}{}x{} d;", b_c as u32, a_r as u32 )?; writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?; @@ -6587,9 +6663,9 @@ template data.pointer, )?; } - crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { let space = crate::AddressSpace::Private; - self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?; + self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b, c)?; } crate::Expression::RayQueryGetIntersection { committed, .. } => { self.write_rq_get_intersection_function(module, committed)?; diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index f9051883719..bda2b920340 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1076,7 +1076,12 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } - crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd, + crate::TypeInner::CooperativeMatrix { scalar, .. } => { + match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FAdd, + _ => spirv::Op::IAdd, + } + } _ => unimplemented!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { @@ -1105,7 +1110,12 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } - crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub, + crate::TypeInner::CooperativeMatrix { scalar, .. } => { + match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FSub, + _ => spirv::Op::ISub, + } + } _ => unimplemented!(), }, crate::BinaryOperator::Multiply => { @@ -2178,12 +2188,49 @@ impl BlockContext<'_> { let b_id = self.cached[b]; let c_id = self.cached[c]; let id = self.gen_id(); + + // Build the CooperativeMatrixOperands word. SPIR-V requires this + // operand whenever any matrix has signed-integer components; for + // all-float operands we omit it (NONE_KHR == 0 is equivalent but + // adds a word, and skipping it keeps the float golden tests stable). + let matrix_operands = { + use crate::ScalarKind::{Float, Sint}; + use spirv::CooperativeMatrixOperands as Cmo; + let scalar_for = |h: Handle| match *self.fun_info[h] + .ty + .inner_with(&self.ir_module.types) + { + crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar, + _ => unreachable!("validated as CooperativeMatrix"), + }; + let a_scalar = scalar_for(a); + let b_scalar = scalar_for(b); + let c_scalar = scalar_for(c); + if a_scalar.kind != Float || b_scalar.kind != Float || c_scalar.kind != Float { + let mut ops = Cmo::NONE_KHR; + if a_scalar.kind == Sint { + ops |= Cmo::MATRIX_A_SIGNED_COMPONENTS_KHR; + } + if b_scalar.kind == Sint { + ops |= Cmo::MATRIX_B_SIGNED_COMPONENTS_KHR; + } + if c_scalar.kind == Sint { + ops |= Cmo::MATRIX_C_SIGNED_COMPONENTS_KHR; + ops |= Cmo::MATRIX_RESULT_SIGNED_COMPONENTS_KHR; + } + Some(ops) + } else { + None + } + }; + block.body.push(Instruction::coop_mul_add( result_type_id, id, a_id, b_id, c_id, + matrix_operands, )); id } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 0a22bc3dbfe..687a7e552c5 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1326,14 +1326,23 @@ impl super::Instruction { instruction.add_operand(stride_id); instruction } - pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { + pub(super) fn coop_mul_add( + result_type_id: Word, + id: Word, + a: Word, + b: Word, + c: Word, + matrix_operands: Option, + ) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(a); instruction.add_operand(b); instruction.add_operand(c); - + if let Some(operands) = matrix_operands { + instruction.add_operand(operands.bits()); + } instruction } } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index ae69c892d09..4614c23dc93 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -1222,4 +1222,5 @@ pub fn supported_capabilities() -> crate::valid::Capabilities { | Caps::DRAW_INDEX | Caps::MEMORY_DECORATION_COHERENT | Caps::MEMORY_DECORATION_VOLATILE + | Caps::SHADER_INT8 } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 7d4df508aaa..bc89b015906 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2048,8 +2048,19 @@ impl Writer { self.require_any("16 bit integer", &[spirv::Capability::Int16])?; self.use_extension("SPV_KHR_16bit_storage"); } + // 8-bit integer support requires Int8 and StorageBuffer8BitAccess + crate::TypeInner::Scalar(crate::Scalar { + width: 1, + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + }) => { + self.capabilities_used + .insert(spirv::Capability::StorageBuffer8BitAccess); + self.capabilities_used + .insert(spirv::Capability::UniformAndStorageBuffer8BitAccess); + self.use_extension("SPV_KHR_8bit_storage"); + } // Cooperative types and ops - crate::TypeInner::CooperativeMatrix { .. } => { + crate::TypeInner::CooperativeMatrix { scalar, .. } => { self.require_any( "cooperative matrix", &[spirv::Capability::CooperativeMatrixKHR], @@ -2057,6 +2068,13 @@ impl Writer { self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?; self.use_extension("SPV_KHR_cooperative_matrix"); self.use_extension("SPV_KHR_vulkan_memory_model"); + if scalar.width == 1 { + self.capabilities_used + .insert(spirv::Capability::StorageBuffer8BitAccess); + self.capabilities_used + .insert(spirv::Capability::UniformAndStorageBuffer8BitAccess); + self.use_extension("SPV_KHR_8bit_storage"); + } } _ => {} } diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index c3fa20832c9..b78fc44d145 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -302,6 +302,8 @@ impl TryToWgsl for crate::Scalar { Scalar::F16 => "f16", Scalar::F32 => "f32", Scalar::F64 => "f64", + Scalar::I8 => "i8", + Scalar::U8 => "u8", Scalar::I16 => "i16", Scalar::U16 => "u16", Scalar::I32 => "i32", diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 59d2268333f..abc51c48b5f 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -1462,7 +1462,7 @@ impl<'a> Error<'a> { Error::UnderspecifiedCooperativeMatrix => ParseError { message: "cooperative matrix constructor is underspecified".into(), labels: vec![], - notes: vec![format!("must be F32")], + notes: vec![format!("scalar type must be one of: f16, f32, i32, u32, i8, u8")], }, Error::InvalidCooperativeLoadType(span) => ParseError { message: "cooperative load should have a generic type for coop_mat".into(), @@ -1472,7 +1472,7 @@ impl<'a> Error<'a> { Error::UnsupportedCooperativeScalar(span) => ParseError { message: "cooperative scalar type is not supported".into(), labels: vec![(span, "type needs the scalar type specified".into())], - notes: vec![format!("must be F32")], + notes: vec![format!("scalar type must be one of: f16, f32, i32, u32, i8, u8")], }, Error::UnexpectedIdentForEnumerant(ident_span) => ParseError { message: format!( diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 4dc27709007..30e26ff836c 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -459,6 +459,8 @@ pub fn map_predeclared_type( // scalars "bool" => Ti::Scalar(Sc::BOOL).into(), + "i8" => Ti::Scalar(Sc::I8).into(), + "u8" => Ti::Scalar(Sc::U8).into(), "i32" => Ti::Scalar(Sc::I32).into(), "u32" => Ti::Scalar(Sc::U32).into(), "f32" => Ti::Scalar(Sc::F32).into(), @@ -557,7 +559,7 @@ pub fn map_predeclared_type( "acceleration_structure" => TypeGenerator::AccelerationStructure.into(), // ray query "ray_query" => TypeGenerator::RayQuery.into(), - // cooperative matrix + // cooperative matrix — square shapes "coop_mat8x8" => TypeGenerator::CooperativeMatrix { columns: crate::CooperativeSize::Eight, rows: crate::CooperativeSize::Eight, @@ -566,6 +568,31 @@ pub fn map_predeclared_type( columns: crate::CooperativeSize::Sixteen, rows: crate::CooperativeSize::Sixteen, }.into(), + // cooperative matrix — rectangular shapes for asymmetric MxNxK configs + "coop_mat8x16" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Sixteen, + }.into(), + "coop_mat16x8" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::Sixteen, + rows: crate::CooperativeSize::Eight, + }.into(), + "coop_mat8x32" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::ThirtyTwo, + }.into(), + "coop_mat32x8" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::ThirtyTwo, + rows: crate::CooperativeSize::Eight, + }.into(), + "coop_mat16x32" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::Sixteen, + rows: crate::CooperativeSize::ThirtyTwo, + }.into(), + "coop_mat32x16" => TypeGenerator::CooperativeMatrix { + columns: crate::CooperativeSize::ThirtyTwo, + rows: crate::CooperativeSize::Sixteen, + }.into(), _ => return Ok(None), }; @@ -578,6 +605,20 @@ pub fn map_predeclared_type( PredeclaredType::TypeInner(ref ty) if matches!(ty.scalar(), Some(s) if s == Sc::I16 || s == Sc::U16) => { Some(&[ImplementedEnableExtension::WgpuInt16]) } + PredeclaredType::TypeInner(ref ty) + if matches!( + ty.scalar(), + Some(Sc { + kind: crate::ScalarKind::Sint, + width: 1 + }) | Some(Sc { + kind: crate::ScalarKind::Uint, + width: 1 + }) + ) => + { + Some(&[ImplementedEnableExtension::WgpuCooperativeMatrix]) + } PredeclaredType::RayDesc | PredeclaredType::RayIntersection | PredeclaredType::TypeGenerator(TypeGenerator::AccelerationStructure) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index be69efc8d0c..4b70841d767 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -567,6 +567,7 @@ impl From for u32 { pub enum CooperativeSize { Eight = 8, Sixteen = 16, + ThirtyTwo = 32, } /// Primitive type for a scalar. diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index eaa8faa5499..d8c732eb450 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -23,6 +23,14 @@ impl crate::ScalarKind { } impl crate::Scalar { + pub const I8: Self = Self { + kind: crate::ScalarKind::Sint, + width: 1, + }; + pub const U8: Self = Self { + kind: crate::ScalarKind::Uint, + width: 1, + }; pub const I16: Self = Self { kind: crate::ScalarKind::Sint, width: 2, diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 6f206bae3f0..899b485b9dd 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -151,6 +151,16 @@ pub enum ExpressionError { UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), #[error("Invalid operand for cooperative op")] InvalidCooperativeOperand(Handle), + #[error( + "Cooperative matrix `coopMultiplyAdd` requires A and B to share a scalar type, \ + got A: {a:?}, B: {b:?}" + )] + InvalidCooperativeMixedInputs { a: crate::Scalar, b: crate::Scalar }, + #[error( + "Invalid accumulator type for coopMultiplyAdd: A/B use {ab:?} but C uses {c:?}; \ + allowed widened accumulators are f16→f32, i8→i32, u8→u32" + )] + InvalidCooperativeAccumulator { ab: crate::Scalar, c: crate::Scalar }, #[error("Shift amount exceeds the bit width of {lhs_type:?}")] ShiftAmountTooLarge { lhs_type: crate::TypeInner, @@ -1419,6 +1429,61 @@ impl super::Validator { } } } + // Validate that the shapes are compatible: A[rows×cols_a] * B[rows_b×cols] + + // C[rows×cols] requires cols_a == rows_b, plus consistent outer dimensions. + let (a_rows, a_cols) = match resolver[a] { + Ti::CooperativeMatrix { rows, columns, .. } => (rows, columns), + _ => unreachable!(), + }; + let (b_rows, b_cols) = match resolver[b] { + Ti::CooperativeMatrix { rows, columns, .. } => (rows, columns), + _ => unreachable!(), + }; + let (c_rows, c_cols) = match resolver[c] { + Ti::CooperativeMatrix { rows, columns, .. } => (rows, columns), + _ => unreachable!(), + }; + if a_cols != b_rows || a_rows != c_rows || b_cols != c_cols { + return Err(ExpressionError::InvalidCooperativeOperand(a)); + } + // A and B must have the same scalar type. C (the accumulator) + // may be the same type or a canonical wider type for + // mixed-precision GEMM: f16→f32, i8→i32, u8→u32. + let a_scalar = match resolver[a] { + Ti::CooperativeMatrix { scalar, .. } => scalar, + _ => unreachable!(), + }; + let b_scalar = match resolver[b] { + Ti::CooperativeMatrix { scalar, .. } => scalar, + _ => unreachable!(), + }; + let c_scalar = match resolver[c] { + Ti::CooperativeMatrix { scalar, .. } => scalar, + _ => unreachable!(), + }; + if a_scalar != b_scalar { + return Err(ExpressionError::InvalidCooperativeMixedInputs { + a: a_scalar, + b: b_scalar, + }); + } + // Same-type accumulators are allowed for every scalar so that the + // codegen path is exercised uniformly. The runtime + // (`vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR`) is what + // ultimately rejects unrealistic combos like `i8 × i8 → i8`. + let valid_accumulator = c_scalar == a_scalar + || matches!( + (a_scalar, c_scalar), + (crate::Scalar::F16, crate::Scalar::F32) + | (crate::Scalar::I8, crate::Scalar::I32) + | (crate::Scalar::U8, crate::Scalar::U32) + ); + if !valid_accumulator { + return Err(ExpressionError::InvalidCooperativeAccumulator { + ab: a_scalar, + c: c_scalar, + }); + } ShaderStages::COMPUTE } }; diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 8dc549f2fe8..621f2bd3f21 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -218,6 +218,9 @@ bitflags::bitflags! { const MEMORY_DECORATION_VOLATILE = 1 << 42; /// Support for 16-bit integer types. const SHADER_INT16 = 1 << 43; + /// Support for 8-bit integer scalars (`i8`/`u8`) in storage buffers and + /// as the component type of cooperative matrices. + const SHADER_INT8 = 1 << 44; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 78850aeed18..ce3232b520a 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -124,6 +124,13 @@ pub enum TypeError { InvalidArrayBaseType(Handle), #[error("Matrix elements must always be floating-point types")] MatrixElementNotFloat, + #[error( + "Cooperative matrix component type must be f16, f32, i32, u32, i8, or u8 (got {kind:?} width {width})" + )] + CooperativeMatrixScalarUnsupported { + kind: crate::ScalarKind, + width: crate::Bytes, + }, #[error("The constant {0:?} is specialized, and cannot be used as an array size")] UnsupportedSpecializedArrayLength(Handle), #[error("{} of dimensionality {dim:?} and class {class:?} are not supported", if *.arrayed {"Arrayed images"} else {"Images"})] @@ -335,6 +342,15 @@ impl super::Validator { }); } true + } else if scalar.width == 1 { + if !self.capabilities.contains(Capabilities::SHADER_INT8) { + return Err(WidthError::MissingCapability { + name: "i8", + flag: "SHADER_INT8", + }); + } + immediates_compatibility = Err(ImmediateError::InvalidScalar(scalar)); + true } else { scalar.width == 4 } @@ -359,6 +375,15 @@ impl super::Validator { }); } true + } else if scalar.width == 1 { + if !self.capabilities.contains(Capabilities::SHADER_INT8) { + return Err(WidthError::MissingCapability { + name: "u8", + flag: "SHADER_INT8", + }); + } + immediates_compatibility = Err(ImmediateError::InvalidScalar(scalar)); + true } else { scalar.width == 4 } @@ -456,11 +481,19 @@ impl super::Validator { role: _, } => { self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; - // Allow f16 (width 2) and f32 (width 4) for cooperative matrices - if scalar.kind != crate::ScalarKind::Float - || (scalar.width != 2 && scalar.width != 4) - { - return Err(TypeError::MatrixElementNotFloat); + let scalar_ok = match (scalar.kind, scalar.width) { + (crate::ScalarKind::Float, 2) | (crate::ScalarKind::Float, 4) => true, + (crate::ScalarKind::Sint | crate::ScalarKind::Uint, 4) => true, + (crate::ScalarKind::Sint | crate::ScalarKind::Uint, 1) => { + self.capabilities.contains(Capabilities::SHADER_INT8) + } + _ => false, + }; + if !scalar_ok { + return Err(TypeError::CooperativeMatrixScalarUnsupported { + kind: scalar.kind, + width: scalar.width, + }); } TypeInfo::new( TypeFlags::DATA diff --git a/naga/tests/in/wgsl/cooperative-matrix-int.toml b/naga/tests/in/wgsl/cooperative-matrix-int.toml new file mode 100644 index 00000000000..9618e28f75c --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-int.toml @@ -0,0 +1,6 @@ +targets = "SPIRV | WGSL" +capabilities = "COOPERATIVE_MATRIX" + +[spv] +debug = false +version = [1, 4] diff --git a/naga/tests/in/wgsl/cooperative-matrix-int.wgsl b/naga/tests/in/wgsl/cooperative-matrix-int.wgsl new file mode 100644 index 00000000000..498c9635301 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-int.wgsl @@ -0,0 +1,29 @@ +enable wgpu_cooperative_matrix; + +// i32 cooperative matrix — square 16x16 (most portable integer config) +var a_i32: coop_mat16x16; +var b_i32: coop_mat16x16; +@group(0) @binding(0) +var ext_i32: array; + +@compute @workgroup_size(8, 8, 1) +fn main_i32() { + var c = coopLoad>(&ext_i32[4]); + var d = coopMultiplyAdd(a_i32, b_i32, c); + coopStore(d, &ext_i32[0]); + c = d; +} + +// u32 cooperative matrix — square 16x16 +var a_u32: coop_mat16x16; +var b_u32: coop_mat16x16; +@group(0) @binding(1) +var ext_u32: array; + +@compute @workgroup_size(8, 8, 1) +fn main_u32() { + var c = coopLoad>(&ext_u32[4]); + var d = coopMultiplyAdd(a_u32, b_u32, c); + coopStore(d, &ext_u32[0]); + c = d; +} diff --git a/naga/tests/in/wgsl/cooperative-matrix-int8.toml b/naga/tests/in/wgsl/cooperative-matrix-int8.toml new file mode 100644 index 00000000000..2e94e6d227c --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-int8.toml @@ -0,0 +1,6 @@ +targets = "SPIRV | WGSL" +capabilities = "COOPERATIVE_MATRIX | SHADER_INT8" + +[spv] +debug = false +version = [1, 4] diff --git a/naga/tests/in/wgsl/cooperative-matrix-int8.wgsl b/naga/tests/in/wgsl/cooperative-matrix-int8.wgsl new file mode 100644 index 00000000000..8496e03c422 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-int8.wgsl @@ -0,0 +1,38 @@ +enable wgpu_cooperative_matrix; + +// i8 cooperative matrix — asymmetric 16x32 x 32x16 = 16x16 (M=16, N=16, K=32) +// coop_mat32x16 has columns=32, rows=16 → an M×K=16×32 A-matrix +// coop_mat16x32 has columns=16, rows=32 → a K×N=32×16 B-matrix +// coop_mat16x16 has columns=16, rows=16 → an M×N=16×16 C-matrix +// +// Note: this fixture uses an i8/i8/i8 (and u8/u8/u8) accumulator combo to exercise +// the codegen path uniformly. No real Vulkan device advertises an 8-bit accumulator +// (it would overflow immediately), so this combo would be rejected at runtime by +// `vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR`. The realistic widened +// combos (i8→i32, u8→u32) are covered by `cooperative-matrix-mixed.wgsl`. +var a_i8: coop_mat32x16; +var b_i8: coop_mat16x32; +@group(0) @binding(0) +var ext_i8: array; + +@compute @workgroup_size(8, 8, 1) +fn main_i8() { + var c = coopLoad>(&ext_i8[4]); + var d = coopMultiplyAdd(a_i8, b_i8, c); + coopStore(d, &ext_i8[0]); + c = d; +} + +// u8 cooperative matrix — same asymmetric shape +var a_u8: coop_mat32x16; +var b_u8: coop_mat16x32; +@group(0) @binding(1) +var ext_u8: array; + +@compute @workgroup_size(8, 8, 1) +fn main_u8() { + var c = coopLoad>(&ext_u8[4]); + var d = coopMultiplyAdd(a_u8, b_u8, c); + coopStore(d, &ext_u8[0]); + c = d; +} diff --git a/naga/tests/in/wgsl/cooperative-matrix-mixed.toml b/naga/tests/in/wgsl/cooperative-matrix-mixed.toml new file mode 100644 index 00000000000..17b9f015b9d --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-mixed.toml @@ -0,0 +1,6 @@ +targets = "SPIRV | WGSL" +capabilities = "COOPERATIVE_MATRIX | SHADER_INT8 | SHADER_FLOAT16" + +[spv] +debug = false +version = [1, 4] diff --git a/naga/tests/in/wgsl/cooperative-matrix-mixed.wgsl b/naga/tests/in/wgsl/cooperative-matrix-mixed.wgsl new file mode 100644 index 00000000000..8dfab771ba7 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix-mixed.wgsl @@ -0,0 +1,50 @@ +enable wgpu_cooperative_matrix; +enable f16; + +// Mixed-precision cooperative matrix multiply-add. +// All three combinations share the same asymmetric 16×16×32 shape: +// A is M×K = 16×32 (coop_mat32x16, columns=32, rows=16) +// B is K×N = 32×16 (coop_mat16x32, columns=16, rows=32) +// C is M×N = 16×16 (coop_mat16x16) + +// f16 inputs, f32 accumulator — supported on AMD/NVIDIA + Metal Apple7+. +var a_f16: coop_mat32x16; +var b_f16: coop_mat16x32; +@group(0) @binding(0) +var ext_f32: array; + +@compute @workgroup_size(8, 8, 1) +fn main_f16_f32() { + var c = coopLoad>(&ext_f32[4]); + var d = coopMultiplyAdd(a_f16, b_f16, c); + coopStore(d, &ext_f32[0]); + c = d; +} + +// i8 inputs, i32 accumulator — typical integer GEMM on Turing/Ampere/RDNA3. +var a_i8: coop_mat32x16; +var b_i8: coop_mat16x32; +@group(0) @binding(1) +var ext_i32: array; + +@compute @workgroup_size(8, 8, 1) +fn main_i8_i32() { + var c = coopLoad>(&ext_i32[4]); + var d = coopMultiplyAdd(a_i8, b_i8, c); + coopStore(d, &ext_i32[0]); + c = d; +} + +// u8 inputs, u32 accumulator. +var a_u8: coop_mat32x16; +var b_u8: coop_mat16x32; +@group(0) @binding(2) +var ext_u32: array; + +@compute @workgroup_size(8, 8, 1) +fn main_u8_u32() { + var c = coopLoad>(&ext_u32[4]); + var d = coopMultiplyAdd(a_u8, b_u8, c); + coopStore(d, &ext_u32[0]); + c = d; +} diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index 3ab89f11f5d..f3479cf389e 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -5180,6 +5180,110 @@ fn cooperative_matrix_enable_extension() { } } +/// Tests that cooperative matrices reject unsupported scalar types. +#[test] +fn cooperative_matrix_scalar_unsupported() { + use naga::valid::{Capabilities, TypeError}; + + // bool is never a valid cooperative matrix element type + check_one_validation!( + r#"enable wgpu_cooperative_matrix; +var a: coop_mat8x8; +"#, + Err(naga::valid::ValidationError::Type { + source: TypeError::CooperativeMatrixScalarUnsupported { .. }, + .. + }), + Capabilities::COOPERATIVE_MATRIX + ); + + // f64 is never a valid cooperative matrix element type + check_one_validation!( + r#"enable wgpu_cooperative_matrix; +var a: coop_mat8x8; +"#, + Err(naga::valid::ValidationError::Type { + source: TypeError::CooperativeMatrixScalarUnsupported { .. }, + .. + }), + Capabilities::COOPERATIVE_MATRIX | Capabilities::FLOAT64 + ); +} + +/// Tests that cooperative matrix `coopMultiplyAdd` rejects invalid mixed-precision +/// combinations. +#[test] +fn cooperative_matrix_invalid_mixed_precision() { + use naga::valid::{Capabilities, ExpressionError}; + + // A and B have different scalar types — always invalid. + check_one_validation!( + r#"enable wgpu_cooperative_matrix; +var a: coop_mat16x16; +var b: coop_mat16x16; +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat16x16; + var d = coopMultiplyAdd(a, b, c); +}"#, + Err(naga::valid::ValidationError::EntryPoint { + source: naga::valid::EntryPointError::Function( + naga::valid::FunctionError::Expression { + source: ExpressionError::InvalidCooperativeMixedInputs { .. }, + .. + } + ), + .. + }), + Capabilities::COOPERATIVE_MATRIX | Capabilities::SHADER_INT8 + ); + + // i8 × i8 with f32 accumulator — not a valid widened accumulator for i8. + check_one_validation!( + r#"enable wgpu_cooperative_matrix; +var a: coop_mat16x16; +var b: coop_mat16x16; +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat16x16; + var d = coopMultiplyAdd(a, b, c); +}"#, + Err(naga::valid::ValidationError::EntryPoint { + source: naga::valid::EntryPointError::Function( + naga::valid::FunctionError::Expression { + source: ExpressionError::InvalidCooperativeAccumulator { .. }, + .. + } + ), + .. + }), + Capabilities::COOPERATIVE_MATRIX | Capabilities::SHADER_INT8 + ); + + // f16 × f16 with i32 accumulator — not a valid widened accumulator for f16. + check_one_validation!( + r#"enable wgpu_cooperative_matrix; +enable f16; +var a: coop_mat16x16; +var b: coop_mat16x16; +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat16x16; + var d = coopMultiplyAdd(a, b, c); +}"#, + Err(naga::valid::ValidationError::EntryPoint { + source: naga::valid::EntryPointError::Function( + naga::valid::FunctionError::Expression { + source: ExpressionError::InvalidCooperativeAccumulator { .. }, + .. + } + ), + .. + }), + Capabilities::COOPERATIVE_MATRIX | Capabilities::SHADER_FLOAT16 + ); +} + /// Tests for mesh shader extension validation via WGSL parsing. /// /// Some mesh shader features can only be tested at parse-level in WGSL due to diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix-int.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix-int.spvasm new file mode 100644 index 00000000000..66a123c1a14 --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix-int.spvasm @@ -0,0 +1,115 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 78 +OpCapability Shader +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %37 "main_i32" %18 %21 %24 +OpEntryPoint GLCompute %59 "main_u32" %27 %30 %33 +OpExecutionMode %37 LocalSize 8 8 1 +OpExecutionMode %59 LocalSize 8 8 1 +OpDecorate %11 ArrayStride 4 +OpDecorate %16 ArrayStride 4 +OpDecorate %24 DescriptorSet 0 +OpDecorate %24 Binding 0 +OpDecorate %25 Block +OpMemberDecorate %25 0 Offset 0 +OpDecorate %33 DescriptorSet 0 +OpDecorate %33 Binding 1 +OpDecorate %34 Block +OpMemberDecorate %34 0 Offset 0 +%2 = OpTypeVoid +%3 = OpTypeInt 32 1 +%6 = OpTypeInt 32 0 +%5 = OpConstant %6 3 +%7 = OpConstant %6 16 +%8 = OpConstant %6 0 +%4 = OpTypeCooperativeMatrixKHR %3 %5 %7 %7 %8 +%10 = OpConstant %6 1 +%9 = OpTypeCooperativeMatrixKHR %3 %5 %7 %7 %10 +%11 = OpTypeRuntimeArray %3 +%13 = OpConstant %6 2 +%12 = OpTypeCooperativeMatrixKHR %3 %5 %7 %7 %13 +%14 = OpTypeCooperativeMatrixKHR %6 %5 %7 %7 %8 +%15 = OpTypeCooperativeMatrixKHR %6 %5 %7 %7 %10 +%16 = OpTypeRuntimeArray %6 +%17 = OpTypeCooperativeMatrixKHR %6 %5 %7 %7 %13 +%19 = OpTypePointer Private %4 +%20 = OpConstantNull %4 +%18 = OpVariable %19 Private %20 +%22 = OpTypePointer Private %9 +%23 = OpConstantNull %9 +%21 = OpVariable %22 Private %23 +%25 = OpTypeStruct %11 +%26 = OpTypePointer StorageBuffer %25 +%24 = OpVariable %26 StorageBuffer +%28 = OpTypePointer Private %14 +%29 = OpConstantNull %14 +%27 = OpVariable %28 Private %29 +%31 = OpTypePointer Private %15 +%32 = OpConstantNull %15 +%30 = OpVariable %31 Private %32 +%34 = OpTypeStruct %16 +%35 = OpTypePointer StorageBuffer %34 +%33 = OpVariable %35 StorageBuffer +%38 = OpTypeFunction %2 +%39 = OpTypePointer StorageBuffer %11 +%42 = OpTypePointer Function %12 +%43 = OpConstantNull %12 +%45 = OpConstantNull %12 +%47 = OpTypePointer StorageBuffer %3 +%48 = OpConstant %6 4 +%60 = OpTypePointer StorageBuffer %16 +%63 = OpTypePointer Function %17 +%64 = OpConstantNull %17 +%66 = OpConstantNull %17 +%68 = OpTypePointer StorageBuffer %6 +%37 = OpFunction %2 None %38 +%36 = OpLabel +%41 = OpVariable %42 Function %43 +%44 = OpVariable %42 Function %45 +%40 = OpAccessChain %39 %24 %8 +OpBranch %46 +%46 = OpLabel +%49 = OpAccessChain %47 %40 %48 +%50 = OpCooperativeMatrixLoadKHR %12 %49 %10 %7 +OpStore %41 %50 +%51 = OpLoad %4 %18 +%52 = OpLoad %9 %21 +%53 = OpLoad %12 %41 +%54 = OpCooperativeMatrixMulAddKHR %12 %51 %52 %53 CooperativeMatrixOperands(MATRIX_A_SIGNED_COMPONENTS_KHR | MATRIX_B_SIGNED_COMPONENTS_KHR | MATRIX_C_SIGNED_COMPONENTS_KHR | MATRIX_RESULT_SIGNED_COMPONENTS_KHR) +OpStore %44 %54 +%55 = OpLoad %12 %44 +%56 = OpAccessChain %47 %40 %8 +OpCooperativeMatrixStoreKHR %56 %55 %10 %7 +%57 = OpLoad %12 %44 +OpStore %41 %57 +OpReturn +OpFunctionEnd +%59 = OpFunction %2 None %38 +%58 = OpLabel +%62 = OpVariable %63 Function %64 +%65 = OpVariable %63 Function %66 +%61 = OpAccessChain %60 %33 %8 +OpBranch %67 +%67 = OpLabel +%69 = OpAccessChain %68 %61 %48 +%70 = OpCooperativeMatrixLoadKHR %17 %69 %10 %7 +OpStore %62 %70 +%71 = OpLoad %14 %27 +%72 = OpLoad %15 %30 +%73 = OpLoad %17 %62 +%74 = OpCooperativeMatrixMulAddKHR %17 %71 %72 %73 CooperativeMatrixOperands(0x0) +OpStore %65 %74 +%75 = OpLoad %17 %65 +%76 = OpAccessChain %68 %61 %8 +OpCooperativeMatrixStoreKHR %76 %75 %10 %7 +%77 = OpLoad %17 %65 +OpStore %62 %77 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix-int8.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix-int8.spvasm new file mode 100644 index 00000000000..0bedb7fa7c4 --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix-int8.spvasm @@ -0,0 +1,121 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 80 +OpCapability Shader +OpCapability StorageBuffer8BitAccess +OpCapability UniformAndStorageBuffer8BitAccess +OpCapability Int8 +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_8bit_storage" +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %39 "main_i8" %20 %23 %26 +OpEntryPoint GLCompute %61 "main_u8" %29 %32 %35 +OpExecutionMode %39 LocalSize 8 8 1 +OpExecutionMode %61 LocalSize 8 8 1 +OpDecorate %12 ArrayStride 1 +OpDecorate %18 ArrayStride 1 +OpDecorate %26 DescriptorSet 0 +OpDecorate %26 Binding 0 +OpDecorate %27 Block +OpMemberDecorate %27 0 Offset 0 +OpDecorate %35 DescriptorSet 0 +OpDecorate %35 Binding 1 +OpDecorate %36 Block +OpMemberDecorate %36 0 Offset 0 +%2 = OpTypeVoid +%3 = OpTypeInt 8 1 +%6 = OpTypeInt 32 0 +%5 = OpConstant %6 3 +%7 = OpConstant %6 32 +%8 = OpConstant %6 16 +%9 = OpConstant %6 0 +%4 = OpTypeCooperativeMatrixKHR %3 %5 %8 %7 %9 +%11 = OpConstant %6 1 +%10 = OpTypeCooperativeMatrixKHR %3 %5 %7 %8 %11 +%12 = OpTypeRuntimeArray %3 +%14 = OpConstant %6 2 +%13 = OpTypeCooperativeMatrixKHR %3 %5 %8 %8 %14 +%15 = OpTypeInt 8 0 +%16 = OpTypeCooperativeMatrixKHR %15 %5 %8 %7 %9 +%17 = OpTypeCooperativeMatrixKHR %15 %5 %7 %8 %11 +%18 = OpTypeRuntimeArray %15 +%19 = OpTypeCooperativeMatrixKHR %15 %5 %8 %8 %14 +%21 = OpTypePointer Private %4 +%22 = OpConstantNull %4 +%20 = OpVariable %21 Private %22 +%24 = OpTypePointer Private %10 +%25 = OpConstantNull %10 +%23 = OpVariable %24 Private %25 +%27 = OpTypeStruct %12 +%28 = OpTypePointer StorageBuffer %27 +%26 = OpVariable %28 StorageBuffer +%30 = OpTypePointer Private %16 +%31 = OpConstantNull %16 +%29 = OpVariable %30 Private %31 +%33 = OpTypePointer Private %17 +%34 = OpConstantNull %17 +%32 = OpVariable %33 Private %34 +%36 = OpTypeStruct %18 +%37 = OpTypePointer StorageBuffer %36 +%35 = OpVariable %37 StorageBuffer +%40 = OpTypeFunction %2 +%41 = OpTypePointer StorageBuffer %12 +%44 = OpTypePointer Function %13 +%45 = OpConstantNull %13 +%47 = OpConstantNull %13 +%49 = OpTypePointer StorageBuffer %3 +%50 = OpConstant %6 4 +%62 = OpTypePointer StorageBuffer %18 +%65 = OpTypePointer Function %19 +%66 = OpConstantNull %19 +%68 = OpConstantNull %19 +%70 = OpTypePointer StorageBuffer %15 +%39 = OpFunction %2 None %40 +%38 = OpLabel +%43 = OpVariable %44 Function %45 +%46 = OpVariable %44 Function %47 +%42 = OpAccessChain %41 %26 %9 +OpBranch %48 +%48 = OpLabel +%51 = OpAccessChain %49 %42 %50 +%52 = OpCooperativeMatrixLoadKHR %13 %51 %11 %8 +OpStore %43 %52 +%53 = OpLoad %4 %20 +%54 = OpLoad %10 %23 +%55 = OpLoad %13 %43 +%56 = OpCooperativeMatrixMulAddKHR %13 %53 %54 %55 CooperativeMatrixOperands(MATRIX_A_SIGNED_COMPONENTS_KHR | MATRIX_B_SIGNED_COMPONENTS_KHR | MATRIX_C_SIGNED_COMPONENTS_KHR | MATRIX_RESULT_SIGNED_COMPONENTS_KHR) +OpStore %46 %56 +%57 = OpLoad %13 %46 +%58 = OpAccessChain %49 %42 %9 +OpCooperativeMatrixStoreKHR %58 %57 %11 %8 +%59 = OpLoad %13 %46 +OpStore %43 %59 +OpReturn +OpFunctionEnd +%61 = OpFunction %2 None %40 +%60 = OpLabel +%64 = OpVariable %65 Function %66 +%67 = OpVariable %65 Function %68 +%63 = OpAccessChain %62 %35 %9 +OpBranch %69 +%69 = OpLabel +%71 = OpAccessChain %70 %63 %50 +%72 = OpCooperativeMatrixLoadKHR %19 %71 %11 %8 +OpStore %64 %72 +%73 = OpLoad %16 %29 +%74 = OpLoad %17 %32 +%75 = OpLoad %19 %64 +%76 = OpCooperativeMatrixMulAddKHR %19 %73 %74 %75 CooperativeMatrixOperands(0x0) +OpStore %67 %76 +%77 = OpLoad %19 %67 +%78 = OpAccessChain %70 %63 %9 +OpCooperativeMatrixStoreKHR %78 %77 %11 %8 +%79 = OpLoad %19 %67 +OpStore %64 %79 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix-mixed.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix-mixed.spvasm new file mode 100644 index 00000000000..b2a5eb8fe5c --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix-mixed.spvasm @@ -0,0 +1,175 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 116 +OpCapability Shader +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpCapability Float16 +OpCapability StorageBuffer16BitAccess +OpCapability UniformAndStorageBuffer16BitAccess +OpCapability StorageInputOutput16 +OpCapability StorageBuffer8BitAccess +OpCapability UniformAndStorageBuffer8BitAccess +OpCapability Int8 +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" +OpExtension "SPV_KHR_8bit_storage" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %55 "main_f16_f32" %27 %30 %33 +OpEntryPoint GLCompute %77 "main_i8_i32" %36 %39 %42 +OpEntryPoint GLCompute %97 "main_u8_u32" %45 %48 %51 +OpExecutionMode %55 LocalSize 8 8 1 +OpExecutionMode %77 LocalSize 8 8 1 +OpExecutionMode %97 LocalSize 8 8 1 +OpDecorate %13 ArrayStride 4 +OpDecorate %20 ArrayStride 4 +OpDecorate %25 ArrayStride 4 +OpDecorate %33 DescriptorSet 0 +OpDecorate %33 Binding 0 +OpDecorate %34 Block +OpMemberDecorate %34 0 Offset 0 +OpDecorate %42 DescriptorSet 0 +OpDecorate %42 Binding 1 +OpDecorate %43 Block +OpMemberDecorate %43 0 Offset 0 +OpDecorate %51 DescriptorSet 0 +OpDecorate %51 Binding 2 +OpDecorate %52 Block +OpMemberDecorate %52 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeFloat 16 +%6 = OpTypeInt 32 0 +%5 = OpConstant %6 3 +%7 = OpConstant %6 32 +%8 = OpConstant %6 16 +%9 = OpConstant %6 0 +%3 = OpTypeCooperativeMatrixKHR %4 %5 %8 %7 %9 +%11 = OpConstant %6 1 +%10 = OpTypeCooperativeMatrixKHR %4 %5 %7 %8 %11 +%12 = OpTypeFloat 32 +%13 = OpTypeRuntimeArray %12 +%15 = OpConstant %6 2 +%14 = OpTypeCooperativeMatrixKHR %12 %5 %8 %8 %15 +%17 = OpTypeInt 8 1 +%16 = OpTypeCooperativeMatrixKHR %17 %5 %8 %7 %9 +%18 = OpTypeCooperativeMatrixKHR %17 %5 %7 %8 %11 +%19 = OpTypeInt 32 1 +%20 = OpTypeRuntimeArray %19 +%21 = OpTypeCooperativeMatrixKHR %19 %5 %8 %8 %15 +%23 = OpTypeInt 8 0 +%22 = OpTypeCooperativeMatrixKHR %23 %5 %8 %7 %9 +%24 = OpTypeCooperativeMatrixKHR %23 %5 %7 %8 %11 +%25 = OpTypeRuntimeArray %6 +%26 = OpTypeCooperativeMatrixKHR %6 %5 %8 %8 %15 +%28 = OpTypePointer Private %3 +%29 = OpConstantNull %3 +%27 = OpVariable %28 Private %29 +%31 = OpTypePointer Private %10 +%32 = OpConstantNull %10 +%30 = OpVariable %31 Private %32 +%34 = OpTypeStruct %13 +%35 = OpTypePointer StorageBuffer %34 +%33 = OpVariable %35 StorageBuffer +%37 = OpTypePointer Private %16 +%38 = OpConstantNull %16 +%36 = OpVariable %37 Private %38 +%40 = OpTypePointer Private %18 +%41 = OpConstantNull %18 +%39 = OpVariable %40 Private %41 +%43 = OpTypeStruct %20 +%44 = OpTypePointer StorageBuffer %43 +%42 = OpVariable %44 StorageBuffer +%46 = OpTypePointer Private %22 +%47 = OpConstantNull %22 +%45 = OpVariable %46 Private %47 +%49 = OpTypePointer Private %24 +%50 = OpConstantNull %24 +%48 = OpVariable %49 Private %50 +%52 = OpTypeStruct %25 +%53 = OpTypePointer StorageBuffer %52 +%51 = OpVariable %53 StorageBuffer +%56 = OpTypeFunction %2 +%57 = OpTypePointer StorageBuffer %13 +%60 = OpTypePointer Function %14 +%61 = OpConstantNull %14 +%63 = OpConstantNull %14 +%65 = OpTypePointer StorageBuffer %12 +%66 = OpConstant %6 4 +%78 = OpTypePointer StorageBuffer %20 +%81 = OpTypePointer Function %21 +%82 = OpConstantNull %21 +%84 = OpConstantNull %21 +%86 = OpTypePointer StorageBuffer %19 +%98 = OpTypePointer StorageBuffer %25 +%101 = OpTypePointer Function %26 +%102 = OpConstantNull %26 +%104 = OpConstantNull %26 +%106 = OpTypePointer StorageBuffer %6 +%55 = OpFunction %2 None %56 +%54 = OpLabel +%59 = OpVariable %60 Function %61 +%62 = OpVariable %60 Function %63 +%58 = OpAccessChain %57 %33 %9 +OpBranch %64 +%64 = OpLabel +%67 = OpAccessChain %65 %58 %66 +%68 = OpCooperativeMatrixLoadKHR %14 %67 %11 %8 +OpStore %59 %68 +%69 = OpLoad %3 %27 +%70 = OpLoad %10 %30 +%71 = OpLoad %14 %59 +%72 = OpCooperativeMatrixMulAddKHR %14 %69 %70 %71 +OpStore %62 %72 +%73 = OpLoad %14 %62 +%74 = OpAccessChain %65 %58 %9 +OpCooperativeMatrixStoreKHR %74 %73 %11 %8 +%75 = OpLoad %14 %62 +OpStore %59 %75 +OpReturn +OpFunctionEnd +%77 = OpFunction %2 None %56 +%76 = OpLabel +%80 = OpVariable %81 Function %82 +%83 = OpVariable %81 Function %84 +%79 = OpAccessChain %78 %42 %9 +OpBranch %85 +%85 = OpLabel +%87 = OpAccessChain %86 %79 %66 +%88 = OpCooperativeMatrixLoadKHR %21 %87 %11 %8 +OpStore %80 %88 +%89 = OpLoad %16 %36 +%90 = OpLoad %18 %39 +%91 = OpLoad %21 %80 +%92 = OpCooperativeMatrixMulAddKHR %21 %89 %90 %91 CooperativeMatrixOperands(MATRIX_A_SIGNED_COMPONENTS_KHR | MATRIX_B_SIGNED_COMPONENTS_KHR | MATRIX_C_SIGNED_COMPONENTS_KHR | MATRIX_RESULT_SIGNED_COMPONENTS_KHR) +OpStore %83 %92 +%93 = OpLoad %21 %83 +%94 = OpAccessChain %86 %79 %9 +OpCooperativeMatrixStoreKHR %94 %93 %11 %8 +%95 = OpLoad %21 %83 +OpStore %80 %95 +OpReturn +OpFunctionEnd +%97 = OpFunction %2 None %56 +%96 = OpLabel +%100 = OpVariable %101 Function %102 +%103 = OpVariable %101 Function %104 +%99 = OpAccessChain %98 %51 %9 +OpBranch %105 +%105 = OpLabel +%107 = OpAccessChain %106 %99 %66 +%108 = OpCooperativeMatrixLoadKHR %26 %107 %11 %8 +OpStore %100 %108 +%109 = OpLoad %22 %45 +%110 = OpLoad %24 %48 +%111 = OpLoad %26 %100 +%112 = OpCooperativeMatrixMulAddKHR %26 %109 %110 %111 CooperativeMatrixOperands(0x0) +OpStore %103 %112 +%113 = OpLoad %26 %103 +%114 = OpAccessChain %106 %99 %9 +OpCooperativeMatrixStoreKHR %114 %113 %11 %8 +%115 = OpLoad %26 %103 +OpStore %100 %115 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix-int.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix-int.wgsl new file mode 100644 index 00000000000..adbc8c8d151 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix-int.wgsl @@ -0,0 +1,44 @@ +enable wgpu_cooperative_matrix; + +var a_i32_: coop_mat16x16; +var b_i32_: coop_mat16x16; +@group(0) @binding(0) +var ext_i32_: array; +var a_u32_: coop_mat16x16; +var b_u32_: coop_mat16x16; +@group(0) @binding(1) +var ext_u32_: array; + +@compute @workgroup_size(8, 8, 1) +fn main_i32_() { + var c: coop_mat16x16; + var d: coop_mat16x16; + + c = coopLoad>((&ext_i32_[4]), 16u); + let _e6 = a_i32_; + let _e8 = b_i32_; + let _e9 = c; + d = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d; + coopStore(_e12, (&ext_i32_[0]), 16u); + let _e16 = d; + c = _e16; + return; +} + +@compute @workgroup_size(8, 8, 1) +fn main_u32_() { + var c_1: coop_mat16x16; + var d_1: coop_mat16x16; + + c_1 = coopLoad>((&ext_u32_[4]), 16u); + let _e6 = a_u32_; + let _e8 = b_u32_; + let _e9 = c_1; + d_1 = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d_1; + coopStore(_e12, (&ext_u32_[0]), 16u); + let _e16 = d_1; + c_1 = _e16; + return; +} diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix-int8.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix-int8.wgsl new file mode 100644 index 00000000000..5c6c15e14e0 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix-int8.wgsl @@ -0,0 +1,44 @@ +enable wgpu_cooperative_matrix; + +var a_i8_: coop_mat32x16; +var b_i8_: coop_mat16x32; +@group(0) @binding(0) +var ext_i8_: array; +var a_u8_: coop_mat32x16; +var b_u8_: coop_mat16x32; +@group(0) @binding(1) +var ext_u8_: array; + +@compute @workgroup_size(8, 8, 1) +fn main_i8_() { + var c: coop_mat16x16; + var d: coop_mat16x16; + + c = coopLoad>((&ext_i8_[4]), 16u); + let _e6 = a_i8_; + let _e8 = b_i8_; + let _e9 = c; + d = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d; + coopStore(_e12, (&ext_i8_[0]), 16u); + let _e16 = d; + c = _e16; + return; +} + +@compute @workgroup_size(8, 8, 1) +fn main_u8_() { + var c_1: coop_mat16x16; + var d_1: coop_mat16x16; + + c_1 = coopLoad>((&ext_u8_[4]), 16u); + let _e6 = a_u8_; + let _e8 = b_u8_; + let _e9 = c_1; + d_1 = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d_1; + coopStore(_e12, (&ext_u8_[0]), 16u); + let _e16 = d_1; + c_1 = _e16; + return; +} diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix-mixed.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix-mixed.wgsl new file mode 100644 index 00000000000..45a4505d841 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix-mixed.wgsl @@ -0,0 +1,65 @@ +enable wgpu_cooperative_matrix; + +var a_f16_: coop_mat32x16; +var b_f16_: coop_mat16x32; +@group(0) @binding(0) +var ext_f32_: array; +var a_i8_: coop_mat32x16; +var b_i8_: coop_mat16x32; +@group(0) @binding(1) +var ext_i32_: array; +var a_u8_: coop_mat32x16; +var b_u8_: coop_mat16x32; +@group(0) @binding(2) +var ext_u32_: array; + +@compute @workgroup_size(8, 8, 1) +fn main_f16_f32_() { + var c: coop_mat16x16; + var d: coop_mat16x16; + + c = coopLoad>((&ext_f32_[4]), 16u); + let _e6 = a_f16_; + let _e8 = b_f16_; + let _e9 = c; + d = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d; + coopStore(_e12, (&ext_f32_[0]), 16u); + let _e16 = d; + c = _e16; + return; +} + +@compute @workgroup_size(8, 8, 1) +fn main_i8_i32_() { + var c_1: coop_mat16x16; + var d_1: coop_mat16x16; + + c_1 = coopLoad>((&ext_i32_[4]), 16u); + let _e6 = a_i8_; + let _e8 = b_i8_; + let _e9 = c_1; + d_1 = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d_1; + coopStore(_e12, (&ext_i32_[0]), 16u); + let _e16 = d_1; + c_1 = _e16; + return; +} + +@compute @workgroup_size(8, 8, 1) +fn main_u8_u32_() { + var c_2: coop_mat16x16; + var d_2: coop_mat16x16; + + c_2 = coopLoad>((&ext_u32_[4]), 16u); + let _e6 = a_u8_; + let _e8 = b_u8_; + let _e9 = c_2; + d_2 = coopMultiplyAdd(_e6, _e8, _e9); + let _e12 = d_2; + coopStore(_e12, (&ext_u32_[0]), 16u); + let _e16 = d_2; + c_2 = _e16; + return; +} diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 211e53e3731..c4bc3eaa72a 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -70,6 +70,12 @@ pub struct PhysicalDeviceFeatures { /// Features provided by `VK_KHR_16bit_storage`, promoted to Vulkan 1.1 _16bit_storage: Option>, + /// Features provided by `VK_KHR_8bit_storage`, promoted to Vulkan 1.2. + /// + /// Required to use 8-bit integers in `StorageBuffer` address space (e.g. + /// `array` storage buffers for cooperative matrix loads/stores). + _8bit_storage: Option>, + /// Features provided by `VK_KHR_acceleration_structure`. acceleration_structure: Option>, @@ -222,6 +228,9 @@ impl PhysicalDeviceFeatures { if let Some(ref mut feature) = self.portability_subset { info = info.push_next(feature); } + if let Some(ref mut feature) = self._8bit_storage { + info = info.push_next(feature); + } if let Some(ref mut feature) = self.cooperative_matrix { info = info.push_next(feature); } @@ -448,6 +457,20 @@ impl PhysicalDeviceFeatures { } else { None }, + _8bit_storage: if requested_features + .contains(wgt::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) + && private_caps.shader_int8 + && private_caps.storage_buffer_8bit_access + && enabled_extensions.contains(&khr::_8bit_storage::NAME) + { + Some( + vk::PhysicalDevice8BitStorageFeatures::default() + .storage_buffer8_bit_access(true) + .uniform_and_storage_buffer8_bit_access(true), + ) + } else { + None + }, acceleration_structure: if enabled_extensions .contains(&khr::acceleration_structure::NAME) { @@ -1395,6 +1418,16 @@ impl PhysicalDeviceProperties { extensions.push(khr::cooperative_matrix::NAME); } + // Optionally require `VK_KHR_8bit_storage` when cooperative matrix is requested and + // the device supports it. Needed for 8-bit integer storage buffers used with i8/u8 + // cooperative matrix loads and stores. This is a no-op on Vulkan 1.2+. + if requested_features.contains(wgt::Features::EXPERIMENTAL_COOPERATIVE_MATRIX) + && self.device_api_version < vk::API_VERSION_1_2 + && self.supports_extension(khr::_8bit_storage::NAME) + { + extensions.push(khr::_8bit_storage::NAME); + } + extensions } @@ -2013,6 +2046,16 @@ impl super::InstanceShared { .insert(vk::PhysicalDevice16BitStorageFeaturesKHR::default()); features2 = features2.push_next(next); } + + // `VK_KHR_8bit_storage` is promoted to Vulkan 1.2 + if capabilities.device_api_version >= vk::API_VERSION_1_2 + || capabilities.supports_extension(khr::_8bit_storage::NAME) + { + let next = features + ._8bit_storage + .insert(vk::PhysicalDevice8BitStorageFeaturesKHR::default()); + features2 = features2.push_next(next); + } if capabilities.supports_extension(khr::acceleration_structure::NAME) { let next = features .acceleration_structure @@ -2341,6 +2384,9 @@ impl super::Instance { shader_int8: phd_features .shader_float16_int8 .is_some_and(|features| features.shader_int8 != 0), + storage_buffer_8bit_access: phd_features + ._8bit_storage + .is_some_and(|f| f.storage_buffer8_bit_access != 0), multiview_instance_index_limit: phd_capabilities .multiview .map(|a| a.max_multiview_instance_index) @@ -2697,6 +2743,12 @@ impl super::Adapter { // See . capabilities.extend(&[spv::Capability::Int8]); } + if self.private_caps.storage_buffer_8bit_access { + capabilities.extend(&[ + spv::Capability::StorageBuffer8BitAccess, + spv::Capability::UniformAndStorageBuffer8BitAccess, + ]); + } spv::Options { lang_version: match self.phd_capabilities.device_api_version { // Use maximum supported SPIR-V version according to @@ -3302,6 +3354,8 @@ fn map_vk_component_type(ty: vk::ComponentTypeKHR) -> Option Some(wgt::CooperativeScalarType::F32), vk::ComponentTypeKHR::SINT32 => Some(wgt::CooperativeScalarType::I32), vk::ComponentTypeKHR::UINT32 => Some(wgt::CooperativeScalarType::U32), + vk::ComponentTypeKHR::SINT8 => Some(wgt::CooperativeScalarType::I8), + vk::ComponentTypeKHR::UINT8 => Some(wgt::CooperativeScalarType::U8), _ => None, } } @@ -3309,7 +3363,7 @@ fn map_vk_component_type(ty: vk::ComponentTypeKHR) -> Option Option { match size { - 8 | 16 => Some(size), + 8 | 16 | 32 => Some(size), _ => None, } } diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index 124df07c8aa..b20ecf49235 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -374,6 +374,12 @@ struct PrivateCapabilities { /// [see spec]: https://registry.khronos.org/vulkan/specs/latest/man/html/VkPhysicalDeviceShaderFloat16Int8Features.html#extension-features-shaderInt8 shader_int8: bool, + /// True if the device supports `storageBuffer8BitAccess` from + /// `VK_KHR_8bit_storage` (promoted to Vulkan 1.2). + /// + /// Required to use 8-bit integers in `StorageBuffer` address space. + storage_buffer_8bit_access: bool, + /// This is done to panic before undefined behavior, and is imperfect. /// Basically, to allow implementations to emulate mv using instancing, if you /// want to draw `n` instances to VR, you must draw `2n` instances, but you diff --git a/wgpu-naga-bridge/src/lib.rs b/wgpu-naga-bridge/src/lib.rs index b04bf1e4a40..be1d15f042e 100644 --- a/wgpu-naga-bridge/src/lib.rs +++ b/wgpu-naga-bridge/src/lib.rs @@ -163,6 +163,17 @@ pub fn features_to_naga_capabilities( Caps::COOPERATIVE_MATRIX, features.intersects(wgt::Features::EXPERIMENTAL_COOPERATIVE_MATRIX), ); + // i8/u8 scalars are enabled alongside cooperative matrix (needed for int8 component types). + // Backends that don't support cooperative matrix never set this feature so this is safe. + // + // TODO: SHADER_INT8 is logically orthogonal to cooperative matrix. If a future + // use case wants 8-bit integer scalars without cooperative matrix, split this + // into its own `wgt::Features` flag and gate the WGSL `wgpu_cooperative_matrix` + // enable's grant of `i8`/`u8` types on it independently. + caps.set( + Caps::SHADER_INT8, + features.intersects(wgt::Features::EXPERIMENTAL_COOPERATIVE_MATRIX), + ); caps.set( Caps::PER_VERTEX, features.intersects(wgt::Features::SHADER_PER_VERTEX), diff --git a/wgpu-types/src/adapter.rs b/wgpu-types/src/adapter.rs index 99bbb15e68d..78f4549863d 100644 --- a/wgpu-types/src/adapter.rs +++ b/wgpu-types/src/adapter.rs @@ -334,6 +334,16 @@ pub enum CooperativeScalarType { I32, /// 32-bit unsigned integer. U32, + /// 8-bit signed integer. + /// + /// Requires `EXPERIMENTAL_COOPERATIVE_MATRIX` and a device that advertises + /// `shaderInt8` + `storageBuffer8BitAccess` (i.e. `VK_KHR_8bit_storage`). + I8, + /// 8-bit unsigned integer. + /// + /// Requires `EXPERIMENTAL_COOPERATIVE_MATRIX` and a device that advertises + /// `shaderInt8` + `storageBuffer8BitAccess` (i.e. `VK_KHR_8bit_storage`). + U8, } /// Describes a supported cooperative matrix configuration. diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index 4048a139ed9..65ddcafa65f 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1402,14 +1402,34 @@ bitflags_array! { /// matrix multiply-accumulate operations on small tiles of data, enabling /// hardware-accelerated matrix math. /// - /// **Current limitations:** The implementation currently only supports 8x8 f32 matrices. - /// On Vulkan, support is determined by querying `vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR` - /// for configurations matching 8x8x8 f32. Most Vulkan implementations (NVIDIA, AMD) primarily - /// support f16 inputs at larger sizes (e.g., 16x16), so Vulkan support may be limited. + /// **Supported scalar types:** + /// - `f16`, `f32` — float types, supported on Metal and Vulkan + /// - `i32`, `u32` — 32-bit integer types, Vulkan only + /// - `i8`, `u8` — 8-bit integer types, Vulkan only (requires device support for + /// `shaderInt8` and `storageBuffer8BitAccess`; use `enable wgpu_cooperative_matrix;` + /// to access the `i8`/`u8` predeclared scalars and the asymmetric matrix types + /// such as `coop_mat16x32`, `coop_mat32x16` that 8-bit multiplies require) + /// + /// **Note on `i8`/`u8` scope:** `enable wgpu_cooperative_matrix;` makes `i8` and + /// `u8` first-class WGSL scalar types — they may appear in storage buffers + /// (`var buf: array;`) and struct fields, not just as cooperative + /// matrix component types. This is required so cooperative matrix loads and + /// stores can address backing buffers with 8-bit element strides. + /// + /// **Naming convention for cooperative matrix types:** `coop_matx` + /// is `C` columns by `R` rows of scalar `T` with usage `U`, matching WGSL's + /// `matx` ordering. So `coop_mat32x16` is a matrix with 32 + /// columns and 16 rows (which represents an `M=16, K=32` A-operand in textbook + /// `M×N = (M×K)·(K×N)` GEMM notation). + /// + /// On Vulkan, the set of supported `(M, N, K, scalar)` configurations is determined + /// at adapter creation by querying `vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR`. + /// If no matching configuration is available on the device, this feature will not be + /// reported as supported. /// /// Supported platforms: - /// - Metal (with MSL 2.3+ and Apple7+/Mac2+, using simdgroup matrix operations) - /// - Vulkan (with [VK_KHR_cooperative_matrix](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_cooperative_matrix.html), if 8x8 f32 is supported) + /// - Metal (with MSL 2.3+ and Apple7+/Mac2+, using simdgroup matrix operations; float only) + /// - Vulkan (with [`VK_KHR_cooperative_matrix`](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_cooperative_matrix.html)) /// /// This is a native only feature. #[name("wgpu-cooperative-matrix")]