From d8384260fcfb9448990167ca084715df5ace5b57 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 2 Apr 2021 22:00:16 -0500 Subject: [PATCH 1/9] first try to make infer_discrete work for scan --- numpyro/contrib/funsor/discrete.py | 82 ++++++---------------------- numpyro/contrib/funsor/infer_util.py | 69 ++++++++++++++--------- test/contrib/test_infer_discrete.py | 33 +++++++++++ 3 files changed, 92 insertions(+), 92 deletions(-) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 59ed4513a..8ab2cfa3b 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -7,10 +7,9 @@ from jax import random import funsor -from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace -from numpyro.contrib.funsor.infer_util import plate_to_enum_plate -from numpyro.distributions.util import is_identically_one -from numpyro.handlers import block, replay, seed, trace +from numpyro.contrib.funsor.enum_messenger import enum +from numpyro.contrib.funsor.infer_util import _enum_log_density +from numpyro.handlers import block, seed, substitute, trace from numpyro.infer.util import _guess_max_plate_nesting @@ -38,46 +37,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs): return OrderedDict(funsor_dist.terms)[name][0] -def terms_from_trace(tr): - """Helper function to extract elbo components from execution traces.""" - log_factors = {} - log_measures = {} - sum_vars, prod_vars = frozenset(), frozenset() - for site in tr.values(): - if site["type"] == "sample": - value = site["value"] - intermediates = site["intermediates"] - scale = site["scale"] - if intermediates: - log_prob = site["fn"].log_prob(value, intermediates) - else: - log_prob = site["fn"].log_prob(value) - - if (scale is not None) and (not is_identically_one(scale)): - log_prob = scale * log_prob - - dim_to_name = site["infer"]["dim_to_name"] - log_prob_factor = funsor.to_funsor( - log_prob, output=funsor.Real, dim_to_name=dim_to_name - ) - - if site["is_observed"]: - log_factors[site["name"]] = log_prob_factor - else: - log_measures[site["name"]] = log_prob_factor - sum_vars |= frozenset({site["name"]}) - prod_vars |= frozenset( - f.name for f in site["cond_indep_stack"] if f.dim is not None - ) - - return { - "log_factors": log_factors, - "log_measures": log_measures, - "measure_vars": sum_vars, - "plate_vars": prod_vars, - } - - def _sample_posterior( model, first_available_dim, temperature, rng_key, *args, **kwargs ): @@ -97,27 +56,14 @@ def _sample_posterior( model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 - with block(), enum(first_available_dim=first_available_dim): - with plate_to_enum_plate(): - model_tr = packed_trace(model).get_trace(*args, **kwargs) - - terms = terms_from_trace(model_tr) - # terms["log_factors"] = [log p(x) for each observed or latent sample site x] - # terms["log_measures"] = [log p(z) or other Dice factor - # for each latent sample site z] - - with funsor.interpretations.lazy: - log_prob = funsor.sum_product.sum_product( - sum_op, - prod_op, - list(terms["log_factors"].values()) + list(terms["log_measures"].values()), - eliminate=terms["measure_vars"] | terms["plate_vars"], - plates=terms["plate_vars"], - ) - log_prob = funsor.optimizer.apply_optimizer(log_prob) + with funsor.adjoint.AdjointTape() as tape: + with block(), enum(first_available_dim=first_available_dim): + log_prob, model_tr, log_measures = _enum_log_density( + model, args, kwargs, {}, sum_op, prod_op + ) with approx: - approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) + approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() @@ -138,13 +84,19 @@ def _sample_posterior( value, name_to_dim=node["infer"]["name_to_dim"] ) else: - log_measure = approx_factors[terms["log_measures"][name]] + log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"] ) - with replay(guide_trace=sample_tr): + data = { + name: site["value"] + for name, site in sample_tr.items() + if site["type"] == "sample" + } + print(data) + with substitute(data=data): return model(*args, **kwargs) diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 09d94b88f..5a65a2db5 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -100,6 +100,8 @@ def compute_markov_factors( sum_vars, prod_vars, history, + sum_op, + prod_op, ): """ :param dict time_to_factors: a map from time variable to the log prob factors. @@ -119,8 +121,8 @@ def compute_markov_factors( eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var] with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, + sum_op, + prod_op, log_factors, eliminate=eliminate_vars, plates=prod_vars, @@ -136,7 +138,7 @@ def compute_markov_factors( ) markov_factors.append( funsor.sum_product.sarkka_bilmes_product( - funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars + sum_op, prod_op, trans, time_var, global_vars ) ) else: @@ -144,33 +146,14 @@ def compute_markov_factors( prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars} markov_factors.append( funsor.sum_product.sequential_sum_product( - funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr + sum_op, prod_op, trans, time_var, prev_to_curr ) ) return markov_factors -def log_density(model, model_args, model_kwargs, params): - """ - Similar to :func:`numpyro.infer.util.log_density` but works for models - with discrete latent variables. Internally, this uses :mod:`funsor` - to marginalize discrete latent sites and evaluate the joint log probability. - - :param model: Python callable containing NumPyro primitives. Typically, - the model has been enumerated by using - :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: - - def model(*args, **kwargs): - ... - - log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) - - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. - :return: log of joint density and a corresponding model trace - """ +def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): + """Helper function to compute elbo and extract its components from execution traces.""" model = substitute(model, data=params) with plate_to_enum_plate(): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) @@ -180,6 +163,7 @@ def model(*args, **kwargs): time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() history = 1 + log_measures = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] @@ -214,7 +198,9 @@ def model(*args, **kwargs): log_factors.append(log_prob_factor) if not site["is_observed"]: + log_measures[site["name"]] = log_prob_factor sum_vars |= frozenset({site["name"]}) + prod_vars |= frozenset( f.name for f in site["cond_indep_stack"] if f.dim is not None ) @@ -236,13 +222,15 @@ def model(*args, **kwargs): sum_vars, prod_vars, history, + sum_op, + prod_op, ) log_factors = log_factors + markov_factors with funsor.interpretations.lazy: lazy_result = funsor.sum_product.sum_product( - funsor.ops.logaddexp, - funsor.ops.add, + sum_op, + prod_op, log_factors, eliminate=sum_vars | prod_vars, plates=prod_vars, @@ -255,4 +243,31 @@ def model(*args, **kwargs): result.data.shape, {k.split("__BOUND")[0] for k in result.inputs} ) ) + return result, model_trace, log_measures + + +def log_density(model, model_args, model_kwargs, params): + """ + Similar to :func:`numpyro.infer.util.log_density` but works for models + with discrete latent variables. Internally, this uses :mod:`funsor` + to marginalize discrete latent sites and evaluate the joint log probability. + + :param model: Python callable containing NumPyro primitives. Typically, + the model has been enumerated by using + :class:`~numpyro.contrib.funsor.enum_messenger.enum` handler:: + + def model(*args, **kwargs): + ... + + log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params) + + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of current parameter values keyed by site + name. + :return: log of joint density and a corresponding model trace + """ + result, model_trace, _ = _enum_log_density( + model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add + ) return result.data, model_trace diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 78b24d15b..cf914baed 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -12,6 +12,7 @@ import numpyro from numpyro import handlers, infer +from numpyro.contrib.control_flow import scan import numpyro.distributions as dist from numpyro.distributions.util import is_identically_one @@ -81,6 +82,38 @@ def hmm(data, hidden_dim=10): logger.info("inferred states: {}".format(list(map(int, inferred_states)))) +@pytest.mark.parametrize("length", [1, 2, 10]) +@pytest.mark.parametrize("temperature", [0, 1]) +def test_scan_hmm_smoke(length, temperature): + + # This should match the example in the infer_discrete docstring. + def hmm(data, hidden_dim=10): + transition = 0.3 / hidden_dim + 0.7 * jnp.eye(hidden_dim) + means = jnp.arange(float(hidden_dim)) + + def transition_fn(state, y): + state = numpyro.sample("states", dist.Categorical(transition[state])) + y = numpyro.sample("obs", dist.Normal(means[state], 1.0), obs=y) + return state, (state, y) + + _, (states, data) = scan(transition_fn, 0, data, length=length) + + return [0] + [s for s in states], data + + true_states, data = handlers.seed(hmm, 0)(None) + assert len(data) == length + assert len(true_states) == 1 + len(data) + + decoder = infer_discrete( + config_enumerate(hmm), temperature=temperature, rng_key=random.PRNGKey(1) + ) + inferred_states, _ = decoder(data) + assert len(inferred_states) == len(true_states) + + logger.info("true states: {}".format(list(map(int, true_states)))) + logger.info("inferred states: {}".format(list(map(int, inferred_states)))) + + def vectorize_model(model, size, dim): def fn(*args, **kwargs): with numpyro.plate("particles", size=size, dim=dim): From c9311db1d9508192d6bd372e23d678804e422fb0 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 3 Apr 2021 00:45:00 -0500 Subject: [PATCH 2/9] fix substituting errors --- numpyro/contrib/funsor/discrete.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 8ab2cfa3b..72767ff84 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -1,14 +1,15 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict +from collections import OrderedDict, defaultdict import functools from jax import random +import jax.numpy as jnp import funsor from numpyro.contrib.funsor.enum_messenger import enum -from numpyro.contrib.funsor.infer_util import _enum_log_density +from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name from numpyro.handlers import block, seed, substitute, trace from numpyro.infer.util import _guess_max_plate_nesting @@ -95,7 +96,28 @@ def _sample_posterior( for name, site in sample_tr.items() if site["type"] == "sample" } - print(data) + + # concatenate _PREV_foo to foo + time_vars = defaultdict(list) + for name in data: + if name.startswith("_PREV_"): + root_name = _shift_name(name, -_get_shift(name)) + time_vars[root_name].append(name) + for name in time_vars: + if name in data: + time_vars[name].append(name) + time_vars[name] = sorted(time_vars[name], key=len, reverse=True) + + for root_name, vars in time_vars.items(): + prototype_shape = model_trace[root_name]["value"].shape + values = [data.pop(name) for name in vars] + if len(values) == 1: + data[root_name] = values[0].reshape(prototype_shape) + else: + assert len(prototype_shape) >= 1 + values = [v.reshape((-1,) + prototype_shape[1:]) for v in values] + data[root_name] = jnp.concatenate(values) + with substitute(data=data): return model(*args, **kwargs) From 72e152330d752dbb72b8e0491da688a570fa042f Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 27 May 2021 16:52:33 -0500 Subject: [PATCH 3/9] add todo infer_discrete for annotation example --- examples/annotation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/annotation.py b/examples/annotation.py index da35e89ed..2b9819c2d 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -46,6 +46,7 @@ import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam +# TODO: use infer_discrete def get_data(): From c38c2c8c0513e88bc54684972f75abaa8acd9601 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 11:55:47 -0500 Subject: [PATCH 4/9] add infer discrete to annotation --- examples/annotation.py | 22 ++++++++++++++++++---- numpyro/handlers.py | 15 ++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 2b9819c2d..7994efa20 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -37,16 +37,17 @@ import numpy as np -from jax import nn, random +from jax import nn, random, vmap import jax.numpy as jnp import numpyro from numpyro import handlers +from numpyro.contrib.funsor import config_enumerate, infer_discrete from numpyro.contrib.indexing import Vindex +from numpyro.diagnostics import print_summary import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS +from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam -# TODO: use infer_discrete def get_data(): @@ -311,7 +312,20 @@ def main(args): progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(0), *data) - mcmc.print_summary() + posterior_samples = mcmc.get_samples() + + def infer_discrete_model(rng_key, samples): + rng_key, subkey = random.split(rng_key) + conditioned_model = handlers.condition(model, data=samples) + infer_discrete_model = infer_discrete(config_enumerate(conditioned_model), + rng_key=subkey) + predictive = Predictive(infer_discrete_model, num_samples=1, batch_ndims=0) + return predictive(rng_key, *data) + + discrete_samples = vmap(infer_discrete_model)( + random.split(random.PRNGKey(1), args.num_samples), posterior_samples) + posterior_samples.update(discrete_samples) + print_summary(posterior_samples) if __name__ == "__main__": diff --git a/numpyro/handlers.py b/numpyro/handlers.py index e4d736a27..857f0888e 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -196,14 +196,19 @@ class replay(Messenger): >>> assert replayed_trace['a']['value'] == exec_trace['a']['value'] """ - def __init__(self, fn=None, guide_trace=None): - assert guide_trace is not None - self.guide_trace = guide_trace + def __init__(self, fn=None, trace=None, guide_trace=None): + if guide_trace is not None: + warnings.warn("`guide_trace` argument is deprecated. Please replace it by `trace`.", + FutureWarning) + if guide_trace is not None: + trace = guide_trace + assert trace is not None + self.trace = trace super(replay, self).__init__(fn) def process_message(self, msg): - if msg["type"] in ("sample", "plate") and msg["name"] in self.guide_trace: - msg["value"] = self.guide_trace[msg["name"]]["value"] + if msg["type"] in ("sample", "plate") and msg["name"] in self.trace: + msg["value"] = self.trace[msg["name"]]["value"] class block(Messenger): From 12b1ee6d23d3ace1df8d3897612b21c3c19647ca Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 13:30:17 -0500 Subject: [PATCH 5/9] print history in annotation example --- examples/annotation.py | 32 ++++++++++++++++++++++++-------- numpyro/handlers.py | 6 ++++-- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 7994efa20..ab1ba6272 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -312,20 +312,36 @@ def main(args): progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, ) mcmc.run(random.PRNGKey(0), *data) - posterior_samples = mcmc.get_samples() + mcmc.print_summary() def infer_discrete_model(rng_key, samples): rng_key, subkey = random.split(rng_key) conditioned_model = handlers.condition(model, data=samples) - infer_discrete_model = infer_discrete(config_enumerate(conditioned_model), - rng_key=subkey) - predictive = Predictive(infer_discrete_model, num_samples=1, batch_ndims=0) - return predictive(rng_key, *data) + infer_discrete_model = infer_discrete( + config_enumerate(conditioned_model), rng_key=subkey + ) + with handlers.trace() as tr: + infer_discrete_model(*data) + return { + name: site["value"] + for name, site in tr.items() + if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel" + } + + posterior_samples = mcmc.get_samples() discrete_samples = vmap(infer_discrete_model)( - random.split(random.PRNGKey(1), args.num_samples), posterior_samples) - posterior_samples.update(discrete_samples) - print_summary(posterior_samples) + random.split(random.PRNGKey(1), args.num_samples), posterior_samples + ) + + item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( + discrete_samples["c"].squeeze(-1) + ) + print("Histogram of the predicted class of each item:") + row_format = "{:>10}" * 5 + print(row_format.format("", *["c={}".format(i) for i in range(4)])) + for i, row in enumerate(item_class): + print(row_format.format(f"item[{i}]", *row)) if __name__ == "__main__": diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 857f0888e..53be5116d 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -198,8 +198,10 @@ class replay(Messenger): def __init__(self, fn=None, trace=None, guide_trace=None): if guide_trace is not None: - warnings.warn("`guide_trace` argument is deprecated. Please replace it by `trace`.", - FutureWarning) + warnings.warn( + "`guide_trace` argument is deprecated. Please replace it by `trace`.", + FutureWarning, + ) if guide_trace is not None: trace = guide_trace assert trace is not None From 2864d3589b8a1ff8c662163d0df1900f86267742 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 13:38:22 -0500 Subject: [PATCH 6/9] mark xfail for nontrivial scan history --- test/contrib/test_infer_discrete.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/contrib/test_infer_discrete.py b/test/contrib/test_infer_discrete.py index 45d282cba..dd364531e 100644 --- a/test/contrib/test_infer_discrete.py +++ b/test/contrib/test_infer_discrete.py @@ -82,7 +82,19 @@ def hmm(data, hidden_dim=10): logger.info("inferred states: {}".format(list(map(int, inferred_states)))) -@pytest.mark.parametrize("length", [1, 2, 10]) +@pytest.mark.parametrize( + "length", + [ + 1, + 2, + pytest.param( + 10, + marks=pytest.mark.xfail( + reason="adjoint does not work with markov sum product yet." + ), + ), + ], +) @pytest.mark.parametrize("temperature", [0, 1]) def test_scan_hmm_smoke(length, temperature): From 53650b33092dae576162441748640e2ee89cadec Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 13:40:48 -0500 Subject: [PATCH 7/9] make lint --- examples/annotation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index ab1ba6272..8fac3ccf4 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -44,9 +44,8 @@ from numpyro import handlers from numpyro.contrib.funsor import config_enumerate, infer_discrete from numpyro.contrib.indexing import Vindex -from numpyro.diagnostics import print_summary import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS, Predictive +from numpyro.infer import MCMC, NUTS from numpyro.infer.reparam import LocScaleReparam From 0dc44b8edd118bc3a56a713c9954bf19bbc324f2 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 13:52:41 -0500 Subject: [PATCH 8/9] simplify infer discrete code --- examples/annotation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 8fac3ccf4..6b7ada33e 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -314,10 +314,9 @@ def main(args): mcmc.print_summary() def infer_discrete_model(rng_key, samples): - rng_key, subkey = random.split(rng_key) conditioned_model = handlers.condition(model, data=samples) infer_discrete_model = infer_discrete( - config_enumerate(conditioned_model), rng_key=subkey + config_enumerate(conditioned_model), rng_key=rng_key ) with handlers.trace() as tr: infer_discrete_model(*data) From 02459100198dd892edcf0c1d9463aeab8558bf2d Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 31 May 2021 15:07:44 -0500 Subject: [PATCH 9/9] fix failing test --- test/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_handlers.py b/test/test_handlers.py index ccdfe9c66..a58926278 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -383,7 +383,7 @@ def test_subsample_replay(): with numpyro.plate("a", len(data), subsample_size=subsample_size): pass - with handlers.seed(rng_seed=1), handlers.replay(guide_trace=guide_trace): + with handlers.seed(rng_seed=1), handlers.replay(trace=guide_trace): with numpyro.plate("a", len(data)): subsample_data = numpyro.subsample(data, event_dim=0) assert subsample_data.shape == (subsample_size,)