Skip to content

Commit bfe2ca8

Browse files
committed
Update inputs to new warnings model
1 parent 45fc1d3 commit bfe2ca8

File tree

3 files changed

+380
-128
lines changed

3 files changed

+380
-128
lines changed

src/pyetm/models/inputs.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
2-
from typing import Any, Optional, Union
3-
from pydantic import field_validator, model_validator, ValidationInfo
2+
from typing import Optional, Union
3+
from pydantic import field_validator, model_validator
44
import pandas as pd
5+
from pyetm.models.warnings import WarningCollector
56
from pyetm.models.base import Base
67

78

@@ -19,9 +20,9 @@ class Input(Base):
1920
coupling_groups: Optional[list[str]] = []
2021
disabled_by: Optional[str] = None
2122

22-
def is_valid_update(self, value) -> list[str]:
23+
def is_valid_update(self, value) -> WarningCollector:
2324
"""
24-
Returns a list of validation warnings without updating the current object
25+
Returns a WarningCollector with validation warnings without updating the current object.
2526
"""
2627
new_obj_dict = self.model_dump()
2728
new_obj_dict["user"] = value
@@ -30,7 +31,7 @@ def is_valid_update(self, value) -> list[str]:
3031
return warnings_obj.warnings
3132

3233
@classmethod
33-
def from_json(cls, data: tuple[str, dict]) -> Input:
34+
def from_json(cls, data: tuple[str, dict]) -> "Input":
3435
"""
3536
Initialize an Input from a JSON-like tuple coming from .items()
3637
"""
@@ -43,7 +44,8 @@ def from_json(cls, data: tuple[str, dict]) -> Input:
4344
return input_instance
4445
except Exception as e:
4546
# Create a basic Input with warning attached
46-
basic_input = cls.model_validate(payload)
47+
basic_input = cls.model_construct(**payload) # Bypass validation
48+
basic_input._warning_collector = WarningCollector()
4749
basic_input.add_warning(key, f"Failed to create specialized input: {e}")
4850
return basic_input
4951

@@ -78,7 +80,7 @@ class BoolInput(Input):
7880

7981
@field_validator("user", mode="after")
8082
@classmethod
81-
def is_bool_float(cls, value: float) -> float:
83+
def is_bool_float(cls, value: Optional[float]) -> Optional[float]:
8284
if value == 1.0 or value == 0.0 or value is None:
8385
return value
8486
raise ValueError(
@@ -161,30 +163,39 @@ def __iter__(self):
161163
def keys(self):
162164
return [input.key for input in self.inputs]
163165

164-
def is_valid_update(self, key_vals: dict) -> dict:
166+
# TODO: Check the efficiency of doing this in a loop
167+
def is_valid_update(self, key_vals: dict) -> dict[str, WarningCollector]:
165168
"""
166-
Returns a dict of input keys and errors when errors were found
169+
Returns a dict mapping input keys to their WarningCollectors when errors were found.
167170
"""
168171
warnings = {}
169-
for input in self.inputs:
170-
if input.key in key_vals:
171-
input_warn = input.is_valid_update(key_vals[input.key])
172-
if len(input_warn) > 0:
173-
warnings[input.key] = input_warn
174172

173+
# Check each input that has an update
174+
for input_obj in self.inputs:
175+
if input_obj.key in key_vals:
176+
input_warnings = input_obj.is_valid_update(key_vals[input_obj.key])
177+
if len(input_warnings) > 0:
178+
warnings[input_obj.key] = input_warnings
179+
180+
# Check for non-existent keys
175181
non_existent_keys = set(key_vals.keys()) - set(self.keys())
176182
for key in non_existent_keys:
177-
warnings[key] = "Key does not exist"
183+
# Create a warning collector for non-existent keys
184+
warning_collector = WarningCollector()
185+
warning_collector.add(key, "Key does not exist")
186+
warnings[key] = warning_collector
178187

179188
return warnings
180189

181190
def update(self, key_vals: dict):
182191
"""
183-
Update the values of certain inputs
192+
Update the values of certain inputs.
193+
Uses the new warning system for validation.
184194
"""
185-
for input in self.inputs:
186-
if input.key in key_vals:
187-
input.user = key_vals[input.key]
195+
for input_obj in self.inputs:
196+
if input_obj.key in key_vals:
197+
# Use assignment which goes through __setattr__ validation
198+
input_obj.user = key_vals[input_obj.key]
188199

189200
def _to_dataframe(self, columns="user", **kwargs) -> pd.DataFrame:
190201
"""

tests/models/test_input.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

0 commit comments

Comments
 (0)