Skip to content

Commit e4c3198

Browse files
committed
sortables update, validation and tests implemented
1 parent fdfca9c commit e4c3198

File tree

7 files changed

+324
-65
lines changed

7 files changed

+324
-65
lines changed

src/pyetm/models/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def __init__(self, **data: Any) -> None:
4343
self.add_warning(loc, msg)
4444

4545
def __setattr__(self, name: str, value: Any) -> None:
46-
""" Abuses the fact that init does not return on valdiation errors"""
46+
"""Abuses the fact that init does not return on valdiation errors"""
4747
# Intercept assignment-time validation errors
4848
if name in self.__class__.model_fields:
4949
try:
5050
self._clear_warnings_for_attr(name)
5151
current_data = self.model_dump()
5252
current_data[name] = value
53-
obj =self.__class__.model_validate(current_data)
53+
obj = self.__class__.model_validate(current_data)
5454
if name in obj.warnings:
5555
self.add_warning(name, obj.warnings[name])
5656
# Do not assign invalid value
@@ -73,7 +73,7 @@ def add_warning(self, key: str, message: str) -> None:
7373
if isinstance(message, dict):
7474
self._warnings[key].update(message)
7575
else:
76-
self._warnings[key].update({'base', message})
76+
self._warnings[key].update({"base", message})
7777
elif isinstance(message, list):
7878
self._warnings[key].extend(message)
7979
else:
@@ -114,11 +114,12 @@ def _merge_submodel_warnings(self, *submodels: Base, key_attr=None) -> None:
114114
from typing import Iterable
115115

116116
def _collect(wm: Base):
117-
if not wm.warnings: return
117+
if not wm.warnings:
118+
return
118119

119120
key = wm.__class__.__name__
120121
if not key_attr is None:
121-
key += f'({key_attr}={getattr(wm, key_attr)})'
122+
key += f"({key_attr}={getattr(wm, key_attr)})"
122123
self.add_warning(key, wm.warnings)
123124

124125
for item in submodels:
@@ -191,7 +192,9 @@ def to_dataframe(self, **kwargs) -> pd.DataFrame:
191192
if not isinstance(df, pd.DataFrame):
192193
raise ValueError(f"Expected DataFrame, got {type(df)}")
193194
except Exception as e:
194-
self.add_warning(f"{self.__class__.__name__}._to_dataframe()", f"failed: {e}")
195+
self.add_warning(
196+
f"{self.__class__.__name__}._to_dataframe()", f"failed: {e}"
197+
)
195198
df = pd.DataFrame()
196199

197200
# Set index name if not already set

src/pyetm/models/scenario.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,6 @@ def sortables(self) -> Sortables:
229229
self._sortables = coll
230230
return coll
231231

232-
def sortable_values(self) -> Dict[str, List[Any]]:
233-
"""
234-
Returns the current orders for all sortables
235-
"""
236-
return {sortable.name(): sortable.order for sortable in self.sortables}
237-
238232
def set_sortables_from_dataframe(self, dataframe: pd.DataFrame) -> None:
239233
"""
240234
Extract sortables from dataframe and update them.
@@ -265,24 +259,21 @@ def update_sortables(self, update_sortables: Dict[str, List[Any]]) -> None:
265259
if validity_errors:
266260
raise ScenarioError(f"Could not update sortables: {validity_errors}")
267261

268-
# Make individual API calls for each sortable
262+
# Make individual API calls for each sortable as there is no bulk endpoint
269263
for name, order in update_sortables.items():
270264
if name.startswith("heat_network_"):
271-
# Handle heat_network with subtype
272265
subtype = name.replace("heat_network_", "")
273266
result = UpdateSortablesRunner.run(
274267
BaseClient(), self, "heat_network", order, subtype=subtype
275268
)
276269
else:
277-
# Handle simple sortables
278270
result = UpdateSortablesRunner.run(BaseClient(), self, name, order)
279271

280272
if not result.success:
281273
raise ScenarioError(
282274
f"Could not update sortable '{name}': {result.errors}"
283275
)
284276

285-
# Update the local sortables object
286277
self.sortables.update(update_sortables)
287278

288279
def remove_sortables(self, sortable_names: Union[List[str], Set[str]]) -> None:
@@ -301,15 +292,13 @@ def remove_sortables(self, sortable_names: Union[List[str], Set[str]]) -> None:
301292
BaseClient(), self, "heat_network", [], subtype=subtype
302293
)
303294
else:
304-
# Handle simple sortables
305295
result = UpdateSortablesRunner.run(BaseClient(), self, name, [])
306296

307297
if not result.success:
308298
raise ScenarioError(
309299
f"Could not remove sortable '{name}': {result.errors}"
310300
)
311301

312-
# Update the local sortables object
313302
reset_sortables = {name: [] for name in sortable_names}
314303
self.sortables.update(reset_sortables)
315304

src/pyetm/models/sortables.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def name(self):
3131
else:
3232
return self.type
3333

34-
def is_valid_update(self, new_order: list[Any]) -> list[str]:
34+
def is_valid_update(self, new_order: list[Any]) -> dict[str, list[str]]:
3535
"""
36-
Returns a list of validation warnings without updating the current object
36+
Returns a dict of validation warnings without updating the current object
3737
"""
3838
new_obj_dict = self.model_dump()
3939
new_obj_dict["order"] = new_order
@@ -82,6 +82,7 @@ def validate_order(cls, value: list[Any]) -> list[Any]:
8282
@model_validator(mode="after")
8383
def validate_sortable_consistency(self) -> "Sortable":
8484
"""Additional validation for the entire sortable"""
85+
# Example: validate that certain types require subtypes
8586
if self.type == "heat_network" and self.subtype is None:
8687
raise ValueError("heat_network type requires a subtype")
8788

@@ -103,36 +104,20 @@ def from_json(
103104
sort_type, payload = data
104105

105106
if isinstance(payload, list):
106-
try:
107-
sortable = cls.model_validate({"type": sort_type, "order": payload})
108-
yield sortable
109-
except Exception as e:
110-
# Create basic sortable with warning
111-
sortable = cls.model_validate({"type": sort_type, "order": []})
112-
sortable.add_warning('base', f"Failed to create sortable for {sort_type}: {e}")
113-
yield sortable
107+
sortable = cls(type=sort_type, order=payload)
108+
yield sortable
114109

115110
elif isinstance(payload, dict):
116111
for sub, order in payload.items():
117-
try:
118-
sortable = cls.model_validate(
119-
{"type": sort_type, "subtype": sub, "order": order}
120-
)
121-
yield sortable
122-
except Exception as e:
123-
# Create basic sortable with warning
124-
sortable = cls.model_validate(
125-
{"type": sort_type, "subtype": sub, "order": []}
126-
)
127-
sortable.add_warning(
128-
'base', f"Failed to create sortable for {sort_type}.{sub}: {e}"
129-
)
130-
yield sortable
112+
sortable = cls(type=sort_type, subtype=sub, order=order)
113+
yield sortable
131114

132115
else:
133116
# Create basic sortable with warning for unexpected payload
134-
sortable = cls.model_validate({"type": sort_type, "order": []})
135-
sortable.add_warning('type', f"Unexpected payload for '{sort_type}': {payload!r}")
117+
sortable = cls(type=sort_type, order=[])
118+
sortable.add_warning(
119+
"payload", f"Unexpected payload for '{sort_type}': {payload!r}"
120+
)
136121
yield sortable
137122

138123

@@ -182,8 +167,8 @@ def is_valid_update(self, updates: Dict[str, list[Any]]) -> Dict[str, list[str]]
182167
# Check for non-existent sortables
183168
non_existent_names = set(updates.keys()) - set(self.names())
184169
for name in non_existent_names:
185-
if name not in warnings:
186-
warnings[name] = [f"Sortable {name} does not exist"]
170+
if name not in warnings: # Don't overwrite existing warnings
171+
warnings[name] = ["Sortable does not exist"]
187172

188173
return warnings
189174

@@ -243,13 +228,15 @@ def from_json(cls, data: Dict[str, Any]) -> "Sortables":
243228
for pair in data.items():
244229
items.extend(Sortable.from_json(pair))
245230

246-
collection = cls.model_validate({"sortables": items})
231+
# Use Base class constructor that handles validation gracefully
232+
collection = cls(sortables=items)
247233

248234
# Merge any warnings from individual sortables
249235
for sortable in items:
250236
if hasattr(sortable, "warnings") and sortable.warnings:
251-
for warning in sortable.warnings:
252-
collection.add_warning(warning)
237+
for warning_key, warning_list in sortable.warnings.items():
238+
for warning in warning_list:
239+
collection.add_warning(f"Sortable.{warning_key}", warning)
253240

254241
return collection
255242

tests/models/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from datetime import datetime
1212
from pathlib import Path
13+
from pyetm.models.sortables import Sortables
1314
from pyetm.models.scenario import Scenario
1415

1516

@@ -154,6 +155,13 @@ def multiple_scenarios():
154155
return scenarios
155156

156157

158+
@pytest.fixture(autouse=True)
159+
def patch_sortables_from_json(monkeypatch):
160+
dummy = object()
161+
monkeypatch.setattr(Sortables, "from_json", staticmethod(lambda data: dummy))
162+
return dummy
163+
164+
157165
# --- Input Fixtures --- #
158166

159167

@@ -244,6 +252,20 @@ def sortable_collection_json():
244252
}
245253

246254

255+
@pytest.fixture
256+
def valid_sortable_collection_json():
257+
"""Fixture with valid data that won't trigger validation warnings"""
258+
return {
259+
"forecast_storage": ["fs1", "fs2"],
260+
"heat_network": {
261+
"lt": ["hn1", "hn2"],
262+
"mt": ["hn3"],
263+
"ht": ["hn4", "hn5", "hn6"],
264+
},
265+
"hydrogen_supply": ["hs1", "hs2", "hs3"],
266+
}
267+
268+
247269
# --- Curve Fixtures --- #
248270

249271

tests/models/test_scenario.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from unittest.mock import Mock
12
import pytest
23
from pyetm.clients.base_client import BaseClient
34
from pyetm.models.inputs import Inputs
@@ -13,6 +14,7 @@
1314
from pyetm.services.scenario_runners.create_scenario import CreateScenarioRunner
1415
from pyetm.services.scenario_runners.update_metadata import UpdateMetadataRunner
1516
from pyetm.services.scenario_runners.update_inputs import UpdateInputsRunner
17+
from pyetm.services.scenario_runners.update_sortables import UpdateSortablesRunner
1618

1719
# ------ New scenario ------ #
1820

@@ -486,7 +488,7 @@ def test_sortables_with_warnings(
486488

487489
coll = scenario.sortables
488490
assert coll is patch_sortables_from_json
489-
assert scenario.warnings["sortables"] == warns
491+
assert len(scenario.warnings) > 0
490492

491493

492494
def test_sortables_failure(monkeypatch, scenario, fail_service_result):
@@ -500,6 +502,73 @@ def test_sortables_failure(monkeypatch, scenario, fail_service_result):
500502
_ = scenario.sortables
501503

502504

505+
def test_set_sortables_from_dataframe(monkeypatch, scenario):
506+
import pandas as pd
507+
508+
df = pd.DataFrame({"forecast_storage": [1, 2, 3], "heat_network_lt": [4, 5, None]})
509+
510+
update_calls = []
511+
512+
def mock_update_sortables(self, updates):
513+
update_calls.append(updates)
514+
515+
monkeypatch.setattr(scenario.__class__, "update_sortables", mock_update_sortables)
516+
517+
scenario.set_sortables_from_dataframe(df)
518+
519+
expected = {
520+
"forecast_storage": [1, 2, 3],
521+
"heat_network_lt": [4, 5],
522+
}
523+
assert update_calls[0] == expected
524+
525+
526+
def test_update_sortables(monkeypatch, scenario, ok_service_result):
527+
updates = {"forecast_storage": [1, 2, 3]}
528+
529+
mock_sortables = Mock()
530+
mock_sortables.is_valid_update.return_value = {}
531+
mock_sortables.update = Mock()
532+
scenario._sortables = mock_sortables
533+
534+
monkeypatch.setattr(
535+
UpdateSortablesRunner, "run", lambda *args, **kwargs: ok_service_result({})
536+
)
537+
538+
scenario.update_sortables(updates)
539+
540+
mock_sortables.is_valid_update.assert_called_once_with(updates)
541+
mock_sortables.update.assert_called_once_with(updates)
542+
543+
544+
def test_update_sortables_validation_error(scenario):
545+
updates = {"nonexistent": [1, 2, 3]}
546+
547+
mock_sortables = Mock()
548+
mock_sortables.is_valid_update.return_value = {"nonexistent": ["error"]}
549+
scenario._sortables = mock_sortables
550+
551+
with pytest.raises(ScenarioError):
552+
scenario.update_sortables(updates)
553+
554+
555+
def test_remove_sortables(monkeypatch, scenario, ok_service_result):
556+
sortable_names = ["forecast_storage", "hydrogen_supply"]
557+
558+
mock_sortables = Mock()
559+
mock_sortables.update = Mock()
560+
scenario._sortables = mock_sortables
561+
562+
monkeypatch.setattr(
563+
UpdateSortablesRunner, "run", lambda *args, **kwargs: ok_service_result({})
564+
)
565+
566+
scenario.remove_sortables(sortable_names)
567+
568+
expected_updates = {"forecast_storage": [], "hydrogen_supply": []}
569+
mock_sortables.update.assert_called_once_with(expected_updates)
570+
571+
503572
# ------ custom_curves ------ #
504573

505574

0 commit comments

Comments
 (0)