From d1185549b9eb53af65923b0fce89f5aa75495607 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 5 Aug 2025 17:27:30 +0200 Subject: [PATCH 1/3] Support cirq emit of rotation when axis is X,Y,Z --- src/bloqade/squin/cirq/emit/op.py | 28 ++++++++++++++++++++++++++ src/bloqade/squin/cirq/emit/runtime.py | 15 ++++++++++++++ test/squin/cirq/test_squin_to_cirq.py | 16 ++++++++++++++- test/squin/test_sugar.py | 2 +- 4 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/bloqade/squin/cirq/emit/op.py b/src/bloqade/squin/cirq/emit/op.py index a94a1082..8ee49fc7 100644 --- a/src/bloqade/squin/cirq/emit/op.py +++ b/src/bloqade/squin/cirq/emit/op.py @@ -2,6 +2,7 @@ import cirq import numpy as np +from kirin.emit import EmitError from kirin.interp import MethodTable, impl from ... import op @@ -123,3 +124,30 @@ def pauli_string( self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString ): return (PauliStringRuntime(stmt.string),) + + @impl(op.stmts.Rot) + def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot): + axis: OperatorRuntimeABC = frame.get(stmt.axis) + + if not isinstance(axis, HermitianRuntime): + raise EmitError( + f"Circuit emission only supported for Pauli operators! Got axis {axis}" + ) + + angle = frame.get(stmt.angle) + + exponent = angle / math.pi + + match axis.gate: + case cirq.X: + gate = cirq.XPowGate(exponent=exponent) + case cirq.Y: + gate = cirq.YPowGate(exponent=exponent) + case cirq.Z: + gate = cirq.ZPowGate(exponent=exponent) + case _: + raise EmitError( + f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}" + ) + + return (HermitianRuntime(gate=gate),) diff --git a/src/bloqade/squin/cirq/emit/runtime.py b/src/bloqade/squin/cirq/emit/runtime.py index 8d5399c3..f7ec0d80 100644 --- a/src/bloqade/squin/cirq/emit/runtime.py +++ b/src/bloqade/squin/cirq/emit/runtime.py @@ -240,3 +240,18 @@ def unsafe_apply( qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string) } return [cirq.PauliString(pauli_mapping)] + + +@dataclass +class RotRuntime(OperatorRuntimeABC): + axis: HermitianRuntime + angle: float + + def num_qubits(self) -> int: + return self.axis.num_qubits() + + def unsafe_apply( + self, qubits: Sequence[cirq.Qid], adjoint: bool = False + ) -> list[cirq.Operation]: + return [] + # if not isinstance(self.axis.gate, ) diff --git a/test/squin/cirq/test_squin_to_cirq.py b/test/squin/cirq/test_squin_to_cirq.py index ce5c8f82..d11fec29 100644 --- a/test/squin/cirq/test_squin_to_cirq.py +++ b/test/squin/cirq/test_squin_to_cirq.py @@ -349,4 +349,18 @@ def main(): assert len(target._cached_circuit_operations) == 2 -test_return_value() +def test_rot(): + @squin.kernel + def main(): + axis = squin.op.x() + q = squin.qubit.new(1) + r = squin.op.rot(axis=axis, angle=math.pi / 2) + squin.qubit.apply(r, q[0]) + + main.print() + + circuit = squin.cirq.emit_circuit(main) + + print(circuit) + + assert circuit[0].operations[0].gate == cirq.XPowGate(exponent=0.5) diff --git a/test/squin/test_sugar.py b/test/squin/test_sugar.py index 3147855b..39b517cd 100644 --- a/test/squin/test_sugar.py +++ b/test/squin/test_sugar.py @@ -87,7 +87,7 @@ def main(): sim = StackMemorySimulator(min_qubits=2) ket = sim.state_vector(main) - assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-7) + assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-4) assert ket[1] == ket[2] == ket[3] == 0 From 4e7603d6ba72bdfe367221ab0533a42584c84b8c Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 5 Aug 2025 17:29:33 +0200 Subject: [PATCH 2/3] Remove unused runtime --- src/bloqade/squin/cirq/emit/runtime.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/bloqade/squin/cirq/emit/runtime.py b/src/bloqade/squin/cirq/emit/runtime.py index f7ec0d80..8d5399c3 100644 --- a/src/bloqade/squin/cirq/emit/runtime.py +++ b/src/bloqade/squin/cirq/emit/runtime.py @@ -240,18 +240,3 @@ def unsafe_apply( qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string) } return [cirq.PauliString(pauli_mapping)] - - -@dataclass -class RotRuntime(OperatorRuntimeABC): - axis: HermitianRuntime - angle: float - - def num_qubits(self) -> int: - return self.axis.num_qubits() - - def unsafe_apply( - self, qubits: Sequence[cirq.Qid], adjoint: bool = False - ) -> list[cirq.Operation]: - return [] - # if not isinstance(self.axis.gate, ) From 5435313871c94b7859881fd2d315b85778ceb92b Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 12 Aug 2025 09:44:13 +0200 Subject: [PATCH 3/3] Don't rewrite to PowGate --- src/bloqade/squin/cirq/emit/op.py | 8 +++--- test/squin/cirq/test_squin_to_cirq.py | 38 +++++++++++++++++++++------ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/bloqade/squin/cirq/emit/op.py b/src/bloqade/squin/cirq/emit/op.py index e408c3fc..25bca61b 100644 --- a/src/bloqade/squin/cirq/emit/op.py +++ b/src/bloqade/squin/cirq/emit/op.py @@ -137,15 +137,13 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot): angle = frame.get(stmt.angle) - exponent = angle / math.pi - match axis.gate: case cirq.X: - gate = cirq.XPowGate(exponent=exponent) + gate = cirq.Rx(rads=angle) case cirq.Y: - gate = cirq.YPowGate(exponent=exponent) + gate = cirq.Ry(rads=angle) case cirq.Z: - gate = cirq.ZPowGate(exponent=exponent) + gate = cirq.Rz(rads=angle) case _: raise EmitError( f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}" diff --git a/test/squin/cirq/test_squin_to_cirq.py b/test/squin/cirq/test_squin_to_cirq.py index de640b76..3d74c310 100644 --- a/test/squin/cirq/test_squin_to_cirq.py +++ b/test/squin/cirq/test_squin_to_cirq.py @@ -362,15 +362,28 @@ def main(): print(circuit) - assert circuit[0].operations[0].gate == cirq.XPowGate(exponent=0.5) - q = cirq.LineQubit.range(3) - expected_circuit = cirq.Circuit( - cirq.Rx(rads=0.123).on(q[0]), - cirq.X(q[1]) ** 0.5, - cirq.Y(q[2]) ** 0.5, - ) + assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2) - assert circuit == expected_circuit + @squin.kernel + def main2(): + x = squin.op.x() + y = squin.op.y() + q = squin.qubit.new(1) + r = squin.op.rot(axis=x * y, angle=0.123) + squin.qubit.apply(r, q[0]) + + with pytest.raises(EmitError): + squin.cirq.emit_circuit(main2) + + @squin.kernel + def main3(): + op = squin.op.h() + q = squin.qubit.new(1) + r = squin.op.rot(axis=op, angle=0.123) + squin.qubit.apply(r, q[0]) + + with pytest.raises(EmitError): + squin.cirq.emit_circuit(main3) def test_additional_stmts(): @@ -391,6 +404,15 @@ def main(): print(circuit) + q = cirq.LineQubit.range(3) + expected_circuit = cirq.Circuit( + cirq.Rx(rads=0.123).on(q[0]), + cirq.X(q[1]) ** 0.5, + cirq.Y(q[2]) ** 0.5, + ) + + assert circuit == expected_circuit + def test_return_measurement():