Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https://

- Expose `naga::front::wgsl::UnimplementedEnableExtension`. By @ErichDonGubler in [#8237](https://github.com/gfx-rs/wgpu/pull/8237).

- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).

### Changes

#### General
Expand Down
19 changes: 13 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ ndk-sys = "0.6"
# These overrides allow our examples to explicitly depend on release crates
[patch.crates-io]
wgpu = { path = "./wgpu" }
rspirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" }

# https://github.com/Xudong-Huang/generator-rs/pull/75
generator = { git = "https://github.com/Xudong-Huang/generator-rs", rev = "70b89fdabcc0e82fe84ca17f65cc52ff25e8e6de" }
Expand Down
22 changes: 22 additions & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ impl StatementGraph {
},
}
}
S::CooperativeLoadStore {
store,
target,
pointer,
stride,
row_major: _,
} => {
self.dependencies.push((id, target, "target"));
self.dependencies.push((id, pointer, "pointer"));
self.dependencies.push((id, stride, "stride"));
if store {
"Store"
} else {
"Load"
}
}
};
// Set the last node to the merge node
last_node = merge_id;
Expand Down Expand Up @@ -761,6 +777,12 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("get{ty}HitVertexPositions").into(), 4)
}
E::CooperativeMultiplyAdd { a, b, c } => {
edges.insert("a", a);
edges.insert("b", b);
edges.insert("c", c);
("cooperativeMultiplyAdd".into(), 4)
}
};

// give uniform expressions an outline
Expand Down
7 changes: 5 additions & 2 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,8 @@ impl<'a, W: Write> Writer<'a, W> {
TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
// Write all variants instead of `_` so that if new variants are added a
// no exhaustiveness error is thrown
TypeInner::Pointer { .. }
TypeInner::CooperativeMatrix { .. }
| TypeInner::Pointer { .. }
| TypeInner::Struct { .. }
| TypeInner::Image { .. }
| TypeInner::Sampler { .. }
Expand Down Expand Up @@ -2815,6 +2816,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
writeln!(self.out, ");")?;
}
Statement::CooperativeLoadStore { .. } => unimplemented!(),
}

Ok(())
Expand Down Expand Up @@ -4351,7 +4353,8 @@ impl<'a, W: Write> Writer<'a, W> {
}
// not supported yet
Expression::RayQueryGetIntersection { .. }
| Expression::RayQueryVertexPositions { .. } => unreachable!(),
| Expression::RayQueryVertexPositions { .. }
| Expression::CooperativeMultiplyAdd { .. } => unreachable!(),
}

Ok(())
Expand Down
6 changes: 5 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2769,6 +2769,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
writeln!(self.out, ");")?;
}
Statement::CooperativeLoadStore { .. } => unimplemented!(),
}

Ok(())
Expand Down Expand Up @@ -4298,7 +4299,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
// Not supported yet
Expression::RayQueryVertexPositions { .. } => unreachable!(),
Expression::RayQueryVertexPositions { .. }
| Expression::CooperativeMultiplyAdd { .. } => {
unreachable!()
}
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
Expand Down
6 changes: 2 additions & 4 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,10 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
}

impl crate::TypeInner {
/// Returns true if this is a handle to a type rather than the type directly.
/// Returns true if a variable of this type is a handle.
pub const fn is_handle(&self) -> bool {
match *self {
crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure { .. } => true,
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
_ => false,
}
}
Expand Down
118 changes: 118 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
/// allowing them to be conveniently passed to user-defined or wrapper
/// functions. The struct is declared in [`Writer::write_type_defs`].
pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper";
pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
Expand Down Expand Up @@ -235,6 +236,21 @@ impl Display for TypeContext<'_> {
rows,
scalar,
} => put_numeric_type(out, scalar, &[rows, columns]),
crate::TypeInner::CooperativeMatrix {
columns,
rows,
scalar,
role: _,
} => {
write!(
out,
"{}::simdgroup_{}{}x{}",
NAMESPACE,
scalar.to_msl_name(),
columns as u32,
rows as u32,
)
}
crate::TypeInner::Pointer { base, space } => {
let sub = Self {
handle: base,
Expand Down Expand Up @@ -468,6 +484,12 @@ enum WrappedFunction {
ImageQuerySize {
class: crate::ImageClass,
},
CooperativeMultiplyAdd {
columns: crate::CooperativeSize,
rows: crate::CooperativeSize,
intermediate: crate::CooperativeSize,
scalar: crate::Scalar,
},
}

pub struct Writer<W> {
Expand Down Expand Up @@ -640,6 +662,7 @@ impl crate::Type {
Ti::Scalar(_)
| Ti::Vector { .. }
| Ti::Matrix { .. }
| Ti::CooperativeMatrix { .. }
| Ti::Atomic(_)
| Ti::Pointer { .. }
| Ti::ValuePointer { .. } => self.name.is_some(),
Expand Down Expand Up @@ -2821,6 +2844,15 @@ impl<W: Write> Writer<W> {
}
write!(self.out, "}}")?;
}
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?;
self.put_expression(a, context, true)?;
write!(self.out, ", ")?;
self.put_expression(b, context, true)?;
write!(self.out, ", ")?;
self.put_expression(c, context, true)?;
write!(self.out, ")")?;
}
}
Ok(())
}
Expand Down Expand Up @@ -4210,6 +4242,27 @@ impl<W: Write> Writer<W> {
}
writeln!(self.out, ");")?;
}
crate::Statement::CooperativeLoadStore {
store,
target,
pointer,
stride,
row_major,
} => {
let op_str = if store { "store" } else { "load" };
write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?;
self.put_expression(target, &context.expression, true)?;
write!(self.out, ", ")?;
self.put_expression(pointer, &context.expression, true)?;
write!(self.out, ", ")?;
self.put_expression(stride, &context.expression, true)?;
if row_major {
let matrix_origin = "0";
let transpose = true;
write!(self.out, ", {matrix_origin}, {transpose}")?;
}
writeln!(self.out, ");")?;
}
}
}

Expand Down Expand Up @@ -6266,6 +6319,68 @@ template <typename A>
Ok(())
}

fn write_wrapped_cooperative_multiply_add(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
a: Handle<crate::Expression>,
b: Handle<crate::Expression>,
) -> BackendResult {
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
crate::TypeInner::CooperativeMatrix {
columns,
rows,
scalar,
..
} => (columns, rows, scalar),
_ => unreachable!(),
},
_ => unreachable!(),
};
let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
_ => unreachable!(),
},
_ => unreachable!(),
};
let wrapped = WrappedFunction::CooperativeMultiplyAdd {
columns: b_c,
rows: a_r,
intermediate: a_c,
scalar,
};
if !self.wrapped_functions.insert(wrapped) {
return Ok(());
}
let scalar_name = match scalar.width {
2 => "half",
4 => "float",
8 => "double",
_ => unreachable!(),
};
writeln!(
self.out,
"{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{",
b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32,
)?;
let l1 = back::Level(1);
writeln!(
self.out,
"{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;",
b_c as u32, a_r as u32
)?;
writeln!(
self.out,
"{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);"
)?;
writeln!(self.out, "{l1}return d;")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}

pub(super) fn write_wrapped_functions(
&mut self,
module: &crate::Module,
Expand Down Expand Up @@ -6340,6 +6455,9 @@ template <typename A>
crate::Expression::ImageQuery { image, query } => {
self.write_wrapped_image_query(module, func_ctx, image, query)?;
}
crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => {
self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?;
}
_ => {}
}
}
Expand Down
20 changes: 20 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,15 @@ fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut E
} => {
adjust(query);
}
Expression::CooperativeMultiplyAdd {
ref mut a,
ref mut b,
ref mut c,
} => {
adjust(a);
adjust(b);
adjust(c);
}
}
}

Expand Down Expand Up @@ -880,6 +889,17 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
adjust(index);
adjust(value);
}
Statement::CooperativeLoadStore {
store: _,
ref mut target,
ref mut pointer,
ref mut stride,
row_major: _,
} => {
adjust(target);
adjust(pointer);
adjust(stride);
}
Statement::Break
| Statement::Continue
| Statement::Kill
Expand Down
Loading
Loading