11
11
import logging
12
12
from argparse import ArgumentParser
13
13
from argparse import Namespace
14
- from copy import deepcopy
15
14
from pathlib import Path
16
- from typing import TYPE_CHECKING
17
15
18
16
from . import Command
19
17
20
18
LOG = logging .getLogger (__name__ )
21
19
22
- if TYPE_CHECKING :
23
- import numpy as np
24
- from torch_geometric .data import HeteroData
20
+
21
+ def check_redefine_imports ():
22
+ """Check if required packages are installed."""
23
+ required_packages = ["anemoi.datasets" , "anemoi.graphs" , "anemoi.models" ]
24
+ from importlib .util import find_spec
25
+
26
+ for package in required_packages :
27
+ if find_spec (package ) is None :
28
+ raise ImportError (f"{ package !r} is required for this command." )
25
29
26
30
27
31
def format_namespace_as_str (namespace : Namespace ) -> str :
@@ -48,44 +52,6 @@ def format_namespace_as_str(namespace: Namespace) -> str:
48
52
return " " .join (args )
49
53
50
54
51
- def update_state_dict (
52
- model ,
53
- external_state_dict ,
54
- keywords : list [str ] | None = None ,
55
- ignore_mismatched_layers = False ,
56
- ignore_additional_layers = False ,
57
- ):
58
- """Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered."""
59
-
60
- LOG .info ("Updating model state dictionary." )
61
-
62
- keywords = keywords or []
63
-
64
- # select relevant part of external_state_dict
65
- reduced_state_dict = {k : v for k , v in external_state_dict .items () if any (kw in k for kw in keywords )}
66
- model_state_dict = model .state_dict ()
67
-
68
- # check layers and their shapes
69
- for key in list (reduced_state_dict ):
70
- if key not in model_state_dict :
71
- if ignore_additional_layers :
72
- LOG .info ("Skipping injection of %s, which is not in the model." , key )
73
- del reduced_state_dict [key ]
74
- else :
75
- raise AssertionError (f"Layer { key } not in model. Consider setting 'ignore_additional_layers = True'." )
76
- elif reduced_state_dict [key ].shape != model_state_dict [key ].shape :
77
- if ignore_mismatched_layers :
78
- LOG .info ("Skipping injection of %s due to shape mismatch." , key )
79
- LOG .info ("Model shape: %s" , model_state_dict [key ].shape )
80
- LOG .info ("Provided shape: %s" , reduced_state_dict [key ].shape )
81
- del reduced_state_dict [key ]
82
- else :
83
- raise AssertionError (f"Mismatch in shape of { key } . Consider setting 'ignore_mismatched_layers = True'." )
84
-
85
- model .load_state_dict (reduced_state_dict , strict = False )
86
- return model
87
-
88
-
89
55
class RedefineCmd (Command ):
90
56
"""Redefine the graph of a checkpoint file."""
91
57
@@ -97,7 +63,7 @@ def add_arguments(self, command_parser: ArgumentParser) -> None:
97
63
command_parser : ArgumentParser
98
64
The argument parser to which the arguments will be added.
99
65
"""
100
- command_parser .description = "Redefine the graph of a checkpoint file."
66
+ command_parser .description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded. "
101
67
command_parser .add_argument ("path" , help = "Path to the checkpoint." )
102
68
103
69
group = command_parser .add_mutually_exclusive_group (required = True )
@@ -122,155 +88,6 @@ def add_arguments(self, command_parser: ArgumentParser) -> None:
122
88
command_parser .add_argument ("--save-graph" , type = str , help = "Path to save the updated graph." , default = None )
123
89
command_parser .add_argument ("--output" , type = str , help = "Path to save the updated checkpoint." , default = None )
124
90
125
- def _get_coordinates (self , args : Namespace ) -> tuple ["np.ndarray" , "np.ndarray" ]:
126
- """Get coordinates from command line arguments.
127
-
128
- Either from files or from coords which are extracted from a MARS request.
129
- """
130
- import numpy as np
131
-
132
- if args .latlon is not None :
133
- latlon = np .load (args .latlon )
134
- return latlon [:, 0 ], latlon [:, 1 ]
135
-
136
- elif args .coords is not None :
137
- import earthkit .data as ekd
138
-
139
- area = [args .coords [0 ], args .coords [1 ], args .coords [2 ], args .coords [3 ]]
140
-
141
- resolution = str (args .coords [4 ])
142
- if resolution .isdigit ():
143
- resolution = f"{ resolution } /{ resolution } "
144
-
145
- ds = ekd .from_source (
146
- "mars" ,
147
- {
148
- "AREA" : area ,
149
- "GRID" : f"{ resolution } " ,
150
- "param" : "2t" ,
151
- "date" : - 2 ,
152
- "stream" : "oper" ,
153
- "type" : "an" ,
154
- "levtype" : "sfc" ,
155
- },
156
- )
157
- return ds [0 ].grid_points () # type: ignore
158
- raise ValueError ("No valid coordinates found." )
159
-
160
- def _combine_nodes (
161
- self , latitudes : "np.ndarray" , longitudes : "np.ndarray" , global_grid : str
162
- ) -> tuple ["np.ndarray" , "np.ndarray" , "np.ndarray" , "np.ndarray" ]:
163
- """Combine lat/lon nodes with global grid if specified.
164
-
165
- Returns lats, lons, local_mask, global_mask
166
- """
167
- import numpy as np
168
- from anemoi .datasets .grids import cutout_mask
169
- from anemoi .utils .grids import grids
170
-
171
- global_points = grids (global_grid )
172
-
173
- global_removal_mask = cutout_mask (
174
- latitudes , longitudes , global_points ["latitudes" ], global_points ["longitudes" ]
175
- )
176
- lats = np .concatenate ([latitudes , global_points ["latitudes" ][global_removal_mask ]])
177
- lons = np .concatenate ([longitudes , global_points ["longitudes" ][global_removal_mask ]])
178
- local_mask = np .array ([True ] * len (latitudes ) + [False ] * sum (global_removal_mask ), dtype = bool )
179
-
180
- return lats , lons , local_mask , global_removal_mask
181
-
182
- def _make_data_graph (
183
- self ,
184
- lats : "np.ndarray" ,
185
- lons : "np.ndarray" ,
186
- local_mask : "np.ndarray" ,
187
- global_mask : "np.ndarray" ,
188
- * ,
189
- mask_attr_name : str = "cutout" ,
190
- attrs ,
191
- ) -> "HeteroData" :
192
- """Make a data graph with the given lat/lon nodes and attributes."""
193
- import torch
194
- from anemoi .graphs .nodes import LatLonNodes
195
- from torch_geometric .data import HeteroData
196
-
197
- graph = LatLonNodes (lats , lons , name = "data" ).update_graph (HeteroData (), attrs_config = attrs )
198
- graph ["data" ][mask_attr_name ] = torch .from_numpy (local_mask )
199
- return graph
200
-
201
- def _make_graph_from_coordinates (
202
- self , args : Namespace , metadata : dict , supporting_arrays : dict
203
- ) -> tuple [dict , dict , "HeteroData" ]:
204
- """Make a graph from coordinates given in args."""
205
- import numpy as np
206
-
207
- if args .global_resolution is None :
208
- raise ValueError ("Global resolution must be specified when generating graph from coordinates." )
209
-
210
- local_lats , local_lons = self ._get_coordinates (args )
211
- LOG .info ("Coordinates loaded. Number of local nodes: %d" , len (local_lats ))
212
- lats , lons , local_mask , global_mask = self ._combine_nodes (local_lats , local_lons , args .global_resolution )
213
-
214
- graph_config = deepcopy (metadata ["config" ]["graph" ])
215
- data_graph = graph_config ["nodes" ].pop ("data" )
216
-
217
- from anemoi .graphs .create import GraphCreator
218
- from anemoi .utils .config import DotDict
219
-
220
- creator = GraphCreator (DotDict (graph_config ))
221
-
222
- LOG .info ("Updating graph..." )
223
- LOG .debug ("Using %r" , graph_config )
224
-
225
- def nested_get (d , keys , default = None ):
226
- for key in keys :
227
- d = d .get (key , {})
228
- return d or default
229
-
230
- mask_attr_name = nested_get (graph_config , ["nodes" , "hidden" , "node_builder" , "mask_attr_name" ], "cutout" )
231
-
232
- data_graph = self ._make_data_graph (
233
- lats , lons , local_mask , global_mask , mask_attr_name = mask_attr_name , attrs = data_graph .get ("attrs" , None )
234
- )
235
- LOG .info ("Created data graph with %d nodes." , data_graph .num_nodes )
236
- graph = creator .update_graph (data_graph )
237
-
238
- supporting_arrays [f"global/{ mask_attr_name } " ] = global_mask
239
- supporting_arrays [f"lam_0/{ mask_attr_name } " ] = np .array ([True ] * len (local_lats ))
240
-
241
- supporting_arrays ["latitudes" ] = lats
242
- supporting_arrays ["longitudes" ] = lons
243
- supporting_arrays ["grid_indices" ] = np .ones (global_mask .shape , dtype = np .int64 )
244
-
245
- return metadata , supporting_arrays , graph
246
-
247
- def _update_checkpoint (self , model , metadata , graph : "HeteroData" ):
248
- from anemoi .utils .config import DotDict
249
-
250
- state_dict_ckpt = deepcopy (model .state_dict ())
251
-
252
- # rebuild the model with the new graph
253
- model .graph_data = graph
254
- model .config = DotDict (metadata ).config
255
- model ._build_model ()
256
-
257
- # reinstate the weights, biases and normalizer from the checkpoint
258
- # reinstating the normalizer is necessary for checkpoints that were created
259
- # using transfer learning, where the statistics as stored in the checkpoint
260
- # do not match the statistics used to build the normalizer in the checkpoint.
261
- model_instance = update_state_dict (model , state_dict_ckpt , keywords = ["bias" , "weight" , "processors.normalizer" ])
262
-
263
- return model_instance
264
-
265
- def _check_imports (self ):
266
- """Check if required packages are installed."""
267
- required_packages = ["anemoi.datasets" , "anemoi.graphs" , "anemoi.models" ]
268
- from importlib .util import find_spec
269
-
270
- for package in required_packages :
271
- if find_spec (package ) is None :
272
- raise ImportError (f"{ package !r} is required for this command." )
273
-
274
91
def run (self , args : Namespace ) -> None :
275
92
"""Run the redefine command.
276
93
@@ -279,44 +96,61 @@ def run(self, args: Namespace) -> None:
279
96
args : Namespace
280
97
The arguments passed to the command.
281
98
"""
282
- self ._check_imports ()
99
+ from anemoi .inference .utils .redefine import create_graph_from_config
100
+ from anemoi .inference .utils .redefine import get_coordinates_from_file
101
+ from anemoi .inference .utils .redefine import get_coordinates_from_mars_request
102
+ from anemoi .inference .utils .redefine import load_graph_from_file
103
+ from anemoi .inference .utils .redefine import make_graph_from_coordinates
104
+ from anemoi .inference .utils .redefine import update_checkpoint
105
+
106
+ check_redefine_imports ()
283
107
284
108
import torch
285
109
from anemoi .utils .checkpoints import load_metadata
286
110
from anemoi .utils .checkpoints import save_metadata
287
111
288
112
path = Path (args .path )
289
113
114
+ # Load checkpoint metadata and supporting arrays
290
115
metadata , supporting_arrays = load_metadata (str (path ), supporting_arrays = True )
291
116
117
+ # Add command to history
292
118
metadata .setdefault ("history" , [])
293
119
metadata ["history" ].append (f"anemoi-inference redefine { format_namespace_as_str (args )} " )
294
120
121
+ # Create or load the graph
295
122
if args .graph is not None :
296
- LOG .info ("Loading graph from %s" , args .graph )
297
- graph = torch .load (args .graph )
123
+ graph = load_graph_from_file (args .graph )
124
+ elif args .graph_config is not None :
125
+ graph = create_graph_from_config (args .graph_config )
298
126
else :
299
- if args .graph_config is not None :
300
- from anemoi .graphs .create import GraphCreator
301
- from torch_geometric .data import HeteroData
302
-
303
- graph = GraphCreator (args .graph_config ).update_graph (HeteroData ())
127
+ # Generate graph from coordinates
128
+ LOG .info ("Generating graph from coordinates..." )
129
+
130
+ # Get coordinates based on input type
131
+ if args .latlon is not None :
132
+ local_lats , local_lons = get_coordinates_from_file (args .latlon )
133
+ elif args .coords is not None :
134
+ local_lats , local_lons = get_coordinates_from_mars_request (args .coords )
304
135
else :
305
- LOG .info ("Generating graph from coordinates..." )
306
- metadata , supporting_arrays , graph = self ._make_graph_from_coordinates (
307
- args , metadata , supporting_arrays
308
- )
136
+ raise ValueError ("No valid coordinates found." )
309
137
310
- if args . save_graph is not None :
311
- torch . save ( graph , args .save_graph )
312
- LOG . info ( "Saved updated graph to %s" , args . save_graph )
138
+ metadata , supporting_arrays , graph = make_graph_from_coordinates (
139
+ local_lats , local_lons , args .global_resolution , metadata , supporting_arrays
140
+ )
313
141
314
- LOG .info ("Updating checkpoint..." )
142
+ # Save graph if requested
143
+ if args .save_graph is not None :
144
+ torch .save (graph , args .save_graph )
145
+ LOG .info ("Saved updated graph to %s" , args .save_graph )
315
146
147
+ # Update checkpoint
148
+ LOG .info ("Updating checkpoint..." )
316
149
model = torch .load (str (path ), weights_only = False , map_location = torch .device ("cpu" ))
317
- model = self ._update_checkpoint (model , metadata , graph )
318
- model_path = args .output if args .output is not None else f"{ path .stem } _updated{ path .suffix } "
150
+ model = update_checkpoint (model , metadata , graph )
319
151
152
+ # Save updated checkpoint
153
+ model_path = args .output if args .output is not None else f"{ path .stem } _updated{ path .suffix } "
320
154
torch .save (model , model_path )
321
155
322
156
save_metadata (
@@ -325,5 +159,7 @@ def run(self, args: Namespace) -> None:
325
159
supporting_arrays = supporting_arrays ,
326
160
)
327
161
162
+ LOG .info ("Updated checkpoint saved to %s" , model_path )
163
+
328
164
329
165
command = RedefineCmd
0 commit comments