Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/bloqade/squin/cirq/emit/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import cirq
import numpy as np
from kirin.emit import EmitError
from kirin.interp import MethodTable, impl

from ... import op
from .runtime import (
SnRuntime,
SpRuntime,
U3Runtime,
RotRuntime,
KronRuntime,
MultRuntime,
ScaleRuntime,
Expand Down Expand Up @@ -126,6 +126,31 @@ def pauli_string(
):
return (PauliStringRuntime(stmt.string),)

@impl(op.stmts.Rot)
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
axis: OperatorRuntimeABC = frame.get(stmt.axis)

if not isinstance(axis, HermitianRuntime):
raise EmitError(
f"Circuit emission only supported for Pauli operators! Got axis {axis}"
)

angle = frame.get(stmt.angle)

match axis.gate:
case cirq.X:
gate = cirq.Rx(rads=angle)
case cirq.Y:
gate = cirq.Ry(rads=angle)
case cirq.Z:
gate = cirq.Rz(rads=angle)
case _:
raise EmitError(
f"Circuit emission only supported for Pauli operators! Got axis {axis.gate}"
)

return (HermitianRuntime(gate=gate),)

@impl(op.stmts.ResetToOne)
def reset_to_one(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ResetToOne
Expand All @@ -140,14 +165,6 @@ def reset_to_one(
# NOTE: mind the order: rhs is applied first
return (MultRuntime(rt2, rt1),)

@impl(op.stmts.Rot)
def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Rot):
axis_op: HermitianRuntime = frame.get(stmt.axis)
angle = frame.get(stmt.angle)

axis_name = str(axis_op.gate).lower()
return (RotRuntime(axis=axis_name, angle=angle),)

@impl(op.stmts.SqrtX)
def sqrt_x(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.SqrtX):
cirq_op = cirq.XPowGate(exponent=0.5)
Expand Down
15 changes: 0 additions & 15 deletions src/bloqade/squin/cirq/emit/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,3 @@ def unsafe_apply(
qbit: pauli_label for (qbit, pauli_label) in zip(qubits, self.string)
}
return [cirq.PauliString(pauli_mapping)]


@dataclass
class RotRuntime(OperatorRuntimeABC):
axis: str
angle: float

def num_qubits(self) -> int:
return 1

def unsafe_apply(
self, qubits: Sequence[cirq.Qid], adjoint: bool = False
) -> list[cirq.Operation]:
rot = getattr(cirq, "R" + self.axis.lower())(rads=self.angle)
return [rot(*qubits)]
38 changes: 38 additions & 0 deletions test/squin/cirq/test_squin_to_cirq.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a not happy path test? To trigger throwing the two EmitErrors I see in emit/op.py?

Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,42 @@ def main():
assert len(target._cached_circuit_operations) == 2


def test_rot():
@squin.kernel
def main():
axis = squin.op.x()
q = squin.qubit.new(1)
r = squin.op.rot(axis=axis, angle=math.pi / 2)
squin.qubit.apply(r, q[0])

circuit = squin.cirq.emit_circuit(main)

print(circuit)

assert circuit[0].operations[0].gate == cirq.Rx(rads=math.pi / 2)

@squin.kernel
def main2():
x = squin.op.x()
y = squin.op.y()
q = squin.qubit.new(1)
r = squin.op.rot(axis=x * y, angle=0.123)
squin.qubit.apply(r, q[0])

with pytest.raises(EmitError):
squin.cirq.emit_circuit(main2)

@squin.kernel
def main3():
op = squin.op.h()
q = squin.qubit.new(1)
r = squin.op.rot(axis=op, angle=0.123)
squin.qubit.apply(r, q[0])

with pytest.raises(EmitError):
squin.cirq.emit_circuit(main3)


def test_additional_stmts():
@squin.kernel
def main():
Expand All @@ -366,6 +402,8 @@ def main():

circuit = squin.cirq.emit_circuit(main)

print(circuit)

q = cirq.LineQubit.range(3)
expected_circuit = cirq.Circuit(
cirq.Rx(rads=0.123).on(q[0]),
Expand Down
2 changes: 1 addition & 1 deletion test/squin/test_sugar.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This should fix that problem from the random phase you told me about which causes the occasional failing CI run

Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():
sim = StackMemorySimulator(min_qubits=2)
ket = sim.state_vector(main)

assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-5)
assert math.isclose(abs(ket[0]) ** 2, 1, abs_tol=1e-4)
assert ket[1] == ket[2] == ket[3] == 0


Expand Down