Skip to content

Commit 5272093

Browse files
committed
Implement Dim ZeroSumNormal
1 parent e6aec0e commit 5272093

File tree

5 files changed

+121
-7
lines changed

5 files changed

+121
-7
lines changed

pymc/dims/distributions/transforms.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytensor.tensor as pt
1415
import pytensor.xtensor as ptx
1516

1617
from pymc.logprob.transforms import Transform
@@ -51,3 +52,44 @@ def log_jac_det(self, value, *inputs):
5152

5253

5354
log_odds_transform = LogOddsTransform()
55+
56+
57+
class ZeroSumTransform(DimTransform):
58+
name = "zerosum"
59+
60+
def __init__(self, dims: tuple[str, ...]):
61+
self.dims = dims
62+
63+
@staticmethod
64+
def extend_dim(array, dim):
65+
n = (array.sizes[dim] + 1).astype("floatX")
66+
sum_vals = array.sum(dim)
67+
norm = sum_vals / (pt.sqrt(n) + n)
68+
fill_val = norm - sum_vals / pt.sqrt(n)
69+
70+
out = ptx.concat([array, fill_val], dim=dim)
71+
return out - norm
72+
73+
@staticmethod
74+
def reduce_dim(array, dim):
75+
n = array.sizes[dim].astype("floatX")
76+
last = array.isel({dim: -1})
77+
78+
sum_vals = -last * pt.sqrt(n)
79+
norm = sum_vals / (pt.sqrt(n) + n)
80+
return array.isel({dim: slice(None, -1)}) + norm
81+
82+
def forward(self, value, *rv_inputs):
83+
for dim in self.dims:
84+
value = self.reduce_dim(value, dim=dim)
85+
return value
86+
87+
def backward(self, value, *rv_inputs):
88+
for dim in self.dims:
89+
value = self.extend_dim(value, dim=dim)
90+
return value
91+
92+
def log_jac_det(self, value, *rv_inputs):
93+
# Use following once broadcast_like is implemented
94+
# as_xtensor(0).broadcast_like(value, exclude=self.dims)`
95+
return value.sum(self.dims) * 0

pymc/dims/distributions/vector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def xrv_op(self, sigma, support_dims, core_dims, extra_dims=None, rng=None):
146146
support_dims = as_xtensor(support_dims, dims=("_",))
147147
support_shape = support_dims.values
148148
core_rv = ZeroSumNormalRV.rv_op(sigma=sigma.values, support_shape=support_shape).owner.op
149-
xop = pxr._as_xrv(
149+
xop = pxr.as_xrv(
150150
core_rv,
151151
core_inps_dims_map=[(), (0,)],
152152
core_out_dims_map=tuple(range(1, len(core_dims) + 1)),

pymc/distributions/multivariate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,7 @@ def logp(value, alpha, K):
26642664
class ZeroSumNormalRV(SymbolicRandomVariable):
26652665
"""ZeroSumNormal random variable."""
26662666

2667+
name = "ZeroSumNormal"
26672668
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
26682669

26692670
@classmethod
@@ -2687,12 +2688,12 @@ def rv_op(cls, sigma, support_shape, *, size=None, rng=None):
26872688
zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True)
26882689

26892690
support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
2690-
extended_signature = f"[rng],(),(s),[size]->[rng],({support_str})"
2691-
return ZeroSumNormalRV(
2692-
inputs=[rng, sigma, support_shape, size],
2691+
extended_signature = f"[rng],[size],(),(s)->[rng],({support_str})"
2692+
return cls(
2693+
inputs=[rng, size, sigma, support_shape],
26932694
outputs=[next_rng, zerosum_rv],
26942695
extended_signature=extended_signature,
2695-
)(rng, sigma, support_shape, size)
2696+
)(rng, size, sigma, support_shape)
26962697

26972698

26982699
class ZeroSumNormal(Distribution):
@@ -2828,7 +2829,7 @@ def zerosum_default_transform(op, rv):
28282829

28292830

28302831
@_logprob.register(ZeroSumNormalRV)
2831-
def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs):
2832+
def zerosumnormal_logp(op, values, rng, size, sigma, support_shape, **kwargs):
28322833
(value,) = values
28332834
shape = value.shape
28342835
n_zerosum_axes = op.ndim_supp

tests/dims/distributions/test_vector.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pymc.distributions as regular_distributions
2020

2121
from pymc import Model
22-
from pymc.dims import Categorical, MvNormal
22+
from pymc.dims import Categorical, MvNormal, ZeroSumNormal
2323
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2424

2525

@@ -60,3 +60,21 @@ def test_mvnormal():
6060

6161
assert_equivalent_random_graph(model, reference_model)
6262
assert_equivalent_logp_graph(model, reference_model)
63+
64+
65+
def test_zerosumnormal():
66+
coords = {"a": range(3), "b": range(2)}
67+
with Model(coords=coords) as model:
68+
ZeroSumNormal("x", core_dims=("b",), dims=("a", "b"))
69+
ZeroSumNormal("y", sigma=3, core_dims=("b",), dims=("a", "b"))
70+
ZeroSumNormal("z", core_dims=("a", "b"), dims=("a", "b"))
71+
72+
with Model(coords=coords) as reference_model:
73+
regular_distributions.ZeroSumNormal("x", dims=("a", "b"))
74+
regular_distributions.ZeroSumNormal("y", sigma=3, n_zerosum_axes=1, dims=("a", "b"))
75+
regular_distributions.ZeroSumNormal("z", n_zerosum_axes=2, dims=("a", "b"))
76+
77+
assert_equivalent_random_graph(model, reference_model)
78+
# Logp is correct, but we have join(..., -1) and join(..., 1), that don't get canonicalized to the same
79+
# Should work once https://github.com/pymc-devs/pytensor/issues/1505 is fixed
80+
# assert_equivalent_logp_graph(model, reference_model)

tests/dims/test_model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,56 @@ def test_complex_model():
172172
tune=200, chains=2, draws=50, compute_convergence_checks=False, progressbar=False
173173
)
174174
pm.sample_posterior_predictive(idata, progressbar=False)
175+
176+
177+
def test_zerosumnormal_model():
178+
coords = {"time": range(5), "item": range(3)}
179+
180+
with pm.Model(coords=coords) as model:
181+
zsn_item = pmd.ZeroSumNormal("zsn_item", core_dims="item", dims=("time", "item"))
182+
zsn_time = pmd.ZeroSumNormal("zsn_time", core_dims="time", dims=("time", "item"))
183+
zsn_item_time = pmd.ZeroSumNormal("zsn_item_time", core_dims=("item", "time"))
184+
assert zsn_item.type.dims == ("time", "item")
185+
assert zsn_time.type.dims == ("time", "item")
186+
assert zsn_item_time.type.dims == ("item", "time")
187+
188+
zsn_item_draw, zsn_time_draw, zsn_item_time_draw = pm.draw(
189+
[zsn_item, zsn_time, zsn_item_time], random_seed=1
190+
)
191+
assert zsn_item_draw.shape == (5, 3)
192+
np.testing.assert_allclose(zsn_item_draw.mean(-1), 0, atol=1e-13)
193+
assert not np.allclose(zsn_item_draw.mean(0), 0, atol=1e-13)
194+
195+
assert zsn_time_draw.shape == (5, 3)
196+
np.testing.assert_allclose(zsn_time_draw.mean(0), 0, atol=1e-13)
197+
assert not np.allclose(zsn_time_draw.mean(-1), 0, atol=1e-13)
198+
199+
assert zsn_item_time_draw.shape == (3, 5)
200+
np.testing.assert_allclose(zsn_item_time_draw.mean(), 0, atol=1e-13)
201+
202+
with pm.Model(coords=coords) as ref_model:
203+
# Check that the ZeroSumNormal can be used in a model
204+
pm.ZeroSumNormal("zsn_item", dims=("time", "item"))
205+
pm.ZeroSumNormal("zsn_time", dims=("item", "time"))
206+
pm.ZeroSumNormal("zsn_item_time", n_zerosum_axes=2, dims=("item", "time"))
207+
208+
# Check initial_point and logp
209+
ip = model.initial_point()
210+
ref_ip = ref_model.initial_point()
211+
assert ip.keys() == ref_ip.keys()
212+
for i, (ip_value, ref_ip_value) in enumerate(zip(ip.values(), ref_ip.values())):
213+
if i == 1:
214+
# zsn_time is actually transposed in the original model
215+
ip_value = ip_value.T
216+
np.testing.assert_allclose(ip_value, ref_ip_value)
217+
218+
logp_fn = model.compile_logp()
219+
ref_logp_fn = ref_model.compile_logp()
220+
np.testing.assert_allclose(logp_fn(ip), ref_logp_fn(ref_ip))
221+
222+
# Test a new point
223+
rng = np.random.default_rng(68)
224+
new_ip = ip.copy()
225+
for key in new_ip:
226+
new_ip[key] += rng.uniform(size=new_ip[key].shape)
227+
np.testing.assert_allclose(logp_fn(new_ip), ref_logp_fn(new_ip))

0 commit comments

Comments
 (0)