-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] bind block predecessors and successors #145116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][python] bind block predecessors and successors #145116
Conversation
173c5c3
to
2160339
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) Changesbind Full diff: https://github.com/llvm/llvm-project/pull/145116.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 1a8e8737f7fed..30763c0c8c052 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -986,6 +986,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
+/// Returns the number of successor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
+
+/// Returns `pos`-th successor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
+ intptr_t pos);
+
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cbd35f2974ae9..6f9db50e2aaa1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2626,6 +2626,45 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
PyOperationRef operation;
};
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+ static constexpr const char *pyClassName = "BlockSuccessors";
+
+ PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumSuccessors(block.get())
+ : length,
+ step),
+ operation(operation), block(block) {}
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+ intptr_t getRawNumElements() {
+ block.checkValid();
+ return mlirBlockGetNumSuccessors(block.get());
+ }
+
+ PyBlock getRawElement(intptr_t pos) {
+ MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+ return PyBlock(operation, block);
+ }
+
+ PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyBlockSuccessors(block, operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+ PyBlock block;
+};
+
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
@@ -3655,7 +3694,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
- "in another block, it will be moved.");
+ "in another block, it will be moved.")
+ .def_prop_ro(
+ "successors",
+ [](PyBlock &self) {
+ return PyBlockSuccessors(self, self.getParentOperation());
+ },
+ "Returns the list of Block successors.");
//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
@@ -4099,6 +4144,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
PyBlockList::bind(m);
+ PyBlockSuccessors::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e0e386d55ede1..7fa831da3a4d8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1059,6 +1059,14 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
unwrap(block)->print(stream);
}
+intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
+ return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
+}
+
+MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
+ return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
+}
+
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 70ccaeeb5435b..551b2181b6552 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -1,12 +1,11 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
-import io
-import itertools
-from mlir.ir import *
+
from mlir.dialects import builtin
from mlir.dialects import cf
from mlir.dialects import func
+from mlir.ir import *
def run(f):
@@ -54,10 +53,22 @@ def testBlockCreation():
with InsertionPoint(middle_block) as middle_ip:
assert middle_ip.block == middle_block
cf.BranchOp([i32_arg], dest=successor_block)
+
module.print(enable_debug_info=True)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region
+ entry_block_successors = entry_block.successors
+ assert len(entry_block_successors) == 1
+ assert middle_block == entry_block_successors[0]
+
+ middle_block_successors = middle_block.successors
+ assert len(middle_block_successors) == 1
+ assert successor_block == middle_block_successors[0]
+
+ successor_block_successors = successor_block.successors
+ assert len(successor_block_successors) == 0
+
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run
|
04ef6e4
to
c49fd8a
Compare
c49fd8a
to
69186b6
Compare
69186b6
to
bedc679
Compare
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, | |||
MLIR_CAPI_EXPORTED void | |||
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); | |||
|
|||
/// Returns the number of successor blocks of the block. | |||
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have tests for the C API as well
Block::pred_iterator it = b->pred_begin(); | ||
std::advance(it, pos); | ||
return wrap(*it); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather avoid iterating over the use-def list every time... This goes through block's use-def chain, maybe there is a way to expose a BlockOperand
(and incidentally OpOperand
if it isn't) and a getNextUse
.
bind
block.getSuccessor
andblock.getPredecessors
.