Skip to content

Commit

Permalink
【CINN】Add equal function for IndexExpr (PaddlePaddle#68987)
Browse files Browse the repository at this point in the history
* equal IndexExpr

* add comment.

* fix bug

* Empty-Commit

* fix bug

* add equal in vector

* add unit test.
  • Loading branch information
liuruyan authored and fxfxfxfxfxfxfxfx committed Oct 29, 2024
1 parent 8e3f283 commit 3196505
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 102 deletions.
51 changes: 8 additions & 43 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,23 +502,7 @@ Expr min(Expr a, Expr b) {
return ir::Min::Make(a, b);
}

int32_t CalculateExprComplexity(const Expr &expr, int count) {
switch (expr.node_type()) {
case ir::IrNodeTy::_Var_:
case ir::IrNodeTy::IntImm:
return count + 1;
case ir::IrNodeTy::Add:
case ir::IrNodeTy::Mul:
case ir::IrNodeTy::Div:
case ir::IrNodeTy::Mod: {
int lhs_count = CalculateExprComplexity(expr->operand(0), count);
int rhs_count = CalculateExprComplexity(expr->operand(1), count);
return lhs_count + rhs_count + 1;
}
}
}

bool IsCorrectPriority(const Expr &lhs, const Expr &rhs) {
bool ComparePriority(const ir::IndexExpr &lhs, const ir::IndexExpr &rhs) {
if (lhs.node_type() == ir::IrNodeTy::IntImm &&
rhs.node_type() != ir::IrNodeTy::IntImm)
return false;
Expand All @@ -529,35 +513,16 @@ bool IsCorrectPriority(const Expr &lhs, const Expr &rhs) {
if (auto rhsVar = rhs.As<ir::_Var_>())
return std::make_tuple(lhsVar->name.length(), lhsVar->name) <=
std::make_tuple(rhsVar->name.length(), rhsVar->name);
auto lhsComplexity = CalculateExprComplexity(lhs);
auto rhsComplexity = CalculateExprComplexity(rhs);
if (lhsComplexity < rhsComplexity) return false;
// Mul < Div < Mod.
else if (lhsComplexity == rhsComplexity)
auto lhsLen = lhs.length();
auto rhsLen = rhs.length();
if (lhsLen < rhsLen) return false;
// Add < Mul < Div < Mod.
else if (lhsLen == rhsLen)
return lhs.node_type() <= rhs.node_type();
else
return true;
}

ir::IndexExpr MulAndNormalize(const ir::IndexExpr &lhs,
const ir::IndexExpr &rhs) {
int64_t cscale = 1;
ir::IndexExpr res = ir::One(lhs.type());
auto fcollect = [&](ir::IndexExpr val) {
if (const auto *intimm = val.As<ir::IntImm>()) {
cscale *= intimm->value;
} else {
res = res * val;
}
};
UnpackReduction<ir::Mul>(lhs, fcollect);
UnpackReduction<ir::Mul>(rhs, fcollect);
if (cscale != 1) {
res = res * ir::IndexExpr(make_shared<ir::IntImm>(res.type(), cscale));
}
return res;
}

bool IsSumPartialBySymbol(const ir::IndexExpr &expr,
const ir::IndexExpr &symbol) {
// TODO(liujinnan): Check Ty
Expand All @@ -568,8 +533,6 @@ bool IsSumPartialBySymbol(const ir::IndexExpr &expr,
case ir::IrNodeTy::_Var_:
return expr == symbol;
case ir::IrNodeTy::Add:
[[fallthrough]];
case ir::IrNodeTy::Sub:
return IsSumPartialBySymbol(expr->operand(0).as_index(), symbol) ||
IsSumPartialBySymbol(expr->operand(1).as_index(), symbol);
case ir::IrNodeTy::Mul:
Expand Down Expand Up @@ -601,6 +564,8 @@ bool IsDivisiblieBySymbol(const ir::IndexExpr &expr,
return IsDivisiblieBySymbol(expr->operand(0).as_index(), symbol, ty) ||
IsDivisiblieBySymbol(expr->operand(1).as_index(), symbol, ty);
case ir::IrNodeTy::Mod:
// Because S0 % 3 + S0 % 5 is not divisiblie by S0, so we push
// `expr.node_type()` into third parameter.
return IsDivisiblieBySymbol(
expr->operand(0).as_index(), symbol, expr.node_type()) &&
IsDivisiblieBySymbol(
Expand Down
119 changes: 104 additions & 15 deletions paddle/cinn/common/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,36 +178,125 @@ inline bool IsZero(const Expr &expr) {
return false;
}

/*!
* \brief Apply func `fleaf` into each leaf node of `expr`.
* which leaf node is the most outside node that has TNode type.
* \param expr The expression to be applied.
* \param fleaf The function to be applied.
*/
template <typename TNode, typename FLeaf>
inline void UnpackReduction(const ir::IndexExpr &value, FLeaf fleaf) {
if (const TNode *node = value.As<TNode>()) {
inline void UnpackReduction(const ir::IndexExpr &expr, FLeaf fleaf) {
if (const TNode *node = expr.As<TNode>()) {
UnpackReduction<TNode, FLeaf>(node->a(), fleaf);
UnpackReduction<TNode, FLeaf>(node->b(), fleaf);
} else {
fleaf(value);
fleaf(expr);
}
}

// TODO(liuruyan): canby simplify into IndexExpr multiply.
ir::IndexExpr MulAndNormalize(const ir::IndexExpr &lhs,
const ir::IndexExpr &rhs);

int32_t CalculateExprComplexity(const Expr &expr, int count = 0);

// True means don't change sequence
bool IsCorrectPriority(const Expr &lhs, const Expr &rhs);
/*!
* \brief Flattern the expression into a vector of expressions splited by `Add`
* or `Mul`.
*
* For example (Add):
* 1. `S0 + S1` ==> {S0, S1}
* 2. `S0 + S1 * S2` ==> {S0, S1 * S2}
* 3. `S0 + S1 * (S2 + S3)` ==> {S0, S1 * (S2 + S3)}
*
* \param lhs The left hand side expression to be compared.
* \param rhs The right hand side expression to be compared.
* \return A boolean value indicating whether the priority of `lhs` is higher
* than `rhs`.
*/
template <typename T>
inline std::vector<ir::IndexExpr> GetFlatternExprs(const ir::IndexExpr &expr) {
std::vector<ir::IndexExpr> result;
auto fcollect = [&](ir::IndexExpr val) { result.push_back(val); };
UnpackReduction<T>(expr, fcollect);
return result;
}

/*!
* \brief Compare the priority of the two expressions. this func follows the
* above rules:
* 1. if lhs = var, rhs = const, return true;
* 2. if lhs = const, rhs = var, return false;
* 3. if lhs = var, rhs = var, return lhs_var_name <= lhs_var_name;
* 4. if lhs.length > rhs.length, return true;
* 5. if lhs.length == rhs.length, return lhs_type <= rhs_type; (Add < Mul <
* Div < Mod)
* 6. if lhs.length < rhs.length return false;
*
* For example:
* 1. `ComparePriority(S0, 2)` return true;
* 2. `ComparePriority(S0, S0)` return true;
* 2. `ComparePriority(S0, S1)` return false;
* 3. `ComparePriority(S0, S1 + 1)` return false;
* 4. `ComparePriority(S0 % 2, S1 + 1)` return false;
*
* \param lhs The left hand side expression to be compared.
* \param rhs The right hand side expression to be compared.
* \return A boolean value indicating whether the priority of `lhs` is higher
* than `rhs`.
*/
bool ComparePriority(const ir::IndexExpr &lhs, const ir::IndexExpr &rhs);

/*!
* \brief Determines whether there are sub-parts in the `expr` that can be
* simplified by `Add` operation with the input `symbol`. If true is returned,
* the operation will be attempted on each subpart in outter
* `SimplifySymbolicAdd` function.
*
* For example:
* 1. `IsSumPartialBySymbol(5, S0)` return false;
* 2. `IsSumPartialBySymbol(S0, S0)` return true;
* 3. `IsSumPartialBySymbol(S0 + S1, S1)` return true;
* 4. `IsSumPartialBySymbol(S0 * 5 + S1, S0)` return true;
* 5. `IsSumPartialBySymbol(S0 / 3, S0)` return true;
* 6. `IsSumPartialBySymbol(S0 / 3 + S1, S0)` return true;
* 7. `IsSumPartialBySymbol(S0 % 3, S0)` return false;
*
* \param expr The expression to be checked.
* \param symbol The symbol to be checked.
* \return True means there are sub-parts in the `expr` that can be simplified.
*/
bool IsSumPartialBySymbol(const ir::IndexExpr &expr,
const ir::IndexExpr &symbol);

// If true is returned, the operation will be attempted on each subpart in
// outter `simplify` function. Note: this func dont deal the corner case, e.g.
// `IsDivisiblieBySymbol(f % S0 - f, S0)` is `false`. please use
// `ProveDivisible` for exact result.
/*!
* \brief Determines whether there are sub-parts in the `expr` that can be
* simplified by `Div` operation with the input `symbol`. If true is returned,
* the operation will be attempted on each subpart in outter
* `SimplifySymbolicDivide` function.
*
* For example:
* 1. `IsDivisiblieBySymbol(5, S0, div)` return false;
* 2. `IsDivisiblieBySymbol(S0, S0, div)` return true;
* 3. `IsDivisiblieBySymbol(S0 + S1, S1, div)` return false;
* 4. `IsDivisiblieBySymbol(S0 * 5 + S1 * S2, S0, div)` return true;
* 5. `IsDivisiblieBySymbol(S0 / 3, S0, div)` return true;
* 6. `IsDivisiblieBySymbol(S0 * 4 / 3, S0, div)` return true;
* 7. `IsDivisiblieBySymbol(S0 % 3, S0, div)` return false;
* 8. `IsDivisiblieBySymbol(S0 / 3, S0, mod)` return false;
*
* \param expr The expression to be checked.
* \param symbol The symbol to be checked.
* \param ty ty is `Mod` or `Div`.
* \return True means there are sub-parts in the `expr` that can be simplified.
* \note this func dont deal the corner case, please use `ProveDivisible` for
* exact result. e.g. `IsDivisiblieBySymbol(f % S0 - f, S0, div)` is false
*/
bool IsDivisiblieBySymbol(const ir::IndexExpr &expr,
const ir::IndexExpr &symbol,
const ir::IrNodeTy &ty);

/*!
* \brief Determine whether `lhs` is divisible by `rhs`, regardless of whether
* `rhs` is a constant or a symbol.
* \param lhs lhs is dividend.
* \param rhs rhs is divisor.
* \return A boolean value indicating whether the `lhs` is divisible by `rhs`
*/
bool ProveDivisible(const ir::IndexExpr &lhs, const ir::IndexExpr &rhs);

} // namespace common
Expand Down
11 changes: 5 additions & 6 deletions paddle/cinn/common/iter_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,8 @@ std::optional<ir::IndexExpr> IterMapRewriter::TryFuse(
grouped_iters.push_back(arg_copy);

// Update expected_scale = matched_split->scale * matched_split->extent
expected_scale = MulAndNormalize(
iter_sum->args[matched_pos].As<ir::IterSplit>()->extent, matched_scale);
expected_scale =
iter_sum->args[matched_pos].As<ir::IterSplit>()->extent * matched_scale;
}
std::reverse(grouped_iters.begin(), grouped_iters.end());
ir::IndexExpr grouped_sum =
Expand Down Expand Up @@ -585,8 +585,7 @@ std::optional<ir::IndexExpr> IterMapRewriter::TryFuseSameSource(
// 2. lhs->scale == rhs->extent * rhs->scale.
// 3. lhs->lower_factor == rhs->lower_factor * rhs->extent.
while (true) {
ir::IndexExpr lhs_scale =
MulAndNormalize(rhs_iter->extent, rhs_iter->scale);
ir::IndexExpr lhs_scale = rhs_iter->extent * rhs_iter->scale;
matched_index = FindSplitWithExactScale(*iter_sum,
visited,
lhs_scale,
Expand All @@ -596,13 +595,13 @@ std::optional<ir::IndexExpr> IterMapRewriter::TryFuseSameSource(
if (matched_index == -1) break;
auto lhs_iter = iter_sum->args[matched_index].As<ir::IterSplit>();
ir::IndexExpr lhs_lower_factor =
MulAndNormalize(rhs_iter->lower_factor, rhs_iter->extent);
rhs_iter->lower_factor * rhs_iter->extent;
if (!analyzer_.ProveEQ(lhs_iter->lower_factor, lhs_lower_factor)
.value_or(false))
break;
visited[matched_index] = true;

rhs_iter->extent = MulAndNormalize(lhs_iter->extent, rhs_iter->extent);
rhs_iter->extent = lhs_iter->extent * rhs_iter->extent;
}
reverse_flattened_iters.push_back(split_copy);
}
Expand Down
33 changes: 26 additions & 7 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,26 @@ int64_t IndexExpr::GetLargestMutiplyPart() const {
::common::errors::Unimplemented("Unsupported type of expr: %s", type()));
}

int32_t IndexExpr::length(int32_t count) const {
switch (node_type()) {
case ir::IrNodeTy::_Var_:
[[fallthrough]];
case ir::IrNodeTy::IntImm:
return count + 1;
case ir::IrNodeTy::Add:
[[fallthrough]];
case ir::IrNodeTy::Mul:
[[fallthrough]];
case ir::IrNodeTy::Div:
[[fallthrough]];
case ir::IrNodeTy::Mod: {
int lhs_count = ptr()->operand(0).as_index().length(count);
int rhs_count = ptr()->operand(1).as_index().length(count);
return lhs_count + rhs_count + 1;
}
}
}

IndexExpr ConstructIndexExprByNodeType(const IrNodeTy &ty,
const IndexExpr &lhs,
const IndexExpr &rhs) {
Expand Down Expand Up @@ -1682,10 +1702,9 @@ IndexExpr Simplify(const IndexExpr &expr) {
case ir::IrNodeTy::Div:
[[fallthrough]];
case ir::IrNodeTy::Mod: {
auto a1 = Simplify(expr->operand(0).as_index());
auto a2 = Simplify(expr->operand(1).as_index());

return ConstructIndexExprByNodeType(expr.node_type(), a1, a2);
auto lhs = Simplify(expr->operand(0).as_index());
auto rhs = Simplify(expr->operand(1).as_index());
return ConstructIndexExprByNodeType(expr.node_type(), lhs, rhs);
}
}
}
Expand All @@ -1698,13 +1717,13 @@ static IndexExpr SimplifyAdd(const IndexExpr &lhs, const IndexExpr &rhs) {
return constRes.value().as_index();
// 3 + d0 ===> d0 + 3.
// d0 + (d1 + d2) ===> (d1 + d2) + d0.
if (!IsCorrectPriority(lhs, rhs)) {
if (!ComparePriority(lhs, rhs)) {
return rhs + lhs;
}

// (d0 + d1) + (d2 + d3) ===> ((d0 + d1) + d2) + d3.
if (auto rhsAdd = rhs.As<Add>()) {
return lhs * rhsAdd->a().as_index() * rhsAdd->b().as_index();
return lhs + rhsAdd->a().as_index() + rhsAdd->b().as_index();
}

// (d0 + 2) + 3 ===> d0 + 5.
Expand Down Expand Up @@ -1769,7 +1788,7 @@ static IndexExpr SimplifyMul(const IndexExpr &lhs, const IndexExpr &rhs) {

// 3 * d0 ===> d0 * 3.
// d0 * (d1 + d2) ===> (d1 + d2) * d0.
if (!IsCorrectPriority(lhs, rhs)) {
if (!ComparePriority(lhs, rhs)) {
return rhs * lhs;
}

Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,16 @@ struct IndexExpr : public Expr {

int64_t GetLargestMutiplyPart() const;

IndexExpr& operator=(const IndexExpr& other);

IndexExpr Normalize() const;

// count the `IndeExpr` length, each node has weight 1, e.g.
// S0, length = 1
// S0 + S1, length = 3
// S0 + S1 * 2, length = 5
int32_t length(int32_t count = 0) const;

IndexExpr& operator=(const IndexExpr& other);

IndexExpr operator-() const;

#define DEFINE_OPERATOR(op) \
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/ir_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ bool Expr::is_var() const { return As<_Var_>(); }

bool Expr::is_index() const {
switch (node_type()) {
case ir::IrNodeTy::Cast:
[[fallthrough]];
case ir::IrNodeTy::_Var_:
return true;
case ir::IrNodeTy::IntImm: {
Expand Down
Loading

0 comments on commit 3196505

Please sign in to comment.