Skip to content

Commit 8a0e2e6

Browse files
committed
[MLIR][Python] fix PyRegionList __iter__
1 parent 321afc7 commit 8a0e2e6

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ namespace {
204204

205205
class PyRegionIterator {
206206
public:
207-
PyRegionIterator(PyOperationRef operation)
208-
: operation(std::move(operation)) {}
207+
PyRegionIterator(PyOperationRef operation, int nextIndex)
208+
: operation(std::move(operation)), nextIndex(nextIndex) {}
209209

210210
PyRegionIterator &dunderIter() { return *this; }
211211

@@ -228,7 +228,7 @@ class PyRegionIterator {
228228

229229
private:
230230
PyOperationRef operation;
231-
int nextIndex = 0;
231+
intptr_t nextIndex = 0;
232232
};
233233

234234
/// Regions of an op are fixed length and indexed numerically so are represented
@@ -247,7 +247,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
247247

248248
PyRegionIterator dunderIter() {
249249
operation->checkValid();
250-
return PyRegionIterator(operation);
250+
return PyRegionIterator(operation, startIndex);
251251
}
252252

253253
static void bindDerived(ClassTy &c) {

mlir/lib/Bindings/Python/NanobindUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@ class Sliceable {
395395
/// Hook for derived classes willing to bind more methods.
396396
static void bindDerived(ClassTy &) {}
397397

398-
private:
399398
intptr_t startIndex;
400399
intptr_t length;
401400
intptr_t step;

mlir/test/python/ir/operation.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tempfile import NamedTemporaryFile
77
from mlir.ir import *
88
from mlir.dialects.builtin import ModuleOp
9-
from mlir.dialects import arith
9+
from mlir.dialects import arith, func, scf
1010
from mlir.dialects._ods_common import _cext
1111

1212

@@ -1199,3 +1199,28 @@ def testGetOwnerConcreteOpview():
11991199
r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
12001200
for u in a.result.uses:
12011201
assert isinstance(u.owner, arith.AddIOp)
1202+
1203+
1204+
# CHECK-LABEL: TEST: testIndexSwitch
1205+
@run
1206+
def testIndexSwitch(_module):
1207+
from mlir.extras import types as T
1208+
1209+
i32 = T.i32()
1210+
1211+
with Context() as ctx, Location.unknown():
1212+
module = Module.create()
1213+
with InsertionPoint(module.body):
1214+
1215+
@func.FuncOp.from_py_func(T.index())
1216+
def index_switch(index):
1217+
c1 = arith.constant(i32, 1)
1218+
switch_op = scf.IndexSwitchOp(
1219+
results_=[i32], arg=index, cases=range(3), num_caseRegions=3
1220+
)
1221+
1222+
assert len(switch_op.regions) == 4
1223+
assert len(switch_op.regions[2:]) == 2
1224+
assert len([i for i in switch_op.regions[2:]]) == 2
1225+
assert len(switch_op.caseRegions) == 3
1226+
assert len([i for i in switch_op.caseRegions]) == 3

0 commit comments

Comments
 (0)