diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad137..8775c24f8d 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -475,6 +475,10 @@ jobs: python3 -m pip install oqc-qcaas-client make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -558,6 +562,11 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim + + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -620,6 +629,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@add-fn-to-grad-prim + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 6ae7e5a0d3..cc5c76b4d0 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -33,6 +33,7 @@ from pennylane.capture.expand_transforms import ExpandTransformsInterpreter from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim +from pennylane.capture.primitives import grad_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.measurements import CountsMP @@ -192,6 +193,15 @@ def __init__(self): super().__init__() +@WorkflowInterpreter.register_primitive(grad_prim) +def handle_grad(self, *args, jaxpr, n_consts, **kwargs): + """Translate a grad equation.""" + f = partial(copy(self).eval, jaxpr, args[:n_consts]) + new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]).jaxpr + + return grad_prim.bind(*args, jaxpr=new_jaxpr, n_consts=n_consts, **kwargs) + + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) def handle_qnode( diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index fec8e41684..5fba1fd5f5 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -100,6 +100,7 @@ VarianceOp, ) from mlir_quantum.dialects.quantum import YieldOp as QYieldOp +from pennylane.capture.primitives import grad_prim as pl_grad_prim from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -644,9 +645,12 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): consts = [] offset = len(args) - len(jaxpr.consts) for i, jax_array_or_tracer in enumerate(jaxpr.consts): - if not isinstance( - jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer - ): + if isinstance(jax_array_or_tracer, jax._src.interpreters.partial_eval.DynamicJaxprTracer): + # There are some cases where this value cannot be converted into + # a jax.numpy.array. + # in that case we get it from the arguments. + consts.append(args[offset + i]) + else: # ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of # element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to # cast such constants to numpy array types. @@ -656,11 +660,6 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): attr = ir.DenseElementsAttr.get(nparray, type=const_type) constval = StableHLOConstantOp(attr).results consts.append(constval) - else: - # There are some cases where this value cannot be converted into - # a jax.numpy.array. - # in that case we get it from the arguments. - consts.append(args[offset + i]) method, h, argnums = grad_params.method, grad_params.h, grad_params.expanded_argnums mlir_ctx = ctx.module_context.context @@ -673,7 +672,6 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): argnum_numpy = np.array(new_argnums) diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy) func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums)) - symbol_ref = get_symbolref(ctx, func_op) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) @@ -692,6 +690,32 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): ).results +# pylint: disable=too-many-arguments +def _capture_grad_lowering(ctx, *args, argnum, jaxpr, n_consts, method, h, fn, scalar_out): + mlir_ctx = ctx.module_context.context + if h: + f64 = ir.F64Type.get(mlir_ctx) + finiteDiffParam = ir.FloatAttr.get(f64, h) + else: + finiteDiffParam = None + + new_argnums = [num+n_consts for num in argnum] + argnum_numpy = np.array(new_argnums) + diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy) + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn) + symbol_ref = get_symbolref(ctx, func_op) + output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) + flat_output_types = util.flatten(output_types) + + return GradOp( + flat_output_types, + ir.StringAttr.get(method), + symbol_ref, + mlir.flatten_lowering_ir_args(args), + diffArgIndices=diffArgIndices, + finiteDiffParam=finiteDiffParam, + ).results + # value_and_grad # @value_and_grad_p.def_impl @@ -2542,6 +2566,7 @@ def subroutine_lowering(*args, **kwargs): (while_p, _while_loop_lowering), (for_p, _for_loop_lowering), (grad_p, _grad_lowering), + (pl_grad_prim, _capture_grad_lowering), (func_p, _func_lowering), (jvp_p, _jvp_lowering), (vjp_p, _vjp_lowering), diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 1227b7f1b5..a14949c6df 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -29,6 +29,31 @@ from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval +def _only_single_expval(call_jaxpr: core.ClosedJaxpr) -> bool: + found_expval = False + for eqn in call_jaxpr.eqns: + name = eqn.primitive.name + if name in {"probs", "counts", "sample"}: + return False + elif name == "expval": + if found_expval: + return False + found_expval = True + return True + + +def _calculate_diff_method(qn: qml.QNode, call_jaxpr: core.ClosedJaxpr): + diff_method = str(qn.diff_method) + if diff_method != "best": + return diff_method + + device_name = getattr(getattr(qn, "device", None), "name", None) + + if device_name and "lightning" in device_name and _only_single_expval(call_jaxpr): + return "adjoint" + return "parameter-shift" + + def get_call_jaxpr(jaxpr): """Extracts the `call_jaxpr` from a JAXPR if it exists.""" "" for eqn in jaxpr.eqns: @@ -45,28 +70,35 @@ def get_call_equation(jaxpr): raise AssertionError("No call_jaxpr found in the JAXPR.") -def lower_jaxpr(ctx, jaxpr, context=None): +def lower_jaxpr(ctx, jaxpr, metadata=None, fn=None): """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p Args: ctx: LoweringRuleContext jaxpr: JAXPR to be lowered - context: additional context to distinguish different FuncOps + metadata: additional metadata to distinguish different FuncOps Returns: FuncOp """ - equation = get_call_equation(jaxpr) - call_jaxpr = equation.params["call_jaxpr"] - callable_ = equation.params.get("fn") - if callable_ is None: - callable_ = equation.params.get("qnode") - pipeline = equation.params.get("pipeline") - return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context) + + if fn is None or isinstance(fn, qml.QNode): + equation = get_call_equation(jaxpr) + call_jaxpr = equation.params["call_jaxpr"] + pipeline = equation.params.get("pipeline") + callable_ = equation.params.get("fn") + if callable_ is None: + callable_ = equation.params.get("qnode", None) + else: + call_jaxpr = jaxpr + pipeline = () + callable_ = fn + + return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, metadata=metadata) # pylint: disable=too-many-arguments, too-many-positional-arguments -def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, public=False): +def lower_callable(ctx, callable_, call_jaxpr, pipeline=(), metadata=None, public=False): """Lowers _callable to MLIR. If callable_ is a qnode, then we will first create a module, then @@ -86,33 +118,33 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ if pipeline is None: pipeline = tuple() - if not isinstance(callable_, qml.QNode): - return get_or_create_funcop( - ctx, callable_, call_jaxpr, pipeline, context=context, public=public - ) - - return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context) + if isinstance(callable_, qml.QNode): + return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=metadata) + return get_or_create_funcop( + ctx, callable_, call_jaxpr, pipeline, metadata=metadata, public=public + ) # pylint: disable=too-many-arguments, too-many-positional-arguments -def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None, public=False): +def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, public=False): """Get funcOp from cache, or create it from scratch Args: ctx: LoweringRuleContext callable_: python function call_jaxpr: jaxpr representing callable_ - context: additional context to distinguish different FuncOps + metadata: additional metadata to distinguish different FuncOps public: whether the visibility should be marked public Returns: FuncOp """ - if context is None: - context = tuple() - key = (callable_, *context, *pipeline) - if func_op := get_cached(ctx, key): - return func_op + if metadata is None: + metadata = tuple() + key = (callable_, *metadata, *pipeline) + if callable_ is not None: + if func_op := get_cached(ctx, key): + return func_op func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public) cache(ctx, key, func_op) return func_op @@ -135,10 +167,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): kwargs = {} kwargs["ctx"] = ctx.module_context - if not isinstance(callable_, functools.partial): - name = callable_.__name__ - else: + if isinstance(callable_, functools.partial): name = callable_.func.__name__ + ".partial" + else: + name = callable_.__name__ kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr @@ -154,28 +186,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): if isinstance(callable_, qml.QNode): func_op.attributes["qnode"] = ir.UnitAttr.get() - diff_method = str(callable_.diff_method) - - if diff_method == "best": - - def only_single_expval(): - found_expval = False - for eqn in call_jaxpr.eqns: - name = eqn.primitive.name - if name in {"probs", "counts", "sample"}: - return False - elif name == "expval": - if found_expval: - return False - found_expval = True - return True - - device_name = getattr(getattr(callable_, "device", None), "name", None) - - if device_name and "lightning" in device_name and only_single_expval(): - diff_method = "adjoint" - else: - diff_method = "parameter-shift" + diff_method = _calculate_diff_method(callable_, call_jaxpr) func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) @@ -195,7 +206,7 @@ def only_single_expval(): return func_op -def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context): +def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, metadata): """A wrapper around lower_qnode_to_funcop that will cache the FuncOp. Args: @@ -205,11 +216,11 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context): Returns: FuncOp """ - if context is None: - context = tuple() + if metadata is None: + metadata = tuple() if callable_.static_argnums: return lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline) - key = (callable_, *context, *pipeline) + key = (callable_, *metadata, *pipeline) if func_op := get_cached(ctx, key): return func_op func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline) diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 0d3decdd35..e1fb93dce6 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -28,7 +28,6 @@ from catalyst import ( CompileError, DifferentiableCompileError, - cond, for_loop, grad, jacobian, @@ -78,6 +77,7 @@ def test_deduction_float(self): infer.calculate_grad_shape(in_signature, [0]) +@pytest.mark.usefixtures("use_both_frontend") def test_gradient_generate_once(): """Test that gradients are only generated once even if they are called multiple times. This is already tested @@ -89,7 +89,7 @@ def identity(x): @qjit def wrap(x: float): - diff = grad(identity) + diff = qml.grad(identity) return diff(x) + diff(x) assert "@identity_0" not in wrap.mlir @@ -173,6 +173,7 @@ def f(x): assert np.allclose(expected[1], result[1]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("argnums", (None, 0, [1], (0, 1))) def test_jacobian_outside_qjit_argnums(argnums): """Test that argnums work correctly outside of a jitting context.""" @@ -203,7 +204,7 @@ def f(x: float): @qjit def grad_f(x): - return grad(f, method="auto")(x) + return qml.grad(f, method="auto")(x) with pytest.raises( DifferentiableCompileError, @@ -224,7 +225,7 @@ def func(p): return x, y def workflow(p: float): - return jacobian(func, method="auto")(p) + return qml.jacobian(func, method="auto")(p) with pytest.raises( DifferentiableCompileError, match="The parameter-shift method can only be used" @@ -244,7 +245,7 @@ def func(p): return x, y def workflow(p: float): - return jacobian(func, method="auto")(p) + return qml.jacobian(func, method="auto")(p) with pytest.raises(DifferentiableCompileError, match="The adjoint method can only be used"): qjit(workflow) @@ -257,7 +258,7 @@ def test_grad_on_qjit(): def f(x: float): return x * x - result = qjit(grad(f))(3.0) + result = qjit(qml.grad(f))(3.0) expected = 6.0 assert np.allclose(result, expected) @@ -470,6 +471,7 @@ def circuit(params): assert np.allclose(result[1]["y"], expected[1]["y"]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff(inp, backend): """Test finite diff.""" @@ -481,7 +483,7 @@ def f(x): @qjit def compiled_grad_default(x: float): g = qml.qnode(qml.device(backend, wires=1))(f) - h = grad(g, method="fd") + h = qml.grad(g, method="fd") return h(x) def interpretted_grad_default(x): @@ -493,6 +495,7 @@ def interpretted_grad_default(x): assert np.allclose(compiled_grad_default(inp), interpretted_grad_default(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff_mul(inp, backend): """Test finite diff with mul.""" @@ -504,7 +507,7 @@ def f(x): @qjit def compiled_grad_default(x: float): g = qml.qnode(qml.device(backend, wires=1))(f) - h = grad(g, method="fd") + h = qml.grad(g, method="fd") return h(x) def interpretted_grad_default(x): @@ -516,6 +519,7 @@ def interpretted_grad_default(x): assert np.allclose(compiled_grad_default(inp), interpretted_grad_default(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [1.0, 2.0, 3.0, 4.0]) def test_finite_diff_in_loop(inp, backend): """Test finite diff in loop.""" @@ -527,12 +531,12 @@ def f(x): @qjit def compiled_grad_default(params, ntrials): - diff = grad(f, argnums=0, method="fd") + diff = qml.grad(f, argnum=0, method="fd") def fn(i, g): return diff(params) - return for_loop(0, ntrials, 1)(fn)(params) + return qml.for_loop(0, ntrials, 1)(fn)(params) def interpretted_grad_default(x): device = qml.device("default.qubit", wires=1) @@ -540,9 +544,16 @@ def interpretted_grad_default(x): h = qml.grad(g, argnum=0) return h(x) - assert np.allclose(compiled_grad_default(inp, 5), interpretted_grad_default(inp)) + enabled = qml.capture.enabled() + qml.capture.disable() + expected = interpretted_grad_default(inp) + if enabled: + qml.capture.enable() + assert np.allclose(compiled_grad_default(inp, 5), expected) + +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_adj(inp, backend): """Test the adjoint method.""" @@ -554,7 +565,7 @@ def f(x): @qjit def compiled(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="adjoint")(f) - h = grad(g, method="auto") + h = qml.grad(g, method="auto") return h(x) def interpreted(x): @@ -566,6 +577,7 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_adj_mult(inp, backend): """Test the adjoint method with mult.""" @@ -577,7 +589,7 @@ def f(x): @qjit def compiled(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="adjoint")(f) - h = grad(g, method="auto") + h = qml.grad(g, method="auto") return h(x) def interpreted(x): @@ -589,6 +601,7 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [1.0, 2.0, 3.0, 4.0]) def test_adj_in_loop(inp, backend): """Test the adjoint method in loop.""" @@ -600,12 +613,12 @@ def f(x): @qjit def compiled_grad_default(params, ntrials): - diff = grad(f, argnums=0, method="auto") + diff = qml.grad(f, argnum=0, method="auto") def fn(i, g): return diff(params) - return for_loop(0, ntrials, 1)(fn)(params) + return qml.for_loop(0, ntrials, 1)(fn)(params) def interpretted_grad_default(x): device = qml.device("default.qubit", wires=1) @@ -613,9 +626,15 @@ def interpretted_grad_default(x): h = qml.grad(g, argnum=0) return h(x) - assert np.allclose(compiled_grad_default(inp, 5), interpretted_grad_default(inp)) + enabled = qml.capture.enabled() + qml.capture.disable() + expected = interpretted_grad_default(inp) + if enabled: + qml.capture.enable() + assert np.allclose(compiled_grad_default(inp, 5), expected) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps(inp, backend): """Test the ps method.""" @@ -627,7 +646,7 @@ def f(x): @qjit def compiled(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="parameter-shift")(f) - h = grad(g, method="auto") + h = qml.grad(g, method="auto") return h(x) def interpreted(x): @@ -639,12 +658,13 @@ def interpreted(x): assert np.allclose(compiled(inp), interpreted(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_conditionals(inp, backend): """Test the ps method and conditionals.""" def f_compiled(x, y): - @cond(y > 1.5) + @qml.cond(y > 1.5) def true_path(): qml.RX(x * 2, wires=0) @@ -665,7 +685,7 @@ def f_interpreted(x, y): @qjit def compiled(x: float, y: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="parameter-shift")(f_compiled) - h = grad(g, method="auto", argnums=0) + h = qml.grad(g, method="auto", argnum=0) return h(x, y) def interpreted(x, y): @@ -674,16 +694,24 @@ def interpreted(x, y): h = qml.grad(g, argnum=0) return h(x, y) - assert np.allclose(compiled(inp, 0.0), interpreted(inp, 0.0)) - assert np.allclose(compiled(inp, 2.0), interpreted(inp, 2.0)) + enabled = qml.capture.enabled() + qml.capture.disable() + expected0 = interpreted(inp, 0.0) + expected2 = interpreted(inp, 2.0) + if enabled: + qml.capture.enable() + + assert np.allclose(compiled(inp, 0.0), expected0) + assert np.allclose(compiled(inp, 2.0), expected2) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_for_loops(inp, backend): """Test the ps method with for loops.""" def f_compiled(x, y): - @for_loop(0, y, 1) + @qml.for_loop(0, y, 1) def loop_fn(i): qml.RX(x * i * 1.5, wires=0) @@ -698,7 +726,7 @@ def f_interpreted(x, y): @qjit def compiled(x: float, y: int): g = qml.qnode(qml.device(backend, wires=1), diff_method="parameter-shift")(f_compiled) - h = grad(g, method="auto", argnums=0) + h = qml.grad(g, method="auto", argnum=0) return h(x, y) def interpreted(x, y): @@ -707,12 +735,22 @@ def interpreted(x, y): h = qml.grad(g, argnum=0) return h(x, y) - assert np.allclose(compiled(inp, 1), interpreted(inp, 1)) - assert np.allclose(compiled(inp, 2), interpreted(inp, 2)) - assert np.allclose(compiled(inp, 3), interpreted(inp, 3)) - assert np.allclose(compiled(inp, 4), interpreted(inp, 4)) + enabled = qml.capture.enabled() + qml.capture.disable() + expected1 = interpreted(inp, 1) + expected2 = interpreted(inp, 2) + expected3 = interpreted(inp, 3) + expected4 = interpreted(inp, 4) + if enabled: + qml.capture.enable() + assert np.allclose(compiled(inp, 1), expected1) + assert np.allclose(compiled(inp, 2), expected2) + assert np.allclose(compiled(inp, 3), expected3) + assert np.allclose(compiled(inp, 4), expected4) + +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_for_loops_entangled(inp, backend): """Test the ps method with for loops and entangled.""" @@ -721,7 +759,7 @@ def f_compiled(x, y, z): qml.RX(x, wires=0) qml.Hadamard(wires=0) - @for_loop(1, y, 1) + @qml.for_loop(1, y, 1) def loop_fn(i): qml.RX(x, wires=i) qml.CNOT(wires=[0, i]) @@ -740,7 +778,7 @@ def f_interpreted(x, y, z): @qjit def compiled(x: float, y: int, z: int): g = qml.qnode(qml.device(backend, wires=3), diff_method="parameter-shift")(f_compiled) - h = grad(g, method="auto", argnums=0) + h = qml.grad(g, method="auto", argnum=0) return h(x, y, z) def interpreted(x, y, z): @@ -749,26 +787,31 @@ def interpreted(x, y, z): h = qml.grad(g, argnum=0) return h(x, y, z) - assert np.allclose(compiled(inp, 1, 1), interpreted(inp, 1, 1)) - assert np.allclose(compiled(inp, 2, 2), interpreted(inp, 2, 2)) + qml.capture.disable() + expected11 = interpreted(inp, 1, 1) + expected22 = interpreted(inp, 2, 2) + + assert np.allclose(compiled(inp, 1, 1), expected11) + assert np.allclose(compiled(inp, 2, 2), expected22) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_qft(inp, backend): """Test the ps method in QFT.""" def qft_compiled(x, n, z): # Input state: equal superposition - @for_loop(0, n, 1) + @qml.for_loop(0, n, 1) def init(i): qml.Hadamard(wires=i) # QFT - @for_loop(0, n, 1) + @qml.for_loop(0, n, 1) def qft(i): qml.Hadamard(wires=i) - @for_loop(i + 1, n, 1) + @qml.for_loop(i + 1, n, 1) def inner(j): qml.RY(x, wires=j) qml.ControlledPhaseShift(jnp.pi / 2 ** (n - j + 1), [i, j]) @@ -798,7 +841,7 @@ def qft_interpreted(x, n, z): @qjit def compiled(x: float, y: int, z: int): g = qml.qnode(qml.device(backend, wires=3), diff_method="parameter-shift")(qft_compiled) - h = grad(g, method="auto", argnums=0) + h = qml.grad(g, method="auto", argnum=0) return h(x, y, z) def interpreted(x, y, z): @@ -807,9 +850,18 @@ def interpreted(x, y, z): h = qml.grad(g, argnum=0) return h(x, y, z) - assert np.allclose(compiled(inp, 2, 2), interpreted(inp, 2, 2)) + enabled = qml.capture.enabled() + qml.capture.disable() + expected = interpreted(inp, 2, 2) + if enabled: + qml.capture.enable() + + print("finish interpreted") + assert np.allclose(compiled(inp, 2, 2), expected) + +@pytest.mark.usefixtures("use_both_frontend") def test_ps_probs(backend): """Check that the parameter-shift method works for qml.probs.""" @@ -820,11 +872,14 @@ def func(p): @qjit def workflow(p: float): - return jacobian(func, method="auto")(p) + return qml.jacobian(func, method="auto")(p) result = workflow(0.5) + enabled = qml.capture.enabled() + qml.capture.disable() reference = qml.jacobian(func, argnum=0)(0.5) - + if enabled: + qml.capture.enable() assert np.allclose(result, reference) @@ -850,6 +905,7 @@ def main(x: float): assert np.allclose(result, reference) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff_h(inp, backend): """Test finite diff.""" @@ -861,7 +917,7 @@ def f(x): @qjit def compiled_grad_h(x: float): g = qml.qnode(qml.device(backend, wires=1))(f) - h = grad(g, method="fd", h=0.1) + h = qml.grad(g, method="fd", h=0.1) return h(x) def interpretted_grad_h(x): @@ -873,6 +929,7 @@ def interpretted_grad_h(x): assert np.allclose(compiled_grad_h(inp), interpretted_grad_h(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff_argnum(inp, backend): """Test finite diff.""" @@ -884,7 +941,7 @@ def f2(x, y): @qjit def compiled_grad_argnum(x: float): g = qml.qnode(qml.device(backend, wires=1))(f2) - h = grad(g, method="fd", argnums=1) + h = qml.grad(g, method="fd", argnum=1) return h(x, 2.0) def interpretted_grad_argnum(x): @@ -896,6 +953,7 @@ def interpretted_grad_argnum(x): assert np.allclose(compiled_grad_argnum(inp), interpretted_grad_argnum(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff_argnum_list(inp, backend): """Test finite diff.""" @@ -907,7 +965,7 @@ def f2(x, y): @qjit def compiled_grad_argnum_list(x: float): g = qml.qnode(qml.device(backend, wires=1))(f2) - h = grad(g, method="fd", argnums=[1]) + h = qml.grad(g, method="fd", argnum=[1]) return h(x, 2.0) def interpretted_grad_argnum_list(x): @@ -923,6 +981,7 @@ def interpretted_grad_argnum_list(x): assert np.allclose(compiled_grad_argnum_list(inp), interpretted_grad_argnum_list(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_grad_range_change(inp, backend): """Test finite diff.""" @@ -934,7 +993,7 @@ def f2(x, y): @qjit def compiled_grad_range_change(x: float): g = qml.qnode(qml.device(backend, wires=1))(f2) - h = grad(g, method="fd", argnums=[0, 1]) + h = qml.grad(g, method="fd", argnum=[0, 1]) return h(x, 2.0) def interpretted_grad_range_change(x): @@ -946,6 +1005,7 @@ def interpretted_grad_range_change(x): assert np.allclose(compiled_grad_range_change(inp), interpretted_grad_range_change(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_grad_range_change(inp, backend): """Test param shift.""" @@ -957,7 +1017,7 @@ def f2(x, y): @qjit def compiled_grad_range_change(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="parameter-shift")(f2) - h = grad(g, method="auto", argnums=[0, 1]) + h = qml.grad(g, method="auto", argnum=[0, 1]) return h(x, 2.0) def interpretted_grad_range_change(x): @@ -969,6 +1029,7 @@ def interpretted_grad_range_change(x): assert np.allclose(compiled_grad_range_change(inp), interpretted_grad_range_change(inp)) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_ps_tensorinp(inp, backend): """Test param shift.""" @@ -980,7 +1041,7 @@ def f2(x, y): @qjit def compiled(x: jax.core.ShapedArray([1], float)): g = qml.qnode(qml.device(backend, wires=1), diff_method="parameter-shift")(f2) - h = grad(g, method="auto", argnums=[0, 1]) + h = qml.grad(g, method="auto", argnum=[0, 1]) return h(x, 2.0) def interpretted(x): @@ -993,6 +1054,7 @@ def interpretted(x): assert np.allclose(dydx_c, dydx_i) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_adjoint_grad_range_change(inp, backend): """Test adjoint.""" @@ -1004,7 +1066,7 @@ def f2(x, y): @qjit def compiled_grad_range_change(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method="adjoint")(f2) - h = grad(g, method="auto", argnums=[0, 1]) + h = qml.grad(g, method="auto", argnum=[0, 1]) return h(x, 2.0) def interpretted_grad_range_change(x): @@ -1029,8 +1091,8 @@ def f(x): @qjit def workflow(x: float): g = qml.qnode(qml.device(backend, wires=1), diff_method=method)(f) - h = grad(g, method="auto") - i = grad(h, method="auto") + h = qml.grad(g, method="auto") + i = qml.grad(h, method="auto") return i(x) @@ -1077,6 +1139,7 @@ def workflow(x: float): qjit(workflow) +@pytest.mark.usefixtures("use_both_frontend") def test_finite_diff_arbitrary_functions(): """Test gradients on non-qnode functions.""" @@ -1085,11 +1148,12 @@ def workflow(x): def _f(x): return 2 * x - return grad(_f, method="fd")(x) + return qml.grad(_f, method="fd")(x) assert workflow(0.0) == 2.0 +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("inp", [(1.0), (2.0), (3.0), (4.0)]) def test_finite_diff_higher_order(inp, backend): """Test finite diff.""" @@ -1101,8 +1165,8 @@ def f(x): @qjit def compiled_grad2_default(x: float): g = qml.qnode(qml.device(backend, wires=1))(f) - h = grad(g, method="fd") - i = grad(h, method="fd") + h = qml.grad(g, method="fd") + i = qml.grad(h, method="fd") return i(x) def interpretted_grad2_default(x): @@ -1112,9 +1176,16 @@ def interpretted_grad2_default(x): i = qml.grad(h, argnum=0) return i(x) - assert np.allclose(compiled_grad2_default(inp), interpretted_grad2_default(inp), rtol=0.1) + enabled = qml.capture.enabled() + qml.capture.disable() + expected = interpretted_grad2_default(inp) + if enabled: + qml.capture.enable() + assert np.allclose(compiled_grad2_default(inp), expected, rtol=0.1) + +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("g_method", ["fd", "auto"]) @pytest.mark.parametrize( "h_coeffs", [[0.2, -0.53], np.array([0.2, -0.53]), jnp.array([0.2, -0.53])] @@ -1123,6 +1194,7 @@ def test_jax_consts(h_coeffs, g_method, backend): """Test jax constants.""" def circuit(params): + qml.H(0) qml.CRX(params[0], wires=[0, 1]) qml.CRX(params[0], wires=[0, 2]) h_obs = [qml.PauliX(0) @ qml.PauliZ(1), qml.PauliZ(0) @ qml.Hadamard(2)] @@ -1132,7 +1204,7 @@ def circuit(params): def compile_grad(params): diff_method = "adjoint" if g_method == "auto" else "finite-diff" g = qml.qnode(qml.device(backend, wires=3), diff_method=diff_method)(circuit) - h = grad(g, method=g_method) + h = qml.grad(g, method=g_method) return h(params) def interpret_grad(params): @@ -1142,7 +1214,14 @@ def interpret_grad(params): return h(params) inp = jnp.array([1.0, 2.0]) - assert np.allclose(compile_grad(jnp.array(inp)), interpret_grad(inp)) + + enabled = qml.capture.enabled() + qml.capture.disable() + expected = interpret_grad(inp) + if enabled: + qml.capture.enable() + + assert np.allclose(compile_grad(jnp.array(inp)), expected) def test_non_float_arg(backend): @@ -1185,6 +1264,7 @@ def cost_fn(x, y): cost_fn(1.0, 2.0) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["fd", "auto"]) @pytest.mark.parametrize("inp", [(1.0), (2.0)]) def test_finite_diff_multiple_devices(inp, diff_method, backend): @@ -1203,18 +1283,18 @@ def g(x): @qjit def compiled_grad_default(params, ntrials): - d_f = grad(f, argnums=0, method=diff_method) + d_f = qml.grad(f, argnum=0, method=diff_method) def fn_f(_i, _g): return d_f(params) - d_g = grad(g, argnums=0, method=diff_method) + d_g = qml.grad(g, argnum=0, method=diff_method) def fn_g(_i, _g): return d_g(params) - d1 = for_loop(0, ntrials, 1)(fn_f)(params) - d2 = for_loop(0, ntrials, 1)(fn_g)(params) + d1 = qml.for_loop(0, ntrials, 1)(fn_f)(params) + d2 = qml.for_loop(0, ntrials, 1)(fn_g)(params) return d1, d2 result = compiled_grad_default(inp, 5) @@ -1253,6 +1333,7 @@ def compiled(x): compiled(1.0) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_multiple_grad_invocations(backend, diff_method): """Test a function that uses grad multiple times.""" @@ -1265,16 +1346,18 @@ def f(x, y): @qjit def compiled(x: float, y: float): - g1 = grad(f, argnums=0, method="auto")(x, y) - g2 = grad(f, argnums=1, method="auto")(x, y) + g1 = qml.grad(f, argnum=0, method="auto")(x, y) + g2 = qml.grad(f, argnum=1, method="auto")(x, y) return jnp.array([g1, g2]) actual = compiled(0.1, 0.2) + qml.capture.disable() expected = jax.jacobian(f, argnums=(0, 1))(0.1, 0.2) for actual_entry, expected_entry in zip(actual, expected): assert actual_entry == pytest.approx(expected_entry) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_loop_with_dyn_wires(backend, diff_method): """Test the gradient on a function with a loop and modular wire arithmetic.""" @@ -1283,7 +1366,7 @@ def test_loop_with_dyn_wires(backend, diff_method): @qml.qnode(dev, diff_method=diff_method) def cat(phi): - @for_loop(0, 3, 1) + @qml.for_loop(0, 3, 1) def loop(i): qml.RY(phi, wires=jnp.mod(i, num_wires)) @@ -1293,7 +1376,7 @@ def loop(i): @qml.qnode(dev, diff_method=diff_method) def pl(phi): - @for_loop(0, 3, 1) + @qml.for_loop(0, 3, 1) def loop(i): qml.RY(phi, wires=i % num_wires) @@ -1302,7 +1385,8 @@ def loop(i): return qml.expval(qml.prod(*[qml.PauliZ(i) for i in range(num_wires)])) arg = 0.75 - result = qjit(grad(cat))(arg) + result = qjit(qml.grad(cat))(arg) + qml.capture.disable() expected = qml.grad(pl, argnum=0)(arg) assert np.allclose(result, expected) @@ -1315,11 +1399,15 @@ def test_classical_kwargs(): def f1(x, y, z): return x * (y - z) - result = qjit(grad(f1, argnums=0))(3.0, y=1.0, z=2.0) - expected = qjit(grad(f1, argnums=0))(3.0, 1.0, 2.0) + def g(*args, **kwargs): + return qml.grad(f1, argnum=0)(*args, **kwargs) + + result = qjit(g)(3.0, y=1.0, z=2.0) + expected = qjit(g)(3.0, 1.0, 2.0) assert np.allclose(expected, result) +@pytest.mark.usefixtures("use_both_frontend") def test_classical_kwargs_switched_arg_order(): """Test the gradient on classical function with keyword arguments and switched argument order""" @@ -1327,11 +1415,15 @@ def test_classical_kwargs_switched_arg_order(): def f1(x, y, z): return x * (y - z) - result = qjit(grad(f1, argnums=0))(3.0, z=2.0, y=1.0) - expected = qjit(grad(f1, argnums=0))(3.0, 1.0, 2.0) + def g(*args, **kwargs): + return qml.grad(f1, argnum=0)(*args, **kwargs) + + result = qjit(g)(3.0, z=2.0, y=1.0) + expected = qjit(g)(3.0, 1.0, 2.0) assert np.allclose(expected, result) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_qnode_kwargs(backend, diff_method): """Test the gradient on a qnode with keyword arguments""" @@ -1345,21 +1437,23 @@ def circuit(x, y, z): qml.RX(z, wires=0) return qml.expval(qml.PauliZ(0)) - result = qjit(jacobian(circuit, argnums=[0]))(0.1, y=0.2, z=0.3) - expected = qjit(jacobian(circuit, argnums=[0]))(0.1, 0.2, 0.3) + result = qjit(qml.jacobian(circuit, argnum=[0]))(0.1, y=0.2, z=0.3) + expected = qjit(qml.jacobian(circuit, argnum=[0]))(0.1, 0.2, 0.3) assert np.allclose(expected, result) - result = qjit(grad(circuit, argnums=[0]))(0.1, y=0.2, z=0.3) - expected = qjit(grad(circuit, argnums=[0]))(0.1, 0.2, 0.3) + result = qjit(qml.grad(circuit, argnum=[0]))(0.1, y=0.2, z=0.3) + expected = qjit(qml.grad(circuit, argnum=[0]))(0.1, 0.2, 0.3) assert np.allclose(expected, result) - result_val, result_grad = qjit(value_and_grad(circuit, argnums=[0]))(0.1, y=0.2, z=0.3) - expected_val = qjit(circuit)(0.1, 0.2, 0.3) - expected_grad = qjit(grad(circuit, argnums=[0]))(0.1, 0.2, 0.3) - print(result_val, result_grad) - print(expected_val, expected_grad) - assert np.allclose(expected_val, result_val) - assert np.allclose(expected_grad, result_grad) + if not qml.capture.enabled(): + result_val, result_grad = qjit(value_and_grad(circuit, argnums=[0]))(0.1, y=0.2, z=0.3) + expected_val = qjit(circuit)(0.1, 0.2, 0.3) + expected_grad = qjit(qml.grad(circuit, argnum=[0]))(0.1, 0.2, 0.3) + + assert np.allclose(expected_val, result_val) + assert np.allclose(expected_grad, result_grad) + +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_qnode_kwargs_switched_arg_order(backend, diff_method): """Test the gradient on a qnode with keyword arguments and switched argument order""" @@ -1373,21 +1467,23 @@ def circuit(x, y, z): qml.RX(z, wires=0) return qml.expval(qml.PauliZ(0)) - switched_order = qjit(jacobian(circuit, argnums=[0]))(0.1, z=0.3, y=0.2) - expected = qjit(jacobian(circuit, argnums=[0]))(0.1, 0.2, 0.3) + switched_order = qjit(qml.jacobian(circuit, argnum=[0]))(0.1, z=0.3, y=0.2) + expected = qjit(qml.jacobian(circuit, argnum=[0]))(0.1, 0.2, 0.3) assert np.allclose(expected[0], switched_order[0]) - switched_order = qjit(grad(circuit, argnums=[0]))(0.1, z=0.3, y=0.2) - expected = qjit(grad(circuit, argnums=[0]))(0.1, 0.2, 0.3) + switched_order = qjit(qml.grad(circuit, argnum=[0]))(0.1, z=0.3, y=0.2) + expected = qjit(qml.grad(circuit, argnum=[0]))(0.1, 0.2, 0.3) assert np.allclose(expected[0], switched_order[0]) - switched_order_val, switched_order_grad = qjit(value_and_grad(circuit, argnums=[0]))( - 0.1, z=0.3, y=0.2 - ) - expected_val = qjit(circuit)(0.1, 0.2, 0.3) - expected_grad = qjit(grad(circuit, argnums=[0]))(0.1, 0.2, 0.3) - assert np.allclose(expected_val, switched_order_val) - assert np.allclose(expected_grad, switched_order_grad) + if not qml.capture.enabled(): + switched_order_val, switched_order_grad = qjit(value_and_grad(circuit, argnums=[0]))( + 0.1, z=0.3, y=0.2 + ) + expected_val = qjit(circuit)(0.1, 0.2, 0.3) + expected_grad = qjit(grad(circuit, argnums=[0]))(0.1, 0.2, 0.3) + assert np.allclose(expected_val, switched_order_val) + assert np.allclose(expected_grad, switched_order_grad) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_pytrees_return_classical_function(backend, diff_method): """Test the jacobian on a qnode with a return including list and dictionaries.""" @@ -1405,10 +1501,12 @@ def circuit(phi, psi): if diff_method == "adjoint": # Adjoint method does not support multiple return values + if qml.capture.enabled(): + pytest.xfail("TODO") with pytest.raises(CompileError): - qjit(jacobian(circuit, argnums=[0, 1]))(psi, phi) + qjit(qml.jacobian(circuit, argnum=[0, 1]))(psi, phi) else: - result = qjit(jacobian(circuit, argnums=[0, 1]))(psi, phi) + result = qjit(qml.jacobian(circuit, argnum=[0, 1]))(psi, phi) assert isinstance(result, list) assert len(result) == 2 @@ -1419,6 +1517,7 @@ def circuit(phi, psi): assert len(result[1]) == 2 +@pytest.mark.usefixtures("use_both_frontend") def test_pytrees_return_classical(): """Test the jacobian on a function with a return including list and dictionaries.""" @@ -1429,7 +1528,7 @@ def f(x, y): y = 0.2 jax_expected_results = jax.jit(jax.jacobian(f, argnums=[0, 1]))(x, y) - catalyst_results = qjit(jacobian(f, argnums=[0, 1]))(x, y) + catalyst_results = qjit(qml.jacobian(f, argnum=[0, 1]))(x, y) flatten_res_jax, tree_jax = tree_flatten(jax_expected_results) flatten_res_catalyst, tree_catalyst = tree_flatten(catalyst_results) @@ -1438,6 +1537,7 @@ def f(x, y): assert np.allclose(flatten_res_jax, flatten_res_catalyst) +@pytest.mark.usefixtures("use_both_frontend") def test_pytrees_args_classical(): """Test the jacobian on a function with a return including list and dictionaries.""" @@ -1448,7 +1548,7 @@ def f(x, y): y = 0.2 jax_expected_results = jax.jit(jax.jacobian(f, argnums=[0, 1]))(x, y) - catalyst_results = qjit(jacobian(f, argnums=[0, 1]))(x, y) + catalyst_results = qjit(qml.jacobian(f, argnum=[0, 1]))(x, y) flatten_res_jax, tree_jax = tree_flatten(jax_expected_results) flatten_res_catalyst, tree_catalyst = tree_flatten(catalyst_results) @@ -1457,6 +1557,7 @@ def f(x, y): assert np.allclose(flatten_res_jax, flatten_res_catalyst) +@pytest.mark.usefixtures("use_both_frontend") def test_pytrees_args_return_classical(): """Test the jacobian on a function with a args and return including list and dictionnaries.""" @@ -1467,7 +1568,7 @@ def f(x, y): y = 0.2 jax_expected_results = jax.jit(jax.jacobian(f, argnums=[0, 1]))(x, y) - catalyst_results = qjit(jacobian(f, argnums=[0, 1]))(x, y) + catalyst_results = qjit(qml.jacobian(f, argnum=[0, 1]))(x, y) flatten_res_jax, tree_jax = tree_flatten(jax_expected_results) flatten_res_catalyst, tree_catalyst = tree_flatten(catalyst_results) @@ -1476,6 +1577,7 @@ def f(x, y): assert np.allclose(flatten_res_jax, flatten_res_catalyst) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", ["parameter-shift", "adjoint"]) def test_non_parametrized_circuit(backend, diff_method): """Test that the derivate of non parametrized circuit is null.""" @@ -1489,7 +1591,7 @@ def circuit(x): # pylint: disable=unused-argument return circuit(x) - assert np.allclose(qjit(grad(cost))(1.1), 0.0) + assert np.allclose(qjit(qml.grad(cost))(1.1), 0.0) @pytest.mark.xfail(reason="The verifier currently doesn't distinguish between active/inactive ops")