diff --git a/doc/releases/changelog-0.13.0.md b/doc/releases/changelog-0.13.0.md index 1167e2296..3873df5f2 100644 --- a/doc/releases/changelog-0.13.0.md +++ b/doc/releases/changelog-0.13.0.md @@ -348,6 +348,11 @@ for example the one-shot mid circuit measurement transform. [(#2056)](https://github.com/PennyLaneAI/catalyst/pull/2056) +* Fixed a bug where applying a quantum transform after a QNode could produce incorrect results or + errors in certain cases. This resolves issues related to transforms operating on QNodes with + classical outputs and improves compatibility with measurement transforms. + [(#2081)](https://github.com/PennyLaneAI/catalyst/pull/2081) +

Internal changes ⚙️

* Updates use of `qml.transforms.dynamic_one_shot.parse_native_mid_circuit_measurements` to improved signature. diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 82325b780..06a7a7662 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1257,15 +1257,13 @@ def apply_transforms( raise CompileError(msg) tracing_mode = TracingMode.TRANSFORM elif len(qnode_program) or have_measurements_changed(tape, tapes[0]): - # TODO: Ideally we should allow qnode transforms that don't modify the measurements to - # operate in the permissive tracing mode, but that currently leads to a small number of - # test failures due to the different result format produced in trace_quantum_function. - only_with_dynamic_one_shot = all( - "dynamic_one_shot_partial" in str(getattr(qnode, "transform", "")) + with_measurement_from_counts_or_samples = any( + "measurements_from_counts" in (transform_str := str(getattr(qnode, "transform", ""))) + or "measurements_from_samples" in transform_str for qnode in qnode_program ) - if has_classical_outputs(flat_results) and not only_with_dynamic_one_shot: + if has_classical_outputs(flat_results) and with_measurement_from_counts_or_samples: msg = ( "Transforming MeasurementProcesses is unsupported with non-MeasurementProcess " "QNode outputs. The selected device, options, or applied QNode transforms, may be " diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py index 38173c5ed..314137353 100644 --- a/frontend/catalyst/qfunc.py +++ b/frontend/catalyst/qfunc.py @@ -404,6 +404,8 @@ def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_ve """Process measurements when there are no mid-circuit measurements.""" assert mcm_method == "one-shot" + # flatten the outs structure + out, _ = tree_flatten(out) new_out = [] idx = 0 diff --git a/frontend/test/pytest/test_transform.py b/frontend/test/pytest/test_transform.py index deda75533..b7aafd2ef 100644 --- a/frontend/test/pytest/test_transform.py +++ b/frontend/test/pytest/test_transform.py @@ -1103,7 +1103,7 @@ def inject_device_transforms(self, ctx, execution_config=None, shots=None): return program, config - # Simulate a Qrack-like device that requires meassurement process transforms. + # Simulate a Qrack-like device that requires measurement process transforms. # Qnode transforms raise this error anyway so we cannot use them directly. original_preprocess = QJITDevice.preprocess monkeypatch.setattr(QJITDevice, "preprocess", inject_device_transforms)