|
41 | 41 |
|
42 | 42 | import pymc as pm |
43 | 43 |
|
44 | | -from pymc.aesaraf import floatX, intX |
| 44 | +from pymc.aesaraf import change_rv_size, floatX, intX |
45 | 45 | from pymc.distributions import transforms |
46 | 46 | from pymc.distributions.continuous import ( |
47 | 47 | BoundedContinuous, |
@@ -1199,8 +1199,21 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs): |
1199 | 1199 | isinstance(sd_dist, Variable) |
1200 | 1200 | and sd_dist.owner is not None |
1201 | 1201 | and isinstance(sd_dist.owner.op, RandomVariable) |
| 1202 | + and sd_dist.owner.op.ndim_supp < 2 |
1202 | 1203 | ): |
1203 | | - raise TypeError("sd_dist must be a Distribution variable") |
| 1204 | + raise TypeError("sd_dist must be a scalar or vector distribution variable") |
| 1205 | + |
| 1206 | + # We resize the sd_dist automatically so that it has (size x n) independent draws |
| 1207 | + # which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random |
| 1208 | + # and logp methods equivalent, as the latter also assumes a unique value for each |
| 1209 | + # diagonal element. |
| 1210 | + # Since `eta` and `n` are forced to be scalars we don't need to worry about |
| 1211 | + # implied batched dimensions for the time being. |
| 1212 | + if sd_dist.owner.op.ndim_supp == 0: |
| 1213 | + sd_dist = change_rv_size(sd_dist, to_tuple(size) + (n,)) |
| 1214 | + else: |
| 1215 | + # The support shape must be `n` but we have no way of controlling it |
| 1216 | + sd_dist = change_rv_size(sd_dist, to_tuple(size)) |
1204 | 1217 |
|
1205 | 1218 | # sd_dist is part of the generative graph, but should be completely ignored |
1206 | 1219 | # by the logp graph, since the LKJ logp explicitly includes these terms. |
@@ -1288,7 +1301,9 @@ class LKJCholeskyCov: |
1288 | 1301 | n: int |
1289 | 1302 | Dimension of the covariance matrix (n > 1). |
1290 | 1303 | sd_dist: pm.Distribution |
1291 | | - A distribution for the standard deviations, should have `size=n`. |
| 1304 | + A positive scalar or vector distribution for the standard deviations, created |
| 1305 | + with the `.dist()` API. Should have `shape[-1]=n`. Scalar distributions will be |
| 1306 | + automatically resized to ensure this. |
1292 | 1307 | compute_corr: bool, default=True |
1293 | 1308 | If `True`, returns three values: the Cholesky decomposition, the correlations |
1294 | 1309 | and the standard deviations of the covariance matrix. Otherwise, only returns |
|
0 commit comments