Skip to content

Commit d575aba

Browse files
jon-wurtzdavid-pl
andauthored
Reduced density matrix to pyqrack (#444)
Co-authored-by: David Plankensteiner <[email protected]>
1 parent e409d1d commit d575aba

File tree

3 files changed

+317
-3
lines changed

3 files changed

+317
-3
lines changed

src/bloqade/pyqrack/device.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from kirin import ir
66
from kirin.passes import fold
7+
from kirin.dialects.ilist import IList
78

89
from bloqade.squin import noise as squin_noise
910
from pyqrack.pauli import Pauli
@@ -18,6 +19,7 @@
1819
_default_pyqrack_args,
1920
)
2021
from bloqade.pyqrack.task import PyQrackSimulatorTask
22+
from pyqrack.qrack_simulator import QrackSimulator
2123
from bloqade.squin.noise.rewrite import RewriteNoiseStmts
2224
from bloqade.analysis.address.lattice import AnyAddress
2325
from bloqade.analysis.address.analysis import AddressAnalysis
@@ -26,6 +28,55 @@
2628
Params = ParamSpec("Params")
2729

2830

31+
def _pyqrack_reduced_density_matrix(
32+
inds: tuple[int, ...], sim_reg: QrackSimulator, tol: float = 1e-12
33+
) -> "np.linalg._linalg.EighResult":
34+
"""
35+
Extract the reduced density matrix representing the state of a list
36+
of qubits from a PyQRack simulator register.
37+
38+
Inputs:
39+
inds: A list of integers labeling the qubit registers to extract the reduced density matrix for
40+
sim_reg: The PyQRack simulator register to extract the reduced density matrix from
41+
tol: The tolerance for density matrix eigenvalues to be considered non-zero.
42+
Outputs:
43+
An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
44+
"""
45+
# Identify the rest of the qubits in the register
46+
N = sim_reg.num_qubits()
47+
other = tuple(set(range(N)).difference(inds))
48+
49+
if len(set(inds)) != len(inds):
50+
raise ValueError("Qubits must be unique.")
51+
52+
if max(inds) > N - 1:
53+
raise ValueError(
54+
f"Qubit indices {inds} exceed the number of qubits in the register {N}."
55+
)
56+
57+
reordering = inds + other
58+
# Fix pyqrack edannes to be consistent with Cirq.
59+
reordering = tuple(N - 1 - x for x in reordering)
60+
# Extract the statevector from the PyQRack qubits
61+
statevector = np.array(sim_reg.out_ket())
62+
# Reshape into a (2,2,2, ..., 2) tensor
63+
vec_f = np.reshape(statevector, (2,) * N)
64+
# Reorder the indexes to obey the order of the qubits
65+
vec_p = np.transpose(vec_f, reordering)
66+
# Rehape into a 2^N by 2^M matrix to compute the singular value decomposition
67+
vec_svd = np.reshape(vec_p, (2 ** len(inds), 2 ** len(other)))
68+
# The singular values and vectors are the eigenspace of the reduced density matrix
69+
s, v, d = np.linalg.svd(vec_svd, full_matrices=False)
70+
71+
# Remove the negligable singular values
72+
nonzero_inds = np.where(np.abs(v) > tol)[0]
73+
s = s[:, nonzero_inds]
74+
v = v[nonzero_inds] ** 2
75+
# Forge into the correct result type
76+
result = np.linalg._linalg.EighResult(eigenvalues=v, eigenvectors=s)
77+
return result
78+
79+
2980
@dataclass
3081
class PyQrackSimulatorBase(AbstractSimulatorDevice[PyQrackSimulatorTask]):
3182
"""PyQrack simulation device base class."""
@@ -50,7 +101,6 @@ def new_task(
50101
kwargs: dict[str, Any],
51102
memory: MemoryType,
52103
) -> PyQrackSimulatorTask[Params, RetType, MemoryType]:
53-
54104
if squin_noise in mt.dialects:
55105
# NOTE: rewrite noise statements
56106
mt_ = mt.similar(mt.dialects)
@@ -112,6 +162,51 @@ def pauli_expectation(pauli: list[Pauli], qubits: list[PyQrackQubit]) -> float:
112162

113163
return sim_reg.pauli_expectation(qubit_ids, pauli)
114164

165+
@staticmethod
166+
def quantum_state(
167+
qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
168+
) -> "np.linalg._linalg.EighResult":
169+
"""
170+
Extract the reduced density matrix representing the state of a list
171+
of qubits from a PyQRack simulator register.
172+
173+
Inputs:
174+
qubits: A list of PyQRack qubits to extract the reduced density matrix for
175+
tol: The tolerance for density matrix eigenvalues to be considered non-zero.
176+
Outputs:
177+
An eigh result containing the eigenvalues and eigenvectors of the reduced density matrix.
178+
"""
179+
if len(qubits) == 0:
180+
return np.linalg._linalg.EighResult(
181+
eigenvalues=np.array([]), eigenvectors=np.array([]).reshape(0, 0)
182+
)
183+
sim_reg = qubits[0].sim_reg
184+
185+
if not all([x.sim_reg is sim_reg for x in qubits]):
186+
raise ValueError("All qubits must be from the same simulator register.")
187+
inds: tuple[int, ...] = tuple(qubit.addr for qubit in qubits)
188+
189+
return _pyqrack_reduced_density_matrix(inds, sim_reg, tol)
190+
191+
@classmethod
192+
def reduced_density_matrix(
193+
cls, qubits: list[PyQrackQubit] | IList[PyQrackQubit, Any], tol: float = 1e-12
194+
) -> np.ndarray:
195+
"""
196+
Extract the reduced density matrix representing the state of a list
197+
of qubits from a PyQRack simulator register.
198+
199+
Inputs:
200+
qubits: A list of PyQRack qubits to extract the reduced density matrix for
201+
tol: The tolerance for density matrix eigenvalues to be considered non-zero.
202+
Outputs:
203+
A dense 2^n x 2^n numpy array representing the reduced density matrix.
204+
"""
205+
rdm = cls.quantum_state(qubits, tol)
206+
return np.einsum(
207+
"ax,x,bx", rdm.eigenvectors, rdm.eigenvalues, rdm.eigenvectors.conj()
208+
)
209+
115210

116211
@dataclass
117212
class StackMemorySimulator(PyQrackSimulatorBase):

src/bloqade/pyqrack/task.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33

44
from bloqade.task import AbstractSimulatorTask
5+
from bloqade.pyqrack.reg import QubitState, PyQrackQubit
56
from bloqade.pyqrack.base import (
67
MemoryABC,
78
PyQrackInterpreter,
@@ -36,3 +37,17 @@ def state_vector(self) -> list[complex]:
3637
"""Returns the state vector of the simulator."""
3738
self.run()
3839
return self.state.sim_reg.out_ket()
40+
41+
def qubits(self) -> list[PyQrackQubit]:
42+
"""Returns the qubits in the simulator."""
43+
try:
44+
N = self.state.sim_reg.num_qubits()
45+
return [
46+
PyQrackQubit(
47+
addr=i, sim_reg=self.state.sim_reg, state=QubitState.Active
48+
)
49+
for i in range(N)
50+
]
51+
except AttributeError:
52+
Warning("Task has not been run, there are no qubits!")
53+
return []

test/pyqrack/runtime/test_qrack.py

Lines changed: 206 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import math
22
from unittest.mock import Mock, call
33

4+
import cirq
5+
import numpy as np
46
from kirin import ir
57

6-
from bloqade import qasm2
8+
from bloqade import qasm2, squin
79
from pyqrack.pauli import Pauli
10+
from bloqade.pyqrack import StackMemorySimulator
811
from bloqade.pyqrack.base import MockMemory, PyQrackInterpreter
912

1013

@@ -91,7 +94,6 @@ def program():
9194

9295

9396
def test_basic_control_gates():
94-
9597
@qasm2.main
9698
def program():
9799
q = qasm2.qreg(3)
@@ -162,3 +164,205 @@ def program():
162164
call.r(3, 0.5, 1),
163165
]
164166
)
167+
168+
169+
def test_rdm1():
170+
"""
171+
Is extracting the exact state vector consistent with cirq?
172+
This test also validates the ordering of the qubit basis.
173+
"""
174+
175+
@squin.kernel
176+
def program():
177+
q = squin.qubit.new(5)
178+
squin.gate.h(q[1])
179+
return q
180+
181+
emulator = StackMemorySimulator(min_qubits=6)
182+
task = emulator.task(program)
183+
qubits = task.run()
184+
rho = emulator.quantum_state(qubits)
185+
186+
assert all(np.isclose(rho.eigenvalues, [1]))
187+
188+
circuit = cirq.Circuit()
189+
qbs = cirq.LineQubit.range(5)
190+
circuit.append(cirq.H(qbs[1]))
191+
for i in range(5):
192+
circuit.append(cirq.I(qbs[i]))
193+
state = cirq.Simulator().simulate(circuit).state_vector()
194+
assert cirq.equal_up_to_global_phase(state, rho[1][:, 0])
195+
196+
197+
def test_rdm1b():
198+
"""
199+
Is extracting the exact state vector consistent with cirq?
200+
This test also validates the ordering of the qubit basis.
201+
Same as test_rdm1, but with the total qubits equal to the number of qubits in the program.
202+
"""
203+
204+
@squin.kernel
205+
def program():
206+
q = squin.qubit.new(5)
207+
squin.gate.h(q[1])
208+
return q
209+
210+
emulator = StackMemorySimulator(min_qubits=5)
211+
task = emulator.task(program)
212+
qubits = task.run()
213+
rho = emulator.quantum_state(qubits)
214+
215+
assert all(np.isclose(rho.eigenvalues, [1]))
216+
217+
circuit = cirq.Circuit()
218+
qbs = cirq.LineQubit.range(5)
219+
circuit.append(cirq.H(qbs[1]))
220+
for i in range(5):
221+
circuit.append(cirq.I(qbs[i]))
222+
state = cirq.Simulator().simulate(circuit).state_vector()
223+
assert cirq.equal_up_to_global_phase(state, rho[1][:, 0])
224+
225+
226+
def test_rdm2():
227+
"""
228+
Does the RDM project correctly?
229+
"""
230+
231+
@squin.kernel
232+
def program():
233+
"""
234+
Creates a GHZ state on qubits 0,1,3,4 on a total of 6 qubits.
235+
"""
236+
q = squin.qubit.new(6)
237+
squin.gate.h(q[0])
238+
squin.gate.cx(q[0], q[1])
239+
squin.gate.cx(q[0], q[3])
240+
squin.gate.cx(q[0], q[4])
241+
return q
242+
243+
emulator = StackMemorySimulator(min_qubits=6)
244+
task = emulator.task(program)
245+
qubits = task.run()
246+
rho = emulator.quantum_state([qubits[x] for x in [0, 1, 3, 4]])
247+
target = np.array([1] + [0] * (14) + [1]) / np.sqrt(2) + 0j
248+
assert all(np.isclose(rho.eigenvalues, [1]))
249+
assert cirq.equal_up_to_global_phase(rho[1][:, 0], target)
250+
251+
rho2 = emulator.quantum_state([qubits[x] for x in [0, 1, 3]])
252+
assert all(np.isclose(rho2.eigenvalues, [0.5, 0.5]))
253+
assert rho2.eigenvectors.shape == (8, 2)
254+
255+
256+
def test_rdm3():
257+
"""
258+
Out-of-order indexing is consistent with cirq.
259+
"""
260+
261+
@squin.kernel
262+
def program():
263+
"""
264+
Random unitaries on 3 qubits.
265+
"""
266+
q = squin.qubit.new(3)
267+
squin.gate.rx(0.1, q[0])
268+
squin.gate.ry(0.2, q[1])
269+
squin.gate.rx(0.3, q[2])
270+
return q
271+
272+
emulator = StackMemorySimulator(min_qubits=6)
273+
task = emulator.task(program)
274+
qubits = task.run()
275+
276+
# Canonical ordering
277+
rho = emulator.quantum_state([qubits[x] for x in [0, 1, 2]])
278+
circuit = cirq.Circuit()
279+
qbs = cirq.LineQubit.range(3)
280+
circuit.append(cirq.rx(0.1)(qbs[0]))
281+
circuit.append(cirq.ry(0.2)(qbs[1]))
282+
circuit.append(cirq.rx(0.3)(qbs[2]))
283+
state = cirq.Simulator().simulate(circuit).state_vector()
284+
assert cirq.equal_up_to_global_phase(state, rho[1][:, 0])
285+
286+
# Reverse ordering 0->2, 1->, 2->0
287+
rho = emulator.quantum_state([qubits[x] for x in [2, 1, 0]])
288+
circuit = cirq.Circuit()
289+
qbs = cirq.LineQubit.range(3)
290+
circuit.append(cirq.rx(0.1)(qbs[2]))
291+
circuit.append(cirq.ry(0.2)(qbs[1]))
292+
circuit.append(cirq.rx(0.3)(qbs[0]))
293+
state = cirq.Simulator().simulate(circuit).state_vector()
294+
assert cirq.equal_up_to_global_phase(state, rho[1][:, 0])
295+
296+
# Other ordering
297+
rho = emulator.quantum_state([qubits[x] for x in [1, 2, 0]])
298+
circuit = cirq.Circuit()
299+
qbs = cirq.LineQubit.range(3)
300+
circuit.append(cirq.rx(0.1)(qbs[2]))
301+
circuit.append(cirq.ry(0.2)(qbs[0]))
302+
circuit.append(cirq.rx(0.3)(qbs[1]))
303+
state = cirq.Simulator().simulate(circuit).state_vector()
304+
assert cirq.equal_up_to_global_phase(state, rho[1][:, 0])
305+
306+
307+
def test_rdm4():
308+
rho = StackMemorySimulator.quantum_state([])
309+
assert rho.eigenvalues.shape == (0,)
310+
assert rho.eigenvectors.shape == (0, 0)
311+
312+
313+
def test_rdm5():
314+
@squin.kernel
315+
def program():
316+
"""
317+
Random unitaries on 3 qubits.
318+
"""
319+
q = squin.qubit.new(3)
320+
return q
321+
322+
emulator = StackMemorySimulator(min_qubits=6)
323+
task = emulator.task(program)
324+
qubits = task.run()
325+
rho = emulator.reduced_density_matrix(qubits)
326+
assert rho.shape == (8, 8)
327+
328+
329+
def test_rdm_failures():
330+
@squin.kernel
331+
def program():
332+
q = squin.qubit.new(3)
333+
return q
334+
335+
emulator = StackMemorySimulator(min_qubits=6)
336+
task = emulator.task(program)
337+
qbsA = task.qubits()
338+
qubits = task.run()
339+
qubits2 = task.run()
340+
qbsB = task.qubits()
341+
assert len(qbsA) == 0
342+
assert len(qbsB) == 6
343+
344+
try:
345+
emulator.quantum_state([qubits[0], qubits[0]])
346+
assert False, "Should have failed; qubits must be unique"
347+
except ValueError as e:
348+
assert str(e) == "Qubits must be unique."
349+
350+
try:
351+
emulator.quantum_state([qubits[0], qubits2[1]])
352+
assert False, "Should have failed; qubits must be from the same register"
353+
except ValueError as e:
354+
assert str(e) == "All qubits must be from the same simulator register."
355+
356+
357+
def test_get_qubits():
358+
@squin.kernel
359+
def program():
360+
q = squin.qubit.new(3)
361+
return q
362+
363+
emulator = StackMemorySimulator(min_qubits=6)
364+
task = emulator.task(program)
365+
task.run()
366+
367+
qubits2 = task.qubits()
368+
assert len(qubits2) == 6

0 commit comments

Comments
 (0)