Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import omegaconf as oc
import pandas as pd
import xarray as xr
from tqdm import tqdm

Expand Down Expand Up @@ -69,6 +70,9 @@ class DataAvailability:


class Reader:

data: pd.DataFrame | None # Data attributes (if specified)

def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] | None = None):
"""
Generic data reader class.
Expand All @@ -85,8 +89,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] |
self.eval_cfg = eval_cfg
self.run_id = run_id
self.private_paths = private_paths

self.streams = eval_cfg.streams.keys()
self.data = None

# If results_base_dir and model_base_dir are not provided, default paths are used
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)
Expand Down Expand Up @@ -297,6 +301,96 @@ def _get_channels_fsteps_samples(self, stream: str, mode: str) -> DataAvailabili
ensemble=None if (ensemble == "all" or ensemble is None) else list(ensemble),
)

##### Helper function for CSVReader ####
def _rename_channels(data) -> pd.DataFrame:
"""
The scores downloaded from Quaver have a different convention. Need renaming.
Rename channel names to include underscore between letters and digits.
E.g., 'z500' -> 'z_500', 't850' -> 't_850', '2t' -> '2t', '10ff' -> '10ff'

Parameters
----------
name : str
Original channel name.

Returns
-------
pd.DataFrame
Dataset with renamed channel names.
"""
for name in list(data.index):
# If it starts with digits (surface vars like 2t, 10ff) → leave unchanged
if re.match(r"^\d", name):
continue

# Otherwise, insert underscore between letters and digits
data = data.rename(
index={name: re.sub(r"([a-zA-Z])(\d+)", r"\1_\2", name)}
)

return data

class CsvReader(Reader):
"""
Reader class to read evaluation data from CSV files and convert to xarray DataArray.
"""

def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
"""
Initialize the CsvReader.

Parameters
----------
eval_cfg : dir
config with plotting and evaluation options for that run id
run_id : str
run id of the model
private_paths: lists
list of private paths for the supported HPC
"""

super().__init__(eval_cfg, run_id, private_paths)
self.csv_path = eval_cfg.get("csv_path")
assert self.csv_path is not None, "CSV path must be provided in the config."

pd_data = pd.read_csv(self.csv_path, index_col=0)

self.data = _rename_channels(pd_data)
self.metrics_base_dir = Path(self.csv_path).parent
# for backward compatibility allow metric_dir to be specified in the run config
self.metrics_dir = Path(
self.eval_cfg.get(
"metrics_dir", self.metrics_base_dir / self.run_id / "evaluation"
)
)

assert len(eval_cfg.streams.keys()) == 1, "CsvReader only supports one stream."
self.stream = list(eval_cfg.streams.keys())[0]
self.channels = self.data.index.tolist()
self.samples = [0]
self.forecast_steps = [int(col.split()[0]) for col in self.data.columns]
self.npoints_per_sample = [0]
self.epoch = eval_cfg.get("epoch", 0)
self.metric = eval_cfg.get("metric")
self.region = eval_cfg.get("region")

def get_samples(self) -> set[int]:
""" get set of samples for the retrieved scores (initialisation times) """
return set(self.samples) # Placeholder implementation

def get_forecast_steps(self) -> set[int]:
""" get set of forecast steps """
return set(self.forecast_steps) # Placeholder implementation

# TODO: get this from config
def get_channels(self, stream: str | None = None) -> list[str]:
""" get set of channels """
assert stream == self.stream, "streams do not match in CSVReader."
return list(self.channels) # Placeholder implementation

def get_values(self) -> xr.DataArray:
""" get score values in the right format """
return self.data.values[np.newaxis, :, :, np.newaxis].T

class WeatherGenReader(Reader):
def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = None):
Expand Down
16 changes: 12 additions & 4 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from omegaconf import OmegaConf

from weathergen.common.config import _REPO_ROOT
from weathergen.evaluate.io_reader import WeatherGenReader
from weathergen.evaluate.io_reader import CsvReader, WeatherGenReader
from weathergen.evaluate.utils import (
calc_scores_per_stream,
metric_list_to_json,
plot_data,
plot_summary,
retrieve_metric_from_json,
retrieve_metric_from_file,
)

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,7 +83,15 @@ def evaluate_from_config(cfg):
for run_id, run in runs.items():
_logger.info(f"RUN {run_id}: Getting data...")

reader = WeatherGenReader(run, run_id, private_paths)
type = run.get("type", "zarr")
if type == "zarr":
reader = WeatherGenReader(run, run_id, private_paths)
elif type == "csv":
reader = CsvReader(run, run_id, private_paths)
else:
raise ValueError(
f"Unknown run type {type} for run {run_id}. Supported: zarr, csv."
)

for stream in reader.streams:
_logger.info(f"RUN {run_id}: Processing stream {stream}...")
Expand All @@ -107,7 +115,7 @@ def evaluate_from_config(cfg):

for metric in metrics:
try:
metric_data = retrieve_metric_from_json(
metric_data = retrieve_metric_from_file(
reader,
stream,
region,
Expand Down
60 changes: 48 additions & 12 deletions packages/evaluate/src/weathergen/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ def metric_list_to_json(
)


def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric: str):
def retrieve_metric_from_file(reader: Reader, stream: str, region: str, metric: str):
"""
Retrieve the score for a given run, stream, metric, epoch, and rank from a JSON file.
Retrieve the score for a given run, stream, metric, epoch, and rank from a given file (Json or csv).

Parameters
----------
Expand All @@ -378,18 +378,54 @@ def retrieve_metric_from_json(reader: Reader, stream: str, region: str, metric:
xr.DataArray
The metric DataArray.
"""
score_path = (
Path(reader.metrics_dir)
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")
if reader.data is not None:
available_data = reader.check_availability(stream, mode="evaluation")

# fill it only for matching metric
if (
metric == reader.metric
and region == reader.region
and stream == reader.stream
):
data = reader.get_values()
else:

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
data = np.full(
(
len(available_data.samples),
len(available_data.fsteps),
len(available_data.channels),
1,
),
np.nan,
)

da = xr.DataArray(
data.astype(np.float32),
dims=("sample", "forecast_step", "channel", "metric"),
coords={
"sample": available_data.samples,
"forecast_step": available_data.fsteps,
"channel": available_data.channels,
"metric": [metric],
},
attrs={"npoints_per_sample": reader.npoints_per_sample},
)

return da
else:
raise FileNotFoundError(f"File {score_path} not found in the archive.")
score_path = (
Path(reader.metrics_dir)
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")

if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
return xr.DataArray.from_dict(data_dict)
else:
raise FileNotFoundError(f"File {score_path} not found in the archive.")


def plot_summary(cfg: dict, scores_dict: dict, summary_dir: Path):
Expand Down