diff --git a/config/cams_eac4_config.yml b/config/cams_eac4_config.yml new file mode 100644 index 000000000..fa8eebcba --- /dev/null +++ b/config/cams_eac4_config.yml @@ -0,0 +1,20 @@ +streams_directory: "./config/streams/streams_cams_eac4/" + +start_date: 200301010000 +# end_date: 202112310000 +# start_date_val: 202201010000 +# end_date_val: 202205300000 + +num_epochs: 45 # 10 + +# samples_per_epoch: 700 # HERE +# samples_per_validation: 200 # HERE +# shuffle: True + +loader_num_workers: 4 + +masking_rate: 0.8 +forecast_offset : 1 # HERE + + +with_mixed_precision: True \ No newline at end of file diff --git a/config/config_eval.yml b/config/config_eval.yml new file mode 100644 index 000000000..df105c63a --- /dev/null +++ b/config/config_eval.yml @@ -0,0 +1,161 @@ +verbose: true +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + +summary_plots : true +print_summary: false + +evaluation: + metrics : ["rmse"] + regions: ["global", "nhem"] +run_ids : + wtqfk9i5: + label: "CAMS EAC4 forecast epoch=25 dim_embed = 512 token_size = 16 num_blocks = 2" + epoch: 0 + rank: 0 + streams: + # ERA5: + # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: false + CAMSEAC4: + channels: [ + # Surface variables + 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Ozone (o3) + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', + ] + evaluation: + forecast_step: "all" + sample: "all" + plotting: + sample: [0] + forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + plot_maps: true + plot_histograms: true + plot_animations: true + # e8fzh2t1: + # label: "CAMS forecast finetune dim_embed = 256 token_size = 16 num_blocks = 2" + # epoch: 0 + # rank: 0 + # streams: + # # ERA5: + # # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # # evaluation: + # # forecast_step: "all" + # # sample: "all" + # # plotting: + # # sample: [0] + # # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # # plot_maps: true + # # plot_histograms: true + # # plot_animations: false + # CAMSEAC4: + # channels: [ + # # Surface variables + # 'pm1', 'pm2p5', 'pm10', + # 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # # Ozone (o3) + # 'go3_1000', 'go3_500', 'go3_250','go3_50', + + # # Sulfur dioxide (so2) + # 'so2_1000', 'so2_500', 'so2_250','so2_50', + + + # # Nitrogen monoxide (no) + # 'no_1000', 'no_500', 'no_250','no_50', + + # # Nitrogen dioxide (no2) + # 'no2_1000', 'no2_500', 'no2_250','no2_50', + + + # # Carbon monoxide (co) + # 'co_1000', 'co_500', 'co_250','co_50', + + # ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: true + # z6aup4r1: + # label: "CAMS forecast finetune dim_embed = 512 token_size = 32 num_blocks = 2" + # epoch: 0 + # rank: 0 + # streams: + # # ERA5: + # # channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + # # evaluation: + # # forecast_step: "all" + # # sample: "all" + # # plotting: + # # sample: [0] + # # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # # plot_maps: true + # # plot_histograms: true + # # plot_animations: false + # CAMSEAC4: + # channels: [ + # # Surface variables + # 'pm1', 'pm2p5', 'pm10', + # 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # # Ozone (o3) + # 'go3_1000', 'go3_500', 'go3_250','go3_50', + + # # Sulfur dioxide (so2) + # 'so2_1000', 'so2_500', 'so2_250','so2_50', + + + # # Nitrogen monoxide (no) + # 'no_1000', 'no_500', 'no_250','no_50', + + # # Nitrogen dioxide (no2) + # 'no2_1000', 'no2_500', 'no2_250','no2_50', + + + # # Carbon monoxide (co) + # 'co_1000', 'co_500', 'co_250','co_50', + + # ] + # evaluation: + # forecast_step: "all" + # sample: "all" + # plotting: + # sample: [0] + # forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8] + # plot_maps: true + # plot_histograms: true + # plot_animations: true \ No newline at end of file diff --git a/config/streams/icon/icon.yml b/config/streams/icon/icon.yml index a38bbdc97..47275aed6 100644 --- a/config/streams/icon/icon.yml +++ b/config/streams/icon/icon.yml @@ -10,8 +10,8 @@ ICON : type : icon filenames : ['icon-art-NWP_OH_CHEMISTRY-chem_DOM01_ML_daily_repeat_reduced_levels.zarr'] - source : ['u_00', 'v_00', 'w_80', 'temp_00'] - target : ['u_00', 'v_00', 'w_80', 'temp_00'] + source_channels : ['u_00', 'v_00', 'w_80', 'temp_00'] + target_channels : ['u_00', 'v_00', 'w_80', 'temp_00'] loss_weight : 1. diagnostic : False masking_rate : 0.6 diff --git a/config/streams/streams_cams_eac4/cams_eac4.yml b/config/streams/streams_cams_eac4/cams_eac4.yml new file mode 100644 index 000000000..079396005 --- /dev/null +++ b/config/streams/streams_cams_eac4/cams_eac4.yml @@ -0,0 +1,112 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +CAMSEAC4 : + type : camseac4 + filenames : ['cams_eac4_2003_2024.zarr'] + source : [ + # Surface variables + 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Ozone (o3) + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', + ] + target : [ + # Surface variables + 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Ozone (o3) + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', + ] + # source_exclude : [] + # target_exclude : [] + variables: [ + # Surface variables + 'pm1', 'pm2p5', 'pm10', + 'tc_co', 'tc_no', 'tc_no2', 'tc_o3', 'tc_so2', + + # Ozone (o3) + 'go3_1000', 'go3_925', 'go3_850', 'go3_700', 'go3_600', 'go3_500', + 'go3_400', 'go3_300', 'go3_250', 'go3_200', 'go3_150', 'go3_100', 'go3_50', + + # Sulfur dioxide (so2) + 'so2_1000', 'so2_925', 'so2_850', 'so2_700', 'so2_600', 'so2_500', + 'so2_400', 'so2_300', 'so2_250', 'so2_200', 'so2_150', 'so2_100', 'so2_50', + + # Nitrogen monoxide (no) + 'no_1000', 'no_925', 'no_850', 'no_700', 'no_600', 'no_500', + 'no_400', 'no_300', 'no_250', 'no_200', 'no_150', 'no_100', 'no_50', + + # Nitrogen dioxide (no2) + 'no2_1000', 'no2_925', 'no2_850', 'no2_700', 'no2_600', 'no2_500', + 'no2_400', 'no2_300', 'no2_250', 'no2_200', 'no2_150', 'no2_100', 'no2_50', + + # Carbon monoxide (co) + 'co_1000', 'co_925', 'co_850', 'co_700', 'co_600', 'co_500', + 'co_400', 'co_300', 'co_250', 'co_200', 'co_150', 'co_100', 'co_50', + ] + pressure_levels : ["50", "100", "150", "200", "250", "300", "400", "500", "600", "700", "850", "925", "1000"] + loss_weight : 1. + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 16 + tokenize_spacetime : True + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/config/streams/streams_cams_eac4/era5.yml b/config/streams/streams_cams_eac4/era5.yml new file mode 100644 index 000000000..5561ef0c6 --- /dev/null +++ b/config/streams/streams_cams_eac4/era5.yml @@ -0,0 +1,37 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr'] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index a558a3478..657571351 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -64,6 +64,7 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> Load a configuration file from a given run_id and epoch. If run_id is a full path, loads it from the full path. """ + print(f"model_path = {model_path}") if Path(run_id).exists(): # load from the full path if a full path is provided fname = Path(run_id) _logger.info(f"Loading config from provided full run_id path: {fname}") @@ -76,6 +77,7 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> ) model_path = Path(model_path) fname = model_path / run_id / _get_model_config_file_name(run_id, epoch) + print(f"fname = {fname}") assert fname.exists(), ( "The fallback path to the model does not exist. Please provide a `model_path`." ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index 80ec2d65d..53a614364 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -15,7 +15,7 @@ from weathergen.common.config import _load_private_conf from weathergen.evaluate.plot_utils import DefaultMarkerSize -work_dir = Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy" +work_dir = Path("./assets/cartopy" ) # Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy" cartopy.config["data_dir"] = str(work_dir) cartopy.config["pre_existing_data_dir"] = str(work_dir) diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 8e01b189f..354a15aa8 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -327,7 +327,6 @@ def __len__(self) -> int: ------- length of dataset """ - return self.length() def get_source(self, idx: TIndex) -> ReaderData: diff --git a/src/weathergen/datasets/data_reader_cams.py b/src/weathergen/datasets/data_reader_cams.py new file mode 100644 index 000000000..6ba6621e8 --- /dev/null +++ b/src/weathergen/datasets/data_reader_cams.py @@ -0,0 +1,293 @@ +import json +import logging +from pathlib import Path +from typing import override + +import numpy as np +import xarray as xr + +from weathergen.datasets.data_reader_anemoi import _clip_lat, _clip_lon + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +import os, time +from typing import Sequence + +def _now_ms() -> int: + return int(time.time() * 1000) + +def _pfx() -> str: + # Helpful when multiple workers/ranks print at once + return f"[DATAREADER DEBUG:{os.environ.get('RANK', '?')}/{os.getpid()}]" + +import signal + +class _Timeout(Exception): pass +def _alarm_handler(signum, frame): raise _Timeout() + +signal.signal(signal.SIGALRM, _alarm_handler) + +############################################################################ + +_logger = logging.getLogger(__name__) + + +class DataReaderCams(DataReaderTimestep): + "Wrapper for CAMs data variables" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + """ + Parameters + ---------- + tw_handler : TimeWindowHandler + Handles temporal slicing and mapping from time indices to datetime + filename : + filename (and path) of dataset + stream_info : dict + Stream metadata + """ + + # ======= Reading the Dataset ================ + # open groups + ds_surface = xr.open_zarr(filename, group="surface", chunks={"time": 24}) + ds_profiles = xr.open_zarr(filename, group="profiles", chunks={"time": 24}) + + # merge along variables + self.ds = xr.merge([ds_surface, ds_profiles]) + + # Column (variable) names and indices + self.colnames = stream_info["variables"] # list(self.ds) + self.cols_idx = np.array(list(np.arange(len(self.colnames)))) + + # Load associated statistics file for normalization + stats_filename = Path(filename).with_name(Path(filename).stem + "_stats.json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + # Variables included in the stats + self.stats_vars = list(self.stats) + + # Load mean and standard deviation per variable + self.mean = np.array([self.stats[var]["mean"] for var in self.stats_vars], dtype=np.float64) + self.stdev = np.array([self.stats[var]["std"] for var in self.stats_vars], dtype=np.float64) + + # Extract coordinates and pressure level + self.lat = _clip_lat(self.ds["latitude"].values) + self.lon = _clip_lon(self.ds["longitude"].values) + self.levels = stream_info["pressure_levels"] + + # Time range in the dataset + self.time = self.ds["time"].values + start_ds = np.datetime64(self.time[0]) + end_ds = np.datetime64(self.time[-1]) + self.temporal_frequency = self.time[1] - self.time[0] + + if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + # print("inside skipping stream") + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + # Initialize parent class with resolved time window + super().__init__( + tw_handler, + stream_info, + start_ds, + end_ds, + self.temporal_frequency, + ) + + # Compute absolute start/end indices in the dataset based on time window + self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[ns]").astype(int) + self.end_idx = (tw_handler.t_end - start_ds).astype("timedelta64[ns]").astype(int) + 1 + + # Number of time steps in selected range + self.len = self.end_idx - self.start_idx + 1 + + # Placeholder; currently unused + self.step_hrs = 1 + + # Stream metadata + self.properties = { + "stream_id": 0, + } + + # === Normalization statistics === + + # Ensure stats match dataset columns + assert self.stats_vars == self.colnames, ( + f"Variables in normalization file {self.stats_vars} do not match " + f"dataset columns {self.colnames}" + ) + + # === Channel selection === + + # Source channels and levels + source_channels = stream_info.get("source") + if source_channels: + self.source_channels, self.source_idx = self.select(source_channels) + else: + self.source_channels = self.colnames + self.source_idx = self.cols_idx + # self.source_levels = self.get_levels(self.source_channels) + + # Target channels and levels + target_channels = stream_info.get("target") + if target_channels: + self.target_channels, self.target_idx = self.select(target_channels) + else: + self.target_channels = self.colnames + self.target_idx = self.cols_idx + # self.target_levels = self.get_levels(self.target_channels) + + + + # Ensure all selected channels have valid standard deviations + selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) + non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] + assert len(non_positive_stds) == 0, ( + f"Abort: Encountered non-positive standard deviations for selected columns " + f"{[self.colnames[selected_channel_indices][i] for i in non_positive_stds]}." + ) + + # === Geo-info channels (currently unused) === + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + + def select(self, ch_filters: list[str]) -> (np.array, list[str]): + """ + Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + Parameters + ---------- + ch_filters: list[str] + list of patterns to access + Returns + ------- + selected_colnames: np.array, + Selected columns according to the patterns specified in ch_filters + selected_cols_idx + respective index of these patterns in the data array + """ + mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] + + selected_cols_idx = self.cols_idx[np.where(mask)[0]] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + + @override + def length(self) -> int: + """ + Length of dataset + Parameters + ---------- + None + Returns + ------- + length of dataset + """ + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + """ + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + assert t_idxs[0] >= 0, "index must be non-negative" + t0 = t_idxs[0] + t1 = t_idxs[-1] + 1 # end is exclusive + T = t1 - t0 + + nlat = len(self.lat) + nlon = len(self.lon) + # channels to read + channels = np.array(self.colnames)[channels_idx].tolist() + + # --- read & shape data to match anemoi path: (T, C, G) -> (T, G, C) -> (T*G, C) + data_per_channel = [] + try: + for ch in channels: + ch_parts = ch.split("_") + # retrieving profile channels + if len(ch_parts) == 2 and ch_parts[1] in self.levels : + ch_ = ch_parts[0] + level=int(ch_parts[1]) + data_lazy = self.ds[ch_].sel(isobaricInhPa=level)[t0:t1, :, :].astype("float32") + # retrieving surface channels + else: + data_lazy = self.ds[ch][t0:t1, :, :].astype("float32") + + signal.alarm(600) # seconds + try: + data = data_lazy.compute(scheduler='synchronous').values + data_per_channel.append(data.reshape(T, nlat * nlon)) # (T, G) + except _Timeout: + print(f"{_pfx()} idx={idx} TIMEOUT while reading channel '{ch}' [{t0}:{t1}] after 600s", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT time steps: {self.time[t0:t1]}", flush=True) + print(f"{_pfx()} idx={idx} TIMEOUT data: {data}", flush=True) + + finally: + signal.alarm(0) # always cancel alarm + + except Exception as e: + _logger.debug(f"Date not present in CAMS dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # stack channels to (T, C, G) + data_TCG = np.stack(data_per_channel, axis=1) # (T, C, G) + # move channels to last and flatten time: (T, G, C) -> (T*G, C) + data = np.transpose(data_TCG, (0, 2, 1)).reshape(T * (nlat * nlon), len(channels)).astype(np.float32) + + # --- coords: build flattened [lat, lon] once, then repeat for each time + lon2d, lat2d = np.meshgrid(np.asarray(self.lon), np.asarray(self.lat)) # shapes (nlat, nlon) + G = lon2d.size + latlon_flat = np.column_stack([lat2d.ravel(order="C"), lon2d.ravel(order="C")]) # (G, 2); LAT first, LON second + coords = np.vstack([latlon_flat] * T) # (T*G, 2) + + # --- datetimes: repeat each timestamp for all grid points + datetimes = np.repeat(self.time[t0:t1], G) + + # --- empty geoinfos (match anemoi) + geoinfos = np.zeros((data.shape[0], 0), dtype=np.float32) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + return rd \ No newline at end of file diff --git a/src/weathergen/datasets/data_reader_icon.py b/src/weathergen/datasets/data_reader_icon.py new file mode 100644 index 000000000..bdc72a810 --- /dev/null +++ b/src/weathergen/datasets/data_reader_icon.py @@ -0,0 +1,305 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import numpy as np +import json +import zarr + + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + +class DataReaderIcon(DataReaderTimestep): + "Wrapper for ICON data variables" + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + + """ + Construct data reader for ICON data variables + + Parameters + ---------- + filename : Path + filename (and path) of json kerchunk generated file + stream_info : Omega object + information about stream + + Attributes + ---------- + self.filename + self.ds + self.mesh_size + self.colnames + self.cols_idx + self.stats + self.time + self.start_idx + self.end_idx + self.len + self.lat + self.lon + self.step_hrs + self.properties + self.mean + self.stdev + self.source_channels + self.source_idx + self.target_channels + self.target_idx + self.geoinfo_channels + self.geoinfo_idx + + Returns + ------- + None + """ + + # loading datafile + self.filename = filename + self.ds = zarr.open(filename, mode="r") + self.mesh_size = self.ds.attrs["ncells"] + + # variables + self.colnames = list(self.ds) + self.cols_idx = np.array(list(np.arange(len(self.colnames)))) + + stats_filename = Path(filename).with_name(Path(filename).stem + ".json") + with open(stats_filename) as stats_file: + self.stats = json.load(stats_file) + + stats_vars = self.stats["metadata"]["variables"] + assert stats_vars == self.colnames, ( + f"Variables in normalization file {stats_vars} do not match dataset columns {self.colnames}" + ) + + # time + self.time = np.array(self.ds["time"], dtype="timedelta64[D]") + np.datetime64( + self.ds["time"].attrs["units"].split("since ")[-1] + ) + + start_ds = self.time[0] + end_ds = self.time[-1] + + if start_ds > tw_handler.t_end or end_ds < tw_handler.t_start: + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + self.start_idx = (tw_handler.t_start - start_ds).astype("timedelta64[D]").astype( + int + ) * self.mesh_size + self.end_idx = ( + (tw_handler.t_end - start_ds).astype("timedelta64[D]").astype(int) + 1 + ) * self.mesh_size - 1 + + self.len = (self.end_idx - self.start_idx) // self.mesh_size + + assert self.end_idx > self.start_idx, ( + f"Abort: Final index of {self.end_idx} is the same of larger than start index {self.start_idx}" + ) + + # TODO @Asma - use something more generalizable + period = self.time[1] - self.time[0] + + super().__init__( + tw_handler, + stream_info, + start_ds, + end_ds, + period, + ) + + len_data_entries = len(self.time) * self.mesh_size + len_hrs = tw_handler.t_window_len + assert self.end_idx + len_hrs <= len_data_entries, ( + f"Abort: end_date must be set at least {len_hrs} before the last date in the dataset" + ) + + # coordinates + coords_units = self.ds["clat"].attrs['units'] + + if coords_units == "radian": + self.lat = np.rad2deg(self.ds["clat"][:].astype("f")) + self.lon = np.rad2deg(self.ds["clon"][:].astype("f")) + + else: + self.lat = self.ds["clat"][:].astype("f") + self.lon = self.ds["clon"][:].astype("f") + + # Ignore step_hrs, idk how it supposed to work + # TODO, TODO, TODO: + self.step_hrs = 1 + + self.properties = { + "stream_id": 0, + } + + # stats + stats_vars = self.stats["metadata"]["variables"] + assert stats_vars == self.colnames, ( + f"Variables in normalization file {stats_vars} do not match dataset columns {self.colnames}" + ) + self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") + self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") + + + source_channels = stream_info.get("source_channels") + if source_channels: + self.source_channels, self.source_idx = self.select(source_channels) + else: + self.source_channels = self.colnames + self.source_idx = self.cols_idx + + target_channels = stream_info.get("target_channels") + if target_channels: + self.target_channels, self.target_idx = self.select(target_channels) + else: + self.target_channels = self.colnames + self.target_idx = self.cols_idx + + # Check if standard deviations are strictly positive for selected channels + selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) + non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] + assert len(non_positive_stds) == 0, ( + f"Abort: Encountered non-positive standard deviations for selected columns {[self.colnames[selected_channel_indices][i] for i in non_positive_stds]}." + ) + + self.geoinfo_channels = [] + self.geoinfo_idx = [] + + def select(self, ch_filters: list[str]) -> (np.array, list[str]): + """ + Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + + Parameters + ---------- + ch_filters: list[str] + list of patterns to access + + Returns + ------- + selected_colnames: np.array, + Selected columns according to the patterns specified in ch_filters + selected_cols_idx + respective index of these patterns in the data array + """ + mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] + + selected_cols_idx = self.cols_idx[np.where(mask)[0]] + selected_colnames = [self.colnames[int(i)] for i in np.where(mask)[0]] + + return selected_colnames, selected_cols_idx + + @override + def init_empty(self) -> None: + super().init_empty() + self.len = 0 + + @override + def length(self) -> int: + """ + Length of dataset + + Parameters + ---------- + None + + Returns + ------- + length of dataset + """ + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for temporal window + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + data (coords, geoinfos, data, datetimes) + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # TODO: handle sub-sampling + + t_idxs_start = t_idxs[0] + t_idxs_end = t_idxs[-1] + 1 + + # datetime + datetimes = self.time[t_idxs_start:t_idxs_end] + + # lat/lon coordinates + tiling to match time steps + lat = self.lat[:, np.newaxis] + lon = self.lon[:, np.newaxis] + + lat = np.tile(lat, len(datetimes)) + lon = np.tile(lon, len(datetimes)) + + coords = np.concatenate([lat, lon], axis=1) + + # time coordinate repeated to match grid points + datetimes = np.repeat(datetimes, self.mesh_size).reshape(-1, 1) + datetimes = np.squeeze(datetimes) + # print(f"datetimes.shape = {datetimes.shape}", flush = True) + + # expanding indexes for data + start_row = t_idxs_start * self.mesh_size + end_row = t_idxs_end * self.mesh_size + + # data + channels = np.array(self.colnames)[channels_idx] + data_reshaped = [ + np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels + ] + data = np.concatenate(data_reshaped, axis=1) + + # empty geoinfos + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd \ No newline at end of file diff --git a/src/weathergen/datasets/icon_dataset.py b/src/weathergen/datasets/icon_dataset.py deleted file mode 100644 index abc17e32a..000000000 --- a/src/weathergen/datasets/icon_dataset.py +++ /dev/null @@ -1,484 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -from datetime import datetime -from pathlib import Path - -import numpy as np -import torch -import zarr - - -class IconDataset: - """ - A data reader for ICON model output stored in zarr. - - Parameters - ---------- - start : datetime | int - Start time of the data period as datetime object or integer in "%Y%m%d%H%M" format - end : datetime | int - End time of the data period (inclusive) with same format as start - len_hrs : int - Length of temporal windows in days - step_hrs : int - (Currently unused) Intended step size between windows in hours - filename : Path - Path to Zarr dataset containing ICON output - stream_info : dict[str, list[str]] - Dictionary with "source" and "target" keys specifying channel subsets to use - (e.g., {"source": ["temp_00"], "target": ["TRCH4_chemtr_00"]}) - - Attributes - ---------- - len_hrs : int - Temporal window length in days - mesh_size : int - Number of nodes in the ICON mesh - source_channels : list[str] - Patterns of selected source channels - target_channels : list[str] - Patterns of selected target channels - mean : np.ndarray - Per-channel means for normalization (includes coordinates) - stdev : np.ndarray - Per-channel standard deviations for normalization (includes coordinates) - properties : dict[str, list[str]] - Dataset metadata including 'stream_id' from Zarr attributes - - """ - - def __init__( - self, - start: datetime | int, - end: datetime | int, - len_hrs: int, - step_hrs: int, - filename: Path, - stream_info: dict, - ): - self.len_hrs = len_hrs - - format_str = "%Y%m%d%H%M" - if type(start) is not datetime: - start = datetime.strptime(str(start), format_str) - start = np.datetime64(start).astype("datetime64[D]") - - if type(end) is not datetime: - end = datetime.strptime(str(end), format_str) - end = np.datetime64(end).astype("datetime64[D]") - - # loading datafile - self.filename = filename - self.ds = zarr.open(filename, mode="r") - self.mesh_size = self.ds.attrs["ncells"] - - # Loading stat file - stats_filename = Path(filename).with_suffix(".json") - with open(stats_filename) as stats_file: - self.stats = json.load(stats_file) - - time_as_in_data_file = np.array(self.ds["time"], dtype="timedelta64[D]") + np.datetime64( - self.ds["time"].attrs["units"].split("since ")[-1] - ) - - start_ds = time_as_in_data_file[0] - end_ds = time_as_in_data_file[-1] - - # asserting start and end times - if start_ds > end or end_ds < start: - # TODO: this should be set in the base class - self.source_channels = [] - self.target_channels = [] - self.source_idx = np.array([]) - self.target_idx = np.array([]) - self.geoinfo_idx = [] - self.len = 0 - self.ds = None - return - - self.start_idx = (start - start_ds).astype("timedelta64[D]").astype(int) * self.mesh_size - self.end_idx = ( - (end - start_ds).astype("timedelta64[D]").astype(int) + 1 - ) * self.mesh_size - 1 - - self.len = (self.end_idx - self.start_idx) // self.mesh_size - - assert self.end_idx > self.start_idx, ( - f"Abort: Final index of {self.end_idx} is the same of larger than", - f" start index {self.start_idx}", - ) - - len_data_entries = len(self.ds["time"]) * self.mesh_size - - assert self.end_idx + len_hrs <= len_data_entries, ( - f"Abort: end_date must be set at least {len_hrs} before the last date in the dataset" - ) - - # variables - self.colnames = list(self.ds) - self.cols_idx = np.array(list(np.arange(len(self.colnames)))) - - # Ignore step_hrs, idk how it supposed to work - # TODO, TODO, TODO: - self.step_hrs = 1 - - # time - repeated_times = np.repeat(time_as_in_data_file, self.mesh_size).reshape(-1, 1) - self.time = repeated_times - - # coordinates - coords_units = self.ds["clat"].attrs["units"] - - if coords_units == "radian": - lat_as_in_data_file = np.rad2deg(self.ds["clat"][:].astype("f")) - lon_as_in_data_file = np.rad2deg(self.ds["clon"][:].astype("f")) - - else: - lat_as_in_data_file = self.ds["clat"][:].astype("f") - lon_as_in_data_file = self.ds["clon"][:].astype("f") - - self.lat = np.tile(lat_as_in_data_file, len(time_as_in_data_file)) - self.lon = np.tile(lon_as_in_data_file, len(time_as_in_data_file)) - - self.properties = {"stream_id": 0} - - # stats - stats_vars = self.stats["metadata"]["variables"] - assert stats_vars == self.colnames, ( - f"Variables in normalization file {stats_vars}" - f"do not match dataset columns {self.colnames}" - ) - - self.mean = np.array(self.stats["statistics"]["mean"], dtype="d") - self.stdev = np.array(self.stats["statistics"]["std"], dtype="d") - - # Channel selection and indexing - source_channels = stream_info["source"] if "source" in stream_info else None - if source_channels: - self.source_channels, self.source_idx = self.select(source_channels) - else: - self.source_channels = self.colnames - self.source_idx = self.cols_idx - - target_channels = stream_info["target"] if "target" in stream_info else None - if target_channels: - self.target_channels, self.target_idx = self.select(target_channels) - else: - self.target_channels = self.colnames - self.target_idx = self.cols_idx - - # Check if standard deviations are strictly positive for selected channels - selected_channel_indices = list(set(self.source_idx).union(set(self.target_idx))) - non_positive_stds = np.where(self.stdev[selected_channel_indices] <= 0)[0] - assert len(non_positive_stds) == 0, ( - f"Abort: Encountered non-positive standard deviations " - f"for selected columns { - [self.colnames[selected_channel_indices][i] for i in non_positive_stds] - }." - ) - # TODO: define in base class - self.geoinfo_idx = [] - - def select(self, ch_filters: list[str]) -> tuple[list[str], np.array]: - """ - Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ - - mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames] - - selected_cols_idx = np.where(mask)[0] - selected_colnames = [self.colnames[i] for i in selected_cols_idx] - - return selected_colnames, selected_cols_idx - - def __len__(self) -> int: - """ - Length of dataset - - Parameters - ---------- - None - - Returns - ------- - length of dataset - """ - return self.len - - def _get(self, idx: int, channels: np.array) -> tuple: - """ - Get data for window - - Parameters - ---------- - idx : int - Index of temporal window - channels_idx : np.array - Selection of channels - - Returns - ------- - data (coords, geoinfos, data, datetimes) - """ - if self.ds is None: - fp32 = np.float32 - return ( - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - np.array([], dtype=fp32), - ) - - # indexing - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - # data - data_reshaped = [ - np.asarray(self.ds[ch_]).reshape(-1, 1)[start_row:end_row] for ch_ in channels - ] - data = np.concatenate(data_reshaped, axis=1) - - lat = np.expand_dims(self.lat[start_row:end_row], 1) - lon = np.expand_dims(self.lon[start_row:end_row], 1) - - latlon = np.concatenate([lat, lon], 1) - - # empty geoinfos - geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) - datetimes = np.squeeze(self.time[start_row:end_row]) - - return (latlon, geoinfos, data, datetimes) - - def get_source(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get source data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - source data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.source_channels) - - def get_target(self, idx: int) -> tuple[np.array, np.array, np.array, np.array]: - """ - Get target data for idx - - Parameters - ---------- - idx : int - Index of temporal window - - Returns - ------- - target data (coords, geoinfos, data, datetimes) - """ - return self._get(idx, self.target_channels) - - def get_source_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.source_idx) if self.ds else 0 - - def get_target_size(self) -> int: - """ - Get size of all columns, including coordinates and geoinfo, with source - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 + len(self.geoinfo_idx) + len(self.target_idx) if self.ds else 0 - - def get_coords_size(self) -> int: - """ - Get size of coords - - Parameters - ---------- - None - - Returns - ------- - size of coords - """ - return 2 - - def normalize_coords(self, coords: torch.tensor) -> torch.tensor: - """ - Normalize coordinates - - Parameters - ---------- - coords : - coordinates to be normalized - - Returns - ------- - Normalized coordinates - """ - coords[..., 0] = np.sin(np.deg2rad(coords[..., 0])) - coords[..., 1] = np.sin(0.5 * np.deg2rad(coords[..., 1])) - - return coords - - def normalize_source_channels(self, source: torch.tensor) -> torch.tensor: - """ - Normalize source channels - - Parameters - ---------- - source : - data to be normalized - - Returns - ------- - Normalized data - """ - assert source.shape[1] == len(self.source_idx) - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] - - return source - - def normalize_target_channels(self, target: torch.tensor) -> torch.tensor: - """ - Normalize target channels - - Parameters - ---------- - target : - data to be normalized - - Returns - ------- - Normalized data - """ - assert target.shape[1] == len(self.target_idx) - for i, ch in enumerate(self.target_idx): - target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch] - - return target - - def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: - """ - Temporal window corresponding to index - - Parameters - ---------- - idx : - index of temporal window - - Returns - ------- - start and end of temporal window - """ - start_row = self.start_idx + idx * self.mesh_size - end_row = start_row + self.len_hrs * self.mesh_size - - return (self.time[start_row, 0], self.time[end_row, 0]) - - def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor: - """ - Denormalize target channels - - Parameters - ---------- - data : - data to be denormalized (target or pred) - - Returns - ------- - Denormalized data - """ - assert data.shape[-1] == len(self.target_idx), "incorrect number of channels" - for i, ch in enumerate(self.target_idx): - data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch] - - return data - - def get_source_num_channels(self) -> int: - """ - Get number of source channels - - Parameters - ---------- - None - - Returns - ------- - number of source channels - """ - return len(self.source_idx) - - def get_target_num_channels(self) -> int: - """ - Get number of target channels - - Parameters - ---------- - None - - Returns - ------- - number of target channels - """ - return len(self.target_idx) - - def get_geoinfo_size(self) -> int: - """ - Get size of geoinfos - - Parameters - ---------- - None - - Returns - ------- - size of geoinfos - """ - return len(self.geoinfo_idx) - - def normalize_geoinfos(self, geoinfos: torch.tensor) -> torch.tensor: - """ - Normalize geoinfos - - Parameters - ---------- - geoinfos : - geoinfos to be normalized - - Returns - ------- - Normalized geoinfo - """ - - assert geoinfos.shape[-1] == 0, "incorrect number of geoinfo channels" - return geoinfos diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c41916f9d..57e682b0e 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -23,7 +23,8 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs -from weathergen.datasets.icon_dataset import IconDataset +from weathergen.datasets.data_reader_icon import DataReaderIcon +from weathergen.datasets.data_reader_cams import DataReaderCams from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData from weathergen.datasets.tokenizer_forecast import TokenizerForecast @@ -106,8 +107,11 @@ def __init__( dataset = DataReaderFesom datapath = cf.data_path_fesom case "icon": - dataset = IconDataset + dataset = DataReaderIcon datapath = cf.data_path_icon + case "camseac4": + dataset = DataReaderCams + datapath = cf.data_path_cams case _: msg = f"Unsupported stream type {stream_info['type']}" f"for stream name '{stream_info['name']}'." @@ -300,6 +304,7 @@ def __iter__(self): # idx_raw is used to index into the dataset; the decoupling is needed # since there are empty batches idx_raw = iter_start + for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)): # forecast_dt needs to be constant per batch (amortized through data parallel training) forecast_dt = self.perms_forecast_dt[i] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 29bae5806..189b5d829 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -363,6 +363,7 @@ def __init__( dim_internal = dim_embed * hidden_factor # norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm enl = ens_num_layers + # dim_out = 72 self.pred_heads = torch.nn.ModuleList() for i in range(ens_size): diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 6b954b717..f0c3900b9 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -92,8 +92,8 @@ def train_continue_from_args(argl: list[str]): if args.finetune_forecast: finetune_overwrite = dict( training_mode="forecast", - forecast_delta_hrs=0, # 12 - forecast_steps=1, # [j for j in range(1,9) for i in range(4)] + forecast_delta_hrs=6, # 12 + forecast_steps= 7, # [j for j in range(1,9) for i in range(4)] forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random' forecast_freeze_model=True, forecast_att_dense_rate=1.0, # 0.25 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f3eb8850e..3c731a2e1 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -95,6 +95,11 @@ def init( def inference(self, cf, run_id_trained, epoch): # general initalization + + # Asma: This is a quick fix, won't be useful in the future + # cf.batch_size_per_gpu = 1 + # cf.batch_size_validation_per_gpu = 1 + ######## End of stupid code self.init(cf) cf = self.cf @@ -322,8 +327,11 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): self.loss_calculator_val = LossCalculator(cf=cf, stage=VAL, device=self.devices[0]) # recover epoch when continuing run + print(f"self.num_ranks_original = {self.num_ranks_original}") if self.num_ranks_original is None: epoch_base = int(self.cf.istep / len(self.data_loader)) + elif epoch_contd is not None: + epoch_base = epoch_contd + 1 else: len_per_rank = ( len(self.dataset) // (self.num_ranks_original * cf.batch_size_per_gpu) @@ -332,6 +340,7 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): self.cf.istep / (min(len_per_rank, cf.samples_per_epoch) * self.num_ranks_original) ) + # torch.autograd.set_detect_anomaly(True) if cf.forecast_policy is not None: torch._dynamo.config.optimize_ddp = False @@ -346,6 +355,10 @@ def run(self, cf, run_id_contd=None, epoch_contd=None): if cf.val_initial: self.validate(-1) + print(f"before training") + print(f"epoch_base = {epoch_base} cf.num_epochs = {cf.num_epochs}") + print(f"epoch_contd = {epoch_contd}") + for epoch in range(epoch_base, cf.num_epochs): logger.info(f"Epoch {epoch} of {cf.num_epochs}: train.") self.train(epoch)