diff --git a/config/default_config.yml b/config/default_config.yml index e40fa3b7..679f58dd 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -86,6 +86,11 @@ batch_size_validation_per_gpu: 1 # encoders and decoders that exist per stream have the stream name attached at the end freeze_modules: "" +# whether to track the exponential moving average of weights for validation +validate_with_ema: True +ema_ramp_up_ratio: 0.09 +ema_halflife_in_thousands: 1e-3 + # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 training_mode: "masking" diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index dc0d5b54..61fd8015 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -100,6 +100,7 @@ def _get_model_config_file_name(run_id: str, epoch: int | None): epoch_str = f"_epoch{epoch:05d}" return f"model_{run_id}{epoch_str}.json" + def get_model_results(run_id: str, epoch: int, rank: int) -> Path: """ Get the path to the model results zarr store from a given run_id and epoch. @@ -110,6 +111,7 @@ def get_model_results(run_id: str, epoch: int, rank: int) -> Path: raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.") return zarr_path + def _apply_fixes(config: Config) -> Config: """ Apply fixes to maintain a best effort backward combatibility. @@ -129,12 +131,13 @@ def _check_logging(config: Config) -> Config: """ config = config.copy() if config.get("train_log_freq") is None: # TODO remove this for next version - config.train_log_freq = OmegaConf.construct( + config.train_log_freq = OmegaConf.create( {"checkpoint": 250, "terminal": 10, "metrics": config.train_log.log_interval} ) return config + def load_config( private_home: Path | None, from_run_id: str | None, diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 740fb2d0..c24441e4 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -13,7 +13,9 @@ import logging import pathlib import typing +from copy import deepcopy +import astropy_healpix as hp import dask.array as da import numpy as np import xarray as xr @@ -65,6 +67,45 @@ def create(cls, other: typing.Any) -> "IOReaderData": return cls(**dataclasses.asdict(other)) + @classmethod + def spoof( + cls, other: typing.Any, nchannels, datetime, geoinfo_size, mean_of_data + ) -> typing.Self: + """ + Spoof an instance from data_reader_base.ReaderData instance. + other should be such an instance. + """ + + hl = 5 + dx = 0.5 + dy = 0.5 + other_copy = deepcopy(other) + num_healpix_cells = 12 * 4**hl + lons, lats = hp.healpix_to_lonlat( + np.arange(0, num_healpix_cells), 2**hl, dx=dx, dy=dy, order="nested" + ) + other_copy.coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32) + other_copy.geoinfos = np.zeros((other_copy.coords.shape[0], geoinfo_size), dtype=np.float32) + + other_copy.data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat( + other_copy.coords.shape[0], axis=0 + ) + other_copy.datetimes = np.array(datetime).repeat(other_copy.coords.shape[0]) + + n_datapoints = len(other_copy.data) + + assert other_copy.coords.shape == (n_datapoints, 2), ( + "number of datapoints do not match data" + ) + assert other_copy.geoinfos.shape[0] == n_datapoints, ( + "number of datapoints do not match data" + ) + assert other_copy.datetimes.shape[0] == n_datapoints, ( + "number of datapoints do not match data" + ) + + return cls(**dataclasses.asdict(other_copy)) + @dataclasses.dataclass class ItemKey: diff --git a/packages/evaluate/src/weathergen/evaluate/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export_inference.py index b6da11f5..6d723d7f 100755 --- a/packages/evaluate/src/weathergen/evaluate/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export_inference.py @@ -31,9 +31,7 @@ if not _logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) _logger.addHandler(handler) @@ -98,9 +96,9 @@ def reshape_dataset(input_data_array: xr.DataArray) -> xr.Dataset: ipoint=input_data_array.coords["ipoint"], pressure_level=pl, ) - reshaped_dataset = reshaped_dataset.set_index( - ipoint=("valid_time", "lat", "lon") - ).unstack("ipoint") + reshaped_dataset = reshaped_dataset.set_index(ipoint=("valid_time", "lat", "lon")).unstack( + "ipoint" + ) return reshaped_dataset @@ -120,9 +118,8 @@ def add_conventions(stream: str, run_id: str, ds: xr.Dataset) -> xr.Dataset: ds.attrs["title"] = f"WeatherGenerator Output for {run_id} using stream {stream}" ds.attrs["institution"] = "WeatherGenerator Project" ds.attrs["source"] = "WeatherGenerator v0.0" - ds.attrs["history"] = ( - "Created using the zarr_nc.py script on " - + np.datetime_as_string(np.datetime64("now"), unit="s") + ds.attrs["history"] = "Created using the zarr_nc.py script on " + np.datetime_as_string( + np.datetime64("now"), unit="s" ) ds.attrs["Conventions"] = "CF-1.12" return ds @@ -172,9 +169,7 @@ def cf_parser(config: OmegaConf, ds: xr.Dataset) -> xr.Dataset: if mapping[var_name]["level_type"] == "sfc": dims.remove("pressure") coordinates = {} - for coord, new_name in config["coordinates"][ - mapping[var_name]["level_type"] - ].items(): + for coord, new_name in config["coordinates"][mapping[var_name]["level_type"]].items(): coordinates |= { new_name: ( ds.coords[coord].dims, @@ -257,7 +252,7 @@ def get_data( dtype: str, fsteps: list, channels: list, - fstep_hours: int, + fstep_hours: int, n_processes: list, epoch: int, rank: int, @@ -295,11 +290,7 @@ def get_data( all_channels = dummy_out.target.channels channels = all_channels if channels is None else channels - fsteps = ( - zio_forecast_steps - if fsteps is None - else sorted([int(fstep) for fstep in fsteps]) - ) + fsteps = zio_forecast_steps if fsteps is None else sorted([int(fstep) for fstep in fsteps]) samples = ( zio_samples @@ -310,8 +301,7 @@ def get_data( for sample_idx in tqdm(samples): da_fs = [] step_tasks = [ - (sample_idx, fstep, run_id, stream, dtype, epoch, rank) - for fstep in fsteps + (sample_idx, fstep, run_id, stream, dtype, epoch, rank) for fstep in fsteps ] for result in tqdm( pool.imap_unordered(get_data_worker, step_tasks, chunksize=1), @@ -323,9 +313,7 @@ def get_data( result = result.as_xarray().squeeze() if set(channels) != set(all_channels): available_channels = result.channel.values - existing_channels = [ - ch for ch in channels if ch in available_channels - ] + existing_channels = [ch for ch in channels if ch in available_channels] if len(existing_channels) < len(channels): _logger.info( f"The following channels were not found: " @@ -340,7 +328,7 @@ def get_data( _logger.info( f"Saving sample {sample_idx} data to {output_format} format in {output_dir}." ) - + save_sample_to_netcdf( str(dtype)[:4], da_fs, @@ -383,9 +371,7 @@ def save_sample_to_netcdf( Loaded config for cf_parser function. """ # find forecast_ref_time - frt = array_list[0].valid_time.values[0] - fstep_hours * int( - array_list[0].forecast_step.values - ) + frt = array_list[0].valid_time.values[0] - fstep_hours * int(array_list[0].forecast_step.values) out_fname = output_filename(type_str, run_id, output_dir, output_format, frt) # check if file already exists if out_fname.exists(): @@ -406,9 +392,7 @@ def save_sample_to_netcdf( sample_all_steps = cf_parser(config, sample_all_steps) # add forecast_period attributes n_hours = fstep_hours.astype("int64") - sample_all_steps["forecast_period"] = ( - sample_all_steps["forecast_period"] * n_hours - ) + sample_all_steps["forecast_period"] = sample_all_steps["forecast_period"] * n_hours sample_all_steps["forecast_period"].attrs = { "standard_name": "forecast_period", "long_name": "time since forecast_reference_time", @@ -495,11 +479,11 @@ def parse_args(args: list) -> argparse.Namespace: ) parser.add_argument( - "--fstep-hours", - type = int, - default= 6, - help= "Time difference between forecast steps in hours (e.g., 6)" - ) + "--fstep-hours", + type=int, + default=6, + help="Time difference between forecast steps in hours (e.g., 6)", + ) parser.add_argument( "--epoch", @@ -535,7 +519,7 @@ def export_from_args(args: list) -> None: Export data from Zarr store to NetCDF files based on command line arguments. Parameters ---------- - args : List of command line arguments. + args : List of command line arguments. """ args = parse_args(sys.argv[1:]) run_id = args.run_id @@ -559,7 +543,7 @@ def export_from_args(args: list) -> None: config_file = Path(_REPO_ROOT, "config/evaluate/config_zarr2cf.yaml") config = OmegaConf.load(config_file) # check config loaded correctly - assert len(config["variables"].keys()) > 0 , "Config file not loaded correctly" + assert len(config["variables"].keys()) > 0, "Config file not loaded correctly" for dtype in data_type: _logger.info(f"Starting processing {dtype} for run ID {run_id}.") @@ -570,7 +554,7 @@ def export_from_args(args: list) -> None: dtype, fsteps, channels, - fstep_hours, + fstep_hours, n_processes, epoch, rank, diff --git a/src/weathergen/model/ema.py b/src/weathergen/model/ema.py index 0adbc9fc..7acbbf9f 100644 --- a/src/weathergen/model/ema.py +++ b/src/weathergen/model/ema.py @@ -8,52 +8,64 @@ # nor does it submit to any jurisdiction. -import copy - import torch -class EMA: +class EMAModel: """ Taken and modified from https://github.com/NVlabs/edm2/tree/main """ @torch.no_grad() - def __init__(self, net, halflife_steps=float("inf"), rampup_ratio=0.09): - self.net = net + def __init__( + self, + model, + empty_model, + halflife_steps=float("inf"), + rampup_ratio=0.09, + is_model_sharded=False, + ): + self.original_model = model self.halflife_steps = halflife_steps self.rampup_ratio = rampup_ratio - self.ema = copy.deepcopy(net) + self.ema_model = empty_model + self.is_model_sharded = is_model_sharded + + self.reset() @torch.no_grad() def reset(self): - for p_net, p_ema in zip(self.net.parameters(), self.ema.parameters(), strict=False): - p_ema.copy_(p_net) + """ + This function resets the EMAModel to be the same as the Model. + + It operates via the state_dict to be able to deal with sharded tensors in case + FSDP2 is used. + """ + self.ema_model.to_empty(device="cuda") + maybe_sharded_sd = self.original_model.state_dict() + # this copies correctly tested in pdb + mkeys, ukeys = self.ema_model.load_state_dict(maybe_sharded_sd, strict=True, assign=False) @torch.no_grad() - def update(self, cur_steps, batch_size): + def update(self, cur_step, batch_size): halflife_steps = self.halflife_steps if self.rampup_ratio is not None: - halflife_steps = min(halflife_steps, cur_steps / 1e3 * self.rampup_ratio) + halflife_steps = min(halflife_steps, cur_step / 1e3 * self.rampup_ratio) beta = 0.5 ** (batch_size / max(halflife_steps * 1e3, 1e-6)) - for p_net, p_ema in zip(self.net.parameters(), self.ema.parameters(), strict=False): + for p_net, p_ema in zip( + self.original_model.parameters(), self.ema_model.parameters(), strict=True + ): p_ema.lerp_(p_net, 1 - beta) @torch.no_grad() def forward_eval(self, *args, **kwargs): - self.ema.eval() - out = self.ema(*args, **kwargs) - self.ema.train() + self.ema_model.eval() + out = self.ema_model(*args, **kwargs) + self.ema_model.train() return out - @torch.no_grad() - def get(self): - for p_net, p_ema in zip(self.net.buffers(), self.ema.buffers(), strict=False): - p_ema.copy_(p_net) - return self.ema - def state_dict(self): - return self.ema.state_dict() + return self.ema_model.state_dict() - def load_state_dict(self, state): - self.ema.load_state_dict(state) + def load_state_dict(self, state, **kwargs): + self.ema_model.load_state_dict(state, **kwargs) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 36e97aee..d7137d0a 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -39,6 +39,7 @@ MultiSelfAttentionHeadLocal, MultiSelfAttentionHeadVarlen, ) +from weathergen.model.ema import EMAModel from weathergen.model.layers import MLP from weathergen.model.model import Model, ModelParams from weathergen.model.utils import freeze_weights @@ -152,54 +153,15 @@ def inference(self, cf, devices, run_id_trained, epoch): self.validate(epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") - def run(self, cf, devices, run_id_contd=None, epoch_contd=None): - # general initalization - self.init(cf, devices) - cf = self.cf - - # TODO: do not define new members outside of the init!! - self.device_type = torch.accelerator.current_accelerator() - self.device = torch.device(f"{self.device_type}:{cf.local_rank}") - - self.dataset = MultiStreamDataSampler( - cf, - cf.start_date, - cf.end_date, - cf.batch_size_per_gpu, - cf.samples_per_epoch, - stage=TRAIN, - shuffle=cf.shuffle, - ) - self.dataset_val = MultiStreamDataSampler( - cf, - cf.start_date_val, - cf.end_date_val, - cf.batch_size_validation_per_gpu, - cf.samples_per_validation, - stage=VAL, - shuffle=True, - ) - - loader_params = { - "batch_size": None, - "batch_sampler": None, - "shuffle": False, - "num_workers": cf.loader_num_workers, - "pin_memory": True, - } - self.data_loader = torch.utils.data.DataLoader(self.dataset, **loader_params, sampler=None) - self.data_loader_validation = torch.utils.data.DataLoader( - self.dataset_val, **loader_params, sampler=None - ) - + def init_model_and_shard(self, cf, devices): sources_size = self.dataset.get_sources_size() targets_num_channels = self.dataset.get_targets_num_channels() targets_coords_size = self.dataset.get_targets_coords_size() with torch.device("meta"): - self.model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() + model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create() - for name, module in self.model.named_modules(): + for name, module in model.named_modules(): name = module.name if hasattr(module, "name") else name # avoid the whole model element which has name '' if name == "": @@ -208,8 +170,8 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): freeze_weights(module) if cf.with_ddp and not cf.with_fsdp: - self.model = torch.nn.parallel.DistributedDataParallel( - self.model, + model = torch.nn.parallel.DistributedDataParallel( + model, broadcast_buffers=True, find_unused_parameters=True, gradient_as_bucket_view=True, @@ -235,23 +197,20 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): MultiCrossAttentionHeadVarlenSlicedQ, MultiSelfAttentionHeadVarlen, ) - # for module in self.model.embeds.modules(): - # if isinstance(module, modules_to_shard): - # fully_shard(module, **fsdp_kwargs) - for module in self.model.ae_local_blocks.modules(): + for module in model.ae_local_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in self.model.ae_adapter.modules(): + for module in model.ae_adapter.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in self.model.ae_global_blocks.modules(): + for module in model.ae_global_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) - for module in self.model.fe_blocks.modules(): + for module in model.fe_blocks.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **fsdp_kwargs) @@ -265,20 +224,63 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): else None ), } - for module in self.model.pred_adapter_kv.modules(): + for module in model.pred_adapter_kv.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **full_precision_fsdp_kwargs) - for module in self.model.target_token_engines.modules(): + for module in model.target_token_engines.modules(): if isinstance(module, modules_to_shard): fully_shard(module, **full_precision_fsdp_kwargs) - self.model_params = ModelParams(cf).create(cf) # .to(device) + model_params = ModelParams(cf).create(cf) if cf.with_ddp and cf.with_fsdp: - fully_shard(self.model) - for tensor in itertools.chain(self.model.parameters(), self.model.buffers()): + fully_shard(model) + for tensor in itertools.chain(model.parameters(), model.buffers()): assert tensor.device == torch.device("meta") + return model, model_params + + def run(self, cf, devices, run_id_contd=None, epoch_contd=None): + # general initalization + self.init(cf, devices) + cf = self.cf + + # TODO: do not define new members outside of the init!! + self.device_type = torch.accelerator.current_accelerator() + self.device = torch.device(f"{self.device_type}:{cf.local_rank}") + + self.dataset = MultiStreamDataSampler( + cf, + cf.start_date, + cf.end_date, + cf.batch_size_per_gpu, + cf.samples_per_epoch, + stage=TRAIN, + shuffle=cf.shuffle, + ) + self.dataset_val = MultiStreamDataSampler( + cf, + cf.start_date_val, + cf.end_date_val, + cf.batch_size_validation_per_gpu, + cf.samples_per_validation, + stage=VAL, + shuffle=True, + ) + + loader_params = { + "batch_size": None, + "batch_sampler": None, + "shuffle": False, + "num_workers": cf.loader_num_workers, + "pin_memory": True, + } + self.data_loader = torch.utils.data.DataLoader(self.dataset, **loader_params, sampler=None) + self.data_loader_validation = torch.utils.data.DataLoader( + self.dataset_val, **loader_params, sampler=None + ) + + self.model, self.model_params = self.init_model_and_shard(cf, devices) if run_id_contd is None: self.model.to_empty(device="cuda") @@ -295,6 +297,17 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): if cf.compile_model: self.model = torch.compile(self.model, dynamic=True) + self.validate_with_ema = cf.get("validate_with_ema", False) + if self.validate_with_ema: + meta_ema_model = self.init_model_and_shard(cf, devices)[0] + self.ema_model = EMAModel( + self.model, + meta_ema_model, + halflife_steps=cf.get("ema_halflife_in_thousands", 1e-3), + rampup_ratio=cf.get("ema_ramp_up_ratio", 0.09), + is_model_sharded=(cf.with_ddp and cf.with_fsdp), + ) + # if with_fsdp then parameter count is unreliable if (is_root() and not cf.with_fsdp) or not cf.with_ddp: self.model.print_num_parameters() @@ -580,6 +593,13 @@ def train(self, epoch): # update learning rate self.lr_scheduler.step() + # EMA update + if self.validate_with_ema: + self.ema_model.update( + self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, + self.world_size_original * self.cf.batch_size_per_gpu, + ) + self.loss_unweighted_hist += [loss_values.losses_all] self.loss_model_hist += [loss_values.loss.item()] self.stdev_unweighted_hist += [loss_values.stddev_all] @@ -596,7 +616,7 @@ def train(self, epoch): if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: self.save_model(-1) - self.cf.istep += cf.batch_size_per_gpu + self.cf.istep += 1 self.dataset.advance() @@ -622,7 +642,7 @@ def validate(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, _ = self.model( + preds, _ = self.ema_model.forward_eval( self.model_params, batch, cf.forecast_offset, forecast_steps ) @@ -806,7 +826,9 @@ def load_optim(self): # ) def _get_full_model_state_dict(self): - maybe_sharded_sd = self.model.state_dict() + maybe_sharded_sd = ( + self.model.state_dict() if self.ema_model is None else self.ema_model.state_dict() + ) if self.cf.with_ddp and self.cf.with_fsdp: cpu_state_dict = {} for param_name, sharded_param in maybe_sharded_sd.items():