Skip to content

Commit 9ac839d

Browse files
Pre-commit
1 parent d040f6a commit 9ac839d

File tree

10 files changed

+308
-208
lines changed

10 files changed

+308
-208
lines changed

pysindy/_core.py

Lines changed: 117 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import warnings
2-
32
from abc import ABC
43
from abc import abstractmethod
54
from itertools import product
@@ -12,11 +11,12 @@
1211
from scipy.integrate import odeint
1312
from scipy.integrate import solve_ivp
1413
from scipy.interpolate import interp1d
14+
from sklearn import set_config
1515
from sklearn.base import BaseEstimator
1616
from sklearn.metrics import r2_score
1717
from sklearn.pipeline import Pipeline
1818
from sklearn.utils.validation import check_is_fitted
19-
from sklearn import set_config
19+
2020
set_config(enable_metadata_routing=True)
2121
from typing_extensions import Self
2222

@@ -348,7 +348,7 @@ def fit(
348348
feature_names : list of string, length n_input_features, optional
349349
Names for the input features (e.g. :code:`['x', 'y', 'z']`).
350350
If None, will use :code:`['x0', 'x1', ...]`.
351-
351+
352352
sample_weight : float or array-like of shape (n_samples,), optional
353353
Per-sample weights for the regression. Passed internally to
354354
the optimizer (e.g. STLSQ). Supports compatibility with
@@ -383,11 +383,15 @@ def fit(
383383
self.feature_names = feature_names
384384

385385
if sample_weight is not None:
386-
mode = "weak" if "Weak" in self.feature_library.__class__.__name__ else "standard"
386+
mode = (
387+
"weak"
388+
if "Weak" in self.feature_library.__class__.__name__
389+
else "standard"
390+
)
387391
sample_weight = _expand_sample_weights(
388392
sample_weight, x, feature_library=self.feature_library, mode=mode
389393
)
390-
394+
391395
steps = [
392396
("features", self.feature_library),
393397
("shaping", SampleConcatter()),
@@ -429,7 +433,7 @@ def predict(self, x, u=None):
429433
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)
430434

431435
check_is_fitted(self, "model")
432-
436+
433437
if self.n_control_features_ > 0 and u is None:
434438
raise TypeError("Model was fit using control variables, so u is required")
435439
if self.n_control_features_ == 0 and u is not None:
@@ -485,7 +489,16 @@ def print(self, lhs=None, precision=3, **kwargs):
485489
names = f"{lhs[i]}"
486490
print(f"{names} = {eqn}", **kwargs)
487491

488-
def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, **metric_kws):
492+
def score(
493+
self,
494+
x,
495+
t,
496+
x_dot=None,
497+
u=None,
498+
metric=r2_score,
499+
sample_weight=None,
500+
**metric_kws,
501+
):
489502
"""
490503
Returns a score for the time derivative prediction produced by the model.
491504
@@ -518,14 +531,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
518531
See `Scikit-learn \
519532
<https://scikit-learn.org/stable/modules/model_evaluation.html>`_
520533
for more options.
521-
534+
522535
sample_weight : array-like of shape (n_samples,), optional
523536
Per-sample weights passed directly to the metric. This is the
524537
preferred way to supply weights.
525538
526539
metric_kws: dict, optional
527540
Optional keyword arguments to pass to the metric function.
528-
541+
529542
530543
Returns
531544
-------
@@ -548,7 +561,7 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
548561

549562
if sample_weight is not None:
550563
sample_weight = _expand_sample_weights(sample_weight, x)
551-
564+
552565
x_dot = concat_sample_axis(x_dot)
553566
x_dot_predict = concat_sample_axis(x_dot_predict)
554567

@@ -557,7 +570,7 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
557570
x_dot, x_dot_predict, return_indices=True
558571
)
559572
sample_weight = sample_weight[good_idx]
560-
metric_kws = {**metric_kws, "sample_weight": sample_weight}
573+
metric_kws = {**metric_kws, "sample_weight": sample_weight}
561574
else:
562575
x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
563576

@@ -945,78 +958,115 @@ def comprehend_and_validate(arr, t):
945958
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
946959
return x, x_dot, u
947960

948-
def _expand_sample_weights(sample_weight, trajectories, feature_library=None, mode="standard"):
949-
"""Expand per-trajectory sample weights for estimators or weak-form libraries.
961+
962+
def _expand_sample_weights(
963+
sample_weight, trajectories, feature_library=None, mode="standard"
964+
):
965+
"""
966+
Expand per-trajectory or per-sample weights for use in SINDy estimators.
950967
951968
Parameters
952969
----------
953-
sample_weight : sequence of array-like or None
954-
Per-trajectory sample weights. Each element corresponds to one trajectory.
970+
sample_weight : sequence of scalars or array-like
971+
Weights for each trajectory. In "standard" mode, each entry can be:
972+
- a scalar weight (applied to all samples in that trajectory), or
973+
- an array of length equal to the number of samples (n_time) for that
974+
trajectory.
975+
In "weak" mode, each entry must be a single scalar weight per trajectory.
976+
955977
trajectories : sequence
956-
Sequence of trajectory objects, each with attributes `n_time` and `n_coord`.
978+
Sequence of trajectory-like objects, each having attributes `n_time` and
979+
`n_coord`.
980+
957981
feature_library : object, optional
958-
Library instance, required when mode='weak'.
982+
Library instance used in weak-form mode. Must define attribute `K`
983+
(the number of weak test functions). If missing, assumes K=1 with a warning.
984+
959985
mode : {'standard', 'weak'}, default='standard'
960-
Expansion mode:
961-
- 'standard' : Concatenate weights per sample or per coordinate.
962-
- 'weak' : Expand weights for weak-form (integral) test functions.
986+
- "standard": Expand per-sample weights to match concatenated samples.
987+
- "weak": Repeat each trajectory’s single scalar weight `K` times.
963988
964989
Returns
965990
-------
966991
np.ndarray or None
967-
Concatenated and expanded sample weights, or None if no weights are given.
992+
A 1D numpy array of concatenated and expanded sample weights,
993+
or None if `sample_weight` is None.
968994
"""
995+
# -------------------------------------------------------------
996+
# Early exit for None
997+
# -------------------------------------------------------------
969998
if sample_weight is None:
970999
return None
9711000

972-
if not (isinstance(sample_weight, Sequence) and not isinstance(sample_weight, np.ndarray)):
973-
raise ValueError("sample_weight must be a list or tuple, not a scalar or numpy array.")
1001+
if not (
1002+
isinstance(sample_weight, Sequence)
1003+
and not isinstance(sample_weight, np.ndarray)
1004+
):
1005+
raise ValueError(
1006+
"sample_weight must be a list or tuple, not a scalar or numpy array."
1007+
)
9741008

9751009
if len(sample_weight) != len(trajectories):
9761010
raise ValueError("sample_weight length must match number of trajectories.")
9771011

978-
# --- Validate shape consistency ---
979-
validated = []
980-
for sw, traj in zip(sample_weight, trajectories):
981-
arr = np.asarray(sw)
982-
if arr.ndim == 0:
983-
validated.append(arr)
984-
continue
985-
if arr.shape[0] != traj.n_time:
986-
raise ValueError("sample_weight entry length does not match trajectory length.")
987-
if arr.ndim == 2 and arr.shape[1] not in (1, traj.n_coord):
988-
raise ValueError("sample_weight 2D second dim must be 1 or equal to n_coord.")
989-
validated.append(arr)
990-
991-
# --- Weak-form expansion ---
1012+
# -------------------------------------------------------------
1013+
# Weak mode: one weight per trajectory, repeated K times
1014+
# -------------------------------------------------------------
9921015
if mode == "weak":
993-
n_funcs = getattr(feature_library, "K", 1)
994-
if n_funcs is None:
995-
warnings.warn("feature_library missing 'K'; assuming 1 test function.")
996-
n_funcs = 1
997-
return np.concatenate([np.repeat(np.asarray(sw), n_funcs, axis=0) for sw in validated])
998-
999-
# --- Standard expansion ---
1000-
n_coords = {int(t.n_coord) for t in trajectories}
1001-
if len(n_coords) != 1:
1002-
raise ValueError("All trajectories must have the same n_coord.")
1003-
n_coord = n_coords.pop()
1004-
1005-
processed = []
1006-
for arr in validated:
1007-
arr = np.asarray(arr)
1008-
if arr.ndim == 1:
1009-
arr = arr.reshape(-1, 1)
1010-
elif arr.ndim == 2 and arr.shape[1] == 1:
1011-
pass # already correct shape
1012-
processed.append(arr)
1013-
1014-
# Promote to n_coord if any arrays have multiple coordinates
1015-
is_scalar_weight = all(a.shape[1] == 1 for a in processed)
1016-
if is_scalar_weight:
1017-
return np.concatenate([a.ravel() for a in processed])
1018-
expanded = [
1019-
np.broadcast_to(a, (a.shape[0], n_coord)) if a.shape[1] == 1 else a
1020-
for a in processed
1021-
]
1022-
return np.concatenate(expanded, axis=0)
1016+
if feature_library is None:
1017+
raise ValueError("feature_library is required in weak mode.")
1018+
1019+
K = getattr(feature_library, "K", None)
1020+
if K is None:
1021+
warnings.warn("feature_library missing 'K'; assuming K=1.", UserWarning)
1022+
K = 1
1023+
1024+
validated = []
1025+
for w, traj in zip(sample_weight, trajectories):
1026+
arr = np.asarray(w)
1027+
if arr.ndim > 0 and arr.size > 1:
1028+
raise ValueError(
1029+
"Weak mode expects exactly one weight per trajectory (scalar), "
1030+
f"but got shape {arr.shape} for trajectory with {traj.n_time}"
1031+
f"samples."
1032+
)
1033+
validated.append(float(arr))
1034+
return np.repeat(validated, K)
1035+
1036+
# -------------------------------------------------------------
1037+
# Standard mode: expand scalars or per-sample arrays
1038+
# -------------------------------------------------------------
1039+
expanded = []
1040+
for w, traj in zip(sample_weight, trajectories):
1041+
arr = np.asarray(w)
1042+
1043+
# Scalar → expand to all samples in trajectory
1044+
if arr.ndim == 0:
1045+
arr = np.full(traj.n_time, arr, dtype=float)
1046+
1047+
# 1D array → must match number of samples
1048+
elif arr.ndim == 1:
1049+
if arr.shape[0] != traj.n_time:
1050+
raise ValueError(
1051+
f"sample_weight length {arr.shape[0]} does"
1052+
f" not match trajectory length {traj.n_time}."
1053+
)
1054+
1055+
# 2D array → only (n,1) allowed
1056+
elif arr.ndim == 2:
1057+
if arr.shape[1] != 1:
1058+
raise ValueError(
1059+
"sample_weight 2D arrays must have second dimension = 1."
1060+
)
1061+
if arr.shape[0] != traj.n_time:
1062+
raise ValueError(
1063+
"sample_weight 2D array length does not match trajectory length."
1064+
)
1065+
arr = arr.ravel()
1066+
1067+
else:
1068+
raise ValueError("Invalid sample_weight shape.")
1069+
1070+
expanded.append(arr.ravel())
1071+
1072+
return np.concatenate(expanded)

pysindy/feature_library/weighted_weak_pde_library.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
2-
from .weak_pde_library import WeakPDELibrary
2+
33
from ..utils import AxesArray
4+
from .weak_pde_library import WeakPDELibrary
45

56

67
class WeightedWeakPDELibrary(WeakPDELibrary):
@@ -39,14 +40,15 @@ def _build_whitener_from_variance(self):
3940

4041
# --- robust weight-field shape handling ---
4142
base_grid = np.asarray(self.spatiotemporal_grid)
42-
expected = tuple(base_grid.shape[:-1]) # e.g. (Nx, Nt) for a 2D grid
43+
expected = tuple(base_grid.shape[:-1]) # e.g. (Nx, Nt) for a 2D grid
4344
var_grid = np.asarray(self.spatiotemporal_weights)
4445

4546
if var_grid.shape == expected + (1,):
4647
var_grid = var_grid[..., 0]
4748
elif var_grid.shape != expected:
4849
raise ValueError(
49-
f"spatiotemporal_weights must have shape {expected} or {expected + (1,)}, "
50+
f"spatiotemporal_weights must have \
51+
shape {expected} or {expected + (1,)}, "
5052
f"got {var_grid.shape}"
5153
)
5254

@@ -86,7 +88,7 @@ def _build_whitener_from_variance(self):
8688
vk = val_lists[k]
8789
Cov[k, k] = np.dot(vk, vk)
8890
idx_k = idx_lists[k]
89-
91+
9092
map_k = dict(zip(idx_k.tolist(), vk.tolist()))
9193
for ell in range(k + 1, K):
9294
s = 0.0
@@ -133,18 +135,17 @@ def _weak_form_setup(self):
133135
self._build_whitener_from_variance()
134136

135137
def convert_u_dot_integral(self, u):
136-
Vy = super().convert_u_dot_integral(u) # (K, 1)
138+
Vy = super().convert_u_dot_integral(u) # (K, 1)
137139
Vy_w = self._apply_whitener(np.asarray(Vy))
138140
return AxesArray(Vy_w, {"ax_sample": 0, "ax_coord": 1})
139141

140142
def transform(self, x_full):
141-
VTheta_list = super().transform(x_full) # list of (K, n_features)
143+
VTheta_list = super().transform(x_full) # list of (K, n_features)
142144
if self._L_chol is None:
143145
return VTheta_list
144146
out = []
145147
for VTheta in VTheta_list:
146148
A = np.asarray(VTheta)
147-
A_w = self._apply_whitener(A) # (K, m)
149+
A_w = self._apply_whitener(A) # (K, m)
148150
out.append(AxesArray(A_w, {"ax_sample": 0, "ax_coord": 1}))
149151
return out
150-

pysindy/optimizers/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sklearn.utils.extmath import safe_sparse_dot
1717
from sklearn.utils.validation import check_is_fitted
1818
from sklearn.utils.validation import check_X_y
19-
from sklearn.base import clone
2019

2120
from .._typing import Float2D
2221
from .._typing import FloatDType
@@ -173,7 +172,7 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
173172
y = AxesArray(np.asarray(y), y_axes)
174173
x_, y = drop_nan_samples(x_, y)
175174
x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True)
176-
175+
177176
x, y, X_offset, y_offset, X_scale = _preprocess_data(
178177
x_,
179178
y,

pysindy/optimizers/stlsq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ class STLSQ(BaseOptimizer):
7878
history_ : list
7979
History of ``coef_``. ``history_[k]`` contains the values of
8080
``coef_`` at iteration k of sequentially thresholded least-squares.
81-
82-
81+
82+
8383
Notes
8484
-----
8585
- Supports ``sample_weight`` during :meth:`fit`. Sample weights are applied

pysindy/optimizers/trapping_sr3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,8 @@ def _objective(self, x, y, coef_sparse, A, PW, k):
517517
if self.verbose and k % max(1, self.max_iter // 10) == 0:
518518
print(
519519
f"{k:5d} ... {sindy_loss:8.3e} ... {relax_loss:8.3e} ... {L1:8.2e}"
520-
f" ... {nonlin_ens_loss:8.2e} ... {cubic_ens_loss:8.2e} ... {obj:8.2e}"
520+
f" ... {nonlin_ens_loss:8.2e} ... {cubic_ens_loss:8.2e}"
521+
f" ... {obj:8.2e}"
521522
)
522523
return obj
523524

pysindy/utils/_axes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@
6969

7070
import numpy as np
7171
from numpy.typing import NDArray
72-
from sklearn.base import BaseEstimator, TransformerMixin
72+
from sklearn.base import BaseEstimator
73+
from sklearn.base import TransformerMixin
7374

7475
HANDLED_FUNCTIONS = {}
7576

@@ -839,6 +840,7 @@ def __sklearn_is_fitted__(self):
839840
def transform(self, x_list):
840841
return concat_sample_axis(x_list)
841842

843+
842844
def concat_sample_axis(x_list: List[AxesArray]):
843845
"""Concatenate all trajectories and axes used to create samples."""
844846
new_arrs = []

0 commit comments

Comments
 (0)