Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 14 additions & 73 deletions src/bloqade/qasm2/passes/fold.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
from dataclasses import field, dataclass

from kirin import ir
from kirin.passes import Pass, TypeInfer
from kirin.rewrite import (
Walk,
Chain,
Inline,
Fixpoint,
WrapConst,
Call2Invoke,
ConstantFold,
CFGCompactify,
InlineGetItem,
InlineGetField,
DeadCodeElimination,
CommonSubexpressionElimination,
)
from kirin.analysis import const
from kirin.dialects import scf, ilist
from kirin.passes import Pass
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult

from bloqade.qasm2.dialects import expr
from bloqade.rewrite.passes import AggressiveUnroll

from .unroll_if import UnrollIfs

Expand All @@ -30,71 +15,27 @@
class QASM2Fold(Pass):
"""Fold pass for qasm2.extended"""

constprop: const.Propagate = field(init=False)
inline_gate_subroutine: bool = True
unroll_ifs: bool = True
aggressive_unroll: AggressiveUnroll = field(init=False)

def __post_init__(self):
self.constprop = const.Propagate(self.dialects)
self.typeinfer = TypeInfer(self.dialects)
def inline_simple(node: ir.Statement):
if isinstance(node, expr.GateFunction):
return self.inline_gate_subroutine

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
frame, _ = self.constprop.run_analysis(mt)
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
rule = Chain(
ConstantFold(),
Call2Invoke(),
InlineGetField(),
InlineGetItem(),
DeadCodeElimination(),
CommonSubexpressionElimination(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
return True

result = (
Walk(
Chain(
scf.unroll.PickIfElse(),
scf.unroll.ForLoop(),
scf.trim.UnusedYield(),
)
)
.rewrite(mt.code)
.join(result)
self.aggressive_unroll = AggressiveUnroll(
self.dialects, inline_simple, no_raise=self.no_raise
)

if self.unroll_ifs:
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)

# run typeinfer again after unroll etc. because we now insert
# a lot of new nodes, which might have more precise types
self.typeinfer.unsafe_run(mt)
result = (
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
.rewrite(mt.code)
.join(result)
)

def inline_simple(node: ir.Statement):
if isinstance(node, expr.GateFunction):
return self.inline_gate_subroutine
def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()

if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)):
return True # always inline calls outside of loops and if-else
if self.unroll_ifs:
result = UnrollIfs(mt.dialects).unsafe_run(mt).join(result)

# inside loops and if-else, only inline simple functions, i.e. functions with a single block
if (trait := node.get_trait(ir.CallableStmtInterface)) is None:
return False # not a callable, don't inline to be safe
region = trait.get_callable_region(node)
return len(region.blocks) == 1
result = self.aggressive_unroll.unsafe_run(mt).join(result)

result = (
Walk(
Inline(inline_simple),
)
.rewrite(mt.code)
.join(result)
)
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
return result
1 change: 1 addition & 0 deletions src/bloqade/rewrite/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
93 changes: 93 additions & 0 deletions src/bloqade/rewrite/passes/aggressive_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable
from dataclasses import field, dataclass

from kirin import ir
from kirin.passes import Pass, HintConst, TypeInfer
from kirin.rewrite import (
Walk,
Chain,
Inline,
Fixpoint,
Call2Invoke,
ConstantFold,
CFGCompactify,
InlineGetItem,
InlineGetField,
DeadCodeElimination,
)
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult
from kirin.rewrite.cse import CommonSubexpressionElimination
from kirin.passes.aggressive import UnrollScf


@dataclass
class Fold(Pass):
hint_const: HintConst = field(init=False)

def __post_init__(self):
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.hint_const.unsafe_run(mt).join(result)
rule = Chain(
ConstantFold(),
Call2Invoke(),
InlineGetField(),
InlineGetItem(),
ilist.rewrite.InlineGetItem(),
ilist.rewrite.HintLen(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)

return result


@dataclass
class AggressiveUnroll(Pass):
"""A pass to unroll structured control flow"""

additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True

fold: Fold = field(init=False)
typeinfer: TypeInfer = field(init=False)
scf_unroll: UnrollScf = field(init=False)

def __post_init__(self):
self.fold = Fold(self.dialects, no_raise=self.no_raise)
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.scf_unroll.unsafe_run(mt).join(result)
result = (
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
.rewrite(mt.code)
.join(result)
)
result = self.typeinfer.unsafe_run(mt).join(result)
result = self.fold.unsafe_run(mt).join(result)
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)

rule = Chain(
CommonSubexpressionElimination(),
DeadCodeElimination(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)

return result

def inline_heuristic(self, node: ir.Statement) -> bool:
"""The heuristic to decide whether to inline a function call or not.
inside loops and if-else, only inline simple functions, i.e.
functions with a single block
"""
return not isinstance(
node.parent_stmt, (scf.For, scf.IfElse)
) and self.additional_inline_heuristic(
node
) # always inline calls outside of loops and if-else
32 changes: 18 additions & 14 deletions src/bloqade/rewrite/passes/canonicalize_ilist.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from dataclasses import dataclass
from dataclasses import field, dataclass

from kirin import ir
from kirin.passes import Pass
from kirin import ir, passes
from kirin.rewrite import (
Walk,
Chain,
Fixpoint,
)
from kirin.analysis import const

from ..rules.flatten_ilist import FlattenAddOpIList
from ..rules.inline_getitem_ilist import InlineGetItemFromIList
from kirin.dialects.ilist import rewrite


@dataclass
class CanonicalizeIList(Pass):
class CanonicalizeIList(passes.Pass):

def unsafe_run(self, mt: ir.Method):
fold_pass: passes.Fold = field(init=False)

cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
def __post_init__(self):
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)

return Fixpoint(
Chain(
Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
Walk(FlattenAddOpIList()),
def unsafe_run(self, mt: ir.Method):
result = Fixpoint(
Walk(
Chain(
rewrite.InlineGetItem(),
rewrite.FlattenAdd(),
rewrite.HintLen(),
)
)
).rewrite(mt.code)

result = self.fold_pass(mt).join(result)
return result
51 changes: 0 additions & 51 deletions src/bloqade/rewrite/rules/flatten_ilist.py

This file was deleted.

31 changes: 0 additions & 31 deletions src/bloqade/rewrite/rules/inline_getitem_ilist.py

This file was deleted.

27 changes: 27 additions & 0 deletions test/qasm2/emit/test_qasm2_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,30 @@ def nested_kernel():

target = qasm2.emit.QASM2()
target.emit(nested_kernel)


def test_loop_unroll():
n_qubits = 4

@qasm2.extended
def ghz_linear():
q = qasm2.qreg(n_qubits)
qasm2.h(q[0])
for i in range(1, n_qubits):
qasm2.cx(q[i - 1], q[i])

target = qasm2.emit.QASM2(
allow_parallel=True,
)
qasm2_str = target.emit_str(ghz_linear)

assert qasm2_str == (
"""KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf};
include "qelib1.inc";
qreg q[4];
h q[0];
CX q[0], q[1];
CX q[1], q[2];
CX q[2], q[3];
"""
)