From 18164d67881c791262a385fc27f09000257ef606 Mon Sep 17 00:00:00 2001 From: Wael Almikaeel Date: Fri, 17 Oct 2025 14:25:38 +0200 Subject: [PATCH 1/3] default config adjusted for rebase --- config/default_config.yml | 15 +++ src/weathergen/model/engines.py | 54 +++++++-- src/weathergen/model/layers.py | 209 ++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 11 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..596055de3 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -155,8 +155,23 @@ desc: "" data_loader_rng_seed: ??? run_id: ??? +<<<<<<< HEAD # The period to log in the training loop (in number of batch steps) train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 +======= +# Parameters for logging/printing in the training loop +train_log: + # The period to log metrics (in number of batch steps) + log_interval: 20 + +# Forecast MLP type: "dense" (default) or "moe" +fe_mlp_type: "dense" # set to "moe" to enable MoE + +# MoE-only params (ignored when fe_mlp_type != "moe") +fe_moe_num_experts: 8 +fe_moe_top_k: 2 +fe_moe_hidden_factor: 0.5 # = HF_dense / 4 +>>>>>>> 36fea3a (Adding MoEMLP layer to the layers file, integrate the MoE layer in Forecasting engine, and set up the config file to control the use of this layer) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..74e488fce 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -318,18 +318,50 @@ def create(self) -> torch.nn.ModuleList: ) ) # Add MLP block - self.fe_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.fe_dropout_rate, - norm_type=self.cf.norm_type, - dim_aux=1, - norm_eps=self.cf.mlp_norm_eps, - ) + use_moe = getattr(self.cf, "fe_mlp_type", "dense") == "moe" + mlp_common_kwargs = dict( + dim_in=self.cf.ae_global_dim_embed, + dim_out=self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, ) - + # self.fe_blocks.append( + # MLP( + # self.cf.ae_global_dim_embed, + # self.cf.ae_global_dim_embed, + # with_residual=True, + # dropout_rate=self.cf.fe_dropout_rate, + # norm_type=self.cf.norm_type, + # dim_aux=1, + # norm_eps=self.cf.mlp_norm_eps, + # ) + # ) + if use_moe: + self.fe_blocks.append( + MoEMLP( + **mlp_common_kwargs, + num_experts=getattr(self.cf, "fe_moe_num_experts", 8), + top_k=getattr(self.cf, "fe_moe_top_k", 4), + router_noisy_std=getattr(self.cf, "fe_moe_router_noisy_std", 0.0), + hidden_factor=getattr(self.cf, "fe_moe_hidden_factor", 2), + ) + ) + else: + self.fe_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.fe_dropout_rate, + norm_type=self.cf.norm_type, + dim_aux=1, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + # ------------------------------------------------------------------ def init_weights_final(m): if isinstance(m, torch.nn.Linear): torch.nn.init.normal_(m.weight, mean=0, std=0.001) diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..f0f4b6ea6 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -93,3 +93,212 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + +class _DenseBlock(nn.Module): + """A tiny FFN that mirrors the structure of the current MLP stack.""" + def __init__(self, dim_in, dim_hidden, dim_out, num_layers=2, + nonlin=nn.GELU, dropout_rate=0.0): + super().__init__() + layers = [nn.Linear(dim_in, dim_hidden), nonlin(), nn.Dropout(dropout_rate)] + for _ in range(num_layers - 2): + layers += [nn.Linear(dim_hidden, dim_hidden), nonlin(), nn.Dropout(dropout_rate)] + layers += [nn.Linear(dim_hidden, dim_out)] + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + +class MoEMLP(nn.Module): + """ + Drop-in MoE MLP (memory-friendly): + - Same call pattern as the current MLP: forward(*args) where args=(x, ...) and optional aux at the end + - Supports residual add exactly like MLP + - Optional AdaLayerNorm when dim_aux is provided + - Simple top-k router; mixes experts with streaming accumulation (no big [E, ..., D] stack) + """ + def __init__( + self, + dim_in, + dim_out, + num_layers=2, + hidden_factor=2, + pre_layer_norm=True, + dropout_rate=0.0, + nonlin=nn.GELU, + with_residual=False, + norm_type="LayerNorm", + dim_aux=None, + norm_eps=1e-5, + name: str | None = None, + # MoE bits + num_experts: int = 8, + top_k: int = 4, + router_noisy_std: float = 0.0, # set >0 to add noise to router logits + # Memory bits + use_checkpoint: bool = False, # checkpoint expert forward to save memory + ): + super().__init__() + if name is not None: + self.name = name + + assert num_layers >= 2 + assert 1 <= top_k <= num_experts + + self.with_residual = with_residual + self.with_aux = dim_aux is not None + self.pre_layer_norm = pre_layer_norm + self.top_k = top_k + self.num_experts = num_experts + self.use_checkpoint = use_checkpoint + + dim_hidden = int(dim_in * hidden_factor) + + # Norm (match MLP behavior) + Norm = nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm + if pre_layer_norm: + self.norm = ( + Norm(dim_in, eps=norm_eps) + if dim_aux is None + else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) + ) + else: + self.norm = None # no pre-norm + + # Router + self.router = nn.Linear(dim_in, num_experts) + self.router_noisy_std = router_noisy_std + + # Experts (identical shape) + self.experts = nn.ModuleList( + [ + _DenseBlock( + dim_in=dim_in, + dim_hidden=dim_hidden, + dim_out=dim_out, + num_layers=num_layers, + nonlin=nonlin, + dropout_rate=dropout_rate, + ) + for _ in range(num_experts) + ] + ) + + # For optional aux loss (load-balancing); not used unless you read it + self.register_buffer("last_aux_loss", torch.zeros((), dtype=torch.float32)) + + def _gate(self, x_norm): + # x_norm: [*, D]. Router works on the last dim. + logits = self.router(x_norm) + if self.router_noisy_std > 0: + logits = logits + torch.randn_like(logits) * self.router_noisy_std + + if self.top_k == self.num_experts: + # softmax over all experts + weights = torch.softmax(logits, dim=-1) # [..., E] + top_idx = None # not needed + else: + # top-k softmax + top_vals, top_idx = torch.topk(logits, k=self.top_k, dim=-1) # [*, k] + weights = torch.softmax(top_vals, dim=-1) # [*, k] + return weights, top_idx + + @torch.no_grad() + def _compute_load_balance_aux(self, weights, top_idx, num_experts): + """ + Simple load-balancing penalty from Switch/MoE papers: + Encourage uniform expert probability and uniform usage. + Works with both full-softmax (top_idx None) and top-k. + """ + if top_idx is None: + # weights over E + probs = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + else: + # Build usage over experts from top-k selection + # *prefix, K = weights.shape + # flat_w = weights.reshape(-1, K) # [N, K] + # flat_i = top_idx.reshape(-1, K) # [N, K] + if weights.shape != top_idx.shape: + raise ValueError( + "Top-k weights and indices must share the same shape" + ) + + K = weights.shape[-1] + flat_w = weights.reshape(-1, K) # [N, K] + flat_i = top_idx.reshape(-1, K) # [N, K] + E = num_experts + usage = torch.zeros(E, device=weights.device, dtype=weights.dtype) + usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + usage = usage / usage.sum().clamp_min(1e-6) # normalize + probs = usage # proxy + # Target is uniform 1/E + E = num_experts + target = torch.full_like(probs, 1.0 / E) + aux = (probs * (probs.add(1e-6).log() - target.add(1e-6).log())).sum() + return aux + + def forward(self, *args): + # Match your MLP(*args) calling convention + x = args[0] + x_in = x + aux = args[-1] if self.with_aux else None + + # Optional pre-norm (possibly adaptive) + if self.norm is not None: + if self.with_aux: + x = self.norm(x, aux) + else: + x = self.norm(x) + + # Router + weights, top_idx = self._gate(x) # weights: [..., E] or [..., K] + + # Build a full weight tensor [..., E] if we are in top-k mode, + # so we can stream over experts without stacking their outputs. + if top_idx is None: + w_full = weights # [..., E] + else: + # scatter top-k weights into a zero tensor of size E + E = self.num_experts + w_full = torch.zeros(*weights.shape[:-1], E, device=weights.device, dtype=weights.dtype) # [..., E] + w_full.scatter_(-1, top_idx, weights) + + # Output accumulator (no expert stacking) + out_dim = self.experts[0].net[-1].out_features # last Linear of _DenseBlock + y = x.new_zeros(*x.shape[:-1], out_dim) + + # Optional gradient checkpoint + if self.use_checkpoint: + from torch.utils.checkpoint import checkpoint + + # Stream over experts: y += expert(x) * w_full[..., e] + for e, expert in enumerate(self.experts): + # skip compute if weight mass is (nearly) zero for this expert + w_e = w_full[..., e] # [...] + if torch.allclose(w_e, torch.zeros((), device=w_e.device, dtype=w_e.dtype)): + continue + + if self.use_checkpoint and self.training: + y_e = checkpoint(expert, x) + else: + y_e = expert(x) + y = y + y_e * w_e.unsqueeze(-1) + + # Residual (same logic as your MLP) + if self.with_residual: + if y.shape[-1] == x_in.shape[-1]: + y = x_in + y + else: + assert y.shape[-1] % x_in.shape[-1] == 0 + y = y + x_in.repeat([*[1 for _ in y.shape[:-1]], y.shape[-1] // x_in.shape[-1]]) + + # Optional: update aux loss (not returned; read if you want) + with torch.no_grad(): + self.last_aux_loss = self._compute_load_balance_aux( + # w_full if top_idx is not None else weights, # use full probs if we built them + # None if top_idx is None else top_idx, + weights, + top_idx, + self.num_experts, + ) + + return y \ No newline at end of file From c73578e1f341faf16c8404e41922900327f37e05 Mon Sep 17 00:00:00 2001 From: Wael Almikaeel Date: Fri, 17 Oct 2025 14:27:45 +0200 Subject: [PATCH 2/3] correcting default config after rebase --- config/default_config.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 596055de3..62d20d345 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -155,13 +155,12 @@ desc: "" data_loader_rng_seed: ??? run_id: ??? -<<<<<<< HEAD # The period to log in the training loop (in number of batch steps) train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 -======= + # Parameters for logging/printing in the training loop train_log: # The period to log metrics (in number of batch steps) @@ -174,4 +173,3 @@ fe_mlp_type: "dense" # set to "moe" to enable MoE fe_moe_num_experts: 8 fe_moe_top_k: 2 fe_moe_hidden_factor: 0.5 # = HF_dense / 4 ->>>>>>> 36fea3a (Adding MoEMLP layer to the layers file, integrate the MoE layer in Forecasting engine, and set up the config file to control the use of this layer) From b3de51e4e14756ebf38998741c8307a8137359ab Mon Sep 17 00:00:00 2001 From: Wael Almikaeel Date: Fri, 24 Oct 2025 13:59:55 +0200 Subject: [PATCH 3/3] adding MoE layers to the global engine and decoder layers, adding the router loss to the trainier --- config/default_config.yml | 36 +++-- src/weathergen/model/blocks.py | 125 +++++++++++++---- src/weathergen/model/engines.py | 104 ++++++++++++--- src/weathergen/model/layers.py | 228 +++++++++++++++++++------------- src/weathergen/train/trainer.py | 173 ++++++++++++++++++------ 5 files changed, 477 insertions(+), 189 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 62d20d345..bf1f1f160 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 2 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -42,12 +42,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 1 +forecast_policy: "fixed" forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 2 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -168,8 +168,28 @@ train_log: # Forecast MLP type: "dense" (default) or "moe" fe_mlp_type: "dense" # set to "moe" to enable MoE +ae_global_mlp_type: "dense" # set to "moe" to enable MoE +ffn_mlp_type: "dense" # set to "moe" to enable MoE in the feed-forward network of the decoder blocks +decoder_mlp_type: "dense" # set to "moe" to enable MoE in the decoder prediction MLP +moe_lambda: 0.02 # coefficient for the MoE load balancing loss # MoE-only params (ignored when fe_mlp_type != "moe") -fe_moe_num_experts: 8 -fe_moe_top_k: 2 +fe_moe_num_experts: 2 +fe_moe_top_k: 1 fe_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when ae_global_mlp_type != "moe") +ae_global_moe_num_experts: 4 +ae_global_moe_top_k: 2 +ae_global_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when ffn_mlp_type != "moe") +ffn_moe_num_experts: 2 +ffn_moe_top_k: 1 +ffn_moe_hidden_factor: 0.5 # = HF_dense / 4 + +# MoE-only params (ignored when decoder_mlp_type != "moe") +decoder_moe_num_experts: 2 +decoder_moe_top_k: 1 +decoder_moe_hidden_factor: 0.5 # = HF_dense / 4 +tr_mlp_hidden_factor: 2 diff --git a/src/weathergen/model/blocks.py b/src/weathergen/model/blocks.py index 061928f64..94dc73e9d 100644 --- a/src/weathergen/model/blocks.py +++ b/src/weathergen/model/blocks.py @@ -14,10 +14,12 @@ MultiCrossAttentionHeadVarlen, MultiSelfAttentionHeadVarlen, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.norms import AdaLayerNormLayer from weathergen.utils.utils import get_dtype +import logging +logger = logging.getLogger(__name__) class SelfAttentionBlock(nn.Module): """ @@ -43,14 +45,32 @@ def __init__(self, dim, dim_aux, with_adanorm, num_heads, dropout_rate, **kwargs self.mhsa_block = lambda x, _, **kwargs: self.mhsa(self.ln_sa(x), **kwargs) + x approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = MLP( - dim_in=dim, - dim_out=dim, - hidden_factor=4, - dropout_rate=0.1, - nonlin=approx_gelu, - with_residual=False, - ) + use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe") + ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4) + moe_kwargs = kwargs.get("moe_kwargs", {}) # e.g. num_experts, top_k, router_noisy_std + + if use_moe_ffn: + self.mlp = MoEMLP( + dim_in=dim, + dim_out=dim, + hidden_factor=ffn_hidden_factor, + dropout_rate=0.1, + nonlin=nn.GELU, # internal block constructs nonlin() + with_residual=False, + norm_type=kwargs["attention_kwargs"]["norm_type"], + dim_aux=(dim_aux if self.with_adanorm else None), + norm_eps=kwargs["attention_kwargs"]["norm_eps"], + **moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01 + ) + else: + self.mlp = MLP( + dim_in=dim, + dim_out=dim, + hidden_factor=4, + dropout_rate=0.1, + nonlin=approx_gelu, + with_residual=False, + ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) self.mlp_block = AdaLayerNormLayer(dim, dim_aux, self.mlp_fn, dropout_rate) @@ -104,7 +124,7 @@ def __init__( self.with_adanorm = with_adanorm self.with_self_attn = with_self_attn - self.with_mlp = with_self_attn + self.with_mlp = with_mlp if with_self_attn: self.mhsa = MultiSelfAttentionHeadVarlen( @@ -136,18 +156,37 @@ def __init__( if self.with_mlp: approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = MLP( - dim_in=dim_q, - dim_out=dim_q, - hidden_factor=4, - nonlin=approx_gelu, - with_residual=False, - ) + + use_moe_ffn = (kwargs.get("ffn_mlp_type", "dense") == "moe") + ffn_hidden_factor = kwargs.get("ffn_hidden_factor", 4) + moe_kwargs = kwargs.get("moe_kwargs", {}) + + if use_moe_ffn: + self.mlp = MoEMLP( + dim_in=dim_q, + dim_out=dim_q, + hidden_factor=ffn_hidden_factor, + dropout_rate=0.1, + nonlin=nn.GELU, # internal block constructs nonlin() + with_residual=False, + norm_type=kwargs["attention_kwargs"]["norm_type"], + dim_aux=(dim_aux if self.with_adanorm else None), + norm_eps=kwargs["attention_kwargs"]["norm_eps"], + **moe_kwargs, # <- e.g. num_experts=8, top_k=2, router_noisy_std=0.01 + ) + else: + self.mlp = MLP( + dim_in=dim_q, + dim_out=dim_q, + hidden_factor=4, + nonlin=approx_gelu, + with_residual=False, + ) if self.with_adanorm: self.mlp_fn = lambda x, **kwargs: self.mlp(x) self.mlp_block = AdaLayerNormLayer(dim_q, dim_aux, self.mlp_fn, dropout_rate) else: - self.ln_mlp = nn.LayerNorm(dim_q, eps=kwargs["attention_kwargs"]["norm_eps"]) + self.ln_mlp = nn.LayerNorm(eps=kwargs["attention_kwargs"]["norm_eps"]) self.mlp_block = lambda x, _, **kwargs: self.mlp(self.ln_mlp(x)) + x else: self.mlp_block = lambda x, _, **kwargs: x @@ -191,6 +230,7 @@ def __init__( tr_mlp_hidden_factor, tro_type, mlp_norm_eps=1e-6, + **kwargs, ): super().__init__() @@ -237,19 +277,46 @@ def __init__( ) # MLP Block - self.block.append( - MLP( - dim_in, - dim_out, - with_residual=True, - hidden_factor=self.tr_mlp_hidden_factor, - dropout_rate=0.1, # Assuming dropout_rate is 0.1 - norm_type=self.cf.norm_type, - dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), - norm_eps=self.cf.mlp_norm_eps, - ) + # Add MoE option + use_moe = getattr(self.cf, "decoder_mlp_type", "dense") == "moe" + logger.info( + "[MoE] Decoder head: type=%s%s", + "moe" if use_moe else "dense", + ("" if not use_moe else + f" (experts={getattr(self.cf,'moe_num_experts',None)}, top_k={getattr(self.cf,'moe_top_k',None)})"), ) + if use_moe: + self.block.append( + MoEMLP( + dim_in, + dim_out, + hidden_factor=self.tr_mlp_hidden_factor, + dropout_rate=0.1, + with_residual=True, # mirror dense + norm_type=self.cf.norm_type, + dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), + norm_eps=self.cf.mlp_norm_eps, + num_experts=getattr(self.cf, "moe_num_experts", 8), + top_k=getattr(self.cf, "moe_top_k", 2), + router_noisy_std=getattr(self.cf, "moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "moe_use_checkpoint", False), + ) + ) + else: + self.block.append( + MLP( + dim_in, + dim_out, + with_residual=True, + hidden_factor=self.tr_mlp_hidden_factor, + dropout_rate=0.1, # Assuming dropout_rate is 0.1 + norm_type=self.cf.norm_type, + dim_aux=(dim_aux if self.cf.pred_mlp_adaln else None), + norm_eps=self.cf.mlp_norm_eps, + ) + ) + def forward(self, latent, output, coords, latent_lens, output_lens): for layer in self.block: if isinstance(layer, MultiCrossAttentionHeadVarlen): diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 74e488fce..2a7a9721e 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -24,10 +24,12 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype +import logging +logger = logging.getLogger(__name__) class EmbeddingEngine: name: "EmbeddingEngine" @@ -249,17 +251,50 @@ def create(self) -> torch.nn.ModuleList: ) ) # MLP block - self.ae_global_blocks.append( - MLP( - self.cf.ae_global_dim_embed, - self.cf.ae_global_dim_embed, - with_residual=True, - dropout_rate=self.cf.ae_global_dropout_rate, - hidden_factor=self.cf.ae_global_mlp_hidden_factor, - norm_type=self.cf.norm_type, - norm_eps=self.cf.mlp_norm_eps, - ) + # Add MoE option + use_moe = getattr(self.cf, "ae_global_mlp_type", "dense") == "moe" + mlp_common_kwargs = dict( + dim_in=self.cf.ae_global_dim_embed, + dim_out=self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.ae_global_dropout_rate, + norm_type=self.cf.norm_type, + norm_eps=self.cf.mlp_norm_eps, ) + if use_moe: + self.ae_global_blocks.append( + MoEMLP( + **mlp_common_kwargs, + num_experts=getattr(self.cf, "ae_global_moe_num_experts", 2), + top_k=getattr(self.cf, "ae_global_moe_top_k", 1), + router_noisy_std=getattr(self.cf, "ae_global_moe_router_noisy_std", 0.0), + hidden_factor=getattr(self.cf, "ae_global_moe_hidden_factor", 2), + ) + ) + else: + self.ae_global_blocks.append( + MLP( + self.cf.ae_global_dim_embed, + self.cf.ae_global_dim_embed, + with_residual=True, + dropout_rate=self.cf.ae_global_dropout_rate, + hidden_factor=self.cf.ae_global_mlp_hidden_factor, + norm_type=self.cf.norm_type, + norm_eps=self.cf.mlp_norm_eps, + ) + ) + # Count MoE blocks + num_moe = sum(1 for m in self.ae_global_blocks if isinstance(m, MoEMLP)) + logger.info( + "[MoE] GlobalAssimilationEngine: %d MoEMLP blocks " + "(ae_global_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)", + num_moe, + getattr(self.cf, "ae_global_mlp_type", "dense"), + getattr(self.cf, "ae_global_moe_num_experts", None), + getattr(self.cf, "ae_global_moe_top_k", None), + getattr(self.cf, "ae_global_moe_hidden_factor", None), + ) + return self.ae_global_blocks @@ -343,8 +378,8 @@ def create(self) -> torch.nn.ModuleList: self.fe_blocks.append( MoEMLP( **mlp_common_kwargs, - num_experts=getattr(self.cf, "fe_moe_num_experts", 8), - top_k=getattr(self.cf, "fe_moe_top_k", 4), + num_experts=getattr(self.cf, "fe_moe_num_experts", 2), + top_k=getattr(self.cf, "fe_moe_top_k", 2), router_noisy_std=getattr(self.cf, "fe_moe_router_noisy_std", 0.0), hidden_factor=getattr(self.cf, "fe_moe_hidden_factor", 2), ) @@ -362,15 +397,24 @@ def create(self) -> torch.nn.ModuleList: ) ) # ------------------------------------------------------------------ - def init_weights_final(m): - if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_(m.weight, mean=0, std=0.001) - if m.bias is not None: - torch.nn.init.normal_(m.bias, mean=0, std=0.001) - - for block in self.fe_blocks: - block.apply(init_weights_final) - + # def init_weights_final(m): + # if isinstance(m, torch.nn.Linear) and not getattr(m, "is_moe_router", False): + # torch.nn.init.normal_(m.weight, mean=0, std=0.001) + # if m.bias is not None: + # torch.nn.init.normal_(m.bias, mean=0, std=0.001) + + # for block in self.fe_blocks: + # block.apply(init_weights_final) + num_moe = sum(1 for m in self.fe_blocks if isinstance(m, MoEMLP)) + logger.info( + "[MoE] ForecastingEngine: %d MoEMLP blocks " + "(fe_mlp_type=%s, experts=%s, top_k=%s, hidden_factor=%s)", + num_moe, + getattr(self.cf, "fe_mlp_type", "dense"), + getattr(self.cf, "fe_moe_num_experts", None), + getattr(self.cf, "fe_moe_top_k", None), + getattr(self.cf, "fe_moe_hidden_factor", None), + ) return self.fe_blocks @@ -619,6 +663,14 @@ def __init__( with_adanorm=False, with_mlp=False, attention_kwargs=attention_kwargs, + ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"), + ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4), + moe_kwargs=dict( + num_experts=getattr(self.cf, "decoder_moe_num_experts", 2), + top_k=getattr(self.cf, "decoder_moe_top_k", 2), + router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False), + ) ) ) elif self.cf.decoder_type == "AdaLayerNormConditioning": @@ -674,6 +726,14 @@ def __init__( tr_mlp_hidden_factor=tr_mlp_hidden_factor, tro_type=tro_type, mlp_norm_eps=self.cf.mlp_norm_eps, + ffn_mlp_type=getattr(self.cf, "decoder_ffn_mlp_type", "dense"), + ffn_hidden_factor=getattr(self.cf, "decoder_ffn_hidden_factor", 4), + moe_kwargs=dict( + num_experts=getattr(self.cf, "decoder_moe_num_experts", 2), + top_k=getattr(self.cf, "decoder_moe_top_k", 2), + router_noisy_std=getattr(self.cf, "decoder_moe_router_noisy_std", 0.0), + use_checkpoint=getattr(self.cf, "decoder_moe_use_checkpoint", False), + ) ) ) else: diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index f0f4b6ea6..acd06b721 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from typing import Optional, Tuple, Dict, Any from weathergen.model.norms import AdaLayerNorm, RMSNorm @@ -93,7 +94,7 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x - + class _DenseBlock(nn.Module): """A tiny FFN that mirrors the structure of the current MLP stack.""" def __init__(self, dim_in, dim_hidden, dim_out, num_layers=2, @@ -108,49 +109,67 @@ def __init__(self, dim_in, dim_hidden, dim_out, num_layers=2, def forward(self, x): return self.net(x) + class MoEMLP(nn.Module): """ - Drop-in MoE MLP (memory-friendly): - - Same call pattern as the current MLP: forward(*args) where args=(x, ...) and optional aux at the end - - Supports residual add exactly like MLP - - Optional AdaLayerNorm when dim_aux is provided - - Simple top-k router; mixes experts with streaming accumulation (no big [E, ..., D] stack) + Memory-friendly MoE MLP. + + Features + -------- + - Matches MLP call pattern: forward(*args) where args=(x, ...) and optional aux at the end + - Optional AdaLayerNorm pre-norm when dim_aux is provided + - Top-k routing with softmax over selected logits + - Streams experts and accumulates outputs (no large [E, ..., D] stacks) + - Optional auxiliary outputs (gate loss, route histogram) via `return_aux` + + Notes + ----- + - If `return_aux=False` (default), we still *compute* the aux loss (with grads) and stash it + on `self.last_aux` and `self.last_aux_loss` so you can read it after forward if desired. + - To actively use the load-balancing loss in training, either set `return_aux=True` and add it + to your loss, or read `self.last_aux['gate_loss']` from the module instance. """ def __init__( self, - dim_in, - dim_out, - num_layers=2, - hidden_factor=2, - pre_layer_norm=True, - dropout_rate=0.0, + dim_in: int, + dim_out: int, + num_layers: int = 2, + hidden_factor: float = 2.0, + pre_layer_norm: bool = True, + dropout_rate: float = 0.0, nonlin=nn.GELU, - with_residual=False, - norm_type="LayerNorm", - dim_aux=None, - norm_eps=1e-5, - name: str | None = None, - # MoE bits + with_residual: bool = False, + norm_type: str = "LayerNorm", + dim_aux: Optional[int] = None, + norm_eps: float = 1e-5, + name: Optional[str] = None, + # MoE num_experts: int = 8, top_k: int = 4, - router_noisy_std: float = 0.0, # set >0 to add noise to router logits - # Memory bits - use_checkpoint: bool = False, # checkpoint expert forward to save memory + router_noisy_std: float = 0.0, + # Memory + use_checkpoint: bool = False, + # API + return_aux: bool = False, ): super().__init__() if name is not None: self.name = name - assert num_layers >= 2 - assert 1 <= top_k <= num_experts + assert num_layers >= 2, "MoEMLP requires at least 2 layers" + assert 1 <= top_k <= num_experts, "top_k must be in [1, num_experts]" self.with_residual = with_residual self.with_aux = dim_aux is not None self.pre_layer_norm = pre_layer_norm self.top_k = top_k self.num_experts = num_experts + self.router_noisy_std = router_noisy_std self.use_checkpoint = use_checkpoint + self.return_aux = return_aux + self.enable_gate_loss = True + self.register_buffer("usage_buf", torch.zeros(num_experts), persistent=False) dim_hidden = int(dim_in * hidden_factor) # Norm (match MLP behavior) @@ -162,13 +181,15 @@ def __init__( else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps) ) else: - self.norm = None # no pre-norm + self.norm = None # Router self.router = nn.Linear(dim_in, num_experts) - self.router_noisy_std = router_noisy_std + # Recommended init: small std, zero bias + nn.init.normal_(self.router.weight, mean=0.0, std=1e-2) + nn.init.constant_(self.router.bias, 0.0) - # Experts (identical shape) + # Experts self.experts = nn.ModuleList( [ _DenseBlock( @@ -183,107 +204,98 @@ def __init__( ] ) - # For optional aux loss (load-balancing); not used unless you read it + # Stashed aux for consumers that don't use return_aux self.register_buffer("last_aux_loss", torch.zeros((), dtype=torch.float32)) + self.last_aux: Dict[str, torch.Tensor] = {} - def _gate(self, x_norm): - # x_norm: [*, D]. Router works on the last dim. + def _gate(self, x_norm: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Returns: + weights: [..., E] if top_k == E else [..., K] + top_idx: None if full softmax, else [..., K] int indices + """ logits = self.router(x_norm) if self.router_noisy_std > 0: logits = logits + torch.randn_like(logits) * self.router_noisy_std if self.top_k == self.num_experts: - # softmax over all experts - weights = torch.softmax(logits, dim=-1) # [..., E] - top_idx = None # not needed + weights = torch.softmax(logits, dim=-1) + top_idx = None else: - # top-k softmax - top_vals, top_idx = torch.topk(logits, k=self.top_k, dim=-1) # [*, k] - weights = torch.softmax(top_vals, dim=-1) # [*, k] + top_vals, top_idx = torch.topk(logits, k=self.top_k, dim=-1) + weights = torch.softmax(top_vals, dim=-1) return weights, top_idx - @torch.no_grad() - def _compute_load_balance_aux(self, weights, top_idx, num_experts): + def _compute_load_balance_aux( + self, weights: torch.Tensor, top_idx: Optional[torch.Tensor], num_experts: int + ) -> torch.Tensor: """ - Simple load-balancing penalty from Switch/MoE papers: - Encourage uniform expert probability and uniform usage. - Works with both full-softmax (top_idx None) and top-k. + Cross-entropy between observed expert usage and uniform 1/E target. + Works for both full-softmax and top-k. """ if top_idx is None: - # weights over E + # weights over E -> average across batch/time dims probs = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] else: - # Build usage over experts from top-k selection - # *prefix, K = weights.shape - # flat_w = weights.reshape(-1, K) # [N, K] - # flat_i = top_idx.reshape(-1, K) # [N, K] + # Aggregate usage from top-k selections if weights.shape != top_idx.shape: - raise ValueError( - "Top-k weights and indices must share the same shape" - ) - + raise ValueError("Top-k weights and indices must share the same shape") K = weights.shape[-1] flat_w = weights.reshape(-1, K) # [N, K] flat_i = top_idx.reshape(-1, K) # [N, K] - E = num_experts - usage = torch.zeros(E, device=weights.device, dtype=weights.dtype) + usage = torch.zeros(num_experts, device=weights.device, dtype=weights.dtype) usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) - usage = usage / usage.sum().clamp_min(1e-6) # normalize - probs = usage # proxy - # Target is uniform 1/E + probs = usage / usage.sum().clamp_min(1e-6) # [E] + E = num_experts target = torch.full_like(probs, 1.0 / E) aux = (probs * (probs.add(1e-6).log() - target.add(1e-6).log())).sum() return aux def forward(self, *args): - # Match your MLP(*args) calling convention + """ + Args: + *args: expects x first; if AdaLN is enabled (dim_aux != None), the last arg is aux. + + Returns: + y or (y, aux_out) depending on `self.return_aux`. + aux_out = {"gate_loss": ..., "route_hist": ...} + """ x = args[0] x_in = x - aux = args[-1] if self.with_aux else None + aux_in = args[-1] if self.with_aux else None - # Optional pre-norm (possibly adaptive) + # Optional pre-norm if self.norm is not None: - if self.with_aux: - x = self.norm(x, aux) - else: - x = self.norm(x) + x = self.norm(x, aux_in) if self.with_aux else self.norm(x) - # Router - weights, top_idx = self._gate(x) # weights: [..., E] or [..., K] + # Routing + weights, top_idx = self._gate(x) # [..., E] or [..., K] - # Build a full weight tensor [..., E] if we are in top-k mode, - # so we can stream over experts without stacking their outputs. + # Build full weights when in top-k mode to stream experts if top_idx is None: w_full = weights # [..., E] else: - # scatter top-k weights into a zero tensor of size E E = self.num_experts - w_full = torch.zeros(*weights.shape[:-1], E, device=weights.device, dtype=weights.dtype) # [..., E] + w_full = torch.zeros(*weights.shape[:-1], E, device=weights.device, dtype=weights.dtype) w_full.scatter_(-1, top_idx, weights) - # Output accumulator (no expert stacking) + # Accumulate outputs without stacking out_dim = self.experts[0].net[-1].out_features # last Linear of _DenseBlock y = x.new_zeros(*x.shape[:-1], out_dim) - # Optional gradient checkpoint - if self.use_checkpoint: + if self.use_checkpoint and self.training: from torch.utils.checkpoint import checkpoint - # Stream over experts: y += expert(x) * w_full[..., e] for e, expert in enumerate(self.experts): - # skip compute if weight mass is (nearly) zero for this expert w_e = w_full[..., e] # [...] - if torch.allclose(w_e, torch.zeros((), device=w_e.device, dtype=w_e.dtype)): + # Skip experts with (near) zero mass + if w_e.abs().max() <= 1e-12: continue - - if self.use_checkpoint and self.training: - y_e = checkpoint(expert, x) - else: - y_e = expert(x) + y_e = expert(x) if not (self.use_checkpoint and self.training) else checkpoint(expert, x) y = y + y_e * w_e.unsqueeze(-1) - # Residual (same logic as your MLP) + # Residual if self.with_residual: if y.shape[-1] == x_in.shape[-1]: y = x_in + y @@ -291,14 +303,52 @@ def forward(self, *args): assert y.shape[-1] % x_in.shape[-1] == 0 y = y + x_in.repeat([*[1 for _ in y.shape[:-1]], y.shape[-1] // x_in.shape[-1]]) - # Optional: update aux loss (not returned; read if you want) - with torch.no_grad(): - self.last_aux_loss = self._compute_load_balance_aux( - # w_full if top_idx is not None else weights, # use full probs if we built them - # None if top_idx is None else top_idx, - weights, - top_idx, - self.num_experts, - ) + # # Aux outputs (WITH grads so router learns; also stash for external access) + # aux_out: Dict[str, Any] = {} + # gate_loss = self._compute_load_balance_aux(weights, top_idx, self.num_experts) + # aux_out["gate_loss"] = gate_loss + + # # utilization histogram (for logging) + # if top_idx is None: + # aux_out["route_hist"] = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + # else: + # K = weights.shape[-1] + # flat_w = weights.reshape(-1, K) + # flat_i = top_idx.reshape(-1, K) + # usage = torch.zeros(self.num_experts, device=weights.device, dtype=weights.dtype) + # usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + # aux_out["route_hist"] = usage / usage.sum().clamp_min(1e-6) # [E] + + # # stash for consumers that don't use return_aux + # self.last_aux = aux_out + # self.last_aux_loss = gate_loss + + # return (y, aux_out) if self.return_aux else y + # --- Aux outputs (gate loss + route hist) --- + aux_out: Dict[str, Any] = {} + if self.enable_gate_loss: + gate_loss = self._compute_load_balance_aux(weights, top_idx, self.num_experts) + aux_out["gate_loss"] = gate_loss + + # utilization histogram (for logging / debugging only) + if top_idx is None: + aux_out["route_hist"] = weights.mean(dim=tuple(range(weights.dim() - 1))) # [E] + else: + K = weights.shape[-1] + flat_w = weights.reshape(-1, K) + flat_i = top_idx.reshape(-1, K) + usage = self.usage_buf + usage = usage.to(weights.device, dtype=weights.dtype) + usage.zero_() + usage.scatter_add_(0, flat_i.reshape(-1), flat_w.reshape(-1)) + aux_out["route_hist"] = usage / usage.sum().clamp_min(1e-6) # [E] + else: + # no aux computation this step + pass + + # stash + self.last_aux = aux_out + if "gate_loss" in aux_out: + self.last_aux_loss = aux_out["gate_loss"] - return y \ No newline at end of file + return (y, aux_out) if self.return_aux else y \ No newline at end of file diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 8f26da14d..2c2396c84 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -40,7 +40,7 @@ MultiSelfAttentionHeadVarlen, ) from weathergen.model.ema import EMAModel -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, MoEMLP from weathergen.model.model import Model, ModelParams from weathergen.model.utils import freeze_weights from weathergen.train.loss_calculator import LossCalculator @@ -154,6 +154,14 @@ def inference(self, cf, devices, run_id_trained, epoch): self.validate(epoch=0) logger.info(f"Finished inference run with id: {cf.run_id}") + def _ensure_moe_modules_cached(self): + # Works with plain, DDP-wrapped, FSDP, or compiled models + from weathergen.model.layers import MoEMLP + m = self.model + if hasattr(m, "module"): # DDP + m = m.module + self.moe_modules = [x for x in m.modules() if isinstance(x, MoEMLP)] + def init_model_and_shard(self, cf, devices): sources_size = self.dataset.get_sources_size() targets_num_channels = self.dataset.get_targets_num_channels() @@ -197,6 +205,7 @@ def init_model_and_shard(self, cf, devices): MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, MultiSelfAttentionHeadVarlen, + MoEMLP, ) for module in model.ae_local_blocks.modules(): @@ -239,6 +248,7 @@ def init_model_and_shard(self, cf, devices): fully_shard(model) for tensor in itertools.chain(model.parameters(), model.buffers()): assert tensor.device == torch.device("meta") + return model, model_params def run(self, cf, devices, run_id_contd=None, epoch_contd=None): @@ -282,7 +292,7 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None): ) self.model, self.model_params = self.init_model_and_shard(cf, devices) - + self._ensure_moe_modules_cached() if run_id_contd is None: self.model.to_empty(device="cuda") self.model.reset_parameters() @@ -560,7 +570,12 @@ def train(self, epoch): for bidx, batch in enumerate(dataset_iter): forecast_steps = batch[-1] batch = self.batch_to_device(batch) - + interval = max(1, int(getattr(self.cf, "moe_loss_interval", 1))) + collect = (self.cf.istep % interval) == 0 + for m in self.moe_modules: + # only set if MoEMLP implements the flag + if hasattr(m, "enable_gate_loss"): + m.enable_gate_loss = collect # evaluate model with torch.autocast( device_type=f"cuda:{cf.local_rank}", @@ -577,50 +592,126 @@ def train(self, epoch): if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + + # MoE gate loss + moe_lambda_base = float(getattr(self.cf, "moe_lambda", 0.02)) + if moe_lambda_base and collect: + # optional warmup + warm = int(getattr(self.cf, "moe_lambda_warmup_steps", 0)) + warm_mult = 1.0 if warm <= 0 else min(1.0, self.cf.istep / float(warm)) + + gate_loss = None + for m in self.moe_modules: + la = getattr(m, "last_aux", None) + if isinstance(la, dict) and ("gate_loss" in la): + gate_loss = la["gate_loss"] if gate_loss is None else (gate_loss + la["gate_loss"]) + + if gate_loss is not None: + # scale λ by interval so average gradient matches per-step application + effective_lambda = moe_lambda_base * interval * warm_mult + loss_values.loss = loss_values.loss + effective_lambda * gate_loss + + # moe_lambda = getattr(self.cf, "moe_lambda", 0.0) + # gate_loss = None + # if moe_lambda != 0.0: + # gate_loss = torch.zeros((), device=self.device) + # route_hists = [] + + # for m in self.model.modules(): + # if isinstance(m, MoEMLP) and hasattr(m, "last_aux"): + # la = m.last_aux + # if isinstance(la, dict): + # if "gate_loss" in la: + # gate_loss = gate_loss + la["gate_loss"] + # if "route_hist" in la: + # # route_hist: [E] + # route_hists.append(la["route_hist"].detach()) + + # loss_values.loss = loss_values.loss + moe_lambda * gate_loss + + # # Lightweight logging every metrics interval + # if (self.cf.istep % self.train_log_freq.metrics) == 0: + # # summarize routing (entropy and max-util) + # # summarize routing (entropy and max-util) without stacking, since E can differ per block + # if route_hists: + # entropies = [] + # max_utils = [] + # sizes = [] + # for rh in route_hists: + # p = rh.float() # [E], sums ~1 + # ent = (-(p * p.clamp_min(1e-6).log())).sum() # scalar + # entropies.append(ent.item()) + # max_utils.append(p.max().item()) + # sizes.append(p.numel()) + + # # averages across blocks + # entropy_mean = torch.tensor(entropies, device=self.device).mean().item() + # max_util_mean = torch.tensor(max_utils, device=self.device).mean().item() + + # # optional: quick distribution of expert counts across MoE modules + # # (kept tiny for logging) + # unique_E = sorted(set(sizes)) + # logger.info( + # "[MoE] step=%d | gate_loss=%.4e (λ=%.3g) | blocks=%d | route: entropy=%.3f, max_util=%.3f | E=%s", + # self.cf.istep, + # gate_loss.item(), + # moe_lambda, + # len(route_hists), + # entropy_mean, + # max_util_mean, + # unique_E, + # ) + # else: + # logger.info( + # "[MoE] step=%d | gate_loss=%.4e (λ=%.3g) | blocks=0 (no route_hist yet)", + # self.cf.istep, + # gate_loss.item(), + # moe_lambda, + # ) + + # backward pass + self.optimizer.zero_grad() + self.grad_scaler.scale(loss_values.loss).backward() + # loss_values.loss.backward() + + # gradient clipping + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) + + # optimizer step + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + # self.optimizer.step() + + # update learning rate + self.lr_scheduler.step() + + # EMA update + if self.validate_with_ema: + self.ema_model.update( + self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, + self.world_size_original * self.cf.batch_size_per_gpu, + ) - # backward pass - self.optimizer.zero_grad() - self.grad_scaler.scale(loss_values.loss).backward() - # loss_values.loss.backward() - - # gradient clipping - self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=cf.grad_clip) - - # optimizer step - self.grad_scaler.step(self.optimizer) - self.grad_scaler.update() - # self.optimizer.step() - - # update learning rate - self.lr_scheduler.step() - - # EMA update - if self.validate_with_ema: - self.ema_model.update( - self.cf.istep * self.world_size_original * self.cf.batch_size_per_gpu, - self.world_size_original * self.cf.batch_size_per_gpu, - ) - - self.loss_unweighted_hist += [loss_values.losses_all] - self.loss_model_hist += [loss_values.loss.item()] - self.stdev_unweighted_hist += [loss_values.stddev_all] + self.loss_unweighted_hist += [loss_values.losses_all] + self.loss_model_hist += [loss_values.loss.item()] + self.stdev_unweighted_hist += [loss_values.stddev_all] - perf_gpu, perf_mem = self.get_perf() - self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() - self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() + perf_gpu, perf_mem = self.get_perf() + self.perf_gpu = ddp_average(torch.tensor([perf_gpu], device=self.device)).item() + self.perf_mem = ddp_average(torch.tensor([perf_mem], device=self.device)).item() - self._log_terminal(bidx, epoch, TRAIN) - if bidx % self.train_log_freq.metrics == 0: - self._log(TRAIN) + self._log_terminal(bidx, epoch, TRAIN) + if bidx % self.train_log_freq.metrics == 0: + self._log(TRAIN) - # save model checkpoint (with designation _latest) - if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: - self.save_model(-1) + # save model checkpoint (with designation _latest) + if bidx % self.train_log_freq.checkpoint == 0 and bidx > 0: + self.save_model(-1) - self.cf.istep += 1 + self.cf.istep += 1 - self.dataset.advance() + self.dataset.advance() def validate(self, epoch): cf = self.cf