Skip to content

Commit e409d1d

Browse files
david-plkaihsin
andauthored
Fix UnrollIfs pass for non-empty else bodies (#417)
Closes #284 and #286. These issues were blocked by another bug in kirin that is fixed now. Also, note that this change basically just keeps the rewrite from raising an error, but ultimately things like non-empty else bodies (including `elif`) is just not supported in QASM2, so you will not be able to emit QASM2 from such a program. Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent 69818de commit e409d1d

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/bloqade/rewrite/rules/split_ifs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
4646
if not isinstance(node, scf.IfElse):
4747
return RewriteResult()
4848

49+
# NOTE: only empty else bodies are allowed in valid QASM2
50+
if not self._has_empty_else(node):
51+
return RewriteResult()
52+
4953
*stmts, yield_or_return = node.then_body.stmts()
5054

51-
if len(stmts) == 1:
55+
if len(stmts) <= 1:
5256
return RewriteResult()
5357

5458
is_yield = isinstance(yield_or_return, scf.Yield)
@@ -71,3 +75,16 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
7175
node.delete()
7276

7377
return RewriteResult(has_done_something=True)
78+
79+
def _has_empty_else(self, node: scf.IfElse) -> bool:
80+
else_stmts = list(node.else_body.stmts())
81+
if len(else_stmts) > 1:
82+
return False
83+
84+
if len(else_stmts) == 0:
85+
return True
86+
87+
if not isinstance(else_stmts[0], scf.Yield):
88+
return False
89+
90+
return len(else_stmts[0].values) == 0

test/qasm2/passes/test_unroll_if.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from bloqade import qasm2
22
from bloqade.qasm2.emit import QASM2
3+
from bloqade.qasm2.passes import QASM2Fold
34

45

56
def test_unrolling_ifs():
@@ -103,3 +104,28 @@ def main():
103104
ast = target.emit(main)
104105

105106
qasm2.parse.pprint(ast)
107+
108+
109+
def test_elif():
110+
111+
@qasm2.extended
112+
def main():
113+
q = qasm2.qreg(1)
114+
c = qasm2.creg(1)
115+
qasm2.h(q[0])
116+
qasm2.measure(q, c)
117+
118+
parity = 0
119+
if c[0] == 1 and parity == 0:
120+
qasm2.x(q[0])
121+
parity = 0
122+
elif c[0] == 0:
123+
parity = 1
124+
elif c[0] == 2:
125+
parity = 2
126+
127+
return parity
128+
129+
QASM2Fold(qasm2.extended, unroll_ifs=True, no_raise=False).fixpoint(main)
130+
131+
main.print()

0 commit comments

Comments
 (0)