diff --git a/doc/releases/changelog-0.13.0.md b/doc/releases/changelog-0.13.0.md
index 1167e2296..3873df5f2 100644
--- a/doc/releases/changelog-0.13.0.md
+++ b/doc/releases/changelog-0.13.0.md
@@ -348,6 +348,11 @@
for example the one-shot mid circuit measurement transform.
[(#2056)](https://github.com/PennyLaneAI/catalyst/pull/2056)
+* Fixed a bug where applying a quantum transform after a QNode could produce incorrect results or
+ errors in certain cases. This resolves issues related to transforms operating on QNodes with
+ classical outputs and improves compatibility with measurement transforms.
+ [(#2081)](https://github.com/PennyLaneAI/catalyst/pull/2081)
+
Internal changes ⚙️
* Updates use of `qml.transforms.dynamic_one_shot.parse_native_mid_circuit_measurements` to improved signature.
diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py
index 82325b780..06a7a7662 100644
--- a/frontend/catalyst/jax_tracer.py
+++ b/frontend/catalyst/jax_tracer.py
@@ -1257,15 +1257,13 @@ def apply_transforms(
raise CompileError(msg)
tracing_mode = TracingMode.TRANSFORM
elif len(qnode_program) or have_measurements_changed(tape, tapes[0]):
- # TODO: Ideally we should allow qnode transforms that don't modify the measurements to
- # operate in the permissive tracing mode, but that currently leads to a small number of
- # test failures due to the different result format produced in trace_quantum_function.
- only_with_dynamic_one_shot = all(
- "dynamic_one_shot_partial" in str(getattr(qnode, "transform", ""))
+ with_measurement_from_counts_or_samples = any(
+ "measurements_from_counts" in (transform_str := str(getattr(qnode, "transform", "")))
+ or "measurements_from_samples" in transform_str
for qnode in qnode_program
)
- if has_classical_outputs(flat_results) and not only_with_dynamic_one_shot:
+ if has_classical_outputs(flat_results) and with_measurement_from_counts_or_samples:
msg = (
"Transforming MeasurementProcesses is unsupported with non-MeasurementProcess "
"QNode outputs. The selected device, options, or applied QNode transforms, may be "
diff --git a/frontend/catalyst/qfunc.py b/frontend/catalyst/qfunc.py
index 38173c5ed..314137353 100644
--- a/frontend/catalyst/qfunc.py
+++ b/frontend/catalyst/qfunc.py
@@ -404,6 +404,8 @@ def _process_terminal_measurements(mcm_method, cpy_tape, out, snapshots, shot_ve
"""Process measurements when there are no mid-circuit measurements."""
assert mcm_method == "one-shot"
+ # flatten the outs structure
+ out, _ = tree_flatten(out)
new_out = []
idx = 0
diff --git a/frontend/test/pytest/test_transform.py b/frontend/test/pytest/test_transform.py
index deda75533..b7aafd2ef 100644
--- a/frontend/test/pytest/test_transform.py
+++ b/frontend/test/pytest/test_transform.py
@@ -1103,7 +1103,7 @@ def inject_device_transforms(self, ctx, execution_config=None, shots=None):
return program, config
- # Simulate a Qrack-like device that requires meassurement process transforms.
+ # Simulate a Qrack-like device that requires measurement process transforms.
# Qnode transforms raise this error anyway so we cannot use them directly.
original_preprocess = QJITDevice.preprocess
monkeypatch.setattr(QJITDevice, "preprocess", inject_device_transforms)