Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@
ComputationalBasisOp,
CountsOp,
CustomOp,
PauliXOp,
PauliYOp,
PauliZOp,
HadamardOp,
SGateOp,
TGateOp,
CNOTOp,
RXOp,
RYOp,
RZOp,
DeallocOp,
DeallocQubitOp,
DeviceInitOp,
Expand Down Expand Up @@ -1268,7 +1278,7 @@


# pylint: disable=too-many-arguments
def _qinst_lowering(

Check notice on line 1281 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L1281

Too many branches (15/12) (too-many-branches)

Check notice on line 1281 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L1281

Too many return statements (13/6) (too-many-return-statements)
jax_ctx: mlir.LoweringRuleContext,
*qubits_or_params,
op=None,
Expand Down Expand Up @@ -1310,6 +1320,122 @@
name_str = str(name_attr)
name_str = name_str.replace('"', "")

if name_str == "PauliX":
assert len(float_params) == 0, "PauliX takes no float parameters"
return PauliXOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "PauliY":
assert len(float_params) == 0, "PauliY takes no float parameters"
return PauliYOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "PauliZ":
assert len(float_params) == 0, "PauliZ takes no float parameters"
return PauliZOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "Hadamard":
assert len(float_params) == 0, "Hadamard takes no float parameters"
return HadamardOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "S":
assert len(float_params) == 0, "S takes no float parameters"
return SGateOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "T":
assert len(float_params) == 0, "T takes no float parameters"
return TGateOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "CNOT":
assert len(float_params) == 0, "CNOT takes no float parameters"
return CNOTOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "RX":
assert len(float_params) == 1, "RX takes one float parameter"
float_param = float_params[0]
return RXOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
theta=float_param,
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "RY":
assert len(float_params) == 1, "RY takes one float parameter"
float_param = float_params[0]
return RYOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
theta=float_param,
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "RZ":
assert len(float_params) == 1, "RZ takes one float parameter"
float_param = float_params[0]
return RZOp(
out_qubits=[qubit.type for qubit in qubits],
out_ctrl_qubits=[qubit.type for qubit in ctrl_qubits],
theta=float_param,
in_qubits=qubits,
in_ctrl_qubits=ctrl_qubits,
in_ctrl_values=ctrl_values_i1,
adjoint=adjoint,
).results

if name_str == "MultiRZ":
assert len(float_params) == 1, "MultiRZ takes one float parameter"
float_param = float_params[0]
Expand Down
Loading