Skip to content

Commit 173c5c3

Browse files
committed
[mlir][python] bind block successors
1 parent b8355a7 commit 173c5c3

File tree

4 files changed

+75
-4
lines changed

4 files changed

+75
-4
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
986986
MLIR_CAPI_EXPORTED void
987987
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
988988

989+
/// Returns the number of successor blocks of the block.
990+
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
991+
992+
/// Returns `pos`-th successor of the block.
993+
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
994+
intptr_t pos);
995+
989996
//===----------------------------------------------------------------------===//
990997
// Value API.
991998
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2626,6 +2626,45 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
26262626
PyOperationRef operation;
26272627
};
26282628

2629+
/// A list of block successors. Internally, these are stored as consecutive
2630+
/// elements, random access is cheap. The (returned) successor list is
2631+
/// associated with the operation and block whose successors these are, and thus
2632+
/// extends the lifetime of this operation and block.
2633+
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
2634+
public:
2635+
static constexpr const char *pyClassName = "BlockSuccessors";
2636+
2637+
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
2638+
intptr_t startIndex = 0, intptr_t length = -1,
2639+
intptr_t step = 1)
2640+
: Sliceable(startIndex,
2641+
length == -1 ? mlirBlockGetNumSuccessors(block.get())
2642+
: length,
2643+
step),
2644+
operation(operation), block(block) {}
2645+
2646+
private:
2647+
/// Give the parent CRTP class access to hook implementations below.
2648+
friend class Sliceable<PyBlockSuccessors, PyBlock>;
2649+
2650+
intptr_t getRawNumElements() {
2651+
block.checkValid();
2652+
return mlirBlockGetNumSuccessors(block.get());
2653+
}
2654+
2655+
PyBlock getRawElement(intptr_t pos) {
2656+
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
2657+
return PyBlock(operation, block);
2658+
}
2659+
2660+
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2661+
return PyBlockSuccessors(block, operation, startIndex, length, step);
2662+
}
2663+
2664+
PyOperationRef operation;
2665+
PyBlock block;
2666+
};
2667+
26292668
/// A list of operation attributes. Can be indexed by name, producing
26302669
/// attributes, or by index, producing named attributes.
26312670
class PyOpAttributeMap {
@@ -3655,7 +3694,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
36553694
},
36563695
nb::arg("operation"),
36573696
"Appends an operation to this block. If the operation is currently "
3658-
"in another block, it will be moved.");
3697+
"in another block, it will be moved.")
3698+
.def_prop_ro(
3699+
"successors",
3700+
[](PyBlock &self) {
3701+
return PyBlockSuccessors(self, self.getParentOperation());
3702+
},
3703+
"Returns the list of Block successors.");
36593704

36603705
//----------------------------------------------------------------------------
36613706
// Mapping of PyInsertionPoint.
@@ -4099,6 +4144,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40994144
PyBlockArgumentList::bind(m);
41004145
PyBlockIterator::bind(m);
41014146
PyBlockList::bind(m);
4147+
PyBlockSuccessors::bind(m);
41024148
PyOperationIterator::bind(m);
41034149
PyOperationList::bind(m);
41044150
PyOpAttributeMap::bind(m);

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,14 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
10591059
unwrap(block)->print(stream);
10601060
}
10611061

1062+
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
1063+
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
1064+
}
1065+
1066+
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
1067+
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
1068+
}
1069+
10621070
//===----------------------------------------------------------------------===//
10631071
// Value API.
10641072
//===----------------------------------------------------------------------===//

mlir/test/python/ir/blocks.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
import gc
4-
import io
5-
import itertools
6-
from mlir.ir import *
4+
75
from mlir.dialects import builtin
86
from mlir.dialects import cf
97
from mlir.dialects import func
8+
from mlir.ir import *
109

1110

1211
def run(f):
@@ -58,6 +57,17 @@ def testBlockCreation():
5857
# Ensure region back references are coherent.
5958
assert entry_block.region == middle_block.region == successor_block.region
6059

60+
entry_block_successors = entry_block.successors
61+
assert len(entry_block_successors) == 1
62+
assert middle_block == entry_block_successors[0]
63+
64+
middle_block_successors = middle_block.successors
65+
assert len(middle_block_successors) == 1
66+
assert successor_block == middle_block_successors[0]
67+
68+
successor_block_successors = successor_block.successors
69+
assert len(successor_block_successors) == 0
70+
6171

6272
# CHECK-LABEL: TEST: testBlockCreationArgLocs
6373
@run

0 commit comments

Comments
 (0)