Skip to content

Commit 581456b

Browse files
committed
Allow Dim version of simple SymbolicRandomVariables
1 parent d79fb3c commit 581456b

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

pymc/dims/distributions/scalar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ def _get_sigma_from_either_sigma_or_tau(*, sigma, tau):
4040

4141

4242
class Flat(DimDistribution):
43-
xrv_op = pxr._as_xrv(flat)
43+
xrv_op = pxr.as_xrv(flat)
4444

4545
@classmethod
4646
def dist(cls, **kwargs):
4747
return super().dist([], **kwargs)
4848

4949

5050
class HalfFlat(PositiveDimDistribution):
51-
xrv_op = pxr._as_xrv(halfflat, [], ())
51+
xrv_op = pxr.as_xrv(halfflat, [], ())
5252

5353
@classmethod
5454
def dist(cls, **kwargs):
@@ -102,7 +102,7 @@ def xrv_op(self, nu, sigma, core_dims=None, extra_dims=None, rng=None):
102102
nu = as_xtensor(nu)
103103
sigma = as_xtensor(sigma)
104104
core_rv = HalfStudentTRV.rv_op(nu=nu.values, sigma=sigma.values).owner.op
105-
xop = pxr._as_xrv(core_rv)
105+
xop = pxr.as_xrv(core_rv)
106106
return xop(nu, sigma, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
107107

108108

pymc/distributions/distribution.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,27 @@ def __init__(
370370

371371
kwargs.setdefault("inline", True)
372372
kwargs.setdefault("strict", True)
373+
# Many RVS have a size argument, even when this is `None` and is therefore unused
374+
kwargs.setdefault("on_unused_input", "ignore")
373375
super().__init__(*args, **kwargs)
374376

377+
def make_node(self, *inputs):
378+
# If we try to build the RV with a different size type (vector -> None or None -> vector)
379+
# We need to rebuild the Op with new size type in the inner graph
380+
if self.extended_signature is not None:
381+
(rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs(
382+
self.extended_signature
383+
)
384+
if size_arg_idx is not None and len(rng_arg_idxs) == 1:
385+
new_size_type = normalize_size_param(inputs[size_arg_idx]).type
386+
if not self.input_types[size_arg_idx].in_same_class(new_size_type):
387+
params = [inputs[idx] for idx in param_idxs]
388+
size = inputs[size_arg_idx]
389+
rng = inputs[rng_arg_idxs[0]]
390+
return self.rebuild_rv(*params, size=size, rng=rng).owner
391+
392+
return super().make_node(*inputs)
393+
375394
def update(self, node: Apply) -> dict[Variable, Variable]:
376395
"""Symbolic update expression for input random state variables.
377396

tests/dims/distributions/test_scalar.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
HalfCauchy,
2323
HalfFlat,
2424
HalfNormal,
25+
HalfStudentT,
2526
InverseGamma,
2627
Laplace,
2728
LogNormal,
@@ -119,6 +120,22 @@ def test_studentt():
119120
assert_equivalent_logp_graph(model, reference_model)
120121

121122

123+
def test_halfstudentt():
124+
coords = {"a": range(3)}
125+
with Model(coords=coords) as model:
126+
HalfStudentT("x", nu=1, dims="a")
127+
HalfStudentT("y", nu=1, sigma=3, dims="a")
128+
HalfStudentT("z", nu=1, lam=3, dims="a")
129+
130+
with Model(coords=coords) as reference_model:
131+
regular_distributions.HalfStudentT("x", nu=1, dims="a")
132+
regular_distributions.HalfStudentT("y", nu=1, sigma=3, dims="a")
133+
regular_distributions.HalfStudentT("z", nu=1, lam=3, dims="a")
134+
135+
assert_equivalent_random_graph(model, reference_model)
136+
assert_equivalent_logp_graph(model, reference_model)
137+
138+
122139
def test_cauchy():
123140
coords = {"a": range(3)}
124141
with Model(coords=coords) as model:

0 commit comments

Comments
 (0)