From 1d79315e73a71341b3d0a0eaa9b72097e94a5b00 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 18 May 2022 14:56:19 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 449584663 --- edward2/jax/nn/random_feature.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/edward2/jax/nn/random_feature.py b/edward2/jax/nn/random_feature.py index acc717f6..5535c20d 100644 --- a/edward2/jax/nn/random_feature.py +++ b/edward2/jax/nn/random_feature.py @@ -207,7 +207,11 @@ def __call__(self, inputs: Array) -> Array: # Performs forward pass. inputs = jnp.asarray(inputs, self.dtype) - outputs = lax.dot_general(inputs, kernel.value, + # Cast the kernel to correct dtype in case the parameter is saved and + # restored with a different dtype. + # TODO(b/235921783): Avoid casting dtype here. + kernel_value = jnp.asarray(kernel.value, self.dtype) + outputs = lax.dot_general(inputs, kernel_value, (contracting_dims, batch_dims)) outputs = outputs + jnp.broadcast_to(bias.value, outputs.shape)