Skip to content

Commit e3743a3

Browse files
committed
cleanup tests
1 parent e61b702 commit e3743a3

File tree

2 files changed

+24
-80
lines changed

2 files changed

+24
-80
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,9 @@ class ShiftedBetaGeometricRV(RandomVariable):
411411

412412
@classmethod
413413
def rng_fn(cls, rng, alpha, beta, size):
414-
# Determine output size
415414
if size is None:
416415
size = np.broadcast_shapes(alpha.shape, beta.shape)
417416

418-
# Broadcast parameters to output size
419417
alpha = np.broadcast_to(alpha, size)
420418
beta = np.broadcast_to(beta, size)
421419

@@ -432,13 +430,12 @@ def rng_fn(cls, rng, alpha, beta, size):
432430
class ShiftedBetaGeometric(Discrete):
433431
r"""Shifted Beta-Geometric distribution.
434432
435-
This mixture distribution extends the Geometric distribution for the number of trials until a discrete event
436-
to support heterogeneity across observations.
433+
This mixture distribution extends the Geometric distribution to support heterogeneity across observations.
437434
438435
Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:
439436
440437
.. math::
441-
\mathbb{P}T=t|\alpha,\beta) = (\frac{B(\alpha+1,\beta+t-1)}{B(\alpha,\beta}),t=1,2,... \\
438+
\mathbb{P}(T=t|\alpha,\beta) = (\frac{B(\alpha+1,\beta+t-1)}{B(\alpha,\beta}),t=1,2,... \\
442439
\begin{align}
443440
\mathbb{S}(t|\alpha,\beta) = (\frac{B(\alpha,\beta+t)}{B(\alpha,\beta}),t=1,2,... \\
444441
\end{align}
@@ -491,36 +488,16 @@ def dist(cls, alpha, beta, *args, **kwargs):
491488

492489
return super().dist([alpha, beta], *args, **kwargs)
493490

494-
# TODO: Determine if current period cohorts must be excluded and/or if S(t) must be called and added as well.
495491
def logp(value, alpha, beta):
496-
##### RECURSIVE VARIANT PRESERVED UNTIL PR MERGED #####
497-
# # Number of recursive steps: T = 2..t ⇒ n_steps = max(t-1, 0)
498-
# n_steps = pt.maximum(value - 1, 0)
499-
# t_seq = pt.arange(n_steps, dtype="int64") + 2 # [2, 3, ..., t]
500-
501-
# def step(t, acc, alpha, beta):
502-
# term = pt.log(beta + t - 2) - pt.log(alpha + beta + t - 1)
503-
# return acc + term
504-
505-
# (accs, updates) = scan(
506-
# fn=step,
507-
# sequences=[t_seq],
508-
# outputs_info=pt.as_tensor_variable(0.0),
509-
# non_sequences=[alpha, beta],
510-
# )
511-
512-
# sum_increments = pt.switch(pt.gt(n_steps, 0), accs[-1], 0.0)
513-
# logp = pt.log(alpha / (alpha + beta)) + sum_increments
514-
515492
logp = betaln(alpha + 1, beta + value - 1) - betaln(alpha, beta)
516493

517494
logp = pt.switch(
518495
pt.or_(
496+
pt.lt(value, 1),
519497
pt.or_(
520498
alpha <= 0,
521499
beta <= 0,
522500
),
523-
pt.lt(value, 1),
524501
),
525502
-np.inf,
526503
logp,
@@ -538,7 +515,6 @@ def support_point(rv, size, alpha, beta):
538515
539516
For the Shifted Beta-Geometric distribution, we use a point estimate based on
540517
the expected value of the mixture components.
541-
542518
"""
543519
geo_mean = pt.ceil(
544520
pt.reciprocal(

tests/distributions/test_discrete.py

Lines changed: 21 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,9 @@ def test_random_basic_properties(self):
240240
assert np.var(draws) > 0
241241

242242
def test_random_edge_cases(self):
243-
"""Test edge cases with more reasonable parameter values"""
244-
# Test with small beta and large alpha values
245-
beta_vals = [0.1, 0.5]
246-
alpha_vals = [5.0, 10.0]
243+
"""Test with very small and large beta and alpha values"""
244+
beta_vals = [20.0, 0.07, 18.0, 0.05]
245+
alpha_vals = [20.0, 14.0, 0.06, 0.05]
247246

248247
for beta in beta_vals:
249248
for alpha in alpha_vals:
@@ -256,27 +255,10 @@ def test_random_edge_cases(self):
256255
assert np.mean(draws) > 0
257256
assert np.var(draws) > 0
258257

259-
@pytest.mark.parametrize(
260-
"alpha",
261-
[
262-
(0.5, 1.0, 10.0),
263-
],
264-
)
265-
def test_random_moments(self, alpha):
266-
beta = np.array([0.5, 1.0, 10.0])
267-
dist = self.pymc_dist.dist(alpha=alpha, beta=beta, size=(10_000, len(beta)))
268-
draws = dist.eval()
269-
270-
assert np.all(draws > 0)
271-
assert np.all(draws.astype(int) == draws)
272-
assert np.mean(draws) > 0
273-
assert np.var(draws) > 0
274-
275258
def test_logp(self):
276-
# Create PyTensor variables with explicit values to ensure proper initialization
277259
alpha = pt.scalar("alpha")
278260
beta = pt.scalar("beta")
279-
value = pt.vector("value", dtype="int64")
261+
value = pt.vector(dtype="int64")
280262

281263
# Compile logp function for testing
282264
dist = ShiftedBetaGeometric.dist(alpha, beta)
@@ -289,33 +271,33 @@ def test_logp(self):
289271
assert not np.any(np.isnan(logp_vals))
290272
assert np.all(np.isfinite(logp_vals))
291273

292-
assert logp_fn(-1, alpha=5, beta=1) == -np.inf
274+
neg_value = np.array([-1], dtype="int64")
275+
assert logp_fn(neg_value, alpha=5, beta=1)[0] == -np.inf
293276

294277
# Check alpha/beta restrictions
278+
valid_value = np.array([1], dtype="int64")
295279
with pytest.raises(ParameterValueError):
296-
logp_fn(1, alpha=-1, beta=2)
280+
logp_fn(valid_value, alpha=-1, beta=2) # invalid neg alpha
297281
with pytest.raises(ParameterValueError):
298-
logp_fn(1, alpha=0, beta=0)
282+
logp_fn(valid_value, alpha=0, beta=0) # invalid zero alpha and beta
299283
with pytest.raises(ParameterValueError):
300-
logp_fn(1, alpha=1, beta=-1)
284+
logp_fn(valid_value, alpha=1, beta=-1) # invalid neg beta
301285

302-
def test_logp_matches_paper_alpha1_beta1(self):
303-
# For alpha=1, beta=1, P(T=t) = 1 / (t (t+1)) → logp = -log(t(t+1))
304-
# Derived from B(2, t) / B(1,1); see Appendix B (B3)
305-
# Reference: Fader & Hardie (2007), Appendix B [Figure B1, expression (B3)]
286+
def test_logp_matches_paper(self):
287+
# Reference: Fader & Hardie (2007), Appendix B (Figure B1, cells B6:B12)
306288
# https://faculty.wharton.upenn.edu/wp-content/uploads/2012/04/Fader_hardie_jim_07.pdf
307289
alpha = 1.0
308290
beta = 1.0
309-
t_vec = np.array([1, 2, 3, 4, 5], dtype="int64")
310-
expected = -np.log(t_vec * (t_vec + 1))
291+
t_vec = np.array([1, 2, 3, 4, 5, 6, 7], dtype="int64")
292+
p_t = np.array([0.5, 0.167, 0.083, 0.05, 0.033, 0.024, 0.018], dtype="float64")
293+
expected = np.log(p_t)
311294

312-
alpha_sym = pt.scalar()
313-
beta_sym = pt.scalar()
295+
alpha_ = pt.scalar("alpha")
296+
beta_ = pt.scalar("beta")
314297
value = pt.vector(dtype="int64")
315-
dist = ShiftedBetaGeometric.dist(alpha_sym, beta_sym)
316-
logp = pm.logp(dist, value)
317-
fn = pytensor.function([value, alpha_sym, beta_sym], logp)
318-
np.testing.assert_allclose(fn(t_vec, alpha, beta), expected, rtol=1e-12, atol=1e-12)
298+
logp = pm.logp(ShiftedBetaGeometric.dist(alpha_, beta_), value)
299+
logp_fn = pytensor.function([value, alpha_, beta_], logp)
300+
np.testing.assert_allclose(logp_fn(t_vec, alpha, beta), expected, rtol=1e-2)
319301

320302
@pytest.mark.parametrize(
321303
"alpha, beta, size, expected_shape",
@@ -339,24 +321,10 @@ def test_support_point(self, alpha, beta, size, expected_shape):
339321

340322
# Check values are positive integers
341323
assert np.all(init_point > 0)
342-
# assert np.all(init_point.astype(int) == init_point)
324+
assert np.all(init_point.astype(int) == init_point)
343325

344326
# Check values are finite and reasonable
345327
assert np.all(np.isfinite(init_point))
346328
assert np.all(init_point < 1e6) # Should not be extremely large
347329

348-
# TODO: expected values must be provided
349330
assert_support_point_is_expected(model, init_point)
350-
351-
# TODO: Adapt this to ShiftedBetaGeometric and delete above test?
352-
@pytest.mark.parametrize(
353-
"mu, lam, size, expected",
354-
[
355-
(50, [-0.6, 0, 0.6], None, np.floor(50 / (1 - np.array([-0.6, 0, 0.6])))),
356-
([5, 50], -0.1, (4, 2), np.full((4, 2), np.floor(np.array([5, 50]) / 1.1))),
357-
],
358-
)
359-
def test_moment(self, mu, lam, size, expected):
360-
with pm.Model() as model:
361-
GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
362-
assert_support_point_is_expected(model, expected)

0 commit comments

Comments
 (0)