Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-0.13.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@
for example the one-shot mid circuit measurement transform.
[(#2056)](https://github.com/PennyLaneAI/catalyst/pull/2056)

* Fixed a bug where applying a quantum transform after a QNode could produce incorrect results or
errors in certain cases. This resolves issues related to transforms operating on QNodes with
classical outputs and improves compatibility with measurement transforms.
[(#2081)](https://github.com/PennyLaneAI/catalyst/pull/2081)

<h3>Internal changes ⚙️</h3>

* Updates use of `qml.transforms.dynamic_one_shot.parse_native_mid_circuit_measurements` to improved signature.
Expand Down
10 changes: 4 additions & 6 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,15 +1257,13 @@ def apply_transforms(
raise CompileError(msg)
tracing_mode = TracingMode.TRANSFORM
elif len(qnode_program) or have_measurements_changed(tape, tapes[0]):
# TODO: Ideally we should allow qnode transforms that don't modify the measurements to
# operate in the permissive tracing mode, but that currently leads to a small number of
# test failures due to the different result format produced in trace_quantum_function.
only_with_dynamic_one_shot = all(
"dynamic_one_shot_partial" in str(getattr(qnode, "transform", ""))
with_measurement_from_counts_or_samples = any(
"measurements_from_counts" in (transform_str := str(getattr(qnode, "transform", "")))
or "measurements_from_samples" in transform_str
for qnode in qnode_program
)

if has_classical_outputs(flat_results) and not only_with_dynamic_one_shot:
if has_classical_outputs(flat_results) and with_measurement_from_counts_or_samples:
msg = (
"Transforming MeasurementProcesses is unsupported with non-MeasurementProcess "
"QNode outputs. The selected device, options, or applied QNode transforms, may be "
Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_ve
"""Process measurements when there are no mid-circuit measurements."""
assert mcm_method == "one-shot"

# flatten the outs structure
out, _ = tree_flatten(out)
new_out = []
idx = 0

Expand Down
2 changes: 1 addition & 1 deletion frontend/test/pytest/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def inject_device_transforms(self, ctx, execution_config=None, shots=None):

return program, config

# Simulate a Qrack-like device that requires meassurement process transforms.
# Simulate a Qrack-like device that requires measurement process transforms.
# Qnode transforms raise this error anyway so we cannot use them directly.
original_preprocess = QJITDevice.preprocess
monkeypatch.setattr(QJITDevice, "preprocess", inject_device_transforms)
Expand Down