1717
1818#include < algorithm>
1919#include < numeric>
20+ #include < tuple>
2021#include < unordered_set>
2122
2223#include " tc/core/constants.h"
@@ -228,7 +229,20 @@ isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
228229 return space;
229230}
230231
231- isl::map extractAccess (
232+ // Extract a tagged affine access relation from Halide IR.
233+ // The relation is tagged with a unique identifier, i.e. it lives in the space
234+ // [D[...] -> __tc_ref_#[]] -> A[]
235+ // where # is a unique sequential number, D is the statement identifier
236+ // extracted from "domain" and A is the tensor identifier constructed from
237+ // "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
238+ // which a particular reference # appeared.
239+ // Returns the access relation and a flag indicating whether this relation is
240+ // exact or not. The relation is overapproximated (that is, not exact) if it
241+ // represents a non-affine access, for example, an access with indirection such
242+ // as O(Index(i)) = 42. In such overapproximated access relation, dimensions
243+ // that correspond to affine subscripts are still exact while those that
244+ // correspond to non-affine subscripts are not constrained.
245+ std::pair<isl::map, bool > extractAccess (
232246 isl::set domain,
233247 const IRNode* op,
234248 const std::string& tensor,
@@ -257,6 +271,7 @@ isl::map extractAccess(
257271 isl::map map =
258272 isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
259273
274+ bool exact = true ;
260275 for (size_t i = 0 ; i < args.size (); i++) {
261276 // Then add one equality constraint per dimension to encode the
262277 // point in the allocation actually read/written for each point in
@@ -268,47 +283,64 @@ isl::map extractAccess(
268283 isl::pw_aff (isl::local_space (rangeSpace), isl::dim_type::set, i);
269284 // ... equals the coordinate accessed as a function of the domain.
270285 auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace, args[i]);
271- if (!domainPoint.is_null ()) {
286+ if (!domainPoint) {
287+ exact = false ;
288+ } else {
272289 map = map.intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
273290 }
274291 }
275292
276- return map;
293+ return std::make_pair ( map, exact) ;
277294}
278295
279- std::pair< isl::union_map, isl::union_map>
296+ std::tuple<isl::union_map, isl::union_map, isl::union_map>
280297extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
281298 class FindAccesses : public IRGraphVisitor {
282299 using IRGraphVisitor::visit;
283300
284301 void visit (const Call* op) override {
285302 IRGraphVisitor::visit (op);
286303 if (op->call_type == Call::Halide || op->call_type == Call::Image) {
287- reads = reads.unite (
288- extractAccess (domain, op, op->name , op->args , accesses));
304+ // Read relations can be safely overapproximated.
305+ isl::map read;
306+ std::tie (read, std::ignore) =
307+ extractAccess (domain, op, op->name , op->args , accesses);
308+ reads = reads.unite (read);
289309 }
290310 }
291311
292312 void visit (const Provide* op) override {
293313 IRGraphVisitor::visit (op);
294- writes =
295- writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
314+
315+ // If the write access relation is not exact, we consider that any
316+ // element _may_ be written by the statement. If it is exact, then we
317+ // can guarantee that all the elements specified by the relation _must_
318+ // be written and any previously stored value will be killed.
319+ isl::map write;
320+ bool exact;
321+ std::tie (write, exact) =
322+ extractAccess (domain, op, op->name , op->args , accesses);
323+ if (exact) {
324+ mustWrites = mustWrites.unite (write);
325+ }
326+ mayWrites = mayWrites.unite (write);
296327 }
297328
298329 const isl::set& domain;
299330 AccessMap* accesses;
300331
301332 public:
302- isl::union_map reads, writes ;
333+ isl::union_map reads, mayWrites, mustWrites ;
303334
304335 FindAccesses (const isl::set& domain, AccessMap* accesses)
305336 : domain(domain),
306337 accesses (accesses),
307338 reads(isl::union_map::empty(domain.get_space())),
308- writes(isl::union_map::empty(domain.get_space())) {}
339+ mayWrites(isl::union_map::empty(domain.get_space())),
340+ mustWrites(isl::union_map::empty(domain.get_space())) {}
309341 } finder(domain, accesses);
310342 s.accept(&finder);
311- return { finder.reads , finder.writes } ;
343+ return std::make_tuple( finder.reads, finder.mayWrites, finder.mustWrites) ;
312344}
313345
314346/*
@@ -333,7 +365,8 @@ isl::schedule makeScheduleTreeHelper(
333365 isl::set set,
334366 std::vector<std::string>& outer,
335367 isl::union_map* reads,
336- isl::union_map* writes,
368+ isl::union_map* mayWrites,
369+ isl::union_map* mustWrites,
337370 AccessMap* accesses,
338371 StatementMap* statements,
339372 IteratorMap* iterators) {
@@ -379,7 +412,8 @@ isl::schedule makeScheduleTreeHelper(
379412 set,
380413 outerNext,
381414 reads,
382- writes,
415+ mayWrites,
416+ mustWrites,
383417 accesses,
384418 statements,
385419 iterators);
@@ -412,7 +446,15 @@ isl::schedule makeScheduleTreeHelper(
412446 std::vector<isl::schedule> schedules;
413447 for (Stmt s : stmts) {
414448 schedules.push_back (makeScheduleTreeHelper (
415- s, set, outer, reads, writes, accesses, statements, iterators));
449+ s,
450+ set,
451+ outer,
452+ reads,
453+ mayWrites,
454+ mustWrites,
455+ accesses,
456+ statements,
457+ iterators));
416458 }
417459 schedule = schedules[0 ].sequence (schedules[1 ]);
418460
@@ -427,23 +469,25 @@ isl::schedule makeScheduleTreeHelper(
427469 isl::set domain = set.set_tuple_id (id);
428470 schedule = isl::schedule::from_domain (domain);
429471
430- isl::union_map newReads, newWrites ;
431- std::tie (newReads, newWrites ) =
472+ isl::union_map newReads, newMayWrites, newMustWrites ;
473+ std::tie (newReads, newMayWrites, newMustWrites ) =
432474 halide2isl::extractAccesses (domain, op, accesses);
433475
434476 *reads = reads->unite (newReads);
435- *writes = writes->unite (newWrites);
477+ *mayWrites = mayWrites->unite (newMayWrites);
478+ *mustWrites = mustWrites->unite (newMustWrites);
436479
437480 } else {
438481 LOG (FATAL) << " Unhandled Halide stmt: " << s;
439482 }
440483 return schedule;
441- };
484+ }
442485
443486ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
444487 ScheduleTreeAndAccesses result;
445488
446- result.writes = result.reads = isl::union_map::empty (paramSpace);
489+ result.mayWrites = result.mustWrites = result.reads =
490+ isl::union_map::empty (paramSpace);
447491
448492 // Walk the IR building a schedule tree
449493 std::vector<std::string> outer;
@@ -452,7 +496,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
452496 isl::set::universe (paramSpace),
453497 outer,
454498 &result.reads ,
455- &result.writes ,
499+ &result.mayWrites ,
500+ &result.mustWrites ,
456501 &result.accesses ,
457502 &result.statements ,
458503 &result.iterators );
0 commit comments