Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions frontend/catalyst/autograph/operator_update.py

This file was deleted.

235 changes: 32 additions & 203 deletions frontend/catalyst/autograph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,145 +22,70 @@
"""
import copy
import functools
import inspect
from contextlib import ContextDecorator

import pennylane as qml
from malt.core import ag_ctx, config, converter
from malt.impl.api import PyToPy
from malt.core import config
from pennylane.capture import autograph as pl_autograph
from pennylane.capture.autograph.transformer import (
PennyLaneTransformer,
)

import catalyst
from catalyst.autograph import ag_primitives, operator_update
from catalyst.autograph import ag_primitives
from catalyst.passes.pass_api import PassPipelineWrapper, QNodeWrapper
from catalyst.utils.exceptions import AutoGraphError
from catalyst.utils.patching import Patcher


class CatalystTransformer(PyToPy):
"""A source-to-source transformer to convert imperative style control flow into a function style
suitable for tracing."""

def __init__(self):
super().__init__()

self._extra_locals = None
class CatalystTransformer(PennyLaneTransformer):
"""A source-to-source transformer that extends the PennyLane transformer
to handle Catalyst-specific objects like QNodeWrapper."""

def transform(self, obj, user_context):
"""Launch the transformation process. Typically this only works on function objects.
Here we also allow QNodes to be transformed."""
"""Launch the transformation process, with special handling for
Catalyst's QNodeWrapper and PassPipelineWrapper."""

# By default AutoGraph will only convert function or method objects, not arbitrary classes
# such as QNode objects. Here we handle them explicitly, but we might need a more general
# way to handle these in the future.
# We may also need to check how this interacts with other common function decorators.
fn = obj
if isinstance(obj, qml.QNode):
fn = obj.func
elif isinstance(obj, QNodeWrapper):
if isinstance(obj, QNodeWrapper):
fn = obj
data = []
while isinstance(fn, QNodeWrapper):
data.append((fn.pass_name_or_pipeline, fn.flags, fn.valued_options))
fn = fn.qnode
fn = obj.original_qnode.func
elif inspect.isfunction(fn) or inspect.ismethod(fn):
pass
elif callable(obj):
# pylint: disable=unnecessary-lambda,unnecessary-lambda-assignment
fn = lambda *args, **kwargs: obj(*args, **kwargs)

else:
raise AutoGraphError(f"Unsupported object for transformation: {type(fn)}")
new_obj, module, source_map = super().transform(obj, user_context)

if isinstance(obj, PassPipelineWrapper):
new_qnode = copy.copy(obj.original_qnode)
new_qnode.func = new_obj
data.reverse()
for _pass, flags, kwopts in data:
new_qnode = PassPipelineWrapper(new_qnode, _pass, *flags, **kwopts)
new_obj = new_qnode

return new_obj, module, source_map

new_fn, module, source_map = self.transform_function(fn, user_context)
new_obj = new_fn

if isinstance(obj, qml.QNode):
new_obj = copy.copy(obj)
new_obj.func = new_fn
elif isinstance(obj, PassPipelineWrapper):
if isinstance(obj, QNodeWrapper):
new_qnode = copy.copy(obj.original_qnode)
new_qnode.func = new_fn
data.reverse()
for _pass, flags, kwopts in data:
new_qnode = PassPipelineWrapper(new_qnode, _pass, *flags, **kwopts)
new_obj = new_qnode

return new_obj, module, source_map

def get_extra_locals(self):
"""Here we can provide any extra names that the converted function should have access to.
At a minimum we need to provide the module with definitions for AutoGraph primitives."""

if self._extra_locals is None:
extra_locals = super().get_extra_locals()
updates = {key: ag_primitives.__dict__[key] for key in ag_primitives.__all__}
extra_locals["ag__"].__dict__.update(updates)
self._extra_locals = extra_locals

return self._extra_locals

def has_cache(self, fn):
"""Check for the presence of the given function in the cache. Functions to be converted are
cached by the function object itself as well as the conversion options."""

return (
self._cache.has(fn, TOPLEVEL_OPTIONS)
or self._cache.has(fn, NESTED_OPTIONS)
or self._cache.has(fn, STANDARD_OPTIONS)
)

def get_cached_function(self, fn):
"""Retrieve a Python function object for a previously converted function.
Note that repeatedly calling this function with the same arguments will result in new
function objects every time, however their source code should be identical with the
exception of auto-generated names."""

# Converted functions are cached as a _PythonFnFactory object.
if self._cache.has(fn, TOPLEVEL_OPTIONS):
cached_factory = self._cached_factory(fn, TOPLEVEL_OPTIONS)
elif self._cache.has(fn, NESTED_OPTIONS):
cached_factory = self._cached_factory(fn, NESTED_OPTIONS)
else:
cached_factory = self._cached_factory(fn, STANDARD_OPTIONS)

# Convert to a Python function object before returning (e.g. to obtain its source code).
new_fn = cached_factory.instantiate(
fn.__globals__,
fn.__closure__ or (),
defaults=fn.__defaults__,
kwdefaults=getattr(fn, "__kwdefaults__", None),
)

return new_fn

def transform_ast(self, node, ctx):
"""Overload of PyToPy.transform_ast from DiastaticMalt

.. note::
Once the operator_update interface has been migrated to the
DiastaticMalt project, this overload can be deleted."""
# The operator_update transform would be more correct if placed with
# slices.transform in PyToPy.transform_ast in DiastaticMalt rather than
# at the beginning of the transformation. operator_update.transform
# should come after the unsupported features check and intial analysis,
# but it fails if it does not come before variables.transform.
node = operator_update.transform(node, ctx)
node = super().transform_ast(node, ctx)
return node
return new_qnode, module, source_map

return new_fn, module, source_map


def run_autograph(fn, *modules):
"""Decorator that converts the given function into graph form."""

new_fn = pl_autograph.run_autograph(fn)

allowed_modules = tuple(config.Convert(module) for module in modules)
allowed_modules += ag_primitives.module_allowlist

user_context = converter.ProgramContext(TOPLEVEL_OPTIONS)
new_fn, module, source_map = TRANSFORMER.transform(fn, user_context)
new_fn.ag_module = module
new_fn.ag_source_map = source_map
new_fn.ag_unconverted = fn

@functools.wraps(new_fn)
def wrapper(*args, **kwargs):
with Patcher(
Expand Down Expand Up @@ -226,113 +151,17 @@ def else_body():
return y
"""

# Handle directly converted objects.
if hasattr(fn, "ag_unconverted"):
return inspect.getsource(fn)

# Unwrap known objects to get the function actually transformed by autograph.
if isinstance(fn, catalyst.QJIT):
fn = fn.original_function
if isinstance(fn, qml.QNode):
fn = fn.func
if isinstance(fn, QNodeWrapper):
fn = fn.original_qnode

if TRANSFORMER.has_cache(fn):
new_fn = TRANSFORMER.get_cached_function(fn)
return inspect.getsource(new_fn)

raise AutoGraphError(
"The given function was not converted by AutoGraph. If you expect the"
"given function to be converted, please submit a bug report."
)


class DisableAutograph(ag_ctx.ControlStatusCtx, ContextDecorator):
"""Context decorator that disables AutoGraph for the given function/context.

.. note::

A singleton instance is used for discarding parentheses usage:

@disable_autograph
instead of
@DisableAutograph()

with disable_autograph:
instead of
with DisableAutograph()

**Example 1: as a function decorator**

.. code-block:: python

@disable_autograph
def f():
x = 6
if x > 5:
y = x ** 2
else:
y = x ** 3
return y

@qjit(autograph=True)
def g(x: float, n: int):
for _ in range(n):
x = x + f()
return x

>>> print(g(0.4, 6))
216.4
return pl_autograph.autograph_source(fn)

**Example 2: as a context manager**

.. code-block:: python

def f():
x = 6
if x > 5:
y = x ** 2
else:
y = x ** 3
return y

@qjit(autograph=True)
def g():
x = 0.4
with disable_autograph:
x += f()
return x

>>> print(g())
36.4
"""

def __init__(self):
super().__init__(status=ag_ctx.Status.DISABLED)


# Singleton instance of DisableAutograph
disable_autograph = DisableAutograph()

# converter.Feature.LISTS permits overloading the 'set_item' function in 'ag_primitives.py'
OPTIONAL_FEATURES = [converter.Feature.BUILTIN_FUNCTIONS, converter.Feature.LISTS]

TOPLEVEL_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=True,
internal_convert_user_code=True,
optional_features=OPTIONAL_FEATURES,
)

NESTED_OPTIONS = converter.ConversionOptions(
recursive=True,
user_requested=False,
internal_convert_user_code=True,
optional_features=OPTIONAL_FEATURES,
)
disable_autograph = pl_autograph.disable_autograph

STANDARD_OPTIONS = converter.STANDARD_OPTIONS

# Keep a global instance of the transformer to benefit from caching.
TRANSFORMER = CatalystTransformer()
Loading