Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runtime check for Gather Op #3069

Merged
merged 5 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool disableQuantZeroPoint; // common for both
bool enableKrnlBufferReuse; // common for both
bool enableSafeCodeGen; // common for both
bool disableMemRefPrefetch; // common for both
uint64_t compilationNumThreads; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down Expand Up @@ -245,6 +246,16 @@ static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableSafeCodeGenOpt("enable-safe-code-gen",
llvm::cl::desc("enable extra runtime check to be created in code gen. "
"Such check will have cost at runtime, and is not needed if"
"the model and the data are correct."
"Failure of check will trigger assertion error."
"(default=false).\n"
"Set to 'true' if you want to enable the check."),
llvm::cl::location(enableSafeCodeGen), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n"
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool disableQuantZeroPoint; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool enableSafeCodeGen; // common for both
extern bool disableMemRefPrefetch; // common for both
extern uint64_t compilationNumThreads; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
Expand Down Expand Up @@ -367,7 +368,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
target.addLegalDialect<KrnlDialect, affine::AffineDialect,
arith::ArithDialect, func::FuncDialect, linalg::LinalgDialect,
math::MathDialect, vector::VectorDialect, memref::MemRefDialect,
shape::ShapeDialect, scf::SCFDialect>();
shape::ShapeDialect, scf::SCFDialect, cf::ControlFlowDialect>();
// Needed to support unsigned int computations. To be removed if we use a
// scheme that does not rely on the UnrealizedConversionCastOp.
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();
Expand Down
42 changes: 41 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/Gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

Expand All @@ -37,7 +41,8 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
Location loc = ONNXLoc<ONNXGatherOp>(op);
ValueRange operands = adaptor.getOperands();

MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
MathBuilder>
create(rewriter, loc);

// Get shape.
Expand Down Expand Up @@ -122,6 +127,41 @@ struct ONNXGatherOpLowering : public OpConversionPattern<ONNXGatherOp> {
if (indicesMayBeNegative)
index = index.selectOrSelf(index < zeroIE, index + axisDim);

// The Gather op is data dependent: the value of index should be
// within the input data size.
// Add runtime check if enableSafeCodeGen is set true
// Implementation comments vs. createGenerateRuntimeVerificationPass
// This check is according to onnx op semantics, not general bound
// check for memref. Implementation of RuntimeVerification could be
// borrowed. Slightly difference is that onnx semenatics check is for
// each dimension independently, not the final address is within
// the memref bound.
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axisLit);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
// Report onnx_node_name if the op has the attribute
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
nodeNameStr +
" indices of GatherOp is larger than the upper bound");
Value compareLowerBound =
create.math.sge(index.getValue(), zeroIE.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chentong319 I know it is not optimized for speed, "anding" both condition and calling assert only once would speed the check up a bit. create.math.andi(compareUpperBound, compareLowerBound).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate the check to provide more accurate error message.

nodeNameStr +
" indices of GatherOp is less than the lower bound");
}

// Compute access function of data: data[ii + (indices[jj],) + kk]
SmallVector<IndexExpr, 4> dataAccessFct;
// First add indices iis
Expand Down
33 changes: 32 additions & 1 deletion src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"

Expand All @@ -31,7 +35,8 @@ struct ONNXGatherElementsOpLowering
Location loc = ONNXLoc<ONNXGatherElementsOp>(op);
ValueRange operands = adaptor.getOperands();

MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder>
MultiDialectBuilder<KrnlBuilder, IndexExprBuilderForKrnl, MemRefBuilder,
MathBuilder>
create(rewriter, loc);

// Get shape.
Expand Down Expand Up @@ -93,6 +98,32 @@ struct ONNXGatherElementsOpLowering
index = index.selectOrSelf(index < zero, index + axisDim);
}

// Check the dynamic requirement of GatherElement Op
// Refer to the comments in Gather.cpp
if (enableSafeCodeGen) {
// From onnx document:
// All index values are expected to be within bounds [-s, s-1]
// along axis of size s. It is an error if any of the index values
// are out of bounds.
// After the negative correction, the range should be [0, s-1]
Value upperBound = create.mem.dim(data, axis);
Value compareUpperBound =
create.math.slt(index.getValue(), upperBound);
std::string nodeNameStr = op->getName().getStringRef().str() + " ";
StringAttr nodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (nodeName && !nodeName.getValue().empty()) {
nodeNameStr = nodeNameStr + nodeName.getValue().str();
}
rewriter.create<cf::AssertOp>(loc, compareUpperBound,
"indices of GatherOp is larger than the upper bound");
LiteralIndexExpr zero(0);
Value compareLowerBound =
create.math.sge(index.getValue(), zero.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can use andi.

"indices of GatherOp is less than the lower bound");
}

// Access function for the 'data' tensor.
DimsExpr dataAccessFct;
for (int64_t i = 0; i < dataRank; ++i)
Expand Down