Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 180 additions & 1 deletion packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import omegaconf as oc
import pandas as pd
import xarray as xr
from tqdm import tqdm

Expand Down Expand Up @@ -85,8 +87,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] |
self.eval_cfg = eval_cfg
self.run_id = run_id
self.private_paths = private_paths

self.streams = eval_cfg.streams.keys()
self.data = None

# If results_base_dir and model_base_dir are not provided, default paths are used
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)
Expand Down Expand Up @@ -128,6 +130,10 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Placeholder implementation ensemble member names getter. Override in subclass."""
return list()

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
"""Placeholder to load pre-computed scores for a given run, stream, metric"""
return None

def check_availability(
self,
stream: str,
Expand Down Expand Up @@ -309,6 +315,146 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
)


##### Helper function for CSVReader ####
def _rename_channels(data) -> pd.DataFrame:
"""
The scores downloaded from Quaver have a different convention. Need renaming.
Rename channel names to include underscore between letters and digits.
E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff'

Parameters
----------
name : str
Original channel name.

Returns
-------
pd.DataFrame
Dataset with renamed channel names.
"""
for name in list(data.index):
# If it starts with digits (surface vars like 2t, 10ff) → leave unchanged
if re.match(r"^\d", name):
continue

# Otherwise, insert underscore between letters and digits
data = data.rename(index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)})

return data


class CsvReader(Reader):
"""
Reader class to read evaluation data from CSV files and convert to xarray DataArray.
"""

def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""
Initialize the CsvReader.

Parameters
----------
eval_cfg : dir
config with plotting and evaluation options for that run id
run_id : str
run id of the model
private_paths: lists
list of private paths for the supported HPC
"""

super().__init__(eval_cfg, run_id, private_paths)
self.csv_path = eval_cfg.get("csv_path")
assert self.csv_path is not None, "CSV path must be provided in the config."

pd_data = pd.read_csv(self.csv_path, index_col=0)

self.data = _rename_channels(pd_data)
self.metrics_base_dir = Path(self.csv_path).parent
# for backward compatibility allow metric_dir to be specified in the run config
self.metrics_dir = Path(
self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation")
)

assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream."
self.stream = list(eval_cfg.streams.keys())[0]
self.channels = self.data.index.tolist()
self.samples = [0]
self.forecast_steps = [int(col.split()[0]) for col in self.data.columns]
self.npoints_per_sample = [0]
self.epoch = eval_cfg.get("epoch", 0)
self.metric = eval_cfg.get("metric")
self.region = eval_cfg.get("region")

def get_samples(self) -> set[int]:
"""get set of samples for the retrieved scores (initialisation times)"""
return set(self.samples) # Placeholder implementation

def get_forecast_steps(self) -> set[int]:
"""get set of forecast steps"""
return set(self.forecast_steps) # Placeholder implementation

# TODO: get this from config
def get_channels(self, stream: str | None = None) -> list[str]:
"""get set of channels"""
assert stream == self.stream, "streams do not match in CSVReader."
return list(self.channels) # Placeholder implementation

def get_values(self) -> xr.DataArray:
"""get score values in the right format"""
return self.data.values[np.newaxis, :, :, np.newaxis].T

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
"""
Load the existing scores for a given run, stream and metric.

Parameters
----------
reader :
Reader object containing all info for a specific run_id
stream :
Stream name.
region :
Region name.
metric :
Metric name.

Returns
-------
xr.DataArray
The metric DataArray.
"""

available_data = self.check_availability(stream, mode="evaluation")

# fill it only for matching metric
if metric == self.metric and region == self.region and stream == self.stream:
data = self.get_values()
else:
data = np.full(
(
len(available_data.samples),
len(available_data.fsteps),
len(available_data.channels),
1,
),
np.nan,
)

da = xr.DataArray(
data.astype(np.float32),
dims=("sample", "forecast_step", "channel", "metric"),
coords={
"sample": available_data.samples,
"forecast_step": available_data.fsteps,
"channel": available_data.channels,
"metric": [metric],
},
attrs={"npoints_per_sample": self.npoints_per_sample},
)

return da


class WeatherGenReader(Reader):
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""Data reader class for WeatherGenerator model outputs stored in Zarr format."""
Expand Down Expand Up @@ -656,6 +802,39 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
dummy = zio.get_data(0, stream, zio.forecast_steps[0])
return list(dummy.prediction.as_xarray().coords["ens"].values)

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
"""
Load the pre-computed scores for a given run, stream and metric and epoch.

Parameters
----------
reader :
Reader object containing all info for a specific run_id
stream :
Stream name.
region :
Region name.
metric :
Metric name.

Returns
-------
xr.DataArray
The metric DataArray or None if the file does not exist.
"""
score_path = (
Path(self.metrics_dir)
/ f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
else:
return None

def get_inference_stream_attr(self, stream_name: str, key: str, default=None):
"""
Get the value of a key for a specific stream from the a model config.
Expand Down
51 changes: 28 additions & 23 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

from weathergen.common.config import _REPO_ROOT
from weathergen.common.platform_env import get_platform_env
from weathergen.evaluate.io_reader import WeatherGenReader
from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader
from weathergen.evaluate.plot_utils import collect_channels
from weathergen.evaluate.utils import (
calc_scores_per_stream,
metric_list_to_json,
plot_data,
plot_summary,
retrieve_metric_from_json,
)
from weathergen.metrics.mlflow_utils import (
MlFlowUpload,
Expand Down Expand Up @@ -111,7 +110,13 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None:
for run_id, run in runs.items():
_logger.info(f"RUN {run_id}: Getting data...")

reader = WeatherGenReader(run, run_id, private_paths)
type = run.get("type", "zarr")
if type == "zarr":
reader = WeatherGenReader(run, run_id, private_paths)
elif type == "csv":
reader = CsvReader(run, run_id, private_paths)
else:
raise ValueError(f"Unknown run type {type} for run {run_id}. Supported: zarr, csv.")

for stream in reader.streams:
_logger.info(f"RUN {run_id}: Processing stream {stream}...")
Expand All @@ -135,29 +140,29 @@ def evaluate_from_config(cfg, mlflow_client: MlflowClient | None) -> None:
metrics_to_compute = []

for metric in metrics:
try:
metric_data = retrieve_metric_from_json(
reader,
stream,
region,
metric,
)
metric_data = reader.load_scores(
stream,
region,
metric,
)

available_data = reader.check_availability(
stream, metric_data, mode="evaluation"
)
if metric_data is None:
metrics_to_compute.append(metric)
continue

available_data = reader.check_availability(
stream, metric_data, mode="evaluation"
)

if not available_data.score_availability:
metrics_to_compute.append(metric)
else:
# simply select the chosen eval channels, samples, fsteps here...
scores_dict[metric][region][stream][run_id] = metric_data.sel(
sample=available_data.samples,
channel=available_data.channels,
forecast_step=available_data.fsteps,
)
except (FileNotFoundError, KeyError):
if not available_data.score_availability:
metrics_to_compute.append(metric)
else:
# simply select the chosen eval channels, samples, fsteps here...
scores_dict[metric][region][stream][run_id] = metric_data.sel(
sample=available_data.samples,
channel=available_data.channels,
forecast_step=available_data.fsteps,
)

if metrics_to_compute:
all_metrics, points_per_sample = calc_scores_per_stream(
Expand Down
34 changes: 0 additions & 34 deletions packages/evaluate/src/weathergen/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,40 +375,6 @@ def metric_list_to_json(
)


def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str):
"""
Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file.

Parameters
----------
reader :
Reader object containing all info for a specific run_id
stream :
Stream name.
region :
Region name.
metric :
Metric name.

Returns
-------
xr.DataArray
The metric DataArray.
"""
score_path = (
Path(reader.metrics_dir)
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
else:
raise FileNotFoundError(f"File {score_path} not found in the archive.")


def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
"""
Plot summary of the evaluation results.
Expand Down