Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cf1f829
Save current state
sophie-xhonneux Aug 5, 2025
5017f46
Save current state
sophie-xhonneux Aug 6, 2025
305658e
Barebone FSDP2 prototype TODO save checkpoints
sophie-xhonneux Aug 11, 2025
1db9e19
First version of saving model
sophie-xhonneux Sep 4, 2025
38226bd
Fix save_model
sophie-xhonneux Sep 5, 2025
f9183d9
Merge branch 'develop' into sophiex/dev/fsdp2
sophie-xhonneux Sep 5, 2025
24e865b
Log everything and log to files
sophie-xhonneux Sep 5, 2025
f2562b9
Remove redundant path creation
sophie-xhonneux Sep 5, 2025
3eb5bec
Allow for both slurm and torchrun + fewer log files
sophie-xhonneux Sep 5, 2025
3ba291b
Cleaning up init_ddp
sophie-xhonneux Sep 8, 2025
7f0a088
Ruff
sophie-xhonneux Sep 8, 2025
748021b
Attempt to avoid duplicate logging
sophie-xhonneux Sep 8, 2025
181d170
FSDP2 with mixed precision policy
sophie-xhonneux Sep 9, 2025
0176658
Ruff
sophie-xhonneux Sep 9, 2025
44e1062
Clean up and logging
sophie-xhonneux Sep 9, 2025
cb58cda
Try to get loggers to behave as we want
sophie-xhonneux Sep 9, 2025
f877ab0
Makes ruff unhappy but works
sophie-xhonneux Sep 10, 2025
a398ffa
Fixed ruff issue
clessig Sep 11, 2025
2f8ab49
Fixed problems with multi-node training.
clessig Sep 11, 2025
27bd8ba
Fix for interactive/non-DDP runs
clessig Sep 13, 2025
c4b47c4
No idea why, but this seems to work so far
sophie-xhonneux Sep 16, 2025
4b0fd83
Still works! So which is it memory or the grad scaler?
sophie-xhonneux Sep 16, 2025
ca4e56a
Also still works, I now strongly suspect the amp.gradscaler
sophie-xhonneux Sep 16, 2025
f4ecf2c
This still works, I have no clue anymore why but whatever it works
sophie-xhonneux Sep 16, 2025
6426614
Enable loading model from absolute paths
sophie-xhonneux Sep 17, 2025
df97c31
Enable loading for 1 GPU only
sophie-xhonneux Sep 17, 2025
0669dc1
Fix 1 GPU train continue
sophie-xhonneux Sep 17, 2025
9426c0f
Merge branch 'develop' into sophiex/dev/fsdp2
sophie-xhonneux Sep 17, 2025
beceba2
Appease ruff
sophie-xhonneux Sep 17, 2025
ee7e619
Fix saving the model more regularly and perf logging
sophie-xhonneux Sep 18, 2025
3b3a754
Fixed problem when training with 2 nodes.
clessig Sep 22, 2025
76ac336
Fix data loader seed
sophie-xhonneux Sep 19, 2025
fecfe66
Appease ruff
sophie-xhonneux Sep 19, 2025
5170ea5
Shouldn't overwrite with_fsdp like this
sophie-xhonneux Sep 22, 2025
a092f05
Potential fix for FSDP2 issue
sophie-xhonneux Sep 24, 2025
f90b030
Fix loss scaling and logging of dummy data loss
sophie-xhonneux Sep 25, 2025
0bed983
Clean up
sophie-xhonneux Sep 25, 2025
7790924
Appease ruff
sophie-xhonneux Sep 25, 2025
5495860
Start implementing EMA, works for 1 GPU
sophie-xhonneux Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ weight_decay: 0.1
norm_type: "LayerNorm"
nn_module: "te"

start_date: 197901010000
start_date: 199001010000
end_date: 202012310000
start_date_val: 202101010000
end_date_val: 202201010000
Expand Down
5 changes: 4 additions & 1 deletion packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.omegaconf import open_dict

from weathergen.train.utils import get_run_id

Expand Down Expand Up @@ -143,7 +144,9 @@ def load_config(
base_config = _load_default_conf()
else:
base_config = load_model_config(from_run_id, epoch, private_config.get("model_path", None))

from_run_id = base_config.run_id
with open_dict(base_config):
base_config.from_run_id = from_run_id
# use OmegaConf.unsafe_merge if too slow
return OmegaConf.merge(base_config, private_config, *overwrite_configs)

Expand Down
42 changes: 42 additions & 0 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import logging
import pathlib
import typing
from copy import deepcopy

import torch
import astropy_healpix as hp
import dask.array as da
import numpy as np
import xarray as xr
Expand Down Expand Up @@ -68,6 +71,45 @@ def create(cls, other: typing.Any) -> typing.Self:

return cls(**dataclasses.asdict(other))

@classmethod
def spoof(
cls, other: typing.Any, nchannels, datetime, geoinfo_size, mean_of_data
) -> typing.Self:
"""
Spoof an instance from data_reader_base.ReaderData instance.
other should be such an instance.
"""

hl = 5
dx = 0.5
dy = 0.5
other_copy = deepcopy(other)
num_healpix_cells = 12 * 4**hl
lons, lats = hp.healpix_to_lonlat(
np.arange(0, num_healpix_cells), 2**hl, dx=dx, dy=dy, order="nested"
)
other_copy.coords = np.stack([lats.deg, lons.deg], axis=-1, dtype=np.float32)
other_copy.geoinfos = np.zeros((other_copy.coords.shape[0], geoinfo_size), dtype=np.float32)

other_copy.data = np.expand_dims(mean_of_data.astype(np.float32), axis=0).repeat(
other_copy.coords.shape[0], axis=0
)
other_copy.datetimes = np.array(datetime).repeat(other_copy.coords.shape[0])

n_datapoints = len(other_copy.data)

assert other_copy.coords.shape == (n_datapoints, 2), (
"number of datapoints do not match data"
)
assert other_copy.geoinfos.shape[0] == n_datapoints, (
"number of datapoints do not match data"
)
assert other_copy.datetimes.shape[0] == n_datapoints, (
"number of datapoints do not match data"
)

return cls(**dataclasses.asdict(other_copy))


@dataclasses.dataclass
class ItemKey:
Expand Down
101 changes: 60 additions & 41 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
import pathlib

import numpy as np
Expand All @@ -33,11 +34,13 @@
compute_offsets_scatter_embed,
compute_source_cell_lens,
)
from weathergen.utils.logger import logger
from weathergen.utils.distributed import is_root
from weathergen.utils.train_logger import Stage

type AnyDataReader = DataReaderBase | DataReaderAnemoi | DataReaderObs

logger = logging.getLogger(__name__)


class MultiStreamDataSampler(torch.utils.data.IterableDataset):
###################################################
Expand All @@ -64,10 +67,11 @@ def __init__(
self.len_hrs: int = cf.len_hrs
self.step_hrs: int = cf.step_hrs
self.time_window_handler = TimeWindowHandler(start_date, end_date, cf.len_hrs, cf.step_hrs)
logger.info(
f"Time window handler: start={start_date}, end={end_date},"
f"len_hrs={cf.len_hrs}, step_hrs={cf.step_hrs}"
)
if is_root():
logger.info(
f"Time window handler: start={start_date}, end={end_date},"
f"len_hrs={cf.len_hrs}, step_hrs={cf.step_hrs}"
)

self.forecast_offset = cf.forecast_offset
self.forecast_delta_hrs = (
Expand All @@ -78,7 +82,7 @@ def __init__(
[cf.forecast_steps] if isinstance(cf.forecast_steps, int) else cf.forecast_steps
)
if cf.forecast_policy is not None:
if self.forecast_steps.max() == 0:
if self.forecast_steps.max() == 0 and is_root():
logger.warning("forecast policy is not None but number of forecast steps is 0.")
self.forecast_policy = cf.forecast_policy

Expand Down Expand Up @@ -130,10 +134,11 @@ def __init__(
raise FileNotFoundError(msg)

ds_type = stream_info["type"]
logger.info(
f"Opening dataset with type: {ds_type}"
+ f" from stream config {stream_info['name']}.",
)
if is_root():
logger.info(
f"Opening dataset with type: {ds_type}"
+ f"from stream config {stream_info['name']}.",
)
ds = dataset(filename=filename, **kwargs)

fsm = self.forecast_steps[0]
Expand All @@ -150,12 +155,12 @@ def __init__(
self.len = int(index_range.end - index_range.start)
self.len = min(self.len, samples_per_epoch if samples_per_epoch else self.len)
# adjust len to split loading across all workers and ensure it is multiple of batch_size
len_chunk = ((self.len // cf.num_ranks) // batch_size) * batch_size
len_chunk = ((self.len // cf.world_size) // batch_size) * batch_size
self.len = min(self.len, len_chunk)
logger.info(f"index_range={index_range}, len={self.len}, len_chunk={len_chunk}")

self.rank = cf.rank
self.num_ranks = cf.num_ranks
self.world_size = cf.world_size

self.streams = cf.streams
self.shuffle = shuffle
Expand Down Expand Up @@ -334,23 +339,29 @@ def __iter__(self):
# to avoid unwanted dependencies => see IOReaderData docstring
rdata_wrapped = IOReaderData.create(rdata)

if rdata.is_empty():
stream_data.add_empty_source(rdata_wrapped)
else:
# TODO: handling of conversion from numpy to torch here and below
# TODO: this should only be collected in validation mode

(ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source(
stream_info,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win1.start, time_win1.end),
ds,
sample_is_empty = rdata.is_empty()
if sample_is_empty:
rdata = IOReaderData.spoof(
rdata,
len(stream_info.train_source_channels),
time_win1.start,
ds.get_geoinfo_size(),
ds.mean[ds.source_idx],
)
# TODO: handling of conversion from numpy to torch here and below
# TODO: this should only be collected in validation mode
(ss_cells, ss_lens, ss_centroids) = self.tokenizer.batchify_source(
stream_info,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win1.start, time_win1.end),
ds,
)

stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids)
stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids)
stream_data.is_spoof = sample_is_empty

# target

Expand All @@ -365,21 +376,29 @@ def __iter__(self):

rdata = ds.get_target(step_forecast_dt)

if rdata.is_empty():
stream_data.add_empty_target(fstep)
else:
(tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target(
stream_info,
self.sampling_rate_target,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win2.start, time_win2.end),
ds,
sample_is_empty = rdata.is_empty()
if sample_is_empty:
rdata = IOReaderData.spoof(
rdata,
len(stream_info.train_target_channels),
time_win1.start,
ds.get_geoinfo_size(),
ds.mean[ds.target_idx],
)

stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t)
(tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target(
stream_info,
self.sampling_rate_target,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win2.start, time_win2.end),
ds,
)

stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t)
stream_data.is_spoof = sample_is_empty

# merge inputs for sources and targets for current stream
stream_data.merge_inputs()
Expand Down Expand Up @@ -417,7 +436,7 @@ def worker_workset(self):
worker_info = torch.utils.data.get_worker_info()

if worker_info is None:
assert self.num_ranks == 1
assert self.world_size == 1, self.world_size
iter_start = 0
iter_end = len(self)

Expand Down
3 changes: 2 additions & 1 deletion src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def add_empty_source(self, source: IOReaderData) -> None:
None
"""

source = IOReaderData.spoof(source)
self.source_raw += [source]
self.source_tokens_lens += [torch.zeros([self.nhc_source], dtype=torch.int32)]
self.source_tokens_lens += [torch.ones([self.nhc_source], dtype=torch.int32)]
self.source_tokens_cells += [torch.tensor([])]
self.source_centroids += [torch.tensor([])]

Expand Down
2 changes: 0 additions & 2 deletions src/weathergen/datasets/tokenizer_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from weathergen.datasets.utils import (
get_target_coords_local_ffast,
)
from weathergen.utils.logger import init_loggers


class TokenizerForecast(Tokenizer):
Expand All @@ -43,7 +42,6 @@ def batchify_source(
time_win: tuple,
normalizer, # dataset
):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)
Expand Down
3 changes: 1 addition & 2 deletions src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from weathergen.datasets.utils import (
get_target_coords_local_ffast,
)
from weathergen.utils.logger import init_loggers


class TokenizerMasking(Tokenizer):
Expand All @@ -49,7 +48,6 @@ def batchify_source(
time_win: tuple,
normalizer, # dataset
):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)
Expand Down Expand Up @@ -245,6 +243,7 @@ def sample_tensors_uniform_vectorized(
return [], 0

num_selected = valid_mask.sum().item()
perm = torch.tensor(perm)
selected_indices = perm[:num_selected]
selected_indices = torch.zeros_like(perm).scatter(0, selected_indices, 1)

Expand Down
5 changes: 4 additions & 1 deletion src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __init__(
assert self.unembed_mode == "block" # only supported mode at the moment
# padding needed if the unembedded columns cannot be concatenated to dim_out (e.g GPSRO)
self.pad = self.dim_out % token_size
self.out_pad = torch.nn.Parameter(torch.zeros(self.pad))
self.out_pad = torch.nn.Parameter(
torch.zeros(self.pad)
) # Why is this a parameter? why should this get
# gradients?
self.unembed = torch.nn.Linear(
self.dim_embed,
self.num_tokens * ((self.dim_out - embed_size_centroids) // token_size),
Expand Down
1 change: 0 additions & 1 deletion src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def __init__(
self.pred_heads[-1].append(fal)

#########################################
@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
def forward(self, toks):
preds = []
for pred_head in self.pred_heads:
Expand Down
Loading
Loading