From 682dd1068ef3405539d8bf54452319639a63cf71 Mon Sep 17 00:00:00 2001 From: elseml <60779710+elseml@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:31:51 +0200 Subject: [PATCH 1/3] Move docstring to comment --- bayesflow/adapters/transforms/nnpe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/nnpe.py b/bayesflow/adapters/transforms/nnpe.py index 60ee9dcf0..36c8227cb 100644 --- a/bayesflow/adapters/transforms/nnpe.py +++ b/bayesflow/adapters/transforms/nnpe.py @@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd return data + noise def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: - """Non-invertible transform.""" + # Non-invertible transform. return data def get_config(self) -> dict: From 61865acf8d68d4fcefdc16cbd0a01cc59e86eb7f Mon Sep 17 00:00:00 2001 From: elseml <60779710+elseml@users.noreply.github.com> Date: Mon, 16 Jun 2025 11:25:01 +0200 Subject: [PATCH 2/3] Always cast to _resolve_scale --- bayesflow/adapters/transforms/nnpe.py | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/bayesflow/adapters/transforms/nnpe.py b/bayesflow/adapters/transforms/nnpe.py index 36c8227cb..91b92a945 100644 --- a/bayesflow/adapters/transforms/nnpe.py +++ b/bayesflow/adapters/transforms/nnpe.py @@ -65,8 +65,8 @@ class NNPE(ElementwiseTransform): def __init__( self, *, - spike_scale: float | np.ndarray | None = None, - slab_scale: float | np.ndarray | None = None, + spike_scale: np.typing.ArrayLike | None = None, + slab_scale: np.typing.ArrayLike | None = None, per_dimension: bool = True, seed: int | None = None, ): @@ -80,14 +80,14 @@ def __init__( def _resolve_scale( self, name: str, - passed: float | np.ndarray | None, + passed: np.typing.ArrayLike | None, default: float, data: np.ndarray, ) -> np.ndarray | float: """ Determine spike/slab scale: - - If passed is None: Automatic determination via default * std(data) (per‐dimension or global). - - Else: validate & cast passed to the correct shape/type. + - If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global). + - Else: validate & cast `passed to the correct shape/type. Parameters ---------- @@ -103,8 +103,8 @@ def _resolve_scale( Returns ------- - float or np.ndarray - The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1] + np.ndarray + The resolved scale, either as a 0D array (if per_dimension=False) or an 1D array of length data.shape[-1] (if per_dimension=True). """ @@ -119,22 +119,22 @@ def _resolve_scale( # If no scale is passed, determine scale automatically given the dimensionwise or global std if passed is None: - return default * std + return np.array(default * std) # If a scale is passed, check if the passed shape matches the expected shape else: - if self.per_dimension: + try: arr = np.asarray(passed, dtype=float) - if arr.shape != expected_shape or arr.ndim != 1: + except Exception as e: + raise TypeError(f"{name}: expected values convertible to float, got {type(passed).__name__}") from e + + if self.per_dimension: + if arr.ndim != 1 or arr.shape != expected_shape: raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}") return arr else: - try: - scalar = float(passed) - except TypeError: - raise TypeError(f"{name}: expected a scalar convertible to float, got type {type(passed).__name__}") - except ValueError: - raise ValueError(f"{name}: expected a scalar convertible to float, got value {passed!r}") - return scalar + if arr.ndim != 0: + raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}") + return arr def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray: """ @@ -173,7 +173,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd return data + noise def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: - # Non-invertible transform. + # Non-invertible transform return data def get_config(self) -> dict: From 8fb4158dddef84170092cc225563bd39193977e1 Mon Sep 17 00:00:00 2001 From: elseml <60779710+elseml@users.noreply.github.com> Date: Mon, 16 Jun 2025 14:11:58 +0200 Subject: [PATCH 3/3] Fix typo --- bayesflow/adapters/transforms/nnpe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/adapters/transforms/nnpe.py b/bayesflow/adapters/transforms/nnpe.py index 91b92a945..b48847c12 100644 --- a/bayesflow/adapters/transforms/nnpe.py +++ b/bayesflow/adapters/transforms/nnpe.py @@ -87,7 +87,7 @@ def _resolve_scale( """ Determine spike/slab scale: - If `passed` is None: Automatic determination via default * std(data) (per‐dimension or global). - - Else: validate & cast `passed to the correct shape/type. + - Else: Validate & cast `passed` to the correct shape/type. Parameters ----------