Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/arviz_stats/base/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

__all__ = ["make_ufunc"]
__all__ = ["make_ufunc", "calculate_khat_bin_edges"]

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,6 +300,50 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar
return nan_error | chain_error | draw_error


def calculate_khat_bin_edges(ary, thresholds, tolerance=1e-9):
"""Calculate edges for Pareto k diagnostic bins.

Parameters
----------
ary : array_like
Pareto k values to bin
thresholds : sequence of float
Diagnostic threshold values to use as potential bin edges (e.g., [0.7, 1.0])
tolerance : float, default 1e-9
Numerical tolerance for edge comparisons to avoid duplicate edges

Returns
-------
bin_edges : list of float or None
Calculated bin edges suitable for np.histogram, or None if edges cannot
be computed.
"""
if not ary.size:
return None

ymin = np.nanmin(ary)
ymax = np.nanmax(ary)

if not (np.isfinite(ymin) and np.isfinite(ymax)):
return None

bin_edges = [ymin]

for edge in thresholds:
if (
edge is not None
and np.isfinite(edge)
and bin_edges[-1] + tolerance < edge < ymax - tolerance
):
bin_edges.append(edge)

if ymax > bin_edges[-1] + tolerance:
bin_edges.append(ymax)
else:
bin_edges[-1] = ymax
return bin_edges if len(bin_edges) > 1 else None


def round_num(value, precision):
"""Round a number to a given precision.

Expand Down
22 changes: 21 additions & 1 deletion tests/base/test_stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from numpy.testing import assert_array_almost_equal
from scipy.special import logsumexp

from arviz_stats.base.stats_utils import calculate_khat_bin_edges, make_ufunc, not_valid
from arviz_stats.base.stats_utils import logsumexp as _logsumexp
from arviz_stats.base.stats_utils import make_ufunc, not_valid


@pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64])
Expand Down Expand Up @@ -488,3 +488,23 @@ def test_logsumexp_loo_bounds(rng):

assert np.all(result <= max_vals)
assert np.all(result >= mean_vals - 5)


def test_calculate_khat_bin_edges():
values = np.array([0.3, 0.5, 0.8, 0.9, 1.2, 1.5])
thresholds = [0.7, 1.0]
bin_edges = calculate_khat_bin_edges(values, thresholds)

assert bin_edges is not None
assert len(bin_edges) == 4
assert bin_edges[0] == 0.3
assert bin_edges[1] == 0.7
assert bin_edges[2] == 1.0
assert bin_edges[3] == 1.5

assert calculate_khat_bin_edges(np.array([]), thresholds) is None

nan_values = np.array([0.3, 0.5, np.nan, 0.9])
bin_edges_nan = calculate_khat_bin_edges(nan_values, thresholds)
assert bin_edges_nan is not None
assert len(bin_edges_nan) == 3