You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
Hello,
I'm currently working on modifying objax to work in my framework and I have a question about how the jax.grad function is called by _DerivativeBase.__call__. By my understanding, the grad function cares only about the first input given and doesn't modify the BaseStates at all. By running a couple of tests I got the following:
By checking the ids of a BaseState inside f_func, the values contained in the function argument and in the object var collection are the same
By removing the line self.vc.subset(BaseState).assign(state_tensors), the behaviour of my program (based on the mnist-tutorial but I added on layer that keep track of how many time it has been called with a BaseState counter) doesn't change.
So my question is, am I missing something? Is passing the BaseStates around necessary (more in general, is it necessary to pass through f_func any argument if I'm not asking to compute the gradient with respect to it)?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm currently working on modifying objax to work in my framework and I have a question about how the jax.grad function is called by
_DerivativeBase.__call__. By my understanding, the grad function cares only about the first input given and doesn't modify the BaseStates at all. By running a couple of tests I got the following:By checking the ids of a BaseState inside
f_func, the values contained in the function argument and in the object var collection are the sameBy removing the line
self.vc.subset(BaseState).assign(state_tensors), the behaviour of my program (based on the mnist-tutorial but I added on layer that keep track of how many time it has been called with a BaseState counter) doesn't change.So my question is, am I missing something? Is passing the BaseStates around necessary (more in general, is it necessary to pass through f_func any argument if I'm not asking to compute the gradient with respect to it)?
Thank you.
Beta Was this translation helpful? Give feedback.
All reactions