1
1
from __future__ import annotations
2
- from typing import Any , Optional , Union
3
- from pydantic import field_validator , model_validator , ValidationInfo
2
+ from typing import Optional , Union
3
+ from pydantic import field_validator , model_validator
4
4
import pandas as pd
5
+ from pyetm .models .warnings import WarningCollector
5
6
from pyetm .models .base import Base
6
7
7
8
@@ -19,9 +20,9 @@ class Input(Base):
19
20
coupling_groups : Optional [list [str ]] = []
20
21
disabled_by : Optional [str ] = None
21
22
22
- def is_valid_update (self , value ) -> list [ str ] :
23
+ def is_valid_update (self , value ) -> WarningCollector :
23
24
"""
24
- Returns a list of validation warnings without updating the current object
25
+ Returns a WarningCollector with validation warnings without updating the current object.
25
26
"""
26
27
new_obj_dict = self .model_dump ()
27
28
new_obj_dict ["user" ] = value
@@ -30,7 +31,7 @@ def is_valid_update(self, value) -> list[str]:
30
31
return warnings_obj .warnings
31
32
32
33
@classmethod
33
- def from_json (cls , data : tuple [str , dict ]) -> Input :
34
+ def from_json (cls , data : tuple [str , dict ]) -> " Input" :
34
35
"""
35
36
Initialize an Input from a JSON-like tuple coming from .items()
36
37
"""
@@ -43,7 +44,8 @@ def from_json(cls, data: tuple[str, dict]) -> Input:
43
44
return input_instance
44
45
except Exception as e :
45
46
# Create a basic Input with warning attached
46
- basic_input = cls .model_validate (payload )
47
+ basic_input = cls .model_construct (** payload ) # Bypass validation
48
+ basic_input ._warning_collector = WarningCollector ()
47
49
basic_input .add_warning (key , f"Failed to create specialized input: { e } " )
48
50
return basic_input
49
51
@@ -78,7 +80,7 @@ class BoolInput(Input):
78
80
79
81
@field_validator ("user" , mode = "after" )
80
82
@classmethod
81
- def is_bool_float (cls , value : float ) -> float :
83
+ def is_bool_float (cls , value : Optional [ float ] ) -> Optional [ float ] :
82
84
if value == 1.0 or value == 0.0 or value is None :
83
85
return value
84
86
raise ValueError (
@@ -161,30 +163,39 @@ def __iter__(self):
161
163
def keys (self ):
162
164
return [input .key for input in self .inputs ]
163
165
164
- def is_valid_update (self , key_vals : dict ) -> dict :
166
+ # TODO: Check the efficiency of doing this in a loop
167
+ def is_valid_update (self , key_vals : dict ) -> dict [str , WarningCollector ]:
165
168
"""
166
- Returns a dict of input keys and errors when errors were found
169
+ Returns a dict mapping input keys to their WarningCollectors when errors were found.
167
170
"""
168
171
warnings = {}
169
- for input in self .inputs :
170
- if input .key in key_vals :
171
- input_warn = input .is_valid_update (key_vals [input .key ])
172
- if len (input_warn ) > 0 :
173
- warnings [input .key ] = input_warn
174
172
173
+ # Check each input that has an update
174
+ for input_obj in self .inputs :
175
+ if input_obj .key in key_vals :
176
+ input_warnings = input_obj .is_valid_update (key_vals [input_obj .key ])
177
+ if len (input_warnings ) > 0 :
178
+ warnings [input_obj .key ] = input_warnings
179
+
180
+ # Check for non-existent keys
175
181
non_existent_keys = set (key_vals .keys ()) - set (self .keys ())
176
182
for key in non_existent_keys :
177
- warnings [key ] = "Key does not exist"
183
+ # Create a warning collector for non-existent keys
184
+ warning_collector = WarningCollector ()
185
+ warning_collector .add (key , "Key does not exist" )
186
+ warnings [key ] = warning_collector
178
187
179
188
return warnings
180
189
181
190
def update (self , key_vals : dict ):
182
191
"""
183
- Update the values of certain inputs
192
+ Update the values of certain inputs.
193
+ Uses the new warning system for validation.
184
194
"""
185
- for input in self .inputs :
186
- if input .key in key_vals :
187
- input .user = key_vals [input .key ]
195
+ for input_obj in self .inputs :
196
+ if input_obj .key in key_vals :
197
+ # Use assignment which goes through __setattr__ validation
198
+ input_obj .user = key_vals [input_obj .key ]
188
199
189
200
def _to_dataframe (self , columns = "user" , ** kwargs ) -> pd .DataFrame :
190
201
"""
0 commit comments