Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ jobs:
python3 -m pip install oqc-qcaas-client
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -558,6 +562,11 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim


- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -620,6 +629,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down
10 changes: 10 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim
from pennylane.capture.primitives import grad_prim
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
Expand Down Expand Up @@ -192,6 +193,15 @@ def __init__(self):
super().__init__()


@WorkflowInterpreter.register_primitive(grad_prim)
def handle_grad(self, *args, jaxpr, n_consts, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, args[:n_consts])
new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]).jaxpr

return grad_prim.bind(*args, jaxpr=new_jaxpr, n_consts=n_consts, **kwargs)


# pylint: disable=unused-argument, too-many-arguments
@WorkflowInterpreter.register_primitive(qnode_prim)
def handle_qnode(
Expand Down
43 changes: 34 additions & 9 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
VarianceOp,
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
from pennylane.capture.primitives import grad_prim as pl_grad_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -644,9 +645,12 @@
consts = []
offset = len(args) - len(jaxpr.consts)
for i, jax_array_or_tracer in enumerate(jaxpr.consts):
if not isinstance(
jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer
):
if isinstance(jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer):
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])
else:
# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to
# cast such constants to numpy array types.
Expand All @@ -656,11 +660,6 @@
attr = ir.DenseElementsAttr.get(nparray, type=const_type)
constval = StableHLOConstantOp(attr).results
consts.append(constval)
else:
# There are some cases where this value cannot be converted into
# a jax.numpy.array.
# in that case we get it from the arguments.
consts.append(args[offset + i])

method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums
mlir_ctx = ctx.module_context.context
Expand All @@ -673,7 +672,6 @@
argnum_numpy = np.array(new_argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))

symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand All @@ -692,6 +690,32 @@
).results


# pylint: disable=too-many-arguments
def _capture_grad_lowering(ctx, *args, argnum, jaxpr, n_consts, method, h, fn, scalar_out):
mlir_ctx = ctx.module_context.context
if h:
f64 = ir.F64Type.get(mlir_ctx)
finiteDiffParam = ir.FloatAttr.get(f64, h)
else:
finiteDiffParam = None

new_argnums = [num+n_consts for num in argnum]
argnum_numpy = np.array(new_argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn)
symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)

return GradOp(

Check notice on line 710 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L710

Trailing whitespace (trailing-whitespace)
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(args),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results

# value_and_grad
#
@value_and_grad_p.def_impl
Expand Down Expand Up @@ -2542,6 +2566,7 @@
(while_p, _while_loop_lowering),
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(pl_grad_prim, _capture_grad_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
115 changes: 63 additions & 52 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval


def _only_single_expval(call_jaxpr: core.ClosedJaxpr) -> bool:
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True


def _calculate_diff_method(qn: qml.QNode, call_jaxpr: core.ClosedJaxpr):
diff_method = str(qn.diff_method)
if diff_method != "best":
return diff_method

device_name = getattr(getattr(qn, "device", None), "name", None)

if device_name and "lightning" in device_name and _only_single_expval(call_jaxpr):
return "adjoint"
return "parameter-shift"


def get_call_jaxpr(jaxpr):
"""Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
for eqn in jaxpr.eqns:
Expand All @@ -45,28 +70,35 @@ def get_call_equation(jaxpr):
raise AssertionError("No call_jaxpr found in the JAXPR.")


def lower_jaxpr(ctx, jaxpr, context=None):
def lower_jaxpr(ctx, jaxpr, metadata=None, fn=None):
"""Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p

Args:
ctx: LoweringRuleContext
jaxpr: JAXPR to be lowered
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps

Returns:
FuncOp
"""
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode")
pipeline = equation.params.get("pipeline")
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)

if fn is None or isinstance(fn, qml.QNode):
equation = get_call_equation(jaxpr)
call_jaxpr = equation.params["call_jaxpr"]
pipeline = equation.params.get("pipeline")
callable_ = equation.params.get("fn")
if callable_ is None:
callable_ = equation.params.get("qnode", None)
else:
call_jaxpr = jaxpr
pipeline = ()
callable_ = fn

return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, metadata=metadata)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False):
def lower_callable(ctx, callable_, call_jaxpr, pipeline=(), metadata=None, public=False):
"""Lowers _callable to MLIR.

If callable_ is a qnode, then we will first create a module, then
Expand All @@ -86,33 +118,33 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ
if pipeline is None:
pipeline = tuple()

if not isinstance(callable_, qml.QNode):
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, context=context, public=public
)

return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
if isinstance(callable_, qml.QNode):
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=metadata)
return get_or_create_funcop(
ctx, callable_, call_jaxpr, pipeline, metadata=metadata, public=public
)


# pylint: disable=too-many-arguments, too-many-positional-arguments
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False):
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, public=False):
"""Get funcOp from cache, or create it from scratch

Args:
ctx: LoweringRuleContext
callable_: python function
call_jaxpr: jaxpr representing callable_
context: additional context to distinguish different FuncOps
metadata: additional metadata to distinguish different FuncOps
public: whether the visibility should be marked public

Returns:
FuncOp
"""
if context is None:
context = tuple()
key = (callable_, *context, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
if metadata is None:
metadata = tuple()
key = (callable_, *metadata, *pipeline)
if callable_ is not None:
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
cache(ctx, key, func_op)
return func_op
Expand All @@ -135,10 +167,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):

kwargs = {}
kwargs["ctx"] = ctx.module_context
if not isinstance(callable_, functools.partial):
name = callable_.__name__
else:
if isinstance(callable_, functools.partial):
name = callable_.func.__name__ + ".partial"
else:
name = callable_.__name__

kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
Expand All @@ -154,28 +186,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
if isinstance(callable_, qml.QNode):
func_op.attributes["qnode"] = ir.UnitAttr.get()

diff_method = str(callable_.diff_method)

if diff_method == "best":

def only_single_expval():
found_expval = False
for eqn in call_jaxpr.eqns:
name = eqn.primitive.name
if name in {"probs", "counts", "sample"}:
return False
elif name == "expval":
if found_expval:
return False
found_expval = True
return True

device_name = getattr(getattr(callable_, "device", None), "name", None)

if device_name and "lightning" in device_name and only_single_expval():
diff_method = "adjoint"
else:
diff_method = "parameter-shift"
diff_method = _calculate_diff_method(callable_, call_jaxpr)

func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)

Expand All @@ -195,7 +206,7 @@ def only_single_expval():
return func_op


def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata):
"""A wrapper around lower_qnode_to_funcop that will cache the FuncOp.

Args:
Expand All @@ -205,11 +216,11 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
Returns:
FuncOp
"""
if context is None:
context = tuple()
if metadata is None:
metadata = tuple()
if callable_.static_argnums:
return lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
key = (callable_, *context, *pipeline)
key = (callable_, *metadata, *pipeline)
if func_op := get_cached(ctx, key):
return func_op
func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)
Expand Down
Loading
Loading