Skip to content
Discussion options

You must be logged in to vote

implies that calling stop_gradient inside my objective function will have different behaviour to calling it on the original array prior to passing it as input?

Yes, exactly. You can see this with a simple example:

import jax

def f(x):
  return x

def g(x):
  return jax.lax.stop_gradient(x)

x = 1.0
print(jax.grad(f)(x))  # 1.0
print(jax.grad(f)(jax.lax.stop_gradient(x)))  # 1.0
print(jax.grad(g)(x))  # 0.0

Calling stop_gradient on an array at the top level has no effect; there are no gradients to stop outside the context of an autodiff transformation.

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@cjchristopher
Comment options

@cjchristopher
Comment options

@jakevdp
Comment options

Answer selected by cjchristopher
@jakevdp
Comment options

@cjchristopher
Comment options

@jakevdp
Comment options

@cjchristopher
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants