Skip to content

Commit 6e5d133

Browse files
committed
Bump onnx.Cast to opset 21 , adding int/uint4 support (onnx#3057)
* Add support for TensorProto::UINT4/INT4 Signed-off-by: Rickert, Jonas <[email protected]> * Upgrade onnx.Cast to opset 21 Signed-off-by: Rickert, Jonas <[email protected]> --------- Signed-off-by: Rickert, Jonas <[email protected]>
1 parent d863cb7 commit 6e5d133

File tree

6 files changed

+43
-8
lines changed

6 files changed

+43
-8
lines changed

docs/Dialects/onnx.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1163,13 +1163,13 @@ Effects: `MemoryEffects::Effect{}`
11631163

11641164
| Operand | Description |
11651165
| :-----: | ----------- |
1166-
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
1166+
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values
11671167

11681168
#### Results:
11691169

11701170
| Result | Description |
11711171
| :----: | ----------- |
1172-
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
1172+
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values
11731173

11741174
### `onnx.CategoryMapper` (ONNXCategoryMapperOp)
11751175

src/Builder/OpBuildTable.inc

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18};
2828
op_dialect_version_map_["BitwiseOr"] = {18};
2929
op_dialect_version_map_["BitwiseXor"] = {18};
3030
op_dialect_version_map_["BlackmanWindow"] = {17};
31-
op_dialect_version_map_["Cast"] = {19};
31+
op_dialect_version_map_["Cast"] = {21};
3232
op_dialect_version_map_["CastLike"] = {19};
3333
op_dialect_version_map_["CastMap"] = {1};
3434
op_dialect_version_map_["CategoryMapper"] = {1};

src/Dialect/ONNX/ONNXOps.td.inc

+2-2
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,10 @@ def ONNXCastOp:ONNX_Op<"Cast",
911911
| [x] < -FLT_MAX | NaN | NaN | -Inf | NaN |
912912
| else | RNE | RNE | RNE | RNE |
913913
}];
914-
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input,
914+
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input,
915915
DefaultValuedAttr<SI64Attr, "1">:$saturate,
916916
TypeAttr:$to);
917-
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output);
917+
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output);
918918
let extraClassDeclaration = [{
919919
static int getNumberOfOperands() {
920920
return 1;

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -667,11 +667,13 @@ Type convertONNXTypeToMLIRType(
667667
return builder.getI1Type();
668668
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
669669
return ONNXStringType::get(builder.getContext());
670+
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
671+
return builder.getIntegerType(/*width=*/4);
672+
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
673+
return builder.getIntegerType(/*width=*/4, false);
670674

671675
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
672676
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
673-
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
674-
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
675677
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
676678
llvm_unreachable("Unsupported data type encountered.");
677679
return nullptr;
@@ -721,6 +723,10 @@ int64_t mlirTypeToOnnxType(Type elemType) {
721723
? onnx::TensorProto::UNDEFINED
722724
: onnx::TensorProto::BOOL;
723725
break;
726+
case 4:
727+
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4
728+
: onnx::TensorProto::INT4;
729+
break;
724730
case 8:
725731
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8
726732
: onnx::TensorProto::INT8;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
2+
<
3+
ir_version: 10,
4+
opset_import: ["" : 22]
5+
>
6+
test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) {
7+
int8_cast_output = Cast <to: int = 3> (input)
8+
int4_cast_output = Cast <to: int = 22> (int8_cast_output)
9+
uint8_cast_output = Cast <to: int = 2> (input2)
10+
uint4_cast_output = Cast <to: int = 21> (uint8_cast_output)
11+
}
12+
// CHECK-LABEL: func.func @main_graph
13+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) {
14+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8>
15+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4>
16+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8>
17+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4>
18+
// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4>
19+
// CHECK: }

utils/gen_onnx_mlir.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
"BitwiseOr": [18],
110110
"BitwiseXor": [18],
111111
"BlackmanWindow": [17],
112-
"Cast": [19],
112+
"Cast": [21],
113113
"CastLike": [19],
114114
"CastMap": [1],
115115
"CategoryMapper": [1],
@@ -614,6 +614,10 @@
614614
# FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
615615
# FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
616616
#
617+
# // 4-bit integer data types
618+
# UINT4 = 21; // Unsigned integer in range [0, 15]
619+
# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
620+
#
617621
# // Future extensions go here.
618622
# }
619623
onnx_types = (
@@ -638,6 +642,8 @@
638642
"float8e4m3fnuz",
639643
"float8e5m2",
640644
"float8e5m2fnuz",
645+
"uint4",
646+
"int4",
641647
)
642648
tblgen_types = (
643649
"BF16",
@@ -661,6 +667,8 @@
661667
"F8E4M3FNUZ",
662668
"F8E5M2",
663669
"F8E5M2FNUZ",
670+
"AnyUI4",
671+
"AnyI4",
664672
)
665673

666674
# Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode
@@ -1051,10 +1059,12 @@ def parse_type_str(allowedType):
10511059
"seq": "SeqOf",
10521060
"map": "TupleOf",
10531061
"bool": "I1",
1062+
"uint4": "UI<4>",
10541063
"uint8": "UI8",
10551064
"uint16": "UI16",
10561065
"uint32": "UI32",
10571066
"uint64": "UI64",
1067+
"int4": "I<4>",
10581068
"int8": "I8",
10591069
"int16": "I16",
10601070
"int32": "I32",

0 commit comments

Comments
 (0)