From 6619ba15de9749e743759661e1932ea691e06808 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 7 Oct 2025 12:11:39 +0200 Subject: [PATCH] Fix for loop unrolling in qasm2 emit --- src/bloqade/qasm2/passes/fold.py | 15 +++------------ test/qasm2/emit/test_qasm2_emit.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/bloqade/qasm2/passes/fold.py b/src/bloqade/qasm2/passes/fold.py index 65968f03..ef45843d 100644 --- a/src/bloqade/qasm2/passes/fold.py +++ b/src/bloqade/qasm2/passes/fold.py @@ -20,6 +20,7 @@ from kirin.dialects import scf, ilist from kirin.ir.method import Method from kirin.rewrite.abc import RewriteResult +from kirin.passes.aggressive import UnrollScf from bloqade.qasm2.dialects import expr @@ -51,18 +52,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult: CommonSubexpressionElimination(), ) result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) - - result = ( - Walk( - Chain( - scf.unroll.PickIfElse(), - scf.unroll.ForLoop(), - scf.trim.UnusedYield(), - ) - ) - .rewrite(mt.code) - .join(result) - ) + result = UnrollScf(self.dialects).fixpoint(mt).join(result) + result = Walk(scf.trim.UnusedYield()).rewrite(mt.code).join(result) if self.unroll_ifs: UnrollIfs(mt.dialects).unsafe_run(mt).join(result) diff --git a/test/qasm2/emit/test_qasm2_emit.py b/test/qasm2/emit/test_qasm2_emit.py index 34474165..e044b7ee 100644 --- a/test/qasm2/emit/test_qasm2_emit.py +++ b/test/qasm2/emit/test_qasm2_emit.py @@ -288,3 +288,30 @@ def nested_kernel(): target = qasm2.emit.QASM2() target.emit(nested_kernel) + + +def test_loop_unroll(): + n_qubits = 4 + + @qasm2.extended + def ghz_linear(): + q = qasm2.qreg(n_qubits) + qasm2.h(q[0]) + for i in range(1, n_qubits): + qasm2.cx(q[i - 1], q[i]) + + target = qasm2.emit.QASM2( + allow_parallel=True, + ) + qasm2_str = target.emit_str(ghz_linear) + + assert qasm2_str == ( + """KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf}; +include "qelib1.inc"; +qreg q[4]; +h q[0]; +CX q[0], q[1]; +CX q[1], q[2]; +CX q[2], q[3]; +""" + )