@@ -216,7 +216,7 @@ Expr translateExpr(
216216 }
217217}
218218
219- vector<const Variable*> unboundVariables (const vector<Var >& lhs, Expr rhs) {
219+ vector<const Variable*> unboundVariables (const vector<Expr >& lhs, Expr rhs) {
220220 class FindUnboundVariables : public IRVisitor {
221221 using IRVisitor::visit;
222222
@@ -241,14 +241,19 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
241241 set<string> visited;
242242
243243 public:
244- FindUnboundVariables (const vector<Var>& lhs) {
245- for (auto v : lhs) {
246- bound.push (v.name ());
244+ FindUnboundVariables (const vector<Expr>& lhs) {
245+ for (auto e : lhs) {
246+ if (const Variable *v = e.as <Variable>()) {
247+ bound.push (v->name );
248+ }
247249 }
248250 }
249251 vector<const Variable*> result;
250252 } finder (lhs);
251253 rhs.accept (&finder);
254+ for (auto e : lhs) {
255+ e.accept (&finder);
256+ }
252257 return finder.result ;
253258}
254259
@@ -507,22 +512,31 @@ void translateComprehension(
507512 f = Function (c.ident ().name ());
508513 (*funcs)[c.ident ().name ()] = f;
509514 }
515+
516+ // we currently inline all of the let bindings generated in where clauses
517+ // in the future we may consider using Halide Let bindings when they
518+ // are supported later
519+ map<string, Expr> lets;
520+
510521 // Function is the internal Halide IR type for a pipeline
511522 // stage. Func is the front-end class that wraps it. Here it's
512523 // convenient to use both.
513524 Func func (f);
514525
515- vector<Var> lhs;
516- vector<Expr> lhs_as_exprs;
517- for (lang::Ident id : c.indices ()) {
518- lhs.push_back (Var (id.name ()));
519- lhs_as_exprs.push_back (lhs.back ());
526+ vector<Expr> lhs;
527+ vector<Var> lhs_vars;
528+ bool total_definition = true ;
529+ for (lang::TreeRef idx : c.indices ()) {
530+ Expr e = translateExpr (idx, params, *funcs, lets);
531+ if (const Variable *op = e.as <Variable>()) {
532+ lhs_vars.push_back (Var (op->name ));
533+ } else {
534+ total_definition = false ;
535+ lhs_vars.push_back (Var ());
536+ }
537+ lhs.push_back (e);
520538 }
521539
522- // we currently inline all of the let bindings generated in where clauses
523- // in the future we may consider using Halide Let bindings when they
524- // are supported later
525- map<string, Expr> lets;
526540 for (auto wc : c.whereClauses ()) {
527541 if (wc->kind () == lang::TK_LET) {
528542 auto let = lang::Let (wc);
@@ -546,9 +560,8 @@ void translateComprehension(
546560 auto setupIdentity = [&](const Expr& identity, bool zero) {
547561 if (!f.has_pure_definition ()) {
548562 added_implicit_initialization = true ;
549- func (lhs) = (zero) ? identity
550- : undef (rhs.type ()); // undef causes the original value
551- // to remain in input arrays
563+ // undef causes the original value to remain in input arrays
564+ func (lhs_vars) = (zero) ? identity : undef (rhs.type ());
552565 }
553566 };
554567
@@ -587,6 +600,9 @@ void translateComprehension(
587600 break ;
588601
589602 case ' =' :
603+ if (!total_definition) {
604+ setupIdentity (rhs, false );
605+ }
590606 break ;
591607 default :
592608 throw lang::ErrorReport (c) << " Unimplemented reduction "
@@ -618,9 +634,10 @@ void translateComprehension(
618634 for (auto & exp : all_exprs) {
619635 exp = bindParams.mutate (exp);
620636 }
621-
622- // TODO: When the LHS incorporates general expressions we'll need to
623- // bind params there too.
637+ for (auto &e : lhs) {
638+ e = bindParams.mutate (e);
639+ all_exprs.push_back (e);
640+ }
624641
625642 // Do forward bounds inference -- construct an expression that says
626643 // this expression never reads out of bounds on its inputs, and
@@ -660,19 +677,34 @@ void translateComprehension(
660677 // (e.g. an in-place stencil)?. The .bound directive will use the
661678 // bounds of the last stage for all stages.
662679
663- // Does a tensor have a single bound, or can its bounds shrink over
664- // time? Solve for a single bound for now.
665-
666- for (Var v : lhs) {
667- if (!solution.contains (v.name ())) {
668- throw lang::ErrorReport (c)
680+ // Set the bounds to be the union of the boxes written to by every
681+ // comprehension touching the tensor.
682+ for (size_t i = 0 ; i < lhs.size (); i++) {
683+ Expr e = lhs[i];
684+ if (const Variable *v = e.as <Variable>()) {
685+ if (!solution.contains (v->name )) {
686+ throw lang::ErrorReport (c)
669687 << " Free variable " << v
670688 << " was not solved in range inference. May not be used right-hand side" ;
689+ }
690+ }
691+
692+ Interval in = bounds_of_expr_in_scope (e, solution);
693+ if (!in.is_bounded ()) {
694+ throw lang::ErrorReport (c.indices ()[i])
695+ << " Left-hand side expression is unbounded" ;
696+ }
697+ in.min = cast<int >(in.min );
698+ in.max = cast<int >(in.max );
699+
700+ map<string, Interval> &b = (*bounds)[f];
701+ string dim_name = f.dimensions () ? f.args ()[i] : lhs_vars[i].name ();
702+ auto old = b.find (dim_name);
703+ if (old != b.end ()) {
704+ // Take the union with any existing bounds
705+ in.include (old->second );
671706 }
672- // TODO: We're enforcing a single bound across all comprehensions
673- // for now. We should really check later ones are equal to earlier
674- // ones instead of just clobbering.
675- (*bounds)[f][v.name ()] = solution.get (v.name ());
707+ b[dim_name] = in;
676708 }
677709
678710 // Free variables that appear on the rhs but not the lhs are
@@ -703,6 +735,9 @@ void translateComprehension(
703735 for (auto v : unbound) {
704736 Expr rv = Variable::make (Int (32 ), v->name , domain);
705737 rhs = substitute (v->name , rv, rhs);
738+ for (Expr &e : lhs) {
739+ e = substitute (v->name , rv, e);
740+ }
706741 }
707742 rdom = RDom (domain);
708743 }
@@ -718,9 +753,12 @@ void translateComprehension(
718753 }
719754 }
720755 while (!lhs.empty ()) {
721- loop_nest.push_back (lhs.back ());
756+ if (const Variable *v = lhs.back ().as <Variable>()) {
757+ loop_nest.push_back (Var (v->name ));
758+ }
722759 lhs.pop_back ();
723760 }
761+ stage.reorder (loop_nest);
724762
725763 if (added_implicit_initialization) {
726764 // Also reorder reduction initializations to the TC convention
@@ -734,7 +772,6 @@ void translateComprehension(
734772 }
735773
736774 func.compute_root ();
737- stage.reorder (loop_nest);
738775}
739776
740777HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
0 commit comments