diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index 671c7359..cbb9c916 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -1217,10 +1217,6 @@ LogicalResult PtrAnalysis::rewriteAddptrOp(triton::AddPtrOp op) { auto maketptrOp = state.createTTSMakeTensorPtrOp(builder, op.getLoc()); ptrMap.map(op.getResult(), maketptrOp.getResult()); } else if (enableMakeGatherScatterTensorPtr) { - // If there is only one dimension, return failure since there are no - // continuous dimensions. - if (state.getRank() == 1) - return failure(); PtrState unstructuredState; // Switch to unstructured state analysis to create offsets and strides // for the non-structured dimension. diff --git a/test/Conversion/StructuredToMemref/gather_scatter_ptr_in_loop_to_linalg.mlir b/test/Conversion/StructuredToMemref/gather_scatter_ptr_in_loop_to_linalg.mlir index f3f1fee9..af41515e 100644 --- a/test/Conversion/StructuredToMemref/gather_scatter_ptr_in_loop_to_linalg.mlir +++ b/test/Conversion/StructuredToMemref/gather_scatter_ptr_in_loop_to_linalg.mlir @@ -56,205 +56,198 @@ // CHECK: %[[VAL_44:.*]] = arith.index_cast %[[VAL_43]] : i32 to index // CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_3]], %[[VAL_4]] : i32 // CHECK: %[[VAL_46:.*]] = linalg.fill ins(%[[VAL_45]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_47:.*]] = tensor.empty() : tensor<64xi1> -// CHECK: %[[VAL_48:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_38]], %[[VAL_46]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_47]] : tensor<64xi1>) { -// CHECK: ^bb0(%[[VAL_49:.*]]: i32, %[[VAL_50:.*]]: i32, %[[VAL_51:.*]]: i1): -// CHECK: %[[VAL_52:.*]] = arith.cmpi slt, %[[VAL_49]], %[[VAL_50]] : i32 -// CHECK: linalg.yield %[[VAL_52]] : i1 -// CHECK: } -> tensor<64xi1> -// CHECK: %[[VAL_53:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_38]], %[[VAL_46]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_38]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_54:.*]]: i32, %[[VAL_55:.*]]: i32, %[[VAL_56:.*]]: i32): -// CHECK: %[[VAL_57:.*]] = arith.divsi %[[VAL_54]], %[[VAL_55]] : i32 -// CHECK: linalg.yield %[[VAL_57]] : i32 +// CHECK: %[[VAL_47:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_38]], %[[VAL_46]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_38]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_48:.*]]: i32, %[[VAL_49:.*]]: i32, %[[VAL_50:.*]]: i32): +// CHECK: %[[VAL_51:.*]] = arith.divsi %[[VAL_48]], %[[VAL_49]] : i32 +// CHECK: linalg.yield %[[VAL_51]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_58:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_38]], %[[VAL_46]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_38]] : tensor<64xi32>) { +// CHECK: %[[VAL_52:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_38]], %[[VAL_46]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_38]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_53:.*]]: i32, %[[VAL_54:.*]]: i32, %[[VAL_55:.*]]: i32): +// CHECK: %[[VAL_56:.*]] = arith.remsi %[[VAL_53]], %[[VAL_54]] : i32 +// CHECK: linalg.yield %[[VAL_56]] : i32 +// CHECK: } -> tensor<64xi32> +// CHECK: %[[VAL_57:.*]] = linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_58:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_52]], %[[VAL_57]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_52]] : tensor<64xi32>) { // CHECK: ^bb0(%[[VAL_59:.*]]: i32, %[[VAL_60:.*]]: i32, %[[VAL_61:.*]]: i32): -// CHECK: %[[VAL_62:.*]] = arith.remsi %[[VAL_59]], %[[VAL_60]] : i32 +// CHECK: %[[VAL_62:.*]] = arith.divsi %[[VAL_59]], %[[VAL_60]] : i32 // CHECK: linalg.yield %[[VAL_62]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_63:.*]] = linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_64:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_58]], %[[VAL_63]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_58]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_65:.*]]: i32, %[[VAL_66:.*]]: i32, %[[VAL_67:.*]]: i32): -// CHECK: %[[VAL_68:.*]] = arith.divsi %[[VAL_65]], %[[VAL_66]] : i32 -// CHECK: linalg.yield %[[VAL_68]] : i32 -// CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_69:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_58]], %[[VAL_63]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_58]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_70:.*]]: i32, %[[VAL_71:.*]]: i32, %[[VAL_72:.*]]: i32): -// CHECK: %[[VAL_73:.*]] = arith.remsi %[[VAL_70]], %[[VAL_71]] : i32 -// CHECK: linalg.yield %[[VAL_73]] : i32 +// CHECK: %[[VAL_63:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_52]], %[[VAL_57]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_52]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_64:.*]]: i32, %[[VAL_65:.*]]: i32, %[[VAL_66:.*]]: i32): +// CHECK: %[[VAL_67:.*]] = arith.remsi %[[VAL_64]], %[[VAL_65]] : i32 +// CHECK: linalg.yield %[[VAL_67]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_74:.*]] = scf.for %[[VAL_75:.*]] = %[[VAL_27]] to %[[VAL_6]] step %[[VAL_24]] iter_args(%[[VAL_76:.*]] = %[[VAL_30]]) -> (tensor<64x64xf32>) : i32 { -// CHECK: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_78:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_53]], %[[VAL_77]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_53]] : tensor<64xi32>) { +// CHECK: %[[VAL_68:.*]] = scf.for %[[VAL_69:.*]] = %[[VAL_27]] to %[[VAL_6]] step %[[VAL_24]] iter_args(%[[VAL_70:.*]] = %[[VAL_30]]) -> (tensor<64x64xf32>) : i32 { +// CHECK: %[[VAL_71:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_72:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_47]], %[[VAL_71]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_47]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_73:.*]]: i32, %[[VAL_74:.*]]: i32, %[[VAL_75:.*]]: i32): +// CHECK: %[[VAL_76:.*]] = arith.muli %[[VAL_73]], %[[VAL_74]] : i32 +// CHECK: linalg.yield %[[VAL_76]] : i32 +// CHECK: } -> tensor<64xi32> +// CHECK: %[[VAL_77:.*]] = linalg.fill ins(%[[VAL_8]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_78:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_58]], %[[VAL_77]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_58]] : tensor<64xi32>) { // CHECK: ^bb0(%[[VAL_79:.*]]: i32, %[[VAL_80:.*]]: i32, %[[VAL_81:.*]]: i32): // CHECK: %[[VAL_82:.*]] = arith.muli %[[VAL_79]], %[[VAL_80]] : i32 // CHECK: linalg.yield %[[VAL_82]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_83:.*]] = linalg.fill ins(%[[VAL_8]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_84:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_64]], %[[VAL_83]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_64]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_85:.*]]: i32, %[[VAL_86:.*]]: i32, %[[VAL_87:.*]]: i32): -// CHECK: %[[VAL_88:.*]] = arith.muli %[[VAL_85]], %[[VAL_86]] : i32 -// CHECK: linalg.yield %[[VAL_88]] : i32 +// CHECK: %[[VAL_83:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_72]], %[[VAL_78]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_72]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_84:.*]]: i32, %[[VAL_85:.*]]: i32, %[[VAL_86:.*]]: i32): +// CHECK: %[[VAL_87:.*]] = arith.addi %[[VAL_84]], %[[VAL_85]] : i32 +// CHECK: linalg.yield %[[VAL_87]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_89:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_78]], %[[VAL_84]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_78]] : tensor<64xi32>) { +// CHECK: %[[VAL_88:.*]] = linalg.fill ins(%[[VAL_9]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_89:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_63]], %[[VAL_88]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_63]] : tensor<64xi32>) { // CHECK: ^bb0(%[[VAL_90:.*]]: i32, %[[VAL_91:.*]]: i32, %[[VAL_92:.*]]: i32): -// CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_90]], %[[VAL_91]] : i32 +// CHECK: %[[VAL_93:.*]] = arith.muli %[[VAL_90]], %[[VAL_91]] : i32 // CHECK: linalg.yield %[[VAL_93]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_94:.*]] = linalg.fill ins(%[[VAL_9]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_95:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_69]], %[[VAL_94]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_69]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_96:.*]]: i32, %[[VAL_97:.*]]: i32, %[[VAL_98:.*]]: i32): -// CHECK: %[[VAL_99:.*]] = arith.muli %[[VAL_96]], %[[VAL_97]] : i32 -// CHECK: linalg.yield %[[VAL_99]] : i32 +// CHECK: %[[VAL_94:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_83]], %[[VAL_89]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_83]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_95:.*]]: i32, %[[VAL_96:.*]]: i32, %[[VAL_97:.*]]: i32): +// CHECK: %[[VAL_98:.*]] = arith.addi %[[VAL_95]], %[[VAL_96]] : i32 +// CHECK: linalg.yield %[[VAL_98]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_100:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_89]], %[[VAL_95]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_89]] : tensor<64xi32>) { +// CHECK: %[[VAL_99:.*]] = linalg.fill ins(%[[VAL_69]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_100:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_94]], %[[VAL_99]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_94]] : tensor<64xi32>) { // CHECK: ^bb0(%[[VAL_101:.*]]: i32, %[[VAL_102:.*]]: i32, %[[VAL_103:.*]]: i32): // CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_101]], %[[VAL_102]] : i32 // CHECK: linalg.yield %[[VAL_104]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_105:.*]] = linalg.fill ins(%[[VAL_75]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_106:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_100]], %[[VAL_105]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_100]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_107:.*]]: i32, %[[VAL_108:.*]]: i32, %[[VAL_109:.*]]: i32): -// CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_107]], %[[VAL_108]] : i32 -// CHECK: linalg.yield %[[VAL_110]] : i32 -// CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_111:.*]] = memref.cast %[[VAL_0]] : memref<*xf32> to memref -// CHECK: %[[VAL_112:.*]] = bufferization.to_tensor %[[VAL_111]] restrict : memref to tensor -// CHECK: %[[VAL_113:.*]] = tensor.empty() : tensor<64xf32> -// CHECK: %[[VAL_114:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_106]], %[[VAL_48]] : tensor<64xi32>, tensor<64xi1>) outs(%[[VAL_113]] : tensor<64xf32>) { -// CHECK: ^bb0(%[[VAL_115:.*]]: i32, %[[VAL_116:.*]]: i1, %[[VAL_117:.*]]: f32): -// CHECK: %[[VAL_118:.*]] = scf.if %[[VAL_116]] -> (f32) { -// CHECK: %[[VAL_119:.*]] = arith.index_cast %[[VAL_115]] : i32 to index -// CHECK: %[[VAL_120:.*]] = tensor.extract %[[VAL_112]]{{\[}}%[[VAL_119]]] : tensor -// CHECK: scf.yield %[[VAL_120]] : f32 -// CHECK: } else { -// CHECK: scf.yield %[[VAL_26]] : f32 -// CHECK: } -// CHECK: linalg.yield %[[VAL_118]] : f32 -// CHECK: } -> tensor<64xf32> -// CHECK: %[[VAL_121:.*]] = arith.index_cast %[[VAL_12]] : i32 to index -// CHECK: %[[VAL_122:.*]] = arith.muli %[[VAL_44]], %[[VAL_121]] : index -// CHECK: %[[VAL_123:.*]] = linalg.fill ins(%[[VAL_10]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_124:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_64]], %[[VAL_123]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_64]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_125:.*]]: i32, %[[VAL_126:.*]]: i32, %[[VAL_127:.*]]: i32): -// CHECK: %[[VAL_128:.*]] = arith.muli %[[VAL_125]], %[[VAL_126]] : i32 -// CHECK: linalg.yield %[[VAL_128]] : i32 +// CHECK: %[[VAL_105:.*]] = arith.index_cast %[[VAL_31]] : i32 to index +// CHECK: %[[VAL_106:.*]] = arith.addi %[[VAL_105]], %[[VAL_25]] : index +// CHECK: %[[VAL_107:.*]] = arith.index_cast %[[VAL_45]] : i32 to index +// CHECK: %[[VAL_108:.*]] = arith.minsi %[[VAL_106]], %[[VAL_107]] : index +// CHECK: %[[VAL_109:.*]] = arith.maxsi %[[VAL_108]], %[[VAL_105]] : index +// CHECK: %[[VAL_110:.*]] = arith.subi %[[VAL_109]], %[[VAL_105]] : index +// CHECK: %[[VAL_111:.*]] = memref.alloc() : memref<64xf32> +// CHECK: %[[VAL_112:.*]] = arith.cmpi slt, %[[VAL_110]], %[[VAL_25]] : index +// CHECK: scf.if %[[VAL_112]] { +// CHECK: linalg.fill ins(%[[VAL_26]] : f32) outs(%[[VAL_111]] : memref<64xf32>) +// CHECK: } +// CHECK: %[[VAL_113:.*]] = arith.minsi %[[VAL_110]], %[[VAL_25]] : index +// CHECK: scf.for %[[VAL_114:.*]] = %[[VAL_23]] to %[[VAL_113]] step %[[VAL_22]] { +// CHECK: %[[VAL_115:.*]] = tensor.extract %[[VAL_100]]{{\[}}%[[VAL_114]]] : tensor<64xi32> +// CHECK: %[[VAL_116:.*]] = arith.index_cast %[[VAL_115]] : i32 to index +// CHECK: %[[VAL_117:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_116]]], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_118:.*]] = memref.subview %[[VAL_111]]{{\[}}%[[VAL_114]]] [1] [1] : memref<64xf32> to memref<1xf32, strided<[1], offset: ?>> +// CHECK: memref.copy %[[VAL_117]], %[[VAL_118]] : memref<1xf32, strided<[1], offset: ?>> to memref<1xf32, strided<[1], offset: ?>> +// CHECK: } +// CHECK: %[[VAL_119:.*]] = bufferization.to_tensor %[[VAL_111]] restrict writable : memref<64xf32> to tensor<64xf32> +// CHECK: %[[VAL_120:.*]] = arith.index_cast %[[VAL_12]] : i32 to index +// CHECK: %[[VAL_121:.*]] = arith.muli %[[VAL_44]], %[[VAL_120]] : index +// CHECK: %[[VAL_122:.*]] = linalg.fill ins(%[[VAL_10]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_123:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_58]], %[[VAL_122]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_58]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_124:.*]]: i32, %[[VAL_125:.*]]: i32, %[[VAL_126:.*]]: i32): +// CHECK: %[[VAL_127:.*]] = arith.muli %[[VAL_124]], %[[VAL_125]] : i32 +// CHECK: linalg.yield %[[VAL_127]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_129:.*]] = linalg.fill ins(%[[VAL_11]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_130:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_69]], %[[VAL_129]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_69]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_131:.*]]: i32, %[[VAL_132:.*]]: i32, %[[VAL_133:.*]]: i32): -// CHECK: %[[VAL_134:.*]] = arith.muli %[[VAL_131]], %[[VAL_132]] : i32 -// CHECK: linalg.yield %[[VAL_134]] : i32 +// CHECK: %[[VAL_128:.*]] = linalg.fill ins(%[[VAL_11]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_129:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_63]], %[[VAL_128]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_63]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_130:.*]]: i32, %[[VAL_131:.*]]: i32, %[[VAL_132:.*]]: i32): +// CHECK: %[[VAL_133:.*]] = arith.muli %[[VAL_130]], %[[VAL_131]] : i32 +// CHECK: linalg.yield %[[VAL_133]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_135:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_124]], %[[VAL_130]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_124]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_136:.*]]: i32, %[[VAL_137:.*]]: i32, %[[VAL_138:.*]]: i32): -// CHECK: %[[VAL_139:.*]] = arith.addi %[[VAL_136]], %[[VAL_137]] : i32 -// CHECK: linalg.yield %[[VAL_139]] : i32 +// CHECK: %[[VAL_134:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_123]], %[[VAL_129]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_123]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_135:.*]]: i32, %[[VAL_136:.*]]: i32, %[[VAL_137:.*]]: i32): +// CHECK: %[[VAL_138:.*]] = arith.addi %[[VAL_135]], %[[VAL_136]] : i32 +// CHECK: linalg.yield %[[VAL_138]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_140:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_135]], %[[VAL_105]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_135]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_141:.*]]: i32, %[[VAL_142:.*]]: i32, %[[VAL_143:.*]]: i32): -// CHECK: %[[VAL_144:.*]] = arith.addi %[[VAL_141]], %[[VAL_142]] : i32 -// CHECK: linalg.yield %[[VAL_144]] : i32 +// CHECK: %[[VAL_139:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_134]], %[[VAL_99]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_134]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_140:.*]]: i32, %[[VAL_141:.*]]: i32, %[[VAL_142:.*]]: i32): +// CHECK: %[[VAL_143:.*]] = arith.addi %[[VAL_140]], %[[VAL_141]] : i32 +// CHECK: linalg.yield %[[VAL_143]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_145:.*]] = arith.index_cast %[[VAL_31]] : i32 to index -// CHECK: %[[VAL_146:.*]] = arith.addi %[[VAL_145]], %[[VAL_25]] : index -// CHECK: %[[VAL_147:.*]] = arith.index_cast %[[VAL_45]] : i32 to index -// CHECK: %[[VAL_148:.*]] = arith.minsi %[[VAL_146]], %[[VAL_147]] : index -// CHECK: %[[VAL_149:.*]] = arith.maxsi %[[VAL_148]], %[[VAL_145]] : index -// CHECK: %[[VAL_150:.*]] = arith.subi %[[VAL_149]], %[[VAL_145]] : index -// CHECK: %[[VAL_151:.*]] = arith.addi %[[VAL_44]], %[[VAL_25]] : index -// CHECK: %[[VAL_152:.*]] = arith.index_cast %[[VAL_5]] : i32 to index -// CHECK: %[[VAL_153:.*]] = arith.minsi %[[VAL_151]], %[[VAL_152]] : index -// CHECK: %[[VAL_154:.*]] = arith.maxsi %[[VAL_153]], %[[VAL_44]] : index -// CHECK: %[[VAL_155:.*]] = arith.subi %[[VAL_154]], %[[VAL_44]] : index -// CHECK: %[[VAL_156:.*]] = arith.minsi %[[VAL_150]], %[[VAL_25]] : index -// CHECK: %[[VAL_157:.*]] = arith.minsi %[[VAL_155]], %[[VAL_25]] : index -// CHECK: %[[VAL_158:.*]] = memref.alloc() : memref<64x64xf32> -// CHECK: %[[VAL_159:.*]] = arith.cmpi slt, %[[VAL_156]], %[[VAL_25]] : index -// CHECK: %[[VAL_160:.*]] = arith.cmpi slt, %[[VAL_157]], %[[VAL_25]] : index -// CHECK: %[[VAL_161:.*]] = arith.ori %[[VAL_159]], %[[VAL_160]] : i1 -// CHECK: scf.if %[[VAL_161]] { -// CHECK: linalg.fill ins(%[[VAL_26]] : f32) outs(%[[VAL_158]] : memref<64x64xf32>) +// CHECK: %[[VAL_144:.*]] = arith.addi %[[VAL_44]], %[[VAL_25]] : index +// CHECK: %[[VAL_145:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_146:.*]] = arith.minsi %[[VAL_144]], %[[VAL_145]] : index +// CHECK: %[[VAL_147:.*]] = arith.maxsi %[[VAL_146]], %[[VAL_44]] : index +// CHECK: %[[VAL_148:.*]] = arith.subi %[[VAL_147]], %[[VAL_44]] : index +// CHECK: %[[VAL_149:.*]] = arith.minsi %[[VAL_148]], %[[VAL_25]] : index +// CHECK: %[[VAL_150:.*]] = memref.alloc() : memref<64x64xf32> +// CHECK: %[[VAL_151:.*]] = arith.cmpi slt, %[[VAL_113]], %[[VAL_25]] : index +// CHECK: %[[VAL_152:.*]] = arith.cmpi slt, %[[VAL_149]], %[[VAL_25]] : index +// CHECK: %[[VAL_153:.*]] = arith.ori %[[VAL_151]], %[[VAL_152]] : i1 +// CHECK: scf.if %[[VAL_153]] { +// CHECK: linalg.fill ins(%[[VAL_26]] : f32) outs(%[[VAL_150]] : memref<64x64xf32>) // CHECK: } -// CHECK: %[[VAL_162:.*]] = arith.minsi %[[VAL_156]], %[[VAL_25]] : index -// CHECK: scf.for %[[VAL_163:.*]] = %[[VAL_23]] to %[[VAL_162]] step %[[VAL_22]] { -// CHECK: %[[VAL_164:.*]] = tensor.extract %[[VAL_140]]{{\[}}%[[VAL_163]]] : tensor<64xi32> -// CHECK: %[[VAL_165:.*]] = arith.index_cast %[[VAL_164]] : i32 to index -// CHECK: %[[VAL_166:.*]] = arith.addi %[[VAL_165]], %[[VAL_122]] : index -// CHECK: %[[VAL_167:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_166]]], sizes: [1, 64], strides: [1, %[[VAL_121]]] : memref<*xf32> to memref<1x64xf32, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_168:.*]] = memref.subview %[[VAL_167]][0, 0] [1, %[[VAL_157]]] [1, 1] : memref<1x64xf32, strided<[1, ?], offset: ?>> to memref<1x?xf32, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_169:.*]] = memref.subview %[[VAL_158]]{{\[}}%[[VAL_163]], 0] [1, %[[VAL_157]]] [1, 1] : memref<64x64xf32> to memref<1x?xf32, strided<[64, 1], offset: ?>> -// CHECK: memref.copy %[[VAL_168]], %[[VAL_169]] : memref<1x?xf32, strided<[1, ?], offset: ?>> to memref<1x?xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[VAL_154:.*]] = arith.minsi %[[VAL_113]], %[[VAL_25]] : index +// CHECK: scf.for %[[VAL_155:.*]] = %[[VAL_23]] to %[[VAL_154]] step %[[VAL_22]] { +// CHECK: %[[VAL_156:.*]] = tensor.extract %[[VAL_139]]{{\[}}%[[VAL_155]]] : tensor<64xi32> +// CHECK: %[[VAL_157:.*]] = arith.index_cast %[[VAL_156]] : i32 to index +// CHECK: %[[VAL_158:.*]] = arith.addi %[[VAL_157]], %[[VAL_121]] : index +// CHECK: %[[VAL_159:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_158]]], sizes: [1, 64], strides: [1, %[[VAL_120]]] : memref<*xf32> to memref<1x64xf32, strided<[1, ?], offset: ?>> +// CHECK: %[[VAL_160:.*]] = memref.subview %[[VAL_159]][0, 0] [1, %[[VAL_149]]] [1, 1] : memref<1x64xf32, strided<[1, ?], offset: ?>> to memref<1x?xf32, strided<[1, ?], offset: ?>> +// CHECK: %[[VAL_161:.*]] = memref.subview %[[VAL_150]]{{\[}}%[[VAL_155]], 0] [1, %[[VAL_149]]] [1, 1] : memref<64x64xf32> to memref<1x?xf32, strided<[64, 1], offset: ?>> +// CHECK: memref.copy %[[VAL_160]], %[[VAL_161]] : memref<1x?xf32, strided<[1, ?], offset: ?>> to memref<1x?xf32, strided<[64, 1], offset: ?>> // CHECK: } -// CHECK: %[[VAL_170:.*]] = bufferization.to_tensor %[[VAL_158]] restrict writable : memref<64x64xf32> to tensor<64x64xf32> -// CHECK: %[[VAL_171:.*]] = tensor.expand_shape %[[VAL_114]] {{\[\[}}0, 1]] output_shape [64, 1] : tensor<64xf32> into tensor<64x1xf32> -// CHECK: %[[VAL_172:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_171]] : tensor<64x1xf32>) outs(%[[VAL_29]] : tensor<64x64xf32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_173:.*]]: f32, %[[VAL_174:.*]]: f32): -// CHECK: linalg.yield %[[VAL_173]] : f32 +// CHECK: %[[VAL_162:.*]] = bufferization.to_tensor %[[VAL_150]] restrict writable : memref<64x64xf32> to tensor<64x64xf32> +// CHECK: %[[VAL_163:.*]] = tensor.expand_shape %[[VAL_119]] {{\[\[}}0, 1]] output_shape [64, 1] : tensor<64xf32> into tensor<64x1xf32> +// CHECK: %[[VAL_164:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_163]] : tensor<64x1xf32>) outs(%[[VAL_29]] : tensor<64x64xf32>) attrs = {broadcastDims = array} { +// CHECK: ^bb0(%[[VAL_165:.*]]: f32, %[[VAL_166:.*]]: f32): +// CHECK: linalg.yield %[[VAL_165]] : f32 // CHECK: } -> tensor<64x64xf32> -// CHECK: %[[VAL_175:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_172]], %[[VAL_170]] : tensor<64x64xf32>, tensor<64x64xf32>) outs(%[[VAL_172]] : tensor<64x64xf32>) { -// CHECK: ^bb0(%[[VAL_176:.*]]: f32, %[[VAL_177:.*]]: f32, %[[VAL_178:.*]]: f32): -// CHECK: %[[VAL_179:.*]] = arith.mulf %[[VAL_176]], %[[VAL_177]] : f32 -// CHECK: linalg.yield %[[VAL_179]] : f32 +// CHECK: %[[VAL_167:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_164]], %[[VAL_162]] : tensor<64x64xf32>, tensor<64x64xf32>) outs(%[[VAL_164]] : tensor<64x64xf32>) { +// CHECK: ^bb0(%[[VAL_168:.*]]: f32, %[[VAL_169:.*]]: f32, %[[VAL_170:.*]]: f32): +// CHECK: %[[VAL_171:.*]] = arith.mulf %[[VAL_168]], %[[VAL_169]] : f32 +// CHECK: linalg.yield %[[VAL_171]] : f32 // CHECK: } -> tensor<64x64xf32> -// CHECK: %[[VAL_180:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_76]], %[[VAL_175]] : tensor<64x64xf32>, tensor<64x64xf32>) outs(%[[VAL_76]] : tensor<64x64xf32>) { -// CHECK: ^bb0(%[[VAL_181:.*]]: f32, %[[VAL_182:.*]]: f32, %[[VAL_183:.*]]: f32): -// CHECK: %[[VAL_184:.*]] = arith.addf %[[VAL_181]], %[[VAL_182]] : f32 -// CHECK: linalg.yield %[[VAL_184]] : f32 +// CHECK: %[[VAL_172:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_70]], %[[VAL_167]] : tensor<64x64xf32>, tensor<64x64xf32>) outs(%[[VAL_70]] : tensor<64x64xf32>) { +// CHECK: ^bb0(%[[VAL_173:.*]]: f32, %[[VAL_174:.*]]: f32, %[[VAL_175:.*]]: f32): +// CHECK: %[[VAL_176:.*]] = arith.addf %[[VAL_173]], %[[VAL_174]] : f32 +// CHECK: linalg.yield %[[VAL_176]] : f32 // CHECK: } -> tensor<64x64xf32> -// CHECK: scf.yield %[[VAL_180]] : tensor<64x64xf32> +// CHECK: scf.yield %[[VAL_172]] : tensor<64x64xf32> // CHECK: } -// CHECK: %[[VAL_185:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_186:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_53]], %[[VAL_185]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_53]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_187:.*]]: i32, %[[VAL_188:.*]]: i32, %[[VAL_189:.*]]: i32): -// CHECK: %[[VAL_190:.*]] = arith.muli %[[VAL_187]], %[[VAL_188]] : i32 -// CHECK: linalg.yield %[[VAL_190]] : i32 +// CHECK: %[[VAL_177:.*]] = linalg.fill ins(%[[VAL_13]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_178:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_47]], %[[VAL_177]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_47]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_179:.*]]: i32, %[[VAL_180:.*]]: i32, %[[VAL_181:.*]]: i32): +// CHECK: %[[VAL_182:.*]] = arith.muli %[[VAL_179]], %[[VAL_180]] : i32 +// CHECK: linalg.yield %[[VAL_182]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_191:.*]] = linalg.fill ins(%[[VAL_14]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_192:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_64]], %[[VAL_191]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_64]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_193:.*]]: i32, %[[VAL_194:.*]]: i32, %[[VAL_195:.*]]: i32): -// CHECK: %[[VAL_196:.*]] = arith.muli %[[VAL_193]], %[[VAL_194]] : i32 -// CHECK: linalg.yield %[[VAL_196]] : i32 +// CHECK: %[[VAL_183:.*]] = linalg.fill ins(%[[VAL_14]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_184:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_58]], %[[VAL_183]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_58]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_185:.*]]: i32, %[[VAL_186:.*]]: i32, %[[VAL_187:.*]]: i32): +// CHECK: %[[VAL_188:.*]] = arith.muli %[[VAL_185]], %[[VAL_186]] : i32 +// CHECK: linalg.yield %[[VAL_188]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_197:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_186]], %[[VAL_192]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_186]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_198:.*]]: i32, %[[VAL_199:.*]]: i32, %[[VAL_200:.*]]: i32): -// CHECK: %[[VAL_201:.*]] = arith.addi %[[VAL_198]], %[[VAL_199]] : i32 -// CHECK: linalg.yield %[[VAL_201]] : i32 +// CHECK: %[[VAL_189:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_178]], %[[VAL_184]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_178]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_190:.*]]: i32, %[[VAL_191:.*]]: i32, %[[VAL_192:.*]]: i32): +// CHECK: %[[VAL_193:.*]] = arith.addi %[[VAL_190]], %[[VAL_191]] : i32 +// CHECK: linalg.yield %[[VAL_193]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_202:.*]] = linalg.fill ins(%[[VAL_15]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> -// CHECK: %[[VAL_203:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_69]], %[[VAL_202]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_69]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_204:.*]]: i32, %[[VAL_205:.*]]: i32, %[[VAL_206:.*]]: i32): -// CHECK: %[[VAL_207:.*]] = arith.muli %[[VAL_204]], %[[VAL_205]] : i32 -// CHECK: linalg.yield %[[VAL_207]] : i32 +// CHECK: %[[VAL_194:.*]] = linalg.fill ins(%[[VAL_15]] : i32) outs(%[[VAL_32]] : tensor<64xi32>) -> tensor<64xi32> +// CHECK: %[[VAL_195:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_63]], %[[VAL_194]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_63]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_196:.*]]: i32, %[[VAL_197:.*]]: i32, %[[VAL_198:.*]]: i32): +// CHECK: %[[VAL_199:.*]] = arith.muli %[[VAL_196]], %[[VAL_197]] : i32 +// CHECK: linalg.yield %[[VAL_199]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_208:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_197]], %[[VAL_203]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_197]] : tensor<64xi32>) { -// CHECK: ^bb0(%[[VAL_209:.*]]: i32, %[[VAL_210:.*]]: i32, %[[VAL_211:.*]]: i32): -// CHECK: %[[VAL_212:.*]] = arith.addi %[[VAL_209]], %[[VAL_210]] : i32 -// CHECK: linalg.yield %[[VAL_212]] : i32 +// CHECK: %[[VAL_200:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_189]], %[[VAL_195]] : tensor<64xi32>, tensor<64xi32>) outs(%[[VAL_189]] : tensor<64xi32>) { +// CHECK: ^bb0(%[[VAL_201:.*]]: i32, %[[VAL_202:.*]]: i32, %[[VAL_203:.*]]: i32): +// CHECK: %[[VAL_204:.*]] = arith.addi %[[VAL_201]], %[[VAL_202]] : i32 +// CHECK: linalg.yield %[[VAL_204]] : i32 // CHECK: } -> tensor<64xi32> -// CHECK: %[[VAL_213:.*]] = arith.index_cast %[[VAL_31]] : i32 to index -// CHECK: %[[VAL_214:.*]] = arith.addi %[[VAL_213]], %[[VAL_25]] : index -// CHECK: %[[VAL_215:.*]] = arith.index_cast %[[VAL_45]] : i32 to index -// CHECK: %[[VAL_216:.*]] = arith.minsi %[[VAL_214]], %[[VAL_215]] : index -// CHECK: %[[VAL_217:.*]] = arith.maxsi %[[VAL_216]], %[[VAL_213]] : index -// CHECK: %[[VAL_218:.*]] = arith.subi %[[VAL_217]], %[[VAL_213]] : index -// CHECK: %[[VAL_219:.*]] = arith.addi %[[VAL_44]], %[[VAL_25]] : index -// CHECK: %[[VAL_220:.*]] = arith.index_cast %[[VAL_5]] : i32 to index -// CHECK: %[[VAL_221:.*]] = arith.minsi %[[VAL_219]], %[[VAL_220]] : index -// CHECK: %[[VAL_222:.*]] = arith.maxsi %[[VAL_221]], %[[VAL_44]] : index -// CHECK: %[[VAL_223:.*]] = arith.subi %[[VAL_222]], %[[VAL_44]] : index -// CHECK: %[[VAL_224:.*]] = arith.minsi %[[VAL_218]], %[[VAL_25]] : index -// CHECK: %[[VAL_225:.*]] = arith.minsi %[[VAL_223]], %[[VAL_25]] : index -// CHECK: %[[VAL_226:.*]] = arith.minsi %[[VAL_224]], %[[VAL_25]] : index -// CHECK: scf.for %[[VAL_227:.*]] = %[[VAL_23]] to %[[VAL_226]] step %[[VAL_22]] { -// CHECK: %[[VAL_228:.*]] = tensor.extract %[[VAL_208]]{{\[}}%[[VAL_227]]] : tensor<64xi32> -// CHECK: %[[VAL_229:.*]] = arith.index_cast %[[VAL_228]] : i32 to index -// CHECK: %[[VAL_230:.*]] = tensor.extract_slice %[[VAL_74]]{{\[}}%[[VAL_227]], 0] [1, %[[VAL_225]]] [1, 1] : tensor<64x64xf32> to tensor<1x?xf32> -// CHECK: %[[VAL_231:.*]] = arith.addi %[[VAL_229]], %[[VAL_44]] : index -// CHECK: %[[VAL_232:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_231]]], sizes: [1, 64], strides: [1, 1] : memref<*xf32> to memref<1x64xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_233:.*]] = memref.subview %[[VAL_232]][0, 0] [1, %[[VAL_225]]] [1, 1] : memref<1x64xf32, strided<[1, 1], offset: ?>> to memref<1x?xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_234:.*]] = memref.cast %[[VAL_233]] : memref<1x?xf32, strided<[1, 1], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_230]] in writable %[[VAL_234]] : (tensor<1x?xf32>, memref<1x?xf32, strided<[?, ?], offset: ?>>) -> () +// CHECK: %[[VAL_205:.*]] = arith.index_cast %[[VAL_31]] : i32 to index +// CHECK: %[[VAL_206:.*]] = arith.addi %[[VAL_205]], %[[VAL_25]] : index +// CHECK: %[[VAL_207:.*]] = arith.index_cast %[[VAL_45]] : i32 to index +// CHECK: %[[VAL_208:.*]] = arith.minsi %[[VAL_206]], %[[VAL_207]] : index +// CHECK: %[[VAL_209:.*]] = arith.maxsi %[[VAL_208]], %[[VAL_205]] : index +// CHECK: %[[VAL_210:.*]] = arith.subi %[[VAL_209]], %[[VAL_205]] : index +// CHECK: %[[VAL_211:.*]] = arith.addi %[[VAL_44]], %[[VAL_25]] : index +// CHECK: %[[VAL_212:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_213:.*]] = arith.minsi %[[VAL_211]], %[[VAL_212]] : index +// CHECK: %[[VAL_214:.*]] = arith.maxsi %[[VAL_213]], %[[VAL_44]] : index +// CHECK: %[[VAL_215:.*]] = arith.subi %[[VAL_214]], %[[VAL_44]] : index +// CHECK: %[[VAL_216:.*]] = arith.minsi %[[VAL_210]], %[[VAL_25]] : index +// CHECK: %[[VAL_217:.*]] = arith.minsi %[[VAL_215]], %[[VAL_25]] : index +// CHECK: %[[VAL_218:.*]] = arith.minsi %[[VAL_216]], %[[VAL_25]] : index +// CHECK: scf.for %[[VAL_219:.*]] = %[[VAL_23]] to %[[VAL_218]] step %[[VAL_22]] { +// CHECK: %[[VAL_220:.*]] = tensor.extract %[[VAL_200]]{{\[}}%[[VAL_219]]] : tensor<64xi32> +// CHECK: %[[VAL_221:.*]] = arith.index_cast %[[VAL_220]] : i32 to index +// CHECK: %[[VAL_222:.*]] = tensor.extract_slice %[[VAL_68]]{{\[}}%[[VAL_219]], 0] [1, %[[VAL_217]]] [1, 1] : tensor<64x64xf32> to tensor<1x?xf32> +// CHECK: %[[VAL_223:.*]] = arith.addi %[[VAL_221]], %[[VAL_44]] : index +// CHECK: %[[VAL_224:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_223]]], sizes: [1, 64], strides: [1, 1] : memref<*xf32> to memref<1x64xf32, strided<[1, 1], offset: ?>> +// CHECK: %[[VAL_225:.*]] = memref.subview %[[VAL_224]][0, 0] [1, %[[VAL_217]]] [1, 1] : memref<1x64xf32, strided<[1, 1], offset: ?>> to memref<1x?xf32, strided<[1, 1], offset: ?>> +// CHECK: %[[VAL_226:.*]] = memref.cast %[[VAL_225]] : memref<1x?xf32, strided<[1, 1], offset: ?>> to memref<1x?xf32, strided<[?, ?], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_222]] in writable %[[VAL_226]] : (tensor<1x?xf32>, memref<1x?xf32, strided<[?, ?], offset: ?>>) -> () // CHECK: } // CHECK: return diff --git a/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_if.mlir b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_if.mlir new file mode 100644 index 00000000..44ed4f34 --- /dev/null +++ b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_if.mlir @@ -0,0 +1,50 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// Make sure tts.make_gather_scatter_tptr is generated with for 1D tensor on addptr with if. + +// CHECK-LABEL: tt.func public @gather_row( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32, +// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant dense<8> : tensor<4xi32> +// CHECK: %[[VAL_5:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_6:.*]] = arith.constant dense<4> : tensor<4xi32> +// CHECK: %[[VAL_7:.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_6]] : tensor<4xi32> +// CHECK: %[[VAL_9:.*]] = arith.cmpi slt, %[[VAL_2]], %[[VAL_5]] : i32 +// CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_9]] -> (tensor<4xi32>) { +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_4]] : tensor<4xi32> +// CHECK: scf.yield %[[VAL_11]] : tensor<4xi32> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_8]] : tensor<4xi32> +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tts.make_gather_scatter_tptr %[[VAL_0]] to sizes: [4] gather_scatter_dim: 0 gather_scatter_offset: %[[VAL_10]], strides: [1], offsets: [0] : tensor<4xi32> to !tt.ptr> +// CHECK: %[[VAL_13:.*]] = tts.make_gather_scatter_tptr %[[VAL_1]] to sizes: [4] gather_scatter_dim: 0 gather_scatter_offset: %[[VAL_10]], strides: [1], offsets: [0] : tensor<4xi32> to !tt.ptr> +// CHECK: %[[VAL_14:.*]] = "tts.load"(%[[VAL_12]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (!tt.ptr>) -> tensor<4xf32> +// CHECK: "tts.store"(%[[VAL_13]], %[[VAL_14]]) <{static_mask_dims = array}> : (!tt.ptr>, tensor<4xf32>) -> () +// CHECK: tt.retur + +module { + tt.func public @gather_row(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %cst = arith.constant dense<8> : tensor<4xi32> + %c8_i32 = arith.constant 8 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.divsi %0, %cst_0 : tensor<4xi32> + %2 = arith.cmpi slt, %arg2, %c8_i32 : i32 + %3 = scf.if %2 -> (tensor<4xi32>) { + %9 = arith.addi %1, %cst : tensor<4xi32> + scf.yield %9 : tensor<4xi32> + } else { + scf.yield %1 : tensor<4xi32> + } + %4 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<4x!tt.ptr>, tensor<4xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<4x!tt.ptr>, tensor<4xi32> + %8 = tt.load %5 : tensor<4x!tt.ptr> + tt.store %7, %8 : tensor<4x!tt.ptr> + tt.return + } +} diff --git a/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_loop.mlir b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_loop.mlir new file mode 100644 index 00000000..450aa3e8 --- /dev/null +++ b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_loop.mlir @@ -0,0 +1,32 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// TODO: fix test case in https://github.com/microsoft/triton-shared/pull/332, remove XFAIL and update the CHECKs. +// XFAIL: * + +// Make sure tts.make_gather_scatter_tptr is generated with for 1D tensor on addptr with loop. + +// CHECK: make_gather_scatter_tptr + +module { + tt.func public @gather_row(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<10> : tensor<4xi32> + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = arith.divsi %0, %cst_0 : tensor<4xi32> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %3 = tt.addptr %2, %1 : tensor<4x!tt.ptr>, tensor<4xi32> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %5 = tt.addptr %4, %1 : tensor<4x!tt.ptr>, tensor<4xi32> + %6:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %3, %arg6 = %5) -> (tensor<4x!tt.ptr>, tensor<4x!tt.ptr>) : i32 { + %7 = tt.load %arg5 : tensor<4x!tt.ptr> + tt.store %arg6, %7 : tensor<4x!tt.ptr> + %8 = tt.addptr %arg5, %cst : tensor<4x!tt.ptr>, tensor<4xi32> + %9 = tt.addptr %arg6, %cst : tensor<4x!tt.ptr>, tensor<4xi32> + scf.yield %8, %9 : tensor<4x!tt.ptr>, tensor<4x!tt.ptr> + } + tt.return + } +} diff --git a/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_mask.mlir b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_mask.mlir new file mode 100644 index 00000000..ad58d8cd --- /dev/null +++ b/test/Conversion/TritonToStructured/gather_scatter_ptr_1d_with_mask.mlir @@ -0,0 +1,40 @@ +// RUN: triton-shared-opt --triton-to-structured --remove-dead-values --canonicalize --cse %s | FileCheck %s + +// Make sure tts.load is generated with correct mask for 1D tensor. + +// CHECK-LABEL: tt.func public @row_gather1d_with_mask( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !tt.ptr, +// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_6:.*]] = tts.make_tptr %[[VAL_1]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK: %[[VAL_7:.*]] = "tts.load"(%[[VAL_6]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (tensor<32x!tt.ptr>) -> tensor<32xi32> +// CHECK: %[[VAL_8:.*]] = tts.make_gather_scatter_tptr %[[VAL_0]] to sizes: [32] gather_scatter_dim: 0 gather_scatter_offset: %[[VAL_7]], strides: [1], offsets: [0] : tensor<32xi32> to !tt.ptr> +// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_3]] : i32 to index +// CHECK: %[[VAL_10:.*]] = arith.minsi %[[VAL_9]], %[[VAL_5]] : index +// CHECK: %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : index +// The tts.load with mask. +// CHECK: %[[VAL_12:.*]] = "tts.load"(%[[VAL_8]], %[[VAL_11]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (!tt.ptr>, index) -> tensor<32xf32> +// CHECK: %[[VAL_13:.*]] = tts.make_tptr %[[VAL_2]] to sizes: [32], strides: [1], offsets: [0], shape: [0], order: [] : to tensor<32x!tt.ptr> +// CHECK: "tts.store"(%[[VAL_13]], %[[VAL_12]]) <{static_mask_dims = array}> : (tensor<32x!tt.ptr>, tensor<32xf32>) -> () +// CHECK: tt.return + +module attributes {} { + tt.func public @row_gather1d_with_mask(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.splat %arg3 : i32 -> tensor<32xi32> + %7 = arith.cmpi slt, %0, %6 : tensor<32xi32> + %8 = tt.load %5, %7: tensor<32x!tt.ptr> + %9 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %10 = tt.addptr %9, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + tt.store %10, %8: tensor<32x!tt.ptr> + tt.return + } +}