Skip to content

Commit 4fa9f00

Browse files
committed
coop: support generic argument on coopLoad
1 parent 07be9e9 commit 4fa9f00

File tree

8 files changed

+59
-12
lines changed

8 files changed

+59
-12
lines changed

naga/src/back/wgsl/writer.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,9 +1712,29 @@ impl<W: Write> Writer<W> {
17121712
| Expression::SubgroupBallotResult
17131713
| Expression::SubgroupOperationResult { .. }
17141714
| Expression::WorkGroupUniformLoadResult { .. } => {}
1715-
Expression::CooperativeLoad { ref data, .. } => {
1715+
Expression::CooperativeLoad {
1716+
columns,
1717+
rows,
1718+
role,
1719+
ref data,
1720+
} => {
17161721
let suffix = if data.row_major { "T" } else { "" };
1717-
write!(self.out, "coopLoad{suffix}(")?;
1722+
let scalar = func_ctx.info[data.pointer]
1723+
.ty
1724+
.inner_with(&module.types)
1725+
.pointer_base_type()
1726+
.unwrap()
1727+
.inner_with(&module.types)
1728+
.scalar()
1729+
.unwrap();
1730+
write!(
1731+
self.out,
1732+
"coopLoad{suffix}<coop_mat{}x{}<{},{:?}>>(",
1733+
columns as u32,
1734+
rows as u32,
1735+
scalar.try_to_wgsl().unwrap(),
1736+
role,
1737+
)?;
17181738
self.write_expr(module, data.pointer, func_ctx)?;
17191739
write!(self.out, ", ")?;
17201740
self.write_expr(module, data.stride, func_ctx)?;

naga/src/front/wgsl/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ pub(crate) enum Error<'a> {
413413
span: Span,
414414
},
415415
UnderspecifiedCooperativeMatrix,
416+
InvalidCooperativeLoadType(Span),
416417
UnsupportedCooperativeScalar(Span),
417418
}
418419

@@ -1393,6 +1394,11 @@ impl<'a> Error<'a> {
13931394
labels: vec![],
13941395
notes: vec![format!("must be F32")],
13951396
},
1397+
Error::InvalidCooperativeLoadType(span) => ParseError {
1398+
message: "cooperative load should have a generic type for coop_mat".into(),
1399+
labels: vec![(span, "type needs the coop_mat<...>".into())],
1400+
notes: vec![format!("must be a valid cooperative type")],
1401+
},
13961402
Error::UnsupportedCooperativeScalar(span) => ParseError {
13971403
message: "cooperative scalar type is not supported".into(),
13981404
labels: vec![(span, "type needs the scalar type specified".into())],

naga/src/front/wgsl/lower/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
19051905
stmt.span,
19061906
function,
19071907
arguments,
1908+
None,
19081909
&mut ctx.as_expression(block, &mut emitter),
19091910
true,
19101911
)?;
@@ -2227,9 +2228,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
22272228
ast::Expression::Call {
22282229
ref function,
22292230
ref arguments,
2231+
result_ty,
22302232
} => {
22312233
let handle = self
2232-
.call(span, function, arguments, ctx, false)?
2234+
.call(span, function, arguments, result_ty, ctx, false)?
22332235
.ok_or(Error::FunctionReturnsVoid(function.span))?;
22342236
return Ok(Typed::Plain(handle));
22352237
}
@@ -2424,6 +2426,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
24242426
span: Span,
24252427
function: &ast::Ident<'source>,
24262428
arguments: &[Handle<ast::Expression<'source>>],
2429+
result_ty: Option<(Handle<ast::Type<'source>>, Span)>,
24272430
ctx: &mut ExpressionContext<'source, '_, '_>,
24282431
is_statement: bool,
24292432
) -> Result<'source, Option<Handle<ir::Expression>>> {
@@ -3145,9 +3148,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31453148
let row_major = function.name.ends_with("T");
31463149
let mut args = ctx.prepare_args(arguments, 1, span);
31473150
let pointer = self.expression(args.next()?, ctx)?;
3148-
//TODO: read from generic argument
3149-
let columns = crate::CooperativeSize::Eight;
3150-
let rows = crate::CooperativeSize::Eight;
3151+
let (matrix_ty, matrix_span) = result_ty.expect("generic argument");
3152+
let (columns, rows, role) = match ctx.types[matrix_ty] {
3153+
ast::Type::CooperativeMatrix {
3154+
columns,
3155+
rows,
3156+
role,
3157+
..
3158+
} => (columns, rows, role),
3159+
_ => {
3160+
return Err(Box::new(Error::InvalidCooperativeLoadType(
3161+
matrix_span,
3162+
)))
3163+
}
3164+
};
31513165
let stride = if args.total_args > 1 {
31523166
self.expression(args.next()?, ctx)?
31533167
} else {
@@ -3167,7 +3181,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31673181
crate::Expression::CooperativeLoad {
31683182
columns,
31693183
rows,
3170-
role: crate::CooperativeRole::C, //TODO
3184+
role,
31713185
data: crate::CooperativeData {
31723186
pointer,
31733187
stride,

naga/src/front/wgsl/parse/ast.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ pub enum Expression<'a> {
487487
Call {
488488
function: Ident<'a>,
489489
arguments: Vec<Handle<Expression<'a>>>,
490+
result_ty: Option<(Handle<Type<'a>>, Span)>,
490491
},
491492
Index {
492493
base: Handle<Expression<'a>>,

naga/src/front/wgsl/parse/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,11 @@ impl Parser {
800800
}
801801
// everything else must be handled later, since they can be hidden by user-defined functions.
802802
_ => {
803+
let result_ty = if lexer.peek().0 == Token::Paren('<') {
804+
Some(self.singular_generic(lexer, ctx)?)
805+
} else {
806+
None
807+
};
803808
let arguments = self.arguments(lexer, ctx)?;
804809
ctx.unresolved.insert(ast::Dependency {
805810
ident: name,
@@ -811,6 +816,7 @@ impl Parser {
811816
span: name_span,
812817
},
813818
arguments,
819+
result_ty,
814820
}
815821
}
816822
};
@@ -959,7 +965,7 @@ impl Parser {
959965
} else if let Token::Paren('(') = lexer.peek().0 {
960966
self.pop_rule_span(lexer);
961967
return self.function_call(lexer, word, span, ctx);
962-
} else if word == "bitcast" {
968+
} else if ["bitcast", "coopLoad"].contains(&word) {
963969
self.pop_rule_span(lexer);
964970
return self.function_call(lexer, word, span, ctx);
965971
} else {

naga/tests/in/wgsl/cooperative-matrix.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ var<storage, read_write> ext: array<f32>;
55

66
@compute @workgroup_size(8, 8, 1)
77
fn main() {
8-
var c = coopLoad(&ext[4]);
8+
var c = coopLoad<coop_mat8x8<f32, C>>(&ext[4]);
99
var d = coopMultiplyAdd(a, b, c);
1010
coopStore(d, &ext[0]);
1111
c = d;

naga/tests/out/spv/wgsl-cooperative-matrix.spvasm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ var<storage, read_write> ext: array<f32>;
1919

2020
@compute @workgroup_size(8, 8, 1)
2121
fn main() {
22-
var c = coopLoad(&ext[4]);
22+
var c = coopLoad<coop_mat8x8<f32, C>>(&ext[4]);
2323
var d = coopMultiplyAdd(a, b, c);
2424
coopStore(d, &ext[0]);
2525
c = d;
@@ -71,7 +71,7 @@ OpMemberDecorate %22 0 Offset 0
7171
%28 = OpAccessChain %27 %21 %9
7272
OpBranch %34
7373
%34 = OpLabel
74-
OpLine %3 8 23
74+
OpLine %3 8 44
7575
OpLine %3 8 13
7676
%37 = OpAccessChain %35 %28 %36
7777
%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8

naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ fn main() {
88
var c: coop_mat8x8<f32,C>;
99
var d: coop_mat8x8<f32,C>;
1010

11-
c = coopLoad((&ext[4]), 8u);
11+
c = coopLoad<coop_mat8x8<f32,C>>((&ext[4]), 8u);
1212
let _e6 = a;
1313
let _e8 = b;
1414
let _e9 = c;

0 commit comments

Comments
 (0)