Skip to content

Commit 0a2bdba

Browse files
TillHaeclessigkacpnowakshmh40iluise
authored
Rename epoch to mini_epoch (#1190)
* training progress unit realignment from epoch to mini_epoch * small naming fix of mini_epoch * ruffed * linted * Fix out of bounds in data_reader_obs (#1180) * fix out of bounds access * Adding comment * Removed debgu * Fixed to use forward function for forecast engine (#1188) * Fixed to use forward function for forecast engine, and also fstep for conditioning * Fixed missing return statement * Enable FesomDataReader to have different source and target datasets (#1046) * Implement separate target and source files, adjust masking * Fix casual masking * Fix longitude conversion flag * Fix casual masking strategy --------- Co-authored-by: Seb Hickman <[email protected]> * Add support for constant learning rate (#1186) * Added support for constant learning rate and minor clean-up in code * Fixed issues with overlap between lr phases * Changing default lr to constant * [issue 1123] restore probabilistic scores (#1128) * rebase * add ensemble * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * Fix spoofing and refactor handling of multiple source files (#1118) * Cleaning up spoofing and related code on data preprocessing for model * Fixed typo * Updated comments * Removed merge cells and implemented necessary adjustments * Fixed forecasting * Fixed missing handling of NaNs in coordinates and channel data * Minor clean up * Fix to removing/renaming variables * Changed funtion name to improve readability * Fixed bug with incorrect handling of multiple input datasources. * Addressed reviewer comments * resolve conflict * [1131] fixes circular dependencies (#1134) * fixes dependencies * cleanup * make the type checker not fail * cleanup * cleanup of type issues * Give option to plot only prediction maps (#1139) * add plot_preds_only feature * minor changes after comments * Tell FSDP2 about embedding engine forward functions (#1133) * Tell FSDP2 about embedding engine forward functions Note DO NOT add print functions in forward functions of the model, it will break with FSDP2 * Add comment * recover 'all' option (#1146) * Fixed problem in inferecne (#1145) * implement vrmse (#1147) * [1144] Extra fixes (#1148) * Fixed problem in inferecne * more fixes * fixes * lint * lint --------- Co-authored-by: Christian Lessig <[email protected]> * Jk/log grad norms/log grad norms (#1068) * Log gradient norms * Prototype for recording grad norms * Address review changes + hide behind feature flag * Final fixes including backward compatibility * Ruff * More ruff stuff * forecast config with small decoder * fixed uv.lock * test gradient logging on mutli gpus * update uv.lock to latest develop version * revert to default confit * add comment on FSDP2 specifics * move plot grad script to private repo * rm seaborn from pyproject * updating terminal and metrics loggin, add get_tensor_item fct * check for DTensor instead of world size * revert forecast fct, fix in separate PR * rename grad_norm log names to exclude from MLFlow * add log_grad_norms to default config --------- Co-authored-by: sophiex <[email protected]> * Add forecast and observation activity (#1126) * Add calculation methods for forecast and observation activity metrics in Scores class * Add new calculation methods for forecast activity metrics in Scores class * ruff * fix func name * Rename observation activity calculation method to target activity in Scores class * typo * refactor to common calc_act function for activity * fix cases * have calc_tact and calc_fact that use _calc_act for maintainability * fix small thing in style --------- Co-authored-by: iluise <[email protected]> * hotfix: use correct methot `create` instead of `construct` (#1090) * restore develop * fix deterministic * fix plotting * lint * fix eval_config * probabilistic scores working now * lint * update utils * packages/evaluate/src/weathergen/evaluate/score.py * lint * removing duplication --------- Co-authored-by: Christian Lessig <[email protected]> Co-authored-by: Timothy Hunter <[email protected]> Co-authored-by: Savvas Melidonis <[email protected]> Co-authored-by: Sophie X <[email protected]> Co-authored-by: Julius Polz <[email protected]> Co-authored-by: Julian Kuehnert <[email protected]> Co-authored-by: Simon Grasse <[email protected]> * Adding config to issue templates The issue template seems to have disappeared, attempting to solve that. * Add the duration of animation as global plotting option (#1189) * Add the animation duration as global plotting option * Linting * Use FPS instead of milliseconds * Linting * Attempt to fix the bug report template * Attempt to fix initiative template * Update task template * [1081][Evaluation] Use parent ruff rules (#1177) * use ruff settings from parent * fix code checks * check fixes 2nd round * reformat to line length * [1092] Adds pushing metrics to the evaluation pipeline (#1127) * changes * changes * changes * changes * changes * scores successfully pushed to MLFlow, still need to refactor * try to batch upload all metrics form same runid * batch logging all scores of each run_id * get parent_run by from_run_id * changes * cleanups * bug fixes * typing issue * Cleanup * pdb * integration test --------- Co-authored-by: Jubeku <[email protected]> * Fix the issue - "Empty source still have embedding network" (#1114) * Replace cf.rank==0 with utils.distributed.is_root * fix empty source inputs still have embedding layer * fix lint * fix source empty or source exclude all * fix source empty or source exclude all * fix forecast mode empty source --------- Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> * [930][evaluation] implement CSVReader (#932) * first version of quaver reader * working version * add CSVReader * rebase to develop * add polimorphism * fix names * lint * Iluise/hot fixes (#1209) * fix froct * fix 1150 * Fix plot_train verbosity (#1225) * [1206] Experimentation for extra data readers (#1207) * initial implementation * changes * toml * add module to annotations.json (#1142) Co-authored-by: Javad Kasravi <[email protected]> * Correct bug with score cards and bar plots for different metrics (#1192) * Rebase to develop * Linting * Address comments and linting * [eval][1122] Plot scores on a map (#1176) * first version of score maps * add maps to compute_scores * fix single sample situation * fix single sample * lint * restore score.py * fix bug in metric stream * default flag to false * Minor correction, a line was deleted by mistake? (#1193) * fix * working setup for regridded data * fix missing valid time case * lint and fix color in score cards * fix path for score maps * Allow plotting score maps every time --------- Co-authored-by: Savvas Melidonis <[email protected]> * Fix DDP without FSDP (#1227) * Fix DDP without FSDP * Fixed taht freezing would not have worked with only DDP * refactor export scripts - Part 1 (#1223) * move stuff around * move stuff around * rename files * [1034][reader_extra] E-Obs datareader (#1228) * [1034] rebase * [1034] add dataloader * [1034] Zarr3-->Zarr2 * [1034] lint * [1034] lint * [1034] Moved to reader_extra * [1034] registry E-Obs * training progress unit realignment from epoch to mini_epoch * ruffed * check if path is dir in io_reader * fix overwrite of fname_zarr in io_reader * add backward compatibility to config read * Separate write and read functions for model*chkpt*.json files (MatKBauer) --------- Co-authored-by: Christian Lessig <[email protected]> Co-authored-by: Kacper Nowak <[email protected]> Co-authored-by: Seb Hickman <[email protected]> Co-authored-by: iluise <[email protected]> Co-authored-by: Timothy Hunter <[email protected]> Co-authored-by: Savvas Melidonis <[email protected]> Co-authored-by: Sophie X <[email protected]> Co-authored-by: Julius Polz <[email protected]> Co-authored-by: Julian Kuehnert <[email protected]> Co-authored-by: Simon Grasse <[email protected]> Co-authored-by: Michael Tarnawa <[email protected]> Co-authored-by: Jubeku <[email protected]> Co-authored-by: Jifeng Wang <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: wang85 <[email protected]> Co-authored-by: Javad kasravi <[email protected]> Co-authored-by: Javad Kasravi <[email protected]> Co-authored-by: Simone Norberti <[email protected]>
1 parent 8aade86 commit 0a2bdba

File tree

21 files changed

+210
-143
lines changed

21 files changed

+210
-143
lines changed

config/default_config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
113113
"same_strategy_per_batch": false
114114
}
115115

116-
num_epochs: 32
117-
samples_per_epoch: 4096
116+
num_mini_epochs: 32
117+
samples_per_mini_epoch: 4096
118118
samples_per_validation: 512
119119
shuffle: True
120120

config/evaluate/eval_config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ run_ids :
3131
ar40mckx:
3232
label: "pretrained model ar40mckx"
3333
results_base_dir : "./results/"
34-
epoch: 0
34+
mini_epoch: 0
3535
rank: 0
3636
streams:
3737
ERA5:
@@ -62,7 +62,7 @@ run_ids :
6262
c8g5katp:
6363
label: "2 steps window"
6464
results_base_dir : "./results/"
65-
epoch: 0
65+
mini_epoch: 0
6666
rank: 0
6767
streams:
6868
ERA5:

integration_tests/small1.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ run_path: "./results"
33
model_path: "./models"
44
loss_fcts: [["mse", 1.0]]
55
loss_fcts_val: [["mse", 1.0]]
6-
num_epochs: 1
7-
samples_per_epoch: 10
6+
num_mini_epochs: 1
7+
samples_per_mini_epoch: 10
88
samples_per_validation: 5
99
lr_steps: 4
1010
lr_steps_warmup: 2

integration_tests/small1_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_train(setup, test_run_id):
6969
def infer(run_id):
7070
logger.info("run inference")
7171
inference_from_args(
72-
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"]
72+
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"]
7373
+ [
7474
"--from_run_id",
7575
run_id,
@@ -84,7 +84,7 @@ def infer(run_id):
8484
def infer_with_missing(run_id):
8585
logger.info("run inference")
8686
inference_from_args(
87-
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"]
87+
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--mini_epoch", "0"]
8888
+ [
8989
"--from_run_id",
9090
run_id,
@@ -128,7 +128,7 @@ def evaluate_results(run_id):
128128
}
129129
},
130130
"label": "MTM ERA5",
131-
"epoch": 0,
131+
"mini_epoch": 0,
132132
"rank": 0,
133133
}
134134
},
@@ -171,7 +171,7 @@ def assert_train_loss_below_threshold(run_id):
171171
assert loss_metric is not None, (
172172
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
173173
)
174-
# Check that the loss does not explode in a single epoch
174+
# Check that the loss does not explode in a single mini_epoch
175175
# This is meant to be a quick test, not a convergence test
176176
target = 1.5
177177
assert loss_metric < target, (
@@ -193,7 +193,7 @@ def assert_val_loss_below_threshold(run_id):
193193
assert loss_metric is not None, (
194194
"'stream.ERA5.loss_mse.loss_avg' metric is missing in metrics file"
195195
)
196-
# Check that the loss does not explode in a single epoch
196+
# Check that the loss does not explode in a single mini_epoch
197197
# This is meant to be a quick test, not a convergence test
198198
assert loss_metric < 1.25, (
199199
f"'stream.ERA5.loss_mse.loss_avg' is {loss_metric}, expected to be below 0.25"

packages/common/src/weathergen/common/config.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,23 @@ def format_cf(config: Config) -> str:
5454
return stream.getvalue()
5555

5656

57-
def save(config: Config, epoch: int | None):
57+
def save(config: Config, mini_epoch: int | None):
5858
"""Save current config into the current runs model directory."""
5959
path_models = Path(config.model_path)
6060
# save in directory with model files
6161
dirname = path_models / config.run_id
6262
dirname.mkdir(exist_ok=True, parents=True)
6363

64-
fname = dirname / _get_model_config_file_name(config.run_id, epoch)
64+
fname = _get_model_config_file_write_name(path_models, config.run_id, mini_epoch)
6565

6666
json_str = json.dumps(OmegaConf.to_container(config))
6767
with fname.open("w") as f:
6868
f.write(json_str)
6969

7070

71-
def load_model_config(run_id: str, epoch: int | None, model_path: str | None) -> Config:
71+
def load_model_config(run_id: str, mini_epoch: int | None, model_path: str | None) -> Config:
7272
"""
73-
Load a configuration file from a given run_id and epoch.
73+
Load a configuration file from a given run_id and mini_epoch.
7474
If run_id is a full path, loads it from the full path.
7575
"""
7676
if Path(run_id).exists(): # load from the full path if a full path is provided
@@ -84,13 +84,13 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
8484
config=pconf, attribute_name="model_path", fallback="models"
8585
)
8686
path = Path(model_path)
87-
fname = path / run_id / _get_model_config_file_name(run_id, epoch)
87+
fname = _get_model_config_file_read_name(path, run_id, mini_epoch)
8888
assert fname.exists(), (
8989
"The fallback path to the model does not exist. Please provide a `model_path`.",
9090
fname,
9191
)
9292

93-
_logger.info(f"Loading config from specified run_id and epoch: {fname}")
93+
_logger.info(f"Loading config from specified run_id and mini_epoch: {fname}")
9494

9595
with fname.open() as f:
9696
json_str = f.read()
@@ -100,24 +100,49 @@ def load_model_config(run_id: str, epoch: int | None, model_path: str | None) ->
100100
return _apply_fixes(config)
101101

102102

103-
def _get_model_config_file_name(run_id: str, epoch: int | None):
104-
if epoch is None:
105-
epoch_str = ""
106-
elif epoch == -1:
107-
epoch_str = "_latest"
103+
def _get_model_config_file_write_name(path: Path, run_id: str, mini_epoch: int | None):
104+
if mini_epoch is None:
105+
mini_epoch_str = ""
106+
elif mini_epoch == -1:
107+
mini_epoch_str = "_latest"
108108
else:
109-
epoch_str = f"_epoch{epoch:05d}"
110-
return f"model_{run_id}{epoch_str}.json"
109+
mini_epoch_str = f"_chkpt{mini_epoch:05d}"
111110

111+
return path / run_id / f"model_{run_id}{mini_epoch_str}.json"
112112

113-
def get_model_results(run_id: str, epoch: int, rank: int) -> Path:
113+
114+
def _get_model_config_file_read_name(path: Path, run_id: str, mini_epoch: int | None):
115+
if mini_epoch is None:
116+
mini_epoch_str = ""
117+
elif mini_epoch == -1:
118+
mini_epoch_str = "_latest"
119+
elif (path / run_id / f"model_{run_id}_epoch{mini_epoch:05d}.json").exists():
120+
mini_epoch_str = f"_epoch{mini_epoch:05d}"
121+
else:
122+
mini_epoch_str = f"_chkpt{mini_epoch:05d}"
123+
124+
return path / run_id / f"model_{run_id}{mini_epoch_str}.json"
125+
126+
127+
def get_model_results(run_id: str, mini_epoch: int, rank: int) -> Path:
114128
"""
115-
Get the path to the model results zarr store from a given run_id and epoch.
129+
Get the path to the model results zarr store from a given run_id and mini_epoch.
116130
"""
117131
run_results = Path(_load_private_conf(None)["path_shared_working_dir"]) / f"results/{run_id}"
118-
zarr_path = run_results / f"validation_epoch{epoch:05d}_rank{rank:04d}.zarr"
119-
if not zarr_path.exists() or not zarr_path.is_dir():
120-
raise FileNotFoundError(f"Zarr file {zarr_path} does not exist or is not a directory.")
132+
133+
zarr_path_new = run_results / f"validation_chkpt{mini_epoch:05d}_rank{rank:04d}.zarr"
134+
zarr_path_old = run_results / f"validation_epoch{mini_epoch:05d}_rank{rank:04d}.zarr"
135+
136+
if zarr_path_new.exists() or zarr_path_new.is_dir():
137+
zarr_path = zarr_path_new
138+
elif zarr_path_old.exists() or zarr_path_old.is_dir():
139+
zarr_path = zarr_path_old
140+
else:
141+
raise FileNotFoundError(
142+
f"Zarr file with run_id {run_id}, mini_epoch {mini_epoch} and rank {rank} does not "
143+
f"exist or is not a directory."
144+
)
145+
121146
return zarr_path
122147

123148

@@ -150,7 +175,7 @@ def _check_logging(config: Config) -> Config:
150175
def load_config(
151176
private_home: Path | None,
152177
from_run_id: str | None,
153-
epoch: int | None,
178+
mini_epoch: int | None,
154179
*overwrites: Path | dict | Config,
155180
) -> Config:
156181
"""
@@ -161,7 +186,7 @@ def load_config(
161186
private_home: Configuration file containing platform dependent information and secretes
162187
from_run_id: Run id of the pretrained WeatherGenerator model
163188
to continue training or inference
164-
epoch: epoch of the checkpoint to load. -1 indicates last checkpoint available.
189+
mini_epoch: mini_epoch of the checkpoint to load. -1 indicates last checkpoint available.
165190
*overwrites: Additional overwrites from different sources
166191
167192
Note: The order of precendence for merging the final config is in ascending order:
@@ -191,13 +216,21 @@ def load_config(
191216
if from_run_id is None:
192217
base_config = _load_default_conf()
193218
else:
194-
base_config = load_model_config(from_run_id, epoch, private_config.get("model_path", None))
219+
base_config = load_model_config(
220+
from_run_id, mini_epoch, private_config.get("model_path", None)
221+
)
195222
from_run_id = base_config.run_id
196223
with open_dict(base_config):
197224
base_config.from_run_id = from_run_id
198225
# use OmegaConf.unsafe_merge if too slow
199226
c = OmegaConf.merge(base_config, private_config, *overwrite_configs)
200227
assert isinstance(c, Config)
228+
229+
# Ensure the config has mini-epoch notation
230+
if hasattr(c, "samples_per_epoch"):
231+
c.samples_per_mini_epoch = c.samples_per_epoch
232+
c.num_mini_epochs = c.num_epochs
233+
201234
return c
202235

203236

@@ -456,9 +489,9 @@ def get_path_model(config: Config) -> Path:
456489
return Path(config.model_path) / config.run_id
457490

458491

459-
def get_path_output(config: Config, epoch: int) -> Path:
492+
def get_path_output(config: Config, mini_epoch: int) -> Path:
460493
base_path = get_path_run(config)
461-
fname = f"validation_epoch{epoch:05d}_rank{config.rank:04d}.zarr"
494+
fname = f"validation_chkpt{mini_epoch:05d}_rank{config.rank:04d}.zarr"
462495

463496
return base_path / fname
464497

@@ -523,7 +556,7 @@ def validate_forecast_policy_and_steps(cf: OmegaConf):
523556
valid_forecast_policies = (
524557
"Valid values for 'forecast_policy' are, e.g., 'fixed' when using constant "
525558
"forecast steps throughout the training, or 'sequential' when varying the forecast "
526-
"steps over epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. "
559+
"steps over mini_epochs, such as, e.g., 'forecast_steps: [2, 2, 4, 4]'. "
527560
)
528561
valid_forecast_steps = (
529562
"'forecast_steps' must be a positive integer or a non-empty list of positive integers. "

packages/evaluate/src/weathergen/evaluate/io_reader.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
469469

470470
super().__init__(eval_cfg, run_id, private_paths)
471471

472-
self.epoch = eval_cfg.epoch
472+
self.mini_epoch = eval_cfg.mini_epoch
473473
self.rank = eval_cfg.rank
474474

475475
# Load model configuration and set (run-id specific) directories
@@ -498,9 +498,17 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
498498
self.eval_cfg.get("metrics_dir", self.metrics_base_dir / self.run_id / "evaluation")
499499
)
500500

501-
self.fname_zarr = self.results_dir.joinpath(
502-
f"validation_epoch{self.epoch:05d}_rank{self.rank:04d}.zarr"
501+
fname_zarr_new = self.results_dir.joinpath(
502+
f"validation_chkpt{self.mini_epoch:05d}_rank{self.rank:04d}.zarr"
503503
)
504+
fname_zarr_old = self.results_dir.joinpath(
505+
f"validation_epoch{self.mini_epoch:05d}_rank{self.rank:04d}.zarr"
506+
)
507+
508+
if fname_zarr_new.exists() or fname_zarr_new.is_dir():
509+
self.fname_zarr = fname_zarr_new
510+
else:
511+
self.fname_zarr = fname_zarr_old
504512

505513
if not self.fname_zarr.exists() or not self.fname_zarr.is_dir():
506514
_logger.error(f"Zarr file {self.fname_zarr} does not exist.")
@@ -522,12 +530,12 @@ def get_inference_config(self):
522530
_logger.info(
523531
f"Loading config for run {self.run_id} from private paths: {self.private_paths}"
524532
)
525-
config = load_config(self.private_paths, self.run_id, self.epoch)
533+
config = load_config(self.private_paths, self.run_id, self.mini_epoch)
526534
else:
527535
_logger.info(
528536
f"Loading config for run {self.run_id} from model directory: {self.model_base_dir}"
529537
)
530-
config = load_model_config(self.run_id, self.epoch, self.model_base_dir)
538+
config = load_model_config(self.run_id, self.mini_epoch, self.model_base_dir)
531539

532540
if type(config) not in [dict, oc.DictConfig]:
533541
_logger.warning("Model config not found. inference config will be empty.")

packages/evaluate/src/weathergen/evaluate/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ def metric_list_to_json(
435435
Output directory.
436436
run_id :
437437
Identifier of the inference run.
438-
epoch :
439-
Epoch number.
438+
mini_epoch :
439+
Mini_epoch number.
440440
"""
441441
assert len(metrics_list) == len(npoints_sample_list) == len(streams), (
442442
"The lengths of metrics_list, npoints_sample_list, and streams must be the same."
@@ -460,16 +460,16 @@ def metric_list_to_json(
460460
# Match the expected filename pattern
461461
save_path = (
462462
reader.metrics_dir
463-
/ f"{reader.run_id}_{stream}_{region}_{metric}_epoch{reader.epoch:05d}.json"
463+
/ f"{reader.run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json"
464464
)
465465

466466
_logger.info(f"Saving results to {save_path}")
467467
with open(save_path, "w") as f:
468468
json.dump(metric_dict, f, indent=4)
469469

470470
_logger.info(
471-
f"Saved all results of inference run {reader.run_id} - epoch {reader.epoch:d} successfully "
472-
f"to {reader.metrics_dir}."
471+
f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "
472+
f"successfully to {reader.metrics_dir}."
473473
)
474474

475475

src/weathergen/datasets/masking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, cf: Config):
9797

9898
def reset_rng(self, rng) -> None:
9999
"""
100-
Reset rng after epoch to ensure proper randomization
100+
Reset rng after mini_epoch to ensure proper randomization
101101
"""
102102
self.rng = rng
103103

0 commit comments

Comments
 (0)