Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/cli/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ The commands are:
- :ref:`Validate Command <validate-command>`
- :ref:`Patch Command <patch-command>`
- :ref:`Requests Command <requests-command>`
- :ref:`redefine_graph Command <redefine-command>`
160 changes: 160 additions & 0 deletions docs/cli/redefine_graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
.. _redefine_graph-command:

Redefine Graph Command
======================

With this command, you can redefine the graph of a checkpoint file.
This is useful when you want to change / reconfigure the local-domain of a model, or rebuild with a new graph.

We should caution that such transfer of the model from one graph to
another is not guaranteed to lead to good results. Still, it is a
powerful tool to explore generalisability of the model or to test
performance before starting fine tuning through transfer learning.

This will create a new checkpoint file with the updated graph, and optionally save the graph to a file.

Subcommands allow for a graph to be made from a lat/lon coordinate file, bounding box, or from a defined graph config.

*********
Usage
*********

.. code-block:: bash

% anemoi-inference redefine_graph --help

Redefine the graph of a checkpoint file.

positional arguments:
path Path to the checkpoint.

options:
-h, --help show this help message and exit
-g GRAPH, --graph GRAPH
Path to graph file to use
-y GRAPH_CONFIG, --graph-config GRAPH_CONFIG
Path to graph config to use
-ll LATLON, --latlon LATLON
Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes.
-c COORDS COORDS COORDS COORDS COORDS, --coords COORDS COORDS COORDS COORDS COORDS
Coordinates, (North West South East Resolution).
-gr GLOBAL_RESOLUTION, --global_-esolution GLOBAL_RESOLUTION
Global grid resolution required with --coords, (e.g. n320, o96).
--save-graph SAVE_GRAPH
Path to save the updated graph.
--output OUTPUT Path to save the updated checkpoint.


*********
Examples
*********

Here are some examples of how to use the `redefine_graph` command:

#. Using a graph file:

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --graph path/to/graph

#. Using a graph configuration:

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --graph_config path/to/graph_config

.. note::
The configuration of the existing graph can be found using:

.. code-block:: bash

anemoi-inference metadata path/to/checkpoint -get config.graph ----yaml

#. Using latitude/longitude coordinates:
This lat lon file should be a numpy file of shape (N, 2) with latitudes and longitudes.

It can be easily made from a list of coordinates as follows:

.. code-block:: python

import numpy as np
coords = np.array(np.meshgrid(latitudes, longitudes)).T.reshape(-1, 2)
np.save('path/to/latlon.npy', coords)

Once created,

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --latlon path/to/latlon.npy

#. Using bounding box coordinates:

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --coords North West South East Resolution

i.e.

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320


All examples can optionally save the updated graph and checkpoint using the `--save-graph` and `--output` options.

***************************
Complete Inference Example
***************************

For this example we will redefine a checkpoint using a bounding box and then run inference


Redefine the checkpoint
-----------------------

.. code-block:: bash

anemoi-inference redefine_graph path/to/checkpoint --coords 30.0 -10.0 20.0 0.0 0.1/0.1 --global_resolution n320 --save-graph path/to/updated_graph --output path/to/updated_checkpoint

Create the inference config
---------------------------

If you have an input file of the expected shape handy use it in place of the input block, here we will show
how to use MARS to handle the regridding.

.. note::
Using the `anemoi-plugins-ecmwf-inference <https://github.com/ecmwf/anemoi-plugins-ecmwf>`_ package, preprocessors are available which can handle the regridding for you from other sources.

.. code-block:: yaml

checkpoint: path/to/updated_checkpoint
date: -2

input:
cutout:
lam_0:
mars:
grid: 0.1/0.1 # RESOLUTION WE SET
area: 30.0/-10.0/20.0/0.0 # BOUNDING BOX WE SET, N W S E
global:
mars:
grid: n320 # GLOBAL RESOLUTION WE SET


Run inference
-----------------

.. code-block:: bash

anemoi-inference run path/to/updated_checkpoint


**********
Reference
**********

.. argparse::
:module: anemoi.inference.__main__
:func: create_parser
:prog: anemoi-inference
:path: redefine_graph
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ You may also have to install pandoc on MacOS:
cli/inspect
cli/patch
cli/requests
cli/redefine_graph

.. toctree::
:maxdepth: 1
Expand Down
165 changes: 165 additions & 0 deletions src/anemoi/inference/commands/redefine_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# (C) Copyright 2025- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
from argparse import ArgumentParser
from argparse import Namespace
from pathlib import Path

from . import Command

LOG = logging.getLogger(__name__)


def check_redefine_imports():
"""Check if required packages are installed."""
required_packages = ["anemoi.datasets", "anemoi.graphs", "anemoi.models"]
from importlib.util import find_spec

for package in required_packages:
if find_spec(package) is None:
raise ImportError(f"{package!r} is required for this command.")


def format_namespace_as_str(namespace: Namespace) -> str:
"""Format an argparse Namespace object as command-line arguments."""
args = []

for key, value in vars(namespace).items():
if key == "command":
continue
if value is None:
continue

# Convert underscores to hyphens for command line format
arg_name = f"--{key.replace('_', '-')}"

if isinstance(value, bool):
if value:
args.append(arg_name)
elif isinstance(value, list):
args.append(f"{arg_name} {' '.join(map(str, value))}")
else:
args.extend([arg_name, str(value)])

return " ".join(args)


class RedefineGraphCmd(Command):
"""Redefine the graph of a checkpoint file."""

def add_arguments(self, command_parser: ArgumentParser) -> None:
"""Add arguments to the command parser.

Parameters
----------
command_parser : ArgumentParser
The argument parser to which the arguments will be added.
"""
command_parser.description = "Redefine the graph of a checkpoint file. If using coordinate specifications, assumes the input to the local domain is already regridded."
command_parser.add_argument("path", help="Path to the checkpoint.")

group = command_parser.add_mutually_exclusive_group(required=True)

group.add_argument("-g", "--graph", type=Path, help="Path to graph file to use")
group.add_argument("-y", "--graph-config", type=Path, help="Path to graph config to use")
group.add_argument(
"-ll",
"--latlon",
type=Path,
help="Path to coordinate npy, should be of shape (N, 2) with latitudes and longitudes.",
)
group.add_argument("-c", "--coords", type=str, help="Coordinates, (North West South East Resolution).", nargs=5)

command_parser.add_argument(
"-gr",
"--global-resolution",
type=str,
help="Global grid resolution required with --coords, (e.g. n320, o96).",
)

command_parser.add_argument("--save-graph", type=str, help="Path to save the updated graph.", default=None)
command_parser.add_argument("--output", type=str, help="Path to save the updated checkpoint.", default=None)

def run(self, args: Namespace) -> None:
"""Run the redefine_graph command.

Parameters
----------
args : Namespace
The arguments passed to the command.
"""
from anemoi.inference.utils.redefine_graph import create_graph_from_config
from anemoi.inference.utils.redefine_graph import get_coordinates_from_file
from anemoi.inference.utils.redefine_graph import get_coordinates_from_mars_request
from anemoi.inference.utils.redefine_graph import load_graph_from_file
from anemoi.inference.utils.redefine_graph import make_graph_from_coordinates
from anemoi.inference.utils.redefine_graph import update_checkpoint

check_redefine_imports()

import torch
from anemoi.utils.checkpoints import load_metadata
from anemoi.utils.checkpoints import save_metadata

path = Path(args.path)

# Load checkpoint metadata and supporting arrays
metadata, supporting_arrays = load_metadata(str(path), supporting_arrays=True)

# Add command to history
metadata.setdefault("history", [])
metadata["history"].append(f"anemoi-inference redefine_graph {format_namespace_as_str(args)}")

# Create or load the graph
if args.graph is not None:
graph = load_graph_from_file(args.graph)
elif args.graph_config is not None:
graph = create_graph_from_config(args.graph_config)
else:
# Generate graph from coordinates
LOG.info("Generating graph from coordinates...")

# Get coordinates based on input type
if args.latlon is not None:
local_lats, local_lons = get_coordinates_from_file(args.latlon)
elif args.coords is not None:
local_lats, local_lons = get_coordinates_from_mars_request(args.coords)
else:
raise ValueError("No valid coordinates found.")

metadata, supporting_arrays, graph = make_graph_from_coordinates(
local_lats, local_lons, args.global_resolution, metadata, supporting_arrays
)

# Save graph if requested
if args.save_graph is not None:
torch.save(graph, args.save_graph)
LOG.info("Saved updated graph to %s", args.save_graph)

# Update checkpoint
LOG.info("Updating checkpoint...")
model = torch.load(path, weights_only=False, map_location=torch.device("cpu"))
model = update_checkpoint(model, metadata, graph)

# Save updated checkpoint
model_path = args.output if args.output is not None else f"{path.stem}_updated{path.suffix}"
torch.save(model, model_path)

save_metadata(
model_path,
metadata=metadata,
supporting_arrays=supporting_arrays,
)

LOG.info("Updated checkpoint saved to %s", model_path)


command = RedefineGraphCmd
Loading
Loading