-
Notifications
You must be signed in to change notification settings - Fork 269
Illustrate the usage of infer_discrete in the annotation example #1043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d838426
c9311db
394fd8b
72e1523
c38c2c8
c786fb5
12b1ee6
2864d35
53650b3
0dc44b8
0245910
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,16 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections import OrderedDict | ||
from collections import OrderedDict, defaultdict | ||
import functools | ||
|
||
from jax import random | ||
import jax.numpy as jnp | ||
|
||
import funsor | ||
from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace | ||
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate | ||
from numpyro.distributions.util import is_identically_one | ||
from numpyro.handlers import block, replay, seed, trace | ||
from numpyro.contrib.funsor.enum_messenger import enum | ||
from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name | ||
from numpyro.handlers import block, seed, substitute, trace | ||
from numpyro.infer.util import _guess_max_plate_nesting | ||
|
||
|
||
|
@@ -38,46 +38,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs): | |
return OrderedDict(funsor_dist.terms)[name][0] | ||
|
||
|
||
def terms_from_trace(tr): | ||
"""Helper function to extract elbo components from execution traces.""" | ||
log_factors = {} | ||
log_measures = {} | ||
sum_vars, prod_vars = frozenset(), frozenset() | ||
for site in tr.values(): | ||
if site["type"] == "sample": | ||
value = site["value"] | ||
intermediates = site["intermediates"] | ||
scale = site["scale"] | ||
if intermediates: | ||
log_prob = site["fn"].log_prob(value, intermediates) | ||
else: | ||
log_prob = site["fn"].log_prob(value) | ||
|
||
if (scale is not None) and (not is_identically_one(scale)): | ||
log_prob = scale * log_prob | ||
|
||
dim_to_name = site["infer"]["dim_to_name"] | ||
log_prob_factor = funsor.to_funsor( | ||
log_prob, output=funsor.Real, dim_to_name=dim_to_name | ||
) | ||
|
||
if site["is_observed"]: | ||
log_factors[site["name"]] = log_prob_factor | ||
else: | ||
log_measures[site["name"]] = log_prob_factor | ||
sum_vars |= frozenset({site["name"]}) | ||
prod_vars |= frozenset( | ||
f.name for f in site["cond_indep_stack"] if f.dim is not None | ||
) | ||
|
||
return { | ||
"log_factors": log_factors, | ||
"log_measures": log_measures, | ||
"measure_vars": sum_vars, | ||
"plate_vars": prod_vars, | ||
} | ||
|
||
|
||
def _sample_posterior( | ||
model, first_available_dim, temperature, rng_key, *args, **kwargs | ||
): | ||
|
@@ -97,27 +57,14 @@ def _sample_posterior( | |
model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) | ||
first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 | ||
|
||
with block(), enum(first_available_dim=first_available_dim): | ||
with plate_to_enum_plate(): | ||
model_tr = packed_trace(model).get_trace(*args, **kwargs) | ||
|
||
terms = terms_from_trace(model_tr) | ||
# terms["log_factors"] = [log p(x) for each observed or latent sample site x] | ||
# terms["log_measures"] = [log p(z) or other Dice factor | ||
# for each latent sample site z] | ||
|
||
with funsor.interpretations.lazy: | ||
log_prob = funsor.sum_product.sum_product( | ||
sum_op, | ||
prod_op, | ||
list(terms["log_factors"].values()) + list(terms["log_measures"].values()), | ||
eliminate=terms["measure_vars"] | terms["plate_vars"], | ||
plates=terms["plate_vars"], | ||
) | ||
log_prob = funsor.optimizer.apply_optimizer(log_prob) | ||
with funsor.adjoint.AdjointTape() as tape: | ||
with block(), enum(first_available_dim=first_available_dim): | ||
log_prob, model_tr, log_measures = _enum_log_density( | ||
model, args, kwargs, {}, sum_op, prod_op | ||
) | ||
|
||
with approx: | ||
approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) | ||
approx_factors = tape.adjoint(sum_op, prod_op, log_prob) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does using the more functional interface There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the late response, @eb8680! Yes, for some reason there is an error happens at this line.
|
||
|
||
# construct a result trace to replay against the model | ||
sample_tr = model_tr.copy() | ||
|
@@ -138,13 +85,40 @@ def _sample_posterior( | |
value, name_to_dim=node["infer"]["name_to_dim"] | ||
) | ||
else: | ||
log_measure = approx_factors[terms["log_measures"][name]] | ||
log_measure = approx_factors[log_measures[name]] | ||
sample_subs[name] = _get_support_value(log_measure, name) | ||
node["value"] = funsor.to_data( | ||
sample_subs[name], name_to_dim=node["infer"]["name_to_dim"] | ||
) | ||
|
||
with replay(guide_trace=sample_tr): | ||
data = { | ||
name: site["value"] | ||
for name, site in sample_tr.items() | ||
if site["type"] == "sample" | ||
} | ||
|
||
# concatenate _PREV_foo to foo | ||
time_vars = defaultdict(list) | ||
for name in data: | ||
if name.startswith("_PREV_"): | ||
root_name = _shift_name(name, -_get_shift(name)) | ||
time_vars[root_name].append(name) | ||
for name in time_vars: | ||
if name in data: | ||
time_vars[name].append(name) | ||
time_vars[name] = sorted(time_vars[name], key=len, reverse=True) | ||
|
||
for root_name, vars in time_vars.items(): | ||
prototype_shape = model_trace[root_name]["value"].shape | ||
values = [data.pop(name) for name in vars] | ||
if len(values) == 1: | ||
data[root_name] = values[0].reshape(prototype_shape) | ||
else: | ||
assert len(prototype_shape) >= 1 | ||
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values] | ||
data[root_name] = jnp.concatenate(values) | ||
|
||
with substitute(data=data): | ||
return model(*args, **kwargs) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does using
lazy
here as in the old version andfunsor.adjoint.adjoint
below not work? I see we made this change in #991 but can't recall why.