Skip to content

Commit 0acbf74

Browse files
committed
[MLIR][Python] fix PyRegionList __iter__
1 parent 321afc7 commit 0acbf74

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ namespace {
204204

205205
class PyRegionIterator {
206206
public:
207-
PyRegionIterator(PyOperationRef operation)
208-
: operation(std::move(operation)) {}
207+
PyRegionIterator(PyOperationRef operation, int nextIndex)
208+
: operation(std::move(operation)), nextIndex(nextIndex) {}
209209

210210
PyRegionIterator &dunderIter() { return *this; }
211211

@@ -228,7 +228,7 @@ class PyRegionIterator {
228228

229229
private:
230230
PyOperationRef operation;
231-
int nextIndex = 0;
231+
intptr_t nextIndex = 0;
232232
};
233233

234234
/// Regions of an op are fixed length and indexed numerically so are represented
@@ -247,7 +247,7 @@ class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
247247

248248
PyRegionIterator dunderIter() {
249249
operation->checkValid();
250-
return PyRegionIterator(operation);
250+
return PyRegionIterator(operation, startIndex);
251251
}
252252

253253
static void bindDerived(ClassTy &c) {

mlir/lib/Bindings/Python/NanobindUtils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,6 @@ class Sliceable {
395395
/// Hook for derived classes willing to bind more methods.
396396
static void bindDerived(ClassTy &) {}
397397

398-
private:
399398
intptr_t startIndex;
400399
intptr_t length;
401400
intptr_t step;

mlir/test/python/ir/operation.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import gc
44
import io
5-
import itertools
65
from tempfile import NamedTemporaryFile
76
from mlir.ir import *
87
from mlir.dialects.builtin import ModuleOp
9-
from mlir.dialects import arith
8+
from mlir.dialects import arith, func, scf
109
from mlir.dialects._ods_common import _cext
10+
from mlir.extras import types as T
1111

1212

1313
def run(f):
@@ -1199,3 +1199,27 @@ def testGetOwnerConcreteOpview():
11991199
r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
12001200
for u in a.result.uses:
12011201
assert isinstance(u.owner, arith.AddIOp)
1202+
1203+
1204+
# CHECK-LABEL: TEST: testIndexSwitch
1205+
@run
1206+
def testIndexSwitch():
1207+
with Context() as ctx, Location.unknown():
1208+
i32 = T.i32()
1209+
module = Module.create()
1210+
with InsertionPoint(module.body):
1211+
1212+
@func.FuncOp.from_py_func(T.index())
1213+
def index_switch(index):
1214+
c1 = arith.constant(i32, 1)
1215+
switch_op = scf.IndexSwitchOp(
1216+
results_=[i32], arg=index, cases=range(3), num_caseRegions=3
1217+
)
1218+
1219+
assert len(switch_op.regions) == 4
1220+
assert len(switch_op.regions[2:]) == 2
1221+
assert len([i for i in switch_op.regions[2:]]) == 2
1222+
assert len(switch_op.caseRegions) == 3
1223+
assert len([i for i in switch_op.caseRegions]) == 3
1224+
assert len(switch_op.caseRegions[1:]) == 2
1225+
assert len([i for i in switch_op.caseRegions[1:]]) == 2

0 commit comments

Comments
 (0)