Skip to content

Commit

Permalink
Fold 0 <<,>>,/,% n to 0. Fold a / 1 to a. Fold a % 1 to 0. Fold f % 1…
Browse files Browse the repository at this point in the history
….0 to 0.0. Fold 0.0 % f to 0.0 (#6020)

* Fold 0 <<,>>,/,% n to 0. Fold a / 1 to a. Fold a % 1 to 0

* Fold OpFMod (f % 1.0) = 0.0 and (0.0 % f) = 0.0
  • Loading branch information
nabijaczleweli authored Mar 7, 2025
1 parent bac6ca7 commit ba828b2
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 36 deletions.
143 changes: 131 additions & 12 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2454,20 +2454,51 @@ FoldingRule RedundantFDiv() {
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);

if (kind0 == FloatConstantKind::Zero) {
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::One) {
inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}

if (kind1 == FloatConstantKind::One) {
return false;
};
}

FoldingRule RedundantFMod() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(inst->opcode() == spv::Op::OpFMod &&
"Wrong opcode. Should be OpFMod.");
assert(constants.size() == 2);

if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}

FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);

if (kind0 == FloatConstantKind::Zero) {
inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
return true;
}

if (kind1 == FloatConstantKind::One) {
auto type = context->get_type_mgr()->GetType(inst->type_id());
std::vector<uint32_t> zero_words;
zero_words.resize(ElementWidth(type) / 32);
auto const_mgr = context->get_constant_mgr();
auto zero = const_mgr->GetConstant(type, std::move(zero_words));
auto zero_id = const_mgr->GetDefiningInstruction(zero)->result_id();

inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
return true;
}

return false;
};
}
Expand Down Expand Up @@ -2507,15 +2538,16 @@ FoldingRule RedundantFMix() {
};
}

// Returns a folding rule that folds the instruction to the operand not being
// checked if the operand that is checked is zero.
FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg) {
return [arg](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
// Returns a folding rule that folds the instruction to operand |foldToArg|
// (0 or 1) if operand |arg| (0 or 1) is a zero constant.
FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg, uint32_t foldToArg) {
return [arg, foldToArg](
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(constants.size() == 2);

if (constants[arg] && constants[arg]->IsZero()) {
auto operand = inst->GetSingleWordInOperand(1 - arg);
auto operand = inst->GetSingleWordInOperand(foldToArg);
auto operand_type = constants[arg]->type();

const analysis::Type* inst_type =
Expand All @@ -2533,7 +2565,7 @@ FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg) {
}

// This rule handles any of RedundantBinaryRhs0Ops with a 0 or vector 0 on the
// right-hand side.
// right-hand side (a | 0 => a).
static const constexpr spv::Op RedundantBinaryRhs0Ops[] = {
spv::Op::OpBitwiseOr,
spv::Op::OpBitwiseXor,
Expand All @@ -2548,11 +2580,11 @@ FoldingRule RedundantBinaryRhs0(spv::Op op) {
op) != std::end(RedundantBinaryRhs0Ops) &&
"Wrong opcode.");
(void)op;
return RedundantBinaryOpWithZeroOperand(1);
return RedundantBinaryOpWithZeroOperand(1, 0);
}

// This rule handles any of RedundantBinaryLhs0Ops with a 0 or vector 0 on the
// left-hand side.
// left-hand side (0 | a => a).
static const constexpr spv::Op RedundantBinaryLhs0Ops[] = {
spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpIAdd};
FoldingRule RedundantBinaryLhs0(spv::Op op) {
Expand All @@ -2561,7 +2593,86 @@ FoldingRule RedundantBinaryLhs0(spv::Op op) {
op) != std::end(RedundantBinaryLhs0Ops) &&
"Wrong opcode.");
(void)op;
return RedundantBinaryOpWithZeroOperand(0);
return RedundantBinaryOpWithZeroOperand(0, 1);
}

// This rule handles shifts and divisions of 0 or vector 0 by any amount
// (0 >> a => 0).
static const constexpr spv::Op RedundantBinaryLhs0To0Ops[] = {
spv::Op::OpShiftRightLogical,
spv::Op::OpShiftRightArithmetic,
spv::Op::OpShiftLeftLogical,
spv::Op::OpSDiv,
spv::Op::OpUDiv,
spv::Op::OpSMod,
spv::Op::OpUMod};
FoldingRule RedundantBinaryLhs0To0(spv::Op op) {
assert(std::find(std::begin(RedundantBinaryLhs0To0Ops),
std::end(RedundantBinaryLhs0To0Ops),
op) != std::end(RedundantBinaryLhs0To0Ops) &&
"Wrong opcode.");
(void)op;
return RedundantBinaryOpWithZeroOperand(0, 0);
}

// Returns true if all elements in |c| are 1.
bool IsAllInt1(const analysis::Constant* c) {
if (auto composite = c->AsCompositeConstant()) {
auto& components = composite->GetComponents();
return std::all_of(std::begin(components), std::end(components), IsAllInt1);
} else if (c->AsIntConstant()) {
return c->GetSignExtendedValue() == 1;
}

return false;
}

// This rule handles divisions by 1 or vector 1 (a / 1 => a).
FoldingRule RedundantSUDiv() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(constants.size() == 2);
assert((inst->opcode() == spv::Op::OpUDiv ||
inst->opcode() == spv::Op::OpSDiv) &&
"Wrong opcode.");

if (constants[1] && IsAllInt1(constants[1])) {
auto operand = inst->GetSingleWordInOperand(0);
auto operand_type = constants[1]->type();

const analysis::Type* inst_type =
context->get_type_mgr()->GetType(inst->type_id());
if (inst_type->IsSame(operand_type)) {
inst->SetOpcode(spv::Op::OpCopyObject);
} else {
inst->SetOpcode(spv::Op::OpBitcast);
}
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
return true;
}
return false;
};
}

// This rule handles modulo from division by 1 or vector 1 (a % 1 => 0).
FoldingRule RedundantSUMod() {
return [](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
assert(constants.size() == 2);
assert((inst->opcode() == spv::Op::OpUMod ||
inst->opcode() == spv::Op::OpSMod) &&
"Wrong opcode.");

if (constants[1] && IsAllInt1(constants[1])) {
auto type = context->get_type_mgr()->GetType(inst->type_id());
auto zero_id = context->get_constant_mgr()->GetNullConstId(type);

inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
return true;
}
return false;
};
}

// This rule look for a dot with a constant vector containing a single 1 and
Expand Down Expand Up @@ -2905,6 +3016,12 @@ void FoldingRules::AddFoldingRules() {
rules_[op].push_back(RedundantBinaryRhs0(op));
for (auto op : RedundantBinaryLhs0Ops)
rules_[op].push_back(RedundantBinaryLhs0(op));
for (auto op : RedundantBinaryLhs0To0Ops)
rules_[op].push_back(RedundantBinaryLhs0To0(op));
rules_[spv::Op::OpSDiv].push_back(RedundantSUDiv());
rules_[spv::Op::OpUDiv].push_back(RedundantSUDiv());
rules_[spv::Op::OpSMod].push_back(RedundantSUMod());
rules_[spv::Op::OpUMod].push_back(RedundantSUMod());

rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());

Expand Down Expand Up @@ -2937,6 +3054,8 @@ void FoldingRules::AddFoldingRules() {
rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());

rules_[spv::Op::OpFMod].push_back(RedundantFMod());

rules_[spv::Op::OpFMul].push_back(RedundantFMul());
rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());
Expand Down
Loading

0 comments on commit ba828b2

Please sign in to comment.