From 599ab5442012ebeb70728b1e3d7762b59a5dd7b8 Mon Sep 17 00:00:00 2001 From: Boyana Norris Date: Fri, 21 Feb 2025 10:08:12 -0800 Subject: [PATCH] update float types, tosa, other misc changes Signed-off-by: Boyana Norris --- .../ZLowToLLVM/ZLowToLLVMCommon.cpp | 8 +++---- .../KrnlToLLVM/KrnlRandomNormal.cpp | 4 ++-- src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp | 13 +++++----- .../KrnlToLLVM/KrnlVectorTypeCast.cpp | 4 ++-- .../ONNXToKrnl/ML/CategoryMapper.cpp | 2 +- src/Conversion/ONNXToKrnl/Math/LRN.cpp | 2 +- src/Conversion/ONNXToTOSA/DialectBuilder.cpp | 24 ++++++++++++------- .../ONNXToTOSA/Math/Elementwise.cpp | 20 ++++++++++++---- src/Conversion/ONNXToTOSA/Math/Gemm.cpp | 10 ++++---- src/Dialect/ONNX/ElementsAttr/BType.cpp | 10 ++++---- src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp | 2 +- .../ONNX/ONNXOps/Math/ElementwiseUnary.cpp | 2 +- .../ONNX/ONNXOps/Math/RandomNormal.cpp | 14 +++++------ .../ONNX/ONNXOps/Math/RandomNormalLike.cpp | 14 +++++------ src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 8 +++---- .../Quantize/DynamicQuantizeLinear.cpp | 2 +- src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp | 4 ++-- .../ONNX/ONNXOps/Tensor/ConstantOfShape.cpp | 4 ++-- utils/clone-mlir.sh | 2 +- 19 files changed, 83 insertions(+), 66 deletions(-) diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index 114c19d618..6853f9d070 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) { auto int16Ty = IntegerType::get(context, 16); auto int32Ty = IntegerType::get(context, 32); auto int64Ty = IntegerType::get(context, 64); - auto float32Ty = FloatType::getF32(context); + auto float32Ty = Float32Type::get(context); // Declare API type as an enum value, its string name and an LLVM Type // specifying its signature. @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) { Type llvmI64Ty = IntegerType::get(context, 64); Type llvmI1Ty = IntegerType::get(context, 1); Type llvmI8Ty = IntegerType::get(context, 8); - Type llvmF32Ty = FloatType::getF32(context); + Type llvmF32Ty = Float32Type::get(context); Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3); Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20); Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty); @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float"); create.llvm.store(recScale, recScalePtr); } else { - Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + Value zero = create.llvm.constant(Float32Type::get(context), (double)0.); create.llvm.store(zero, recScalePtr); } @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float"); create.llvm.store(offset, offsetPtr); } else { - Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + Value zero = create.llvm.constant(Float32Type::get(context), (double)0.); create.llvm.store(zero, offsetPtr); } diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp index e976b42b7f..5a4c494f14 100644 --- a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern { // or // (memref<3x4x5xf64>, index, f64, f64, f64) Type llvmVoidTy = LLVM::LLVMVoidType::get(context); - Type llvmOptionsTy = FloatType::getF32(context); + Type llvmOptionsTy = Float32Type::get(context); Type llvmOutputTy = getPointerType(context, llvmOptionsTy); if (inType.isF64()) { - llvmOptionsTy = FloatType::getF64(context); + llvmOptionsTy = Float64Type::get(context); llvmOutputTy = getPointerType(context, llvmOptionsTy); } Type llvmI64Ty = IntegerType::get(context, 64); diff --git a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp index 2a0ee747c7..a50acf402f 100644 --- a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { Type outType = op->getResultTypes().front(); Type llvmInType, llvmOutType; if (inType.isF16()) - llvmInType = FloatType::getF16(context); + llvmInType = Float16Type::get(context); else if (inType.isF32()) - llvmInType = FloatType::getF32(context); + llvmInType = Float32Type::get(context); else if (inType.isF64()) - llvmInType = FloatType::getF64(context); + llvmInType = Float64Type::get(context); else if (inType.isBF16()) - llvmInType = FloatType::getBF16(context); + llvmInType = Float64Type::get(context); if (outType.isInteger(1)) llvmOutType = IntegerType::get(context, 1); else if (outType.isF32()) - llvmOutType = FloatType::getF32(context); + llvmOutType = Float32Type::get(context); else if (outType.isF64()) - llvmOutType = FloatType::getF64(context); + llvmOutType = Float64Type::get(context); // Insert and/or get reference to elementary math function declaration. assert( @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { return SymbolRefAttr::get(context, mathFuncName); // Create function declaration. - // auto llvmF32Ty = FloatType::get(context); auto llvmFnType = LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef({llvmInType})); diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp index 62d7c25de3..a52e57afe7 100644 --- a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { // Get memRefDescriptor, the new memref descriptor. MemRefDescriptor memRefDescriptor = - MemRefDescriptor::undef(rewriter, loc, targetStructType); + MemRefDescriptor::poison(rewriter, loc, targetStructType); auto targetElementPtrType = memRefDescriptor.getElementPtrType(); // Set the new memref to the same buffer as the source memref. @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(targetType, strides, offset))) + if (failed(targetType.getStridesAndOffset(strides, offset))) return failure(); // Unhandled dynamic offset. diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index 565e63a7d7..00e252fdb6 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering SmallVector strides; int64_t alignmentOffset; // not used, just to make the function call // completed. - if (getStridesAndOffset(memRefType, strides, alignmentOffset) + if (memRefType.getStridesAndOffset(strides, alignmentOffset) .failed()) llvm_unreachable("Failed to get strides"); Value stringMemRef = diff --git a/src/Conversion/ONNXToKrnl/Math/LRN.cpp b/src/Conversion/ONNXToKrnl/Math/LRN.cpp index 1b08661a2d..12a596d08c 100644 --- a/src/Conversion/ONNXToKrnl/Math/LRN.cpp +++ b/src/Conversion/ONNXToKrnl/Math/LRN.cpp @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern { float alphaLit = adaptor.getAlpha().convertToFloat(); float betaLit = adaptor.getBeta().convertToFloat(); int sizeLit = adaptor.getSize(); - auto f32Type = FloatType::getF32(rewriter.getContext()); + auto f32Type = Float32Type::get(rewriter.getContext()); Value biasValue = create.math.constant(f32Type, biasLit); Value alphaDivSizeValue = create.math.constant(f32Type, alphaLit / static_cast(sizeLit)); diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index adf494c88e..4b6755520c 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) { Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef size, llvm::ArrayRef start) { - DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size); - DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start); + auto startVal = + mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start)); + auto sizeVal = + mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size)); Value newSliceInput = tosa::CreateOpAndInfer(rewriter(), loc(), RankedTensorType::get( llvm::SmallVector(size.size(), ShapedType::kDynamic), mlir::cast(inputConst.getType()).getElementType()), - inputConst, startAttr, sizeAttr); + inputConst, startVal, sizeVal); return newSliceInput; } @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef shape) { Type newValueType = RankedTensorType::get( llvm::SmallVector(shape.size(), ShapedType::kDynamic), valueType.getElementType()); - return tosa::CreateOpAndInfer( - rewriter(), loc(), newValueType, value, shapeAttr); + return tosa::CreateOpAndInfer(rewriter(), loc(), + newValueType, value, + mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr)); } Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsType.getElementType()); + + auto int8Type = rewriter().getI8Type(); + auto shiftValue = + TosaBuilder::createConst(ArrayRef{shift}, {1}, int8Type); return tosa::CreateOpAndInfer( - rewriter(), loc(), newValueType, lhs, rhs, shift); + rewriter(), loc(), newValueType, lhs, rhs, shiftValue); } Value TosaBuilder::intdiv(Value &lhs, Value &rhs) { @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp(Value &lhs, Value &rhs); // Return null if none is found. ElementsAttr IndexExprBuilderForTosa::getConst(Value value) { auto definingOp = value.getDefiningOp(); - // If we have a cast between index/integer, skip it, i.e. get the defining op - // that is the input to the cast. + // If we have a cast between index/integer, skip it, i.e. get the defining + // op that is the input to the cast. if (auto castOp = dyn_cast_or_null(definingOp)) { Value input = castOp.getIn(); definingOp = input.getDefiningOp(); diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 2e105d2dc5..ab8b9a43a0 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { // Quantized types are not supported right now (in type conversion). // Once they are, the input should be rescaled for quantized types. (TBD) // Maps to `tosa.clamp` which has both int and fp limits. - rewriter.replaceOpWithNewOp(op, op.getType(), input, - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); + auto inputElementType = + llvm::cast(op.getType()).getElementType(); + if (llvm::isa(inputElementType)) { + auto minClamp = rewriter.getI64IntegerAttr(0); + auto maxClamp = + rewriter.getI64IntegerAttr(std::numeric_limits::max()); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, minClamp, maxClamp); + } else { + auto minClamp = rewriter.getF32FloatAttr(0.0f); + auto maxClamp = + rewriter.getF32FloatAttr(std::numeric_limits::max()); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, minClamp, maxClamp); + } return success(); } }; diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index 4f1028002c..9d25f922ad 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" @@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { llvm::SmallVector dynamicTensorShape = { ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}; - A = tosa::CreateOpAndInfer(rewriter, op->getLoc(), + + tosa::CreateOpAndInfer(rewriter, op->getLoc(), RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A, - rewriter.getDenseI64ArrayAttr(newShapeA)) - .getResult(); + mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA)) + .getResult(); B = tosa::CreateOpAndInfer(rewriter, op->getLoc(), RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B, - rewriter.getDenseI64ArrayAttr(newShapeB)) + mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB)) .getResult(); // If transA or transB are present, create Transpose operators. diff --git a/src/Dialect/ONNX/ElementsAttr/BType.cpp b/src/Dialect/ONNX/ElementsAttr/BType.cpp index 8073d2a4e2..a6aa4b17f5 100644 --- a/src/Dialect/ONNX/ElementsAttr/BType.cpp +++ b/src/Dialect/ONNX/ElementsAttr/BType.cpp @@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) { case BType::FLOAT : return b.getF32Type(); case BType::FLOAT16 : return b.getF16Type(); case BType::BFLOAT16 : return b.getBF16Type(); - case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType(); - case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType(); - case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type(); - case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType(); + case BType::FLOAT8E4M3FN : return b.getType(); + case BType::FLOAT8E4M3FNUZ : return b.getType(); + case BType::FLOAT8E5M2 : return b.getType(); + case BType::FLOAT8E5M2FNUZ : return b.getType(); default: llvm_unreachable("unsupported data type"); } // clang-format on @@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) { [](auto btype) { return toBType::widetype>; }); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp index 47a74a0093..56fd3c5ca8 100644 --- a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp +++ b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp @@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes( return success(); ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {}); - return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext())); + return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext())); return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp index a38ddfcb11..13308602cd 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp @@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes( ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {}); RankedTensorType xType = mlir::dyn_cast(getX().getType()); return shapeHelper.computeShapeAndUpdateType( - FloatType::getF32(getContext()), xType.getEncoding()); + Float32Type::get(getContext()), xType.getEncoding()); } //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp index 926f37764f..e5cdb01cde 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp @@ -47,16 +47,16 @@ std::vector ONNXRandomNormalOp::resultTypeInference() { Type elementType; if (auto attr = getDtypeAttr()) { if (getDtype() == 0) { - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); } else if (getDtype() == 1) { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } else if (getDtype() == 2) { - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); } else { llvm_unreachable("dtype not supported for RandomNormal"); } } else { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } return {UnrankedTensorType::get(elementType)}; } @@ -68,11 +68,11 @@ std::vector ONNXRandomNormalOp::resultTypeInference() { LogicalResult ONNXRandomNormalOp::inferShapes( std::function doShapeInference) { auto elementTypeID = getDtype(); - Type elementType = FloatType::getF32(getContext()); + Type elementType = Float32Type::get(getContext()); if (elementTypeID == 0) - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); else if (elementTypeID == 2) - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp index 9df2bbe18b..321d2b55a1 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp @@ -42,13 +42,11 @@ LogicalResult ONNXRandomNormalLikeOp::verify() { if (elementTypeID < 0 || elementTypeID > 2) { return emitOpError("dtype not 0, 1 or 2."); } - if (elementTypeID == 0 && outputType != FloatType::getF16(getContext())) + if (elementTypeID == 0 && outputType != Float16Type::get(getContext())) return emitOpError("output tensor does match 0 dtype."); - else if (elementTypeID == 1 && - outputType != FloatType::getF32(getContext())) + else if (elementTypeID == 1 && outputType != Float32Type::get(getContext())) return emitOpError("output tensor does match 1 dtype."); - else if (elementTypeID == 2 && - outputType != FloatType::getF64(getContext())) + else if (elementTypeID == 2 && outputType != Float64Type::get(getContext())) return emitOpError("output tensor does match 2 dtype."); } else if (inputType != outputType) { return emitOpError("output and input element types do not match."); @@ -75,11 +73,11 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes( } else { int64_t elementTypeID = elementTypeIDDType.value(); if (elementTypeID == 0) - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); else if (elementTypeID == 1) - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); else if (elementTypeID == 2) - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); else return emitError("dtype attribute is invalid (use: 0, 1 or 2)"); } diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 7f260f2e99..9eb0b27d21 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -608,13 +608,13 @@ Type convertONNXTypeToMLIRType( Builder &builder, onnx::TensorProto_DataType onnxType) { switch (onnxType) { case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: - return builder.getFloat8E4M3FNType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: - return builder.getFloat8E4M3FNUZType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: - return builder.getFloat8E5M2Type(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: - return builder.getFloat8E5M2FNUZType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16: return builder.getBF16Type(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp index 7f27d19ebb..ae1ea165fd 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp @@ -61,7 +61,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes( IntegerType ui8Type = IntegerType::get(getContext(), 8, IntegerType::Unsigned); - FloatType f32Type = FloatType::getF32(getContext()); + FloatType f32Type = Float32Type::get(getContext()); ONNXDynamicQuantizeLinearOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateTypes( diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp index 70ee132682..bfa487d74a 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp @@ -54,10 +54,10 @@ std::vector ONNXConstantOp::resultTypeInference() { } else if (auto attr = getSparseValueAttr()) { type = mlir::cast(attr).getShapedType(); } else if (auto attr = getValueFloatAttr()) { - type = RankedTensorType::get({}, FloatType::getF32(getContext())); + type = RankedTensorType::get({}, Float32Type::get(getContext())); } else if (auto attr = getValueFloatsAttr()) { int64_t size = attr.size(); - type = RankedTensorType::get({size}, FloatType::getF32(getContext())); + type = RankedTensorType::get({size}, Float32Type::get(getContext())); } else if (auto attr = getValueIntAttr()) { type = RankedTensorType::get({}, IntegerType::get(getContext(), 64)); } else if (auto attr = getValueIntsAttr()) { diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index 6058adfcdb..773152fc52 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -99,7 +99,7 @@ std::vector ONNXConstantOfShapeOp::resultTypeInference() { if (auto attr = getValueAttr()) { elementType = mlir::cast(attr).getElementType(); } else { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } return {UnrankedTensorType::get(elementType)}; } @@ -125,7 +125,7 @@ LogicalResult ONNXConstantOfShapeOp::inferShapes( } else { // If 'value' attribute is not specified, it defaults to a tensor of // value 0 and datatype float32. - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); llvm::SmallVector dims(1, 1); auto tensorType = RankedTensorType::get(dims, elementType); diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index 804dff5fda..c7366d7b5f 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd .. +cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..