diff --git a/config/default_config.yml b/config/default_config.yml index 66b57c865..4e645c45b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -50,12 +50,12 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: fixed forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -93,7 +93,7 @@ ema_halflife_in_thousands: 1e-3 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" training_mode_config: {"losses": {LossPhysical: {weight: 0.7, loss_fcts: [['mse', 0.8], ['mae', 0.2]]},} } # training_mode_config: {"loss": {LossPhysical: [['mse', 0.7]], @@ -121,7 +121,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_mini_epochs: 32 +num_mini_epochs: 16 samples_per_mini_epoch: 4096 samples_per_validation: 512 diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 39ed1c041..9c45b3e1a 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,6 +14,25 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm +from weathergen.model.positional_encoding import ( + apply_rotary_emb, + compute_mixed_cis, + init_random_2d_freqs, +) + + +def _maybe_init_rope(dim_head: int, num_heads: int, theta: float = 10.0, rotate: bool = True): + dim_total = dim_head * num_heads + return init_random_2d_freqs(dim_total, num_heads, theta=theta, rotate=rotate) + + +def _compute_rope(freqs: torch.Tensor, coords: torch.Tensor, num_heads: int) -> torch.Tensor: + coords = coords.to(freqs.device) + coords_flat = coords.reshape(-1, coords.shape[-1]) + freqs_cis = compute_mixed_cis(freqs, coords_flat[:, 0], coords_flat[:, 1], num_heads) + freqs_cis = torch.diagonal(freqs_cis, dim1=0, dim2=1).permute(1, 0, 2) + freqs_cis = freqs_cis.reshape(*coords.shape[:-1], num_heads, -1) + return freqs_cis class MultiSelfAttentionHeadVarlen(torch.nn.Module): @@ -197,6 +216,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_rope=False, + rope_theta=10.0, ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -204,6 +225,7 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual + self.with_rope = with_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -231,6 +253,14 @@ def __init__( self.dtype = attention_dtype assert with_flash, "Only flash attention supported." + if self.with_rope: + self.register_buffer( + "rope_freqs", + _maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta), + persistent=False, + ) + else: + self.rope_freqs = None # define block mask def mask_block_local(batch, head, idx_q, idx_kv): @@ -242,7 +272,7 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, ada_ln_aux=None, rope_coords=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -253,6 +283,10 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) + if self.with_rope and rope_coords is not None: + freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads) + qs, ks = apply_rotary_emb(qs, ks, freqs) + outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) out = self.proj_out(self.dropout(outs.flatten(-2, -1))) @@ -378,6 +412,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_rope=False, + rope_theta=10.0, ): super(MultiCrossAttentionHeadVarlenSlicedQ, self).__init__() @@ -387,6 +423,7 @@ def __init__( self.with_residual = with_residual self.with_flash = with_flash self.softcap = softcap + self.with_rope = with_rope if norm_type == "LayerNorm": norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps) @@ -426,8 +463,16 @@ def __init__( self.dtype = attention_dtype assert with_flash, "Only flash attention supported at the moment" + if self.with_rope: + self.register_buffer( + "rope_freqs", + _maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta), + persistent=False, + ) + else: + self.rope_freqs = None - def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): + def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None, rope_coords=None): if self.with_residual: x_q_in = x_q x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux) @@ -444,6 +489,13 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x_kv).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x_kv).reshape(s) + if self.with_rope and rope_coords is not None: + freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads) + qs = [ + apply_rotary_emb(q_i, q_i, freqs[:, idx].contiguous())[0] + for idx, q_i in enumerate(qs) + ] + # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 @@ -487,6 +539,8 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_rope=False, + rope_theta=10.0, ): super(MultiSelfAttentionHead, self).__init__() @@ -495,6 +549,7 @@ def __init__( self.softcap = softcap self.dropout_rate = dropout_rate self.with_residual = with_residual + self.with_rope = with_rope assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -526,8 +581,16 @@ def __init__( else: self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) + if self.with_rope: + self.register_buffer( + "rope_freqs", + _maybe_init_rope(self.dim_head_proj, self.num_heads, theta=rope_theta), + persistent=False, + ) + else: + self.rope_freqs = None - def forward(self, x, ada_ln_aux=None): + def forward(self, x, ada_ln_aux=None, rope_coords=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) @@ -539,6 +602,10 @@ def forward(self, x, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) + if self.with_rope and rope_coords is not None: + freqs = _compute_rope(self.rope_freqs, rope_coords, self.num_heads) + qs, ks = apply_rotary_emb(qs, ks, freqs) + # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 190fd6548..b86805dc6 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -225,7 +225,9 @@ def __init__(self, cf: Config) -> None: ) ) - def forward(self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant): + def forward( + self, tokens_c, tokens_global_c, q_cells_lens_c, cell_lens_c, use_reentrant + ): for block in self.ae_adapter: tokens_global_c = checkpoint( block, @@ -273,6 +275,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_rope=True, ) ) else: @@ -288,6 +291,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_rope=True, ) ) # MLP block @@ -303,9 +307,22 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, use_reentrant): + def forward(self, tokens, coords, use_reentrant): for block in self.ae_aggregation_blocks: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + if isinstance(block, MultiSelfAttentionHead): + tokens = checkpoint( + lambda x, blk=block, c=coords: blk(x, rope_coords=c), + tokens, + use_reentrant=use_reentrant, + ) + elif isinstance(block, MultiSelfAttentionHeadLocal): + tokens = checkpoint( + lambda x, blk=block, c=coords: blk(x, rope_coords=c), + tokens, + use_reentrant=use_reentrant, + ) + else: + tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) return tokens @@ -341,6 +358,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_rope=True, ) ) else: @@ -356,6 +374,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_type=self.cf.norm_type, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_rope=True, ) ) # MLP block @@ -371,9 +390,22 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) - def forward(self, tokens, use_reentrant): + def forward(self, tokens, coords, use_reentrant): for block in self.ae_global_blocks: - tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) + if isinstance(block, MultiSelfAttentionHead): + tokens = checkpoint( + lambda x, blk=block, c=coords: blk(x, rope_coords=c), + tokens, + use_reentrant=use_reentrant, + ) + elif isinstance(block, MultiSelfAttentionHeadLocal): + tokens = checkpoint( + lambda x, blk=block, c=coords: blk(x, rope_coords=c), + tokens, + use_reentrant=use_reentrant, + ) + else: + tokens = checkpoint(block, tokens, use_reentrant=use_reentrant) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7a71cda15..60f146611 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -23,6 +23,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.engines import ( EmbeddingEngine, EnsPredictionHead, @@ -75,13 +76,9 @@ def __init__(self, cf) -> None: torch.zeros(len_token_seq, cf.ae_local_dim_embed, dtype=self.dtype), requires_grad=False ) - pe = torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, - cf.ae_global_dim_embed, - dtype=self.dtype, + self.query_coords = torch.nn.Parameter( + torch.zeros(self.num_healpix_cells, 2, dtype=torch.float32), requires_grad=False ) - self.pe_global = torch.nn.Parameter(pe, requires_grad=False) ### HEALPIX NEIGHBOURS ### hlc = self.healpix_level @@ -149,29 +146,9 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) - dim_embed = cf.ae_global_dim_embed - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.query_coords.device)).to(self.query_coords.dtype) + self.query_coords.data.copy_(coords) # healpix neighborhood structure @@ -598,9 +575,11 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.embed_cells(model_params, streams_data) # local assimilation engine and adapter - tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) + tokens, coords_tokens, posteriors = self.assimilate_local( + model_params, tokens, source_cell_lens + ) - tokens = self.assimilate_global(model_params, tokens) + tokens = self.assimilate_global(model_params, tokens, coords_tokens) # roll-out in latent space preds_all = [] @@ -666,7 +645,7 @@ def assimilate_local( cell_lens : Used to identify range of tokens to use from generated tokens in cell embedding Returns: - Tokens for global assimilation + Tuple of (global tokens, per-token coordinates, posterior stats) """ batch_size = ( @@ -674,15 +653,15 @@ def assimilate_local( ) s = self.q_cells.shape - # print( f'{np.prod(np.array(tokens.shape))} :: {np.prod(np.array(s))}' - # + ':: {np.prod(np.array(tokens.shape))/np.prod(np.array(s))}') - # TODO: test if positional encoding is needed here + coords_cells = model_params.query_coords.to(tokens.device) + coords_base = coords_cells.unsqueeze(1).repeat(1, self.cf.ae_local_num_queries, 1) + coords_seq_full = coords_base.unsqueeze(0).repeat(batch_size, 1, 1, 1) + coords_global = coords_seq_full.reshape(-1, self.cf.ae_local_num_queries, 2) if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(batch_size, 1, 1) + tokens_global = self.q_cells else: - tokens_global = ( - self.q_cells.repeat(self.num_healpix_cells, 1, 1) + model_params.pe_global - ) + tokens_global = self.q_cells.repeat(self.num_healpix_cells, 1, 1) + tokens_global = tokens_global.repeat(batch_size, 1, 1) q_cells_lens = torch.cat( [model_params.q_cells_lens[0].unsqueeze(0)] + [model_params.q_cells_lens[1:] for _ in range(batch_size)] @@ -714,6 +693,7 @@ def assimilate_local( cell_lens = cell_lens[1:] clen = self.num_healpix_cells // (2 if self.cf.healpix_level <= 5 else 8) tokens_global_unmasked_all = [] + coords_global_unmasked_all = [] if self.ae_aggregation_engine is not None else None posteriors = [] zero_pad = torch.zeros(1, device=tokens.device, dtype=torch.int32) for i in range((cell_lens.shape[0]) // clen): @@ -726,6 +706,7 @@ def assimilate_local( tokens_c = tokens[l0:l1] tokens_global_c = tokens_global[i * clen : i_end] + coords_cells_c = coords_global[i * clen : i_end] cell_lens_c = torch.cat([zero_pad, cell_lens[i * clen : i_end]]) q_cells_lens_c = q_cells_lens[: cell_lens_c.shape[0]] @@ -743,11 +724,18 @@ def assimilate_local( # create mask for global tokens, without first element (used for padding) mask_c = cell_lens_c[1:].to(torch.bool) tokens_global_unmasked_c = tokens_global_c[mask_c] + coords_global_unmasked_c = ( + coords_cells_c[mask_c] + if self.ae_aggregation_engine is not None + else None + ) q_cells_lens_unmasked_c = torch.cat([zero_pad, q_cells_lens_c[1:][mask_c]]) cell_lens_unmasked_c = torch.cat([zero_pad, cell_lens_c[1:][mask_c]]) if l0 == l1 or tokens_c.shape[0] == 0: tokens_global_unmasked_all += [tokens_global_unmasked_c] + if coords_global_unmasked_all is not None: + coords_global_unmasked_all += [coords_global_unmasked_c] continue # local to global adapter engine @@ -760,18 +748,18 @@ def assimilate_local( ) tokens_global_unmasked_all += [tokens_global_unmasked_c] + if coords_global_unmasked_all is not None: + coords_global_unmasked_all += [coords_global_unmasked_c] tokens_global_unmasked = torch.cat(tokens_global_unmasked_all) - - # query aggregation engine on the query tokens in unmasked cells - # (applying this here assumes batch_size=1) - # permute to use ae_local_num_queries as the batchsize and no_of_tokens - # as seq len for flash attention - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) - tokens_global_unmasked = self.ae_aggregation_engine( - tokens_global_unmasked, use_reentrant=False - ) - tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) + if self.ae_aggregation_engine is not None: + coords_global_unmasked = torch.cat(coords_global_unmasked_all) + tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) + coords_global_unmasked = torch.permute(coords_global_unmasked, [1, 0, 2]) + tokens_global_unmasked = self.ae_aggregation_engine( + tokens_global_unmasked, coords_global_unmasked, use_reentrant=False + ) + tokens_global_unmasked = torch.permute(tokens_global_unmasked, [1, 0, 2]) # create mask from cell lens mask = cell_lens.to(torch.bool) @@ -780,30 +768,37 @@ def assimilate_local( tokens_global[mask] = tokens_global_unmasked.to(tokens_global.dtype) # recover batch dimension and build global token list - tokens_global = ( - tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]]) - + model_params.pe_global - ).flatten(1, 2) + tokens_global = tokens_global.reshape([batch_size, self.num_healpix_cells, s[-2], s[-1]]) + tokens_global = tokens_global.flatten(1, 2) + coords_flat = coords_seq_full.flatten(1, 2) - return tokens_global, posteriors + return tokens_global, coords_flat, posteriors ######################################### - def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Tensor: + def assimilate_global( + self, model_params: ModelParams, tokens: torch.Tensor, coords: torch.Tensor + ) -> torch.Tensor: """Performs transformer based global assimilation in latent space Args: model_params : Query and embedding parameters (never used) tokens : Input tokens to be pre-processed by global assimilation + coords : Lat/lon coordinates associated with the query tokens Returns: Latent representation of the model """ # global assimilation engine and adapter - tokens = self.ae_global_engine(tokens, use_reentrant=False) + tokens = self.ae_global_engine(tokens, coords, use_reentrant=False) return tokens ######################################### - def forecast(self, model_params: ModelParams, tokens: torch.Tensor, fstep: int) -> torch.Tensor: + def forecast( + self, + model_params: ModelParams, + tokens: torch.Tensor, + fstep: int, + ) -> torch.Tensor: """Advances latent space representation in time Args: diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 88df67fa3..b88a42aad 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -94,3 +94,79 @@ def positional_encoding_harmonic_coord(x, lats, lons): x = x + pe return x + +#################################################################################################### +# 2D Rotary Position Embedding +# https://github.com/naver-ai/rope-vit/blob/main/models/vit_rope.py +#################################################################################################### +def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True): + freqs_x = [] + freqs_y = [] + mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + for i in range(num_heads): + angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1) + fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1) + fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1) + freqs_x.append(fx) + freqs_y.append(fy) + freqs_x = torch.stack(freqs_x, dim=0) + freqs_y = torch.stack(freqs_y, dim=0) + freqs = torch.stack([freqs_x, freqs_y], dim=0) + return freqs + +def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int): + N = t_x.shape[0] + depth = freqs.shape[1] + # ensure float32 math even if inputs are bf16, otherwise will raise error when using multi gpu + freqs = freqs.float() + t_x = t_x.float() + t_y = t_y.float() + # No float 16 for this range + with torch.cuda.amp.autocast(enabled=False): + freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3) + freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3) + freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y) + + return freqs_cis + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode='floor').float() + return t_x, t_y + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + if freqs_cis.shape == (x.shape[-2], x.shape[-1]): + shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)] + elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]): + shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)] + # for Query Aggregation engine fallback + else: + # fallback for cases where rotary frequencies already include head dim + shape = [1] * ndim + shape[-freqs_cis.ndim :] = freqs_cis.shape + + return freqs_cis.view(*shape) + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/uv.lock b/uv.lock index a524103f1..b37100a61 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.12.*" resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -547,7 +547,7 @@ wheels = [ [[package]] name = "earthkit-data" -version = "0.18.0" +version = "0.18.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cfgrib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -570,9 +570,9 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "xarray", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/d0/0208abe30005d6073b18f7bb19eba703407c2ec1aa81f978d514d525790a/earthkit_data-0.18.0.tar.gz", hash = "sha256:dff18c616f4f817a74ac9a1588267409707e8774f253336ac15c821b4505a1cf", size = 5553184, upload-time = "2025-11-12T10:11:44.184Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/a4/a78a78258093ea85f11bf2b5b90274403f0c88fe82c2b53070f4ab0d4bdb/earthkit_data-0.18.2.tar.gz", hash = "sha256:fbbb9ade7898b872456913af70dea2f680734cd414747dd368739804794670df", size = 5554363, upload-time = "2025-11-18T19:35:09.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bc/99/2672ba5c6e457e9d83fdfa5ea0848c8a5ed2747ba311f98f4d8e2304ed23/earthkit_data-0.18.0-py3-none-any.whl", hash = "sha256:fc61ffbe829cc1538cde1b9e6562031e25e9c8df75213b418fc30234b929fb78", size = 389538, upload-time = "2025-11-12T10:11:42.284Z" }, + { url = "https://files.pythonhosted.org/packages/57/cb/d6d435c7ce7782fa3c7aaf260f779cab80f6944c13a1546a0a3aed797b69/earthkit_data-0.18.2-py3-none-any.whl", hash = "sha256:0c61b5f61c7decb921ff3543f9c73b4988b6f2c88d6e8b68ee1ee34bee9d3573", size = 389574, upload-time = "2025-11-18T19:35:07.334Z" }, ] [[package]] @@ -616,21 +616,45 @@ wheels = [ [[package]] name = "eccodes" -version = "2.40.0" +version = "2.44.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "cffi", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "findlibs", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/d7/588c558512359382eef169d444281a46d385cb21f9c672a4aeed6b20f78d/eccodes-2.40.0.tar.gz", hash = "sha256:47eb68869c2510fe63c26490d25065fc88521ec92a431facc12033340c449032", size = 2267613, upload-time = "2025-02-11T14:05:44.977Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/2a/9242d0a83de707ed401906a34bfe1d9a3af616abf498580ef73a6e8cebd5/eccodes-2.44.0.tar.gz", hash = "sha256:8aba9316749349e64db7d075100bff8e24a892814e3529132ec97b6d787eb8f4", size = 2310714, upload-time = "2025-10-03T14:02:37.462Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/75/26ecc31b88ece99295dd7d5ed73870ab31bb1e27a71d41c312702ed85c04/eccodes-2.40.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:4ed9c11b5bf335872e265e094aa64f45888a0fde3b62c670ac832deba8b0d317", size = 6539794, upload-time = "2025-02-11T13:52:55.22Z" }, - { url = "https://files.pythonhosted.org/packages/fb/2a/dafcb6240652d356c0e2e0691a0e69733cdf8738f97ef804197474e8598b/eccodes-2.40.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:49ac38aac6437244779b3b78d2156804f48336edbd9b5789c4b03c39ff39f515", size = 6638887, upload-time = "2025-02-11T13:58:17.818Z" }, - { url = "https://files.pythonhosted.org/packages/06/39/1d22469705e596823c1ed18f8f33c9ae958f30cbf95b2add51d5571b5d5e/eccodes-2.40.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:590286ac150e0a3ffc18c7cc59162c9a02245bf5f3097bc4582084d56126e370", size = 7403664, upload-time = "2025-02-11T13:51:57.121Z" }, - { url = "https://files.pythonhosted.org/packages/42/07/f174bbd973224786f03b53cdf1691de6d957c3631946f32d17616ae2e197/eccodes-2.40.0-cp312-cp312-win_amd64.whl", hash = "sha256:80fd3a5737a8f1aeb9036db83fca08f82fe10532b0b9fbc8a6bc099b4401ec23", size = 6208888, upload-time = "2025-02-11T13:51:26.314Z" }, - { url = "https://files.pythonhosted.org/packages/a4/9c/2000b4a14efb5060e57226e103f7275d52bc00e162f634906eb20423f8a1/eccodes-2.40.0-py3-none-any.whl", hash = "sha256:3033b1023f1a2e875d30a17ad06f6119003d682687b7ea4b7484adc58e16c20c", size = 43556, upload-time = "2025-02-11T14:05:42.814Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a8/4d3b00f09440b269da208831b450a77e150ecfd1ac3981ca83d984ede4bd/eccodes-2.44.0-cp312-cp312-win_amd64.whl", hash = "sha256:20864247343bf88df88eafbf811fa90c290c45ed32d24f046238bd0f1684e16e", size = 7247248, upload-time = "2025-10-03T14:02:05.837Z" }, + { url = "https://files.pythonhosted.org/packages/dd/b8/9d15cea1f63fb2e1e14fda4160c355e6187e69b71b848c05faaae08b2e6c/eccodes-2.44.0-py3-none-any.whl", hash = "sha256:c3f11041bde7c3f53767c5bbed608c43695f257c09c58bb4de24bcd9cdae4e3a", size = 83465, upload-time = "2025-10-03T14:02:36.181Z" }, +] + +[[package]] +name = "eccodeslib" +version = "2.44.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "fckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/21/555e76b8dfa2ac050df8e638e9b91c6e671c3e2ba0abc2213e8df84d1e5c/eccodeslib-2.44.0.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:f28cebbfae6594ed393214f59b828b55238bdd2c61e4f533e96098c2e19bb47f", size = 8926805, upload-time = "2025-11-25T11:59:59.543Z" }, + { url = "https://files.pythonhosted.org/packages/1c/b8/e50cfc8588a85f31568ef02f6913b42d44e36c476cd1aaf61f2489e6749b/eccodeslib-2.44.0.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:e191ce8d33fce4c796fe6ffb57652e0faa19e61b3ae8e9d0adacb50fc824d77b", size = 8723732, upload-time = "2025-11-25T11:59:18.999Z" }, + { url = "https://files.pythonhosted.org/packages/55/7f/a81915d7693e8d46df61b44d5bbc1717c8b41deaf3084831b369191ee24c/eccodeslib-2.44.0.7-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:c1730012671b8c6a70001fc9f6fa4a557ca8d0888c2f76eae81ab6f978190cad", size = 20983542, upload-time = "2025-11-25T12:01:17.72Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ae/a8f0fc3468e7d0e3cbcf7d2d51d55c53a785f7e3440f9b4546a0994b29b9/eccodeslib-2.44.0.7-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:61774756ea652e3bcea436e8bd6dcbbd3e044f0a41b39d748e110fada48ffdbe", size = 20853439, upload-time = "2025-11-25T12:04:15.242Z" }, +] + +[[package]] +name = "eckitlib" +version = "1.32.3.7" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/02/55468294aa6fb836d1a4d3d18459fad467e2f622df980e59181da2ed80a4/eckitlib-1.32.3.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:2344c7250b28f3cad2110ee6703c2a58714ed2d012c66fc6b38edb39eb567cd7", size = 2925833, upload-time = "2025-11-25T12:00:05.094Z" }, + { url = "https://files.pythonhosted.org/packages/91/19/33ba5777745f1f237ee6a549fb585afc6dde6f51672ea269d0285237214e/eckitlib-1.32.3.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:3964fee0a5cf828886957d87d053d42f29081f4d8b0f3b9f78fcc4a2401f6335", size = 3028987, upload-time = "2025-11-25T11:59:25.777Z" }, + { url = "https://files.pythonhosted.org/packages/7e/42/51dbb879c0e4b3a70dfa3463c24c41aca5097a2cddc68accacd1f7b572e8/eckitlib-1.32.3.7-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de323af51b6560b22de2fb0ec4ad98c5f318975688526d04351b13428ad72de6", size = 43683895, upload-time = "2025-11-25T12:01:26.934Z" }, + { url = "https://files.pythonhosted.org/packages/40/88/2e751d24663b15a50e8aec49332020cb5e3c1305e6dc229e8cf396f92809/eckitlib-1.32.3.7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3825984742c687db0feed540c8ba4bb26472859711e78c560711eba3fe6d12cf", size = 44585482, upload-time = "2025-11-25T12:04:28.06Z" }, ] [[package]] @@ -707,6 +731,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924, upload-time = "2024-12-02T10:55:07.599Z" }, ] +[[package]] +name = "fckitlib" +version = "0.14.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/07/921d9adf99b4cb0983f4327f32e76718e88e1fbc78eb253e6a33ce1004e4/fckitlib-0.14.1.7-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:cd3213c33656e6bb7652cbbf4dac7b466d294f42188ef0f4c7f69aab124c006e", size = 411476, upload-time = "2025-11-25T12:00:07.712Z" }, + { url = "https://files.pythonhosted.org/packages/e8/e8/3339b155d2486a3710bf59274259bee846325b7bad5aaa269565e2b76838/fckitlib-0.14.1.7-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:073c97897e032e51ff64028fe4602e50a8092c40563d147f06641f0f7b4f8f23", size = 417158, upload-time = "2025-11-25T11:59:29.004Z" }, + { url = "https://files.pythonhosted.org/packages/bd/46/06c9fd28b580a8fc59f7a889a7710fdd4afe6f029325ac908d687bdbc3eb/fckitlib-0.14.1.7-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6af3c23fa3aecf8dc45d7dc610ed542b807ae81e73f2e164c74e90ebbcc0252e", size = 1342966, upload-time = "2025-11-25T12:01:31.618Z" }, + { url = "https://files.pythonhosted.org/packages/f7/8f/d85d55b3582e168a0221a71b2f54c28b02b1d7ce78b37926cc6019da7945/fckitlib-0.14.1.7-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d70bea78e3248780b4b49fce4683ebbf9419e2b5cfd9a9fdd9512ead8627aa3d", size = 12761273, upload-time = "2025-11-25T12:04:35.044Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -1013,7 +1051,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ @@ -2810,7 +2848,6 @@ dependencies = [ { name = "anemoi-datasets", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "astropy-healpix", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "dask", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "eccodes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "hatchling", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "matplotlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "numexpr", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2857,7 +2894,6 @@ requires-dist = [ { name = "anemoi-datasets", git = "https://github.com/ecmwf/anemoi-datasets?branch=feature%2Fzarr3" }, { name = "astropy-healpix", specifier = "~=1.1.2" }, { name = "dask", specifier = "~=2025.5.1" }, - { name = "eccodes", specifier = "<=2.40" }, { name = "flash-attn", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl" }, { name = "flash-attn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl" }, { name = "flash-attn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'gpu') or (sys_platform != 'linux' and extra == 'gpu')" }, @@ -2944,6 +2980,10 @@ version = "0.1.0" source = { editable = "packages/evaluate" } dependencies = [ { name = "cartopy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "earthkit-data", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eccodeslib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "eckitlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "omegaconf", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "panel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "plotly", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2964,6 +3004,10 @@ dev = [ [package.metadata] requires-dist = [ { name = "cartopy", specifier = ">=0.24.1" }, + { name = "earthkit-data", specifier = "==0.18.2" }, + { name = "eccodes", specifier = "==2.44.0" }, + { name = "eccodeslib", specifier = "==2.44.0.7" }, + { name = "eckitlib", specifier = "==1.32.3.7" }, { name = "omegaconf" }, { name = "panel" }, { name = "plotly", specifier = ">=6.2.0" },