Skip to content

Commit 12b3906

Browse files
authored
Implement structure preserving dynamical decoupling (dd-v2) (#7609)
Previous version of dd allows adding new moments to the circuit in the transformer. In some use cases, we might see a lot of new moments in the transformed circuits. This PR will fix the issues. First, it captures the circuit's structure. Second, it inserts elements based on the structural information gathered in the initial step: [the design](https://docs.google.com/document/d/1FhPwRRazCKKpEG2L3P8CvN5Eqd_3s-MGJSVU6xLxv44/edit?tab=t.0#heading=h.xgjl2srtytjt).
1 parent 472b6ad commit 12b3906

File tree

2 files changed

+359
-292
lines changed

2 files changed

+359
-292
lines changed

cirq-core/cirq/transformers/dynamical_decoupling.py

Lines changed: 185 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
from __future__ import annotations
1818

19+
from enum import Enum
1920
from functools import reduce
2021
from itertools import cycle
2122
from typing import TYPE_CHECKING
2223

2324
import numpy as np
25+
from attrs import frozen
2426

2527
from cirq import ops, protocols
2628
from cirq.circuits import Circuit, FrozenCircuit, Moment
@@ -133,10 +135,6 @@ def _calc_busy_moment_range_of_each_qubit(circuit: FrozenCircuit) -> dict[ops.Qi
133135
return busy_moment_range_by_qubit
134136

135137

136-
def _is_insertable_moment(moment: Moment, single_qubit_gate_moments_only: bool) -> bool:
137-
return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)
138-
139-
140138
def _merge_single_qubit_ops_to_phxz(
141139
q: ops.Qid, operations: tuple[ops.Operation, ...]
142140
) -> ops.Operation:
@@ -149,34 +147,162 @@ def _merge_single_qubit_ops_to_phxz(
149147
return gate.on(q)
150148

151149

152-
def _calc_pulled_through(moment: Moment, input_pauli_ops: ops.PauliString) -> ops.PauliString:
153-
"""Calculates the pulled_through such that circuit(input_pauli_ops, moment.clifford_ops) is
154-
equivalent to circuit(moment.clifford_ops, pulled_through).
150+
def _backward_set_stopping_slots(
151+
q: ops.Qid,
152+
from_mid: int,
153+
mergable: dict[ops.Qid, dict[int, bool]],
154+
need_to_stop: dict[ops.Qid, dict[int, bool]],
155+
gate_types: dict[ops.Qid, dict[int, _CellType]],
156+
circuit: FrozenCircuit,
157+
):
158+
"""Sets stopping slots for dynamical decoupling insertion.
159+
160+
This function traverses backward from a given moment `from_mid` for a specific qubit `q`.
161+
It identifies moments where a dynamical decoupling sequence needs to be "stopped".
162+
163+
Args:
164+
q: The qubit for which to set stopping slots.
165+
from_mid: The moment ID to start the backward traversal from.
166+
mergable: A dictionary indicating if a single-qubit Clifford gate at (qubit, moment_id)
167+
can be merged with a Pauli gate.
168+
need_to_stop: A dictionary to mark moments where a DD sequence must be stopped.
169+
gate_types: A dictionary indicating the type of gate at each (qubit, moment_id).
170+
circuit: The original frozen circuit.
155171
"""
156-
clifford_ops_in_moment: list[ops.Operation] = [
157-
op for op in moment.operations if _is_clifford_op(op)
158-
]
159-
return input_pauli_ops.after(clifford_ops_in_moment)
172+
affected_qubits: set[ops.Qid] = {q}
173+
for back_mid in range(from_mid, -1, -1):
174+
for back_q in set(affected_qubits):
175+
if gate_types[back_q][back_mid] == _CellType.WALL:
176+
affected_qubits.remove(back_q)
177+
continue
178+
if mergable[back_q][back_mid]:
179+
need_to_stop[back_q][back_mid] = True
180+
affected_qubits.remove(back_q)
181+
continue
182+
op_at_q = circuit[back_mid].operation_at(back_q) or ops.I(q)
183+
affected_qubits.update(op_at_q.qubits)
184+
if not affected_qubits:
185+
break
186+
187+
188+
class _CellType(Enum):
189+
UNKNOWN = '?'
190+
# Non-insertable gates that cannot be pulled through
191+
WALL = 'w'
192+
# Clifford gates where Pauli Gates can be pulled through
193+
DOOR = 'd'
194+
# An empty gate can be used to insert Pauli gates from the dd sequence
195+
INSERTABLE = 'i'
196+
197+
198+
@frozen
199+
class _Grid:
200+
"""A grid representation of the circuit where each gate position is labeled for
201+
dynamical decoupling.
202+
203+
With this representation, a DD sequence can be automatically navigated in a
204+
forward-only process. This avoids issues where a partially inserted DD
205+
sequence encounters a "wall" and a new moment must be inserted because the
206+
remaining DD sequence cannot be absorbed by nearby gates.
207+
208+
This labeled representation pre-calculates where DD pulses can be inserted
209+
and where leftover DD sequences must be merged, avoiding the need for
210+
backtracking.
211+
212+
An example labeled circuit is shown below:
213+
| 0 | 1 | 2 | 3 | 4 |
214+
-----+-----+-----+-----+-----+-----+
215+
q(0) | d | i | i,s | d | w |
216+
q(1) | d | i | d,s | w | w |
217+
q(2) | d | d | d,s | w | w |
218+
where `w`=WALL, `d`=DOOR, `i`=INSERTABLE. `s` represents a stop gate,
219+
meaning that any unfinished DD sequences must be merged at this gate.
220+
"""
221+
222+
gate_types: dict[ops.Qid, dict[int, _CellType]]
223+
need_to_stop: dict[ops.Qid, dict[int, bool]]
224+
circuit: FrozenCircuit
225+
226+
@classmethod
227+
def from_circuit(
228+
cls, circuit: cirq.FrozenCircuit, single_qubit_gate_moments_only: bool
229+
) -> _Grid:
230+
gate_types: dict[ops.Qid, dict[int, _CellType]] = {
231+
q: {mid: _CellType.UNKNOWN for mid in range(len(circuit))} for q in circuit.all_qubits()
232+
}
233+
mergable: dict[ops.Qid, dict[int, bool]] = {
234+
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
235+
}
236+
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(circuit)
237+
238+
# Set gate types for each (q, mid)
239+
for mid, moment in enumerate(circuit):
240+
is_insertable_moment = (
241+
not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)
242+
)
243+
for q in circuit.all_qubits():
244+
if mid < busy_moment_range_by_qubit[q][0] or mid > busy_moment_range_by_qubit[q][1]:
245+
gate_types[q][mid] = _CellType.WALL
246+
continue
247+
op_at_q = moment.operation_at(q)
248+
if op_at_q is None:
249+
if is_insertable_moment:
250+
gate_types[q][mid] = _CellType.INSERTABLE
251+
mergable[q][mid] = True
252+
else:
253+
gate_types[q][mid] = _CellType.DOOR
254+
else:
255+
if _is_clifford_op(op_at_q):
256+
gate_types[q][mid] = _CellType.DOOR
257+
mergable[q][mid] = _is_single_qubit_operation(op_at_q)
258+
else:
259+
gate_types[q][mid] = _CellType.WALL
260+
261+
need_to_stop: dict[ops.Qid, dict[int, bool]] = {
262+
q: {mid: False for mid in range(len(circuit))} for q in circuit.all_qubits()
263+
}
264+
# Reversely find the last mergeable gate of each qubit, set them as need_to_stop.
265+
for q in circuit.all_qubits():
266+
_backward_set_stopping_slots(
267+
q, len(circuit) - 1, mergable, need_to_stop, gate_types, circuit
268+
)
269+
# Reversely check for each wall gate, mark the closest mergeable gate as need_to_stop.
270+
for mid in range(len(circuit)):
271+
for q in circuit.all_qubits():
272+
if gate_types[q][mid] == _CellType.WALL:
273+
_backward_set_stopping_slots(
274+
q, mid - 1, mergable, need_to_stop, gate_types, circuit
275+
)
276+
return cls(circuit=circuit, gate_types=gate_types, need_to_stop=need_to_stop)
160277

278+
def __str__(self) -> str:
279+
if not self.gate_types:
280+
return "Grid(empty)"
161281

162-
def _get_stop_qubits(moment: Moment) -> set[ops.Qid]:
163-
stop_pulling_through_qubits: set[ops.Qid] = set()
164-
for op in moment:
165-
if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary(
166-
op
167-
): # multi-qubit clifford op or non-mergable op.
168-
stop_pulling_through_qubits.update(op.qubits)
169-
return stop_pulling_through_qubits
282+
qubits = sorted(list(self.gate_types.keys()))
283+
num_moments = len(self.gate_types[qubits[0]])
170284

285+
max_qubit_len = max(len(str(q)) for q in qubits) if qubits else 0
171286

172-
def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool:
173-
"""With a pulling through pauli gate before op_at_q, need to merge with the
174-
pauli in the conditions below."""
175-
# The op must be mergable and single-qubit
176-
if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)):
177-
return False
178-
# Either non-Clifford or at the last busy moment
179-
return is_at_last_busy_moment or not _is_clifford_op(op_at_q)
287+
header = f"{'':>{max_qubit_len}} |"
288+
for i in range(num_moments):
289+
header += f" {i:^3} |"
290+
291+
separator = f"{'-' * max_qubit_len}-+"
292+
separator += '-----+' * num_moments
293+
294+
lines = ["Grid Repr:", header, separator]
295+
296+
for q in qubits:
297+
row_str = f"{str(q):>{max_qubit_len}} |"
298+
for mid in range(num_moments):
299+
gate_type = self.gate_types[q][mid].value
300+
stop = self.need_to_stop[q][mid]
301+
cell = f"{gate_type},s" if stop else f" {gate_type} "
302+
row_str += f" {cell} |"
303+
lines.append(row_str)
304+
305+
return "\n".join(lines)
180306

181307

182308
@transformer_api.transformer
@@ -188,7 +314,7 @@ def add_dynamical_decoupling(
188314
single_qubit_gate_moments_only: bool = True,
189315
) -> cirq.Circuit:
190316
"""Adds dynamical decoupling gate operations to a given circuit.
191-
This transformer might add new moments and thus change the structure of the original circuit.
317+
This transformer preserves the structure of the original circuit.
192318
193319
Args:
194320
circuit: Input circuit to transform.
@@ -202,11 +328,18 @@ def add_dynamical_decoupling(
202328
Returns:
203329
A copy of the input circuit with dynamical decoupling operations.
204330
"""
205-
base_dd_sequence, pauli_map = _parse_dd_sequence(schema)
331+
332+
if context is not None and context.deep:
333+
raise ValueError("Deep transformation is not supported.")
334+
206335
orig_circuit = circuit.freeze()
207336

208-
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit)
337+
grid = _Grid.from_circuit(orig_circuit, single_qubit_gate_moments_only)
338+
339+
if context is not None and context.logger is not None:
340+
context.logger.log("Preprocessed input circuit grid repr:\n%s", str(grid))
209341

342+
base_dd_sequence, pauli_map = _parse_dd_sequence(schema)
210343
# Stores all the moments of the output circuit chronologically.
211344
transformed_moments: list[Moment] = []
212345
# A PauliString stores the result of 'pulling' Pauli gates past each operations
@@ -215,90 +348,30 @@ def add_dynamical_decoupling(
215348
# Iterator of gate to be used in dd sequence for each qubit.
216349
dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()}
217350

218-
def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation:
219-
nonlocal pulled_through, pauli_map
220-
pulled_through *= pauli_map[insert_gate].on(q)
221-
return insert_gate.on(q)
222-
223-
# Insert and pull remaining Pauli ops through the whole circuit.
224-
# General ideas are
225-
# * Pull through Clifford gates.
226-
# * Stop at multi-qubit non-Clifford ops (and other non-mergable ops).
227-
# * Merge to single-qubit non-Clifford ops.
228-
# * Insert a new moment if necessary.
229-
# After pulling through pulled_through at `moment`, we expect a transformation of
230-
# (pulled_through, moment) -> (updated_moment, updated_pulled_through) or
231-
# (pulled_through, moment) -> (new_moment, updated_moment, updated_pulled_through)
232-
# Moments structure changes are split into 3 steps:
233-
# 1, (..., last_moment, pulled_through1, moment, ...)
234-
# -> (..., last_moment, new_moment or None, pulled_through2, moment, ...)
235-
# 2, (..., pulled_through2, moment, ...) -> (..., pulled_through3, updated_moment, ...)
236-
# 3, (..., pulled_through3, updated_moment, ...)
237-
# -> (..., updated_moment, pulled_through4, ...)
238351
for moment_id, moment in enumerate(orig_circuit.moments):
239-
# Step 1, insert new_moment if necessary.
240-
# In detail: stop pulling through for multi-qubit non-Clifford ops or gates without
241-
# unitary representation (e.g., measure gates). If there are remaining pulled through ops,
242-
# insert into a new moment before current moment.
243-
stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment)
244-
new_moment_ops: list[ops.Operation] = []
245-
for q in stop_pulling_through_qubits:
246-
# Insert the remaining pulled_through
247-
remaining_pulled_through_gate = pulled_through.get(q)
248-
if remaining_pulled_through_gate is not None:
249-
new_moment_ops.append(_update_pulled_through(q, remaining_pulled_through_gate))
250-
# Reset dd sequence
251-
dd_iter_by_qubits[q] = cycle(base_dd_sequence)
252-
# Need to insert a new moment before current moment
253-
if new_moment_ops:
254-
# Fill insertable idle moments in the new moment using dd sequence
255-
for q in orig_circuit.all_qubits() - stop_pulling_through_qubits:
256-
if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]:
257-
new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q])))
258-
transformed_moments.append(Moment(new_moment_ops))
259-
260-
# Step 2, calc updated_moment with insertions / merges.
261352
updated_moment_ops: set[cirq.Operation] = set()
262353
for q in orig_circuit.all_qubits():
263-
op_at_q = moment.operation_at(q)
264-
remaining_pulled_through_gate = pulled_through.get(q)
265-
updated_op = op_at_q
266-
if op_at_q is None: # insert into idle op
267-
if not _is_insertable_moment(moment, single_qubit_gate_moments_only):
268-
continue
269-
if (
270-
busy_moment_range_by_qubit[q][0] < moment_id < busy_moment_range_by_qubit[q][1]
271-
): # insert next pauli gate in the dd sequence
272-
updated_op = _update_pulled_through(q, next(dd_iter_by_qubits[q]))
273-
elif ( # insert the remaining pulled through if beyond the ending busy moment
274-
moment_id > busy_moment_range_by_qubit[q][1]
275-
and remaining_pulled_through_gate is not None
276-
):
277-
updated_op = _update_pulled_through(q, remaining_pulled_through_gate)
278-
elif (
279-
remaining_pulled_through_gate is not None
280-
): # merge pulled-through of q to op_at_q if needed
281-
if _need_merge_pulled_through(
282-
op_at_q, moment_id == busy_moment_range_by_qubit[q][1]
283-
):
284-
remaining_op = _update_pulled_through(q, remaining_pulled_through_gate)
285-
updated_op = _merge_single_qubit_ops_to_phxz(q, (remaining_op, op_at_q))
286-
if updated_op is not None:
287-
updated_moment_ops.add(updated_op)
288-
289-
if updated_moment_ops:
290-
updated_moment = Moment(updated_moment_ops)
291-
transformed_moments.append(updated_moment)
292-
293-
# Step 3, update pulled through.
294-
# In detail: pulling current `pulled_through` through updated_moment.
295-
pulled_through = _calc_pulled_through(updated_moment, pulled_through)
296-
297-
# Insert a new moment if there are remaining pulled-through operations.
298-
ending_moment_ops = []
299-
for affected_q, combined_op_in_pauli in pulled_through.items():
300-
ending_moment_ops.append(combined_op_in_pauli.on(affected_q))
301-
if ending_moment_ops:
302-
transformed_moments.append(Moment(ending_moment_ops))
354+
new_op_at_q = moment.operation_at(q)
355+
if grid.gate_types[q][moment_id] == _CellType.INSERTABLE:
356+
new_gate = next(dd_iter_by_qubits[q])
357+
new_op_at_q = new_gate.on(q)
358+
pulled_through *= pauli_map[new_gate].on(q)
359+
if grid.need_to_stop[q][moment_id]:
360+
to_be_merged = pulled_through.get(q)
361+
if to_be_merged is not None:
362+
new_op_at_q = _merge_single_qubit_ops_to_phxz(
363+
q, (to_be_merged.on(q), new_op_at_q or ops.I(q))
364+
)
365+
pulled_through *= to_be_merged.on(q)
366+
if new_op_at_q is not None:
367+
updated_moment_ops.add(new_op_at_q)
368+
369+
updated_moment = Moment(updated_moment_ops)
370+
clifford_ops = [op for op in updated_moment if _is_clifford_op(op)]
371+
pulled_through = pulled_through.after(clifford_ops)
372+
transformed_moments.append(updated_moment)
373+
374+
if len(pulled_through) > 0:
375+
raise RuntimeError("Expect empty remaining Paulis after the dd insertion.")
303376

304377
return Circuit.from_moments(*transformed_moments)

0 commit comments

Comments
 (0)