diff --git a/frontend/catalyst/autograph/operator_update.py b/frontend/catalyst/autograph/operator_update.py deleted file mode 100644 index cca096c0e8..0000000000 --- a/frontend/catalyst/autograph/operator_update.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Converter for array element operator assignment.""" - -import gast -from malt.core import converter -from malt.pyct import templates - - -# TODO: The methods from this class should be moved to the SliceTransformer class in DiastaticMalt -class SingleIndexArrayOperatorUpdateTransformer(converter.Base): - """Converts array element operator assignment statements into calls to update_item_with_{op}, - where op is one of the following: - - - `add` corresponding to `+=` - - `sub` to `-=` - - `mult` to `*=` - - `div` to `/=` - - `pow` to `**=` - """ - - def _process_single_update(self, target, op, value): - if not isinstance(target, gast.Subscript): - return None - s = target.slice - if isinstance(s, (gast.Tuple, gast.Call)): - return None - if not isinstance(op, (gast.Mult, gast.Add, gast.Sub, gast.Div, gast.Pow)): - return None - - template = f""" - target = ag__.update_item_with_op(target, index, x, "{type(op).__name__.lower()}") - """ - lower, upper, step = None, None, None - - if isinstance(s, (gast.Slice)): - # Replace unused arguments in template with "None" to preserve each arguments' position. - # templates.replace ignores None and does not accept string so change is applied here. - lower_str = "lower" if s.lower is not None else "None" - upper_str = "upper" if s.upper is not None else "None" - step_str = "step" if s.step is not None else "None" - template = template.replace("index", f"slice({lower_str}, {upper_str}, {step_str})") - - lower, upper, step = s.lower, s.upper, s.step - - return templates.replace( - template, - target=target.value, - index=target.slice, - lower=lower, - upper=upper, - step=step, - x=value, - ) - - def visit_AugAssign(self, node): - """The AugAssign node is replaced with a call to ag__.update_item_with_{op} - when its target is a single index array subscript and its op is an arithmetic - operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is. - - Example: - `x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)` - `x[i] ^= y` remains unchanged - """ - node = self.generic_visit(node) - replacement = self._process_single_update(node.target, node.op, node.value) - if replacement is not None: - return replacement - return node - - -def transform(node, ctx): - """Replace an AugAssign node with a call to ag__.update_item_with_{op} - when the its target is a single index array subscript and its op is an arithmetic - operator (i.e. Add, Sub, Mult, Div, or Pow), otherwise the node is left as is. - - Example: - `x[i] += y` is replaced with `x = ag__.update_item_with(x, i, y)` - `x[i] ^= y` remains unchanged - """ - return SingleIndexArrayOperatorUpdateTransformer(ctx).visit(node) diff --git a/frontend/catalyst/autograph/transformer.py b/frontend/catalyst/autograph/transformer.py index 16b8774548..b1c31048b6 100644 --- a/frontend/catalyst/autograph/transformer.py +++ b/frontend/catalyst/autograph/transformer.py @@ -22,145 +22,70 @@ """ import copy import functools -import inspect -from contextlib import ContextDecorator -import pennylane as qml -from malt.core import ag_ctx, config, converter -from malt.impl.api import PyToPy +from malt.core import config +from pennylane.capture import autograph as pl_autograph +from pennylane.capture.autograph.transformer import ( + PennyLaneTransformer, +) import catalyst -from catalyst.autograph import ag_primitives, operator_update +from catalyst.autograph import ag_primitives from catalyst.passes.pass_api import PassPipelineWrapper, QNodeWrapper -from catalyst.utils.exceptions import AutoGraphError from catalyst.utils.patching import Patcher -class CatalystTransformer(PyToPy): - """A source-to-source transformer to convert imperative style control flow into a function style - suitable for tracing.""" - - def __init__(self): - super().__init__() - - self._extra_locals = None +class CatalystTransformer(PennyLaneTransformer): + """A source-to-source transformer that extends the PennyLane transformer + to handle Catalyst-specific objects like QNodeWrapper.""" def transform(self, obj, user_context): - """Launch the transformation process. Typically this only works on function objects. - Here we also allow QNodes to be transformed.""" + """Launch the transformation process, with special handling for + Catalyst's QNodeWrapper and PassPipelineWrapper.""" - # By default AutoGraph will only convert function or method objects, not arbitrary classes - # such as QNode objects. Here we handle them explicitly, but we might need a more general - # way to handle these in the future. - # We may also need to check how this interacts with other common function decorators. fn = obj - if isinstance(obj, qml.QNode): - fn = obj.func - elif isinstance(obj, QNodeWrapper): + if isinstance(obj, QNodeWrapper): fn = obj data = [] while isinstance(fn, QNodeWrapper): data.append((fn.pass_name_or_pipeline, fn.flags, fn.valued_options)) fn = fn.qnode fn = obj.original_qnode.func - elif inspect.isfunction(fn) or inspect.ismethod(fn): - pass - elif callable(obj): - # pylint: disable=unnecessary-lambda,unnecessary-lambda-assignment - fn = lambda *args, **kwargs: obj(*args, **kwargs) + else: - raise AutoGraphError(f"Unsupported object for transformation: {type(fn)}") + new_obj, module, source_map = super().transform(obj, user_context) + + if isinstance(obj, PassPipelineWrapper): + new_qnode = copy.copy(obj.original_qnode) + new_qnode.func = new_obj + data.reverse() + for _pass, flags, kwopts in data: + new_qnode = PassPipelineWrapper(new_qnode, _pass, *flags, **kwopts) + new_obj = new_qnode + + return new_obj, module, source_map new_fn, module, source_map = self.transform_function(fn, user_context) - new_obj = new_fn - if isinstance(obj, qml.QNode): - new_obj = copy.copy(obj) - new_obj.func = new_fn - elif isinstance(obj, PassPipelineWrapper): + if isinstance(obj, QNodeWrapper): new_qnode = copy.copy(obj.original_qnode) new_qnode.func = new_fn data.reverse() for _pass, flags, kwopts in data: new_qnode = PassPipelineWrapper(new_qnode, _pass, *flags, **kwopts) - new_obj = new_qnode - - return new_obj, module, source_map - - def get_extra_locals(self): - """Here we can provide any extra names that the converted function should have access to. - At a minimum we need to provide the module with definitions for AutoGraph primitives.""" - - if self._extra_locals is None: - extra_locals = super().get_extra_locals() - updates = {key: ag_primitives.__dict__[key] for key in ag_primitives.__all__} - extra_locals["ag__"].__dict__.update(updates) - self._extra_locals = extra_locals - - return self._extra_locals - - def has_cache(self, fn): - """Check for the presence of the given function in the cache. Functions to be converted are - cached by the function object itself as well as the conversion options.""" - - return ( - self._cache.has(fn, TOPLEVEL_OPTIONS) - or self._cache.has(fn, NESTED_OPTIONS) - or self._cache.has(fn, STANDARD_OPTIONS) - ) - - def get_cached_function(self, fn): - """Retrieve a Python function object for a previously converted function. - Note that repeatedly calling this function with the same arguments will result in new - function objects every time, however their source code should be identical with the - exception of auto-generated names.""" - - # Converted functions are cached as a _PythonFnFactory object. - if self._cache.has(fn, TOPLEVEL_OPTIONS): - cached_factory = self._cached_factory(fn, TOPLEVEL_OPTIONS) - elif self._cache.has(fn, NESTED_OPTIONS): - cached_factory = self._cached_factory(fn, NESTED_OPTIONS) - else: - cached_factory = self._cached_factory(fn, STANDARD_OPTIONS) - - # Convert to a Python function object before returning (e.g. to obtain its source code). - new_fn = cached_factory.instantiate( - fn.__globals__, - fn.__closure__ or (), - defaults=fn.__defaults__, - kwdefaults=getattr(fn, "__kwdefaults__", None), - ) - - return new_fn - - def transform_ast(self, node, ctx): - """Overload of PyToPy.transform_ast from DiastaticMalt - - .. note:: - Once the operator_update interface has been migrated to the - DiastaticMalt project, this overload can be deleted.""" - # The operator_update transform would be more correct if placed with - # slices.transform in PyToPy.transform_ast in DiastaticMalt rather than - # at the beginning of the transformation. operator_update.transform - # should come after the unsupported features check and intial analysis, - # but it fails if it does not come before variables.transform. - node = operator_update.transform(node, ctx) - node = super().transform_ast(node, ctx) - return node + return new_qnode, module, source_map + + return new_fn, module, source_map def run_autograph(fn, *modules): """Decorator that converts the given function into graph form.""" + new_fn = pl_autograph.run_autograph(fn) + allowed_modules = tuple(config.Convert(module) for module in modules) allowed_modules += ag_primitives.module_allowlist - user_context = converter.ProgramContext(TOPLEVEL_OPTIONS) - new_fn, module, source_map = TRANSFORMER.transform(fn, user_context) - new_fn.ag_module = module - new_fn.ag_source_map = source_map - new_fn.ag_unconverted = fn - @functools.wraps(new_fn) def wrapper(*args, **kwargs): with Patcher( @@ -226,113 +151,17 @@ def else_body(): return y """ - # Handle directly converted objects. - if hasattr(fn, "ag_unconverted"): - return inspect.getsource(fn) - # Unwrap known objects to get the function actually transformed by autograph. if isinstance(fn, catalyst.QJIT): fn = fn.original_function - if isinstance(fn, qml.QNode): - fn = fn.func if isinstance(fn, QNodeWrapper): fn = fn.original_qnode - if TRANSFORMER.has_cache(fn): - new_fn = TRANSFORMER.get_cached_function(fn) - return inspect.getsource(new_fn) - - raise AutoGraphError( - "The given function was not converted by AutoGraph. If you expect the" - "given function to be converted, please submit a bug report." - ) - - -class DisableAutograph(ag_ctx.ControlStatusCtx, ContextDecorator): - """Context decorator that disables AutoGraph for the given function/context. - - .. note:: - - A singleton instance is used for discarding parentheses usage: - - @disable_autograph - instead of - @DisableAutograph() - - with disable_autograph: - instead of - with DisableAutograph() - - **Example 1: as a function decorator** - - .. code-block:: python - - @disable_autograph - def f(): - x = 6 - if x > 5: - y = x ** 2 - else: - y = x ** 3 - return y - - @qjit(autograph=True) - def g(x: float, n: int): - for _ in range(n): - x = x + f() - return x - - >>> print(g(0.4, 6)) - 216.4 + return pl_autograph.autograph_source(fn) - **Example 2: as a context manager** - - .. code-block:: python - - def f(): - x = 6 - if x > 5: - y = x ** 2 - else: - y = x ** 3 - return y - @qjit(autograph=True) - def g(): - x = 0.4 - with disable_autograph: - x += f() - return x - - >>> print(g()) - 36.4 - """ - - def __init__(self): - super().__init__(status=ag_ctx.Status.DISABLED) - - -# Singleton instance of DisableAutograph -disable_autograph = DisableAutograph() - -# converter.Feature.LISTS permits overloading the 'set_item' function in 'ag_primitives.py' -OPTIONAL_FEATURES = [converter.Feature.BUILTIN_FUNCTIONS, converter.Feature.LISTS] - -TOPLEVEL_OPTIONS = converter.ConversionOptions( - recursive=True, - user_requested=True, - internal_convert_user_code=True, - optional_features=OPTIONAL_FEATURES, -) - -NESTED_OPTIONS = converter.ConversionOptions( - recursive=True, - user_requested=False, - internal_convert_user_code=True, - optional_features=OPTIONAL_FEATURES, -) +disable_autograph = pl_autograph.disable_autograph -STANDARD_OPTIONS = converter.STANDARD_OPTIONS # Keep a global instance of the transformer to benefit from caching. TRANSFORMER = CatalystTransformer()