Skip to content

Commit 1d79315

Browse files
fehiepsiedward-bot
authored andcommitted
Internal
PiperOrigin-RevId: 449584663
1 parent ff5e774 commit 1d79315

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

edward2/jax/nn/random_feature.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,11 @@ def __call__(self, inputs: Array) -> Array:
207207

208208
# Performs forward pass.
209209
inputs = jnp.asarray(inputs, self.dtype)
210-
outputs = lax.dot_general(inputs, kernel.value,
210+
# Cast the kernel to correct dtype in case the parameter is saved and
211+
# restored with a different dtype.
212+
# TODO(b/235921783): Avoid casting dtype here.
213+
kernel_value = jnp.asarray(kernel.value, self.dtype)
214+
outputs = lax.dot_general(inputs, kernel_value,
211215
(contracting_dims, batch_dims))
212216
outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)
213217

0 commit comments

Comments
 (0)