Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
317501e
Replace cf.rank==0 with utils.distributed.is_root
Jul 16, 2025
77de417
replace cf.rank==0 with weathergen.utils.distributed.is_root
Jul 16, 2025
6439618
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 22, 2025
8993875
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 25, 2025
f4a9d85
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 28, 2025
f8fdef4
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 29, 2025
ca89e7b
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 30, 2025
49d7a4d
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
f39f094
Merge branch 'ecmwf:develop' into develop
csjfwang Jul 31, 2025
ebb03ea
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 25, 2025
f40737d
Merge branch 'ecmwf:develop' into develop
csjfwang Aug 28, 2025
87fa078
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 10, 2025
5dfe275
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 19, 2025
b7244d9
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
5be41f5
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 22, 2025
39d3965
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 23, 2025
015ec88
Merge branch 'ecmwf:develop' into develop
csjfwang Sep 24, 2025
cb1b7cc
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 1, 2025
90da4cf
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 20, 2025
f04891b
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 21, 2025
105d992
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 24, 2025
5f56073
Merge branch 'ecmwf:develop' into develop
csjfwang Oct 26, 2025
95ee18a
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 3, 2025
3c702d3
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 10, 2025
6f14a30
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 13, 2025
5e87881
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 14, 2025
0c7d305
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 24, 2025
e43ac94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 25, 2025
5f63bcc
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
c51eb94
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
dd5acc2
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 26, 2025
f03672d
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 27, 2025
49c52e1
Merge branch 'ecmwf:develop' into develop
csjfwang Nov 28, 2025
c6356a2
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
36c709a
Merge branch 'ecmwf:develop' into develop
csjfwang Dec 1, 2025
b230904
2D RoPE on Query Aggregation & Global Assimilation
Dec 1, 2025
7f0f2a0
fix compute_mixed_cis multi gpu trainning
Dec 1, 2025
a2f5f12
2D RoPE for GlobalAssimilation Only2D RoPE for GlobalAssimilation Only
Dec 2, 2025
36c14e1
add QueryAggregationEngine
Dec 3, 2025
622c342
remove 2d RoPE for forecast engine
Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_steps: 2
forecast_policy: fixed
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_blocks: 8
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
Expand Down Expand Up @@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
training_mode: "forecast"
training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},}
}
# training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]],
Expand Down Expand Up @@ -121,7 +121,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"same_strategy_per_batch": false
}

num_mini_epochs: 32
num_mini_epochs: 16
samples_per_mini_epoch: 4096
samples_per_validation: 512

Expand Down
73 changes: 70 additions & 3 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from weathergen.model.norms import AdaLayerNorm, RMSNorm
from weathergen.model.positional_encoding import (
apply_rotary_emb,
compute_mixed_cis,
init_random_2d_freqs,
)


def _maybe_init_rope(dim_head: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
dim_total = dim_head * num_heads
return init_random_2d_freqs(dim_total, num_heads, theta=theta, rotate=rotate)


def _compute_rope(freqs: torch.Tensor, coords: torch.Tensor, num_heads: int) -> torch.Tensor:
coords = coords.to(freqs.device)
coords_flat = coords.reshape(-1, coords.shape[-1])
freqs_cis = compute_mixed_cis(freqs, coords_flat[:, 0], coords_flat[:, 1], num_heads)
freqs_cis = torch.diagonal(freqs_cis, dim1=0, dim2=1).permute(1, 0, 2)
freqs_cis = freqs_cis.reshape(*coords.shape[:-1], num_heads, -1)
return freqs_cis


class MultiSelfAttentionHeadVarlen(torch.nn.Module):
Expand Down Expand Up @@ -197,13 +216,16 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_rope=False,
rope_theta=10.0,
):
super(MultiSelfAttentionHeadLocal, self).__init__()

self.num_heads = num_heads
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_rope = with_rope

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -231,6 +253,14 @@ def __init__(

self.dtype = attention_dtype
assert with_flash, "Only flash attention supported."
if self.with_rope:
self.register_buffer(
"rope_freqs",
_maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta),
persistent=False,
)
else:
self.rope_freqs = None

# define block mask
def mask_block_local(batch, head, idx_q, idx_kv):
Expand All @@ -242,7 +272,7 @@ def mask_block_local(batch, head, idx_q, idx_kv):
# compile for efficiency
self.flex_attention = torch.compile(flex_attention, dynamic=False)

def forward(self, x, ada_ln_aux=None):
def forward(self, x, ada_ln_aux=None, rope_coords=None):
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -253,6 +283,10 @@ def forward(self, x, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3])
vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3])

if self.with_rope and rope_coords is not None:
freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads)
qs, ks = apply_rotary_emb(qs, ks, freqs)

outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2)

out = self.proj_out(self.dropout(outs.flatten(-2, -1)))
Expand Down Expand Up @@ -378,6 +412,8 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_rope=False,
rope_theta=10.0,
):
super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__()

Expand All @@ -387,6 +423,7 @@ def __init__(
self.with_residual = with_residual
self.with_flash = with_flash
self.softcap = softcap
self.with_rope = with_rope

if norm_type == "LayerNorm":
norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps)
Expand Down Expand Up @@ -426,8 +463,16 @@ def __init__(

self.dtype = attention_dtype
assert with_flash, "Only flash attention supported at the moment"
if self.with_rope:
self.register_buffer(
"rope_freqs",
_maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta),
persistent=False,
)
else:
self.rope_freqs = None

def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None, rope_coords=None):
if self.with_residual:
x_q_in = x_q
x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
Expand All @@ -444,6 +489,13 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x_kv).reshape(s)

if self.with_rope and rope_coords is not None:
freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads)
qs = [
apply_rotary_emb(q_i, q_i, freqs[:, idx].contiguous())[0]
for idx, q_i in enumerate(qs)
]

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0

Expand Down Expand Up @@ -487,6 +539,8 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_rope=False,
rope_theta=10.0,
):
super(MultiSelfAttentionHead, self).__init__()

Expand All @@ -495,6 +549,7 @@ def __init__(
self.softcap = softcap
self.dropout_rate = dropout_rate
self.with_residual = with_residual
self.with_rope = with_rope

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -526,8 +581,16 @@ def __init__(
else:
self.att = self.attention
self.softmax = torch.nn.Softmax(dim=-1)
if self.with_rope:
self.register_buffer(
"rope_freqs",
_maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta),
persistent=False,
)
else:
self.rope_freqs = None

def forward(self, x, ada_ln_aux=None):
def forward(self, x, ada_ln_aux=None, rope_coords=None):
if self.with_residual:
x_in = x
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
Expand All @@ -539,6 +602,10 @@ def forward(self, x, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)

if self.with_rope and rope_coords is not None:
freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads)
qs, ks = apply_rotary_emb(qs, ks, freqs)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0

Expand Down
42 changes: 37 additions & 5 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def __init__(self, cf: Config) -> None:
)
)

def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant):
def forward(
self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant
):
for block in self.ae_adapter:
tokens_global_c = checkpoint(
block,
Expand Down Expand Up @@ -273,6 +275,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_rope=True,
)
)
else:
Expand All @@ -288,6 +291,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_rope=True,
)
)
# MLP block
Expand All @@ -303,9 +307,22 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
)
)

def forward(self, tokens, use_reentrant):
def forward(self, tokens, coords, use_reentrant):
for block in self.ae_aggregation_blocks:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
if isinstance(block, MultiSelfAttentionHead):
tokens = checkpoint(
lambda x, blk=block, c=coords: blk(x, rope_coords=c),
tokens,
use_reentrant=use_reentrant,
)
elif isinstance(block, MultiSelfAttentionHeadLocal):
tokens = checkpoint(
lambda x, blk=block, c=coords: blk(x, rope_coords=c),
tokens,
use_reentrant=use_reentrant,
)
else:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
return tokens


Expand Down Expand Up @@ -341,6 +358,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_rope=True,
)
)
else:
Expand All @@ -356,6 +374,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
norm_type=self.cf.norm_type,
norm_eps=self.cf.norm_eps,
attention_dtype=get_dtype(self.cf.attention_dtype),
with_rope=True,
)
)
# MLP block
Expand All @@ -371,9 +390,22 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None:
)
)

def forward(self, tokens, use_reentrant):
def forward(self, tokens, coords, use_reentrant):
for block in self.ae_global_blocks:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
if isinstance(block, MultiSelfAttentionHead):
tokens = checkpoint(
lambda x, blk=block, c=coords: blk(x, rope_coords=c),
tokens,
use_reentrant=use_reentrant,
)
elif isinstance(block, MultiSelfAttentionHeadLocal):
tokens = checkpoint(
lambda x, blk=block, c=coords: blk(x, rope_coords=c),
tokens,
use_reentrant=use_reentrant,
)
else:
tokens = checkpoint(block, tokens, use_reentrant=use_reentrant)
return tokens


Expand Down
Loading
Loading