Skip to content

Commit fd57ac5

Browse files
author
baoxinqi
committed
Add step attribute to ForNode (Initial codes)
1 parent 8ab96af commit fd57ac5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+382
-178
lines changed

include/tvm/script/ir_builder/tir/frame.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode {
251251
* \param loop_body The loop body
252252
* \return A stmt, the loop nest
253253
*/
254-
using FMakeForLoop =
255-
ffi::TypedFunction<tvm::tir::Stmt(ffi::Array<tvm::tir::Var> loop_vars,
256-
ffi::Array<Range> loop_extents, tvm::tir::Stmt loop_body)>;
254+
using FMakeForLoop = ffi::TypedFunction<tvm::tir::Stmt(
255+
ffi::Array<tvm::tir::Var> loop_vars, ffi::Array<Range> loop_extents,
256+
ffi::Array<ffi::Optional<PrimExpr>> loop_steps, tvm::tir::Stmt loop_body)>;
257257
/*! \brief The loop variable. */
258258
ffi::Array<tvm::tir::Var> vars;
259259
/*! \brief The domains of iteration. */
260260
ffi::Array<Range> doms;
261+
/*! \brief The optional steps of iteration. */
262+
ffi::Array<ffi::Optional<PrimExpr>> steps;
261263
/*! \brief The for loop generating function. */
262264
FMakeForLoop f_make_for_loop;
263265

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,45 @@ ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings,
228228
* \param start The minimum value of iteration.
229229
* \param stop The maximum value of iteration.
230230
* \param annotations The optional annotations of the For statement.
231+
* \param step The optional step value of iteration.
231232
* \return The ForFrame.
232233
*/
233234
ForFrame Serial(PrimExpr start, PrimExpr stop,
234-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
235+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
236+
ffi::Optional<PrimExpr> step = std::nullopt);
235237
/*!
236238
* \brief The parallel For statement.
237239
* \param start The minimum value of iteration.
238240
* \param stop The maximum value of iteration.
239241
* \param annotations The optional annotations of the For statement.
242+
* \param step The optional step value of iteration.
240243
* \return The ForFrame.
241244
*/
242245
ForFrame Parallel(PrimExpr start, PrimExpr stop,
243-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
246+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
247+
ffi::Optional<PrimExpr> step = std::nullopt);
244248
/*!
245249
* \brief The vectorized For statement.
246250
* \param start The minimum value of iteration.
247251
* \param stop The maximum value of iteration.
248252
* \param annotations The optional annotations of the For statement.
253+
* \param step The optional step value of iteration.
249254
* \return The ForFrame.
250255
*/
251256
ForFrame Vectorized(PrimExpr start, PrimExpr stop,
252-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
257+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
258+
ffi::Optional<PrimExpr> step = std::nullopt);
253259
/*!
254260
* \brief The unrolled For statement.
255261
* \param start The minimum value of iteration.
256262
* \param stop The maximum value of iteration.
257263
* \param annotations The optional annotations of the For statement.
264+
* \param step The optional step value of iteration.
258265
* \return The ForFrame.
259266
*/
260267
ForFrame Unroll(PrimExpr start, PrimExpr stop,
261-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt);
268+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt,
269+
ffi::Optional<PrimExpr> step = std::nullopt);
262270
/*!
263271
* \brief The thread-binding For statement.
264272
* \param start The minimum value of iteration.

include/tvm/tir/stmt.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ enum class ForKind : int {
717717
*
718718
* \code
719719
*
720-
* for (loop_var = min; loop_var < min + extent; ++loop_var) {
720+
* for (loop_var = min; loop_var < min + extent; loop_var += step) {
721721
* // body
722722
* }
723723
* \endcode
@@ -748,6 +748,10 @@ class ForNode : public StmtNode {
748748
* and can be ignored in most passes.
749749
*/
750750
ffi::Map<ffi::String, ffi::Any> annotations;
751+
/*!
752+
* \brief The loop step. It is one if not specified.
753+
*/
754+
ffi::Optional<PrimExpr> step;
751755

752756
static void RegisterReflection() {
753757
namespace refl = tvm::ffi::reflection;
@@ -758,8 +762,13 @@ class ForNode : public StmtNode {
758762
.def_ro("kind", &ForNode::kind)
759763
.def_ro("body", &ForNode::body)
760764
.def_ro("thread_binding", &ForNode::thread_binding)
761-
.def_ro("annotations", &ForNode::annotations);
765+
.def_ro("annotations", &ForNode::annotations)
766+
.def_ro("step", &ForNode::step);
762767
}
768+
769+
/*! \brief Check it is a loop without nontrivial loop step. */
770+
bool HasTrivialStep() const;
771+
763772
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode);
764773
};
765774

@@ -770,9 +779,11 @@ class ForNode : public StmtNode {
770779
class For : public Stmt {
771780
public:
772781
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
773-
ffi::Optional<IterVar> thread_binding = std::nullopt,
774-
ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(),
775-
Span span = Span());
782+
ffi::Optional<IterVar> thread_binding, ffi::Map<ffi::String, ffi::Any> annotations,
783+
ffi::Optional<PrimExpr> step, Span span = Span());
784+
785+
TVM_DLL static For ForSimple(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind,
786+
Stmt body);
776787

777788
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode);
778789
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,11 @@ def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[L
677677

678678

679679
def serial(
680-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
680+
start: PrimExpr,
681+
stop: PrimExpr = None,
682+
*,
683+
annotations: Dict[str, Any] = None,
684+
step: Optional[PrimExpr] = None,
681685
) -> frame.ForFrame:
682686
"""The serial For statement.
683687
@@ -692,6 +696,9 @@ def serial(
692696
annotations : Dict[str, Any]
693697
The optional annotations of the For statement.
694698
699+
step : PrimExpr
700+
The optional step value of iteration.
701+
695702
Returns
696703
-------
697704
res : frame.ForFrame
@@ -703,11 +710,15 @@ def serial(
703710
start = IntImm(start.dtype, 0)
704711
else:
705712
start = 0
706-
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
713+
return _ffi_api.Serial(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
707714

708715

709716
def parallel(
710-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
717+
start: PrimExpr,
718+
stop: PrimExpr = None,
719+
*,
720+
annotations: Dict[str, Any] = None,
721+
step: Optional[PrimExpr] = None,
711722
) -> frame.ForFrame:
712723
"""The parallel For statement.
713724
@@ -722,6 +733,9 @@ def parallel(
722733
annotations : Dict[str, Any]
723734
The optional annotations of the For statement.
724735
736+
step : PrimExpr
737+
The optional step value of iteration.
738+
725739
Returns
726740
-------
727741
res : frame.ForFrame
@@ -733,11 +747,15 @@ def parallel(
733747
start = IntImm(start.dtype, 0)
734748
else:
735749
start = 0
736-
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
750+
return _ffi_api.Parallel(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
737751

738752

739753
def vectorized(
740-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
754+
start: PrimExpr,
755+
stop: PrimExpr = None,
756+
*,
757+
annotations: Dict[str, Any] = None,
758+
step: Optional[PrimExpr] = None,
741759
) -> frame.ForFrame:
742760
"""The vectorized For statement.
743761
@@ -752,6 +770,9 @@ def vectorized(
752770
annotations : Dict[str, Any]
753771
The optional annotations of the For statement.
754772
773+
step : PrimExpr
774+
The optional step value of iteration.
775+
755776
Returns
756777
-------
757778
res : frame.ForFrame
@@ -763,11 +784,15 @@ def vectorized(
763784
start = IntImm(start.dtype, 0)
764785
else:
765786
start = 0
766-
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
787+
return _ffi_api.Vectorized(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
767788

768789

769790
def unroll(
770-
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
791+
start: PrimExpr,
792+
stop: PrimExpr = None,
793+
*,
794+
annotations: Dict[str, Any] = None,
795+
step: Optional[PrimExpr] = None,
771796
) -> frame.ForFrame:
772797
"""The unrolled For statement.
773798
@@ -782,6 +807,9 @@ def unroll(
782807
annotations : Dict[str, Any]
783808
The optional annotations of the For statement.
784809
810+
step : PrimExpr
811+
The optional step value of iteration.
812+
785813
Returns
786814
-------
787815
res : frame.ForFrame
@@ -793,7 +821,7 @@ def unroll(
793821
start = IntImm(start.dtype, 0)
794822
else:
795823
start = 0
796-
return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
824+
return _ffi_api.Unroll(start, stop, annotations, step) # type: ignore[attr-defined] # pylint: disable=no-member
797825

798826

799827
def thread_binding(

python/tvm/script/parser/tir/parser.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import contextlib
2020
from functools import partial
21-
from typing import Any
21+
from typing import Any, Dict, Optional
2222

2323
import tvm
2424
from tvm.ir import GlobalVar, PrimType
@@ -168,6 +168,17 @@ def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: b
168168
return default
169169

170170

171+
def range_sugar(
172+
start: PrimExpr,
173+
stop: PrimExpr = None,
174+
step: Optional[PrimExpr] = None,
175+
*,
176+
annotations: Dict[str, Any] = None,
177+
) -> T.frame.ForFrame:
178+
"""The sugar for python range builtin."""
179+
return T.serial(start, stop, annotations=annotations, step=step)
180+
181+
171182
@dispatch.register(token="tir", type_name="For")
172183
def visit_for(self: Parser, node: doc.For) -> None:
173184
"""The for visiting method for tir.
@@ -379,7 +390,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
379390
privacy = find_decorator_annotation(node, "private", default=False)
380391
self.function_annotations = None
381392
with self.var_table.with_frame():
382-
self.var_table.add("range", T.serial)
393+
394+
self.var_table.add("range", range_sugar)
383395
with T.prim_func(is_private=privacy):
384396
T.func_name(node.name)
385397
if node.returns is not None:

python/tvm/tir/stmt.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ class For(Stmt):
145145
The thread this loop binds to. Only valid
146146
if kind is ThreadBinding
147147
148+
step : PrimExpr
149+
The loop step. Default to none which
150+
represent one.
151+
148152
annotations: Optional[Mapping[str, Object]]
149153
Additional annotation hints.
150154
@@ -159,6 +163,7 @@ class For(Stmt):
159163
body: Stmt
160164
thread_binding: Optional[IterVar]
161165
annotations: Mapping[str, Object]
166+
step: Optional[PrimExpr]
162167
span: Optional[Span]
163168

164169
def __init__(
@@ -170,6 +175,7 @@ def __init__(
170175
body: Stmt,
171176
thread_binding: Optional[IterVar] = None,
172177
annotations: Optional[Mapping[str, Object]] = None,
178+
step: Optional[PrimExpr] = None,
173179
span: Optional[Span] = None,
174180
) -> None:
175181
self.__init_handle_by_constructor__(
@@ -181,6 +187,7 @@ def __init__(
181187
body,
182188
thread_binding,
183189
annotations,
190+
step,
184191
span,
185192
)
186193

src/relax/distributed/transform/lower_global_view_to_local_view.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
330330
if (shard > 1) {
331331
arith::Analyzer analyzer;
332332
ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
333-
return For(new_loop->loop_var, new_loop->min, floordiv(new_loop->extent, shard),
334-
new_loop->kind, new_loop->body, new_loop->thread_binding, new_loop->annotations);
333+
new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
334+
return new_loop;
335335
}
336336
}
337337
return new_loop;

src/script/ir_builder/tir/frame.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() {
123123

124124
void ForFrameNode::ExitWithScope() {
125125
TIRFrameNode::ExitWithScope();
126-
AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
126+
AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts)));
127127
}
128128

129129
void AssertFrameNode::ExitWithScope() {

src/script/ir_builder/tir/ir.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,19 +362,23 @@ ffi::Array<Var> Remap(ffi::String kinds, ffi::Array<PrimExpr> bindings, DataType
362362

363363
#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
364364
ForFrame Method(PrimExpr start, PrimExpr stop, \
365-
ffi::Optional<ffi::Map<ffi::String, Any>> annotations) { \
365+
ffi::Optional<ffi::Map<ffi::String, Any>> annotations, \
366+
ffi::Optional<PrimExpr> step) { \
366367
PrimExpr min = start; \
367368
PrimExpr extent = arith::Analyzer().Simplify(stop - start); \
368369
ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>(); \
369370
int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \
370371
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \
371372
n->doms = {Range::FromMinExtent(min, extent)}; \
373+
n->steps = {step}; \
372374
n->f_make_for_loop = [annotations](ffi::Array<Var> vars, ffi::Array<Range> doms, \
375+
ffi::Array<ffi::Optional<PrimExpr>> steps, \
373376
tvm::tir::Stmt body) { \
374377
ICHECK_EQ(vars.size(), 1); \
375378
ICHECK_EQ(doms.size(), 1); \
379+
ICHECK_EQ(steps.size(), 1); \
376380
return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \
377-
annotations.value_or(ffi::Map<ffi::String, Any>())); \
381+
annotations.value_or(ffi::Map<ffi::String, Any>()), steps[0]); \
378382
}; \
379383
return ForFrame(n); \
380384
}
@@ -396,13 +400,16 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread,
396400
DataType dtype = DataType(min.dtype().code(), bits, 1);
397401
n->vars = {Var("v", dtype)};
398402
n->doms = {Range::FromMinExtent(min, extent)};
403+
n->steps = {std::nullopt};
399404
n->f_make_for_loop = [annotations, thread, dtype](ffi::Array<Var> vars, ffi::Array<Range> doms,
405+
ffi::Array<ffi::Optional<PrimExpr>> steps,
400406
Stmt body) -> For {
401407
ICHECK_EQ(vars.size(), 1);
402408
ICHECK_EQ(doms.size(), 1);
409+
ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0])));
403410
IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread);
404411
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var,
405-
annotations.value_or(ffi::Map<ffi::String, ffi::Any>()));
412+
annotations.value_or(ffi::Map<ffi::String, ffi::Any>()), std::nullopt);
406413
};
407414
return ForFrame(n);
408415
}
@@ -412,19 +419,22 @@ ForFrame Grid(ffi::Array<PrimExpr> extents) {
412419
ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();
413420
n->vars.reserve(extents.size());
414421
n->doms.reserve(extents.size());
422+
n->steps.resize(extents.size());
415423
for (const auto& extent : extents) {
416424
DataType dtype = extent.dtype();
417425
n->vars.push_back(Var("v", extent.dtype()));
418426
n->doms.push_back(Range(make_const(dtype, 0), extent));
419427
}
420-
n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms, Stmt body) -> Stmt {
428+
n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms,
429+
ffi::Array<ffi::Optional<PrimExpr>> steps, Stmt body) -> Stmt {
421430
ICHECK_EQ(vars.size(), doms.size());
431+
ICHECK_EQ(vars.size(), steps.size());
422432
int n = vars.size();
423433
for (int i = n - 1; i >= 0; --i) {
424434
Range dom = doms[i];
425435
Var var = vars[i];
426436
body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
427-
/*thread_binding=*/std::nullopt, /*annotations=*/{});
437+
/*thread_binding=*/std::nullopt, /*annotations=*/{}, /*step=*/steps[i]);
428438
}
429439
return body;
430440
};

0 commit comments

Comments
 (0)