@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
6262 }
6363}
6464
65+ // translate the TC def input params to corresponding Halide components.
66+ // params, inputs will be populated here
6567void translateParam (
6668 const lang::Param& p,
6769 map<string, Parameter>* params,
6870 vector<ImageParam>* inputs) {
71+ // check if the param is already converted to halide components
6972 if (params->find (p.ident ().name ()) != params->end ()) {
7073 return ;
71- } else {
72- lang::TensorType type = p.tensorType ();
73- int dimensions = (int )type.dims ().size ();
74- ImageParam imageParam (
75- translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
76- inputs->push_back (imageParam);
77- vector<Expr> dims;
78- for (auto d_ : type.dims ()) {
79- if (d_->kind () == lang::TK_IDENT) {
80- auto d = lang::Ident (d_);
81- auto it = params->find (d.name ());
82- Parameter p;
83- if (it != params->end ()) {
84- p = it->second ;
85- } else {
86- p = Parameter (Int (32 ), false , 0 , d.name (), true );
87- (*params)[d.name ()] = p;
88- }
89- dims.push_back (Variable::make (Int (32 ), p.name (), p));
74+ }
75+ lang::TensorType type = p.tensorType ();
76+ int dimensions = (int )type.dims ().size ();
77+ ImageParam imageParam (
78+ translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79+ inputs->push_back (imageParam);
80+ vector<Expr> dims;
81+ for (auto d_ : type.dims ()) {
82+ if (d_->kind () == lang::TK_IDENT) {
83+ auto d = lang::Ident (d_);
84+ auto it = params->find (d.name ());
85+ Parameter p;
86+ if (it != params->end ()) {
87+ p = it->second ;
9088 } else {
91- CHECK (d_->kind () == lang::TK_CONST);
92- int32_t value = lang::Const (d_).value ();
93- dims.push_back (Expr (value));
89+ p = Parameter (Int (32 ), false , 0 , d.name (), true );
90+ (*params)[d.name ()] = p;
9491 }
92+ dims.push_back (Variable::make (Int (32 ), p.name (), p));
93+ } else {
94+ CHECK (d_->kind () == lang::TK_CONST);
95+ int32_t value = lang::Const (d_).value ();
96+ dims.push_back (Expr (value));
9597 }
98+ }
9699
97- for (int i = 0 ; i < imageParam.dimensions (); i++) {
98- imageParam.dim (i).set_bounds (0 , dims[i]);
99- }
100- (*params)[imageParam.name ()] = imageParam.parameter ();
100+ for (int i = 0 ; i < imageParam.dimensions (); i++) {
101+ imageParam.dim (i).set_bounds (0 , dims[i]);
101102 }
103+ (*params)[imageParam.name ()] = imageParam.parameter ();
102104}
103105
104106void translateOutput (
@@ -156,6 +158,8 @@ Expr translateExpr(
156158 return t (0 ) * t (1 );
157159 case ' /' :
158160 return t (0 ) / t (1 );
161+ case ' %' :
162+ return t (0 ) % t (1 );
159163 case lang::TK_MIN:
160164 return min (t (0 ), t (1 ));
161165 case lang::TK_MAX:
@@ -492,20 +496,25 @@ Expr reductionUpdate(Expr e) {
492496 return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
493497}
494498
499+ // translate a single TC comprehension/statement to Halide component.
500+ // funcs, bounds, reductions will be populated
495501void translateComprehension (
496- const lang::Comprehension& c ,
502+ const lang::Comprehension& comprehension ,
497503 const map<string, Parameter>& params,
498504 bool throwWarnings,
499505 map<string, Function>* funcs,
500506 FunctionBounds* bounds,
501507 vector<Function>* reductions) {
508+ // Function is the internal Halide IR type for a pipeline
509+ // stage. Func is the front-end class that wraps it. Here it's
510+ // convenient to use both. Why? what is not exposed in Func?
502511 Function f;
503- auto it = funcs->find (c .ident ().name ());
512+ auto it = funcs->find (comprehension .ident ().name ());
504513 if (it != funcs->end ()) {
505514 f = it->second ;
506515 } else {
507- f = Function (c .ident ().name ());
508- (*funcs)[c .ident ().name ()] = f;
516+ f = Function (comprehension .ident ().name ());
517+ (*funcs)[comprehension .ident ().name ()] = f;
509518 }
510519 // Function is the internal Halide IR type for a pipeline
511520 // stage. Func is the front-end class that wraps it. Here it's
@@ -514,7 +523,7 @@ void translateComprehension(
514523
515524 vector<Var> lhs;
516525 vector<Expr> lhs_as_exprs;
517- for (lang::Ident id : c .indices ()) {
526+ for (lang::Ident id : comprehension .indices ()) {
518527 lhs.push_back (Var (id.name ()));
519528 lhs_as_exprs.push_back (lhs.back ());
520529 }
@@ -523,17 +532,17 @@ void translateComprehension(
523532 // in the future we may consider using Halide Let bindings when they
524533 // are supported later
525534 map<string, Expr> lets;
526- for (auto wc : c .whereClauses ()) {
535+ for (auto wc : comprehension .whereClauses ()) {
527536 if (wc->kind () == lang::TK_LET) {
528537 auto let = lang::Let (wc);
529538 lets[let.name ().name ()] = translateExpr (let.rhs (), params, *funcs, lets);
530539 }
531540 }
532541
533- Expr rhs = translateExpr (c .rhs (), params, *funcs, lets);
542+ Expr rhs = translateExpr (comprehension .rhs (), params, *funcs, lets);
534543
535544 std::vector<Expr> all_exprs;
536- for (auto wc : c .whereClauses ()) {
545+ for (auto wc : comprehension .whereClauses ()) {
537546 if (wc->kind () == lang::TK_EXISTS) {
538547 all_exprs.push_back (
539548 translateExpr (lang::Exists (wc).exp (), params, *funcs, lets));
@@ -557,7 +566,7 @@ void translateComprehension(
557566 // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
558567 // for the reduction and then applies the reduction.
559568 bool should_zero = false ;
560- switch (c .assignment ()->kind ()) {
569+ switch (comprehension .assignment ()->kind ()) {
561570 case lang::TK_PLUS_EQ_B:
562571 should_zero = true ; // fallthrough
563572 case lang::TK_PLUS_EQ:
@@ -589,11 +598,12 @@ void translateComprehension(
589598 case ' =' :
590599 break ;
591600 default :
592- throw lang::ErrorReport (c) << " Unimplemented reduction "
593- << c.assignment ()->range ().text () << " \n " ;
601+ throw lang::ErrorReport (comprehension)
602+ << " Unimplemented reduction "
603+ << comprehension.assignment ()->range ().text () << " \n " ;
594604 }
595605
596- if (c .assignment ()->kind () != ' =' ) {
606+ if (comprehension .assignment ()->kind () != ' =' ) {
597607 reductions->push_back (f);
598608 }
599609
@@ -633,7 +643,7 @@ void translateComprehension(
633643 Scope<Interval> solution;
634644
635645 // Put anything explicitly specified with a 'where' class in the solution
636- for (auto constraint_ : c .whereClauses ()) {
646+ for (auto constraint_ : comprehension .whereClauses ()) {
637647 if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
638648 continue ;
639649 auto constraint = lang::RangeConstraint (constraint_);
@@ -654,7 +664,8 @@ void translateComprehension(
654664
655665 // Infer the rest
656666 all_exprs.push_back (rhs);
657- forwardBoundsInference (all_exprs, *bounds, c, throwWarnings, &solution);
667+ forwardBoundsInference (
668+ all_exprs, *bounds, comprehension, throwWarnings, &solution);
658669
659670 // TODO: What if subsequent updates have incompatible bounds
660671 // (e.g. an in-place stencil)?. The .bound directive will use the
@@ -665,7 +676,7 @@ void translateComprehension(
665676
666677 for (Var v : lhs) {
667678 if (!solution.contains (v.name ())) {
668- throw lang::ErrorReport (c )
679+ throw lang::ErrorReport (comprehension )
669680 << " Free variable " << v
670681 << " was not solved in range inference. May not be used right-hand side" ;
671682 }
@@ -689,7 +700,7 @@ void translateComprehension(
689700 for (size_t i = 0 ; i < unbound.size (); i++) {
690701 auto v = unbound[unbound.size () - 1 - i];
691702 if (!solution.contains (v->name )) {
692- throw lang::ErrorReport (c )
703+ throw lang::ErrorReport (comprehension )
693704 << " Free variable " << v << " is unconstrained. "
694705 << " Use a 'where' clause to set its range." ;
695706 }
@@ -737,6 +748,7 @@ void translateComprehension(
737748 stage.reorder (loop_nest);
738749}
739750
751+ // translate a semantically checked TC def to Halide components struct
740752HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
741753 map<string, Function> funcs;
742754 HalideComponents components;
@@ -956,6 +968,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
956968 lang::Def (lang::Sema ().checkFunction (treeRef)), throwWarnings);
957969}
958970
971+ // NOTE: there is no guarantee here that the tc string has only one def. It
972+ // could have many defs. Only first def will be converted in that case.
959973HalideComponents
960974translate (isl::ctx ctx, const std::string& tc, bool throwWarnings) {
961975 LOG_IF (INFO, tc::FLAGS_debug_halide) << tc;
0 commit comments