Skip to content

Commit 52dcbaa

Browse files
committed
update
1 parent 083016b commit 52dcbaa

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

frontend/catalyst/jax_primitives.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def wrapper(*args, **kwargs):
391391

392392

393393
def decomposition_rule(
394-
func=None, *, is_qreg=True, num_params=0, base: qml.operation.Operator = None
394+
func=None, *, is_qreg=True, num_params=0, op: qml.operation.Operator = None
395395
):
396396
"""
397397
Denotes the creation of a quantum definition in the intermediate representation.
@@ -403,13 +403,13 @@ def decomposition_rule(
403403

404404
if func is None:
405405
return functools.partial(
406-
decomposition_rule, is_qreg=is_qreg, num_params=num_params, base=base
406+
decomposition_rule, is_qreg=is_qreg, num_params=num_params, op=op
407407
)
408408

409409
@functools.wraps(func)
410410
def wrapper(*args, **kwargs):
411-
if base is not None:
412-
new_func = functools.partial(func, **base.hyperparameters)
411+
if op is not None:
412+
new_func = functools.partial(func,wires=op.wires, **op.hyperparameters)
413413
jaxpr = jax.make_jaxpr(new_func)(*args, **kwargs)
414414
else:
415415
jaxpr = jax.make_jaxpr(func)(*args, **kwargs)
@@ -418,7 +418,7 @@ def wrapper(*args, **kwargs):
418418
func_jaxpr=jaxpr,
419419
is_qreg=is_qreg,
420420
num_params=num_params,
421-
op_name=base.name if base else "",
421+
op_name=op.name if op else "",
422422
)
423423

424424
return wrapper

test_new_decomp_adjoint.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,35 @@
1515
qml.decomposition.enable_graph()
1616

1717

18-
@decomposition_rule(base=qml.adjoint(qml.RZ(float, wires=0)))
18+
@decomposition_rule(op=qml.adjoint(qml.RZ(float, wires=1)))
1919
def adjoint_rotation(phi, wires, base, **__):
2020
"""Decompose the adjoint of a rotation operator by inverting the angle."""
2121
_, struct = base._flatten()
2222
base._unflatten((-phi,), struct)
23+
24+
25+
@decomposition_rule(op=qml.adjoint(qml.adjoint(qml.Hadamard(wires=0))))
26+
def cancel_adjoint(*params, wires, base):
27+
"""Decompose the adjoint of a rotation operator by inverting the angle."""
28+
base.base._unflatten(*base.base._flatten())
29+
30+
#
31+
@decomposition_rule(op=qml.ctrl(qml.adjoint(qml.RX(0.2, wires=[0])), control=1))
32+
def flip_control_adjoint(
33+
*_, wires, control_wires, control_values, work_wires, work_wire_type, base, **__
34+
):
35+
"""Decompose the control of an adjoint by applying control to the base of the adjoint
36+
and taking the adjoint of the control."""
37+
base_op = base.base._unflatten(*base.base._flatten())
38+
qml.adjoint(
39+
qml.ctrl(
40+
base_op,
41+
control=wires[: len(control_wires)],
42+
control_values=control_values,
43+
work_wires=work_wires,
44+
work_wire_type=work_wire_type,
45+
)
46+
)
2347

2448

2549
@qml.qjit()
@@ -29,7 +53,9 @@ def circuit():
2953

3054
qml.adjoint(qml.RY(0.432, wires=1))
3155

32-
adjoint_rotation(float, int)
56+
adjoint_rotation(float)
57+
cancel_adjoint()
58+
flip_control_adjoint()
3359

3460
return qml.expval(qml.Z(0))
3561

0 commit comments

Comments
 (0)