diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index 1fa81f73..be902f98 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -19,3 +19,4 @@ The commands are: - :ref:`Validate Command ` - :ref:`Patch Command ` - :ref:`Requests Command ` +- :ref:`redefine_graph Command ` diff --git a/docs/cli/redefine_graph.rst b/docs/cli/redefine_graph.rst new file mode 100644 index 00000000..97b5d9a3 --- /dev/null +++ b/docs/cli/redefine_graph.rst @@ -0,0 +1,160 @@ +.. _redefine_graph-command: + +Redefine Graph Command +====================== + +With this command, you can redefine the graph of a checkpoint file. +This is useful when you want to change / reconfigure the local-domain of a model, or rebuild with a new graph. + +We should caution that such transfer of the model from one graph to +another is not guaranteed to lead to good results. Still, it is a +powerful tool to explore generalisability of the model or to test +performance before starting fine tuning through transfer learning. + +This will create a new checkpoint file with the updated graph, and optionally save the graph to a file. + +Subcommands allow for a graph to be made from a lat/lon coordinate file, bounding box, or from a defined graph config. + +********* + Usage +********* + +.. code-block:: bash + + % anemoi-inference redefine_graph --help + + Redefine the graph of a checkpoint file. + + positional arguments: + path Path to the checkpoint. + + options: + -h, --help show this help message and exit + -g GRAPH, --graph GRAPH + Path to graph file to use + -y GRAPH_CONFIG, --graph-config GRAPH_CONFIG + Path to graph config to use + -ll LATLON, --latlon LATLON + Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. + -c COORDS COORDS COORDS COORDS COORDS, --coords COORDS COORDS COORDS COORDS COORDS + Coordinates, (North West South East Resolution). + -gr GLOBAL_RESOLUTION, --global_-esolution GLOBAL_RESOLUTION + Global grid resolution required with --coords, (e.g. n320, o96). + --save-graph SAVE_GRAPH + Path to save the updated graph. + --output OUTPUT Path to save the updated checkpoint. + + +********* +Examples +********* + +Here are some examples of how to use the `redefine_graph` command: + +#. Using a graph file: + + .. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --graph path/to/graph + +#. Using a graph configuration: + + .. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --graph_config path/to/graph_config + + .. note:: + The configuration of the existing graph can be found using: + + .. code-block:: bash + + anemoi-inference metadata path/to/checkpoint -get config.graph ----yaml + +#. Using latitude/longitude coordinates: + This lat lon file should be a numpy file of shape (N, 2) with latitudes and longitudes. + + It can be easily made from a list of coordinates as follows: + + .. code-block:: python + + import numpy as np + coords = np.array(np.meshgrid(latitudes, longitudes)).T.reshape(-1, 2) + np.save('path/to/latlon.npy', coords) + + Once created, + + .. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --latlon path/to/latlon.npy + +#. Using bounding box coordinates: + + .. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --coords North West South East Resolution + + i.e. + + .. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 + + +All examples can optionally save the updated graph and checkpoint using the `--save-graph` and `--output` options. + +*************************** +Complete Inference Example +*************************** + +For this example we will redefine a checkpoint using a bounding box and then run inference + + +Redefine the checkpoint +----------------------- + +.. code-block:: bash + + anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint + +Create the inference config +--------------------------- + +If you have an input file of the expected shape handy use it in place of the input block, here we will show +how to use MARS to handle the regridding. + +.. note:: + Using the `anemoi-plugins-ecmwf-inference `_ package, preprocessors are available which can handle the regridding for you from other sources. + +.. code-block:: yaml + + checkpoint: path/to/updated_checkpoint + date: -2 + + input: + cutout: + lam_0: + mars: + grid: 0.1/0.1 # RESOLUTION WE SET + area: 30.0/-10.0/20.0/0.0 # BOUNDING BOX WE SET, N W S E + global: + mars: + grid: n320 # GLOBAL RESOLUTION WE SET + + +Run inference +----------------- + +.. code-block:: bash + + anemoi-inference run path/to/updated_checkpoint + + +********** +Reference +********** + +.. argparse:: + :module: anemoi.inference.__main__ + :func: create_parser + :prog: anemoi-inference + :path: redefine_graph diff --git a/docs/index.rst b/docs/index.rst index ff24cab3..353428a4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -153,6 +153,7 @@ You may also have to install pandoc on MacOS: cli/inspect cli/patch cli/requests + cli/redefine_graph .. toctree:: :maxdepth: 1 diff --git a/src/anemoi/inference/commands/redefine_graph.py b/src/anemoi/inference/commands/redefine_graph.py new file mode 100644 index 00000000..2637d3c4 --- /dev/null +++ b/src/anemoi/inference/commands/redefine_graph.py @@ -0,0 +1,165 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from argparse import ArgumentParser +from argparse import Namespace +from pathlib import Path + +from . import Command + +LOG = logging.getLogger(__name__) + + +def check_redefine_imports(): + """Check if required packages are installed.""" + required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"] + from importlib.util import find_spec + + for package in required_packages: + if find_spec(package) is None: + raise ImportError(f"{package!r} is required for this command.") + + +def format_namespace_as_str(namespace: Namespace) -> str: + """Format an argparse Namespace object as command-line arguments.""" + args = [] + + for key, value in vars(namespace).items(): + if key == "command": + continue + if value is None: + continue + + # Convert underscores to hyphens for command line format + arg_name = f"--{key.replace('_', '-')}" + + if isinstance(value, bool): + if value: + args.append(arg_name) + elif isinstance(value, list): + args.append(f"{arg_name} {' '.join(map(str, value))}") + else: + args.extend([arg_name, str(value)]) + + return " ".join(args) + + +class RedefineGraphCmd(Command): + """Redefine the graph of a checkpoint file.""" + + def add_arguments(self, command_parser: ArgumentParser) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : ArgumentParser + The argument parser to which the arguments will be added. + """ + command_parser.description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded." + command_parser.add_argument("path", help="Path to the checkpoint.") + + group = command_parser.add_mutually_exclusive_group(required=True) + + group.add_argument("-g", "--graph", type=Path, help="Path to graph file to use") + group.add_argument("-y", "--graph-config", type=Path, help="Path to graph config to use") + group.add_argument( + "-ll", + "--latlon", + type=Path, + help="Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes.", + ) + group.add_argument("-c", "--coords", type=str, help="Coordinates, (North West South East Resolution).", nargs=5) + + command_parser.add_argument( + "-gr", + "--global-resolution", + type=str, + help="Global grid resolution required with --coords, (e.g. n320, o96).", + ) + + command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None) + command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None) + + def run(self, args: Namespace) -> None: + """Run the redefine_graph command. + + Parameters + ---------- + args : Namespace + The arguments passed to the command. + """ + from anemoi.inference.utils.redefine_graph import create_graph_from_config + from anemoi.inference.utils.redefine_graph import get_coordinates_from_file + from anemoi.inference.utils.redefine_graph import get_coordinates_from_mars_request + from anemoi.inference.utils.redefine_graph import load_graph_from_file + from anemoi.inference.utils.redefine_graph import make_graph_from_coordinates + from anemoi.inference.utils.redefine_graph import update_checkpoint + + check_redefine_imports() + + import torch + from anemoi.utils.checkpoints import load_metadata + from anemoi.utils.checkpoints import save_metadata + + path = Path(args.path) + + # Load checkpoint metadata and supporting arrays + metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True) + + # Add command to history + metadata.setdefault("history", []) + metadata["history"].append(f"anemoi-inference redefine_graph {format_namespace_as_str(args)}") + + # Create or load the graph + if args.graph is not None: + graph = load_graph_from_file(args.graph) + elif args.graph_config is not None: + graph = create_graph_from_config(args.graph_config) + else: + # Generate graph from coordinates + LOG.info("Generating graph from coordinates...") + + # Get coordinates based on input type + if args.latlon is not None: + local_lats, local_lons = get_coordinates_from_file(args.latlon) + elif args.coords is not None: + local_lats, local_lons = get_coordinates_from_mars_request(args.coords) + else: + raise ValueError("No valid coordinates found.") + + metadata, supporting_arrays, graph = make_graph_from_coordinates( + local_lats, local_lons, args.global_resolution, metadata, supporting_arrays + ) + + # Save graph if requested + if args.save_graph is not None: + torch.save(graph, args.save_graph) + LOG.info("Saved updated graph to %s", args.save_graph) + + # Update checkpoint + LOG.info("Updating checkpoint...") + model = torch.load(path, weights_only=False, map_location=torch.device("cpu")) + model = update_checkpoint(model, metadata, graph) + + # Save updated checkpoint + model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}" + torch.save(model, model_path) + + save_metadata( + model_path, + metadata=metadata, + supporting_arrays=supporting_arrays, + ) + + LOG.info("Updated checkpoint saved to %s", model_path) + + +command = RedefineGraphCmd diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index e0b752ca..7491daac 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -95,6 +95,16 @@ def __init__(self, metadata: dict[str, Any], supporting_arrays: dict[str, FloatA self._supporting_arrays = supporting_arrays self._variables_categories = None + def to_dict(self) -> dict[str, Any]: + """Convert the Metadata object to a dictionary. + + Returns + ------- + dict + A copy of the metadata dictionary. + """ + return dict(self._metadata).copy() + @property def _indices(self) -> DotDict: """Return the data indices.""" diff --git a/src/anemoi/inference/runners/external_graph.py b/src/anemoi/inference/runners/external_graph.py index 4ee175ad..fec6dffe 100644 --- a/src/anemoi/inference/runners/external_graph.py +++ b/src/anemoi/inference/runners/external_graph.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024- Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -12,7 +12,7 @@ import logging import os -from copy import deepcopy +from contextlib import contextmanager from functools import cached_property from typing import Any from typing import Literal @@ -23,59 +23,11 @@ from ..decorators import main_argument from ..runners.default import DefaultRunner +from ..utils.redefine_graph import update_checkpoint from . import runner_registry LOG = logging.getLogger(__name__) -# Possibly move the function(s) below to anemoi-models or anemoi-utils since it could be used in transfer learning. - - -def contains_any(key, specifications): - contained = False - for specification in specifications: - if specification in key: - contained = True - break - return contained - - -def update_state_dict( - model, external_state_dict, keywords="", ignore_mismatched_layers=False, ignore_additional_layers=False -): - """Update the model's stated_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" - - LOG.info("Updating model state dictionary.") - - if isinstance(keywords, str): - keywords = [keywords] - - # select relevant part of external_state_dict - reduced_state_dict = {k: v for k, v in external_state_dict.items() if contains_any(k, keywords)} - model_state_dict = model.state_dict() - - # check layers and their shapes - for key in list(reduced_state_dict): - if key not in model_state_dict: - if ignore_additional_layers: - LOG.info("Skipping injection of %s, which is not in the model.", key) - del reduced_state_dict[key] - else: - raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") - elif reduced_state_dict[key].shape != model_state_dict[key].shape: - if ignore_mismatched_layers: - LOG.info("Skipping injection of %s due to shape mismatch.", key) - LOG.info("Model shape: %s", model_state_dict[key].shape) - LOG.info("Provided shape: %s", reduced_state_dict[key].shape) - del reduced_state_dict[key] - else: - raise AssertionError( - "Mismatch in shape of %s. Consider setting 'ignore_mismatched_layers = True'.", key - ) - - # update - model.load_state_dict(reduced_state_dict, strict=False) - return model - def _get_supporting_arrays_from_graph(update_supporting_arrays: dict[str, str], graph: Any) -> dict: """Update the supporting arrays from the graph data.""" @@ -258,35 +210,37 @@ def __init__( @cached_property def graph(self): - + """Get the external graph from file.""" graph_path = self.graph_path + assert os.path.isfile( graph_path ), f"No graph found at {graph_path}. An external graph needs to be specified in the config file for this runner." + LOG.info("Loading external graph from path %s.", graph_path) return torch.load(graph_path, map_location="cpu", weights_only=False) + def on_device(self, device: str = "cpu"): + """Temporally reassign the device of the runner""" + + @contextmanager + def _device_manager(runner: ExternalGraphRunner, device: str): # type: ignore + original_device = runner.device + try: + runner.device = device + yield + finally: + runner.device = original_device + + return _device_manager(self, device) + @cached_property def model(self): # load the model from the checkpoint - device = self.device - self.device = "cpu" - model_instance = super().model - state_dict_ckpt = deepcopy(model_instance.state_dict()) - - # rebuild the model with the new graph - model_instance.graph_data = self.graph - model_instance.config = self.checkpoint._metadata._config - model_instance._build_model() - - # reinstate the weights, biases and normalizer from the checkpoint - # reinstating the normalizer is necessary for checkpoints that were created - # using transfer learning, where the statistics as stored in the checkpoint - # do not match the statistics used to build the normalizer in the checkpoint. - model_instance = update_state_dict( - model_instance, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"] - ) + with self.on_device("cpu"): + metadata = self.checkpoint._metadata.to_dict() + model = update_checkpoint(super().model, metadata, self.graph) LOG.info("Successfully built model with external graph and reassigned model weights!") - self.device = device - return model_instance.to(self.device) + + return model.to(self.device) diff --git a/src/anemoi/inference/utils/__init__.py b/src/anemoi/inference/utils/__init__.py new file mode 100644 index 00000000..c6149c4e --- /dev/null +++ b/src/anemoi/inference/utils/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/inference/utils/redefine_graph.py b/src/anemoi/inference/utils/redefine_graph.py new file mode 100644 index 00000000..bc7242de --- /dev/null +++ b/src/anemoi/inference/utils/redefine_graph.py @@ -0,0 +1,300 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING +from typing import NamedTuple + +LOG = logging.getLogger(__name__) + +if TYPE_CHECKING: + import numpy as np + from torch_geometric.data import HeteroData + + +def update_state_dict( + model, + external_state_dict, + keywords: list[str] | None = None, + ignore_mismatched_layers=False, + ignore_additional_layers=False, +): + """Update the model's state_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered.""" + + LOG.info("Updating model state dictionary.") + + keywords = keywords or [] + + # select relevant part of external_state_dict + reduced_state_dict = {k: v for k, v in external_state_dict.items() if any(kw in k for kw in keywords)} + model_state_dict = model.state_dict() + + # check layers and their shapes + for key in list(reduced_state_dict): + if key not in model_state_dict: + if ignore_additional_layers: + LOG.info("Skipping injection of %s, which is not in the model.", key) + del reduced_state_dict[key] + else: + raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.") + elif reduced_state_dict[key].shape != model_state_dict[key].shape: + if ignore_mismatched_layers: + LOG.info("Skipping injection of %s due to shape mismatch.", key) + LOG.info("Model shape: %s", model_state_dict[key].shape) + LOG.info("Provided shape: %s", reduced_state_dict[key].shape) + del reduced_state_dict[key] + else: + raise AssertionError(f"Mismatch in shape of {key}. Consider setting 'ignore_mismatched_layers = True'.") + + model.load_state_dict(reduced_state_dict, strict=False) + return model + + +def get_coordinates_from_file(latlon_path: Path) -> tuple["np.ndarray", "np.ndarray"]: + """Get coordinates from a numpy file. + + Parameters + ---------- + latlon_path : Path + Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Latitudes and longitudes arrays + """ + import numpy as np + + latlon = np.load(latlon_path) + return latlon[:, 0], latlon[:, 1] + + +class Coordinate(NamedTuple): + north: float + west: float + south: float + east: float + resolution: float + + +def get_coordinates_from_mars_request(coords: Coordinate) -> tuple["np.ndarray", "np.ndarray"]: + """Get coordinates from MARS request parameters. + + Parameters + ---------- + coords : Coordinate + Coordinates (North West South East Resolution) + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Latitudes and longitudes arrays + """ + import earthkit.data as ekd + + area = [coords.north, coords.west, coords.south, coords.east] + + resolution = str(coords.resolution) + if resolution.replace(".", "", 1).isdigit(): + resolution = f"{resolution}/{resolution}" + + ds = ekd.from_source( + "mars", + { + "AREA": area, + "GRID": f"{resolution}", + "param": "2t", + "date": -2, + "stream": "oper", + "type": "an", + "levtype": "sfc", + }, + ) + return ds[0].grid_points() # type: ignore + + +def combine_nodes_with_global_grid( + latitudes: "np.ndarray", longitudes: "np.ndarray", global_grid: str +) -> tuple["np.ndarray", "np.ndarray", "np.ndarray", "np.ndarray"]: + """Combine lat/lon nodes with global grid if specified. + + Returns lats, lons, local_mask, global_mask + """ + import numpy as np + from anemoi.datasets.grids import cutout_mask + from anemoi.utils.grids import grids + + global_points = grids(global_grid) + + global_removal_mask = cutout_mask(latitudes, longitudes, global_points["latitudes"], global_points["longitudes"]) + lats = np.concatenate([latitudes, global_points["latitudes"][global_removal_mask]]) + lons = np.concatenate([longitudes, global_points["longitudes"][global_removal_mask]]) + local_mask = np.array([True] * len(latitudes) + [False] * sum(global_removal_mask), dtype=bool) + + return lats, lons, local_mask, global_removal_mask + + +def make_data_graph( + lats: "np.ndarray", + lons: "np.ndarray", + local_mask: "np.ndarray", + global_mask: "np.ndarray", + reference_node_name: str = "data", + *, + mask_attr_name: str = "cutout_mask", + attrs: dict | None = None, +) -> "HeteroData": + """Make a data graph with the given lat/lon nodes and attributes.""" + import torch + from anemoi.graphs.nodes import LatLonNodes + from torch_geometric.data import HeteroData + + graph = LatLonNodes(lats, lons, name=reference_node_name).update_graph(HeteroData(), attrs_config=attrs) # type: ignore + graph[reference_node_name][mask_attr_name] = torch.from_numpy(local_mask).unsqueeze(1) + return graph + + +def make_graph_from_coordinates( + local_lats: "np.ndarray", local_lons: "np.ndarray", global_resolution: str, metadata: dict, supporting_arrays: dict +) -> tuple[dict, dict, "HeteroData"]: + """Make a graph from coordinates. + + Parameters + ---------- + local_lats : np.ndarray + Local latitude coordinates + local_lons : np.ndarray + Local longitude coordinates + global_resolution : str + Global grid resolution (e.g. n320, o96) + metadata : dict + Checkpoint metadata + supporting_arrays : dict + Supporting arrays from checkpoint + + Returns + ------- + tuple[dict, dict, HeteroData] + Updated metadata, supporting arrays, and graph + """ + import numpy as np + + if global_resolution is None: + raise ValueError("Global resolution must be specified when generating graph from coordinates.") + + LOG.info("Coordinates loaded. Number of local nodes: %d", len(local_lats)) + lats, lons, local_mask, global_mask = combine_nodes_with_global_grid(local_lats, local_lons, global_resolution) + + graph_config = deepcopy(metadata["config"]["graph"]) + data_graph = graph_config["nodes"].pop("data") + + from anemoi.graphs.create import GraphCreator + from anemoi.utils.config import DotDict + + creator = GraphCreator(DotDict(graph_config)) + + LOG.info("Updating graph...") + LOG.debug("Using %r", graph_config) + + def nested_get(d, keys, default=None): + for key in keys: + d = d.get(key, {}) + return d or default + + mask_attr_name = nested_get(graph_config, ["nodes", "hidden", "node_builder", "mask_attr_name"], "cutout") + + data_graph_attributes = None + # if mask_attr_name in data_graph.get("attributes", {}): + # data_graph_attributes = {mask_attr_name: data_graph["attributes"][mask_attr_name]} + + LOG.info("Found mask attribute name: %r", mask_attr_name) + # LOG.info("Found data graph attributes: %s", data_graph_attributes) + + data_graph = make_data_graph( + lats, + lons, + local_mask, + global_mask, + reference_node_name="data", + mask_attr_name=mask_attr_name, + attrs=data_graph_attributes, + ) + + LOG.info("Created data graph with %d nodes.", data_graph.num_nodes) + graph = creator.clean(creator.update_graph(data_graph)) + + supporting_arrays[f"global/{mask_attr_name}"] = global_mask + supporting_arrays[f"lam_0/{mask_attr_name}"] = np.array([True] * len(local_lats)) + + supporting_arrays["latitudes"] = lats + supporting_arrays["longitudes"] = lons + supporting_arrays["grid_indices"] = np.ones(local_mask.shape, dtype=np.int64) + + return metadata, supporting_arrays, graph + + +def update_checkpoint(model, metadata: dict, graph: "HeteroData"): + """Update checkpoint with new graph and update state dict.""" + from anemoi.utils.config import DotDict + + state_dict_ckpt = deepcopy(model.state_dict()) + + # rebuild the model with the new graph + model.graph_data = graph + model.config = DotDict(metadata).config + model._build_model() + + # reinstate the weights, biases and normalizer from the checkpoint + # reinstating the normalizer is necessary for checkpoints that were created + # using transfer learning, where the statistics as stored in the checkpoint + # do not match the statistics used to build the normalizer in the checkpoint. + model_instance = update_state_dict(model, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"]) + + return model_instance + + +def load_graph_from_file(graph_path: Path) -> "HeteroData": + """Load graph from file. + + Parameters + ---------- + graph_path : Path + Path to graph file + + Returns + ------- + HeteroData + Loaded graph + """ + import torch + + LOG.info("Loading graph from %s", graph_path) + return torch.load(graph_path, weights_only=False, map_location=torch.device("cpu")) + + +def create_graph_from_config(graph_config_path: Path) -> "HeteroData": + """Create graph from configuration file. + + Parameters + ---------- + graph_config_path : Path + Path to graph configuration file + + Returns + ------- + HeteroData + Created graph + """ + from anemoi.graphs.create import GraphCreator + from torch_geometric.data import HeteroData + + return GraphCreator(graph_config_path).update_graph(HeteroData())