From a42bd4714eeeca2def008759c4013a312f8c1a6a Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 13:57:11 +0000 Subject: [PATCH 1/2] Add support for TensorProto::UINT4/INT4 Signed-off-by: Rickert, Jonas --- src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 10 ++++++++-- utils/gen_onnx_mlir.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 0468919038..5ff1d54ac6 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -642,11 +642,13 @@ Type convertONNXTypeToMLIRType( return builder.getI1Type(); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: return ONNXStringType::get(builder.getContext()); + case onnx::TensorProto_DataType::TensorProto_DataType_INT4: + return builder.getIntegerType(/*width=*/4); + case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: + return builder.getIntegerType(/*width=*/4, false); case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_INT4: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: llvm_unreachable("Unsupported data type encountered."); return nullptr; @@ -696,6 +698,10 @@ int64_t mlirTypeToOnnxType(Type elemType) { ? onnx::TensorProto::UNDEFINED : onnx::TensorProto::BOOL; break; + case 4: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4 + : onnx::TensorProto::INT4; + break; case 8: onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8 : onnx::TensorProto::INT8; diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index a32c931521..29c1e9960f 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -610,6 +610,10 @@ # FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients # FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero # +# // 4-bit integer data types +# UINT4 = 21; // Unsigned integer in range [0, 15] +# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation +# # // Future extensions go here. # } onnx_types = ( @@ -634,6 +638,8 @@ "float8e4m3fnuz", "float8e5m2", "float8e5m2fnuz", + "uint4", + "int4", ) tblgen_types = ( "BF16", @@ -657,6 +663,8 @@ "F8E4M3FNUZ", "F8E5M2", "F8E5M2FNUZ", + "AnyUI4", + "AnyI4", ) # Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode @@ -1047,10 +1055,12 @@ def parse_type_str(allowedType): "seq": "SeqOf", "map": "TupleOf", "bool": "I1", + "uint4": "UI<4>", "uint8": "UI8", "uint16": "UI16", "uint32": "UI32", "uint64": "UI64", + "int4": "I<4>", "int8": "I8", "int16": "I16", "int32": "I32", From 803b8b133eb0b12f5624ff451e5a25921d20094b Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 28 Jan 2025 14:02:20 +0000 Subject: [PATCH 2/2] Upgrade onnx.Cast to opset 21 Signed-off-by: Rickert, Jonas --- docs/Dialects/onnx.md | 4 ++-- src/Builder/OpBuildTable.inc | 2 +- src/Dialect/ONNX/ONNXOps.td.inc | 4 ++-- .../parse/cast_to_int_4_and_back.onnxtext | 19 +++++++++++++++++++ utils/gen_onnx_mlir.py | 2 +- 5 files changed, 25 insertions(+), 6 deletions(-) create mode 100644 test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 3996ad35d6..5470bd6bc0 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -1114,13 +1114,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values ### `onnx.CategoryMapper` (ONNXCategoryMapperOp) diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 067839b22f..ecb5ed8920 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18}; op_dialect_version_map_["BitwiseOr"] = {18}; op_dialect_version_map_["BitwiseXor"] = {18}; op_dialect_version_map_["BlackmanWindow"] = {17}; -op_dialect_version_map_["Cast"] = {19}; +op_dialect_version_map_["Cast"] = {21}; op_dialect_version_map_["CastLike"] = {19}; op_dialect_version_map_["CastMap"] = {1}; op_dialect_version_map_["CategoryMapper"] = {1}; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 685c5438be..47d5adb32f 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -863,10 +863,10 @@ def ONNXCastOp:ONNX_Op<"Cast", | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | | else | RNE | RNE | RNE | RNE | }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input, DefaultValuedAttr:$saturate, TypeAttr:$to); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; diff --git a/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext new file mode 100644 index 0000000000..c5005ca136 --- /dev/null +++ b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext @@ -0,0 +1,19 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 10, + opset_import: ["" : 22] +> +test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) { + int8_cast_output = Cast (input) + int4_cast_output = Cast (int8_cast_output) + uint8_cast_output = Cast (input2) + uint4_cast_output = Cast (uint8_cast_output) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> +// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4> +// CHECK: } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 29c1e9960f..d5de80a908 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -109,7 +109,7 @@ "BitwiseOr": [18], "BitwiseXor": [18], "BlackmanWindow": [17], - "Cast": [19], + "Cast": [21], "CastLike": [19], "CastMap": [1], "CategoryMapper": [1],