diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 6ae7e5a0d3..8ef7e33103 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -192,6 +192,12 @@ def __init__(self): super().__init__() +def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device): + gate_set = set(device.capabilities.operations) + targs = () + tkwargs = {"gate_set": gate_set} + return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs) + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( @@ -199,21 +205,24 @@ def handle_qnode( ): """Handle the conversion from plxpr to Catalyst jaxpr for the qnode primitive""" + shots = args[0] if shots_len else 0 + consts = args[shots_len : n_consts + shots_len] + non_const_args = args[shots_len + n_consts :] + + # hopefully this patch stays patchy and doesn't become permanent + closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device) + self.qubit_index_recorder = QubitIndexRecorder() if shots_len > 1: raise NotImplementedError("shot vectors are not yet supported for catalyst conversion.") - shots = args[0] if shots_len else 0 - consts = args[shots_len : n_consts + shots_len] - non_const_args = args[shots_len + n_consts :] - closed_jaxpr = ( - ClosedJaxpr(qfunc_jaxpr, consts) + closed_jaxpr if not self.requires_decompose_lowering else _apply_compiler_decompose_to_plxpr( - inner_jaxpr=qfunc_jaxpr, - consts=consts, + inner_jaxpr=closed_jaxpr.jaxpr, + consts=closed_jaxpr.consts, ncargs=non_const_args, tgateset=list(self.decompose_tkwargs.get("gate_set", [])), )