We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ff5e774 commit 1d79315Copy full SHA for 1d79315
edward2/jax/nn/random_feature.py
@@ -207,7 +207,11 @@ def __call__(self, inputs: Array) -> Array:
207
208
# Performs forward pass.
209
inputs = jnp.asarray(inputs, self.dtype)
210
- outputs = lax.dot_general(inputs, kernel.value,
+ # Cast the kernel to correct dtype in case the parameter is saved and
211
+ # restored with a different dtype.
212
+ # TODO(b/235921783): Avoid casting dtype here.
213
+ kernel_value = jnp.asarray(kernel.value, self.dtype)
214
+ outputs = lax.dot_general(inputs, kernel_value,
215
(contracting_dims, batch_dims))
216
outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)
217
0 commit comments