Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ By @beholdnec in [#8505](https://github.com/gfx-rs/wgpu/pull/8505).
- Conditional compilation by @jimblandy in [#9390](https://github.com/gfx-rs/wgpu/pull/9390)
- Fixed alignment and MatrixStride for mat2x2 in SPIR-V uniform blocks. By @39ali [#9369](https://github.com/gfx-rs/wgpu/pull/9369).
- Add `wgpu_hal::vulkan::Buffer::raw_handle()` for retrieving the underlying `vk::Buffer` resource. By @WillowGriffiths in [#9459](https://github.com/gfx-rs/wgpu/pull/9459).
- `EXPERIMENTAL_COOPERATIVE_MATRIX` now supports integer scalar types (`i32`, `u32`, `i8`, `u8`) in addition to the existing `f16`/`f32` support. The `i8`/`u8` types require device support for `shaderInt8` and `storageBuffer8BitAccess` (part of Vulkan 1.2 or `VK_KHR_8bit_storage`). New WGSL predeclared types `coop_mat8x16`, `coop_mat16x8`, `coop_mat8x32`, `coop_mat32x8`, `coop_mat16x32`, and `coop_mat32x16` are available behind `enable wgpu_cooperative_matrix;` for the asymmetric matrix shapes that integer multiplies typically require. By @ruihe774.
- `coopMultiplyAdd` now supports mixed-precision accumulation: `i8`/`u8` input matrices can use `i32`/`u32` accumulators, and `f16` input matrices can use `f32` accumulators (e.g. `coopMultiplyAdd(a_i8, b_i8, c_i32)`). By @ruihe774.

#### naga

Expand Down
94 changes: 85 additions & 9 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,11 @@ pub(super) enum WrappedFunction {
columns: crate::CooperativeSize,
rows: crate::CooperativeSize,
intermediate: crate::CooperativeSize,
scalar: crate::Scalar,
/// Scalar type of A and B (the multiplied operands).
input_scalar: crate::Scalar,
/// Scalar type of C and the result (the accumulator). May differ from
/// `input_scalar` for mixed-precision (e.g. f16 inputs, f32 accumulator).
accumulator_scalar: crate::Scalar,
},
RayQueryGetIntersection {
committed: bool,
Expand Down Expand Up @@ -554,6 +558,14 @@ impl crate::Scalar {
kind: Sk::Float,
width: 2,
} => "half",
Self {
kind: Sk::Sint,
width: 1,
} => "char",
Self {
kind: Sk::Uint,
width: 1,
} => "uchar",
Self {
kind: Sk::Sint,
width: 2,
Expand Down Expand Up @@ -2926,6 +2938,22 @@ impl<W: Write> Writer<W> {
if context.lang_version < (2, 3) {
return Err(Error::UnsupportedCooperativeMatrix);
}
// Metal simdgroup_matrix only supports floating-point types.
{
let ptr_ty = context.resolve_type(data.pointer);
let scalar = ptr_ty
.pointer_base_type()
.and_then(|tr| tr.inner_with(&context.module.types).scalar());
if !matches!(
scalar,
Some(crate::Scalar {
kind: crate::ScalarKind::Float,
..
})
) {
return Err(Error::UnsupportedCooperativeMatrix);
}
}
write!(self.out, "{COOPERATIVE_LOAD_FUNCTION}(")?;
write!(self.out, "&")?;
self.put_access_chain(data.pointer, context.policies.index, context)?;
Expand All @@ -2937,7 +2965,33 @@ impl<W: Write> Writer<W> {
if context.lang_version < (2, 3) {
return Err(Error::UnsupportedCooperativeMatrix);
}
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
// Metal simdgroup_matrix only supports floating-point types.
let (in_name, acc_name) = {
let a_scalar = context.resolve_type(a).scalar();
let c_scalar = context.resolve_type(c).scalar();
match (a_scalar, c_scalar) {
(
Some(crate::Scalar {
kind: crate::ScalarKind::Float,
..
}),
Some(crate::Scalar {
kind: crate::ScalarKind::Float,
..
}),
) => (
a_scalar.unwrap().to_msl_name(),
c_scalar.unwrap().to_msl_name(),
),
_ => return Err(Error::UnsupportedCooperativeMatrix),
}
};
let fn_suffix = if in_name == acc_name {
String::new()
} else {
format!("_{in_name}_{acc_name}")
};
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}{fn_suffix}(")?;
self.put_expression(a, context, true)?;
write!(self.out, ", ")?;
self.put_expression(b, context, true)?;
Expand Down Expand Up @@ -6448,16 +6502,25 @@ template <typename A>
Ok(())
}

/// NOTE: this generator does not itself reject non-float scalar types. Metal's
/// `simdgroup_matrix` only exists for `float`/`half`, so emitting a wrapper for
/// e.g. `simdgroup_char32x16` would produce invalid MSL. The float-only check
/// lives at the use site in `put_expression` for `CooperativeMultiplyAdd`,
/// which fails the whole compilation before the consumer ever reads the
/// emitted output. If we ever start writing wrappers in a context where
/// `put_expression` cannot bail (e.g. emitting a stub library), mirror that
/// check here.
fn write_wrapped_cooperative_multiply_add(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
space: crate::AddressSpace,
a: Handle<crate::Expression>,
b: Handle<crate::Expression>,
c: Handle<crate::Expression>,
) -> BackendResult {
let space_name = space.to_msl_name().unwrap_or_default();
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
let (a_c, a_r, input_scalar) = match *func_ctx.resolve_type(a, &module.types) {
crate::TypeInner::CooperativeMatrix {
columns,
rows,
Expand All @@ -6470,26 +6533,39 @@ template <typename A>
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
_ => unreachable!(),
};
let accumulator_scalar = match *func_ctx.resolve_type(c, &module.types) {
crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar,
_ => unreachable!(),
};
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
space_name,
columns: b_c,
rows: a_r,
intermediate: a_c,
scalar,
input_scalar,
accumulator_scalar,
};
if !self.wrapped_functions.insert(wrapped) {
return Ok(());
}
let scalar_name = scalar.to_msl_name();
let in_name = input_scalar.to_msl_name();
let acc_name = accumulator_scalar.to_msl_name();
// When input and accumulator types differ we need a disambiguated name
// so two combos (e.g. f16→f32 and f32→f32) can coexist in one module.
let fn_suffix = if in_name == acc_name {
String::new()
} else {
format!("_{in_name}_{acc_name}")
};
writeln!(
self.out,
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
"{NAMESPACE}::simdgroup_{acc_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}{fn_suffix}(const {space_name} {NAMESPACE}::simdgroup_{in_name}{}x{}& a, const {space_name} {NAMESPACE}::simdgroup_{in_name}{}x{}& b, const {space_name} {NAMESPACE}::simdgroup_{acc_name}{}x{}& c) {{",
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
)?;
let l1 = back::Level(1);
writeln!(
self.out,
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
"{l1}{NAMESPACE}::simdgroup_{acc_name}{}x{} d;",
b_c as u32, a_r as u32
)?;
writeln!(self.out, "{l1}simdgroup_multiply_accumulate(d,a,b,c);")?;
Expand Down Expand Up @@ -6587,9 +6663,9 @@ template <typename A>
data.pointer,
)?;
}
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
let space = crate::AddressSpace::Private;
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b)?;
self.write_wrapped_cooperative_multiply_add(module, func_ctx, space, a, b, c)?;
}
crate::Expression::RayQueryGetIntersection { committed, .. } => {
self.write_rq_get_intersection_function(module, committed)?;
Expand Down
51 changes: 49 additions & 2 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,12 @@
self.cached[expr_handle] = id;
return Ok(());
}
crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd,
crate::TypeInner::CooperativeMatrix { scalar, .. } => {
match scalar.kind {
crate::ScalarKind::Float => spirv::Op::FAdd,
_ => spirv::Op::IAdd,
}
}
_ => unimplemented!(),
},
crate::BinaryOperator::Subtract => match *left_ty_inner {
Expand Down Expand Up @@ -1105,7 +1110,12 @@
self.cached[expr_handle] = id;
return Ok(());
}
crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub,
crate::TypeInner::CooperativeMatrix { scalar, .. } => {
match scalar.kind {
crate::ScalarKind::Float => spirv::Op::FSub,
_ => spirv::Op::ISub,
}
}
_ => unimplemented!(),
},
crate::BinaryOperator::Multiply => {
Expand Down Expand Up @@ -2178,12 +2188,49 @@
let b_id = self.cached[b];
let c_id = self.cached[c];
let id = self.gen_id();

// Build the CooperativeMatrixOperands word. SPIR-V requires this
// operand whenever any matrix has signed-integer components; for
// all-float operands we omit it (NONE_KHR == 0 is equivalent but
// adds a word, and skipping it keeps the float golden tests stable).
let matrix_operands = {
use crate::ScalarKind::{Float, Sint};
use spirv::CooperativeMatrixOperands as Cmo;

Check warning on line 2198 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
let scalar_for = |h: Handle<crate::Expression>| match *self.fun_info[h]
.ty
.inner_with(&self.ir_module.types)
{
crate::TypeInner::CooperativeMatrix { scalar, .. } => scalar,
_ => unreachable!("validated as CooperativeMatrix"),
};
let a_scalar = scalar_for(a);
let b_scalar = scalar_for(b);
let c_scalar = scalar_for(c);
if a_scalar.kind != Float || b_scalar.kind != Float || c_scalar.kind != Float {
let mut ops = Cmo::NONE_KHR;

Check warning on line 2210 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
if a_scalar.kind == Sint {
ops |= Cmo::MATRIX_A_SIGNED_COMPONENTS_KHR;

Check warning on line 2212 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
}
if b_scalar.kind == Sint {
ops |= Cmo::MATRIX_B_SIGNED_COMPONENTS_KHR;

Check warning on line 2215 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
}
if c_scalar.kind == Sint {
ops |= Cmo::MATRIX_C_SIGNED_COMPONENTS_KHR;

Check warning on line 2218 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
ops |= Cmo::MATRIX_RESULT_SIGNED_COMPONENTS_KHR;

Check warning on line 2219 in naga/src/back/spv/block.rs

View workflow job for this annotation

GitHub Actions / Format & Typos

"Cmo" should be "Com".
}
Some(ops)
} else {
None
}
};

block.body.push(Instruction::coop_mul_add(
result_type_id,
id,
a_id,
b_id,
c_id,
matrix_operands,
));
id
}
Expand Down
13 changes: 11 additions & 2 deletions naga/src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,14 +1326,23 @@ impl super::Instruction {
instruction.add_operand(stride_id);
instruction
}
pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self {
pub(super) fn coop_mul_add(
result_type_id: Word,
id: Word,
a: Word,
b: Word,
c: Word,
matrix_operands: Option<spirv::CooperativeMatrixOperands>,
) -> Self {
let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(a);
instruction.add_operand(b);
instruction.add_operand(c);

if let Some(operands) = matrix_operands {
instruction.add_operand(operands.bits());
}
instruction
}
}
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1222,4 +1222,5 @@ pub fn supported_capabilities() -> crate::valid::Capabilities {
| Caps::DRAW_INDEX
| Caps::MEMORY_DECORATION_COHERENT
| Caps::MEMORY_DECORATION_VOLATILE
| Caps::SHADER_INT8
}
20 changes: 19 additions & 1 deletion naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2048,15 +2048,33 @@ impl Writer {
self.require_any("16 bit integer", &[spirv::Capability::Int16])?;
self.use_extension("SPV_KHR_16bit_storage");
}
// 8-bit integer support requires Int8 and StorageBuffer8BitAccess
crate::TypeInner::Scalar(crate::Scalar {
width: 1,
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
}) => {
self.capabilities_used
.insert(spirv::Capability::StorageBuffer8BitAccess);
self.capabilities_used
.insert(spirv::Capability::UniformAndStorageBuffer8BitAccess);
self.use_extension("SPV_KHR_8bit_storage");
}
// Cooperative types and ops
crate::TypeInner::CooperativeMatrix { .. } => {
crate::TypeInner::CooperativeMatrix { scalar, .. } => {
self.require_any(
"cooperative matrix",
&[spirv::Capability::CooperativeMatrixKHR],
)?;
self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?;
self.use_extension("SPV_KHR_cooperative_matrix");
self.use_extension("SPV_KHR_vulkan_memory_model");
if scalar.width == 1 {
self.capabilities_used
.insert(spirv::Capability::StorageBuffer8BitAccess);
self.capabilities_used
.insert(spirv::Capability::UniformAndStorageBuffer8BitAccess);
self.use_extension("SPV_KHR_8bit_storage");
}
}
_ => {}
}
Expand Down
2 changes: 2 additions & 0 deletions naga/src/common/wgsl/to_wgsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ impl TryToWgsl for crate::Scalar {
Scalar::F16 => "f16",
Scalar::F32 => "f32",
Scalar::F64 => "f64",
Scalar::I8 => "i8",
Scalar::U8 => "u8",
Scalar::I16 => "i16",
Scalar::U16 => "u16",
Scalar::I32 => "i32",
Expand Down
4 changes: 2 additions & 2 deletions naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ impl<'a> Error<'a> {
Error::UnderspecifiedCooperativeMatrix => ParseError {
message: "cooperative matrix constructor is underspecified".into(),
labels: vec![],
notes: vec![format!("must be F32")],
notes: vec![format!("scalar type must be one of: f16, f32, i32, u32, i8, u8")],
},
Error::InvalidCooperativeLoadType(span) => ParseError {
message: "cooperative load should have a generic type for coop_mat".into(),
Expand All @@ -1472,7 +1472,7 @@ impl<'a> Error<'a> {
Error::UnsupportedCooperativeScalar(span) => ParseError {
message: "cooperative scalar type is not supported".into(),
labels: vec![(span, "type needs the scalar type specified".into())],
notes: vec![format!("must be F32")],
notes: vec![format!("scalar type must be one of: f16, f32, i32, u32, i8, u8")],
},
Error::UnexpectedIdentForEnumerant(ident_span) => ParseError {
message: format!(
Expand Down
Loading
Loading