12
12
//
13
13
// ===----------------------------------------------------------------------===//
14
14
15
+ #include " mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16
+ #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17
+
18
+ #include " src/Compiler/CompilerOptions.hpp"
15
19
#include " src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
16
20
#include " src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
17
21
@@ -37,7 +41,8 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
37
41
Location loc = ONNXLoc<ONNXGatherOp>(op);
38
42
ValueRange operands = adaptor.getOperands ();
39
43
40
- MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
44
+ MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
45
+ MathBuilder>
41
46
create (rewriter, loc);
42
47
43
48
// Get shape.
@@ -122,6 +127,41 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
122
127
if (indicesMayBeNegative)
123
128
index = index .selectOrSelf (index < zeroIE, index + axisDim);
124
129
130
+ // The Gather op is data dependent: the value of index should be
131
+ // within the input data size.
132
+ // Add runtime check if enableSafeCodeGen is set true
133
+ // Implementation comments vs. createGenerateRuntimeVerificationPass
134
+ // This check is according to onnx op semantics, not general bound
135
+ // check for memref. Implementation of RuntimeVerification could be
136
+ // borrowed. Slightly difference is that onnx semenatics check is for
137
+ // each dimension independently, not the final address is within
138
+ // the memref bound.
139
+ if (enableSafeCodeGen) {
140
+ // From onnx document:
141
+ // All index values are expected to be within bounds [-s, s-1]
142
+ // along axis of size s. It is an error if any of the index values
143
+ // are out of bounds.
144
+ // After the negative correction, the range should be [0, s-1]
145
+ Value upperBound = create.mem .dim (data, axisLit);
146
+ Value compareUpperBound =
147
+ create.math .slt (index .getValue (), upperBound);
148
+ // Report onnx_node_name if the op has the attribute
149
+ std::string nodeNameStr = op->getName ().getStringRef ().str () + " " ;
150
+ StringAttr nodeName =
151
+ op->getAttrOfType <mlir::StringAttr>(" onnx_node_name" );
152
+ if (nodeName && !nodeName.getValue ().empty ()) {
153
+ nodeNameStr = nodeNameStr + nodeName.getValue ().str ();
154
+ }
155
+ rewriter.create <cf::AssertOp>(loc, compareUpperBound,
156
+ nodeNameStr +
157
+ " indices of GatherOp is larger than the upper bound" );
158
+ Value compareLowerBound =
159
+ create.math .sge (index .getValue (), zeroIE.getValue ());
160
+ rewriter.create <cf::AssertOp>(loc, compareLowerBound,
161
+ nodeNameStr +
162
+ " indices of GatherOp is less than the lower bound" );
163
+ }
164
+
125
165
// Compute access function of data: data[ii + (indices[jj],) + kk]
126
166
SmallVector<IndexExpr, 4 > dataAccessFct;
127
167
// First add indices iis
0 commit comments