@@ -176,43 +176,33 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
176176 }
177177};
178178
179- // Simplify polygeist.subindex to memref.subview .
180- class SubToSubView final : public OpRewritePattern<SubIndexOp> {
179+ // Simplify polygeist.subindex to a memref.reinterpret_cast .
180+ class SubToReinterpretCast final : public OpRewritePattern<SubIndexOp> {
181181public:
182182 using OpRewritePattern<SubIndexOp>::OpRewritePattern;
183183
184184 LogicalResult matchAndRewrite (SubIndexOp op,
185185 PatternRewriter &rewriter) const override {
186186 auto srcMemRefType = op.source ().getType ().cast <MemRefType>();
187187 auto resMemRefType = op.result ().getType ().cast <MemRefType>();
188- auto dims = srcMemRefType.getShape (). size ();
188+ auto shape = srcMemRefType.getShape ();
189189
190- // For now, restrict subview lowering to statically defined memref's
191- if (!srcMemRefType.hasStaticShape () | !resMemRefType.hasStaticShape ())
190+ if (!resMemRefType.hasStaticShape ())
192191 return failure ();
193192
194- // For now, restrict to simple rank-reducing indexing
195- if (srcMemRefType.getShape ().size () <= resMemRefType.getShape ().size ())
196- return failure ();
193+ int64_t innerSize = resMemRefType.getNumElements ();
194+ auto offset = rewriter.create <arith::MulIOp>(
195+ op.getLoc (), op.index (),
196+ rewriter.create <ConstantIndexOp>(op.getLoc (), innerSize));
197197
198- // Build offset, sizes and strides
199- SmallVector<OpFoldResult> sizes (dims, rewriter.getIndexAttr (0 ));
200- sizes[0 ] = op.index ();
201- SmallVector<OpFoldResult> offsets (dims);
202- for (auto dim : llvm::enumerate (srcMemRefType.getShape ())) {
203- if (dim.index () == 0 )
204- offsets[0 ] = rewriter.getIndexAttr (1 );
205- else
206- offsets[dim.index ()] = rewriter.getIndexAttr (dim.value ());
198+ llvm::SmallVector<OpFoldResult> sizes, strides;
199+ for (auto dim : shape.drop_front ()) {
200+ sizes.push_back (rewriter.getIndexAttr (dim));
201+ strides.push_back (rewriter.getIndexAttr (1 ));
207202 }
208- SmallVector<OpFoldResult> strides (dims, rewriter.getIndexAttr (1 ));
209-
210- // Generate the appropriate return type:
211- auto subMemRefType = MemRefType::get (srcMemRefType.getShape ().drop_front (),
212- srcMemRefType.getElementType ());
213203
214- rewriter.replaceOpWithNewOp <memref::SubViewOp >(
215- op, subMemRefType , op.source (), sizes, offsets , strides);
204+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp >(
205+ op, resMemRefType , op.source (), offset. getResult (), sizes , strides);
216206
217207 return success ();
218208 }
@@ -677,8 +667,8 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
677667 MLIRContext *context) {
678668 results.insert <CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
679669 SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
680- RedundantDynSubIndex>(context);
681- // Disabled: SubToSubView
670+ RedundantDynSubIndex, SubToReinterpretCast >(context);
671+ // Disabled:
682672}
683673
684674// / Simplify pointer2memref(memref2pointer(x)) to cast(x)
0 commit comments