Skip to content

Commit d801c1e

Browse files
committed
coop: rewire WGSL support using references
1 parent d18bc38 commit d801c1e

19 files changed

+174
-371
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https://
376376

377377
- Expose `naga::front::wgsl::UnimplementedEnableExtension`. By @ErichDonGubler in [#8237](https://github.com/gfx-rs/wgpu/pull/8237).
378378

379-
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
379+
- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251).
380380

381381
### Changes
382382

naga/src/back/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
311311
}
312312
}
313313

314+
impl crate::TypeInner {
315+
/// Returns true if a variable of this type is a handle.
316+
pub const fn is_handle(&self) -> bool {
317+
match *self {
318+
Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true,
319+
_ => false,
320+
}
321+
}
322+
}
323+
314324
impl crate::Statement {
315325
/// Returns true if the statement directly terminates the current block.
316326
///

naga/src/back/msl/writer.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6327,16 +6327,22 @@ template <typename A>
63276327
b: Handle<crate::Expression>,
63286328
) -> BackendResult {
63296329
let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) {
6330-
crate::TypeInner::CooperativeMatrix {
6331-
columns,
6332-
rows,
6333-
scalar,
6334-
..
6335-
} => (columns, rows, scalar),
6330+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6331+
crate::TypeInner::CooperativeMatrix {
6332+
columns,
6333+
rows,
6334+
scalar,
6335+
..
6336+
} => (columns, rows, scalar),
6337+
_ => unreachable!(),
6338+
},
63366339
_ => unreachable!(),
63376340
};
63386341
let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) {
6339-
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6342+
crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner {
6343+
crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows),
6344+
_ => unreachable!(),
6345+
},
63406346
_ => unreachable!(),
63416347
};
63426348
let wrapped = WrappedFunction::CooperativeMultiplyAdd {

naga/src/back/spv/block.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3719,7 +3719,13 @@ impl BlockContext<'_> {
37193719
self.cached[stride],
37203720
));
37213721
} else {
3722-
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
3722+
let result_type_id =
3723+
match *self.fun_info[target].ty.inner_with(&self.ir_module.types) {
3724+
crate::TypeInner::Pointer { base, space: _ } => {
3725+
self.get_handle_type_id(base)
3726+
}
3727+
_ => unreachable!(),
3728+
};
37233729
let id = self.gen_id();
37243730
block.body.push(Instruction::coop_load(
37253731
result_type_id,

naga/src/back/spv/writer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -970,14 +970,13 @@ impl Writer {
970970
}
971971
}
972972

973-
// Handle globals are pre-emitted and should be loaded automatically.
974-
//
975-
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
976973
match ir_module.types[var.ty].inner {
974+
// Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
977975
crate::TypeInner::BindingArray { .. } => {
978976
gv.access_id = gv.var_id;
979977
}
980978
_ => {
979+
// Handle globals are pre-emitted and should be loaded automatically.
981980
if var.space == crate::AddressSpace::Handle {
982981
let var_type_id = self.get_handle_type_id(var.ty);
983982
let id = self.id_gen.next();
@@ -1063,6 +1062,7 @@ impl Writer {
10631062
}
10641063
}),
10651064
);
1065+
10661066
context
10671067
.function
10681068
.variables

naga/src/back/wgsl/writer.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -994,13 +994,13 @@ impl<W: Write> Writer<W> {
994994
} => {
995995
let op_str = if store { "Store" } else { "Load" };
996996
let suffix = if row_major { "T" } else { "" };
997-
write!(self.out, "coop{op_str}{suffix}(")?;
998-
self.write_expr(module, target, func_ctx)?;
997+
write!(self.out, "{level}coop{op_str}{suffix}(")?;
998+
self.write_expr_with_indirection(module, target, func_ctx, Indirection::Reference)?;
999999
write!(self.out, ", ")?;
10001000
self.write_expr(module, pointer, func_ctx)?;
10011001
write!(self.out, ", ")?;
10021002
self.write_expr(module, stride, func_ctx)?;
1003-
write!(self.out, ")")?
1003+
writeln!(self.out, ");")?
10041004
}
10051005
}
10061006

@@ -1715,11 +1715,11 @@ impl<W: Write> Writer<W> {
17151715
| Expression::WorkGroupUniformLoadResult { .. } => {}
17161716
Expression::CooperativeMultiplyAdd { a, b, c } => {
17171717
write!(self.out, "coopMultiplyAdd(")?;
1718-
self.write_expr(module, a, func_ctx)?;
1718+
self.write_expr_with_indirection(module, a, func_ctx, Indirection::Reference)?;
17191719
write!(self.out, ", ")?;
1720-
self.write_expr(module, b, func_ctx)?;
1720+
self.write_expr_with_indirection(module, b, func_ctx, Indirection::Reference)?;
17211721
write!(self.out, ", ")?;
1722-
self.write_expr(module, c, func_ctx)?;
1722+
self.write_expr_with_indirection(module, c, func_ctx, Indirection::Reference)?;
17231723
write!(self.out, ")")?;
17241724
}
17251725
}

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ pub(crate) enum Error<'a> {
412412
TypeTooLarge {
413413
span: Span,
414414
},
415+
InvalidCooperativeMatrix,
415416
UnderspecifiedCooperativeMatrix,
416417
UnsupportedCooperativeScalar(Span),
417418
}
@@ -1388,6 +1389,11 @@ impl<'a> Error<'a> {
13881389
crate::valid::MAX_TYPE_SIZE
13891390
)],
13901391
},
1392+
Error::InvalidCooperativeMatrix => ParseError {
1393+
message: "given type is not a cooperative matrix".into(),
1394+
labels: vec![],
1395+
notes: vec![format!("must be coop_mat")],
1396+
},
13911397
Error::UnderspecifiedCooperativeMatrix => ParseError {
13921398
message: "cooperative matrix constructor is underspecified".into(),
13931399
labels: vec![],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
846846
fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle<ir::Type> {
847847
self.as_global().ensure_type_exists(None, inner)
848848
}
849+
850+
fn _get_runtime_expression(&self, expr: Handle<ir::Expression>) -> &ir::Expression {
851+
match self.expr_type {
852+
ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr],
853+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
854+
unreachable!()
855+
}
856+
}
857+
}
849858
}
850859

851860
struct ArgumentContext<'ctx, 'source> {
@@ -955,6 +964,13 @@ impl<T> Typed<T> {
955964
Self::Plain(expr) => Typed::Plain(f(expr)?),
956965
})
957966
}
967+
968+
fn ref_or<E>(self, error: E) -> core::result::Result<T, E> {
969+
match self {
970+
Self::Reference(v) => Ok(v),
971+
Self::Plain(_) => Err(error),
972+
}
973+
}
958974
}
959975

960976
/// A single vector component or swizzle.
@@ -1679,12 +1695,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
16791695
.as_expression(block, &mut emitter)
16801696
.interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?;
16811697
block.extend(emitter.finish(&ctx.function.expressions));
1682-
let typed = if ctx.module.types[ty].inner.is_handle() {
1683-
Typed::Plain(handle)
1684-
} else {
1685-
Typed::Reference(handle)
1686-
};
1687-
ctx.local_table.insert(v.handle, Declared::Runtime(typed));
1698+
ctx.local_table
1699+
.insert(v.handle, Declared::Runtime(Typed::Reference(handle)));
16881700

16891701
match initializer {
16901702
Some(initializer) => ir::Statement::Store {
@@ -1979,12 +1991,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19791991
let value_span = ctx.ast_expressions.get_span(value);
19801992
let target = self
19811993
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
1982-
let target_handle = match target {
1983-
Typed::Reference(handle) => handle,
1984-
Typed::Plain(_) => {
1985-
return Err(Box::new(Error::BadIncrDecrReferenceType(value_span)))
1986-
}
1987-
};
1994+
let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?;
19881995

19891996
let mut ectx = ctx.as_expression(block, &mut emitter);
19901997
let scalar = match *resolve_inner!(ectx, target_handle) {
@@ -2141,10 +2148,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
21412148
LoweredGlobalDecl::Var(handle) => {
21422149
let expr = ir::Expression::GlobalVariable(handle);
21432150
let v = &ctx.module.global_variables[handle];
2144-
let force_value = ctx.module.types[v.ty].inner.is_handle();
21452151
match v.space {
21462152
ir::AddressSpace::Handle => Typed::Plain(expr),
2147-
_ if force_value => Typed::Plain(expr),
21482153
_ => Typed::Reference(expr),
21492154
}
21502155
}
@@ -3142,7 +3147,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31423147
let row_major = function.name.ends_with("T");
31433148

31443149
let mut args = ctx.prepare_args(arguments, 2, span);
3145-
let target = self.expression(args.next()?, ctx)?;
3150+
let target = self
3151+
.expression_for_reference(args.next()?, ctx)?
3152+
.ref_or(Error::InvalidCooperativeMatrix)?;
31463153
let pointer = self.expression(args.next()?, ctx)?;
31473154
let stride = if args.total_args > 2 {
31483155
self.expression(args.next()?, ctx)?
@@ -3180,9 +3187,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31803187
}
31813188
"coopMultiplyAdd" => {
31823189
let mut args = ctx.prepare_args(arguments, 3, span);
3183-
let a = self.expression(args.next()?, ctx)?;
3184-
let b = self.expression(args.next()?, ctx)?;
3185-
let c = self.expression(args.next()?, ctx)?;
3190+
let a = self
3191+
.expression_for_reference(args.next()?, ctx)?
3192+
.ref_or(Error::InvalidCooperativeMatrix)?;
3193+
let b = self
3194+
.expression_for_reference(args.next()?, ctx)?
3195+
.ref_or(Error::InvalidCooperativeMatrix)?;
3196+
let c = self
3197+
.expression_for_reference(args.next()?, ctx)?
3198+
.ref_or(Error::InvalidCooperativeMatrix)?;
31863199
args.finish()?;
31873200

31883201
ir::Expression::CooperativeMultiplyAdd { a, b, c }

naga/src/proc/type_methods.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,6 @@ impl crate::TypeInner {
191191
}
192192
}
193193

194-
/// Returns true if a variable of this type is a handle.
195-
pub const fn is_handle(&self) -> bool {
196-
match *self {
197-
Self::Image { .. }
198-
| Self::Sampler { .. }
199-
| Self::AccelerationStructure { .. }
200-
| Self::CooperativeMatrix { .. } => true,
201-
_ => false,
202-
}
203-
}
204-
205194
/// Attempt to calculate the size of this type. Returns `None` if the size
206195
/// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`].
207196
pub fn try_size(&self, gctx: super::GlobalCtx) -> Option<u32> {

naga/src/proc/typifier.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ impl<'a> ResolveContext<'a> {
454454
}
455455
crate::Expression::GlobalVariable(h) => {
456456
let var = &self.global_vars[h];
457-
let ty = &types[var.ty].inner;
458-
if var.space == crate::AddressSpace::Handle || ty.is_handle() {
457+
if var.space == crate::AddressSpace::Handle {
459458
TypeResolution::Handle(var.ty)
460459
} else {
461460
TypeResolution::Value(Ti::Pointer {
@@ -466,15 +465,10 @@ impl<'a> ResolveContext<'a> {
466465
}
467466
crate::Expression::LocalVariable(h) => {
468467
let var = &self.local_vars[h];
469-
let ty = &types[var.ty].inner;
470-
if ty.is_handle() {
471-
TypeResolution::Handle(var.ty)
472-
} else {
473-
TypeResolution::Value(Ti::Pointer {
474-
base: var.ty,
475-
space: crate::AddressSpace::Function,
476-
})
477-
}
468+
TypeResolution::Value(Ti::Pointer {
469+
base: var.ty,
470+
space: crate::AddressSpace::Function,
471+
})
478472
}
479473
crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
480474
Ti::Pointer { base, space: _ } => {
@@ -807,7 +801,15 @@ impl<'a> ResolveContext<'a> {
807801
scalar: crate::Scalar::U32,
808802
size: crate::VectorSize::Quad,
809803
}),
810-
crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(),
804+
crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => {
805+
match *past(c)?.inner_with(types) {
806+
Ti::Pointer { base, space: _ } => TypeResolution::Handle(base),
807+
ref other => {
808+
log::error!("Pointer type {other:?}");
809+
return Err(ResolveError::InvalidPointer(c));
810+
}
811+
}
812+
}
811813
})
812814
}
813815
}

0 commit comments

Comments
 (0)