Skip to content

Commit 45fc1d3

Browse files
committed
Warnings implementented as separate model with repr
1 parent e4c3198 commit 45fc1d3

File tree

4 files changed

+858
-134
lines changed

4 files changed

+858
-134
lines changed

src/pyetm/models/base.py

Lines changed: 85 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
2-
from typing import Any, Type, TypeVar
2+
from typing import Any, Type, TypeVar, Union, List, Dict
33
from pydantic import BaseModel, PrivateAttr, ValidationError, ConfigDict
44
from pydantic_core import InitErrorDetails, PydanticCustomError
55
import pandas as pd
6+
from pyetm.models.warnings import WarningCollector
67

78
T = TypeVar("T", bound="Base")
89

910

1011
class Base(BaseModel):
1112
"""
1213
Custom base model that:
13-
- Collects non-breaking validation or runtime warnings
14+
- Collects non-breaking validation or runtime warnings using WarningCollector
1415
- Fails fast on critical errors
1516
- Catches validation errors and converts them into warnings
1617
- Validates on assignment, converting assignment errors into warnings
@@ -20,111 +21,101 @@ class Base(BaseModel):
2021
# Enable assignment validation
2122
model_config = ConfigDict(validate_assignment=True)
2223

23-
# Internal list of warnings (not part of serialized schema)
24-
_warnings: dict[str, list[str]] = PrivateAttr(default_factory=dict)
24+
_warning_collector: WarningCollector = PrivateAttr(default_factory=WarningCollector)
2525

2626
def __init__(self, **data: Any) -> None:
27-
# Ensure private warnings list exists before any validation
28-
object.__setattr__(self, "_warnings", {})
27+
"""
28+
Initialize the model, converting validation errors to warnings.
29+
"""
30+
object.__setattr__(self, "_warning_collector", WarningCollector())
31+
2932
try:
3033
super().__init__(**data)
3134
except ValidationError as e:
32-
# Construct without validation to preserve fields
33-
inst = self.__class__.model_construct(**data)
34-
# Copy field data
35-
object.__setattr__(self, "__dict__", inst.__dict__.copy())
36-
# Ensure warnings list on this instance
37-
if not hasattr(self, "_warnings"):
38-
object.__setattr__(self, "_warnings", {})
39-
# Convert each validation error into a warning
40-
for err in e.errors():
41-
loc = ".".join(str(x) for x in err.get("loc", []))
42-
msg = err.get("msg", "")
43-
self.add_warning(loc, msg)
35+
# If validation fails, create model without validation and collect warnings
36+
# Use model_construct to bypass validation
37+
temp_instance = self.__class__.model_construct(**data)
38+
39+
# Copy the constructed data to this instance
40+
for field_name, field_value in temp_instance.__dict__.items():
41+
if not field_name.startswith("_"):
42+
object.__setattr__(self, field_name, field_value)
43+
44+
# Convert validation errors to warnings
45+
for error in e.errors():
46+
field_path = ".".join(str(part) for part in error.get("loc", []))
47+
message = error.get("msg", "Validation failed")
48+
self._warning_collector.add(field_path, message, "error")
4449

4550
def __setattr__(self, name: str, value: Any) -> None:
46-
"""Abuses the fact that init does not return on valdiation errors"""
47-
# Intercept assignment-time validation errors
48-
if name in self.__class__.model_fields:
49-
try:
50-
self._clear_warnings_for_attr(name)
51-
current_data = self.model_dump()
52-
current_data[name] = value
53-
obj = self.__class__.model_validate(current_data)
54-
if name in obj.warnings:
55-
self.add_warning(name, obj.warnings[name])
56-
# Do not assign invalid value
57-
return
58-
except ValidationError as e:
59-
for err in e.errors():
60-
if err.get("loc") == (name,):
61-
msg = err.get("msg", "")
62-
self.add_warning(name, msg)
63-
# Do not assign invalid value
64-
return
65-
66-
super().__setattr__(name, value)
67-
68-
def add_warning(self, key: str, message: str) -> None:
69-
"""Append a warning message to this model."""
70-
# TODO: this is horrible. we need a struct for it!!
71-
if key in self._warnings:
72-
if isinstance(self._warnings[key], dict):
73-
if isinstance(message, dict):
74-
self._warnings[key].update(message)
75-
else:
76-
self._warnings[key].update({"base", message})
77-
elif isinstance(message, list):
78-
self._warnings[key].extend(message)
79-
else:
80-
self._warnings[key].append(message)
81-
else:
82-
# TODO: this is horrible. we need a struct for it
83-
if isinstance(message, list) or isinstance(message, dict):
84-
self._warnings[key] = message
85-
else:
86-
self._warnings[key] = [message]
51+
"""
52+
Handle assignment with validation error capture.
53+
Simplified from the original complex implementation.
54+
"""
55+
# Skip validation for private attributes
56+
if name.startswith("_") or name not in self.__class__.model_fields:
57+
super().__setattr__(name, value)
58+
return
8759

88-
@property
89-
def warnings(self) -> dict[str, list[str]]:
90-
"""Return a copy of the warnings list."""
91-
return self._warnings
60+
# Clear existing warnings for this field
61+
self._warning_collector.clear(name)
9262

93-
def show_warnings(self) -> None:
94-
"""Print all warnings to the console."""
95-
if not self._warnings:
96-
print("No warnings.")
63+
try:
64+
# Try to validate the new value by creating a copy with the update
65+
current_data = self.model_dump()
66+
current_data[name] = value
67+
68+
# Test validation with a temporary instance
69+
test_instance = self.__class__.model_validate(current_data)
70+
71+
# If validation succeeds, set the value
72+
super().__setattr__(name, value)
73+
74+
except ValidationError as e:
75+
# If validation fails, add warnings but don't set the value
76+
for error in e.errors():
77+
if error.get("loc") == (name,):
78+
message = error.get("msg", "Validation failed")
79+
self._warning_collector.add(name, message, "warning")
9780
return
98-
print("Warnings:")
99-
# TODO: use prettyprint
100-
for i, w in enumerate(self._warnings, start=1):
101-
print(f" {i}. {w}")
10281

103-
def _clear_warnings_for_attr(self, key):
104-
"""
105-
Remove a key from the warnings.
106-
"""
107-
self._warnings.pop(key, None)
82+
def add_warning(
83+
self,
84+
field: str,
85+
message: Union[str, List[str], Dict[str, Any]],
86+
severity: str = "warning",
87+
) -> None:
88+
"""Add a warning to this model instance."""
89+
self._warning_collector.add(field, message, severity)
10890

109-
def _merge_submodel_warnings(self, *submodels: Base, key_attr=None) -> None:
91+
@property
92+
def warnings(self) -> Union[WarningCollector, Dict[str, List[str]]]:
11093
"""
111-
Bring warnings from a nested Base (or list thereof)
112-
into this model's warnings list.
94+
Return warnings.
95+
96+
For backward compatibility, this can return either the new WarningCollector
97+
or the legacy dict format. The implementation can be switched based on needs.
11398
"""
114-
from typing import Iterable
99+
# Return the new collector (recommended)
100+
return self._warning_collector
101+
102+
# OR return legacy format for backward compatibility:
103+
# return self._warning_collector.to_legacy_dict()
115104

116-
def _collect(wm: Base):
117-
if not wm.warnings:
118-
return
105+
def show_warnings(self) -> None:
106+
"""Print all warnings to the console."""
107+
self._warning_collector.show_warnings()
119108

120-
key = wm.__class__.__name__
121-
if not key_attr is None:
122-
key += f"({key_attr}={getattr(wm, key_attr)})"
123-
self.add_warning(key, wm.warnings)
109+
def _clear_warnings_for_attr(self, field: str) -> None:
110+
"""Remove warnings for a specific field."""
111+
self._warning_collector.clear(field)
124112

125-
for item in submodels:
126-
if isinstance(item, Base):
127-
_collect(item)
113+
def _merge_submodel_warnings(self, *submodels: Base, key_attr: str = None) -> None:
114+
"""
115+
Merge warnings from nested Base models.
116+
Maintains compatibility with existing code while using the new system.
117+
"""
118+
self._warning_collector.merge_submodel_warnings(*submodels, key_attr=key_attr)
128119

129120
@classmethod
130121
def load_safe(cls: Type[T], **data: Any) -> T:
@@ -134,7 +125,7 @@ def load_safe(cls: Type[T], **data: Any) -> T:
134125
"""
135126
return cls(**data)
136127

137-
def _get_serializable_fields(self) -> list[str]:
128+
def _get_serializable_fields(self) -> List[str]:
138129
"""
139130
Parse and return column names for serialization.
140131
Override this method in subclasses if you need custom field selection logic.
@@ -147,17 +138,14 @@ def _get_serializable_fields(self) -> list[str]:
147138

148139
def _raise_exception_on_loc(self, err: str, type: str, loc: str, msg: str):
149140
"""
150-
Nice and convoluted way to raise validation errors on custom locs.
151-
Used in model validators
141+
Raise validation errors on custom locations.
142+
Used in model validators.
152143
"""
153144
raise ValidationError.from_exception_data(
154145
err,
155146
[
156147
InitErrorDetails(
157-
type=PydanticCustomError(
158-
type,
159-
msg,
160-
),
148+
type=PydanticCustomError(type, msg),
161149
loc=(loc,),
162150
input=self,
163151
),

0 commit comments

Comments
 (0)