diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index ed9b051d5777..4706dfd685f4 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -43,8 +43,9 @@ from jax._src.interpreters.batching import not_mapped from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, - register_pytree_node_class, tree_leaves, tree_flatten_with_path, - tree_leaves_with_path, keystr, treedef_children, tree_structure, PyTreeDef) + register_pytree_node_class, tree_leaves, tree_map_with_path, + tree_flatten_with_path, tree_leaves_with_path, keystr, treedef_children, + tree_structure, PyTreeDef) from jax._src.util import (cache, safe_zip, safe_map, split_list, unzip2, weakref_lru_cache) @@ -72,10 +73,15 @@ def _sum_tangents(_, x, *xs): def _zeros_like_pytree(x): return tree_map(Zero.from_primal_value, x) -_stop_gradient = partial( - tree_map, - lambda x: stop_gradient_p.bind(x) if isinstance(x, core.Tracer) else x, -) +def _check_tracer(i, x): + return tree_map_with_path(partial(_check_tracer_, i), x) + +def _check_tracer_(i, path, x): + if isinstance(x, core.Tracer): + pathstr = f' path {keystr(path)}' if path else '' + raise Exception(f"Tracer of type {x.aval} passed at argument {i}{pathstr}, " + "but Tracers can't be passed as nondiff_argnums.") + return x # like the api_util.py function, but also grabs output avals for error checking @@ -276,7 +282,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable ) from e if self.nondiff_argnums: - args = tuple(_stop_gradient(x) if i in self.nondiff_argnums else x + args = tuple(_check_tracer(i, x) if i in self.nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [i for i in range(len(args)) if i not in self.nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun, debug_info=debug),