From b45f727ae841b435f73520283229b7c048d90008 Mon Sep 17 00:00:00 2001 From: Jordan Deklerk <111652310+jordandeklerk@users.noreply.github.com> Date: Sat, 22 Nov 2025 14:04:58 -0500 Subject: [PATCH 1/3] feat: add calculate_khat_bin_edges function for Pareto k diagnostics --- src/arviz_stats/base/stats_utils.py | 46 ++++++++++++++++++++++++++++- tests/base/test_stats_utils.py | 22 +++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/arviz_stats/base/stats_utils.py b/src/arviz_stats/base/stats_utils.py index f3421af5..704d31dd 100644 --- a/src/arviz_stats/base/stats_utils.py +++ b/src/arviz_stats/base/stats_utils.py @@ -5,7 +5,7 @@ import numpy as np -__all__ = ["make_ufunc"] +__all__ = ["make_ufunc", "calculate_khat_bin_edges"] _log = logging.getLogger(__name__) @@ -298,3 +298,47 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar _log.info(error_msg) return nan_error | chain_error | draw_error + + +def calculate_khat_bin_edges(ary, thresholds, tolerance=1e-9): + """Calculate edges for Pareto k diagnostic bins. + + Parameters + ---------- + ary : array_like + Pareto k values to bin + thresholds : sequence of float + Diagnostic threshold values to use as potential bin edges (e.g., [0.7, 1.0]) + tolerance : float, default 1e-9 + Numerical tolerance for edge comparisons to avoid duplicate edges + + Returns + ------- + bin_edges : list of float or None + Calculated bin edges suitable for np.histogram, or None if edges cannot + be computed. + """ + if not ary.size: + return None + + ymin = np.nanmin(ary) + ymax = np.nanmax(ary) + + if not (np.isfinite(ymin) and np.isfinite(ymax)): + return None + + bin_edges = [ymin] + + for edge in thresholds: + if ( + edge is not None + and np.isfinite(edge) + and bin_edges[-1] + tolerance < edge < ymax - tolerance + ): + bin_edges.append(edge) + + if ymax > bin_edges[-1] + tolerance: + bin_edges.append(ymax) + else: + bin_edges[-1] = ymax + return bin_edges if len(bin_edges) > 1 else None diff --git a/tests/base/test_stats_utils.py b/tests/base/test_stats_utils.py index e48f53b4..c79da5a1 100644 --- a/tests/base/test_stats_utils.py +++ b/tests/base/test_stats_utils.py @@ -7,8 +7,8 @@ from numpy.testing import assert_array_almost_equal from scipy.special import logsumexp +from arviz_stats.base.stats_utils import calculate_khat_bin_edges, make_ufunc, not_valid from arviz_stats.base.stats_utils import logsumexp as _logsumexp -from arviz_stats.base.stats_utils import make_ufunc, not_valid @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64]) @@ -488,3 +488,23 @@ def test_logsumexp_loo_bounds(rng): assert np.all(result <= max_vals) assert np.all(result >= mean_vals - 5) + + +def test_calculate_khat_bin_edges(): + values = np.array([0.3, 0.5, 0.8, 0.9, 1.2, 1.5]) + thresholds = [0.7, 1.0] + bin_edges = calculate_khat_bin_edges(values, thresholds) + + assert bin_edges is not None + assert len(bin_edges) == 4 + assert bin_edges[0] == 0.3 + assert bin_edges[1] == 0.7 + assert bin_edges[2] == 1.0 + assert bin_edges[3] == 1.5 + + assert calculate_khat_bin_edges(np.array([]), thresholds) is None + + nan_values = np.array([0.3, 0.5, np.nan, 0.9]) + bin_edges_nan = calculate_khat_bin_edges(nan_values, thresholds) + assert bin_edges_nan is not None + assert len(bin_edges_nan) == 3 From b7211b0fcf3c2ccd070479bbdb89b68172208523 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Mon, 24 Nov 2025 09:18:15 +0200 Subject: [PATCH 2/3] Add newline at end of stats_utils.py --- src/arviz_stats/base/stats_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/arviz_stats/base/stats_utils.py b/src/arviz_stats/base/stats_utils.py index c23368a3..db35c265 100644 --- a/src/arviz_stats/base/stats_utils.py +++ b/src/arviz_stats/base/stats_utils.py @@ -374,4 +374,5 @@ def round_num(value, precision): return round(value, sig_digits - int(np.floor(np.log10(abs(value)))) - 1) return value - \ No newline at end of file + + From e8d79f59e2f06fd58afffb5d8a720e4ec9c446a4 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 24 Nov 2025 09:41:35 +0200 Subject: [PATCH 3/3] lint --- src/arviz_stats/base/stats_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/arviz_stats/base/stats_utils.py b/src/arviz_stats/base/stats_utils.py index db35c265..8efb109e 100644 --- a/src/arviz_stats/base/stats_utils.py +++ b/src/arviz_stats/base/stats_utils.py @@ -374,5 +374,3 @@ def round_num(value, precision): return round(value, sig_digits - int(np.floor(np.log10(abs(value)))) - 1) return value - -