Skip to content

Commit 584ee43

Browse files
authored
Add runtime check for Gather Op (#3069)
* implemented Signed-off-by: Chen Tong <[email protected]> * add node name Signed-off-by: Chen Tong <[email protected]> * format Signed-off-by: Chen Tong <[email protected]> --------- Signed-off-by: Chen Tong <[email protected]>
1 parent f07fe63 commit 584ee43

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed

src/Compiler/CompilerOptions.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ std::string opsForCall; // common for both
4545
bool disableKrnlOpFusion; // common for both
4646
bool disableQuantZeroPoint; // common for both
4747
bool enableKrnlBufferReuse; // common for both
48+
bool enableSafeCodeGen; // common for both
4849
bool disableMemRefPrefetch; // common for both
4950
uint64_t compilationNumThreads; // common for both
5051
EmissionTargetType emissionTarget; // onnx-mlir only
@@ -245,6 +246,16 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
245246
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
246247
llvm::cl::cat(OnnxMlirCommonOptions));
247248

249+
static llvm::cl::opt<bool, true> enableSafeCodeGenOpt("enable-safe-code-gen",
250+
llvm::cl::desc("enable extra runtime check to be created in code gen. "
251+
"Such check will have cost at runtime, and is not needed if"
252+
"the model and the data are correct."
253+
"Failure of check will trigger assertion error."
254+
"(default=false).\n"
255+
"Set to 'true' if you want to enable the check."),
256+
llvm::cl::location(enableSafeCodeGen), llvm::cl::init(false),
257+
llvm::cl::cat(OnnxMlirCommonOptions));
258+
248259
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
249260
"disable-memref-prefetch",
250261
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"

src/Compiler/CompilerOptions.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ extern std::string opsForCall; // common for both
9191
extern bool disableKrnlOpFusion; // common for both
9292
extern bool disableQuantZeroPoint; // common for both
9393
extern bool enableKrnlBufferReuse; // common for both
94+
extern bool enableSafeCodeGen; // common for both
9495
extern bool disableMemRefPrefetch; // common for both
9596
extern uint64_t compilationNumThreads; // common for both
9697
extern EmissionTargetType emissionTarget; // onnx-mlir only

src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// Krnl IR and standard operations.
1313
//
1414
//===----------------------------------------------------------------------===//
15+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1516
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1617
#include "mlir/Dialect/SCF/IR/SCF.h"
1718
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -367,7 +368,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
367368
target.addLegalDialect<KrnlDialect, affine::AffineDialect,
368369
arith::ArithDialect, func::FuncDialect, linalg::LinalgDialect,
369370
math::MathDialect, vector::VectorDialect, memref::MemRefDialect,
370-
shape::ShapeDialect, scf::SCFDialect>();
371+
shape::ShapeDialect, scf::SCFDialect, cf::ControlFlowDialect>();
371372
// Needed to support unsigned int computations. To be removed if we use a
372373
// scheme that does not rely on the UnrealizedConversionCastOp.
373374
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();

src/Conversion/ONNXToKrnl/Tensor/Gather.cpp

+41-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17+
18+
#include "src/Compiler/CompilerOptions.hpp"
1519
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
1620
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
1721

@@ -37,7 +41,8 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
3741
Location loc = ONNXLoc<ONNXGatherOp>(op);
3842
ValueRange operands = adaptor.getOperands();
3943

40-
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
44+
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
45+
MathBuilder>
4146
create(rewriter, loc);
4247

4348
// Get shape.
@@ -122,6 +127,41 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
122127
if (indicesMayBeNegative)
123128
index = index.selectOrSelf(index < zeroIE, index + axisDim);
124129

130+
// The Gather op is data dependent: the value of index should be
131+
// within the input data size.
132+
// Add runtime check if enableSafeCodeGen is set true
133+
// Implementation comments vs. createGenerateRuntimeVerificationPass
134+
// This check is according to onnx op semantics, not general bound
135+
// check for memref. Implementation of RuntimeVerification could be
136+
// borrowed. Slightly difference is that onnx semenatics check is for
137+
// each dimension independently, not the final address is within
138+
// the memref bound.
139+
if (enableSafeCodeGen) {
140+
// From onnx document:
141+
// All index values are expected to be within bounds [-s, s-1]
142+
// along axis of size s. It is an error if any of the index values
143+
// are out of bounds.
144+
// After the negative correction, the range should be [0, s-1]
145+
Value upperBound = create.mem.dim(data, axisLit);
146+
Value compareUpperBound =
147+
create.math.slt(index.getValue(), upperBound);
148+
// Report onnx_node_name if the op has the attribute
149+
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
150+
StringAttr nodeName =
151+
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
152+
if (nodeName && !nodeName.getValue().empty()) {
153+
nodeNameStr = nodeNameStr + nodeName.getValue().str();
154+
}
155+
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
156+
nodeNameStr +
157+
" indices of GatherOp is larger than the upper bound");
158+
Value compareLowerBound =
159+
create.math.sge(index.getValue(), zeroIE.getValue());
160+
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
161+
nodeNameStr +
162+
" indices of GatherOp is less than the lower bound");
163+
}
164+
125165
// Compute access function of data: data[ii + (indices[jj],) + kk]
126166
SmallVector<IndexExpr, 4> dataAccessFct;
127167
// First add indices iis

src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17+
18+
#include "src/Compiler/CompilerOptions.hpp"
1519
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
1620
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
1721

@@ -31,7 +35,8 @@ struct ONNXGatherElementsOpLowering
3135
Location loc = ONNXLoc<ONNXGatherElementsOp>(op);
3236
ValueRange operands = adaptor.getOperands();
3337

34-
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
38+
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
39+
MathBuilder>
3540
create(rewriter, loc);
3641

3742
// Get shape.
@@ -93,6 +98,32 @@ struct ONNXGatherElementsOpLowering
9398
index = index.selectOrSelf(index < zero, index + axisDim);
9499
}
95100

101+
// Check the dynamic requirement of GatherElement Op
102+
// Refer to the comments in Gather.cpp
103+
if (enableSafeCodeGen) {
104+
// From onnx document:
105+
// All index values are expected to be within bounds [-s, s-1]
106+
// along axis of size s. It is an error if any of the index values
107+
// are out of bounds.
108+
// After the negative correction, the range should be [0, s-1]
109+
Value upperBound = create.mem.dim(data, axis);
110+
Value compareUpperBound =
111+
create.math.slt(index.getValue(), upperBound);
112+
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
113+
StringAttr nodeName =
114+
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
115+
if (nodeName && !nodeName.getValue().empty()) {
116+
nodeNameStr = nodeNameStr + nodeName.getValue().str();
117+
}
118+
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
119+
"indices of GatherOp is larger than the upper bound");
120+
LiteralIndexExpr zero(0);
121+
Value compareLowerBound =
122+
create.math.sge(index.getValue(), zero.getValue());
123+
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
124+
"indices of GatherOp is less than the lower bound");
125+
}
126+
96127
// Access function for the 'data' tensor.
97128
DimsExpr dataAccessFct;
98129
for (int64_t i = 0; i < dataRank; ++i)

0 commit comments

Comments
 (0)