Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,8 @@ def vectorize_over_posterior(
if rv in all_rvs
]:
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
if (
rv not in needed_rvs
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs}
if rv not in needed_rvs and not (
{*outputs, *needed_rvs, *independent_rvs} & set(rv_ancestors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputs / needed_rvs / independent_rvs are also the blockers store than in an intermediate variable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter rv not in needed_rvs already in the list comp that starts this loop?

):
independent_rvs.append(rv)
for rv in independent_rvs:
Expand Down
30 changes: 30 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
atol=0.6 / np.sqrt(10000),
)
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)


def test_vectorize_over_posterior_with_intermediate_rvs():
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Normal.dist(a)
c = b + 1
d = pm.Normal.dist(c)
idata = pm.sample_prior_predictive(100, var_names=["a"])
idata.add_groups({"posterior": idata.prior})
_, _, vectorized_no_intermediate = vectorize_over_posterior(
outputs=[b, c, d],
posterior=idata.posterior,
input_rvs=[a],
allow_rvs_in_graph=True,
)
[vectorized_intermediate_rvs] = vectorize_over_posterior(
outputs=[d],
posterior=idata.posterior,
input_rvs=[a],
allow_rvs_in_graph=True,
)
assert vectorized_no_intermediate.type.shape == (1, 100)
assert vectorized_no_intermediate.type.shape == vectorized_intermediate_rvs.type.shape
a_ancestor1 = get_var_by_name([vectorized_no_intermediate], "a")[0]
a_ancestor2 = get_var_by_name([vectorized_intermediate_rvs], "a")[0]
Comment on lines +1985 to +1986
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming there should only be one match:

Suggested change
a_ancestor1 = get_var_by_name([vectorized_no_intermediate], "a")[0]
a_ancestor2 = get_var_by_name([vectorized_intermediate_rvs], "a")[0]
[a_ancestor1] = get_var_by_name([vectorized_no_intermediate], "a")
[a_ancestor2] = get_var_by_name([vectorized_intermediate_rvs], "a")

assert isinstance(a_ancestor1, TensorConstant)
assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data)
assert isinstance(a_ancestor2, TensorConstant)
assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)
Loading