Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
from jax._src.interpreters.batching import not_mapped
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
tree_leaves_with_path, keystr, treedef_children, tree_structure, PyTreeDef)
register_pytree_node_class, tree_leaves, tree_map_with_path,
tree_flatten_with_path, tree_leaves_with_path, keystr, treedef_children,
tree_structure, PyTreeDef)
from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2,
weakref_lru_cache)

Expand Down Expand Up @@ -72,10 +73,15 @@ def _sum_tangents(_, x, *xs):
def _zeros_like_pytree(x):
return tree_map(Zero.from_primal_value, x)

_stop_gradient = partial(
tree_map,
lambda x: stop_gradient_p.bind(x) if isinstance(x, core.Tracer) else x,
)
def _check_tracer(i, x):
return tree_map_with_path(partial(_check_tracer_, i), x)

def _check_tracer_(i, path, x):
if isinstance(x, core.Tracer):
pathstr = f' path {keystr(path)}' if path else ''
raise Exception(f"Tracer of type {x.aval} passed at argument {i}{pathstr}, "
"but Tracers can't be passed as nondiff_argnums.")
Comment on lines +81 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great improvement for error reporting! To make it even better, consider using the more specific UnexpectedTracerError (which is already imported) instead of the generic Exception. Also, the error message could be more descriptive to help users understand why the error is happening and how to fix it, similar to the error message in _check_for_tracers for custom_vjp. I've also taken the liberty to rename pathstr to path_str for PEP8 compliance.

Suggested change
pathstr = f' path {keystr(path)}' if path else ''
raise Exception(f"Tracer of type {x.aval} passed at argument {i}{pathstr}, "
"but Tracers can't be passed as nondiff_argnums.")
path_str = f' at path {keystr(path)}' if path else ''
msg = (f"A JAX Tracer of type {x.aval} was passed to a custom_jvp function "
f"at argument {i}{path_str}, which was specified as non-differentiable. "
"Tracers cannot be passed as non-differentiable arguments. Instead, "
"`nondiff_argnums` should only be used for arguments that are not JAX "
"tracers, e.g. Python scalars, strings, or callables.")
raise UnexpectedTracerError(msg)

return x


# like the api_util.py function, but also grabs output avals for error checking
Expand Down Expand Up @@ -276,7 +282,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
) from e

if self.nondiff_argnums:
args = tuple(_stop_gradient(x) if i in self.nondiff_argnums else x
args = tuple(_check_tracer(i, x) if i in self.nondiff_argnums else x
for i, x in enumerate(args))
diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums]
f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug),
Expand Down
Loading