Skip to content

Commit 54e7cf7

Browse files
committed
Inprovements to caching and DRY
1 parent 4e826e8 commit 54e7cf7

File tree

6 files changed

+130
-67
lines changed

6 files changed

+130
-67
lines changed

examples/create_or_query_scenarios.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"scenarios = Scenarios.from_excel(\"example_input_excel.xlsx\")\n",
4949
"\n",
5050
"# Here we're also loading a scenario directly from the API and adding it to the scenarios loaded/created via the excel\n",
51-
"scenario_a = Scenario.load(1357691)\n",
51+
"scenario_a = Scenario.load(2690439)\n",
5252
"scenarios.add(scenario_a)"
5353
]
5454
},

inputs/example_input_excel.xlsx

25 Bytes
Binary file not shown.

src/pyetm/models/inputs.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,23 @@ def __iter__(self):
162162
def keys(self):
163163
return [input.key for input in self.inputs]
164164

165-
# TODO: Check the efficiency of doing this in a loop
166165
def is_valid_update(self, key_vals: dict) -> dict[str, WarningCollector]:
167166
"""
168167
Returns a dict mapping input keys to their WarningCollectors when errors were found.
169168
"""
170-
warnings = {}
169+
warnings: dict[str, WarningCollector] = {}
171170

172-
# Check each input that has an update
173-
for input_obj in self.inputs:
174-
if input_obj.key in key_vals:
175-
input_warnings = input_obj.is_valid_update(key_vals[input_obj.key])
176-
if len(input_warnings) > 0:
177-
warnings[input_obj.key] = input_warnings
178-
179-
# Check for non-existent keys
180-
non_existent_keys = set(key_vals.keys()) - set(self.keys())
181-
for key in non_existent_keys:
182-
warnings[key] = WarningCollector.with_warning(key, "Key does not exist")
171+
input_map = {inp.key: inp for inp in self.inputs}
172+
173+
for key, value in key_vals.items():
174+
input_obj = input_map.get(key)
175+
if input_obj is None:
176+
warnings[key] = WarningCollector.with_warning(key, "Key does not exist")
177+
continue
178+
179+
input_warnings = input_obj.is_valid_update(value)
180+
if len(input_warnings) > 0:
181+
warnings[key] = input_warnings
183182

184183
return warnings
185184

src/pyetm/models/output_curves.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
from pathlib import Path
55
from typing import Optional
6+
import os
67

78
import yaml
89
from pyetm.clients import BaseClient
@@ -15,6 +16,18 @@
1516
)
1617

1718

19+
# Small LRU cache for reading CSVs from disk. Uses mtime to invalidate when file changes.
20+
def _read_csv_cached(path: Path) -> pd.DataFrame:
21+
return _read_csv_cached_impl(str(path), os.path.getmtime(path))
22+
23+
24+
# TODO determine appropriate maxsize
25+
@lru_cache(maxsize=64)
26+
def _read_csv_cached_impl(path_str: str, mtime: float) -> pd.DataFrame:
27+
df = pd.read_csv(path_str, index_col=0)
28+
return df.dropna(how="all")
29+
30+
1831
class OutputCurveError(Exception):
1932
"""Base carrier curve error"""
2033

@@ -34,18 +47,26 @@ class OutputCurve(Base):
3447
def available(self) -> bool:
3548
return bool(self.file_path)
3649

37-
def retrieve(self, client, scenario) -> Optional[pd.DataFrame]:
50+
def retrieve(
51+
self, client, scenario, force_refresh: bool = False
52+
) -> Optional[pd.DataFrame]:
3853
"""Process curve from client, save to file, set file_path"""
3954
file_path = (
4055
get_settings().path_to_tmp(str(scenario.id))
4156
/ f"{self.key.replace('/','-')}.csv"
4257
)
4358

44-
# TODO: Examine the caching situation in the future if time permits: could be particularly
45-
# relevant for bulk processing
46-
# if file_path.is_file():
47-
# self.file_path = file_path
48-
# return self.contents()
59+
# Reuse a cached file if present unless explicitly refreshing.
60+
if not force_refresh and file_path.is_file():
61+
self.file_path = file_path
62+
try:
63+
return _read_csv_cached(self.file_path)
64+
except Exception as e:
65+
# Fall through to re-download on cache read failure
66+
self.add_warning(
67+
"file_path",
68+
f"Failed to read cached curve file for {self.key}: {e}; refetching",
69+
)
4970
try:
5071
result = DownloadOutputCurveRunner.run(client, scenario, self.key)
5172
if result.success:
@@ -80,8 +101,7 @@ def contents(self) -> Optional[pd.DataFrame]:
80101
return None
81102

82103
try:
83-
df = pd.read_csv(self.file_path, index_col=0)
84-
return df.dropna(how="all")
104+
return _read_csv_cached(self.file_path)
85105
except Exception as e:
86106
self.add_warning(
87107
"file_path", f"Failed to read curve file for {self.key}: {e}"
@@ -147,6 +167,17 @@ def get_contents(self, scenario, curve_name: str) -> Optional[pd.DataFrame]:
147167
return None
148168

149169
if not curve.available():
170+
# Try to attach a cached file from disk first
171+
expected_path = (
172+
get_settings().path_to_tmp(str(scenario.id))
173+
/ f"{curve.key.replace('/', '-')}.csv"
174+
)
175+
if expected_path.is_file():
176+
curve.file_path = expected_path
177+
contents = curve.contents()
178+
self._merge_submodel_warnings(curve, key_attr="key")
179+
return contents
180+
150181
result = curve.retrieve(BaseClient(), scenario)
151182
self._merge_submodel_warnings(curve, key_attr="key")
152183
return result
@@ -193,17 +224,7 @@ def get_curves_by_carrier_type(
193224
Returns:
194225
Dictionary mapping curve names to DataFrames
195226
"""
196-
carrier_mapping = {
197-
"electricity": ["merit_order", "electricity_price", "residual_load"],
198-
"heat": [
199-
"heat_network",
200-
"agriculture_heat",
201-
"household_heat",
202-
"buildings_heat",
203-
],
204-
"hydrogen": ["hydrogen", "hydrogen_integral_cost"],
205-
"methane": ["network_gas"],
206-
}
227+
carrier_mapping = self._load_carrier_mappings()
207228

208229
if carrier_type not in carrier_mapping:
209230
valid_types = ", ".join(carrier_mapping.keys())

src/pyetm/models/scenario.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,7 @@ def update_user_values(self, update_inputs: Dict[str, Any]) -> None:
288288
"""
289289
# Update them in the Inputs object, and check validation
290290
validity_errors = self.inputs.is_valid_update(update_inputs)
291-
if validity_errors:
292-
error_summary = []
293-
for key, warning_collector in validity_errors.items():
294-
warnings_list = [w.message for w in warning_collector]
295-
error_summary.append(f"{key}: {warnings_list}")
296-
raise ScenarioError(f"Could not update user values: {error_summary}")
291+
self._handle_validity_errors(validity_errors, "user values")
297292

298293
result = UpdateInputsRunner.run(BaseClient(), self, update_inputs)
299294

@@ -357,12 +352,7 @@ def update_sortables(self, update_sortables: Dict[str, List[Any]]) -> None:
357352
"""
358353
# Validate the updates first
359354
validity_errors = self.sortables.is_valid_update(update_sortables)
360-
if validity_errors:
361-
error_summary = []
362-
for key, warning_collector in validity_errors.items():
363-
warnings_list = [w.message for w in warning_collector]
364-
error_summary.append(f"{key}: {warnings_list}")
365-
raise ScenarioError(f"Could not update sortables: {error_summary}")
355+
self._handle_validity_errors(validity_errors, "sortables")
366356

367357
# Make individual API calls for each sortable as there is no bulk endpoint
368358
for name, order in update_sortables.items():
@@ -443,16 +433,9 @@ def update_custom_curves(self, custom_curves) -> None:
443433
Args:
444434
custom_curves: CustomCurves object containing curves to upload
445435
"""
446-
447436
# Validate curves before uploading
448437
validity_errors = custom_curves.validate_for_upload()
449-
# TODO: Extract all these validity_errors thingys to a single util or something, lots of repetition at the moment
450-
if validity_errors:
451-
error_summary = []
452-
for key, warning_collector in validity_errors.items():
453-
warnings_list = [w.message for w in warning_collector]
454-
error_summary.append(f"{key}: {warnings_list}")
455-
raise ScenarioError(f"Could not update custom curves: {error_summary}")
438+
self._handle_validity_errors(validity_errors, "custom curves")
456439

457440
# Upload curves
458441
result = UpdateCustomCurvesRunner.run(BaseClient(), self, custom_curves)
@@ -556,3 +539,19 @@ def show_all_warnings(self) -> None:
556539
if submodel is not None and len(submodel.warnings) > 0:
557540
print(f"\n{name} warnings:")
558541
submodel.show_warnings()
542+
543+
def _handle_validity_errors(
544+
self, validity_errors: Dict[str, Any], context: str
545+
) -> None:
546+
"""
547+
Helper method to format and raise ScenarioError for validity errors.
548+
"""
549+
if not validity_errors:
550+
return
551+
552+
error_summary = []
553+
for key, warning_collector in validity_errors.items():
554+
warnings_list = [w.message for w in warning_collector]
555+
error_summary.append(f"{key}: {warnings_list}")
556+
557+
raise ScenarioError(f"Could not update {context}: {error_summary}")

src/pyetm/models/scenario_packer.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,41 @@ def extract_from_main_sheet(
5656
candidate_series = main_df.iloc[:, 0]
5757

5858
return ExportConfigResolver._parse_config_from_series(candidate_series)
59-
except Exception:
59+
except Exception as e:
60+
logger.exception("Error extracting from main sheet: %s", e)
6061
return None
6162

6263
@staticmethod
63-
def _parse_config_from_series(series: pd.Series) -> ExportConfig:
64+
def _parse_config_from_series(series: pd.Series) -> "ExportConfig":
6465
"""Parse ExportConfig from a pandas Series (column from main sheet)."""
65-
index_map = {str(idx).strip().lower(): idx for idx in series.index}
66+
67+
def _iter_rows():
68+
for label, value in zip(series.index, series.values):
69+
yield str(label).strip().lower(), value
70+
71+
def _value_after_output(name: str) -> Any:
72+
target = name.strip().lower()
73+
seen_output = False
74+
chosen: Any = None
75+
for lbl, val in _iter_rows():
76+
if lbl == "output":
77+
seen_output = True
78+
continue
79+
if seen_output and lbl == target:
80+
chosen = val
81+
return chosen
82+
83+
def _value_any(name: str) -> Any:
84+
target = name.strip().lower()
85+
chosen: Any = None
86+
for lbl, val in _iter_rows():
87+
if lbl == target:
88+
chosen = val
89+
return chosen
6690

6791
def get_cell_value(name: str) -> Any:
68-
key = name.strip().lower()
69-
original_key = index_map.get(key)
70-
return series.get(original_key) if original_key is not None else None
92+
val = _value_after_output(name)
93+
return val if val is not None else _value_any(name)
7194

7295
def parse_bool(value: Any) -> Optional[bool]:
7396
"""Parse boolean from various formats."""
@@ -88,26 +111,47 @@ def parse_bool(value: Any) -> Optional[bool]:
88111
return False
89112
return None
90113

114+
def parse_bool_field(*names: str) -> Optional[bool]:
115+
"""Return the first non-None boolean parsed from the provided field names."""
116+
for n in names:
117+
val = parse_bool(get_cell_value(n))
118+
if val is not None:
119+
return val
120+
return None
121+
91122
def parse_carriers(value: Any) -> Optional[List[str]]:
92123
"""Parse comma-separated carrier list."""
93124
if not isinstance(value, str) or not value.strip():
94125
return None
95126
return [carrier.strip() for carrier in value.split(",") if carrier.strip()]
96127

97-
carriers_raw = get_cell_value("exports") or get_cell_value("output_carriers")
128+
exports_val = get_cell_value("exports")
129+
carriers_val = get_cell_value("output_carriers")
98130

99-
return ExportConfig(
100-
include_inputs=parse_bool(get_cell_value("inputs")),
101-
include_sortables=parse_bool(get_cell_value("sortables")),
102-
include_custom_curves=parse_bool(get_cell_value("custom_curves")),
131+
exports_bool = parse_bool(exports_val)
132+
if exports_bool is True:
133+
output_carriers = ["electricity", "hydrogen", "heat", "methane"]
134+
elif exports_bool is False:
135+
output_carriers = None
136+
else:
137+
output_carriers = parse_carriers(carriers_val) or parse_carriers(
138+
exports_val
139+
)
140+
141+
config = ExportConfig(
142+
include_inputs=parse_bool_field("include_inputs", "inputs"),
143+
include_sortables=parse_bool_field("include_sortables", "sortables"),
144+
include_custom_curves=parse_bool_field(
145+
"include_custom_curves", "custom_curves"
146+
),
103147
include_gqueries=(
104-
parse_bool(get_cell_value("gquery_results"))
105-
or parse_bool(get_cell_value("gqueries"))
148+
parse_bool_field("include_gqueries", "gquery_results", "gqueries")
106149
),
107150
inputs_defaults=parse_bool(get_cell_value("defaults")),
108151
inputs_min_max=parse_bool(get_cell_value("min_max")),
109-
output_carriers=parse_carriers(carriers_raw),
152+
output_carriers=output_carriers,
110153
)
154+
return config
111155

112156

113157
class ScenarioPacker(BaseModel):

0 commit comments

Comments
 (0)