Skip to content

Commit d18bc38

Browse files
committed
coop: make stride non-optional
1 parent 6813768 commit d18bc38

File tree

17 files changed

+105
-120
lines changed

17 files changed

+105
-120
lines changed

naga/src/back/dot/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,7 @@ impl StatementGraph {
431431
} => {
432432
self.dependencies.push((id, target, "target"));
433433
self.dependencies.push((id, pointer, "pointer"));
434-
if let Some(stride) = stride {
435-
self.dependencies.push((id, stride, "stride"));
436-
}
434+
self.dependencies.push((id, stride, "stride"));
437435
if store {
438436
"Store"
439437
} else {

naga/src/back/msl/writer.rs

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,30 +4254,8 @@ impl<W: Write> Writer<W> {
42544254
self.put_expression(target, &context.expression, true)?;
42554255
write!(self.out, ", ")?;
42564256
self.put_expression(pointer, &context.expression, true)?;
4257-
if stride.is_some() || row_major {
4258-
write!(self.out, ", ")?;
4259-
match stride {
4260-
Some(expression) => {
4261-
self.put_expression(expression, &context.expression, true)?;
4262-
}
4263-
None => {
4264-
let default_stride = match *context.expression.resolve_type(target)
4265-
{
4266-
crate::TypeInner::CooperativeMatrix {
4267-
columns, rows, ..
4268-
} => {
4269-
if row_major {
4270-
columns as u32
4271-
} else {
4272-
rows as u32
4273-
}
4274-
}
4275-
_ => 0,
4276-
};
4277-
write!(self.out, "{default_stride}")?;
4278-
}
4279-
}
4280-
}
4257+
write!(self.out, ", ")?;
4258+
self.put_expression(stride, &context.expression, true)?;
42814259
if row_major {
42824260
let matrix_origin = "0";
42834261
let transpose = true;

naga/src/back/pipeline_constants.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -898,9 +898,7 @@ fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut S
898898
} => {
899899
adjust(target);
900900
adjust(pointer);
901-
if let Some(ref mut stride) = *stride {
902-
adjust(stride);
903-
}
901+
adjust(stride);
904902
}
905903
Statement::Break
906904
| Statement::Continue

naga/src/back/spv/block.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,13 +3711,12 @@ impl BlockContext<'_> {
37113711
spirv::CooperativeMatrixLayout::ColumnMajorKHR
37123712
};
37133713
let layout_id = self.get_index_constant(layout as u32);
3714-
let stride_id = stride.map(|exp| self.cached[exp]);
37153714
if store {
37163715
block.body.push(Instruction::coop_store(
37173716
self.cached[target],
37183717
pointer_id,
37193718
layout_id,
3720-
stride_id,
3719+
self.cached[stride],
37213720
));
37223721
} else {
37233722
let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty);
@@ -3727,7 +3726,7 @@ impl BlockContext<'_> {
37273726
id,
37283727
pointer_id,
37293728
layout_id,
3730-
stride_id,
3729+
self.cached[stride],
37313730
));
37323731
block
37333732
.body

naga/src/back/spv/instructions.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,33 +1252,22 @@ impl super::Instruction {
12521252
id: Word,
12531253
pointer_id: Word,
12541254
layout_id: Word,
1255-
stride_id: Option<Word>,
1255+
stride_id: Word,
12561256
) -> Self {
12571257
let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR);
12581258
instruction.set_type(result_type_id);
12591259
instruction.set_result(id);
12601260
instruction.add_operand(pointer_id);
12611261
instruction.add_operand(layout_id);
1262-
if let Some(stride_id) = stride_id {
1263-
instruction.add_operand(stride_id);
1264-
}
1265-
1262+
instruction.add_operand(stride_id);
12661263
instruction
12671264
}
1268-
pub(super) fn coop_store(
1269-
id: Word,
1270-
pointer_id: Word,
1271-
layout_id: Word,
1272-
stride_id: Option<Word>,
1273-
) -> Self {
1265+
pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self {
12741266
let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR);
12751267
instruction.add_operand(pointer_id);
12761268
instruction.add_operand(id);
12771269
instruction.add_operand(layout_id);
1278-
if let Some(stride_id) = stride_id {
1279-
instruction.add_operand(stride_id);
1280-
}
1281-
1270+
instruction.add_operand(stride_id);
12821271
instruction
12831272
}
12841273
pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self {

naga/src/back/wgsl/writer.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,10 +998,8 @@ impl<W: Write> Writer<W> {
998998
self.write_expr(module, target, func_ctx)?;
999999
write!(self.out, ", ")?;
10001000
self.write_expr(module, pointer, func_ctx)?;
1001-
if let Some(stride) = stride {
1002-
write!(self.out, ", ")?;
1003-
self.write_expr(module, stride, func_ctx)?;
1004-
}
1001+
write!(self.out, ", ")?;
1002+
self.write_expr(module, stride, func_ctx)?;
10051003
write!(self.out, ")")?
10061004
}
10071005
}

naga/src/compact/statements.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,7 @@ impl FunctionTracer<'_> {
175175
} => {
176176
self.expressions_used.insert(target);
177177
self.expressions_used.insert(pointer);
178-
if let Some(stride) = stride {
179-
self.expressions_used.insert(stride);
180-
}
178+
self.expressions_used.insert(stride);
181179
}
182180

183181
// Trivial statements.
@@ -427,9 +425,7 @@ impl FunctionMap {
427425
} => {
428426
adjust(target);
429427
adjust(pointer);
430-
if let Some(ref mut stride) = *stride {
431-
adjust(stride);
432-
}
428+
adjust(stride);
433429
}
434430

435431
// Trivial statements.

naga/src/front/wgsl/lower/mod.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3138,19 +3138,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
31383138
return Ok(Some(result));
31393139
}
31403140
"coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => {
3141+
let store = function.name.contains("Store");
3142+
let row_major = function.name.ends_with("T");
3143+
31413144
let mut args = ctx.prepare_args(arguments, 2, span);
31423145
let target = self.expression(args.next()?, ctx)?;
31433146
let pointer = self.expression(args.next()?, ctx)?;
31443147
let stride = if args.total_args > 2 {
3145-
Some(self.expression(args.next()?, ctx)?)
3148+
self.expression(args.next()?, ctx)?
31463149
} else {
3147-
None
3150+
// Infer the stride from the matrix type
3151+
let stride = match *resolve_inner!(ctx, target) {
3152+
ir::TypeInner::CooperativeMatrix { columns, rows, .. } => {
3153+
if row_major {
3154+
columns as u32
3155+
} else {
3156+
rows as u32
3157+
}
3158+
}
3159+
_ => 0,
3160+
};
3161+
ctx.append_expression(
3162+
ir::Expression::Literal(ir::Literal::U32(stride)),
3163+
Span::UNDEFINED,
3164+
)?
31483165
};
31493166
args.finish()?;
31503167

3151-
let store = function.name.contains("Store");
3152-
let row_major = function.name.ends_with("T");
3153-
31543168
let rctx = ctx.runtime_expression_ctx(span)?;
31553169
rctx.block.push(
31563170
crate::Statement::CooperativeLoadStore {

naga/src/ir/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2288,7 +2288,7 @@ pub enum Statement {
22882288
store: bool,
22892289
target: Handle<Expression>,
22902290
pointer: Handle<Expression>,
2291-
stride: Option<Handle<Expression>>,
2291+
stride: Handle<Expression>,
22922292
row_major: bool,
22932293
},
22942294
}

naga/src/valid/analyzer.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,18 +1234,16 @@ impl FunctionInfo {
12341234
pointer,
12351235
stride,
12361236
row_major: _,
1237-
} => {
1238-
if let Some(stride) = stride {
1239-
let _ = self.add_ref(stride);
1240-
}
1241-
FunctionUniformity {
1242-
result: Uniformity {
1243-
non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)),
1244-
requirements: UniformityRequirements::COOP_OPS,
1245-
},
1246-
exit: ExitFlags::empty(),
1247-
}
1248-
}
1237+
} => FunctionUniformity {
1238+
result: Uniformity {
1239+
non_uniform_result: self
1240+
.add_ref(target)
1241+
.or(self.add_ref(pointer))
1242+
.or(self.add_ref(stride)),
1243+
requirements: UniformityRequirements::COOP_OPS,
1244+
},
1245+
exit: ExitFlags::empty(),
1246+
},
12491247
};
12501248

12511249
disruptor = disruptor.or(uniformity.exit_disruptor());

0 commit comments

Comments
 (0)