diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index d7aaa30e9..0890b4ad8 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import json import logging import re from dataclasses import dataclass @@ -14,6 +15,7 @@ import numpy as np import omegaconf as oc +import pandas as pd import xarray as xr from tqdm import tqdm @@ -85,8 +87,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | self.eval_cfg = eval_cfg self.run_id = run_id self.private_paths = private_paths - self.streams = eval_cfg.streams.keys() + self.data = None # If results_base_dir and model_base_dir are not provided, default paths are used self.model_base_dir = self.eval_cfg.get("model_base_dir", None) @@ -128,6 +130,10 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """Placeholder to load pre-computed scores for a given run, stream, metric""" + return None + def check_availability( self, stream: str, @@ -309,6 +315,146 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili ) +##### Helper function for CSVReader #### +def _rename_channels(data) -> pd.DataFrame: + """ + The scores downloaded from Quaver have a different convention. Need renaming. + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + + Parameters + ---------- + name : str + Original channel name. + + Returns + ------- + pd.DataFrame + Dataset with renamed channel names. + """ + for name in list(data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) + + return data + + +class CsvReader(Reader): + """ + Reader class to read evaluation data from CSV files and convert to xarray DataArray. + """ + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + """ + Initialize the CsvReader. + + Parameters + ---------- + eval_cfg : dir + config with plotting and evaluation options for that run id + run_id : str + run id of the model + private_paths: lists + list of private paths for the supported HPC + """ + + super().__init__(eval_cfg, run_id, private_paths) + self.csv_path = eval_cfg.get("csv_path") + assert self.csv_path is not None, "CSV path must be provided in the config." + + pd_data = pd.read_csv(self.csv_path, index_col=0) + + self.data = _rename_channels(pd_data) + self.metrics_base_dir = Path(self.csv_path).parent + # for backward compatibility allow metric_dir to be specified in the run config + self.metrics_dir = Path( + self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") + ) + + assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." + self.stream = list(eval_cfg.streams.keys())[0] + self.channels = self.data.index.tolist() + self.samples = [0] + self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] + self.npoints_per_sample = [0] + self.epoch = eval_cfg.get("epoch", 0) + self.metric = eval_cfg.get("metric") + self.region = eval_cfg.get("region") + + def get_samples(self) -> set[int]: + """get set of samples for the retrieved scores (initialisation times)""" + return set(self.samples) # Placeholder implementation + + def get_forecast_steps(self) -> set[int]: + """get set of forecast steps""" + return set(self.forecast_steps) # Placeholder implementation + + # TODO: get this from config + def get_channels(self, stream: str | None = None) -> list[str]: + """get set of channels""" + assert stream == self.stream, "streams do not match in CSVReader." + return list(self.channels) # Placeholder implementation + + def get_values(self) -> xr.DataArray: + """get score values in the right format""" + return self.data.values[np.newaxis, :, :, np.newaxis].T + + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """ + Load the existing scores for a given run, stream and metric. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray. + """ + + available_data = self.check_availability(stream, mode="evaluation") + + # fill it only for matching metric + if metric == self.metric and region == self.region and stream == self.stream: + data = self.get_values() + else: + data = np.full( + ( + len(available_data.samples), + len(available_data.fsteps), + len(available_data.channels), + 1, + ), + np.nan, + ) + + da = xr.DataArray( + data.astype(np.float32), + dims=("sample", "forecast_step", "channel", "metric"), + coords={ + "sample": available_data.samples, + "forecast_step": available_data.fsteps, + "channel": available_data.channels, + "metric": [metric], + }, + attrs={"npoints_per_sample": self.npoints_per_sample}, + ) + + return da + + class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): """Data reader class for WeatherGenerator model outputs stored in Zarr format.""" @@ -656,6 +802,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + """ + Load the pre-computed scores for a given run, stream and metric and epoch. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray or None if the file does not exist. + """ + score_path = ( + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + return xr.DataArray.from_dict(data_dict) + else: + return None + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3e30acedd..8884b5e91 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -22,14 +22,13 @@ from weathergen.common.config import _REPO_ROOT from weathergen.common.platform_env import get_platform_env -from weathergen.evaluate.io_reader import WeatherGenReader +from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader from weathergen.evaluate.plot_utils import collect_channels from weathergen.evaluate.utils import ( calc_scores_per_stream, metric_list_to_json, plot_data, plot_summary, - retrieve_metric_from_json, ) from weathergen.metrics.mlflow_utils import ( MlFlowUpload, @@ -111,7 +110,13 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: for run_id, run in runs.items(): _logger.info(f"RUN {run_id}: Getting data...") - reader = WeatherGenReader(run, run_id, private_paths) + type = run.get("type", "zarr") + if type == "zarr": + reader = WeatherGenReader(run, run_id, private_paths) + elif type == "csv": + reader = CsvReader(run, run_id, private_paths) + else: + raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.") for stream in reader.streams: _logger.info(f"RUN {run_id}: Processing stream {stream}...") @@ -135,29 +140,29 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metrics_to_compute = [] for metric in metrics: - try: - metric_data = retrieve_metric_from_json( - reader, - stream, - region, - metric, - ) + metric_data = reader.load_scores( + stream, + region, + metric, + ) - available_data = reader.check_availability( - stream, metric_data, mode="evaluation" - ) + if metric_data is None: + metrics_to_compute.append(metric) + continue + + available_data = reader.check_availability( + stream, metric_data, mode="evaluation" + ) - if not available_data.score_availability: - metrics_to_compute.append(metric) - else: - # simply select the chosen eval channels, samples, fsteps here... - scores_dict[metric][region][stream][run_id] = metric_data.sel( - sample=available_data.samples, - channel=available_data.channels, - forecast_step=available_data.fsteps, - ) - except (FileNotFoundError, KeyError): + if not available_data.score_availability: metrics_to_compute.append(metric) + else: + # simply select the chosen eval channels, samples, fsteps here... + scores_dict[metric][region][stream][run_id] = metric_data.sel( + sample=available_data.samples, + channel=available_data.channels, + forecast_step=available_data.fsteps, + ) if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index ae29f7bc2..6ba654bf0 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -375,40 +375,6 @@ def metric_list_to_json( ) -def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): - """ - Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. - - Parameters - ---------- - reader : - Reader object containing all info for a specific run_id - stream : - Stream name. - region : - Region name. - metric : - Metric name. - - Returns - ------- - xr.DataArray - The metric DataArray. - """ - score_path = ( - Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") - - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - return xr.DataArray.from_dict(data_dict) - else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") - - def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): """ Plot summary of the evaluation results.