Using stop_gradient to fix parameters in gradient based optimiser? #32232
-
|
Say I have a batch of input vectors over which I am performing some kind of optimisation and I start them with random state In my setting it will be the case that some of the vector components might have fixed values - say the first 3 values are fixed for all vectors in the batch and we want to optimise the rest of the vector given those fixed points: x_temp = np.random.rand(batch_size, n_vars)
x_fixed = np.repeat(np.array([[1,2,3]]), batch_size, axis=0)
x_temp[:, :x_fixed.shape[-1]] = x_fixed
x_init = jnp.array(x_temp)I've almost certainly got this wrong, but based on reading the documentation on stopping gradients, I had hoped that something like this: mask = np.zeros((batch_size, n_vars))
mask[:, :x_fixed.shape[-1]] = np.ones(x_fixed.shape)
x_init = jnp.where(jnp.array(mask), jax.lax.stop_gradient(x_init), x_init)Would mean that those positions should never get adjusted by a gradient based optimiser; at least, assuming the optimiser is not doing something like directly adding terms to the update step that don't involve multiplication by the gradient? I'm seeing some of the points for which I've applied |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
|
This is not something that Your best bet for this would probably be to split your parameters into an array of constants, and an array of parameters to be fit, and adjust your loss function to take only the fittable parameters as the explicit argument. |
Beta Was this translation helpful? Give feedback.
Yes, exactly. You can see this with a simple example:
Calling
stop_gradienton an array at the top level has no effect; there are no gradients to stop outside the context of an autodiff transformation.