From 8e05a4b6ed8cb7f125e25d14247b6cfb42224353 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Thu, 18 Sep 2025 16:34:01 +0000 Subject: [PATCH 1/7] first version of quaver reader --- .../src/weathergen/evaluate/io_reader.py | 565 ++++++++++++------ .../src/weathergen/evaluate/run_evaluation.py | 57 +- .../evaluate/src/weathergen/evaluate/utils.py | 94 +-- 3 files changed, 483 insertions(+), 233 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index b73e85fa2..b58791414 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -15,6 +15,8 @@ import omegaconf as oc import xarray as xr from tqdm import tqdm +import re +import pandas as pd from weathergen.common.config import load_config, load_model_config from weathergen.common.io import ZarrIO @@ -25,7 +27,7 @@ @dataclass -class WeatherGeneratorOutput: +class ReaderOutput: """ Dataclass to hold the output of the Reader.get_data method. Attributes @@ -82,20 +84,297 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.eval_cfg = eval_cfg self.run_id = run_id self.private_paths = private_paths - self.streams = eval_cfg.streams.keys() - self.epoch = eval_cfg.epoch - self.rank = eval_cfg.rank # If results_base_dir and model_base_dir are not provided, default paths are used self.model_base_dir = self.eval_cfg.get("model_base_dir", None) - # Load model configuration and set (run-id specific) directories - self.inference_cfg = self.get_inference_config() - self.results_base_dir = self.eval_cfg.get( "results_base_dir", None ) # base directory where results will be stored + + def get_stream(self, stream: str): + """ + returns the dictionary associated to a particular stream + + Parameters + ---------- + stream: str + the stream name + + Returns + ------- + dict + the config dictionary associated to that stream + """ + return self.eval_cfg.streams.get(stream, {}) + + def get_samples(self) -> set[int]: + return set() # Placeholder implementation + + def get_forecast_steps(self) -> set[int]: + return set() # Placeholder implementation + + # TODO: get this from config + def get_channels(self, stream: str | None = None) -> list[str]: + return list() # Placeholder implementation + + def check_availability( + self, + stream: str, + available_data: dict = None, + mode: str = "", + ) -> DataAvailability: + """ + Check if requested channels, forecast steps and samples are + i) available in the previously saved json if metric data is specified (return False otherwise) + ii) available in the Zarr file (return error otherwise) + Additionally, if channels, forecast steps or samples is None/'all', it will + i) set the variable to all available vars in Zarr file + ii) return True only if the respective variable contains the same indeces in JSON and Zarr (return False otherwise) + + Parameters + ---------- + stream : str + The stream considered. + available_data : dict, optional + The available data loaded from JSON. + Returns + ------- + DataAvailability + A dataclass containing: + - channels: list of channels or None if 'all' + - fsteps: list of forecast steps or None if 'all' + - samples: list of samples or None if 'all' + """ + + # fill info for requested channels, fsteps, samples + requested_data = self._get_channels_fsteps_samples(stream, mode) + + channels = requested_data.channels + fsteps = requested_data.fsteps + samples = requested_data.samples + + requested = { + "channel": set(channels) if channels is not None else None, + "fstep": set(fsteps) if fsteps is not None else None, + "sample": set(samples) if samples is not None else None, + } + + # fill info from available json file (if provided) + available = { + "channel": set(available_data["channel"].values.ravel()) + if available_data is not None + else {}, + "fstep": set(available_data["forecast_step"].values.ravel()) + if available_data is not None + else {}, + "sample": set(available_data.coords["sample"].values.ravel()) + if available_data is not None + else {}, + } + + # fill info from reader + reader_data = { + "fstep": set(int(f) for f in self.get_forecast_steps()), + "sample": set(int(s) for s in self.get_samples()), + "channel": set(self.get_channels(stream)), + } + + check_json = True + corrected = False + for name in ["channel", "fstep", "sample"]: + if requested[name] is None: + # Default to all in Zarr + requested[name] = reader_data[name] + # If JSON exists, must exactly match + if available_data is not None and reader_data[name] != available[name]: + _logger.info( + f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." + ) + check_json = False + + # Must be subset of Zarr + if not requested[name] <= reader_data[name]: + missing = requested[name] - reader_data[name] + _logger.info( + f"Requested {name}(s) {missing} do(es) not exist in Zarr. " + f"Removing missing {name}(s) for {mode}." + ) + requested[name] = requested[name] & reader_data[name] + corrected = True + + # Must be a subset of available_data (if provided) + if available_data is not None and not requested[name] <= available[name]: + missing = requested[name] - available[name] + _logger.info( + f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." + ) + check_json = False + + if check_json and not corrected: + scope = "metric file" if available_data is not None else "source file" + _logger.info( + f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." + ) + + return DataAvailability( + json_availability=check_json, + channels=sorted(list(requested["channel"])), + fsteps=sorted(list(requested["fstep"])), + samples=sorted(list(requested["sample"])), + ) + + + def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailability: + """ + Get channels, fsteps and samples for a given run and stream from the config. Replace 'all' with None. + + Parameters + ---------- + stream: str + The stream considered. + mode: str + if plotting or evaluation mode + + Returns + ------- + DataAvailability + A dataclass containing: + - channels: list of channels or None if 'all' + - fsteps: list of forecast steps or None if 'all' + - samples: list of samples or None if 'all' + """ + assert mode == "plotting" or mode == "evaluation", ( + "get_channels_fsteps_samples:: Mode should be either 'plotting' or 'evaluation'" + ) + + stream_cfg = self.get_stream(stream) + assert stream_cfg.get(mode, False), ( + "Mode does not exist in stream config. Please add it." + ) + + samples = stream_cfg[mode].get("sample", None) + fsteps = stream_cfg[mode].get("forecast_step", None) + channels = stream_cfg.get("channels", None) + + return DataAvailability( + json_availability=True, + channels=None + if (channels == "all" or channels is None) + else list(channels), + fsteps=None if (fsteps == "all" or fsteps is None) else list(fsteps), + samples=None if (samples == "all" or samples is None) else list(samples), + ) + + +class CsvReader(Reader): + """ + Reader class to read evaluation data from CSV files and convert to xarray DataArray. + """ + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + """ + Initialize the CsvReader. + + Parameters + ---------- + eval_cfg : dir + config with plotting and evaluation options for that run id + run_id : str + run id of the model + private_paths: lists + list of private paths for the supported HPC + """ + + super().__init__(eval_cfg, run_id, private_paths) + self.csv_path = eval_cfg.get("csv_path", None) + assert self.csv_path is not None, "CSV path must be provided in the config." + + self.data = pd.read_csv(self.csv_path, index_col=0) + + self.data = self.rename_channels() + self.metrics_base_dir = Path(self.csv_path).parent + # for backward compatibility allow metric_dir to be specified in the run config + self.metrics_dir = Path( + self.eval_cfg.get( + "metrics_dir", self.metrics_base_dir / self.run_id / "evaluation" + ) + ) + + assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." + self.stream = list(eval_cfg.streams.keys())[0] + self.channels = self.data.index.tolist() + self.samples = [0] + self.forecast_steps = [int(col.split()[0]) for col in self.data.columns] + self.npoints_per_sample = [0] + self.epoch = eval_cfg.get("epoch", 0) + self.metric = eval_cfg.get("metric") + self.region = eval_cfg.get("region") + # self.data = xr.DataArray( + # data.astype(np.float32), + # dims=("forecast_step", "sample", "channel"), + # coords={ + # "forecast_step": self.forecast_steps, + # "sample": self.samples, + # "channel": self.channels, + # "stream": self.stream, + # }, + # attrs={"npoints_per_sample": self.npoints_per_sample}, + # ) + + # da_dict = da.to_dict() + + + def rename_channels(self) -> str: + """ + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + + Parameters + ---------- + name : str + Original channel name. + + Returns + ------- + str + Renamed channel name. + """ + for name in list(self.data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + self.data = self.data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) + + return self.data + + def get_samples(self) -> set[int]: + return set(self.samples) # Placeholder implementation + + def get_forecast_steps(self) -> set[int]: + return set(self.forecast_steps) # Placeholder implementation + + # TODO: get this from config + def get_channels(self, stream: str | None = None) -> list[str]: + assert stream == self.stream, "streams do not match in CSVReader." + return list(self.channels) # Placeholder implementation + + + + +class WeatherGenReader(Reader): + + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): + super().__init__(eval_cfg, run_id, private_paths) + + self.epoch = eval_cfg.epoch + self.rank = eval_cfg.rank + + # Load model configuration and set (run-id specific) directories + self.inference_cfg = self.get_inference_config() if not self.results_base_dir: self.results_base_dir = Path(self.inference_cfg["run_path"]) @@ -162,22 +441,6 @@ def get_inference_config(self): return config - def get_stream(self, stream: str): - """ - returns the dictionary associated to a particular stream - - Parameters - ---------- - stream: str - the stream name - - Returns - ------- - dict - the config dictionary associated to that stream - """ - return self.eval_cfg.streams.get(stream, {}) - def get_data( self, stream: str, @@ -186,7 +449,7 @@ def get_data( fsteps: list[str] | None = None, channels: list[str] | None = None, return_counts: bool = False, - ) -> WeatherGeneratorOutput: + ) -> ReaderOutput: """ Retrieve prediction and target data for a given run from the Zarr store. @@ -211,7 +474,7 @@ def get_data( Returns ------- - WeatherGeneratorOutput + ReaderOutput A dataclass containing: - target: Dictionary of xarray DataArrays for targets, indexed by forecast step. - prediction: Dictionary of xarray DataArrays for predictions, indexed by forecast step. @@ -322,7 +585,7 @@ def get_data( fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=False) } - return WeatherGeneratorOutput( + return ReaderOutput( target=da_tars, prediction=da_preds, points_per_sample=points_per_sample ) @@ -384,149 +647,109 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): return stream.get(key, default) return default - def check_availability( - self, - stream: str, - available_data: dict = None, - mode: str = "", - ) -> DataAvailability: - """ - Check if requested channels, forecast steps and samples are - i) available in the previously saved json if metric data is specified (return False otherwise) - ii) available in the Zarr file (return error otherwise) - Additionally, if channels, forecast steps or samples is None/'all', it will - i) set the variable to all available vars in Zarr file - ii) return True only if the respective variable contains the same indeces in JSON and Zarr (return False otherwise) - - Parameters - ---------- - stream : str - The stream considered. - available_data : dict, optional - The available data loaded from JSON. - Returns - ------- - DataAvailability - A dataclass containing: - - channels: list of channels or None if 'all' - - fsteps: list of forecast steps or None if 'all' - - samples: list of samples or None if 'all' - """ - - # fill info for requested channels, fsteps, samples - requested_data = self._get_channels_fsteps_samples(stream, mode) - - channels = requested_data.channels - fsteps = requested_data.fsteps - samples = requested_data.samples - - requested = { - "channel": set(channels) if channels is not None else None, - "fstep": set(fsteps) if fsteps is not None else None, - "sample": set(samples) if samples is not None else None, - } - - # fill info from available json file (if provided) - available = { - "channel": set(available_data["channel"].values.ravel()) - if available_data is not None - else {}, - "fstep": set(available_data["forecast_step"].values.ravel()) - if available_data is not None - else {}, - "sample": set(available_data.coords["sample"].values.ravel()) - if available_data is not None - else {}, - } - - # fill info from reader - reader_data = { - "fstep": set(int(f) for f in self.get_forecast_steps()), - "sample": set(int(s) for s in self.get_samples()), - "channel": set(self.get_channels(stream)), - } - - check_json = True - corrected = False - for name in ["channel", "fstep", "sample"]: - if requested[name] is None: - # Default to all in Zarr - requested[name] = reader_data[name] - # If JSON exists, must exactly match - if available_data is not None and reader_data[name] != available[name]: - _logger.info( - f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." - ) - check_json = False - - # Must be subset of Zarr - if not requested[name] <= reader_data[name]: - missing = requested[name] - reader_data[name] - _logger.info( - f"Requested {name}(s) {missing} do(es) not exist in Zarr. " - f"Removing missing {name}(s) for {mode}." - ) - requested[name] = requested[name] & reader_data[name] - corrected = True - - # Must be a subset of available_data (if provided) - if available_data is not None and not requested[name] <= available[name]: - missing = requested[name] - available[name] - _logger.info( - f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." - ) - check_json = False - - if check_json and not corrected: - scope = "metric file" if available_data is not None else "Zarr file" - _logger.info( - f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." - ) + # def check_availability( + # self, + # stream: str, + # available_data: dict = None, + # mode: str = "", + # ) -> DataAvailability: + # """ + # Check if requested channels, forecast steps and samples are + # i) available in the previously saved json if metric data is specified (return False otherwise) + # ii) available in the Zarr file (return error otherwise) + # Additionally, if channels, forecast steps or samples is None/'all', it will + # i) set the variable to all available vars in Zarr file + # ii) return True only if the respective variable contains the same indeces in JSON and Zarr (return False otherwise) + + # Parameters + # ---------- + # stream : str + # The stream considered. + # available_data : dict, optional + # The available data loaded from JSON. + # Returns + # ------- + # DataAvailability + # A dataclass containing: + # - channels: list of channels or None if 'all' + # - fsteps: list of forecast steps or None if 'all' + # - samples: list of samples or None if 'all' + # """ + + # # fill info for requested channels, fsteps, samples + # requested_data = self._get_channels_fsteps_samples(stream, mode) + + # channels = requested_data.channels + # fsteps = requested_data.fsteps + # samples = requested_data.samples + + # requested = { + # "channel": set(channels) if channels is not None else None, + # "fstep": set(fsteps) if fsteps is not None else None, + # "sample": set(samples) if samples is not None else None, + # } + + # # fill info from available json file (if provided) + # available = { + # "channel": set(available_data["channel"].values.ravel()) + # if available_data is not None + # else {}, + # "fstep": set(available_data["forecast_step"].values.ravel()) + # if available_data is not None + # else {}, + # "sample": set(available_data.coords["sample"].values.ravel()) + # if available_data is not None + # else {}, + # } + + # # fill info from reader + # reader_data = { + # "fstep": set(int(f) for f in self.get_forecast_steps()), + # "sample": set(int(s) for s in self.get_samples()), + # "channel": set(self.get_channels(stream)), + # } + + # check_json = True + # corrected = False + # for name in ["channel", "fstep", "sample"]: + # if requested[name] is None: + # # Default to all in Zarr + # requested[name] = reader_data[name] + # # If JSON exists, must exactly match + # if available_data is not None and reader_data[name] != available[name]: + # _logger.info( + # f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." + # ) + # check_json = False + + # # Must be subset of Zarr + # if not requested[name] <= reader_data[name]: + # missing = requested[name] - reader_data[name] + # _logger.info( + # f"Requested {name}(s) {missing} do(es) not exist in Zarr. " + # f"Removing missing {name}(s) for {mode}." + # ) + # requested[name] = requested[name] & reader_data[name] + # corrected = True + + # # Must be a subset of available_data (if provided) + # if available_data is not None and not requested[name] <= available[name]: + # missing = requested[name] - available[name] + # _logger.info( + # f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." + # ) + # check_json = False + + # if check_json and not corrected: + # scope = "metric file" if available_data is not None else "Zarr file" + # _logger.info( + # f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." + # ) + + # return DataAvailability( + # json_availability=check_json, + # channels=sorted(list(requested["channel"])), + # fsteps=sorted(list(requested["fstep"])), + # samples=sorted(list(requested["sample"])), + # ) - return DataAvailability( - json_availability=check_json, - channels=sorted(list(requested["channel"])), - fsteps=sorted(list(requested["fstep"])), - samples=sorted(list(requested["sample"])), - ) - - def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailability: - """ - Get channels, fsteps and samples for a given run and stream from the config. Replace 'all' with None. - - Parameters - ---------- - stream: str - The stream considered. - mode: str - if plotting or evaluation mode - - Returns - ------- - DataAvailability - A dataclass containing: - - channels: list of channels or None if 'all' - - fsteps: list of forecast steps or None if 'all' - - samples: list of samples or None if 'all' - """ - assert mode == "plotting" or mode == "evaluation", ( - "get_channels_fsteps_samples:: Mode should be either 'plotting' or 'evaluation'" - ) - - stream_cfg = self.get_stream(stream) - assert stream_cfg.get(mode, False), ( - "Mode does not exist in stream config. Please add it." - ) - - samples = stream_cfg[mode].get("sample", None) - fsteps = stream_cfg[mode].get("forecast_step", None) - channels = stream_cfg.get("channels", None) - - return DataAvailability( - json_availability=True, - channels=None - if (channels == "all" or channels is None) - else list(channels), - fsteps=None if (fsteps == "all" or fsteps is None) else list(fsteps), - samples=None if (samples == "all" or samples is None) else list(samples), - ) diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 4e8472e19..0eaa69bc0 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -17,13 +17,13 @@ from omegaconf import DictConfig, OmegaConf from weathergen.common.config import _REPO_ROOT -from weathergen.evaluate.io_reader import Reader +from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader from weathergen.evaluate.utils import ( calc_scores_per_stream, metric_list_to_json, plot_data, plot_summary, - retrieve_metric_from_json, + retrieve_metric_from_file, ) _logger = logging.getLogger(__name__) @@ -77,7 +77,13 @@ def evaluate_from_config(cfg): for run_id, run in runs.items(): _logger.info(f"RUN {run_id}: Getting data...") - reader = Reader(run, run_id, private_paths) + type = run.get("type", "zarr") + if type == "zarr": + reader = WeatherGenReader(run, run_id, private_paths) + elif type == "csv": + reader = CsvReader(run, run_id, private_paths) + else: + raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.") for stream in reader.streams: _logger.info(f"RUN {run_id}: Processing stream {stream}...") @@ -95,31 +101,30 @@ def evaluate_from_config(cfg): metrics_to_compute = [] for metric in metrics: - try: - metric_data = retrieve_metric_from_json( - reader, - stream, - region, - metric, - ) - - available_data = reader.check_availability( - stream, metric_data, mode="evaluation" - ) + metric_data = retrieve_metric_from_file( + reader, + stream, + region, + metric, + ) - if not available_data.json_availability: - metrics_to_compute.append(metric) - else: - # simply select the chosen eval channels, samples, fsteps here... - scores_dict[metric][region][stream][run_id] = ( - metric_data.sel( - sample=available_data.samples, - channel=available_data.channels, - forecast_step=available_data.fsteps, - ) - ) - except (FileNotFoundError, KeyError): + available_data = reader.check_availability( + stream, metric_data, mode="evaluation" + ) + + if not available_data.json_availability: metrics_to_compute.append(metric) + else: + # simply select the chosen eval channels, samples, fsteps here... + scores_dict[metric][region][stream][run_id] = ( + metric_data.sel( + sample=available_data.samples, + channel=available_data.channels, + forecast_step=available_data.fsteps, + ) + ) + # except (FileNotFoundError, KeyError): + # metrics_to_compute.append(metric) if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 0ae0a1c69..670abc06b 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -52,30 +52,9 @@ def calc_scores_per_stream( ) available_data = reader.check_availability(stream, mode="evaluation") - - output_data = reader.get_data( - stream, - region=region, - fsteps=available_data.fsteps, - samples=available_data.samples, - channels=available_data.channels, - return_counts=True, - ) - - da_preds = output_data.prediction - da_tars = output_data.target - points_per_sample = output_data.points_per_sample - - # get coordinate information from retrieved data - fsteps = [int(k) for k in da_tars.keys()] - - first_da = list(da_preds.values())[0] - - # TODO: improve the way we handle samples. - samples = list(np.atleast_1d(np.unique(first_da.sample.values))) - channels = list(np.atleast_1d(first_da.channel.values)) - - metric_list = [] + channels = available_data.channels + samples = available_data.samples + fsteps = available_data.fsteps metric_stream = xr.DataArray( np.full( @@ -88,8 +67,24 @@ def calc_scores_per_stream( "channel": channels, "metric": metrics, }, + ) + + output_data = reader.get_data( + stream, + region=region, + fsteps=fsteps, + samples=samples, + channels=channels, + return_counts=True, ) + da_preds = output_data.prediction + da_tars = output_data.target + points_per_sample = output_data.points_per_sample + + # TODO: improve the way we handle samples. + metric_list = [] + for (fstep, tars), (_, preds) in zip( da_tars.items(), da_preds.items(), strict=False ): @@ -337,9 +332,9 @@ def metric_list_to_json( ) -def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str): +def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str): """ - Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file. + Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv). Parameters ---------- @@ -357,18 +352,45 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: xr.DataArray The metric DataArray. """ - score_path = ( - Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") + if hasattr(reader, "data") and reader.data is not None: + + available_data = reader.check_availability(stream, mode="evaluation") + + #empty DataArray with NaNs + data = np.full( + (len(available_data.samples), len(available_data.fsteps), len(available_data.channels), 1), + np.nan, + ) + #fill it only for matching metric + if metric == reader.metric and region == reader.region and stream == reader.stream: + data = reader.data.values[np.newaxis, :, :, np.newaxis].T + + da = xr.DataArray( + data.astype(np.float32), + dims=("sample", "forecast_step","channel", "metric"), + coords={ + "sample": available_data.samples, + "forecast_step": available_data.fsteps, + "channel": available_data.channels, + "metric": [metric], + }, + attrs={"npoints_per_sample": reader.npoints_per_sample}, + ) - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - return xr.DataArray.from_dict(data_dict) + return da else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") + score_path = ( + Path(reader.metrics_dir) + / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + return xr.DataArray.from_dict(data_dict) + else: + raise FileNotFoundError(f"File {score_path} not found in the archive.") def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): From a87444e78b2dc3e1545b54258a4d7f35382bfda3 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 19 Sep 2025 12:35:41 +0000 Subject: [PATCH 2/7] working version --- .../src/weathergen/evaluate/io_reader.py | 134 +----------------- .../src/weathergen/evaluate/run_evaluation.py | 8 +- .../evaluate/src/weathergen/evaluate/utils.py | 2 +- 3 files changed, 12 insertions(+), 132 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index b58791414..623d018ea 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -61,7 +61,7 @@ class DataAvailability: List of samples requested """ - json_availability: bool + score_availability: bool channels: list[str] | None fsteps: list[int] | None samples: list[int] | None @@ -122,7 +122,7 @@ def get_channels(self, stream: str | None = None) -> list[str]: def check_availability( self, stream: str, - available_data: dict = None, + available_data: dict | None = None, mode: str = "", ) -> DataAvailability: """ @@ -219,7 +219,7 @@ def check_availability( ) return DataAvailability( - json_availability=check_json, + score_availability=check_json, channels=sorted(list(requested["channel"])), fsteps=sorted(list(requested["fstep"])), samples=sorted(list(requested["sample"])), @@ -259,7 +259,7 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili channels = stream_cfg.get("channels", None) return DataAvailability( - json_availability=True, + score_availability=True, channels=None if (channels == "all" or channels is None) else list(channels), @@ -311,20 +311,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.epoch = eval_cfg.get("epoch", 0) self.metric = eval_cfg.get("metric") self.region = eval_cfg.get("region") - # self.data = xr.DataArray( - # data.astype(np.float32), - # dims=("forecast_step", "sample", "channel"), - # coords={ - # "forecast_step": self.forecast_steps, - # "sample": self.samples, - # "channel": self.channels, - # "stream": self.stream, - # }, - # attrs={"npoints_per_sample": self.npoints_per_sample}, - # ) - - # da_dict = da.to_dict() - def rename_channels(self) -> str: """ @@ -362,8 +348,6 @@ def get_channels(self, stream: str | None = None) -> list[str]: assert stream == self.stream, "streams do not match in CSVReader." return list(self.channels) # Placeholder implementation - - class WeatherGenReader(Reader): @@ -579,10 +563,10 @@ def get_data( # Safer than a list da_tars = { - fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=False) + fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True) } da_preds = { - fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=False) + fstep: da for fstep, da in zip(fsteps_final, da_preds, strict=True) } return ReaderOutput( @@ -647,109 +631,3 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None): return stream.get(key, default) return default - # def check_availability( - # self, - # stream: str, - # available_data: dict = None, - # mode: str = "", - # ) -> DataAvailability: - # """ - # Check if requested channels, forecast steps and samples are - # i) available in the previously saved json if metric data is specified (return False otherwise) - # ii) available in the Zarr file (return error otherwise) - # Additionally, if channels, forecast steps or samples is None/'all', it will - # i) set the variable to all available vars in Zarr file - # ii) return True only if the respective variable contains the same indeces in JSON and Zarr (return False otherwise) - - # Parameters - # ---------- - # stream : str - # The stream considered. - # available_data : dict, optional - # The available data loaded from JSON. - # Returns - # ------- - # DataAvailability - # A dataclass containing: - # - channels: list of channels or None if 'all' - # - fsteps: list of forecast steps or None if 'all' - # - samples: list of samples or None if 'all' - # """ - - # # fill info for requested channels, fsteps, samples - # requested_data = self._get_channels_fsteps_samples(stream, mode) - - # channels = requested_data.channels - # fsteps = requested_data.fsteps - # samples = requested_data.samples - - # requested = { - # "channel": set(channels) if channels is not None else None, - # "fstep": set(fsteps) if fsteps is not None else None, - # "sample": set(samples) if samples is not None else None, - # } - - # # fill info from available json file (if provided) - # available = { - # "channel": set(available_data["channel"].values.ravel()) - # if available_data is not None - # else {}, - # "fstep": set(available_data["forecast_step"].values.ravel()) - # if available_data is not None - # else {}, - # "sample": set(available_data.coords["sample"].values.ravel()) - # if available_data is not None - # else {}, - # } - - # # fill info from reader - # reader_data = { - # "fstep": set(int(f) for f in self.get_forecast_steps()), - # "sample": set(int(s) for s in self.get_samples()), - # "channel": set(self.get_channels(stream)), - # } - - # check_json = True - # corrected = False - # for name in ["channel", "fstep", "sample"]: - # if requested[name] is None: - # # Default to all in Zarr - # requested[name] = reader_data[name] - # # If JSON exists, must exactly match - # if available_data is not None and reader_data[name] != available[name]: - # _logger.info( - # f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." - # ) - # check_json = False - - # # Must be subset of Zarr - # if not requested[name] <= reader_data[name]: - # missing = requested[name] - reader_data[name] - # _logger.info( - # f"Requested {name}(s) {missing} do(es) not exist in Zarr. " - # f"Removing missing {name}(s) for {mode}." - # ) - # requested[name] = requested[name] & reader_data[name] - # corrected = True - - # # Must be a subset of available_data (if provided) - # if available_data is not None and not requested[name] <= available[name]: - # missing = requested[name] - available[name] - # _logger.info( - # f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." - # ) - # check_json = False - - # if check_json and not corrected: - # scope = "metric file" if available_data is not None else "Zarr file" - # _logger.info( - # f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." - # ) - - # return DataAvailability( - # json_availability=check_json, - # channels=sorted(list(requested["channel"])), - # fsteps=sorted(list(requested["fstep"])), - # samples=sorted(list(requested["sample"])), - # ) - diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 0eaa69bc0..7f74bb84e 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -108,11 +108,15 @@ def evaluate_from_config(cfg): metric, ) + if metric_data is None: + metrics_to_compute.append(metric) + continue + available_data = reader.check_availability( stream, metric_data, mode="evaluation" ) - if not available_data.json_availability: + if not available_data.score_availability: metrics_to_compute.append(metric) else: # simply select the chosen eval channels, samples, fsteps here... @@ -123,8 +127,6 @@ def evaluate_from_config(cfg): forecast_step=available_data.fsteps, ) ) - # except (FileNotFoundError, KeyError): - # metrics_to_compute.append(metric) if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 670abc06b..4c46a0aaa 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -390,7 +390,7 @@ def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: data_dict = json.load(f) return xr.DataArray.from_dict(data_dict) else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") + return None def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): From 4d3a63db8a907fae3b079a807075b31f2945e6bd Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Fri, 19 Sep 2025 14:18:49 +0000 Subject: [PATCH 3/7] add CSVReader --- .../src/weathergen/evaluate/io_reader.py | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index f69091802..cedb2dfc9 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -110,14 +110,17 @@ def get_stream(self, stream: str): return self.eval_cfg.streams.get(stream, {}) def get_samples(self) -> set[int]: - return set() # Placeholder implementation + """Placeholder implementation of sample getter. Override in subclass.""" + return set() def get_forecast_steps(self) -> set[int]: - return set() # Placeholder implementation + """Placeholder implementation forecast step getter. Override in subclass.""" + return set() # TODO: get this from config def get_channels(self, stream: str | None = None) -> list[str]: - return list() # Placeholder implementation + """Placeholder implementation channel names getter. Override in subclass.""" + return list() def check_availability( self, @@ -127,18 +130,18 @@ def check_availability( ) -> DataAvailability: """ Check if requested channels, forecast steps and samples are - i) available in the previously saved json if metric data is specified (return False otherwise) - ii) available in the Zarr file (return error otherwise) + i) available in the previously saved metric file if specified (return False otherwise) + ii) available in the source file (e.g. the Zarr file, return error otherwise) Additionally, if channels, forecast steps or samples is None/'all', it will - i) set the variable to all available vars in Zarr file - ii) return True only if the respective variable contains the same indeces in JSON and Zarr (return False otherwise) + i) set the variable to all available vars in source file + ii) return True only if the respective variable contains the same indeces in metric file and source file (return False otherwise) Parameters ---------- stream : str The stream considered. available_data : dict, optional - The available data loaded from JSON. + The available data loaded from metric file. Returns ------- DataAvailability @@ -161,7 +164,7 @@ def check_availability( "sample": set(samples) if samples is not None else None, } - # fill info from available json file (if provided) + # fill info from available metric file (if provided) available = { "channel": set(available_data["channel"].values.ravel()) if available_data is not None @@ -181,18 +184,18 @@ def check_availability( "channel": set(self.get_channels(stream)), } - check_json = True + check_score = True corrected = False for name in ["channel", "fstep", "sample"]: if requested[name] is None: # Default to all in Zarr requested[name] = reader_data[name] - # If JSON exists, must exactly match + # If file with metrics exists, must exactly match if available_data is not None and reader_data[name] != available[name]: _logger.info( f"Requested all {name}s for {mode}, but previous config was a strict subset. Recomputing." ) - check_json = False + check_score = False # Must be subset of Zarr if not requested[name] <= reader_data[name]: @@ -210,16 +213,16 @@ def check_availability( _logger.info( f"{name.capitalize()}(s) {missing} missing in previous evaluation. Recomputing." ) - check_json = False + check_score = False - if check_json and not corrected: - scope = "metric file" if available_data is not None else "source file" + if check_score and not corrected: + scope = "metric file" if available_data is not None else "Zarr file" _logger.info( f"All checks passed – All channels, samples, fsteps requested for {mode} are present in {scope}..." ) return DataAvailability( - score_availability=check_json, + score_availability=check_score, channels=sorted(list(requested["channel"])), fsteps=sorted(list(requested["fstep"])), samples=sorted(list(requested["sample"])), From cb8c6aa1075f4511cbdefe67acf7bb6150cb28a5 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Tue, 28 Oct 2025 11:21:29 +0000 Subject: [PATCH 4/7] rebase to develop --- .../src/weathergen/evaluate/io_reader.py | 69 +++++++++++-------- .../evaluate/src/weathergen/evaluate/utils.py | 25 +++---- 2 files changed, 53 insertions(+), 41 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 6feffe2c3..39ba21121 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -70,6 +70,9 @@ class DataAvailability: class Reader: + + data: pd.DataFrame | None # Data attributes (if specified) + def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): """ Generic data reader class. @@ -87,6 +90,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | self.run_id = run_id self.private_paths = private_paths self.streams = eval_cfg.streams.keys() + self.data = None # If results_base_dir and model_base_dir are not provided, default paths are used self.model_base_dir = self.eval_cfg.get("model_base_dir", None) @@ -297,6 +301,34 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili ensemble=None if (ensemble == "all" or ensemble is None) else list(ensemble), ) +##### Helper function for CSVReader #### +def _rename_channels(data) -> pd.DataFrame: + """ + The scores downloaded from Quaver have a different convention. Need renaming. + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + + Parameters + ---------- + name : str + Original channel name. + + Returns + ------- + pd.DataFrame + Dataset with renamed channel names. + """ + for name in list(data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + data = data.rename( + index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)} + ) + + return data class CsvReader(Reader): """ @@ -321,9 +353,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.csv_path = eval_cfg.get("csv_path") assert self.csv_path is not None, "CSV path must be provided in the config." - self.data = pd.read_csv(self.csv_path, index_col=0) + pd_data = pd.read_csv(self.csv_path, index_col=0) - self.data = self.rename_channels() + self.data = _rename_channels(pd_data) self.metrics_base_dir = Path(self.csv_path).parent # for backward compatibility allow metric_dir to be specified in the run config self.metrics_dir = Path( @@ -342,44 +374,23 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.metric = eval_cfg.get("metric") self.region = eval_cfg.get("region") - def rename_channels(self) -> str: - """ - Rename channel names to include underscore between letters and digits. - E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' - - Parameters - ---------- - name : str - Original channel name. - - Returns - ------- - str - Renamed channel name. - """ - for name in list(self.data.index): - # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged - if re.match(r"^\d", name): - continue - - # Otherwise, insert underscore between letters and digits - self.data = self.data.rename( - index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)} - ) - - return self.data - def get_samples(self) -> set[int]: + """ get set of samples for the retrieved scores (initialisation times) """ return set(self.samples) # Placeholder implementation def get_forecast_steps(self) -> set[int]: + """ get set of forecast steps """ return set(self.forecast_steps) # Placeholder implementation # TODO: get this from config def get_channels(self, stream: str | None = None) -> list[str]: + """ get set of channels """ assert stream == self.stream, "streams do not match in CSVReader." return list(self.channels) # Placeholder implementation + def get_values(self) -> xr.DataArray: + """ get score values in the right format """ + return self.data.values[np.newaxis, :, :, np.newaxis].T class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 4797d2f11..38fb98fb4 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -378,26 +378,27 @@ def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: xr.DataArray The metric DataArray. """ - if hasattr(reader, "data") and reader.data is not None: + if reader.data is not None: available_data = reader.check_availability(stream, mode="evaluation") - # empty DataArray with NaNs - data = np.full( - ( - len(available_data.samples), - len(available_data.fsteps), - len(available_data.channels), - 1, - ), - np.nan, - ) # fill it only for matching metric if ( metric == reader.metric and region == reader.region and stream == reader.stream ): - data = reader.data.values[np.newaxis, :, :, np.newaxis].T + data = reader.get_values() + else: + + data = np.full( + ( + len(available_data.samples), + len(available_data.fsteps), + len(available_data.channels), + 1, + ), + np.nan, + ) da = xr.DataArray( data.astype(np.float32), From 3d4f5b6796d17cdce971bbbea8e38cbe4c2187f4 Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 5 Nov 2025 13:54:31 +0100 Subject: [PATCH 5/7] add polimorphism --- .../src/weathergen/evaluate/io_reader.py | 94 +++++++++++++++++++ .../src/weathergen/evaluate/run_evaluation.py | 44 ++++----- .../evaluate/src/weathergen/evaluate/utils.py | 71 -------------- 3 files changed, 116 insertions(+), 93 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 4d4b394aa..9c9a04ad1 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -131,6 +131,10 @@ def get_channels(self, stream: str | None = None) -> list[str]: def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() + + def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """Placeholder to retrieve the score for a given run, stream, metric""" + return None def check_availability( self, @@ -402,6 +406,63 @@ def get_channels(self, stream: str | None = None) -> list[str]: def get_values(self) -> xr.DataArray: """ get score values in the right format """ return self.data.values[np.newaxis, :, :, np.newaxis].T + + def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """ + Retrieve the scores for a given run, stream and metric. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray. + """ + + available_data = self.check_availability(stream, mode="evaluation") + + # fill it only for matching metric + if ( + metric == self.metric + and region == self.region + and stream == self.stream + ): + data = self.get_values() + else: + + data = np.full( + ( + len(available_data.samples), + len(available_data.fsteps), + len(available_data.channels), + 1, + ), + np.nan, + ) + + da = xr.DataArray( + data.astype(np.float32), + dims=("sample", "forecast_step", "channel", "metric"), + coords={ + "sample": available_data.samples, + "forecast_step": available_data.fsteps, + "channel": available_data.channels, + "metric": [metric], + }, + attrs={"npoints_per_sample": reader.npoints_per_sample}, + ) + + return da + class WeatherGenReader(Reader): def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None): @@ -750,6 +811,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) + def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + """ + Retrieve the scores for a given run, stream and metric and epoch. + + Parameters + ---------- + reader : + Reader object containing all info for a specific run_id + stream : + Stream name. + region : + Region name. + metric : + Metric name. + + Returns + ------- + xr.DataArray + The metric DataArray or None if the file does not exist. + """ + score_path = ( + Path(self.metrics_dir) + / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" + ) + _logger.debug(f"Looking for: {score_path}") + + if score_path.exists(): + with open(score_path) as f: + data_dict = json.load(f) + return xr.DataArray.from_dict(data_dict) + else: + return None + def get_inference_stream_attr(self, stream_name: str, key: str, default=None): """ Get the value of a key for a specific stream from the a model config. diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index b259f99ad..67c87ebe3 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,8 +28,7 @@ calc_scores_per_stream, metric_list_to_json, plot_data, - plot_summary, - retrieve_metric_from_file, + plot_summary ) from weathergen.metrics.mlflow_utils import ( MlFlowUpload, @@ -143,29 +142,30 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metrics_to_compute = [] for metric in metrics: - try: - metric_data = retrieve_metric_from_file( - reader, - stream, - region, - metric, - ) + + metric_data = reader.retrieve_scores( + stream, + region, + metric, + ) - available_data = reader.check_availability( - stream, metric_data, mode="evaluation" - ) + if metric_data is None: + metrics_to_compute.append(metric) + continue + + available_data = reader.check_availability( + stream, metric_data, mode="evaluation" + ) - if not available_data.score_availability: - metrics_to_compute.append(metric) - else: - # simply select the chosen eval channels, samples, fsteps here... - scores_dict[metric][region][stream][run_id] = metric_data.sel( - sample=available_data.samples, - channel=available_data.channels, - forecast_step=available_data.fsteps, - ) - except (FileNotFoundError, KeyError): + if not available_data.score_availability: metrics_to_compute.append(metric) + else: + # simply select the chosen eval channels, samples, fsteps here... + scores_dict[metric][region][stream][run_id] = metric_data.sel( + sample=available_data.samples, + channel=available_data.channels, + forecast_step=available_data.fsteps, + ) if metrics_to_compute: all_metrics, points_per_sample = calc_scores_per_stream( diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 9a4cda4ff..43cb4a708 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -374,77 +374,6 @@ def metric_list_to_json( f"to {reader.metrics_dir}." ) - -def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str): - """ - Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv). - - Parameters - ---------- - reader : - Reader object containing all info for a specific run_id - stream : - Stream name. - region : - Region name. - metric : - Metric name. - - Returns - ------- - xr.DataArray - The metric DataArray. - """ - if reader.data is not None: - available_data = reader.check_availability(stream, mode="evaluation") - - # fill it only for matching metric - if ( - metric == reader.metric - and region == reader.region - and stream == reader.stream - ): - data = reader.get_values() - else: - - data = np.full( - ( - len(available_data.samples), - len(available_data.fsteps), - len(available_data.channels), - 1, - ), - np.nan, - ) - - da = xr.DataArray( - data.astype(np.float32), - dims=("sample", "forecast_step", "channel", "metric"), - coords={ - "sample": available_data.samples, - "forecast_step": available_data.fsteps, - "channel": available_data.channels, - "metric": [metric], - }, - attrs={"npoints_per_sample": reader.npoints_per_sample}, - ) - - return da - else: - score_path = ( - Path(reader.metrics_dir) - / f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json" - ) - _logger.debug(f"Looking for: {score_path}") - - if score_path.exists(): - with open(score_path) as f: - data_dict = json.load(f) - return xr.DataArray.from_dict(data_dict) - else: - raise FileNotFoundError(f"File {score_path} not found in the archive.") - - def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): """ Plot summary of the evaluation results. From 8a8f7ff4333a3e805f934caab77831acd764c8fd Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 5 Nov 2025 13:20:58 +0000 Subject: [PATCH 6/7] fix names --- .../src/weathergen/evaluate/io_reader.py | 17 ++++++++--------- .../src/weathergen/evaluate/run_evaluation.py | 2 +- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 9c9a04ad1..b1cf90862 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -9,6 +9,7 @@ import logging import re +import json from dataclasses import dataclass from pathlib import Path @@ -71,8 +72,6 @@ class DataAvailability: class Reader: - data: pd.DataFrame | None # Data attributes (if specified) - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): """ Generic data reader class. @@ -132,8 +131,8 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() - def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: - """Placeholder to retrieve the score for a given run, stream, metric""" + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + """Placeholder to load pre-computed scores for a given run, stream, metric""" return None def check_availability( @@ -407,9 +406,9 @@ def get_values(self) -> xr.DataArray: """ get score values in the right format """ return self.data.values[np.newaxis, :, :, np.newaxis].T - def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: """ - Retrieve the scores for a given run, stream and metric. + Load the existing scores for a given run, stream and metric. Parameters ---------- @@ -458,7 +457,7 @@ def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr. "channel": available_data.channels, "metric": [metric], }, - attrs={"npoints_per_sample": reader.npoints_per_sample}, + attrs={"npoints_per_sample": self.npoints_per_sample}, ) return da @@ -811,9 +810,9 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: dummy = zio.get_data(0, stream, zio.forecast_steps[0]) return list(dummy.prediction.as_xarray().coords["ens"].values) - def retrieve_computed_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ - Retrieve the scores for a given run, stream and metric and epoch. + Load the pre-computed scores for a given run, stream and metric and epoch. Parameters ---------- diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 67c87ebe3..3d0a085e3 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -143,7 +143,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: for metric in metrics: - metric_data = reader.retrieve_scores( + metric_data = reader.load_scores( stream, region, metric, From b9c283d5e70269f644c2a759061b24a90411857d Mon Sep 17 00:00:00 2001 From: Ilaria Luise Date: Wed, 5 Nov 2025 13:21:23 +0000 Subject: [PATCH 7/7] lint --- .../src/weathergen/evaluate/io_reader.py | 76 +++++++++---------- .../src/weathergen/evaluate/run_evaluation.py | 9 +-- .../evaluate/src/weathergen/evaluate/utils.py | 1 + 3 files changed, 38 insertions(+), 48 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index b1cf90862..0890b4ad8 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -7,9 +7,9 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import json import logging import re -import json from dataclasses import dataclass from pathlib import Path @@ -71,7 +71,6 @@ class DataAvailability: class Reader: - def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None): """ Generic data reader class. @@ -130,7 +129,7 @@ def get_channels(self, stream: str | None = None) -> list[str]: def get_ensemble(self, stream: str | None = None) -> list[str]: """Placeholder implementation ensemble member names getter. Override in subclass.""" return list() - + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: """Placeholder to load pre-computed scores for a given run, stream, metric""" return None @@ -315,34 +314,34 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili ensemble=None if (ensemble == "all" or ensemble is None) else list(ensemble), ) + ##### Helper function for CSVReader #### def _rename_channels(data) -> pd.DataFrame: - """ - The scores downloaded from Quaver have a different convention. Need renaming. - Rename channel names to include underscore between letters and digits. - E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' + """ + The scores downloaded from Quaver have a different convention. Need renaming. + Rename channel names to include underscore between letters and digits. + E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff' - Parameters - ---------- - name : str - Original channel name. + Parameters + ---------- + name : str + Original channel name. - Returns - ------- - pd.DataFrame - Dataset with renamed channel names. - """ - for name in list(data.index): - # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged - if re.match(r"^\d", name): - continue - - # Otherwise, insert underscore between letters and digits - data = data.rename( - index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)} - ) + Returns + ------- + pd.DataFrame + Dataset with renamed channel names. + """ + for name in list(data.index): + # If it starts with digits (surface vars like 2t, 10ff) → leave unchanged + if re.match(r"^\d", name): + continue + + # Otherwise, insert underscore between letters and digits + data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}) + + return data - return data class CsvReader(Reader): """ @@ -373,9 +372,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.metrics_base_dir = Path(self.csv_path).parent # for backward compatibility allow metric_dir to be specified in the run config self.metrics_dir = Path( - self.eval_cfg.get( - "metrics_dir", self.metrics_base_dir / self.run_id / "evaluation" - ) + self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation") ) assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream." @@ -389,27 +386,27 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non self.region = eval_cfg.get("region") def get_samples(self) -> set[int]: - """ get set of samples for the retrieved scores (initialisation times) """ + """get set of samples for the retrieved scores (initialisation times)""" return set(self.samples) # Placeholder implementation def get_forecast_steps(self) -> set[int]: - """ get set of forecast steps """ + """get set of forecast steps""" return set(self.forecast_steps) # Placeholder implementation # TODO: get this from config def get_channels(self, stream: str | None = None) -> list[str]: - """ get set of channels """ + """get set of channels""" assert stream == self.stream, "streams do not match in CSVReader." return list(self.channels) # Placeholder implementation def get_values(self) -> xr.DataArray: - """ get score values in the right format """ + """get score values in the right format""" return self.data.values[np.newaxis, :, :, np.newaxis].T - + def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: """ Load the existing scores for a given run, stream and metric. - + Parameters ---------- reader : @@ -430,14 +427,9 @@ def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray: available_data = self.check_availability(stream, mode="evaluation") # fill it only for matching metric - if ( - metric == self.metric - and region == self.region - and stream == self.stream - ): + if metric == self.metric and region == self.region and stream == self.stream: data = self.get_values() else: - data = np.full( ( len(available_data.samples), @@ -813,7 +805,7 @@ def get_ensemble(self, stream: str | None = None) -> list[str]: def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None: """ Load the pre-computed scores for a given run, stream and metric and epoch. - + Parameters ---------- reader : diff --git a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py index 3d0a085e3..8884b5e91 100755 --- a/packages/evaluate/src/weathergen/evaluate/run_evaluation.py +++ b/packages/evaluate/src/weathergen/evaluate/run_evaluation.py @@ -28,7 +28,7 @@ calc_scores_per_stream, metric_list_to_json, plot_data, - plot_summary + plot_summary, ) from weathergen.metrics.mlflow_utils import ( MlFlowUpload, @@ -116,9 +116,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: elif type == "csv": reader = CsvReader(run, run_id, private_paths) else: - raise ValueError( - f"Unknown run type {type} for run {run_id}. Supported: zarr, csv." - ) + raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.") for stream in reader.streams: _logger.info(f"RUN {run_id}: Processing stream {stream}...") @@ -142,7 +140,6 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: metrics_to_compute = [] for metric in metrics: - metric_data = reader.load_scores( stream, region, @@ -152,7 +149,7 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None: if metric_data is None: metrics_to_compute.append(metric) continue - + available_data = reader.check_availability( stream, metric_data, mode="evaluation" ) diff --git a/packages/evaluate/src/weathergen/evaluate/utils.py b/packages/evaluate/src/weathergen/evaluate/utils.py index 43cb4a708..6ba654bf0 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils.py @@ -374,6 +374,7 @@ def metric_list_to_json( f"to {reader.metrics_dir}." ) + def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path): """ Plot summary of the evaluation results.