Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen Mar 29, 2025
b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 11, 2025
d734c68
docstrings
ColtAllen Apr 15, 2025
71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Apr 15, 2025
48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 15, 2025
93c4a60
unit tests
ColtAllen Apr 20, 2025
d2e72b5
alpha min value
ColtAllen Apr 20, 2025
8685005
revert alpha lim
ColtAllen Apr 21, 2025
026f182
small lam value tests
ColtAllen Apr 22, 2025
d12dd0b
ruff formatting
ColtAllen Apr 22, 2025
bcd9cac
TODOs
ColtAllen Apr 22, 2025
78be107
WIP add covar support to RV
ColtAllen Apr 22, 2025
f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen Jun 20, 2025
8a30459
WIP time indexing
ColtAllen Jun 20, 2025
7c7afc8
WIP time indexing
ColtAllen Jun 20, 2025
fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Jun 20, 2025
b957333
WIP symbolic indexing
ColtAllen Jun 20, 2025
d0c1d98
delete test_simple.py
ColtAllen Jun 20, 2025
264c55e
fix symbolic indexing errors
ColtAllen Jul 11, 2025
05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 11, 2025
0fa3390
clean up cursor code
ColtAllen Jul 11, 2025
5baa6f7
warn for ndims deprecation
ColtAllen Jul 11, 2025
a715ec7
clean up comments and final TODO
ColtAllen Jul 11, 2025
f3c0f29
remove ndims deprecation and extraneous code
ColtAllen Jul 11, 2025
a232e4c
revert changes to irrelevant test
ColtAllen Jul 12, 2025
ffc059f
remove time_covariate_vector default args
ColtAllen Jul 12, 2025
1d41eb7
revert remaining changes in irrelevant tests
ColtAllen Jul 12, 2025
47ad523
remove test_sampling_consistency
ColtAllen Jul 12, 2025
5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
b34e3d8
make C_t external function, code cleanup
ColtAllen Jul 12, 2025
9803321
rng_fn cleanup
ColtAllen Jul 13, 2025
5ff6853
WIP test frameworks
ColtAllen Jul 13, 2025
63a0b10
inverse cdf
ColtAllen Jul 15, 2025
932a046
covariate pos constraint and WIP RV
ColtAllen Jul 15, 2025
b78a5c4
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 28, 2025
ab45a9c
WIP rng_fn testing
ColtAllen Jul 28, 2025
0d1dcea
WIP time covars required param
ColtAllen Jul 29, 2025
434e5a5
C_t for RV time covar support
ColtAllen Aug 10, 2025
c66c8a6
time_covar optional param
ColtAllen Aug 10, 2025
fb96220
restore GPT5 code
ColtAllen Aug 13, 2025
a0ed4f5
Merge branch 'main' into grassia2geo-dist
ColtAllen Aug 30, 2025
c9f5dc2
init commit, WIP testing
ColtAllen Aug 30, 2025
217c521
TODOs and WIP logp testing
ColtAllen Aug 30, 2025
d39ec3e
WIP recursive logp
ColtAllen Aug 31, 2025
cd7815a
revert to beta logp
ColtAllen Sep 3, 2025
4f56c2a
remove logcdf
ColtAllen Sep 3, 2025
3fd1cc7
docstrings
ColtAllen Sep 3, 2025
11a8cac
docstrings plot code
ColtAllen Sep 3, 2025
d5a98c6
WIP test log_p
ColtAllen Sep 4, 2025
e61b702
Merge branch 'pymc-devs:main' into sbg-dist
ColtAllen Sep 4, 2025
e3743a3
cleanup tests
ColtAllen Sep 4, 2025
3d44733
add logcdf
ColtAllen Sep 9, 2025
12d2c36
SymbolicRandomVariable
ColtAllen Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pymc_extras/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pymc_extras.distributions.discrete import (
BetaNegativeBinomial,
GeneralizedPoisson,
ShiftedBetaGeometric,
Skellam,
)
from pymc_extras.distributions.histogram_utils import histogram_approximation
Expand All @@ -39,4 +40,6 @@
"PartialOrder",
"Skellam",
"histogram_approximation",
"ShiftedBetaGeometric",
"PartialOrder",
]
149 changes: 148 additions & 1 deletion pymc_extras/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
import pymc as pm

from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.pytensorf import normalize_rng_param
from pytensor import tensor as pt
from pytensor.tensor.random.basic import beta as beta_rng
from pytensor.tensor.random.basic import geometric as geometric_rng
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param


def log1mexp(x):
Expand Down Expand Up @@ -399,3 +404,145 @@ def dist(cls, mu1, mu2, **kwargs):
class_name="Skellam",
**kwargs,
)


class ShiftedBetaGeometricRV(RandomVariable):
name = "sbg"
extended_signature = "[rng],[size],(),()->[rng],()"
signature = "(),()->()"
_print_name = ("ShiftedBetaGeometric", "\\operatorname{ShiftedBetaGeometric}")

@classmethod
def rv_op(cls, alpha, beta, *, size=None, rng=None):
alpha = pt.as_tensor(alpha)
beta = pt.as_tensor(beta)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(alpha, beta, ndims_params=cls.ndims_params)

next_rng, p = beta_rng(a=alpha, b=beta, size=size, rng=rng).owner.outputs

draws = geometric_rng(p, size=size)
draws = draws.astype("int64")

return cls(inputs=[rng, size, alpha, beta], outputs=[next_rng, draws])(
rng, size, alpha, beta
)


sbg = ShiftedBetaGeometricRV()


class ShiftedBetaGeometric(Discrete):
r"""Shifted Beta-Geometric distribution.

This mixture distribution extends the Geometric distribution to support heterogeneity across observations.

Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:

.. math::
\mathbb{P}(T=t|\alpha,\beta) = (\frac{B(\alpha+1,\beta+t-1)}{B(\alpha,\beta}),t=1,2,... \\
\begin{align}
\mathbb{S}(t|\alpha,\beta) = (\frac{B(\alpha,\beta+t)}{B(\alpha,\beta}),t=1,2,... \\
\end{align}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
from scipy.special import beta
import arviz as az

plt.style.use('arviz-darkgrid')
t = np.arange(1, 11)
alpha_vals = [.1, .1, 1, 1]
beta_vals = [.5, 1, 1, 4]
for alpha, _beta in zip(alpha_vals, beta_vals):
pmf = beta(alpha + 1, _beta + t - 1) / beta(alpha, _beta)
plt.plot(t, pmf, '-o', label=r'$\alpha$ = {}, $beta$ = {}'.format(alpha, _beta))
plt.xlabel('t', fontsize=12)
plt.ylabel('p(t)', fontsize=12)
plt.legend(loc=1)
plt.show()

======== ===============================================
Support :math:`t \in \mathbb{N}_{>0}`
======== ===============================================

Parameters
----------
alpha : tensor_like of float
Scale parameter (alpha > 0).
beta : tensor_like of float
Scale parameter (beta > 0).

References
----------
.. [1] Fader, P. S., & Hardie, B. G. (2007). How to project customer retention.
Journal of Interactive Marketing, 21(1), 76-90.
https://faculty.wharton.upenn.edu/wp-content/uploads/2012/04/Fader_hardie_jim_07.pdf
"""

rv_op = sbg

@classmethod
def dist(cls, alpha, beta, *args, **kwargs):
alpha = pt.as_tensor_variable(alpha)
beta = pt.as_tensor_variable(beta)

return super().dist([alpha, beta], *args, **kwargs)

def logp(value, alpha, beta):
"""From Expression (5) on p.6 of Fader & Hardie (2007)"""

logp = betaln(alpha + 1, beta + value - 1) - betaln(alpha, beta)

logp = pt.switch(
pt.or_(
pt.lt(value, 1),
pt.or_(
alpha <= 0,
beta <= 0,
),
),
-np.inf,
logp,
)

return check_parameters(
logp,
alpha > 0,
beta > 0,
msg="alpha > 0, beta > 0",
)

def logcdf(value, alpha, beta):
"""Adapted from Expression (6) on p.6 of Fader & Hardie (2007)"""
# survival function from paper
logS = (
pt.gammaln(beta + value)
- pt.gammaln(beta)
+ pt.gammaln(alpha + beta)
- pt.gammaln(alpha + beta + value)
)
# log(1-exp(logS))
return pt.log1mexp(logS)

def support_point(rv, size, alpha, beta):
"""Calculate a reasonable starting point for sampling.

For the Shifted Beta-Geometric distribution, we use a point estimate based on
the expected value of the mixture components.
"""
geo_mean = pt.ceil(
pt.reciprocal(
alpha / (alpha + beta) # expected value of the beta distribution
) # expected value of the geometric distribution
)
if not rv_size_is_none(size):
geo_mean = pt.full(size, geo_mean)
return geo_mean
129 changes: 129 additions & 0 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pymc as pm
import pytensor
Expand All @@ -23,16 +24,19 @@
BaseTestDistributionRandom,
Domain,
I,
Nat,
Rplus,
assert_support_point_is_expected,
check_logp,
check_selfconsistency_discrete_logcdf,
discrete_random_tester,
)
from pytensor import config

from pymc_extras.distributions import (
BetaNegativeBinomial,
GeneralizedPoisson,
ShiftedBetaGeometric,
Skellam,
)

Expand Down Expand Up @@ -208,3 +212,128 @@ def test_logp(self):
{"mu1": Rplus_small, "mu2": Rplus_small},
lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
)


class TestShiftedBetaGeometric:
class TestRandomVariable(BaseTestDistributionRandom):
pymc_dist = ShiftedBetaGeometric
pymc_dist_params = {"alpha": 1.0, "beta": 1.0}
expected_rv_op_params = {"alpha": 1.0, "beta": 1.0}
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
]

def test_random_basic_properties(self):
"""Test basic random sampling properties"""
# Test with standard parameter values
alpha_vals = [1.0, 0.5, 2.0]
beta_vals = [0.5, 1.0, 2.0]

for alpha in alpha_vals:
for beta in beta_vals:
dist = self.pymc_dist.dist(alpha=alpha, beta=beta, size=1000)
draws = dist.eval()

# Check basic properties
assert np.all(draws > 0)
assert np.all(draws.astype(int) == draws)
assert np.mean(draws) > 0
assert np.var(draws) > 0

def test_random_edge_cases(self):
"""Test with very small and large beta and alpha values"""
beta_vals = [20.0, 0.08, 18.0, 0.06]
alpha_vals = [20.0, 14.0, 0.07, 0.06]

for beta in beta_vals:
for alpha in alpha_vals:
dist = self.pymc_dist.dist(alpha=alpha, beta=beta, size=1000)
draws = dist.eval()

# Check basic properties
assert np.all(draws > 0)
assert np.all(draws.astype(int) == draws)
assert np.mean(draws) > 0
assert np.var(draws) > 0

def test_logp(self):
alpha = pt.scalar("alpha")
beta = pt.scalar("beta")
value = pt.vector(dtype="int64")

# Compile logp function for testing
dist = ShiftedBetaGeometric.dist(alpha, beta)
logp = pm.logp(dist, value)
logp_fn = pytensor.function([value, alpha, beta], logp)

# Test basic properties of logp
test_value = np.array([1, 2, 3, 4, 5])
logp_vals = logp_fn(test_value, 1.2, 3.4)
assert not np.any(np.isnan(logp_vals))
assert np.all(np.isfinite(logp_vals))

neg_value = np.array([-1], dtype="int64")
assert logp_fn(neg_value, alpha=5, beta=1)[0] == -np.inf

# Check alpha/beta restrictions
valid_value = np.array([1], dtype="int64")
with pytest.raises(ParameterValueError):
logp_fn(valid_value, alpha=-1, beta=2) # invalid neg alpha
with pytest.raises(ParameterValueError):
logp_fn(valid_value, alpha=0, beta=0) # invalid zero alpha and beta
with pytest.raises(ParameterValueError):
logp_fn(valid_value, alpha=1, beta=-1) # invalid neg beta

def test_logp_matches_paper(self):
# Reference: Fader & Hardie (2007), Appendix B (Figure B1, cells B6:B12)
# https://faculty.wharton.upenn.edu/wp-content/uploads/2012/04/Fader_hardie_jim_07.pdf
alpha = 1.0
beta = 1.0
t_vec = np.array([1, 2, 3, 4, 5, 6, 7], dtype="int64")
p_t = np.array([0.5, 0.167, 0.083, 0.05, 0.033, 0.024, 0.018], dtype="float64")
expected = np.log(p_t)

alpha_ = pt.scalar("alpha")
beta_ = pt.scalar("beta")
value = pt.vector(dtype="int64")
logp = pm.logp(ShiftedBetaGeometric.dist(alpha_, beta_), value)
logp_fn = pytensor.function([value, alpha_, beta_], logp)
np.testing.assert_allclose(logp_fn(t_vec, alpha, beta), expected, rtol=1e-2)

def test_logcdf(self):
check_selfconsistency_discrete_logcdf(
distribution=ShiftedBetaGeometric,
domain=Nat,
paramdomains={"alpha": Rplus, "beta": Rplus},
)

@pytest.mark.parametrize(
"alpha, beta, size, expected_shape",
[
(1.0, 1.0, None, ()), # Scalar output
([1.0, 2.0], 1.0, None, (2,)), # Vector output from alpha
(1.0, [1.0, 2.0], None, (2,)), # Vector output from beta
([1.0, 2.0], [1.0, 2.0], None, (2,)), # Vector output from alpha and beta
(1.0, [1.0, 2.0], (1, 2), (1, 2)), # Explicit size with scalar alpha and vector beta
],
)
def test_support_point(self, alpha, beta, size, expected_shape):
"""Test that support_point returns reasonable values with correct shapes"""
with pm.Model() as model:
ShiftedBetaGeometric("x", alpha=alpha, beta=beta, size=size)

init_point = model.initial_point()["x"]

# Check shape
assert init_point.shape == expected_shape

# Check values are positive integers
assert np.all(init_point > 0)
assert np.all(init_point.astype(int) == init_point)

# Check values are finite and reasonable
assert np.all(np.isfinite(init_point))
assert np.all(init_point < 1e6) # Should not be extremely large

assert_support_point_is_expected(model, init_point)
Loading