Skip to content

[mlir][python] bind block predecessors and successors #145116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);

/// Returns the number of successor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have tests for the C API as well


/// Returns `pos`-th successor of the block.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
intptr_t pos);

/// Returns the number of predecessor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);

/// Returns `pos`-th predecessor of the block.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
intptr_t pos);

//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
Expand Down
95 changes: 94 additions & 1 deletion mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2626,6 +2626,85 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
PyOperationRef operation;
};

/// A list of block successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation and block whose successors these are, and thus
/// extends the lifetime of this operation and block.
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockSuccessors";

PyBlockSuccessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumSuccessors(block.get())
: length,
step),
operation(operation), block(block) {}

private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockSuccessors, PyBlock>;

intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumSuccessors(block.get());
}

PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
return PyBlock(operation, block);
}

PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
return PyBlockSuccessors(block, operation, startIndex, length, step);
}

PyOperationRef operation;
PyBlock block;
};

/// A list of block predecessors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) predecessor list is
/// associated with the operation and block whose predecessors these are, and
/// thus extends the lifetime of this operation and block.
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockPredecessors";

PyBlockPredecessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumPredecessors(block.get())
: length,
step),
operation(operation), block(block) {}

private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockPredecessors, PyBlock>;

intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumPredecessors(block.get());
}

PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
return PyBlock(operation, block);
}

PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
intptr_t step) {
return PyBlockPredecessors(block, operation, startIndex, length, step);
}

PyOperationRef operation;
PyBlock block;
};

/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
Expand Down Expand Up @@ -3655,7 +3734,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
"in another block, it will be moved.");
"in another block, it will be moved.")
.def_prop_ro(
"successors",
[](PyBlock &self) {
return PyBlockSuccessors(self, self.getParentOperation());
},
"Returns the list of Block successors.")
.def_prop_ro(
"predecessors",
[](PyBlock &self) {
return PyBlockPredecessors(self, self.getParentOperation());
},
"Returns the list of Block predecessors.");

//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
Expand Down Expand Up @@ -4099,6 +4190,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
PyBlockList::bind(m);
PyBlockSuccessors::bind(m);
PyBlockPredecessors::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
unwrap(block)->print(stream);
}

intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
}

MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
}

intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
Block *b = unwrap(block);
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
}

MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
Block *b = unwrap(block);
Block::pred_iterator it = b->pred_begin();
std::advance(it, pos);
return wrap(*it);
Comment on lines +1077 to +1079
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather avoid iterating over the use-def list every time... This goes through block's use-def chain, maybe there is a way to expose a BlockOperand (and incidentally OpOperand if it isn't) and a getNextUse.

Copy link
Contributor Author

@makslevental makslevental Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so

/// Implement a predecessor iterator for blocks. This works by walking the use

Compare with SuccessorRange just below there. But maybe I'm wrong and it's just not clicking for me.

}

//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 17 additions & 3 deletions mlir/test/python/ir/blocks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# RUN: %PYTHON %s | FileCheck %s

import gc
import io
import itertools
from mlir.ir import *

from mlir.dialects import builtin
from mlir.dialects import cf
from mlir.dialects import func
from mlir.ir import *


def run(f):
Expand Down Expand Up @@ -54,10 +53,25 @@ def testBlockCreation():
with InsertionPoint(middle_block) as middle_ip:
assert middle_ip.block == middle_block
cf.BranchOp([i32_arg], dest=successor_block)

module.print(enable_debug_info=True)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region

assert len(entry_block.predecessors) == 0

assert len(entry_block.successors) == 1
assert middle_block == entry_block.successors[0]
assert len(middle_block.predecessors) == 1
assert entry_block == middle_block.predecessors[0]

assert len(middle_block.successors) == 1
assert successor_block == middle_block.successors[0]
assert len(successor_block.predecessors) == 1
assert middle_block == successor_block.predecessors[0]

assert len(successor_block.successors) == 0


# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
Expand Down