Skip to content

Commit 836a864

Browse files
committed
Move to seperate utils module
1 parent 15c9d7d commit 836a864

File tree

4 files changed

+359
-282
lines changed

4 files changed

+359
-282
lines changed

src/anemoi/inference/commands/redefine.py

Lines changed: 47 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@
1111
import logging
1212
from argparse import ArgumentParser
1313
from argparse import Namespace
14-
from copy import deepcopy
1514
from pathlib import Path
16-
from typing import TYPE_CHECKING
1715

1816
from . import Command
1917

2018
LOG = logging.getLogger(__name__)
2119

22-
if TYPE_CHECKING:
23-
import numpy as np
24-
from torch_geometric.data import HeteroData
20+
21+
def check_redefine_imports():
22+
"""Check if required packages are installed."""
23+
required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"]
24+
from importlib.util import find_spec
25+
26+
for package in required_packages:
27+
if find_spec(package) is None:
28+
raise ImportError(f"{package!r} is required for this command.")
2529

2630

2731
def format_namespace_as_str(namespace: Namespace) -> str:
@@ -48,44 +52,6 @@ def format_namespace_as_str(namespace: Namespace) -> str:
4852
return " ".join(args)
4953

5054

51-
def update_state_dict(
52-
model,
53-
external_state_dict,
54-
keywords: list[str] | None = None,
55-
ignore_mismatched_layers=False,
56-
ignore_additional_layers=False,
57-
):
58-
"""Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered."""
59-
60-
LOG.info("Updating model state dictionary.")
61-
62-
keywords = keywords or []
63-
64-
# select relevant part of external_state_dict
65-
reduced_state_dict = {k: v for k, v in external_state_dict.items() if any(kw in k for kw in keywords)}
66-
model_state_dict = model.state_dict()
67-
68-
# check layers and their shapes
69-
for key in list(reduced_state_dict):
70-
if key not in model_state_dict:
71-
if ignore_additional_layers:
72-
LOG.info("Skipping injection of %s, which is not in the model.", key)
73-
del reduced_state_dict[key]
74-
else:
75-
raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.")
76-
elif reduced_state_dict[key].shape != model_state_dict[key].shape:
77-
if ignore_mismatched_layers:
78-
LOG.info("Skipping injection of %s due to shape mismatch.", key)
79-
LOG.info("Model shape: %s", model_state_dict[key].shape)
80-
LOG.info("Provided shape: %s", reduced_state_dict[key].shape)
81-
del reduced_state_dict[key]
82-
else:
83-
raise AssertionError(f"Mismatch in shape of {key}. Consider setting 'ignore_mismatched_layers = True'.")
84-
85-
model.load_state_dict(reduced_state_dict, strict=False)
86-
return model
87-
88-
8955
class RedefineCmd(Command):
9056
"""Redefine the graph of a checkpoint file."""
9157

@@ -97,7 +63,7 @@ def add_arguments(self, command_parser: ArgumentParser) -> None:
9763
command_parser : ArgumentParser
9864
The argument parser to which the arguments will be added.
9965
"""
100-
command_parser.description = "Redefine the graph of a checkpoint file."
66+
command_parser.description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded."
10167
command_parser.add_argument("path", help="Path to the checkpoint.")
10268

10369
group = command_parser.add_mutually_exclusive_group(required=True)
@@ -122,155 +88,6 @@ def add_arguments(self, command_parser: ArgumentParser) -> None:
12288
command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None)
12389
command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None)
12490

125-
def _get_coordinates(self, args: Namespace) -> tuple["np.ndarray", "np.ndarray"]:
126-
"""Get coordinates from command line arguments.
127-
128-
Either from files or from coords which are extracted from a MARS request.
129-
"""
130-
import numpy as np
131-
132-
if args.latlon is not None:
133-
latlon = np.load(args.latlon)
134-
return latlon[:, 0], latlon[:, 1]
135-
136-
elif args.coords is not None:
137-
import earthkit.data as ekd
138-
139-
area = [args.coords[0], args.coords[1], args.coords[2], args.coords[3]]
140-
141-
resolution = str(args.coords[4])
142-
if resolution.isdigit():
143-
resolution = f"{resolution}/{resolution}"
144-
145-
ds = ekd.from_source(
146-
"mars",
147-
{
148-
"AREA": area,
149-
"GRID": f"{resolution}",
150-
"param": "2t",
151-
"date": -2,
152-
"stream": "oper",
153-
"type": "an",
154-
"levtype": "sfc",
155-
},
156-
)
157-
return ds[0].grid_points() # type: ignore
158-
raise ValueError("No valid coordinates found.")
159-
160-
def _combine_nodes(
161-
self, latitudes: "np.ndarray", longitudes: "np.ndarray", global_grid: str
162-
) -> tuple["np.ndarray", "np.ndarray", "np.ndarray", "np.ndarray"]:
163-
"""Combine lat/lon nodes with global grid if specified.
164-
165-
Returns lats, lons, local_mask, global_mask
166-
"""
167-
import numpy as np
168-
from anemoi.datasets.grids import cutout_mask
169-
from anemoi.utils.grids import grids
170-
171-
global_points = grids(global_grid)
172-
173-
global_removal_mask = cutout_mask(
174-
latitudes, longitudes, global_points["latitudes"], global_points["longitudes"]
175-
)
176-
lats = np.concatenate([latitudes, global_points["latitudes"][global_removal_mask]])
177-
lons = np.concatenate([longitudes, global_points["longitudes"][global_removal_mask]])
178-
local_mask = np.array([True] * len(latitudes) + [False] * sum(global_removal_mask), dtype=bool)
179-
180-
return lats, lons, local_mask, global_removal_mask
181-
182-
def _make_data_graph(
183-
self,
184-
lats: "np.ndarray",
185-
lons: "np.ndarray",
186-
local_mask: "np.ndarray",
187-
global_mask: "np.ndarray",
188-
*,
189-
mask_attr_name: str = "cutout",
190-
attrs,
191-
) -> "HeteroData":
192-
"""Make a data graph with the given lat/lon nodes and attributes."""
193-
import torch
194-
from anemoi.graphs.nodes import LatLonNodes
195-
from torch_geometric.data import HeteroData
196-
197-
graph = LatLonNodes(lats, lons, name="data").update_graph(HeteroData(), attrs_config=attrs)
198-
graph["data"][mask_attr_name] = torch.from_numpy(local_mask)
199-
return graph
200-
201-
def _make_graph_from_coordinates(
202-
self, args: Namespace, metadata: dict, supporting_arrays: dict
203-
) -> tuple[dict, dict, "HeteroData"]:
204-
"""Make a graph from coordinates given in args."""
205-
import numpy as np
206-
207-
if args.global_resolution is None:
208-
raise ValueError("Global resolution must be specified when generating graph from coordinates.")
209-
210-
local_lats, local_lons = self._get_coordinates(args)
211-
LOG.info("Coordinates loaded. Number of local nodes: %d", len(local_lats))
212-
lats, lons, local_mask, global_mask = self._combine_nodes(local_lats, local_lons, args.global_resolution)
213-
214-
graph_config = deepcopy(metadata["config"]["graph"])
215-
data_graph = graph_config["nodes"].pop("data")
216-
217-
from anemoi.graphs.create import GraphCreator
218-
from anemoi.utils.config import DotDict
219-
220-
creator = GraphCreator(DotDict(graph_config))
221-
222-
LOG.info("Updating graph...")
223-
LOG.debug("Using %r", graph_config)
224-
225-
def nested_get(d, keys, default=None):
226-
for key in keys:
227-
d = d.get(key, {})
228-
return d or default
229-
230-
mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout")
231-
232-
data_graph = self._make_data_graph(
233-
lats, lons, local_mask, global_mask, mask_attr_name=mask_attr_name, attrs=data_graph.get("attrs", None)
234-
)
235-
LOG.info("Created data graph with %d nodes.", data_graph.num_nodes)
236-
graph = creator.update_graph(data_graph)
237-
238-
supporting_arrays[f"global/{mask_attr_name}"] = global_mask
239-
supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats))
240-
241-
supporting_arrays["latitudes"] = lats
242-
supporting_arrays["longitudes"] = lons
243-
supporting_arrays["grid_indices"] = np.ones(global_mask.shape, dtype=np.int64)
244-
245-
return metadata, supporting_arrays, graph
246-
247-
def _update_checkpoint(self, model, metadata, graph: "HeteroData"):
248-
from anemoi.utils.config import DotDict
249-
250-
state_dict_ckpt = deepcopy(model.state_dict())
251-
252-
# rebuild the model with the new graph
253-
model.graph_data = graph
254-
model.config = DotDict(metadata).config
255-
model._build_model()
256-
257-
# reinstate the weights, biases and normalizer from the checkpoint
258-
# reinstating the normalizer is necessary for checkpoints that were created
259-
# using transfer learning, where the statistics as stored in the checkpoint
260-
# do not match the statistics used to build the normalizer in the checkpoint.
261-
model_instance = update_state_dict(model, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"])
262-
263-
return model_instance
264-
265-
def _check_imports(self):
266-
"""Check if required packages are installed."""
267-
required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"]
268-
from importlib.util import find_spec
269-
270-
for package in required_packages:
271-
if find_spec(package) is None:
272-
raise ImportError(f"{package!r} is required for this command.")
273-
27491
def run(self, args: Namespace) -> None:
27592
"""Run the redefine command.
27693
@@ -279,44 +96,61 @@ def run(self, args: Namespace) -> None:
27996
args : Namespace
28097
The arguments passed to the command.
28198
"""
282-
self._check_imports()
99+
from anemoi.inference.utils.redefine import create_graph_from_config
100+
from anemoi.inference.utils.redefine import get_coordinates_from_file
101+
from anemoi.inference.utils.redefine import get_coordinates_from_mars_request
102+
from anemoi.inference.utils.redefine import load_graph_from_file
103+
from anemoi.inference.utils.redefine import make_graph_from_coordinates
104+
from anemoi.inference.utils.redefine import update_checkpoint
105+
106+
check_redefine_imports()
283107

284108
import torch
285109
from anemoi.utils.checkpoints import load_metadata
286110
from anemoi.utils.checkpoints import save_metadata
287111

288112
path = Path(args.path)
289113

114+
# Load checkpoint metadata and supporting arrays
290115
metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True)
291116

117+
# Add command to history
292118
metadata.setdefault("history", [])
293119
metadata["history"].append(f"anemoi-inference redefine {format_namespace_as_str(args)}")
294120

121+
# Create or load the graph
295122
if args.graph is not None:
296-
LOG.info("Loading graph from %s", args.graph)
297-
graph = torch.load(args.graph)
123+
graph = load_graph_from_file(args.graph)
124+
elif args.graph_config is not None:
125+
graph = create_graph_from_config(args.graph_config)
298126
else:
299-
if args.graph_config is not None:
300-
from anemoi.graphs.create import GraphCreator
301-
from torch_geometric.data import HeteroData
302-
303-
graph = GraphCreator(args.graph_config).update_graph(HeteroData())
127+
# Generate graph from coordinates
128+
LOG.info("Generating graph from coordinates...")
129+
130+
# Get coordinates based on input type
131+
if args.latlon is not None:
132+
local_lats, local_lons = get_coordinates_from_file(args.latlon)
133+
elif args.coords is not None:
134+
local_lats, local_lons = get_coordinates_from_mars_request(args.coords)
304135
else:
305-
LOG.info("Generating graph from coordinates...")
306-
metadata, supporting_arrays, graph = self._make_graph_from_coordinates(
307-
args, metadata, supporting_arrays
308-
)
136+
raise ValueError("No valid coordinates found.")
309137

310-
if args.save_graph is not None:
311-
torch.save(graph, args.save_graph)
312-
LOG.info("Saved updated graph to %s", args.save_graph)
138+
metadata, supporting_arrays, graph = make_graph_from_coordinates(
139+
local_lats, local_lons, args.global_resolution, metadata, supporting_arrays
140+
)
313141

314-
LOG.info("Updating checkpoint...")
142+
# Save graph if requested
143+
if args.save_graph is not None:
144+
torch.save(graph, args.save_graph)
145+
LOG.info("Saved updated graph to %s", args.save_graph)
315146

147+
# Update checkpoint
148+
LOG.info("Updating checkpoint...")
316149
model = torch.load(str(path), weights_only=False, map_location=torch.device("cpu"))
317-
model = self._update_checkpoint(model, metadata, graph)
318-
model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}"
150+
model = update_checkpoint(model, metadata, graph)
319151

152+
# Save updated checkpoint
153+
model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}"
320154
torch.save(model, model_path)
321155

322156
save_metadata(
@@ -325,5 +159,7 @@ def run(self, args: Namespace) -> None:
325159
supporting_arrays=supporting_arrays,
326160
)
327161

162+
LOG.info("Updated checkpoint saved to %s", model_path)
163+
328164

329165
command = RedefineCmd

0 commit comments

Comments
 (0)