11import warnings
2-
32from abc import ABC
43from abc import abstractmethod
54from itertools import product
1211from scipy .integrate import odeint
1312from scipy .integrate import solve_ivp
1413from scipy .interpolate import interp1d
14+ from sklearn import set_config
1515from sklearn .base import BaseEstimator
1616from sklearn .metrics import r2_score
1717from sklearn .pipeline import Pipeline
1818from sklearn .utils .validation import check_is_fitted
19- from sklearn import set_config
19+
2020set_config (enable_metadata_routing = True )
2121from typing_extensions import Self
2222
@@ -348,7 +348,7 @@ def fit(
348348 feature_names : list of string, length n_input_features, optional
349349 Names for the input features (e.g. :code:`['x', 'y', 'z']`).
350350 If None, will use :code:`['x0', 'x1', ...]`.
351-
351+
352352 sample_weight : float or array-like of shape (n_samples,), optional
353353 Per-sample weights for the regression. Passed internally to
354354 the optimizer (e.g. STLSQ). Supports compatibility with
@@ -383,11 +383,15 @@ def fit(
383383 self .feature_names = feature_names
384384
385385 if sample_weight is not None :
386- mode = "weak" if "Weak" in self .feature_library .__class__ .__name__ else "standard"
386+ mode = (
387+ "weak"
388+ if "Weak" in self .feature_library .__class__ .__name__
389+ else "standard"
390+ )
387391 sample_weight = _expand_sample_weights (
388392 sample_weight , x , feature_library = self .feature_library , mode = mode
389393 )
390-
394+
391395 steps = [
392396 ("features" , self .feature_library ),
393397 ("shaping" , SampleConcatter ()),
@@ -429,7 +433,7 @@ def predict(self, x, u=None):
429433 x , _ , u = _comprehend_and_validate_inputs (x , 1 , None , u , self .feature_library )
430434
431435 check_is_fitted (self , "model" )
432-
436+
433437 if self .n_control_features_ > 0 and u is None :
434438 raise TypeError ("Model was fit using control variables, so u is required" )
435439 if self .n_control_features_ == 0 and u is not None :
@@ -485,7 +489,16 @@ def print(self, lhs=None, precision=3, **kwargs):
485489 names = f"{ lhs [i ]} "
486490 print (f"{ names } = { eqn } " , ** kwargs )
487491
488- def score (self , x , t , x_dot = None , u = None , metric = r2_score , sample_weight = None , ** metric_kws ):
492+ def score (
493+ self ,
494+ x ,
495+ t ,
496+ x_dot = None ,
497+ u = None ,
498+ metric = r2_score ,
499+ sample_weight = None ,
500+ ** metric_kws ,
501+ ):
489502 """
490503 Returns a score for the time derivative prediction produced by the model.
491504
@@ -518,14 +531,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
518531 See `Scikit-learn \
519532 <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
520533 for more options.
521-
534+
522535 sample_weight : array-like of shape (n_samples,), optional
523536 Per-sample weights passed directly to the metric. This is the
524537 preferred way to supply weights.
525538
526539 metric_kws: dict, optional
527540 Optional keyword arguments to pass to the metric function.
528-
541+
529542
530543 Returns
531544 -------
@@ -548,7 +561,7 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
548561
549562 if sample_weight is not None :
550563 sample_weight = _expand_sample_weights (sample_weight , x )
551-
564+
552565 x_dot = concat_sample_axis (x_dot )
553566 x_dot_predict = concat_sample_axis (x_dot_predict )
554567
@@ -557,7 +570,7 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, *
557570 x_dot , x_dot_predict , return_indices = True
558571 )
559572 sample_weight = sample_weight [good_idx ]
560- metric_kws = {** metric_kws , "sample_weight" : sample_weight }
573+ metric_kws = {** metric_kws , "sample_weight" : sample_weight }
561574 else :
562575 x_dot , x_dot_predict = drop_nan_samples (x_dot , x_dot_predict )
563576
@@ -945,78 +958,115 @@ def comprehend_and_validate(arr, t):
945958 u = [comprehend_and_validate (ui , ti ) for ui , ti in _zip_like_sequence (u , t )]
946959 return x , x_dot , u
947960
948- def _expand_sample_weights (sample_weight , trajectories , feature_library = None , mode = "standard" ):
949- """Expand per-trajectory sample weights for estimators or weak-form libraries.
961+
962+ def _expand_sample_weights (
963+ sample_weight , trajectories , feature_library = None , mode = "standard"
964+ ):
965+ """
966+ Expand per-trajectory or per-sample weights for use in SINDy estimators.
950967
951968 Parameters
952969 ----------
953- sample_weight : sequence of array-like or None
954- Per-trajectory sample weights. Each element corresponds to one trajectory.
970+ sample_weight : sequence of scalars or array-like
971+ Weights for each trajectory. In "standard" mode, each entry can be:
972+ - a scalar weight (applied to all samples in that trajectory), or
973+ - an array of length equal to the number of samples (n_time) for that
974+ trajectory.
975+ In "weak" mode, each entry must be a single scalar weight per trajectory.
976+
955977 trajectories : sequence
956- Sequence of trajectory objects, each with attributes `n_time` and `n_coord`.
978+ Sequence of trajectory-like objects, each having attributes `n_time` and
979+ `n_coord`.
980+
957981 feature_library : object, optional
958- Library instance, required when mode='weak'.
982+ Library instance used in weak-form mode. Must define attribute `K`
983+ (the number of weak test functions). If missing, assumes K=1 with a warning.
984+
959985 mode : {'standard', 'weak'}, default='standard'
960- Expansion mode:
961- - 'standard' : Concatenate weights per sample or per coordinate.
962- - 'weak' : Expand weights for weak-form (integral) test functions.
986+ - "standard": Expand per-sample weights to match concatenated samples.
987+ - "weak": Repeat each trajectory’s single scalar weight `K` times.
963988
964989 Returns
965990 -------
966991 np.ndarray or None
967- Concatenated and expanded sample weights, or None if no weights are given.
992+ A 1D numpy array of concatenated and expanded sample weights,
993+ or None if `sample_weight` is None.
968994 """
995+ # -------------------------------------------------------------
996+ # Early exit for None
997+ # -------------------------------------------------------------
969998 if sample_weight is None :
970999 return None
9711000
972- if not (isinstance (sample_weight , Sequence ) and not isinstance (sample_weight , np .ndarray )):
973- raise ValueError ("sample_weight must be a list or tuple, not a scalar or numpy array." )
1001+ if not (
1002+ isinstance (sample_weight , Sequence )
1003+ and not isinstance (sample_weight , np .ndarray )
1004+ ):
1005+ raise ValueError (
1006+ "sample_weight must be a list or tuple, not a scalar or numpy array."
1007+ )
9741008
9751009 if len (sample_weight ) != len (trajectories ):
9761010 raise ValueError ("sample_weight length must match number of trajectories." )
9771011
978- # --- Validate shape consistency ---
979- validated = []
980- for sw , traj in zip (sample_weight , trajectories ):
981- arr = np .asarray (sw )
982- if arr .ndim == 0 :
983- validated .append (arr )
984- continue
985- if arr .shape [0 ] != traj .n_time :
986- raise ValueError ("sample_weight entry length does not match trajectory length." )
987- if arr .ndim == 2 and arr .shape [1 ] not in (1 , traj .n_coord ):
988- raise ValueError ("sample_weight 2D second dim must be 1 or equal to n_coord." )
989- validated .append (arr )
990-
991- # --- Weak-form expansion ---
1012+ # -------------------------------------------------------------
1013+ # Weak mode: one weight per trajectory, repeated K times
1014+ # -------------------------------------------------------------
9921015 if mode == "weak" :
993- n_funcs = getattr (feature_library , "K" , 1 )
994- if n_funcs is None :
995- warnings .warn ("feature_library missing 'K'; assuming 1 test function." )
996- n_funcs = 1
997- return np .concatenate ([np .repeat (np .asarray (sw ), n_funcs , axis = 0 ) for sw in validated ])
998-
999- # --- Standard expansion ---
1000- n_coords = {int (t .n_coord ) for t in trajectories }
1001- if len (n_coords ) != 1 :
1002- raise ValueError ("All trajectories must have the same n_coord." )
1003- n_coord = n_coords .pop ()
1004-
1005- processed = []
1006- for arr in validated :
1007- arr = np .asarray (arr )
1008- if arr .ndim == 1 :
1009- arr = arr .reshape (- 1 , 1 )
1010- elif arr .ndim == 2 and arr .shape [1 ] == 1 :
1011- pass # already correct shape
1012- processed .append (arr )
1013-
1014- # Promote to n_coord if any arrays have multiple coordinates
1015- is_scalar_weight = all (a .shape [1 ] == 1 for a in processed )
1016- if is_scalar_weight :
1017- return np .concatenate ([a .ravel () for a in processed ])
1018- expanded = [
1019- np .broadcast_to (a , (a .shape [0 ], n_coord )) if a .shape [1 ] == 1 else a
1020- for a in processed
1021- ]
1022- return np .concatenate (expanded , axis = 0 )
1016+ if feature_library is None :
1017+ raise ValueError ("feature_library is required in weak mode." )
1018+
1019+ K = getattr (feature_library , "K" , None )
1020+ if K is None :
1021+ warnings .warn ("feature_library missing 'K'; assuming K=1." , UserWarning )
1022+ K = 1
1023+
1024+ validated = []
1025+ for w , traj in zip (sample_weight , trajectories ):
1026+ arr = np .asarray (w )
1027+ if arr .ndim > 0 and arr .size > 1 :
1028+ raise ValueError (
1029+ "Weak mode expects exactly one weight per trajectory (scalar), "
1030+ f"but got shape { arr .shape } for trajectory with { traj .n_time } "
1031+ f"samples."
1032+ )
1033+ validated .append (float (arr ))
1034+ return np .repeat (validated , K )
1035+
1036+ # -------------------------------------------------------------
1037+ # Standard mode: expand scalars or per-sample arrays
1038+ # -------------------------------------------------------------
1039+ expanded = []
1040+ for w , traj in zip (sample_weight , trajectories ):
1041+ arr = np .asarray (w )
1042+
1043+ # Scalar → expand to all samples in trajectory
1044+ if arr .ndim == 0 :
1045+ arr = np .full (traj .n_time , arr , dtype = float )
1046+
1047+ # 1D array → must match number of samples
1048+ elif arr .ndim == 1 :
1049+ if arr .shape [0 ] != traj .n_time :
1050+ raise ValueError (
1051+ f"sample_weight length { arr .shape [0 ]} does"
1052+ f" not match trajectory length { traj .n_time } ."
1053+ )
1054+
1055+ # 2D array → only (n,1) allowed
1056+ elif arr .ndim == 2 :
1057+ if arr .shape [1 ] != 1 :
1058+ raise ValueError (
1059+ "sample_weight 2D arrays must have second dimension = 1."
1060+ )
1061+ if arr .shape [0 ] != traj .n_time :
1062+ raise ValueError (
1063+ "sample_weight 2D array length does not match trajectory length."
1064+ )
1065+ arr = arr .ravel ()
1066+
1067+ else :
1068+ raise ValueError ("Invalid sample_weight shape." )
1069+
1070+ expanded .append (arr .ravel ())
1071+
1072+ return np .concatenate (expanded )
0 commit comments