diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index aef156c5f1d05..fc99a8e30ef1f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -798,7 +798,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ This operation decomposes all the scalar elements from a vector. The decomposed scalar elements are returned in row-major order. The number of scalar results must match the number of elements in the input vector type. - All the result elements have the same result type, which must match the + All the result elements have the same type, which must match the element type of the input vector. Scalable vectors are not supported. Examples: @@ -813,7 +813,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ // %0#0 = %v1[0] // %0#1 = %v1[1] - // Decompose a 2-D. + // Decompose a 2-D vector. %0:6 = vector.to_elements %v2 : vector<2x3xf32> // %0#0 = %v2[0, 0] // %0#1 = %v2[0, 1] @@ -835,6 +835,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [ let arguments = (ins AnyVectorOfAnyRank:$source); let results = (outs Variadic:$elements); + + + let builders = [ + // Build method that infers the result types from `elements`. + OpBuilder<(ins "Value":$elements)>, + ]; + let assemblyFormat = "$source attr-dict `:` type($source)"; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6f0ac6bb58282..cd0516c80377b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2417,6 +2417,15 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, return foldToElementsFromElements(*this, results); } +void vector::ToElementsOp::build(OpBuilder &builder, OperationState &result, + Value elements) { + auto vectorType = cast(elements.getType()); + Type elementType = vectorType.getElementType(); + int64_t nbElements = vectorType.getNumElements(); + SmallVector scalarTypes(nbElements, elementType); + build(builder, result, scalarTypes, elements); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===//