Skip to content

Commit 8670b8b

Browse files
authored
Support cirq emit of rotation when axis is X,Y,Z (#437)
One of the things that dropped out of #436
1 parent a2fa304 commit 8670b8b

File tree

4 files changed

+65
-25
lines changed

4 files changed

+65
-25
lines changed

src/bloqade/squin/cirq/emit/op.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import cirq
44
import numpy as np
5+
from kirin.emit import EmitError
56
from kirin.interp import MethodTable, impl
67

78
from ... import op
89
from .runtime import (
910
SnRuntime,
1011
SpRuntime,
1112
U3Runtime,
12-
RotRuntime,
1313
KronRuntime,
1414
MultRuntime,
1515
ScaleRuntime,
@@ -126,6 +126,31 @@ def pauli_string(
126126
):
127127
return (PauliStringRuntime(stmt.string),)
128128

129+
@impl(op.stmts.Rot)
130+
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
131+
axis: OperatorRuntimeABC = frame.get(stmt.axis)
132+
133+
if not isinstance(axis, HermitianRuntime):
134+
raise EmitError(
135+
f"Circuit emission only supported for Pauli operators! Got axis {axis}"
136+
)
137+
138+
angle = frame.get(stmt.angle)
139+
140+
match axis.gate:
141+
case cirq.X:
142+
gate = cirq.Rx(rads=angle)
143+
case cirq.Y:
144+
gate = cirq.Ry(rads=angle)
145+
case cirq.Z:
146+
gate = cirq.Rz(rads=angle)
147+
case _:
148+
raise EmitError(
149+
f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
150+
)
151+
152+
return (HermitianRuntime(gate=gate),)
153+
129154
@impl(op.stmts.ResetToOne)
130155
def reset_to_one(
131156
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
@@ -140,14 +165,6 @@ def reset_to_one(
140165
# NOTE: mind the order: rhs is applied first
141166
return (MultRuntime(rt2, rt1),)
142167

143-
@impl(op.stmts.Rot)
144-
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
145-
axis_op: HermitianRuntime = frame.get(stmt.axis)
146-
angle = frame.get(stmt.angle)
147-
148-
axis_name = str(axis_op.gate).lower()
149-
return (RotRuntime(axis=axis_name, angle=angle),)
150-
151168
@impl(op.stmts.SqrtX)
152169
def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
153170
cirq_op = cirq.XPowGate(exponent=0.5)

src/bloqade/squin/cirq/emit/runtime.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -240,18 +240,3 @@ def unsafe_apply(
240240
qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
241241
}
242242
return [cirq.PauliString(pauli_mapping)]
243-
244-
245-
@dataclass
246-
class RotRuntime(OperatorRuntimeABC):
247-
axis: str
248-
angle: float
249-
250-
def num_qubits(self) -> int:
251-
return 1
252-
253-
def unsafe_apply(
254-
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
255-
) -> list[cirq.Operation]:
256-
rot = getattr(cirq, "R" + self.axis.lower())(rads=self.angle)
257-
return [rot(*qubits)]

test/squin/cirq/test_squin_to_cirq.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,42 @@ def main():
350350
assert len(target._cached_circuit_operations) == 2
351351

352352

353+
def test_rot():
354+
@squin.kernel
355+
def main():
356+
axis = squin.op.x()
357+
q = squin.qubit.new(1)
358+
r = squin.op.rot(axis=axis, angle=math.pi / 2)
359+
squin.qubit.apply(r, q[0])
360+
361+
circuit = squin.cirq.emit_circuit(main)
362+
363+
print(circuit)
364+
365+
assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2)
366+
367+
@squin.kernel
368+
def main2():
369+
x = squin.op.x()
370+
y = squin.op.y()
371+
q = squin.qubit.new(1)
372+
r = squin.op.rot(axis=x * y, angle=0.123)
373+
squin.qubit.apply(r, q[0])
374+
375+
with pytest.raises(EmitError):
376+
squin.cirq.emit_circuit(main2)
377+
378+
@squin.kernel
379+
def main3():
380+
op = squin.op.h()
381+
q = squin.qubit.new(1)
382+
r = squin.op.rot(axis=op, angle=0.123)
383+
squin.qubit.apply(r, q[0])
384+
385+
with pytest.raises(EmitError):
386+
squin.cirq.emit_circuit(main3)
387+
388+
353389
def test_additional_stmts():
354390
@squin.kernel
355391
def main():
@@ -366,6 +402,8 @@ def main():
366402

367403
circuit = squin.cirq.emit_circuit(main)
368404

405+
print(circuit)
406+
369407
q = cirq.LineQubit.range(3)
370408
expected_circuit = cirq.Circuit(
371409
cirq.Rx(rads=0.123).on(q[0]),

test/squin/test_sugar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def main():
8787
sim = StackMemorySimulator(min_qubits=2)
8888
ket = sim.state_vector(main)
8989

90-
assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-5)
90+
assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-4)
9191
assert ket[1] == ket[2] == ket[3] == 0
9292

9393

0 commit comments

Comments
 (0)