diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..f0b77da5acd02 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final } }; +/// This pattern linearizes vector.load from vector<1xN> to vector. +/// It currently supports only lineariztion of <1XN> to +/// Following, +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> +/// is converted to: +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> +/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32> +struct LinearizeVectorLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = loadOp.getType(); + if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) + return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported"); + auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), + vecTy.isScalable()); + auto newLoad = rewriter.create( + loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); + auto shapeCast = rewriter.create( + loadOp.getLoc(), vecTy, newLoad.getResult()); + rewriter.replaceOp(loadOp, shapeCast.getResult()); + return success(); + } +}; + +/// This pattern linearizes vector.store from vector<1xN> to vector. +/// It currently supports only lineariztion of <1XN> to +/// Following, +/// vector.store %arg0, %arg1[%c0, %c0] +/// : vector<1x4xf32>, memref<1x4xf32> +/// is converted to: +/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// vector.store %arg0, %arg1[%c0, %%c0] +/// : vector<4xf32>, memref<1x4xf32> +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = storeOp.getValueToStore().getType(); + if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) + return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported"); + auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), + vecTy.isScalable()); + + Value valueToStore = adaptor.getValueToStore(); + if (valueToStore.getType() != linearTy) { + valueToStore = rewriter.create( + storeOp.getLoc(), linearTy, valueToStore); + } + + rewriter.replaceOpWithNewOp( + storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices()); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, + LinearizeVectorStore>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9cbf319ffddb2..fa0436792d3f0 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> return %0 : vector<1x[16]xi1> } + +// CHECK-LABEL: linearize_vector_load +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32> +func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32> + // CHECK: return %[[CAST]] : vector<1x4xf32> + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// CHECK-LABEL: linearize_vector_store +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>) +func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) { + // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32> + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + return +}