Skip to content

Commit 08586f7

Browse files
authored
Add cirq emit methods for new and missing statements (#372)
Implements part of #352
1 parent d575aba commit 08586f7

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

src/bloqade/squin/cirq/emit/op.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SnRuntime,
1010
SpRuntime,
1111
U3Runtime,
12+
RotRuntime,
1213
KronRuntime,
1314
MultRuntime,
1415
ScaleRuntime,
@@ -123,3 +124,21 @@ def pauli_string(
123124
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
124125
):
125126
return (PauliStringRuntime(stmt.string),)
127+
128+
@impl(op.stmts.Rot)
129+
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
130+
axis_op: HermitianRuntime = frame.get(stmt.axis)
131+
angle = frame.get(stmt.angle)
132+
133+
axis_name = str(axis_op.gate).lower()
134+
return (RotRuntime(axis=axis_name, angle=angle),)
135+
136+
@impl(op.stmts.SqrtX)
137+
def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
138+
cirq_op = cirq.XPowGate(exponent=0.5)
139+
return (UnitaryRuntime(cirq_op),)
140+
141+
@impl(op.stmts.SqrtY)
142+
def sqrt_y(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtY):
143+
cirq_op = cirq.YPowGate(exponent=0.5)
144+
return (UnitaryRuntime(cirq_op),)

src/bloqade/squin/cirq/emit/runtime.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,18 @@ def unsafe_apply(
240240
qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
241241
}
242242
return [cirq.PauliString(pauli_mapping)]
243+
244+
245+
@dataclass
246+
class RotRuntime(OperatorRuntimeABC):
247+
axis: str
248+
angle: float
249+
250+
def num_qubits(self) -> int:
251+
return 1
252+
253+
def unsafe_apply(
254+
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
255+
) -> list[cirq.Operation]:
256+
rot = getattr(cirq, "R" + self.axis.lower())(rads=self.angle)
257+
return [rot(*qubits)]

test/squin/cirq/test_squin_to_cirq.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,32 @@ def main():
349349
assert len(target._cached_circuit_operations) == 2
350350

351351

352+
def test_additional_stmts():
353+
@squin.kernel
354+
def main():
355+
x = squin.op.x()
356+
r = squin.op.rot(x, 0.123)
357+
q = squin.qubit.new(3)
358+
squin.qubit.apply(r, q[0])
359+
sqrt_x = squin.op.sqrt_x()
360+
sqrt_y = squin.op.sqrt_y()
361+
squin.qubit.apply(sqrt_x, q[1])
362+
squin.qubit.apply(sqrt_y, q[2])
363+
364+
main.print()
365+
366+
circuit = squin.cirq.emit_circuit(main)
367+
368+
q = cirq.LineQubit.range(3)
369+
expected_circuit = cirq.Circuit(
370+
cirq.Rx(rads=0.123).on(q[0]),
371+
cirq.X(q[1]) ** 0.5,
372+
cirq.Y(q[2]) ** 0.5,
373+
)
374+
375+
assert circuit == expected_circuit
376+
377+
352378
def test_return_measurement():
353379

354380
@squin.kernel

0 commit comments

Comments
 (0)