@@ -362,19 +362,23 @@ ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings, DataType
362362
363363#define TVM_TIR_IR_BUILDER_FOR_FRAME (Method, Kind ) \
364364 ForFrame Method (PrimExpr start, PrimExpr stop, \
365- ffi::Optional<ffi::Map<ffi::String, Any>> annotations) { \
365+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations, \
366+ ffi::Optional<PrimExpr> step) { \
366367 PrimExpr min = start; \
367368 PrimExpr extent = arith::Analyzer ().Simplify (stop - start); \
368369 ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>(); \
369370 int bits = std::max (min.dtype ().bits (), extent.dtype ().bits ()); \
370371 n->vars = {Var (" v" , DataType (min.dtype ().code (), bits, 1 ))}; \
371372 n->doms = {Range::FromMinExtent (min, extent)}; \
373+ n->steps = {step}; \
372374 n->f_make_for_loop = [annotations](ffi::Array<Var> vars, ffi::Array<Range> doms, \
375+ ffi::Array<ffi::Optional<PrimExpr>> steps, \
373376 tvm::tir::Stmt body) { \
374377 ICHECK_EQ (vars.size (), 1 ); \
375378 ICHECK_EQ (doms.size (), 1 ); \
379+ ICHECK_EQ (steps.size (), 1 ); \
376380 return tvm::tir::For (vars[0 ], doms[0 ]->min , doms[0 ]->extent , Kind, body, std::nullopt , \
377- annotations.value_or (ffi::Map<ffi::String, Any>())); \
381+ annotations.value_or (ffi::Map<ffi::String, Any>()), steps[ 0 ]); \
378382 }; \
379383 return ForFrame (n); \
380384 }
@@ -396,13 +400,16 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread,
396400 DataType dtype = DataType (min.dtype ().code (), bits, 1 );
397401 n->vars = {Var (" v" , dtype)};
398402 n->doms = {Range::FromMinExtent (min, extent)};
403+ n->steps = {std::nullopt };
399404 n->f_make_for_loop = [annotations, thread, dtype](ffi::Array<Var> vars, ffi::Array<Range> doms,
405+ ffi::Array<ffi::Optional<PrimExpr>> steps,
400406 Stmt body) -> For {
401407 ICHECK_EQ (vars.size (), 1 );
402408 ICHECK_EQ (doms.size (), 1 );
409+ ICHECK (steps.size () == 1 && (!steps[0 ].has_value () || is_one (*steps[0 ])));
403410 IterVar iter_var (Range (nullptr ), Var (" iter" , dtype), IterVarType::kThreadIndex , thread);
404411 return For (vars[0 ], doms[0 ]->min , doms[0 ]->extent , ForKind::kThreadBinding , body, iter_var,
405- annotations.value_or (ffi::Map<ffi::String, ffi::Any>()));
412+ annotations.value_or (ffi::Map<ffi::String, ffi::Any>()), std:: nullopt );
406413 };
407414 return ForFrame (n);
408415}
@@ -412,19 +419,22 @@ ForFrame Grid(ffi::Array<PrimExpr> extents) {
412419 ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();
413420 n->vars .reserve (extents.size ());
414421 n->doms .reserve (extents.size ());
422+ n->steps .resize (extents.size ());
415423 for (const auto & extent : extents) {
416424 DataType dtype = extent.dtype ();
417425 n->vars .push_back (Var (" v" , extent.dtype ()));
418426 n->doms .push_back (Range (make_const (dtype, 0 ), extent));
419427 }
420- n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms, Stmt body) -> Stmt {
428+ n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms,
429+ ffi::Array<ffi::Optional<PrimExpr>> steps, Stmt body) -> Stmt {
421430 ICHECK_EQ (vars.size (), doms.size ());
431+ ICHECK_EQ (vars.size (), steps.size ());
422432 int n = vars.size ();
423433 for (int i = n - 1 ; i >= 0 ; --i) {
424434 Range dom = doms[i];
425435 Var var = vars[i];
426436 body = For (var, dom->min , dom->extent , ForKind::kSerial , std::move (body),
427- /* thread_binding=*/ std::nullopt , /* annotations=*/ {});
437+ /* thread_binding=*/ std::nullopt , /* annotations=*/ {}, /* step= */ steps[i] );
428438 }
429439 return body;
430440 };
0 commit comments