-
Notifications
You must be signed in to change notification settings - Fork 267
Closed
Labels
Description
Hi,
I was trying to combine subsampling and MCMC. Is this possible, because I receive the following error AssertionError: Missing random key to generate subsample indices.
I've looked into the code but couldn't figure out where the rng_key
should be passed to the _subsample_fn
.
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from flax import linen as nn
from jax import random
from numpyro.contrib.module import random_flax_module
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
def get_data(N: int = 30, N_test: int = 1000):
X = jnp.asarray(np.random.uniform(-np.pi * 3 / 2, np.pi, size=(N, 1)))
y = jnp.asarray(np.sin(X) + np.random.normal(loc=0, scale=0.2, size=(N, 1)))
X_test = jnp.linspace(-np.pi * 2, 2 * np.pi, num=N_test).reshape((-1, 1))
return X.ravel(), y.ravel(), X_test.ravel()
class Net(nn.Module):
n_units: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
x = nn.Dense(self.n_units)(x)
x = nn.relu(x)
mean = nn.Dense(1)(x)
rho = nn.Dense(1)(x)
return mean.squeeze(), rho.squeeze()
def model(x, y=None, batch_size=None):
module = Net(n_units=16)
net = random_flax_module("nn", module, dist.Normal(0, 1.), input_shape=x.shape)
with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
batch_x = numpyro.subsample(x, event_dim=1)
batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
mean, rho = net(batch_x)
sigma = nn.softplus(rho)
numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)
n_train_data = 5000
X, y, X_test = get_data(N=n_train_data)
# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
# svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())
kernel = NUTS(model, max_tree_depth=1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1)
mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
Traceback (most recent call last):
File "src/DNN_flax.py", line 101, in <module>
mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 572, in run
states_flat, last_state = partial_map_fn(map_args)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 383, in _single_chain_mcmc
collect_vals = fori_collect(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 353, in fori_collect
vals = jit(_body_fn)(i, vals)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 320, in _body_fn
val = body_fun(val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 174, in _sample_fn_nojit_args
return (sampler.sample(state[0], args, kwargs),)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 760, in sample
return self._sample_fn(state, model_args, model_kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 460, in sample_kernel
vv_state, energy, num_steps, accept_prob, diverging = _next(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 400, in _nuts_next
binary_tree = build_tree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1181, in build_tree
tree, _ = while_loop(_cond_fn, _body_fn, state)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1165, in _body_fn
tree = _double_tree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 917, in _double_tree
new_tree = _iterative_build_subtree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1065, in _iterative_build_subtree
tree, turning, _, _, _ = while_loop(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1007, in _body_fn
new_leaf = _build_basetree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 859, in _build_basetree
z_new, r_new, potential_energy_new, z_new_grad = vv_update(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 298, in update_fn
potential_energy, z_grad = _value_and_grad(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 247, in _value_and_grad
return value_and_grad(f)(x)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 227, in potential_energy
log_joint, model_trace = log_density_(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 53, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/handlers.py", line 165, in get_trace
self(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
[Previous line repeated 1 more time]
File "src/DNN_flax.py", line 78, in model
with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 444, in __init__
self.dim, self._indices = self._subsample(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 467, in _subsample
apply_stack(msg)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 35, in apply_stack
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 385, in _subsample_fn
assert rng_key is not None, "Missing random key to generate subsample indices."
AssertionError: Missing random key to generate subsample indices.
``