Skip to content

Commit 4ad6b98

Browse files
authored
Illustrate the usage of infer_discrete in the annotation example (#1043)
* first try to make infer_discrete work for scan * fix substituting errors * add todo infer_discrete for annotation example * add infer discrete to annotation * print history in annotation example * mark xfail for nontrivial scan history * make lint * simplify infer discrete code * fix failing test
1 parent 1c44b1e commit 4ad6b98

File tree

6 files changed

+170
-100
lines changed

6 files changed

+170
-100
lines changed

examples/annotation.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@
3737

3838
import numpy as np
3939

40-
from jax import nn, random
40+
from jax import nn, random, vmap
4141
import jax.numpy as jnp
4242

4343
import numpyro
4444
from numpyro import handlers
45+
from numpyro.contrib.funsor import config_enumerate, infer_discrete
4546
from numpyro.contrib.indexing import Vindex
4647
import numpyro.distributions as dist
4748
from numpyro.infer import MCMC, NUTS
@@ -312,6 +313,34 @@ def main(args):
312313
mcmc.run(random.PRNGKey(0), *data)
313314
mcmc.print_summary()
314315

316+
def infer_discrete_model(rng_key, samples):
317+
conditioned_model = handlers.condition(model, data=samples)
318+
infer_discrete_model = infer_discrete(
319+
config_enumerate(conditioned_model), rng_key=rng_key
320+
)
321+
with handlers.trace() as tr:
322+
infer_discrete_model(*data)
323+
324+
return {
325+
name: site["value"]
326+
for name, site in tr.items()
327+
if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
328+
}
329+
330+
posterior_samples = mcmc.get_samples()
331+
discrete_samples = vmap(infer_discrete_model)(
332+
random.split(random.PRNGKey(1), args.num_samples), posterior_samples
333+
)
334+
335+
item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
336+
discrete_samples["c"].squeeze(-1)
337+
)
338+
print("Histogram of the predicted class of each item:")
339+
row_format = "{:>10}" * 5
340+
print(row_format.format("", *["c={}".format(i) for i in range(4)]))
341+
for i, row in enumerate(item_class):
342+
print(row_format.format(f"item[{i}]", *row))
343+
315344

316345
if __name__ == "__main__":
317346
assert numpyro.__version__.startswith("0.6.0")

numpyro/contrib/funsor/discrete.py

Lines changed: 40 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright Contributors to the Pyro project.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from collections import OrderedDict
4+
from collections import OrderedDict, defaultdict
55
import functools
66

77
from jax import random
8+
import jax.numpy as jnp
89

910
import funsor
10-
from numpyro.contrib.funsor.enum_messenger import enum, trace as packed_trace
11-
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate
12-
from numpyro.distributions.util import is_identically_one
13-
from numpyro.handlers import block, replay, seed, trace
11+
from numpyro.contrib.funsor.enum_messenger import enum
12+
from numpyro.contrib.funsor.infer_util import _enum_log_density, _get_shift, _shift_name
13+
from numpyro.handlers import block, seed, substitute, trace
1414
from numpyro.infer.util import _guess_max_plate_nesting
1515

1616

@@ -38,46 +38,6 @@ def _get_support_value_delta(funsor_dist, name, **kwargs):
3838
return OrderedDict(funsor_dist.terms)[name][0]
3939

4040

41-
def terms_from_trace(tr):
42-
"""Helper function to extract elbo components from execution traces."""
43-
log_factors = {}
44-
log_measures = {}
45-
sum_vars, prod_vars = frozenset(), frozenset()
46-
for site in tr.values():
47-
if site["type"] == "sample":
48-
value = site["value"]
49-
intermediates = site["intermediates"]
50-
scale = site["scale"]
51-
if intermediates:
52-
log_prob = site["fn"].log_prob(value, intermediates)
53-
else:
54-
log_prob = site["fn"].log_prob(value)
55-
56-
if (scale is not None) and (not is_identically_one(scale)):
57-
log_prob = scale * log_prob
58-
59-
dim_to_name = site["infer"]["dim_to_name"]
60-
log_prob_factor = funsor.to_funsor(
61-
log_prob, output=funsor.Real, dim_to_name=dim_to_name
62-
)
63-
64-
if site["is_observed"]:
65-
log_factors[site["name"]] = log_prob_factor
66-
else:
67-
log_measures[site["name"]] = log_prob_factor
68-
sum_vars |= frozenset({site["name"]})
69-
prod_vars |= frozenset(
70-
f.name for f in site["cond_indep_stack"] if f.dim is not None
71-
)
72-
73-
return {
74-
"log_factors": log_factors,
75-
"log_measures": log_measures,
76-
"measure_vars": sum_vars,
77-
"plate_vars": prod_vars,
78-
}
79-
80-
8141
def _sample_posterior(
8242
model, first_available_dim, temperature, rng_key, *args, **kwargs
8343
):
@@ -97,27 +57,14 @@ def _sample_posterior(
9757
model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs)
9858
first_available_dim = -_guess_max_plate_nesting(model_trace) - 1
9959

100-
with block(), enum(first_available_dim=first_available_dim):
101-
with plate_to_enum_plate():
102-
model_tr = packed_trace(model).get_trace(*args, **kwargs)
103-
104-
terms = terms_from_trace(model_tr)
105-
# terms["log_factors"] = [log p(x) for each observed or latent sample site x]
106-
# terms["log_measures"] = [log p(z) or other Dice factor
107-
# for each latent sample site z]
108-
109-
with funsor.interpretations.lazy:
110-
log_prob = funsor.sum_product.sum_product(
111-
sum_op,
112-
prod_op,
113-
list(terms["log_factors"].values()) + list(terms["log_measures"].values()),
114-
eliminate=terms["measure_vars"] | terms["plate_vars"],
115-
plates=terms["plate_vars"],
116-
)
117-
log_prob = funsor.optimizer.apply_optimizer(log_prob)
60+
with funsor.adjoint.AdjointTape() as tape:
61+
with block(), enum(first_available_dim=first_available_dim):
62+
log_prob, model_tr, log_measures = _enum_log_density(
63+
model, args, kwargs, {}, sum_op, prod_op
64+
)
11865

11966
with approx:
120-
approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
67+
approx_factors = tape.adjoint(sum_op, prod_op, log_prob)
12168

12269
# construct a result trace to replay against the model
12370
sample_tr = model_tr.copy()
@@ -138,13 +85,40 @@ def _sample_posterior(
13885
value, name_to_dim=node["infer"]["name_to_dim"]
13986
)
14087
else:
141-
log_measure = approx_factors[terms["log_measures"][name]]
88+
log_measure = approx_factors[log_measures[name]]
14289
sample_subs[name] = _get_support_value(log_measure, name)
14390
node["value"] = funsor.to_data(
14491
sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]
14592
)
14693

147-
with replay(guide_trace=sample_tr):
94+
data = {
95+
name: site["value"]
96+
for name, site in sample_tr.items()
97+
if site["type"] == "sample"
98+
}
99+
100+
# concatenate _PREV_foo to foo
101+
time_vars = defaultdict(list)
102+
for name in data:
103+
if name.startswith("_PREV_"):
104+
root_name = _shift_name(name, -_get_shift(name))
105+
time_vars[root_name].append(name)
106+
for name in time_vars:
107+
if name in data:
108+
time_vars[name].append(name)
109+
time_vars[name] = sorted(time_vars[name], key=len, reverse=True)
110+
111+
for root_name, vars in time_vars.items():
112+
prototype_shape = model_trace[root_name]["value"].shape
113+
values = [data.pop(name) for name in vars]
114+
if len(values) == 1:
115+
data[root_name] = values[0].reshape(prototype_shape)
116+
else:
117+
assert len(prototype_shape) >= 1
118+
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
119+
data[root_name] = jnp.concatenate(values)
120+
121+
with substitute(data=data):
148122
return model(*args, **kwargs)
149123

150124

numpyro/contrib/funsor/infer_util.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def compute_markov_factors(
100100
sum_vars,
101101
prod_vars,
102102
history,
103+
sum_op,
104+
prod_op,
103105
):
104106
"""
105107
:param dict time_to_factors: a map from time variable to the log prob factors.
@@ -119,8 +121,8 @@ def compute_markov_factors(
119121
eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var]
120122
with funsor.interpretations.lazy:
121123
lazy_result = funsor.sum_product.sum_product(
122-
funsor.ops.logaddexp,
123-
funsor.ops.add,
124+
sum_op,
125+
prod_op,
124126
log_factors,
125127
eliminate=eliminate_vars,
126128
plates=prod_vars,
@@ -136,41 +138,22 @@ def compute_markov_factors(
136138
)
137139
markov_factors.append(
138140
funsor.sum_product.sarkka_bilmes_product(
139-
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars
141+
sum_op, prod_op, trans, time_var, global_vars
140142
)
141143
)
142144
else:
143145
# remove `_PREV_` prefix to convert prev to curr
144146
prev_to_curr = {k: _shift_name(k, -_get_shift(k)) for k in prev_vars}
145147
markov_factors.append(
146148
funsor.sum_product.sequential_sum_product(
147-
funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr
149+
sum_op, prod_op, trans, time_var, prev_to_curr
148150
)
149151
)
150152
return markov_factors
151153

152154

153-
def log_density(model, model_args, model_kwargs, params):
154-
"""
155-
Similar to :func:`numpyro.infer.util.log_density` but works for models
156-
with discrete latent variables. Internally, this uses :mod:`funsor`
157-
to marginalize discrete latent sites and evaluate the joint log probability.
158-
159-
:param model: Python callable containing NumPyro primitives. Typically,
160-
the model has been enumerated by using
161-
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::
162-
163-
def model(*args, **kwargs):
164-
...
165-
166-
log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)
167-
168-
:param tuple model_args: args provided to the model.
169-
:param dict model_kwargs: kwargs provided to the model.
170-
:param dict params: dictionary of current parameter values keyed by site
171-
name.
172-
:return: log of joint density and a corresponding model trace
173-
"""
155+
def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
156+
"""Helper function to compute elbo and extract its components from execution traces."""
174157
model = substitute(model, data=params)
175158
with plate_to_enum_plate():
176159
model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
@@ -180,6 +163,7 @@ def model(*args, **kwargs):
180163
time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites
181164
sum_vars, prod_vars = frozenset(), frozenset()
182165
history = 1
166+
log_measures = {}
183167
for site in model_trace.values():
184168
if site["type"] == "sample":
185169
value = site["value"]
@@ -214,7 +198,9 @@ def model(*args, **kwargs):
214198
log_factors.append(log_prob_factor)
215199

216200
if not site["is_observed"]:
201+
log_measures[site["name"]] = log_prob_factor
217202
sum_vars |= frozenset({site["name"]})
203+
218204
prod_vars |= frozenset(
219205
f.name for f in site["cond_indep_stack"] if f.dim is not None
220206
)
@@ -236,13 +222,15 @@ def model(*args, **kwargs):
236222
sum_vars,
237223
prod_vars,
238224
history,
225+
sum_op,
226+
prod_op,
239227
)
240228
log_factors = log_factors + markov_factors
241229

242230
with funsor.interpretations.lazy:
243231
lazy_result = funsor.sum_product.sum_product(
244-
funsor.ops.logaddexp,
245-
funsor.ops.add,
232+
sum_op,
233+
prod_op,
246234
log_factors,
247235
eliminate=sum_vars | prod_vars,
248236
plates=prod_vars,
@@ -255,4 +243,31 @@ def model(*args, **kwargs):
255243
result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
256244
)
257245
)
246+
return result, model_trace, log_measures
247+
248+
249+
def log_density(model, model_args, model_kwargs, params):
250+
"""
251+
Similar to :func:`numpyro.infer.util.log_density` but works for models
252+
with discrete latent variables. Internally, this uses :mod:`funsor`
253+
to marginalize discrete latent sites and evaluate the joint log probability.
254+
255+
:param model: Python callable containing NumPyro primitives. Typically,
256+
the model has been enumerated by using
257+
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::
258+
259+
def model(*args, **kwargs):
260+
...
261+
262+
log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)
263+
264+
:param tuple model_args: args provided to the model.
265+
:param dict model_kwargs: kwargs provided to the model.
266+
:param dict params: dictionary of current parameter values keyed by site
267+
name.
268+
:return: log of joint density and a corresponding model trace
269+
"""
270+
result, model_trace, _ = _enum_log_density(
271+
model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272+
)
258273
return result.data, model_trace

numpyro/handlers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,21 @@ class replay(Messenger):
196196
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
197197
"""
198198

199-
def __init__(self, fn=None, guide_trace=None):
200-
assert guide_trace is not None
201-
self.guide_trace = guide_trace
199+
def __init__(self, fn=None, trace=None, guide_trace=None):
200+
if guide_trace is not None:
201+
warnings.warn(
202+
"`guide_trace` argument is deprecated. Please replace it by `trace`.",
203+
FutureWarning,
204+
)
205+
if guide_trace is not None:
206+
trace = guide_trace
207+
assert trace is not None
208+
self.trace = trace
202209
super(replay, self).__init__(fn)
203210

204211
def process_message(self, msg):
205-
if msg["type"] in ("sample", "plate") and msg["name"] in self.guide_trace:
206-
msg["value"] = self.guide_trace[msg["name"]]["value"]
212+
if msg["type"] in ("sample", "plate") and msg["name"] in self.trace:
213+
msg["value"] = self.trace[msg["name"]]["value"]
207214

208215

209216
class block(Messenger):

0 commit comments

Comments
 (0)