Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot

from catalyst.device import extract_backend_info
from catalyst.device import extract_backend_info, get_device_capabilities
from catalyst.from_plxpr.decompose import COMPILER_OPS_FOR_DECOMPOSITION, DecompRuleInterpreter
from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder, get_in_qubit_values
from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config
Expand Down Expand Up @@ -192,6 +192,16 @@
super().__init__()


def _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device):
gate_set = set(get_device_capabilities(device).operations)
if get_device_capabilities(device).initial_state_prep:
gate_set.add("StatePrep")
targs = ()
tkwargs = {"gate_set": gate_set}
breakpoint()

Check notice on line 201 in frontend/catalyst/from_plxpr/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/from_plxpr.py#L201

Leaving functions creating breakpoints in production code is not recommended (forgotten-debug-statement)
return qml.transforms.decompose.plxpr_transform(qfunc_jaxpr, consts, targs, tkwargs)


# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
def handle_qnode(
Expand All @@ -208,6 +218,10 @@
consts = args[shots_len : n_consts + shots_len]
non_const_args = args[shots_len + n_consts :]

# hopefully this patch stays patchy and doesn't become permanent
# TODO: Too much has changed within this function, need to rework the patch
closed_jaxpr = _decompose_jaxpr_to_gateset(qfunc_jaxpr, consts, device)

closed_jaxpr = (
ClosedJaxpr(qfunc_jaxpr, consts)
if not self.requires_decompose_lowering
Expand Down