Skip to content

Commit

Permalink
Add result type inference to RandomNormalLike and fix wrong hardcodin…
Browse files Browse the repository at this point in the history
…gs for dtypes (#3091)

* Add RandomNormalLike to ops with ResultTypeInference

Signed-off-by: Rickert, Jonas <[email protected]>

* Add type inference for RandomNormalLike op

Signed-off-by: Rickert, Jonas <[email protected]>

* Fix wrong dtype hardcoding in RandomNormal and RandomNormalLike

Also add support for bf16 dtype.
Change tests to not contain invalid mlir

Signed-off-by: Rickert, Jonas <[email protected]>

* Tests for output type inference for RandomNormalLike op

Signed-off-by: Rickert, Jonas <[email protected]>

---------

Signed-off-by: Rickert, Jonas <[email protected]>
Co-authored-by: Sai Kiran Yeddlapalli Ganesh <[email protected]>
  • Loading branch information
jorickert and Sai Kiran Yeddlapalli Ganesh authored Mar 7, 2025
1 parent 32c0fc4 commit d137f49
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 78 deletions.
2 changes: 1 addition & 1 deletion docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -6916,7 +6916,7 @@ TensorProto message, and be valid as an output type.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6177,7 +6177,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal",
}

def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>, DeclareOpInterfaceMethods<ResultTypeInferenceOpInterface>]> {
let summary = "ONNX RandomNormalLike operation";
let description = [{
Generate a tensor with random values drawn from a normal distribution.
Expand Down
45 changes: 25 additions & 20 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,33 @@ LogicalResult ONNXRandomNormalOpShapeHelper::computeShape() {
// Type Inference
//===----------------------------------------------------------------------===//

std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
Type elementType;
if (auto attr = getDtypeAttr()) {
if (getDtype() == 0) {
elementType = FloatType::getF16(getContext());
} else if (getDtype() == 1) {
elementType = FloatType::getF32(getContext());
} else if (getDtype() == 2) {
elementType = FloatType::getF64(getContext());
namespace {
Type getRandomNormalElementType(ONNXRandomNormalOp op) {
if (op.getDtypeAttr()) {
const auto elementTypeID =
static_cast<onnx::TensorProto_DataType>(op.getDtype());
if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
return FloatType::getF16(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT) {
return FloatType::getF32(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE) {
return FloatType::getF64(op.getContext());
} else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
return FloatType::getBF16(op.getContext());
} else {
llvm_unreachable("dtype not supported for RandomNormal");
}
} else {
elementType = FloatType::getF32(getContext());
}
return {UnrankedTensorType::get(elementType)};
return FloatType::getF32(op.getContext());
}
} // namespace

std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
return {UnrankedTensorType::get(getRandomNormalElementType(*this))};
}

//===----------------------------------------------------------------------===//
Expand All @@ -67,15 +78,9 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {

LogicalResult ONNXRandomNormalOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
auto elementTypeID = getDtype();
Type elementType = FloatType::getF32(getContext());
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());

ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(elementType);
return shapeHelper.computeShapeAndUpdateType(
getRandomNormalElementType(*this));
}

//===----------------------------------------------------------------------===//
Expand Down
82 changes: 54 additions & 28 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,69 @@ LogicalResult ONNXRandomNormalLikeOp::verify() {

auto elementTypeIDDType = operandAdaptor.getDtype();
if (elementTypeIDDType) {
int64_t elementTypeID = elementTypeIDDType.value();
if (elementTypeID < 0 || elementTypeID > 2) {
return emitOpError("dtype not 0, 1 or 2.");
const auto elementTypeID =
static_cast<onnx::TensorProto_DataType>(*elementTypeIDDType);
if (elementTypeID !=
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16 &&
elementTypeID !=
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT &&
elementTypeID !=
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE &&
elementTypeID !=
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16) {
return emitOpError("dtype not float16, float, double or bfloat16");
}
if (elementTypeID == 0 && outputType != FloatType::getF16(getContext()))
return emitOpError("output tensor does match 0 dtype.");
else if (elementTypeID == 1 &&
if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16 &&
outputType != FloatType::getF16(getContext()))
return emitOpError("output tensor does not match float16 dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_FLOAT &&
outputType != FloatType::getF32(getContext()))
return emitOpError("output tensor does match 1 dtype.");
else if (elementTypeID == 2 &&
return emitOpError("output tensor does not match float dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE &&
outputType != FloatType::getF64(getContext()))
return emitOpError("output tensor does match 2 dtype.");
return emitOpError("output tensor does not match double dtype.");
else if (elementTypeID ==
onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16 &&
outputType != FloatType::getBF16(getContext()))
return emitOpError("output tensor does not match bfloat16 dtype.");
} else if (inputType != outputType) {
return emitOpError("output and input element types do not match.");
}

return success();
}

static Type getRandomNormalLikeOutputElementType(ONNXRandomNormalLikeOp op) {
auto inputType = mlir::cast<TensorType>(op.getInput().getType());
Type elementType = inputType.getElementType();
if (op.getDtypeAttr()) {
auto builder = OpBuilder(op.getContext());
elementType = convertONNXTypeToMLIRType(
builder, static_cast<onnx::TensorProto_DataType>(
op.getDtypeAttr().getValue().getSExtValue()));
}
return elementType;
}

//===----------------------------------------------------------------------===//
// Type Inference
//===----------------------------------------------------------------------===//

std::vector<Type> ONNXRandomNormalLikeOp::resultTypeInference() {
Type elementType = getRandomNormalLikeOutputElementType(*this);
std::vector<Type> resultTypes;
if (auto rankedInputType =
mlir::dyn_cast<RankedTensorType>(getInput().getType())) {
resultTypes.push_back(rankedInputType.clone(elementType));
} else {
resultTypes.push_back(UnrankedTensorType::get(elementType));
}
return resultTypes;
}

//===----------------------------------------------------------------------===//
// Shape Inference
//===----------------------------------------------------------------------===//
Expand All @@ -65,24 +109,6 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
if (!hasShapeAndRank(getInput()))
return success();
auto inputType = mlir::cast<RankedTensorType>(getInput().getType());
auto elementTypeIDDType = getDtype();

// Default output tensor type in all cases is the input tensor type.
Type elementType;
if (!elementTypeIDDType) {
elementType = inputType.getElementType();
} else {
int64_t elementTypeID = elementTypeIDDType.value();
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
else if (elementTypeID == 1)
elementType = FloatType::getF32(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
else
return emitError("dtype attribute is invalid (use: 0, 1 or 2)");
}

Type elementType = getRandomNormalLikeOutputElementType(*this);
return inferShapeForUnaryOps(getOperation(), elementType);
}
4 changes: 2 additions & 2 deletions test/mlir/conversion/onnx_to_krnl/Math/RamdomNormal.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func.func @test_random_normal1() -> tensor<*xf32> {
// -----

func.func @test_random_normal2() -> tensor<*xf32> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf32>
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_random_normal2
// CHECK-DAG: [[ALLOC:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<3x4x5xf64>
Expand All @@ -31,7 +31,7 @@ func.func @test_random_normal2() -> tensor<*xf32> {
// -----

func.func @test_random_normal3() -> tensor<*xf32> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32} : () -> tensor<*xf32>
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32} : () -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_random_normal3
// CHECK-DAG: [[ALLOC:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<3x4x5xf64>
Expand Down
4 changes: 2 additions & 2 deletions test/mlir/conversion/onnx_to_krnl/Math/RamdomNormalLike.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func.func @test_random_normal_like1(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32> {
// -----

func.func @test_random_normal_like2(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<3x4x5xf32>) -> tensor<*xf32>
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<3x4x5xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_random_normal_like2
// CHECK-DAG: [[ALLOC:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<3x4x5xf64>
Expand All @@ -31,7 +31,7 @@ func.func @test_random_normal_like2(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32> {
// -----

func.func @test_random_normal_like3(%arg0: tensor<3x4x5xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32} : (tensor<3x4x5xf32>) -> tensor<*xf32>
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32} : (tensor<3x4x5xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: @test_random_normal_like3
// CHECK-DAG: [[ALLOC:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<3x4x5xf64>
Expand Down
68 changes: 44 additions & 24 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2899,12 +2899,12 @@ func.func @test_onehot_dynamic(%arg0: tensor<?x2xi64>, %arg1: tensor<i64>, %arg2

// Test RandomNormal static

func.func @test_random_normal_static_f16() -> tensor<*xf32> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 0 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_static_f16() -> tensor<*xf16> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 10 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf16>
"onnx.Return"(%0) : (tensor<*xf16 >) -> ()

// CHECK-LABEL: @test_random_normal_static_f16
// CHECK: [[R0:%.+]] = "onnx.RandomNormal"() {dtype = 0 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32, shape = [3, 4, 5]} : () -> tensor<3x4x5xf16>
// CHECK: [[R0:%.+]] = "onnx.RandomNormal"() {dtype = 10 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32, shape = [3, 4, 5]} : () -> tensor<3x4x5xf16>
}

// -----
Expand All @@ -2919,12 +2919,22 @@ func.func @test_random_normal_static_f32() -> tensor<*xf32> {

// -----

func.func @test_random_normal_static_f64() -> tensor<*xf32> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_static_f64() -> tensor<*xf64> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xf64>
"onnx.Return"(%0) : (tensor<*xf64>) -> ()

// CHECK-LABEL: @test_random_normal_static_f64
// CHECK: [[R0:%.+]] = "onnx.RandomNormal"() {dtype = 2 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32, shape = [3, 4, 5]} : () -> tensor<3x4x5xf64>
// CHECK: [[R0:%.+]] = "onnx.RandomNormal"() {dtype = 11 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32, shape = [3, 4, 5]} : () -> tensor<3x4x5xf64>
}

// -----

func.func @test_random_normal_static_bf16() -> tensor<*xbf16> {
%0 = "onnx.RandomNormal"() {shape = [3, 4, 5], dtype = 16 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : () -> tensor<*xbf16>
"onnx.Return"(%0) : (tensor<*xbf16>) -> ()

// CHECK-LABEL: @test_random_normal_static_bf16
// CHECK: [[R0:%.+]] = "onnx.RandomNormal"() {dtype = 16 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32, shape = [3, 4, 5]} : () -> tensor<3x4x5xbf16>
}

//===----------------------------------------------------------------------===//
Expand All @@ -2933,12 +2943,12 @@ func.func @test_random_normal_static_f64() -> tensor<*xf32> {

// Test RandomNormalLike static

func.func @test_random_normal_like_static_f16(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 0 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x28xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_like_static_f16(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf16> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 10 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x28xf32>) -> tensor<*xf16>
"onnx.Return"(%0) : (tensor<*xf16>) -> ()

// CHECK-LABEL: @test_random_normal_like_static_f16
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 0 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x28xf32>) -> tensor<1x1x28x28xf16>
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 10 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x28xf32>) -> tensor<1x1x28x28xf16>
}

// -----
Expand All @@ -2953,24 +2963,34 @@ func.func @test_random_normal_like_static_f32(%arg0: tensor<1x1x28x28xf32>) -> t

// -----

func.func @test_random_normal_like_static_f64(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x28xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_like_static_f64(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf64> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x28xf32>) -> tensor<*xf64>
"onnx.Return"(%0) : (tensor<*xf64>) -> ()

// CHECK-LABEL: @test_random_normal_like_static_f64
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x28xf32>) -> tensor<1x1x28x28xf64>
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x28xf32>) -> tensor<1x1x28x28xf64>
}

// -----

func.func @test_random_normal_like_static_bf16(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xbf16> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 16 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x28xf32>) -> tensor<*xbf16>
"onnx.Return"(%0) : (tensor<*xbf16>) -> ()

// CHECK-LABEL: @test_random_normal_like_static_bf16
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 16 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x28xf32>) -> tensor<1x1x28x28xbf16>
}

// -----

// Test RandomNormalLike dynamic

func.func @test_random_normal_like_dynamic_f16(%arg0: tensor<1x?x28x28xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 0 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x?x28x28xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_like_dynamic_f16(%arg0: tensor<1x?x28x28xf32>) -> tensor<*xf16> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 10 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x?x28x28xf32>) -> tensor<*xf16>
"onnx.Return"(%0) : (tensor<*xf16>) -> ()

// CHECK-LABEL: @test_random_normal_like_dynamic_f16
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 0 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x?x28x28xf32>) -> tensor<1x?x28x28xf16>
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 10 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x?x28x28xf32>) -> tensor<1x?x28x28xf16>
}

// -----
Expand All @@ -2985,12 +3005,12 @@ func.func @test_random_normal_like_dynamic_f32(%arg0: tensor<1x1x?x28xf32>) -> t

// -----

func.func @test_random_normal_like_dynamic_f64(%arg0: tensor<1x1x28x?xf32>) -> tensor<*xf32> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x?xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()
func.func @test_random_normal_like_dynamic_f64(%arg0: tensor<1x1x28x?xf32>) -> tensor<*xf64> {
%0 = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.0 :f32, scale = 1.0 : f32, seed = 2.0 : f32} : (tensor<1x1x28x?xf32>) -> tensor<*xf64>
"onnx.Return"(%0) : (tensor<*xf64>) -> ()

// CHECK-LABEL: @test_random_normal_like_dynamic_f64
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 2 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x?xf32>) -> tensor<1x1x28x?xf64>
// CHECK: [[R0:%.+]] = "onnx.RandomNormalLike"(%arg0) {dtype = 11 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<1x1x28x?xf32>) -> tensor<1x1x28x?xf64>
}

// -----
Expand Down
19 changes: 19 additions & 0 deletions test/mlir/onnx/parse/random_normal_like_dtype_bf16.onnxtext
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: onnx-mlir --EmitONNXBasic --useOnnxModelTypes=false --printIR %s | FileCheck %s

// Test output type inference of RandomNormalLike ignoring the types from the model using --useOnnxModelTypes=false
// Output type should be bf16 as dtype = 16 eventhough the output type specified in model is float32

<
ir_version: 4,
opset_import: ["" : 22]
>
test_random_normal_like_dtype (float[unk__a,unk__b] RandomNormalLike_in) => (float[] RandomNormalLike_out)
{
RandomNormalLike_out = RandomNormalLike<dtype: int = 16, mean: float = 0.0, scale: float = 1.0, seed: float = 2.0> (RandomNormalLike_in)
}

// CHECK-LABEL: func.func @main_graph(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32> {onnx.dim_params = "0:unk__a,1:unk__b", onnx.name = "RandomNormalLike_in"}) -> (tensor<?x?xbf16> {onnx.name = "RandomNormalLike_out"}) {
// CHECK: %[[VAL_1:.*]] = "onnx.RandomNormalLike"(%[[VAL_0]]) {dtype = 16 : si64, mean = 0.000000e+00 : f32, scale = 1.000000e+00 : f32, seed = 2.000000e+00 : f32} : (tensor<?x?xf32>) -> tensor<?x?xbf16>
// CHECK: onnx.Return %[[VAL_1]] : tensor<?x?xbf16>
// CHECK: }
Loading

0 comments on commit d137f49

Please sign in to comment.