Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def main():
| cirq.CSwapGate
| cirq.XXPowGate
| cirq.YYPowGate
| cirq.ZZPowGate
| cirq.CCXPowGate
| cirq.CCZPowGate
)
Expand Down Expand Up @@ -523,6 +522,27 @@ def visit_CZPowGate(
gate.stmts.CZ(controls=control_qarg, targets=target_qarg)
)

def visit_ZZPowGate(
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
):
if node.gate.exponent % 2 == 0:
return

qubit1, qubit2 = node.qubits
qarg1 = self.lower_qubit_getindices(state, (qubit1,))
qarg2 = self.lower_qubit_getindices(state, (qubit2,))

if node.gate.exponent % 2 == 1:
state.current_frame.push(gate.stmts.X(qarg1))
state.current_frame.push(gate.stmts.X(qarg2))
return

# NOTE: arbitrary exponent, write as CX * Rz * CX (up to global phase)
state.current_frame.push(gate.stmts.CX(qarg1, qarg2))
angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent))
state.current_frame.push(gate.stmts.Rz(angle.result, qarg2))
state.current_frame.push(gate.stmts.CX(qarg1, qarg2))

def visit_ControlledOperation(
self, state: lowering.State[cirq.Circuit], node: cirq.ControlledOperation
):
Expand Down
45 changes: 45 additions & 0 deletions test/cirq_utils/test_cirq_to_squin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import cirq
import numpy as np
import pytest
from kirin import types
from kirin.passes import inline
Expand Down Expand Up @@ -414,3 +415,47 @@ def multi_arg(n: int, p: float):
@pytest.mark.xfail
def test_amplitude_damping():
test_circuit(amplitude_damping)


def test_trotter():

# NOTE: stolen from jonathan's tutorial
def trotter_layer(
qubits, dt: float = 0.01, J: float = 1, h: float = 1
) -> cirq.Circuit:
"""
Cirq builder function that returns a circuit of
a Trotter step of the 1D transverse Ising model
"""
op_zz = cirq.ZZ ** (dt * J / math.pi)
op_x = cirq.X ** (dt * h / math.pi)
circuit = cirq.Circuit()
for i in range(0, len(qubits) - 1, 2):
circuit.append(op_zz.on(qubits[i], qubits[i + 1]))
for i in range(1, len(qubits) - 1, 2):
circuit.append(op_zz.on(qubits[i], qubits[i + 1]))
for i in range(len(qubits)):
circuit.append(op_x.on(qubits[i]))
return circuit

N = 4
steps = 10
dt = 0.01
J = 1
h = 1

qubits = cirq.LineQubit.range(N)
circuit = cirq.Circuit()
for _ in range(steps):
circuit += trotter_layer(qubits, dt, J, h)

main = load_circuit(circuit)

# actually run
cirq_statevector = cirq.Simulator().simulate(circuit).state_vector()
sim = DynamicMemorySimulator()
ket = sim.state_vector(main)

assert math.isclose(
np.abs(np.dot(np.conj(ket), cirq_statevector)) ** 2, 1.0, abs_tol=1e-3
)