1
1
from __future__ import annotations
2
- from typing import Any , Type , TypeVar
2
+ from typing import Any , Type , TypeVar , Union , List , Dict
3
3
from pydantic import BaseModel , PrivateAttr , ValidationError , ConfigDict
4
4
from pydantic_core import InitErrorDetails , PydanticCustomError
5
5
import pandas as pd
6
+ from pyetm .models .warnings import WarningCollector
6
7
7
8
T = TypeVar ("T" , bound = "Base" )
8
9
9
10
10
11
class Base (BaseModel ):
11
12
"""
12
13
Custom base model that:
13
- - Collects non-breaking validation or runtime warnings
14
+ - Collects non-breaking validation or runtime warnings using WarningCollector
14
15
- Fails fast on critical errors
15
16
- Catches validation errors and converts them into warnings
16
17
- Validates on assignment, converting assignment errors into warnings
@@ -20,111 +21,101 @@ class Base(BaseModel):
20
21
# Enable assignment validation
21
22
model_config = ConfigDict (validate_assignment = True )
22
23
23
- # Internal list of warnings (not part of serialized schema)
24
- _warnings : dict [str , list [str ]] = PrivateAttr (default_factory = dict )
24
+ _warning_collector : WarningCollector = PrivateAttr (default_factory = WarningCollector )
25
25
26
26
def __init__ (self , ** data : Any ) -> None :
27
- # Ensure private warnings list exists before any validation
28
- object .__setattr__ (self , "_warnings" , {})
27
+ """
28
+ Initialize the model, converting validation errors to warnings.
29
+ """
30
+ object .__setattr__ (self , "_warning_collector" , WarningCollector ())
31
+
29
32
try :
30
33
super ().__init__ (** data )
31
34
except ValidationError as e :
32
- # Construct without validation to preserve fields
33
- inst = self .__class__ .model_construct (** data )
34
- # Copy field data
35
- object .__setattr__ (self , "__dict__" , inst .__dict__ .copy ())
36
- # Ensure warnings list on this instance
37
- if not hasattr (self , "_warnings" ):
38
- object .__setattr__ (self , "_warnings" , {})
39
- # Convert each validation error into a warning
40
- for err in e .errors ():
41
- loc = "." .join (str (x ) for x in err .get ("loc" , []))
42
- msg = err .get ("msg" , "" )
43
- self .add_warning (loc , msg )
35
+ # If validation fails, create model without validation and collect warnings
36
+ # Use model_construct to bypass validation
37
+ temp_instance = self .__class__ .model_construct (** data )
38
+
39
+ # Copy the constructed data to this instance
40
+ for field_name , field_value in temp_instance .__dict__ .items ():
41
+ if not field_name .startswith ("_" ):
42
+ object .__setattr__ (self , field_name , field_value )
43
+
44
+ # Convert validation errors to warnings
45
+ for error in e .errors ():
46
+ field_path = "." .join (str (part ) for part in error .get ("loc" , []))
47
+ message = error .get ("msg" , "Validation failed" )
48
+ self ._warning_collector .add (field_path , message , "error" )
44
49
45
50
def __setattr__ (self , name : str , value : Any ) -> None :
46
- """Abuses the fact that init does not return on valdiation errors"""
47
- # Intercept assignment-time validation errors
48
- if name in self .__class__ .model_fields :
49
- try :
50
- self ._clear_warnings_for_attr (name )
51
- current_data = self .model_dump ()
52
- current_data [name ] = value
53
- obj = self .__class__ .model_validate (current_data )
54
- if name in obj .warnings :
55
- self .add_warning (name , obj .warnings [name ])
56
- # Do not assign invalid value
57
- return
58
- except ValidationError as e :
59
- for err in e .errors ():
60
- if err .get ("loc" ) == (name ,):
61
- msg = err .get ("msg" , "" )
62
- self .add_warning (name , msg )
63
- # Do not assign invalid value
64
- return
65
-
66
- super ().__setattr__ (name , value )
67
-
68
- def add_warning (self , key : str , message : str ) -> None :
69
- """Append a warning message to this model."""
70
- # TODO: this is horrible. we need a struct for it!!
71
- if key in self ._warnings :
72
- if isinstance (self ._warnings [key ], dict ):
73
- if isinstance (message , dict ):
74
- self ._warnings [key ].update (message )
75
- else :
76
- self ._warnings [key ].update ({"base" , message })
77
- elif isinstance (message , list ):
78
- self ._warnings [key ].extend (message )
79
- else :
80
- self ._warnings [key ].append (message )
81
- else :
82
- # TODO: this is horrible. we need a struct for it
83
- if isinstance (message , list ) or isinstance (message , dict ):
84
- self ._warnings [key ] = message
85
- else :
86
- self ._warnings [key ] = [message ]
51
+ """
52
+ Handle assignment with validation error capture.
53
+ Simplified from the original complex implementation.
54
+ """
55
+ # Skip validation for private attributes
56
+ if name .startswith ("_" ) or name not in self .__class__ .model_fields :
57
+ super ().__setattr__ (name , value )
58
+ return
87
59
88
- @property
89
- def warnings (self ) -> dict [str , list [str ]]:
90
- """Return a copy of the warnings list."""
91
- return self ._warnings
60
+ # Clear existing warnings for this field
61
+ self ._warning_collector .clear (name )
92
62
93
- def show_warnings (self ) -> None :
94
- """Print all warnings to the console."""
95
- if not self ._warnings :
96
- print ("No warnings." )
63
+ try :
64
+ # Try to validate the new value by creating a copy with the update
65
+ current_data = self .model_dump ()
66
+ current_data [name ] = value
67
+
68
+ # Test validation with a temporary instance
69
+ test_instance = self .__class__ .model_validate (current_data )
70
+
71
+ # If validation succeeds, set the value
72
+ super ().__setattr__ (name , value )
73
+
74
+ except ValidationError as e :
75
+ # If validation fails, add warnings but don't set the value
76
+ for error in e .errors ():
77
+ if error .get ("loc" ) == (name ,):
78
+ message = error .get ("msg" , "Validation failed" )
79
+ self ._warning_collector .add (name , message , "warning" )
97
80
return
98
- print ("Warnings:" )
99
- # TODO: use prettyprint
100
- for i , w in enumerate (self ._warnings , start = 1 ):
101
- print (f" { i } . { w } " )
102
81
103
- def _clear_warnings_for_attr (self , key ):
104
- """
105
- Remove a key from the warnings.
106
- """
107
- self ._warnings .pop (key , None )
82
+ def add_warning (
83
+ self ,
84
+ field : str ,
85
+ message : Union [str , List [str ], Dict [str , Any ]],
86
+ severity : str = "warning" ,
87
+ ) -> None :
88
+ """Add a warning to this model instance."""
89
+ self ._warning_collector .add (field , message , severity )
108
90
109
- def _merge_submodel_warnings (self , * submodels : Base , key_attr = None ) -> None :
91
+ @property
92
+ def warnings (self ) -> Union [WarningCollector , Dict [str , List [str ]]]:
110
93
"""
111
- Bring warnings from a nested Base (or list thereof)
112
- into this model's warnings list.
94
+ Return warnings.
95
+
96
+ For backward compatibility, this can return either the new WarningCollector
97
+ or the legacy dict format. The implementation can be switched based on needs.
113
98
"""
114
- from typing import Iterable
99
+ # Return the new collector (recommended)
100
+ return self ._warning_collector
101
+
102
+ # OR return legacy format for backward compatibility:
103
+ # return self._warning_collector.to_legacy_dict()
115
104
116
- def _collect ( wm : Base ) :
117
- if not wm . warnings :
118
- return
105
+ def show_warnings ( self ) -> None :
106
+ """Print all warnings to the console."""
107
+ self . _warning_collector . show_warnings ()
119
108
120
- key = wm .__class__ .__name__
121
- if not key_attr is None :
122
- key += f"({ key_attr } ={ getattr (wm , key_attr )} )"
123
- self .add_warning (key , wm .warnings )
109
+ def _clear_warnings_for_attr (self , field : str ) -> None :
110
+ """Remove warnings for a specific field."""
111
+ self ._warning_collector .clear (field )
124
112
125
- for item in submodels :
126
- if isinstance (item , Base ):
127
- _collect (item )
113
+ def _merge_submodel_warnings (self , * submodels : Base , key_attr : str = None ) -> None :
114
+ """
115
+ Merge warnings from nested Base models.
116
+ Maintains compatibility with existing code while using the new system.
117
+ """
118
+ self ._warning_collector .merge_submodel_warnings (* submodels , key_attr = key_attr )
128
119
129
120
@classmethod
130
121
def load_safe (cls : Type [T ], ** data : Any ) -> T :
@@ -134,7 +125,7 @@ def load_safe(cls: Type[T], **data: Any) -> T:
134
125
"""
135
126
return cls (** data )
136
127
137
- def _get_serializable_fields (self ) -> list [str ]:
128
+ def _get_serializable_fields (self ) -> List [str ]:
138
129
"""
139
130
Parse and return column names for serialization.
140
131
Override this method in subclasses if you need custom field selection logic.
@@ -147,17 +138,14 @@ def _get_serializable_fields(self) -> list[str]:
147
138
148
139
def _raise_exception_on_loc (self , err : str , type : str , loc : str , msg : str ):
149
140
"""
150
- Nice and convoluted way to raise validation errors on custom locs .
151
- Used in model validators
141
+ Raise validation errors on custom locations .
142
+ Used in model validators.
152
143
"""
153
144
raise ValidationError .from_exception_data (
154
145
err ,
155
146
[
156
147
InitErrorDetails (
157
- type = PydanticCustomError (
158
- type ,
159
- msg ,
160
- ),
148
+ type = PydanticCustomError (type , msg ),
161
149
loc = (loc ,),
162
150
input = self ,
163
151
),
0 commit comments