Skip to content

Commit 329ebe7

Browse files
committed
rename approximator.summaries to summarize with deprecation
1 parent 212af40 commit 329ebe7

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
import keras
6+
import warnings
67

78
from bayesflow.adapters import Adapter
89
from bayesflow.networks import InferenceNetwork, SummaryNetwork
@@ -539,7 +540,7 @@ def _sample(
539540
batch_shape, conditions=inference_conditions, **filter_kwargs(kwargs, self.inference_network.sample)
540541
)
541542

542-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
543+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
543544
"""
544545
Computes the learned summary statistics of given summary variables.
545546
@@ -570,6 +571,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
570571

571572
return summaries
572573

574+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
575+
"""
576+
.. deprecated:: 2.0.4
577+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
578+
"""
579+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
580+
return self.summarize(data=data, **kwargs)
581+
573582
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
574583
"""
575584
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import keras
44
import numpy as np
5+
import warnings
56

67
from bayesflow.adapters import Adapter
78
from bayesflow.datasets import OnlineDataset
@@ -404,7 +405,7 @@ def predict(
404405

405406
return keras.ops.convert_to_numpy(keras.ops.softmax(output) if probs else output)
406407

407-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
408+
def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
408409
"""
409410
Computes the learned summary statistics of given summary variables.
410411
@@ -435,6 +436,14 @@ def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
435436

436437
return summaries
437438

439+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
440+
"""
441+
.. deprecated:: 2.0.4
442+
`summaries` will be removed in version 2.0.5, it was renamed to `summarize` which should be used instead.
443+
"""
444+
warnings.warn("`summaries` was renamed to `summarize` and will be removed in version 2.0.5.", FutureWarning)
445+
return self.summarize(data=data, **kwargs)
446+
438447
def _compute_logits(self, classifier_conditions: Tensor) -> Tensor:
439448
"""Helper to compute projected logits from the classifier network."""
440449
logits = self.classifier_network(classifier_conditions)

bayesflow/diagnostics/metrics/model_misspecification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def summary_space_comparison(
142142
"statistics, or want to compare raw data and not summary statistics, please use the "
143143
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
144144
)
145-
observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
146-
reference_summaries = convert_to_numpy(approximator.summaries(reference_data))
145+
observed_summaries = convert_to_numpy(approximator.summarize(observed_data))
146+
reference_summaries = convert_to_numpy(approximator.summarize(reference_data))
147147

148148
distance_observed, distance_null = bootstrap_comparison(
149149
observed_samples=observed_summaries,

tests/test_approximators/test_summaries.py renamed to tests/test_approximators/test_summarize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55

66
def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
77
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
8-
summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
8+
summaries = approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))})
99
assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))
1010

1111

1212
def test_no_summary_network(approximator_with_summaries, monkeypatch):
1313
monkeypatch.setattr(approximator_with_summaries, "summary_network", None)
1414

1515
with pytest.raises(ValueError):
16-
approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
16+
approximator_with_summaries.summarize({"summary_variables": keras.ops.ones((2, 3))})
1717

1818

1919
def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
2020
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
2121

2222
with pytest.raises(ValueError):
23-
approximator_with_summaries.summaries({})
23+
approximator_with_summaries.summarize({})

0 commit comments

Comments
 (0)