Skip to content

Raise better error message when using HMC for models with subsample #1293

@kaijennissen

Description

@kaijennissen

Hi,

I was trying to combine subsampling and MCMC. Is this possible, because I receive the following error AssertionError: Missing random key to generate subsample indices.
I've looked into the code but couldn't figure out where the rng_key should be passed to the _subsample_fn.

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from flax import linen as nn
from jax import random
from numpyro.contrib.module import random_flax_module
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO


def get_data(N: int = 30, N_test: int = 1000):
    X = jnp.asarray(np.random.uniform(-np.pi * 3 / 2, np.pi, size=(N, 1)))
    y = jnp.asarray(np.sin(X) + np.random.normal(loc=0, scale=0.2, size=(N, 1)))
    X_test = jnp.linspace(-np.pi * 2, 2 * np.pi, num=N_test).reshape((-1, 1))
    return X.ravel(), y.ravel(), X_test.ravel()


class Net(nn.Module):
    n_units: int

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        mean = nn.Dense(1)(x)
        rho = nn.Dense(1)(x)
        return mean.squeeze(), rho.squeeze()


def model(x, y=None, batch_size=None):
    module = Net(n_units=16)
    net = random_flax_module("nn", module, dist.Normal(0, 1.), input_shape=x.shape)
    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
        batch_x = numpyro.subsample(x, event_dim=1)
        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
        mean, rho = net(batch_x)
        sigma = nn.softplus(rho)
        numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)


n_train_data = 5000
X, y, X_test = get_data(N=n_train_data)

# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
# svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())


kernel = NUTS(model, max_tree_depth=1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1)
mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
Traceback (most recent call last):
  File "src/DNN_flax.py", line 101, in <module>
    mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 572, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 383, in _single_chain_mcmc
    collect_vals = fori_collect(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 353, in fori_collect
    vals = jit(_body_fn)(i, vals)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 320, in _body_fn
    val = body_fun(val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 174, in _sample_fn_nojit_args
    return (sampler.sample(state[0], args, kwargs),)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 760, in sample
    return self._sample_fn(state, model_args, model_kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 460, in sample_kernel
    vv_state, energy, num_steps, accept_prob, diverging = _next(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 400, in _nuts_next
    binary_tree = build_tree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1181, in build_tree
    tree, _ = while_loop(_cond_fn, _body_fn, state)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1165, in _body_fn
    tree = _double_tree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 917, in _double_tree
    new_tree = _iterative_build_subtree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1065, in _iterative_build_subtree
    tree, turning, _, _, _ = while_loop(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1007, in _body_fn
    new_leaf = _build_basetree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 859, in _build_basetree
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 298, in update_fn
    potential_energy, z_grad = _value_and_grad(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 247, in _value_and_grad
    return value_and_grad(f)(x)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 227, in potential_energy
    log_joint, model_trace = log_density_(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 53, in log_density
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/handlers.py", line 165, in get_trace
    self(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  [Previous line repeated 1 more time]
  File "src/DNN_flax.py", line 78, in model
    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 444, in __init__
    self.dim, self._indices = self._subsample(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 467, in _subsample
    apply_stack(msg)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 35, in apply_stack
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 385, in _subsample_fn
    assert rng_key is not None, "Missing random key to generate subsample indices."
AssertionError: Missing random key to generate subsample indices.
``

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions