diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py index 64000343..35b7c424 100644 --- a/src/bloqade/squin/op/rewrite.py +++ b/src/bloqade/squin/op/rewrite.py @@ -6,7 +6,7 @@ from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from .stmts import Mult, Scale +from .stmts import Rot, Mult, Scale, Adjoint from .types import OpType @@ -44,3 +44,20 @@ class PyMultToSquinMult(Pass): def unsafe_run(self, mt: ir.Method) -> RewriteResult: return Walk(_PyMultToSquinMult()).rewrite(mt.code) + + +class CanonicalizeAdjointRot(RewriteRule): + """This canonicalizes adjoint of rotations: Adj(Rot(angle, axis)) -> Rot(-angle, Adjoint(axis))""" + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, Adjoint) or not isinstance( + rot_stmt := node.op.owner, Rot + ): + return RewriteResult() + + # NOTE: angle is a float so the adjoint will give a negative angle because of the imaginary unit + (neg_angle := py.USub(rot_stmt.angle)).insert_before(node) + (new_axis_stmt := Adjoint(op=rot_stmt.axis)).insert_before(node) + node.replace_by(Rot(angle=neg_angle.result, axis=new_axis_stmt.result)) + + return RewriteResult(has_done_something=True) diff --git a/test/squin/op/test_rewrite.py b/test/squin/op/test_rewrite.py new file mode 100644 index 00000000..d6149799 --- /dev/null +++ b/test/squin/op/test_rewrite.py @@ -0,0 +1,29 @@ +from kirin import ir +from kirin.rewrite import Walk +from kirin.dialects import py + +from bloqade.squin.op import stmts as op_stmts +from bloqade.test_utils import assert_nodes +from bloqade.squin.op.rewrite import CanonicalizeAdjointRot + + +def test_rot_canonicalize(): + angle = ir.TestValue() + axis = ir.TestValue() + test_block = ir.Block() + test_block.stmts.append(rot := op_stmts.Rot(angle=angle, axis=axis)) + test_block.stmts.append(final_op := op_stmts.Adjoint(rot.result)) + test_block.stmts.append(op_stmts.Control(final_op.result, n_controls=1)) + + Walk(CanonicalizeAdjointRot()).rewrite(test_block) + + expected_block = ir.Block() + expected_block.stmts.append(rot := op_stmts.Rot(angle=angle, axis=axis)) + expected_block.stmts.append(new_angle := py.USub(angle)) + expected_block.stmts.append(new_axis := op_stmts.Adjoint(op=axis)) + expected_block.stmts.append( + final_op := op_stmts.Rot(new_axis.result, new_angle.result) + ) + expected_block.stmts.append(op_stmts.Control(final_op.result, n_controls=1)) + + assert_nodes(test_block, expected_block)