From 9cb8c52a46f2915185cf2dfe306b29ecbef690d7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 16:40:18 +0000 Subject: [PATCH 01/79] feat: recipe generator --- src/anemoi/datasets/recipe.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/anemoi/datasets/recipe.py diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py new file mode 100644 index 000000000..eed457f5b --- /dev/null +++ b/src/anemoi/datasets/recipe.py @@ -0,0 +1,23 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +class Recipe: + pass + + def dump(): + pass + + +if __name__ == "__main__": + r = Recipe() + r.description = "test" + + r.add(r.mars()) + r.add(r.rename(r.mars())) From 01774dda92f39b25aa44e455bfccc08ecfe252f4 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 18:15:39 +0100 Subject: [PATCH 02/79] update --- src/anemoi/datasets/recipe.py | 97 +++++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index eed457f5b..06f3e05eb 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -7,12 +7,101 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import logging -class Recipe: +import yaml +from anemoi.transform.filters import filter_registry as transform_filter_registry + +from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry +from anemoi.datasets.create.sources import source_registry + +LOG = logging.getLogger(__name__) + + +class Step: + def __init__(self, owner, *args, **kwargs): + self.owner = owner + self.args = args + self.kwargs = kwargs + + def as_dict(self): + return {self.owner.name: self.kwargs} + + +class Source(Step): pass - def dump(): - pass + +class Filter(Step): + pass + + +class SourceMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + return Source(self, *args, **kwargs) + + +class FilterMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + return Filter(self, *args, **kwargs) + + +class Recipe: + + def __init__(self): + self.description = None + self._steps = [] + + sources = source_registry.factories.copy() + filters = transform_filter_registry.factories.copy() + + for key, factory in datasets_filter_registry.factories.items(): + if key in filters: + LOG.warning( + f"Filter `{key}` is registered in anemoi.datasets filter registry and in anemoi.transform filter registry" + ) + filters[key] = factory + + for key, factory in sources.items(): + if key in filters: + LOG.warning( + f"Source `{key}` is registered in anemoi.datasets source registry and in anemoi.transform filter registry" + ) + del filters[key] + + for key, factory in sources.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, SourceMaker(key, factory)) + + for key, factory in filters.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, FilterMaker(key, factory)) + + def add(self, step): + self._steps.append(step) + + def dump(self): + result = { + "description": self.description, + "input": [s.as_dict() for s in self._steps], + } + + if len(result["input"]) == 1: + result = result["input"][0] + else: + result["input"] = {"join": result["input"]} + + print(yaml.safe_dump(result)) if __name__ == "__main__": @@ -21,3 +110,5 @@ def dump(): r.add(r.mars()) r.add(r.rename(r.mars())) + + r.dump() From 33793a674df19b922f454055b6d4813303cb522f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 18:24:01 +0100 Subject: [PATCH 03/79] update --- src/anemoi/datasets/recipe.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 06f3e05eb..f8b01caa5 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -29,11 +29,23 @@ def as_dict(self): class Source(Step): - pass + def __init__(self, owner, **kwargs): + super().__init__(owner, **kwargs) + self._source = None class Filter(Step): - pass + def __init__(self, owner, previous, **kwargs): + super().__init__(owner, **kwargs) + self.previous = previous + + def as_dict(self): + prev = self.previous.as_dict() + if isinstance(prev, dict) and "pipe" in prev: + prev = prev.copy() + prev["pipe"] = prev["pipe"].copy() + [super().as_dict()] + return prev + return {"pipe": [prev, super().as_dict()]} class SourceMaker: @@ -109,6 +121,6 @@ def dump(self): r.description = "test" r.add(r.mars()) - r.add(r.rename(r.mars())) + r.add(r.rescale(r.rename(r.mars()))) r.dump() From 920d52303ee56836e7303da7cc4d3c6ec680b074 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 19:14:49 +0100 Subject: [PATCH 04/79] update --- src/anemoi/datasets/recipe.py | 114 ++++++++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 26 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index f8b01caa5..5702db119 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -19,33 +19,60 @@ class Step: + + def __or__(self, other): + return Pipe(self, other) + + def __add__(self, other): + return Join(self, other) + + +class Chain(Step): + def __init__(self, *args): + if len(args) > 0 and isinstance(args[0], self.__class__): + args = args[0].steps + args[1:] + + self.steps = args + + def as_dict(self): + if len(self.steps) == 1: + return self.steps[0].as_dict() + return {self.name: [s.as_dict() for s in self.steps]} + + def __repr__(self): + return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + + +class Pipe(Chain): + name = "pipe" + + +class Join(Chain): + name = "join" + + +class Base(Step): def __init__(self, owner, *args, **kwargs): self.owner = owner - self.args = args - self.kwargs = kwargs + self.params = {} + for a in args: + assert isinstance(a, dict), f"Invalid argument {a}" + self.params.update(a) + self.params.update(kwargs) def as_dict(self): - return {self.owner.name: self.kwargs} + return {self.owner.name: self.params} + def __repr__(self): + return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" -class Source(Step): - def __init__(self, owner, **kwargs): - super().__init__(owner, **kwargs) - self._source = None +class Source(Base): + pass -class Filter(Step): - def __init__(self, owner, previous, **kwargs): - super().__init__(owner, **kwargs) - self.previous = previous - def as_dict(self): - prev = self.previous.as_dict() - if isinstance(prev, dict) and "pipe" in prev: - prev = prev.copy() - prev["pipe"] = prev["pipe"].copy() + [super().as_dict()] - return prev - return {"pipe": [prev, super().as_dict()]} +class Filter(Base): + pass class SourceMaker: @@ -63,6 +90,10 @@ def __init__(self, name, factory): self.factory = factory def __call__(self, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], Step): + prev = args[0] + args = args[1:] + return Pipe(prev, Filter(self, *args, **kwargs)) return Filter(self, *args, **kwargs) @@ -105,14 +136,9 @@ def add(self, step): def dump(self): result = { "description": self.description, - "input": [s.as_dict() for s in self._steps], + "input": Join(*self._steps).as_dict(), } - if len(result["input"]) == 1: - result = result["input"][0] - else: - result["input"] = {"join": result["input"]} - print(yaml.safe_dump(result)) @@ -120,7 +146,43 @@ def dump(self): r = Recipe() r.description = "test" - r.add(r.mars()) - r.add(r.rescale(r.rename(r.mars()))) + # r.add( + # r.mars( + # expver="0001", + # levtype="sfc", + # param=["2t"], + # number=[0, 1], + # ) + # ) + + # r.add( + # r.rescale( + # r.rename( + # r.mars( + # expver="0002", + # levtype="sfc", + # param=["2t"], + # number=[0, 1], + # ), + # param={"2t": "2t_0002"}, + # ), + # {"2t_0002": ["mm", "m"]}, + # ) + # ) + + m1 = r.mars(expver="0001", levtype="sfc", param=["2t"], number=[0, 1]) + m2 = r.mars(expver="0002", levtype="sfc", param=["2t"], number=[0, 1]) + + m3 = r.mars(expver="0003", levtype="sfc", param=["2t"], number=[0, 1]) + + r.add( + (m1 + m2 + m3) + | r.rename( + param={"2t": "2t_0002"}, + ) + | r.rescale( + {"2t_0002": ["mm", "m"]}, + ) + ) r.dump() From ed5d190b0fdf13fdcfa8909dacb9522bacec6e87 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 21:31:40 +0100 Subject: [PATCH 05/79] update --- src/anemoi/datasets/recipe.py | 183 +++++++++++++++++++++++++--------- 1 file changed, 136 insertions(+), 47 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 5702db119..9d8b17545 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -18,6 +18,14 @@ LOG = logging.getLogger(__name__) +class Index: + def __init__(self, index): + self.name = str(index) + + def __repr__(self): + return f"Index({self.name})" + + class Step: def __or__(self, other): @@ -33,15 +41,23 @@ def __init__(self, *args): args = args[0].steps + args[1:] self.steps = args + self.index = [Index(i) for i in range(len(self.steps))] - def as_dict(self): + def as_dict(self, recipe): if len(self.steps) == 1: - return self.steps[0].as_dict() - return {self.name: [s.as_dict() for s in self.steps]} + return self.steps[0].as_dict(recipe) + return {self.name: [s.as_dict(recipe) for s in self.steps]} def __repr__(self): return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + def path(self, target, result, *path): + for i, s in enumerate(self.steps): + s.path(target, result, *path, self, self.index[i]) + + def collocated(self, a, b): + return True + class Pipe(Chain): name = "pipe" @@ -51,21 +67,71 @@ class Join(Chain): name = "join" +class Concat(Step): + name = "concat" + + def __init__(self, args): + assert isinstance(args, dict), f"Invalid argument {args}" + self.params = args + + def __setitem__(self, key, value): + self.params[key] = value + + def as_dict(self, recipe): + + result = [] + + for k, v in sorted(self.params.items()): + result.append({"dates": dict(start=k[0], end=k[1]), **v.as_dict(recipe)}) + + return {"concat": result} + + def collocated(self, a, b): + return a[0] is b[0] + + def path(self, target, result, *path): + + for i, (k, v) in enumerate(sorted(self.params.items())): + v.path(target, result, *path, self, Index(i)) + + class Base(Step): def __init__(self, owner, *args, **kwargs): self.owner = owner + self.name = owner.name self.params = {} for a in args: assert isinstance(a, dict), f"Invalid argument {a}" self.params.update(a) self.params.update(kwargs) - def as_dict(self): - return {self.owner.name: self.params} + def as_dict(self, recipe): + + def resolve(params, recipe): + if isinstance(params, dict): + return {k: resolve(v, recipe) for k, v in params.items()} + + if isinstance(params, (list, tuple)): + return [resolve(v, recipe) for v in params] + + if isinstance(params, list): + return [resolve(v, recipe) for v in params] + + if isinstance(params, Step): + return recipe.resolve(self, params) + + return params + + return {self.owner.name: resolve(self.params, recipe)} def __repr__(self): return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" + def path(self, target, result, *path): + + if self is target: + result.append([*path, self]) + class Source(Base): pass @@ -101,7 +167,7 @@ class Recipe: def __init__(self): self.description = None - self._steps = [] + self.input = Join() sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -130,59 +196,82 @@ def __init__(self): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) - def add(self, step): - self._steps.append(step) - def dump(self): result = { "description": self.description, - "input": Join(*self._steps).as_dict(), + "input": self.input.as_dict(self), } print(yaml.safe_dump(result)) + def concat(self, *args, **kwargs): + return Concat(*args, **kwargs) + + def resolve(self, source, target): + assert isinstance(target, Source), f"Only sources can be used as template {target}" + + top = Index("input") # So we have 'input' first in the path + + path_to_source = [] + self.input.path(source, path_to_source, top) + if len(path_to_source) == 0: + raise ValueError(f"Source {source} not found in recipe") + if len(path_to_source) > 1: + raise ValueError(f"Source {source} found in multiple locations {path_to_source}") + path_to_source = path_to_source[0] + + path_to_target = [] + self.input.path(target, path_to_target, top) + if len(path_to_target) == 0: + raise ValueError(f"Target {target} not found in recipe") + if len(path_to_target) > 1: + raise ValueError(f"Target {target} found in multiple locations {path_to_target}") + path_to_target = path_to_target[0] + + a = [s for s in path_to_target] + b = [s for s in path_to_source] + common_ancestor = None + while a[0] is b[0]: + common_ancestor = a[0] + a = a[1:] + b = b[1:] + + assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" + + if not common_ancestor.collocated(a, b): + source = ".".join(s.name for s in path_to_source) + target = ".".join(s.name for s in path_to_target) + raise ValueError( + f"Source ${{{source}}} and target ${{{target}}} are not collocated (i.e. they are not branch of a 'concat')" + ) + + target = ".".join(s.name for s in path_to_target) + return f"${{{target}}}" + if __name__ == "__main__": r = Recipe() r.description = "test" - # r.add( - # r.mars( - # expver="0001", - # levtype="sfc", - # param=["2t"], - # number=[0, 1], - # ) - # ) - - # r.add( - # r.rescale( - # r.rename( - # r.mars( - # expver="0002", - # levtype="sfc", - # param=["2t"], - # number=[0, 1], - # ), - # param={"2t": "2t_0002"}, - # ), - # {"2t_0002": ["mm", "m"]}, - # ) - # ) - - m1 = r.mars(expver="0001", levtype="sfc", param=["2t"], number=[0, 1]) - m2 = r.mars(expver="0002", levtype="sfc", param=["2t"], number=[0, 1]) - - m3 = r.mars(expver="0003", levtype="sfc", param=["2t"], number=[0, 1]) - - r.add( - (m1 + m2 + m3) - | r.rename( - param={"2t": "2t_0002"}, - ) - | r.rescale( - {"2t_0002": ["mm", "m"]}, - ) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") + + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, ) + c[("2031", "2033")] = r.mars(expver="0005") + + r.input += c + r.dump() From 2562baf9ad8cf1be050b2605794bab4a3adf956b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:26:08 +0000 Subject: [PATCH 06/79] fix: better handling of xarray metadata --- .../create/sources/xarray_support/field.py | 13 +- .../create/sources/xarray_support/metadata.py | 123 ++---------------- 2 files changed, 19 insertions(+), 117 deletions(-) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 663aeab54..d46613474 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -80,12 +80,21 @@ def __init__(self, owner: Any, selection: Any) -> None: # Copy the metadata from the owner self._md = owner._metadata.copy() + aliases = {} for coord_name, coord_value in self.selection.coords.items(): if is_scalar(coord_value): # Extract the single value from the scalar dimension # and store it in the metadata coordinate = owner.by_name[coord_name] - self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value)) + normalised = coordinate.normalise(extract_single_value(coord_value)) + self._md[coord_name] = normalised + for alias in coordinate.mars_names: + aliases[alias] = normalised + + # Add metadata aliases (e.g. levelist == level) only if they are not already present + for alias, value in aliases.items(): + if alias not in self._md: + self._md[alias] = value # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude @@ -188,7 +197,7 @@ def forecast_reference_time(self) -> datetime.datetime: def __repr__(self) -> str: """Return a string representation of the field.""" - return repr(self._metadata) + return f"XArrayField({self._metadata})" def _values(self, dtype: Optional[type] = None) -> Any: """Return the values of the selection. diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 2a0e4d9cb..80f633c3c 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -23,87 +23,6 @@ LOG = logging.getLogger(__name__) -class _MDMapping: - """A class to handle metadata mapping for variables. - - Attributes - ---------- - variable : Any - The variable to map. - time : Any - The time associated with the variable. - mapping : Dict[str, str] - A dictionary mapping keys to variable names. - """ - - def __init__(self, variable: Any) -> None: - """Initialize the _MDMapping class. - - Parameters - ---------- - variable : Any - The variable to map. - """ - self.variable = variable - self.time = variable.time - self.mapping = dict() - # Aliases - - def _from_user(self, key: str) -> str: - """Get the internal key corresponding to a user-provided key. - - Parameters - ---------- - key : str - The user-provided key. - - Returns - ------- - str - The internal key corresponding to the user-provided key. - """ - return self.mapping.get(key, key) - - def from_user(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: - """Convert user-provided keys to internal keys. - - Parameters - ---------- - kwargs : Dict[str, Any] - A dictionary of user-provided keys and values. - - Returns - ------- - Dict[str, Any] - A dictionary with internal keys and original values. - """ - return {self._from_user(k): v for k, v in kwargs.items()} - - def __repr__(self) -> str: - """Return a string representation of the _MDMapping object. - - Returns - ------- - str - String representation of the _MDMapping object. - """ - return f"MDMapping({self.mapping})" - - def fill_time_metadata(self, field: Any, md: Dict[str, Any]) -> None: - """Fill the time metadata for a field. - - Parameters - ---------- - field : Any - The field to fill metadata for. - md : Dict[str, Any] - The metadata dictionary to update. - """ - valid_datetime = self.variable.time.fill_time_metadata(field._md, md) - if valid_datetime is not None: - md["valid_datetime"] = as_datetime(valid_datetime).isoformat() - - class XArrayMetadata(RawMetadata): """A class to handle metadata for XArray fields. @@ -129,10 +48,16 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ + from .field import XArrayField + + assert isinstance(field, XArrayField), type(field) self._field = field md = field._md.copy() - self._mapping = _MDMapping(field.owner) - self._mapping.fill_time_metadata(field, md) + + valid_datetime = field.owner.time.fill_time_metadata(field._md, md) + if valid_datetime is not None: + md["valid_datetime"] = as_datetime(valid_datetime).isoformat() + super().__init__(md) @cached_property @@ -192,38 +117,6 @@ def _valid_datetime(self) -> Optional[datetime.datetime]: """ return self._get("valid_datetime") - def get(self, key: str, astype: Optional[type] = None, **kwargs: Any) -> Any: - """Get a metadata value by key. - - Parameters - ---------- - key : str - The key to get the value for. - astype : Optional[type] - The type to cast the value to. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The value for the specified key, optionally cast to the specified type. - """ - - if key == "levelist": - # Special case for levelist, for compatibility with GRIB - if key not in self._d and "level" in self._d: - key = "level" - - if key in self._d: - if astype is not None: - return astype(self._d[key]) - return self._d[key] - - key = self._mapping._from_user(key) - - return super().get(key, astype=astype, **kwargs) - class XArrayFieldGeography(Geography): """A class to handle geography information for XArray fields. From f048a4b7c84ba4961b32210503920d29c61711b7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:30:52 +0000 Subject: [PATCH 07/79] update --- src/anemoi/datasets/commands/copy.py | 2 + src/anemoi/datasets/create/__init__.py | 18 +++++ src/anemoi/datasets/create/input/__init__.py | 8 +++ src/anemoi/datasets/create/input/action.py | 65 ++++++++++++++++++- src/anemoi/datasets/create/input/filter.py | 17 +++++ src/anemoi/datasets/create/input/function.py | 10 ++- src/anemoi/datasets/create/input/join.py | 3 + src/anemoi/datasets/create/input/pipe.py | 7 ++ .../datasets/create/input/repeated_dates.py | 9 +++ src/anemoi/datasets/create/input/step.py | 16 ++++- 10 files changed, 151 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index ea08c8aab..9cdca5a4d 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -505,6 +505,7 @@ def add_arguments(self, command_parser: Any) -> None: default=100, help="For optimisation purposes, data is transfered by blocks. Default is 100.", ) + command_parser.add_argument("--workdir", help="Working directory for the copy operation.", default=".") command_parser.add_argument("source", help="Source location.") command_parser.add_argument("target", help="Target location.") @@ -534,6 +535,7 @@ def run(self, args: Any) -> None: resume=args.resume, verbosity=args.verbosity, threads=args.transfers, + workdir=args.workdir, ) copier.run() return diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 081c618ed..fb46dd5d1 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1628,3 +1628,21 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An )[name] LOG.debug(f"Creating {cls.__name__} with {kwargs}") return cls(**kwargs) + + +def config_to_python(config: Any) -> Any: + import sys + + config = loader_config(config) + input = build_input_(config, build_output(config.output, None)) + code = input.to_python() + + code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();r.input = {code}; r.dump()" + + try: + import black + + return black.format_str(code, mode=black.Mode()) + except ImportError: + LOG.warning("Black not installed, skipping formatting") + return code diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index cd0b4ad40..734175698 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -75,6 +75,14 @@ def select(self, group_of_dates: GroupOfDates) -> Any: action = action_factory(self.config, context, self.action_path) return action.select(group_of_dates) + def to_python(self) -> str: + from .action import ActionContext + from .action import action_factory + + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.to_python() + def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index eadf01339..98936750e 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,6 +9,7 @@ import json import logging +import re from copy import deepcopy from typing import Any from typing import Dict @@ -159,6 +160,68 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" + def _to_python(self, name, config): + """Convert the action to Python code. + + Parameters + ---------- + name : str + The name of the action. + config : dict + The configuration for the action. + + Returns + ------- + str + The Python code representation of the action. + """ + import json + + RESERVED_KEYWORDS = ( + "and", + "or", + "not", + "is", + "in", + "if", + "else", + "elif", + "for", + "while", + "return", + "class", + "def", + "with", + "as", + "import", + "from", + "try", + "except", + "finally", + "raise", + "assert", + "break", + "continue", + "pass", + ) + + config = json.loads(json.dumps(config)) + + assert len(config) == 1, (name, config) + assert name in config, (name, config) + + config = config[name] + + params = [] + for k, v in config.items(): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" + class ActionContext(Context): """Represents the context in which an action is performed. @@ -254,6 +317,6 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path: from ..sources import create_source source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source) + return FunctionAction(context, action_path + [key], key, source, config) return cls(context, action_path + [key], *args, **kwargs) diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index 289bb3602..f4c55fa67 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -95,6 +95,7 @@ def __init__( previous_step: StepAction, name: str, filter: Any, + config: dict, *args: Any, **kwargs: Any, ) -> None: @@ -116,3 +117,19 @@ def __init__( super().__init__(context, action_path, previous_step, *args, **kwargs) self.name = name self.filter = filter + self.config = config + + def to_python(self) -> Any: + """Converts the action to Python code. + + Parameters + ---------- + file : str + The file to convert. + + Returns + ------- + Any + The converted Python code. + """ + return self._to_python(self.name, self.config) diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index 4d3d21b22..bb33a678a 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -91,7 +91,9 @@ class FunctionAction(Action): The name of the function. """ - def __init__(self, context: object, action_path: list, _name: str, source, **kwargs: Dict[str, Any]) -> None: + def __init__( + self, context: object, action_path: list, _name: str, source, config: dict, **kwargs: Dict[str, Any] + ) -> None: """Initializes a FunctionAction instance. Parameters @@ -108,6 +110,12 @@ def __init__(self, context: object, action_path: list, _name: str, source, **kwa super().__init__(context, action_path, **kwargs) self.name: str = _name self.source = source + self.config = config + + def to_python(self) -> str: + """Returns the Python representation of the function action.""" + + return self._to_python(self.name, self.config) @trace_select def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py index ba24c7072..761922271 100644 --- a/src/anemoi/datasets/create/input/join.py +++ b/src/anemoi/datasets/create/input/join.py @@ -107,6 +107,9 @@ def __init__(self, context: object, action_path: list, *configs: dict) -> None: super().__init__(context, action_path, *configs) self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] + def to_python(self) -> None: + return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" + def __repr__(self) -> str: """Returns a string representation of the JoinAction instance.""" content: str = "\n".join([str(i) for i in self.actions]) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py index 6c9fea0df..e2fda03e4 100644 --- a/src/anemoi/datasets/create/input/pipe.py +++ b/src/anemoi/datasets/create/input/pipe.py @@ -40,9 +40,13 @@ def __init__(self, context: Any, action_path: list, *configs: dict) -> None: f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" ) + self.actions: list = [] + current: Any = action_factory(configs[0], context, action_path + ["0"]) + self.actions.append(current) for i, c in enumerate(configs[1:]): current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) + self.actions.append(current) self.last_step: Any = current @trace_select @@ -64,3 +68,6 @@ def select(self, group_of_dates: Any) -> Any: def __repr__(self) -> str: """Return a string representation of the PipeAction.""" return f"PipeAction({self.last_step})" + + def to_python(self) -> str: + return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index ebf13b36e..788bf7f10 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -351,6 +351,15 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) + def to_python(self) -> None: + """Convert the action to Python code. + + Args: + file (Any): The file to write the Python code to. + """ + return self.source.to_python() + # self.mapper.to_python(file) + @trace_select def select(self, group_of_dates: Any) -> JoinResult: """Select and transform the group of dates. diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py index e99717094..a2e3ccd42 100644 --- a/src/anemoi/datasets/create/input/step.py +++ b/src/anemoi/datasets/create/input/step.py @@ -177,7 +177,14 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") filter = create_datasets_filter(None, config) - return FunctionStepAction(context, action_path + [key], previous_step, key, filter) + return FunctionStepAction( + context, + action_path + [key], + previous_step, + key, + filter, + config, + ) # Use filters from transform registry @@ -185,7 +192,12 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li from ..filters.transform import TransformFilter return FunctionStepAction( - context, action_path + [key], previous_step, key, TransformFilter(context, key, config) + context, + action_path + [key], + previous_step, + key, + TransformFilter(context, key, config), + config, ) raise ValueError(f"Unknown step action `{key}`") From 8ad5eb032e26cdace61ade8b5caa8b85c948a4d6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:35:19 +0000 Subject: [PATCH 08/79] update --- src/anemoi/datasets/create/__init__.py | 1 - src/anemoi/datasets/create/input/action.py | 2 +- src/anemoi/datasets/create/input/filter.py | 9 +--- .../datasets/create/input/repeated_dates.py | 11 ++--- src/anemoi/datasets/recipe.py | 42 +++++++++++-------- 5 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index fb46dd5d1..34504e974 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1631,7 +1631,6 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An def config_to_python(config: Any) -> Any: - import sys config = loader_config(config) input = build_input_(config, build_output(config.output, None)) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 98936750e..4cadff1c1 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -160,7 +160,7 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" - def _to_python(self, name, config): + def _to_python(self, name: str, config: dict) -> str: """Convert the action to Python code. Parameters diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index f4c55fa67..298cead41 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -119,17 +119,12 @@ def __init__( self.filter = filter self.config = config - def to_python(self) -> Any: + def to_python(self) -> str: """Converts the action to Python code. - Parameters - ---------- - file : str - The file to convert. - Returns ------- - Any + str The converted Python code. """ return self._to_python(self.name, self.config) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index 788bf7f10..0f5b5730a 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -9,6 +9,7 @@ import logging +import warnings from collections import defaultdict from typing import Any from typing import Dict @@ -351,14 +352,10 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) - def to_python(self) -> None: - """Convert the action to Python code. - - Args: - file (Any): The file to write the Python code to. - """ + def to_python(self) -> str: + """Convert the action to Python code.""" + warnings.warn("RepeatedDatesAction.to_python is still a work in progress") return self.source.to_python() - # self.mapper.to_python(file) @trace_select def select(self, group_of_dates: Any) -> JoinResult: diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 9d8b17545..5594682dd 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -250,28 +250,34 @@ def resolve(self, source, target): if __name__ == "__main__": - r = Recipe() - r.description = "test" - m1 = r.mars(expver="0001") - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") + if False: + r = Recipe() + r.description = "test" - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - c[("2031", "2033")] = r.mars(expver="0005") + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, + ) - r.input += c + c[("2031", "2033")] = r.mars(expver="0005") - r.dump() + r.input += c + + r.dump() + else: + from anemoi.datasets.create import config_to_python + + print(config_to_python("x.yaml")) From b33f3acfc990670c090f70b345a41ded802e5fae Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 11:42:35 +0000 Subject: [PATCH 09/79] fix: support other keys that param in rename filter --- src/anemoi/datasets/create/filters/rename.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/anemoi/datasets/create/filters/rename.py b/src/anemoi/datasets/create/filters/rename.py index ddae03f0f..4c87bd4bb 100644 --- a/src/anemoi/datasets/create/filters/rename.py +++ b/src/anemoi/datasets/create/filters/rename.py @@ -45,9 +45,7 @@ def __init__(self, field: Any, what: str, renaming: Dict[str, Dict[str, str]]) - """ self.field = field self.what = what - self.renaming = {} - for k, v in renaming.items(): - self.renaming[k] = {str(a): str(b) for a, b in v.items()} + self.renaming = renaming.copy() def metadata(self, key: Optional[str] = None, **kwargs: Any) -> Any: """Get metadata from the original field, with the option to rename the parameter. @@ -69,7 +67,7 @@ def metadata(self, key: Optional[str] = None, **kwargs: Any) -> Any: value = self.field.metadata(key, **kwargs) if key == self.what: - return self.renaming.get(self.what, {}).get(value, value) + return self.renaming.get(value, value) return value @@ -179,7 +177,7 @@ def __repr__(self) -> str: @legacy_filter(__file__) -def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: Any) -> ekd.FieldList: +def execute(context: Any, input: ekd.FieldList, **kwargs: Any) -> ekd.FieldList: """Rename fields based on the value of another field or a format string. Parameters @@ -188,8 +186,6 @@ def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: A The context in which the function is executed. input : List[Any] List of input fields. - what : str, optional - The field to be used for renaming. Defaults to "param". **kwargs : Any Additional keyword arguments for renaming. @@ -199,7 +195,13 @@ def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: A Array of renamed fields. """ - if what in kwargs and isinstance(kwargs[what], str): - return FieldArray([RenamedFieldFormat(fs, what, kwargs[what]) for fs in input]) + for k, v in kwargs.items(): - return FieldArray([RenamedFieldMapping(fs, what, kwargs) for fs in input]) + if not isinstance(v, dict): + input = [RenamedFieldMapping(fs, k, v) for fs in input] + elif isinstance(v, str): + input = [RenamedFieldFormat(fs, k, v) for fs in input] + else: + raise ValueError("Invalid renaming dictionary. Values must be strings or dictionaries.") + + return FieldArray(input) From 6d23027d3668fa0394ce6c4e17799f828199c2d9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 11:45:59 +0000 Subject: [PATCH 10/79] typo --- src/anemoi/datasets/create/filters/rename.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/filters/rename.py b/src/anemoi/datasets/create/filters/rename.py index 4c87bd4bb..5905d35b2 100644 --- a/src/anemoi/datasets/create/filters/rename.py +++ b/src/anemoi/datasets/create/filters/rename.py @@ -197,11 +197,11 @@ def execute(context: Any, input: ekd.FieldList, **kwargs: Any) -> ekd.FieldList: for k, v in kwargs.items(): - if not isinstance(v, dict): + if isinstance(v, dict): input = [RenamedFieldMapping(fs, k, v) for fs in input] elif isinstance(v, str): input = [RenamedFieldFormat(fs, k, v) for fs in input] else: - raise ValueError("Invalid renaming dictionary. Values must be strings or dictionaries.") + raise ValueError(f"Invalid renaming dictionary. Values must be strings or dictionaries. ({type(v)})") return FieldArray(input) From 9179daebf86230785d5e1102f9578b1c640909a2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 17:52:28 +0100 Subject: [PATCH 11/79] add command line --- .gitignore | 1 + src/anemoi/datasets/commands/recipe.py | 39 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 src/anemoi/datasets/commands/recipe.py diff --git a/.gitignore b/.gitignore index 4b70d9e82..158ba7bdd 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,4 @@ Untitled-*.py *.db *.tgz _api/ +trace.txt diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py new file mode 100644 index 000000000..b028969f6 --- /dev/null +++ b/src/anemoi/datasets/commands/recipe.py @@ -0,0 +1,39 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +from . import Command + +LOG = logging.getLogger(__name__) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + from anemoi.datasets.create import config_to_python + + print(config_to_python(args.path)) + + +command = Recipe From 79a391be1741fffe9bb960590a4a43eef0217050 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 05:34:18 +0100 Subject: [PATCH 12/79] update --- src/anemoi/datasets/create/__init__.py | 8 +- src/anemoi/datasets/create/config.py | 2 + src/anemoi/datasets/create/input/__init__.py | 9 ++ src/anemoi/datasets/create/input/concat.py | 16 +++ .../datasets/create/input/data_sources.py | 10 ++ src/anemoi/datasets/dates/__init__.py | 11 ++ src/anemoi/datasets/recipe.py | 103 +++++++++++++----- 7 files changed, 130 insertions(+), 29 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 34504e974..0aa8fc8e4 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -11,6 +11,7 @@ import json import logging import os +import re import time import uuid import warnings @@ -1634,9 +1635,12 @@ def config_to_python(config: Any) -> Any: config = loader_config(config) input = build_input_(config, build_output(config.output, None)) - code = input.to_python() + code1 = input.python_prelude() + code2 = input.to_python() - code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();r.input = {code}; r.dump()" + code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();{code1};r.input = {code2}; r.dump()" + + code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) try: import black diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 1042709a6..6de4a06cd 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -282,6 +282,8 @@ def __init__(self, config: dict, *args, **kwargs): self.output.order_by = normalize_order_by(self.output.order_by) + self.setdefault("dates", Config()) + self.dates["group_by"] = self.build.group_by ########### diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 734175698..b0b23f1b2 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -81,8 +81,17 @@ def to_python(self) -> str: context = ActionContext(**self.kwargs) action = action_factory(self.config, context, self.action_path) + return action.to_python() + def python_prelude(self) -> str: + from .action import ActionContext + from .action import action_factory + + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.python_prelude() + def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index 5399bbc1f..777610a18 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -162,3 +162,19 @@ def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResul return EmptyResult(self.context, self.action_path, group_of_dates) return ConcatResult(self.context, self.action_path, group_of_dates, results) + + def to_python(self) -> str: + """Returns the Python representation of the ConcatAction instance. + + Returns + ------- + str + The Python representation of the ConcatAction instance. + """ + + result = [] + + for i, (filtering_dates, action) in enumerate(self.parts): + result.append(f"{filtering_dates.to_python()}:{action.to_python()}") + + return f"r.concat({{{','.join(result)})" diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index b95f85568..d1ca2dc6e 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -58,6 +58,7 @@ def __init__( self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] self.input = action_factory(input, context, ["input"]) + self.names = [a_path for a_path, config in configs] def select(self, group_of_dates: GroupOfDates) -> "DataSourcesResult": """Selects the data sources result for the given group of dates. @@ -86,6 +87,15 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) + def python_prelude(self) -> str: + result = [] + for n, s in zip(self.names, self.sources): + result.append(f"{n}={s.to_python()}") + return ";".join(result) + + def to_python(self) -> str: + return self.input.to_python() + class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 99771faf1..9570d381f 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -273,6 +273,17 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) + def to_python(self) -> str: + """Convert the StartEndDates instance to a Python string. + + Returns + ------- + str + Python string representation of the instance. + """ + # assert self.frequency == frequency_to_timedelta(1), self.frequency + return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) + class Hindcast: """Class representing a single hindcast date. diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 5594682dd..2cbef29f7 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -11,6 +11,7 @@ import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry +from anemoi.utils.config import DotDict from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry from anemoi.datasets.create.sources import source_registry @@ -165,9 +166,17 @@ def __call__(self, *args, **kwargs): class Recipe: - def __init__(self): - self.description = None + def __init__(self, name=None, description=None, attribution=None, licence=None): + + self._description = description + self._attribution = attribution + self._licence = licence + self._name = name + self.input = Join() + self.output = DotDict() + self.statistics = DotDict() + self.build = DotDict() sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -197,12 +206,25 @@ def __init__(self): setattr(self, key, FilterMaker(key, factory)) def dump(self): + result = self.as_dict(self) + result["input"] = self.input.as_dict(self) + result["output"] = self.description + + print(yaml.safe_dump(result)) + + def as_dict(self): result = { + "name": self.name, "description": self.description, - "input": self.input.as_dict(self), + "attribution": self.attribution, + "licence": self.licence, } - print(yaml.safe_dump(result)) + for k, v in list(result.items()): + if v is None: + del result[k] + + return result def concat(self, *args, **kwargs): return Concat(*args, **kwargs) @@ -248,36 +270,63 @@ def resolve(self, source, target): target = ".".join(s.name for s in path_to_target) return f"${{{target}}}" + @property + def description(self): + return self._description -if __name__ == "__main__": + @description.setter + def description(self, value): + self._description = value - if False: - r = Recipe() - r.description = "test" + @property + def attribution(self): + return self._attribution + + @attribution.setter + def attribution(self, value): + self._attribution = value + + @property + def licence(self): + return self._licence + + @licence.setter + def licence(self, value): + self._licence = value + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + + +if __name__ == "__main__": - m1 = r.mars(expver="0001") - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") + r = Recipe() + r.description = "test" - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - c[("2031", "2033")] = r.mars(expver="0005") + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, + ) - r.input += c + c[("2031", "2033")] = r.mars(expver="0005") - r.dump() - else: - from anemoi.datasets.create import config_to_python + r.input += c - print(config_to_python("x.yaml")) + r.dump() From 203e09bd5bbef3a06c214cfc975db4bdde9ff469 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 05:44:35 +0100 Subject: [PATCH 13/79] update --- src/anemoi/datasets/create/__init__.py | 2 +- src/anemoi/datasets/recipe.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 0aa8fc8e4..0b0cb18ba 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1638,7 +1638,7 @@ def config_to_python(config: Any) -> Any: code1 = input.python_prelude() code2 = input.to_python() - code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();{code1};r.input = {code2}; r.dump()" + code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 2cbef29f7..c1b8bf8f7 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -206,9 +206,9 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): setattr(self, key, FilterMaker(key, factory)) def dump(self): - result = self.as_dict(self) + result = self.as_dict() result["input"] = self.input.as_dict(self) - result["output"] = self.description + # result["output"] = self.description print(yaml.safe_dump(result)) From b4433bd35667b4238e4724490c589bb637276212 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 10:11:34 +0100 Subject: [PATCH 14/79] update --- src/anemoi/datasets/create/__init__.py | 8 +- src/anemoi/datasets/create/input/__init__.py | 4 +- src/anemoi/datasets/create/input/action.py | 10 +- src/anemoi/datasets/create/input/concat.py | 4 + .../datasets/create/input/data_sources.py | 7 +- src/anemoi/datasets/create/input/filter.py | 3 + src/anemoi/datasets/create/input/function.py | 3 + src/anemoi/datasets/create/input/join.py | 4 + src/anemoi/datasets/create/input/pipe.py | 4 + src/anemoi/datasets/recipe.py | 101 ++++++++++++++---- 10 files changed, 120 insertions(+), 28 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 0b0cb18ba..e50695f4e 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1635,7 +1635,11 @@ def config_to_python(config: Any) -> Any: config = loader_config(config) input = build_input_(config, build_output(config.output, None)) - code1 = input.python_prelude() + + prelude = [] + input.python_prelude(prelude) + code1 = "\n".join(prelude) + code2 = input.to_python() code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" @@ -1646,6 +1650,6 @@ def config_to_python(config: Any) -> Any: import black return black.format_str(code, mode=black.Mode()) - except ImportError: + except Exception: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index b0b23f1b2..66266d53d 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -84,13 +84,13 @@ def to_python(self) -> str: return action.to_python() - def python_prelude(self) -> str: + def python_prelude(self, prelude) -> str: from .action import ActionContext from .action import action_factory context = ActionContext(**self.kwargs) action = action_factory(self.config, context, self.action_path) - return action.python_prelude() + return action.python_prelude(prelude) def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 4cadff1c1..6057106ef 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -288,9 +288,15 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path: assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") + if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") + if "label" in config: + config.pop("label") + if "name" in config: + config.pop("name") + if len(config) != 1: + print(json.dumps(config, indent=2, default=str)) + raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") config = deepcopy(config) key = list(config.keys())[0] diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index 777610a18..cef2a64c5 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -178,3 +178,7 @@ def to_python(self) -> str: result.append(f"{filtering_dates.to_python()}:{action.to_python()}") return f"r.concat({{{','.join(result)})" + + def python_prelude(self, prelude) -> None: + for filtering_dates, action in self.parts: + action.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index d1ca2dc6e..42c610315 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -87,11 +87,10 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_prelude(self) -> str: - result = [] + def python_prelude(self, prelude) -> str: for n, s in zip(self.names, self.sources): - result.append(f"{n}={s.to_python()}") - return ";".join(result) + self.sources.python_prelude(prelude) + prelude.append(f"{n}={s.to_python()}") def to_python(self) -> str: return self.input.to_python() diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index 298cead41..9357d2178 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -128,3 +128,6 @@ def to_python(self) -> str: The converted Python code. """ return self._to_python(self.name, self.config) + + def python_prelude(self, prelude) -> None: + pass diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index bb33a678a..651b509b2 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -117,6 +117,9 @@ def to_python(self) -> str: return self._to_python(self.name, self.config) + def python_prelude(self, prelude) -> str: + pass + @trace_select def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": """Selects the function result for the given group of dates. diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py index 761922271..9fc81eac9 100644 --- a/src/anemoi/datasets/create/input/join.py +++ b/src/anemoi/datasets/create/input/join.py @@ -110,6 +110,10 @@ def __init__(self, context: object, action_path: list, *configs: dict) -> None: def to_python(self) -> None: return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" + def python_prelude(self, prelude) -> None: + for i in self.actions: + i.python_prelude(prelude) + def __repr__(self) -> str: """Returns a string representation of the JoinAction instance.""" content: str = "\n".join([str(i) for i in self.actions]) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py index e2fda03e4..7b8f062f3 100644 --- a/src/anemoi/datasets/create/input/pipe.py +++ b/src/anemoi/datasets/create/input/pipe.py @@ -71,3 +71,7 @@ def __repr__(self) -> str: def to_python(self) -> str: return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" + + def python_prelude(self, prelude) -> None: + for i in self.actions: + i.python_prelude(prelude) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index c1b8bf8f7..3b6bad628 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -8,10 +8,16 @@ # nor does it submit to any jurisdiction. import logging +import os +import sys +from tempfile import TemporaryDirectory import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry from anemoi.datasets.create.sources import source_registry @@ -172,6 +178,7 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._attribution = attribution self._licence = licence self._name = name + self._dates = None self.input = Join() self.output = DotDict() @@ -205,19 +212,13 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) - def dump(self): - result = self.as_dict() - result["input"] = self.input.as_dict(self) - # result["output"] = self.description - - print(yaml.safe_dump(result)) - def as_dict(self): result = { "name": self.name, "description": self.description, "attribution": self.attribution, "licence": self.licence, + "dates": self.dates, } for k, v in list(result.items()): @@ -302,31 +303,95 @@ def name(self): def name(self, value): self._name = value + @property + def dates(self): + return self._dates + + def _parse_dates(self, value): + + start = None + end = None + frequency = 1 + + if isinstance(value, (list, tuple)): + if len(value) in [2, 3]: + start = value[0] + end = value[1] + + if len(value) == 3: + frequency = frequency_to_string(frequency_to_timedelta(value[2])) + if isinstance(frequency, int): + frequency = f"{frequency}h" + + if start is None or end is None: + raise ValueError(f"Invalid dates {value}") + + if isinstance(frequency, int): + frequency = f"{frequency}h" + + return dict( + start=as_datetime(start), + end=as_datetime(end), + frequency=frequency, + ) + + @dates.setter + def dates(self, value): + self._dates = self._parse_dates(value) + + def dump(self, file=sys.stdout): + result = self.as_dict() + result["input"] = self.input.as_dict(self) + + yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) + + def test(self, output="recipe.zarr"): + from argparse import ArgumentParser + + from anemoi.datasets.commands.create import command + + parser = ArgumentParser() + parser.add_argument("command", help="Command to run") + + cmd = command() + cmd.add_arguments(parser) + + with TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "recipe.yaml") + with open(path, "w") as file: + self.dump(file) + + args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) + cmd.run(args) + if __name__ == "__main__": r = Recipe() r.description = "test" + r.dates = ("1900-01-01", "2023-12-31") + m1 = r.mars(expver="0001") m2 = r.mars(expver="0002") m3 = r.mars(expver="0003") - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) # | r.rescale(tp=["mm", "m"]) r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + # m0 = r.mars(expver="0000") + # c = r.concat( + # { + # ("190", "2000"): m0, + # ("2001", "2020"): r.mars(expver="0002"), + # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + # }, + # ) - c[("2031", "2033")] = r.mars(expver="0005") + # c[("2031", "2033")] = r.mars(expver="0005") - r.input += c + # r.input += c r.dump() + r.test() From 6f3fdb06148426b4042b2fabc46a97cff22f0785 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 10:11:58 +0100 Subject: [PATCH 15/79] update --- src/anemoi/datasets/commands/migrate.py | 73 +++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/anemoi/datasets/commands/migrate.py diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py new file mode 100644 index 000000000..8e409740d --- /dev/null +++ b/src/anemoi/datasets/commands/migrate.py @@ -0,0 +1,73 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import yaml + +from . import Command + +LOG = logging.getLogger(__name__) + +ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") +ORDER = {k: i for i, k in enumerate(ORDER)} + + +def order(x: str) -> int: + + if x[0] not in ORDER: + ORDER[x[0]] = len(ORDER) + + return ORDER[x[0]] + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + with open(args.path, "r") as file: + old = yaml.safe_load(file) + + while True: + new = self.migrate(old) + if new == old: + break + old = new + + print(yaml.safe_dump(new, sort_keys=False, indent=2, width=120)) + + def migrate(self, config: dict) -> dict: + result = config.copy() + # if 'loop' in config: + # result.pop('loop') + # # config['loop'] = config['loop'].replace('(', '[').replace(')', ']') + + if "statistics_end" in config.get("output", {}): + result.setdefault("statistics", {}) + result["statistics"]["end"] = result["output"].pop("statistics_end") + + result = {k: v for k, v in sorted(result.items(), key=order)} + + return result + + +command = Recipe From 45365a107b49f7d7b228ee716831a40b4976b0e2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 15:44:27 +0100 Subject: [PATCH 16/79] upadte --- src/anemoi/datasets/commands/migrate.py | 106 ++++++++++++++---- src/anemoi/datasets/commands/recipe.py | 9 +- src/anemoi/datasets/create/__init__.py | 3 +- src/anemoi/datasets/create/input/action.py | 20 +++- .../datasets/create/input/data_sources.py | 2 +- .../datasets/create/input/repeated_dates.py | 24 +++- src/anemoi/datasets/recipe.py | 24 +++- 7 files changed, 156 insertions(+), 32 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8e409740d..cca70ffee 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -29,6 +29,88 @@ def order(x: str) -> int: return ORDER[x[0]] +MIGRATE = { + "output.statistics_end": "statistics.end", + "input.dates.<<": "dates", + "input.dates.join": "input.join", + "input.dates": None, + "has_nans": "statistics.allow_nans", + "loop.dates": "dates", + # 'copyright': 'citation', +} + +SOURCES = {"oper-accumulations": "accumulations", "constants": "forcings"} + + +def _move(config, path, new_path, result): + path = path.split(".") + if new_path is not None: + new_path = new_path.split(".") + + for k in path[:-1]: + if k not in config: + return + config = config[k] + + if path[-1] not in config: + return + + value = config.pop(path[-1]) + + if new_path is None: + return + + for k in new_path[:-1]: + if k not in result: + result[k] = {} + result = result[k] + + result[new_path[-1]] = value + + +def _migrate(config: dict) -> dict: + result = config.copy() + for k, v in MIGRATE.items(): + _move(config, k, v, result) + + if isinstance(result["input"], list): + join = [] + prev = {} + for n in result["input"]: + assert isinstance(n, dict), (n, type(n)) + assert len(n) == 1, (n, type(n)) + name = list(n.keys())[0] + prev[name] = n[name]["kwargs"] + if "inherit" in n[name]: + i = n[name]["inherit"] + n[name]["kwargs"].update(prev[i]) + n[name].pop("inherit") + + data = n[name]["kwargs"] + + src = data.pop("name", "mars") + + join.append({SOURCES.get(src, src): data}) + + result["input"] = dict(join=join) + + result = {k: v for k, v in sorted(result.items(), key=order) if v} + + return result + + +def migrate(old: dict) -> dict: + # return _migrate(old) + for i in range(10): + new = _migrate(old) + if new == old: + # print(json.dumps(new, indent=2, default=str)) + return new + old = new + + return new + + class Recipe(Command): def add_arguments(self, command_parser: Any) -> None: """Add arguments to the command parser. @@ -45,29 +127,9 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: with open(args.path, "r") as file: - old = yaml.safe_load(file) - - while True: - new = self.migrate(old) - if new == old: - break - old = new - - print(yaml.safe_dump(new, sort_keys=False, indent=2, width=120)) - - def migrate(self, config: dict) -> dict: - result = config.copy() - # if 'loop' in config: - # result.pop('loop') - # # config['loop'] = config['loop'].replace('(', '[').replace(')', ']') - - if "statistics_end" in config.get("output", {}): - result.setdefault("statistics", {}) - result["statistics"]["end"] = result["output"].pop("statistics_end") - - result = {k: v for k, v in sorted(result.items(), key=order)} + config = yaml.safe_load(file) - return result + print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) command = Recipe diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py index b028969f6..3045f1a1b 100644 --- a/src/anemoi/datasets/commands/recipe.py +++ b/src/anemoi/datasets/commands/recipe.py @@ -11,7 +11,10 @@ import logging from typing import Any +import yaml + from . import Command +from .migrate import migrate LOG = logging.getLogger(__name__) @@ -33,7 +36,11 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: from anemoi.datasets.create import config_to_python - print(config_to_python(args.path)) + with open(args.path, "r") as file: + config = yaml.safe_load(file) + config = migrate(config) + + print(config_to_python(config)) command = Recipe diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index e50695f4e..b6b51cb69 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1634,6 +1634,7 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An def config_to_python(config: Any) -> Any: config = loader_config(config) + input = build_input_(config, build_output(config.output, None)) prelude = [] @@ -1650,6 +1651,6 @@ def config_to_python(config: Any) -> Any: import black return black.format_str(code, mode=black.Mode()) - except Exception: + except ImportError: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 6057106ef..121d1e387 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import datetime import json import logging import re @@ -15,6 +16,7 @@ from typing import Dict from typing import List +from anemoi.utils.dates import frequency_to_string from earthkit.data.core.order import build_remapping from ...dates.groups import GroupOfDates @@ -160,7 +162,7 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" - def _to_python(self, name: str, config: dict) -> str: + def _to_python(self, name: str, config: dict, **extra: Any) -> str: """Convert the action to Python code. Parameters @@ -169,6 +171,8 @@ def _to_python(self, name: str, config: dict) -> str: The name of the action. config : dict The configuration for the action. + extra : Any + Additional keyword arguments. Returns ------- @@ -205,7 +209,16 @@ def _to_python(self, name: str, config: dict) -> str: "pass", ) - config = json.loads(json.dumps(config)) + def convert(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, datetime.date): + return obj.isoformat() + if isinstance(obj, datetime.timedelta): + return frequency_to_string(obj) + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + config = json.loads(json.dumps(config, default=convert)) assert len(config) == 1, (name, config) assert name in config, (name, config) @@ -218,6 +231,9 @@ def _to_python(self, name: str, config: dict) -> str: return f"r.{name}({config})" params.append(f"{k}={repr(v)}") + for k, v in extra.items(): + params.append(f"{k}={v}") + params = ",".join(params) return f"r.{name}({params})" # return f"{name}({config})" diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 42c610315..5b6282469 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -89,7 +89,7 @@ def __repr__(self) -> str: def python_prelude(self, prelude) -> str: for n, s in zip(self.names, self.sources): - self.sources.python_prelude(prelude) + s.python_prelude(prelude) prelude.append(f"{n}={s.to_python()}") def to_python(self) -> str: diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index 0f5b5730a..bc121c5c9 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -204,6 +204,21 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) self.day: int = day self.hour: Optional[int] = hour + def to_python(self) -> Dict[str, Any]: + """Convert the DateMapper to Python code. + + Returns + ------- + dict + The Python code representation of the DateMapper. + """ + return { + "mode": "climatology", + "year": self.year, + "day": self.day, + "hour": self.hour, + } + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. @@ -351,11 +366,18 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) + self.mode = mode + self.kwargs = kwargs def to_python(self) -> str: """Convert the action to Python code.""" warnings.warn("RepeatedDatesAction.to_python is still a work in progress") - return self.source.to_python() + args = {"mode": self.mode} + args.update(self.kwargs) + return self._to_python("repeated_dates", {"repeated_dates": args}, source=self.source.to_python()) + + def python_prelude(self, prelude: Any) -> None: + self.source.python_prelude(prelude) @trace_select def select(self, group_of_dates: Any) -> JoinResult: diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 3b6bad628..69ce8b7df 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -212,6 +212,8 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) + self.repeated_dates = SourceMaker("repeated_dates", None) + def as_dict(self): result = { "name": self.name, @@ -341,8 +343,18 @@ def dates(self, value): def dump(self, file=sys.stdout): result = self.as_dict() + result["input"] = self.input.as_dict(self) + if self.output: + result["output"] = self.output.as_dict() + + if self.statistics: + result["statistics"] = self.statistics.as_dict() + + if self.build: + result["build"] = self.build.as_dict() + yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) def test(self, output="recipe.zarr"): @@ -370,15 +382,15 @@ def test(self, output="recipe.zarr"): r = Recipe() r.description = "test" - r.dates = ("1900-01-01", "2023-12-31") + r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") - m1 = r.mars(expver="0001") + m1 = r.mars(expver="0001", grid=[20, 20]) m2 = r.mars(expver="0002") m3 = r.mars(expver="0003") - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) # | r.rescale(tp=["mm", "m"]) + r.input = m1 - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) # m0 = r.mars(expver="0000") # c = r.concat( @@ -393,5 +405,9 @@ def test(self, output="recipe.zarr"): # r.input += c + r.output.group_by = "day" + r.build.additions = True + r.statistics.end = "80%" + r.dump() r.test() From d7cc82ca6c579bafd830d857178d703c6d9589b7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 16:12:35 +0100 Subject: [PATCH 17/79] update --- src/anemoi/datasets/commands/migrate.py | 43 ++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index cca70ffee..8dbd2ca7b 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -35,11 +35,17 @@ def order(x: str) -> int: "input.dates.join": "input.join", "input.dates": None, "has_nans": "statistics.allow_nans", + "loop.dates.group_by": "build.group_by", "loop.dates": "dates", - # 'copyright': 'citation', + "copyright": "attribution", } -SOURCES = {"oper-accumulations": "accumulations", "constants": "forcings"} +SOURCES = { + "oper-accumulations": "accumulations", + "era5-accumulations": "accumulations", + "constants": "forcings", + "ensemble-perturbations": "recentre", +} def _move(config, path, new_path, result): @@ -68,12 +74,13 @@ def _move(config, path, new_path, result): result[new_path[-1]] = value -def _migrate(config: dict) -> dict: +def _migrate(config: dict, n) -> dict: result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) if isinstance(result["input"], list): + assert n == 0 join = [] prev = {} for n in result["input"]: @@ -94,6 +101,34 @@ def _migrate(config: dict) -> dict: result["input"] = dict(join=join) + if "join" in result["input"] and n == 0: + join = result["input"].pop("join") + new_join = [] + + for j in join: + + if "label" in j: + j = j["label"] + + if "source" in j: + j = j["source"] + + src = j.pop("name", "mars") + data = j + if "<<" in data: + data.update(data.pop("<<")) + + for k, v in list(data.items()): + if k in ("date", "time"): + if isinstance(v, str) and v.startswith("$"): + del data[k] + + new_join.append({SOURCES.get(src, src): data}) + + print(new_join) + + result["input"]["join"] = new_join + result = {k: v for k, v in sorted(result.items(), key=order) if v} return result @@ -102,7 +137,7 @@ def _migrate(config: dict) -> dict: def migrate(old: dict) -> dict: # return _migrate(old) for i in range(10): - new = _migrate(old) + new = _migrate(old, i) if new == old: # print(json.dumps(new, indent=2, default=str)) return new From f381b00a87e9039f71d953331d80157aba0bd6c1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 16:49:27 +0100 Subject: [PATCH 18/79] update --- src/anemoi/datasets/commands/migrate.py | 20 +++++++++++++++++--- src/anemoi/datasets/create/input/concat.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8dbd2ca7b..ec6d2749a 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -108,7 +108,12 @@ def _migrate(config: dict, n) -> dict: for j in join: if "label" in j: - j = j["label"] + if isinstance(j["label"], str): + j.pop("label") + else: + if j["label"] is not None: + j = j["label"] + j.pop("name", None) if "source" in j: j = j["source"] @@ -125,12 +130,21 @@ def _migrate(config: dict, n) -> dict: new_join.append({SOURCES.get(src, src): data}) - print(new_join) - result["input"]["join"] = new_join + if "join" in result["input"]: + for j in result["input"]["join"]: + k = list(j.keys())[0] + j[k].pop("name", None) + + if "source_or_dataset" in j[k]: + j[k].pop("source_or_dataset", None) + j[k]["template"] = "${input.0.join.0.mars}" + result = {k: v for k, v in sorted(result.items(), key=order) if v} + result.pop("loop", None) + return result diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index cef2a64c5..bd906bd03 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -177,7 +177,7 @@ def to_python(self) -> str: for i, (filtering_dates, action) in enumerate(self.parts): result.append(f"{filtering_dates.to_python()}:{action.to_python()}") - return f"r.concat({{{','.join(result)})" + return f"r.concat({{{','.join(result)}}})" def python_prelude(self, prelude) -> None: for filtering_dates, action in self.parts: From da93ad651ce742475549c31024dd208470fc2906 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 14 May 2025 14:10:40 +0100 Subject: [PATCH 19/79] update --- .gitignore | 1 + src/anemoi/datasets/commands/migrate.py | 270 +++++++++++++++++++----- 2 files changed, 221 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 158ba7bdd..05db8afa7 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,4 @@ Untitled-*.py *.tgz _api/ trace.txt +?/ diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index ec6d2749a..3183508b7 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -9,12 +9,14 @@ import logging +from copy import deepcopy from typing import Any import yaml from . import Command +errors = [] LOG = logging.getLogger(__name__) ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") @@ -31,9 +33,6 @@ def order(x: str) -> int: MIGRATE = { "output.statistics_end": "statistics.end", - "input.dates.<<": "dates", - "input.dates.join": "input.join", - "input.dates": None, "has_nans": "statistics.allow_nans", "loop.dates.group_by": "build.group_by", "loop.dates": "dates", @@ -45,6 +44,9 @@ def order(x: str) -> int: "era5-accumulations": "accumulations", "constants": "forcings", "ensemble-perturbations": "recentre", + "ensemble_perturbations": "recentre", + "perturbations": "recentre", + "custom-regrid": "regrid", } @@ -74,72 +76,238 @@ def _move(config, path, new_path, result): result[new_path[-1]] = value +def _fix_dates(result, config): + dates = config["input"].pop("dates") + assert "join" in dates, dates + result["input"] = dates["join"] + config["input"] = result["input"].copy() + + +def _fix_list(result, config): + result["input"] = dict(join=result["input"]) + config["input"] = result["input"].copy() + + +def _fix_join_0(result, config): + + join = config["input"]["join"] + + new_join = [] + for n in join: + + if "function" in n: + f = n["function"] + name = f.pop("name") + data = _tidy(f) + for k, v in list(data.items()): + if isinstance(v, dict): + if "name" in v: + p = v.pop("name") + data[k] = {SOURCES.get(p, p): _tidy(v)} + + new_join.append({SOURCES.get(name, name): data}) + continue + + new_join.append(n) # {SOURCES.get(src, src): _tidy(data)}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _fix_join_1(result, config): + + join = config["input"].pop("join") + new_join = [] + for n in join: + if isinstance(n, dict): + if len(n) == 1: + if "label" in n: + n = n["label"] + + if isinstance(n, dict): + if len(n) == 2: + if "name" in n and "source" in n: + n.pop("name") + + if isinstance(n, dict): + if len(n) == 1: + if "source" in n: + n = n["source"] + if "<<" in n: + n.update(n.pop("<<")) + name = n.pop("name", "mars") + new_join.append({SOURCES.get(name, name): _tidy(n)}) + continue + + new_join.append(n) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _fix_join_3(result, config): + + join = config["input"].pop("join") + new_join = [] + for n in join: + if not isinstance(n, dict): + return + if len(n) != 1: + return + + name = list(n.keys())[0] + data = n[name] + + new_join.append({SOURCES.get(name, name): data}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _tidy(data): + for k, v in list(data.items()): + if k in ("date", "time"): + if isinstance(v, str) and v.startswith("$"): + del data[k] + + if "name" in data: + assert False, data + name = data.pop("name") + return {SOURCES.get(name, name): _tidy(data)} + + return data + + +def _fix_join_2(result, config): + + join = config["input"]["join"] + + previous = {} + + new_join = [] + for n in join: + + if not isinstance(n, dict): + return + + if len(n) != 1: + return + + what = list(n.keys())[0] + + if n[what] is None: + assert False, (n, what, config["input"]) + + if "kwargs" not in n[what]: + return + + # assert False + + previous[what] = deepcopy(n[what]["kwargs"]) + if "inherit" in n[what]: + previous[what].update(deepcopy(previous[n[what]["inherit"]])) + + data = previous[what].copy() + src = data.pop("name", "mars") + + new_join.append({SOURCES.get(src, src): _tidy(data)}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + def _migrate(config: dict, n) -> dict: result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) - if isinstance(result["input"], list): - assert n == 0 - join = [] - prev = {} - for n in result["input"]: - assert isinstance(n, dict), (n, type(n)) - assert len(n) == 1, (n, type(n)) - name = list(n.keys())[0] - prev[name] = n[name]["kwargs"] - if "inherit" in n[name]: - i = n[name]["inherit"] - n[name]["kwargs"].update(prev[i]) - n[name].pop("inherit") + if "dates" in config["input"]: + _fix_dates(result, config) - data = n[name]["kwargs"] + if isinstance(config["input"], list): + _fix_list(result, config) - src = data.pop("name", "mars") + if "join" in config["input"]: + _fix_join_0(result, config) - join.append({SOURCES.get(src, src): data}) + if "join" in config["input"]: + _fix_join_1(result, config) - result["input"] = dict(join=join) + if "join" in config["input"]: + _fix_join_2(result, config) - if "join" in result["input"] and n == 0: - join = result["input"].pop("join") - new_join = [] + if "join" in config["input"]: + _fix_join_3(result, config) - for j in join: + # _check(result, "1") - if "label" in j: - if isinstance(j["label"], str): - j.pop("label") - else: - if j["label"] is not None: - j = j["label"] - j.pop("name", None) + # if isinstance(result["input"], list): + # assert n == 0 + # join = [] + # prev = {} + # for n in result["input"]: + # assert isinstance(n, dict), (n, type(n)) + # assert len(n) == 1, (n, type(n)) + # name = list(n.keys())[0] + # prev[name] = n[name]["kwargs"] + # if "inherit" in n[name]: + # i = n[name]["inherit"] + # n[name]["kwargs"].update(prev[i]) + # n[name].pop("inherit") - if "source" in j: - j = j["source"] + # data = n[name]["kwargs"] - src = j.pop("name", "mars") - data = j - if "<<" in data: - data.update(data.pop("<<")) + # src = data.pop("name", "mars") - for k, v in list(data.items()): - if k in ("date", "time"): - if isinstance(v, str) and v.startswith("$"): - del data[k] + # join.append({SOURCES.get(src, src): data}) + + # result["input"] = dict(join=join) + # _check(result, "2") - new_join.append({SOURCES.get(src, src): data}) + # if "join" in result["input"] and n == 0: + # join = result["input"].pop("join") + # new_join = [] - result["input"]["join"] = new_join + # for j in join: - if "join" in result["input"]: - for j in result["input"]["join"]: - k = list(j.keys())[0] - j[k].pop("name", None) + # if "label" in j: + # if isinstance(j["label"], str): + # j.pop("label") + # else: + # if j["label"] is not None: + # j = j["label"] + # j.pop("name", None) - if "source_or_dataset" in j[k]: - j[k].pop("source_or_dataset", None) - j[k]["template"] = "${input.0.join.0.mars}" + # if "source" in j: + # j = j["source"] + + # src = j.pop("name", "mars") + # data = j + # if "<<" in data: + # data.update(data.pop("<<")) + + # for k, v in list(data.items()): + # if k in ("date", "time"): + # if isinstance(v, str) and v.startswith("$"): + # del data[k] + + # if "mars" in data: + # new_join.append(data) + # else: + # new_join.append({SOURCES.get(src, src): data}) + + # result["input"]["join"] = new_join + # _check(result, "3") + + # if "join" in result["input"]: + # for j in result["input"]["join"]: + # k = list(j.keys())[0] + # j[k].pop("name", None) + + # if "source_or_dataset" in j[k]: + # j[k].pop("source_or_dataset", None) + # j[k]["template"] = "${input.0.join.0.mars}" + # _check(result, "4") result = {k: v for k, v in sorted(result.items(), key=order) if v} @@ -180,5 +348,7 @@ def run(self, args: Any) -> None: print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) + assert not errors, f"Errors: {errors}" + command = Recipe From 3082edfa84a3a479747768ef57eac77df48bfb07 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 09:45:02 +0000 Subject: [PATCH 20/79] refactor missing --- .gitignore | 3 + src/anemoi/datasets/create/__init__.py | 3 + src/anemoi/datasets/create/input/__init__.py | 15 +- src/anemoi/datasets/create/input/action.py | 388 +++++++----------- .../datasets/create/sources/__init__.py | 11 + src/anemoi/datasets/create/sources/legacy.py | 6 +- 6 files changed, 174 insertions(+), 252 deletions(-) diff --git a/.gitignore b/.gitignore index f3777cde2..746409a43 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,6 @@ _version.py *.to_upload tempCodeRunnerFile.python Untitled-*.py +*.prof +prof/ +*.gz diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 8d74975c5..9c8c4613a 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -21,6 +21,7 @@ import cftime import numpy as np +import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -671,6 +672,8 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) + rich.print("Minimal input dates:", self.minimal_input) + variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index c64f8275b..9f30f1abc 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -12,6 +12,8 @@ from typing import Any from typing import Union +import rich + from anemoi.datasets.dates.groups import GroupOfDates from .trace import trace_select @@ -22,7 +24,11 @@ class Context: """Context for building input data.""" - pass + use_grib_paramid = False + + def trace(self, emoji, message) -> None: + + rich.print(f"{emoji}: {message}") class InputBuilder: @@ -67,13 +73,12 @@ def select(self, group_of_dates: GroupOfDates) -> Any: Any Selected data. """ - from .action import ActionContext from .action import action_factory """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) + context = Context() + action = action_factory(self.config, self.action_path) + return action(context, group_of_dates) def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index eadf01339..7c671db2f 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,253 +7,153 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json import logging -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from earthkit.data.core.order import build_remapping +LOG = logging.getLogger(__name__) -from ...dates.groups import GroupOfDates -from .context import Context -from .template import substitute -LOG = logging.getLogger(__name__) +class Predicate: + def __init__(self, config): + self.config = config + + def __repr__(self): + return f"Predicate({self.config})" + + def match(self, dates): + # Just a demo + return True + + +class Concat: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a dict {list}" + + self.choices = [] + + for item in config: + + assert "dates" in item, f"Value must contain the key 'date' {item}" + predicate = Predicate(item.pop("dates")) + action = action_factory(item) + + self.choices.append((predicate, action)) + + def __repr__(self): + return f"Concat({self.choices})" + + def __call__(self, context, group_of_dates): + + for predicate, action in self.choices: + if predicate.match(group_of_dates): + return action(group_of_dates) + + raise ValueError(f"No matching predicate for dates: {group_of_dates}") + + +class Join: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a list {config}" + self.actions = [action_factory(item) for item in config] + + def __repr__(self): + return f"Join({self.actions})" + + def __call__(self, context, group_of_dates): + results = [] + for action in self.actions: + results.append(action(context, group_of_dates)) + return results + + +class Pipe: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a list {config}" + self.actions = [action_factory(item) for item in config] + + def __repr__(self): + return f"Pipe({self.actions})" + + def __call__(self, context, dates): + result = None + for action in self.actions: + if result is None: + result = action(dates) + else: + result = action(result) + return result + + +class Function: + def __init__(self, config): + self.config = config + + # # if self._source: + # # self.source = action_factory(config[self._source]) + # # else: + # # self.source = None + + # def __repr__(self): + # return f"{self.__class__.__name__}({self.config})" + + # def __call__(self, context, dates): + # # Just a demo, in real case it would do something with the dates + + # config = self.config.copy() + # if self.source: + # config[self._source] = self.source(dates) + + # return {self.__class__.__name__: self.config, "dates": dates} + + +class SourceFunction(Function): + def __init__(self, config): + from anemoi.datasets.create.sources import create_source + + super().__init__(config) + config["_type"] = self.name + self.source = create_source(self, config) + + def __call__(self, context, group_of_dates): + return self.source.execute(context, group_of_dates.dates) + + +class FilterFunction(Function): + pass + + +def new_source(name, source=None): + return type(name.title(), (SourceFunction,), {"name": name}) + + +def new_filter(name, source=None): + return type(name.title(), (FilterFunction,), {"name": name}) + + +KLASS = { + "concat": Concat, + "join": Join, + "pipe": Pipe, +} + +LEN_KLASS = len(KLASS) + + +def make(key, config): + + if LEN_KLASS == len(KLASS): + # Load pluggins + from anemoi.datasets.create.sources import registered_sources + + for name in registered_sources(): + assert name not in KLASS, f"Duplicate source name: {name}" + KLASS[name] = new_source(name) + + return KLASS[key](config) + +def action_factory(data, *path): + assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert len(data) == 1, "Input data must contain exactly one key-value pair" -class Action: - """Represents an action to be performed within a given context. - - Attributes - ---------- - context : ActionContext - The context in which the action exists. - kwargs : Dict[str, Any] - Additional keyword arguments. - args : Any - Additional positional arguments. - action_path : List[str] - The action path. - """ - - def __init__( - self, context: "ActionContext", action_path: List[str], /, *args: Any, **kwargs: Dict[str, Any] - ) -> None: - """Initialize an Action instance. - - Parameters - ---------- - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - args : Any - Additional positional arguments. - kwargs : Dict[str, Any] - Additional keyword arguments. - """ - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x: str) -> str: - """Shorten the string representation if it exceeds 1000 characters. - - Parameters - ---------- - x : str - The string to shorten. - - Returns - ------- - str - The shortened string. - """ - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def _repr(self, *args: Any, _indent_: str = "\n", _inline_: str = "", **kwargs: Any) -> str: - """Generate a string representation of the Action instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str, optional - The indentation string, by default "\n". - _inline_ : str, optional - The inline string, by default "". - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Action instance. - - Returns - ------- - str - The string representation. - """ - return self._repr() - - def select(self, dates: object, **kwargs: Any) -> None: - """Select dates for the action. - - Parameters - ---------- - dates : object - The dates to select. - kwargs : Any - Additional keyword arguments. - """ - self._raise_not_implemented() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the selection of a group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates to trace. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({group_of_dates})" - - -class ActionContext(Context): - """Represents the context in which an action is performed. - - Attributes - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - """Initialize an ActionContext instance. - - Parameters - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -def action_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str]) -> Action: - """Factory function to create an Action instance based on the configuration. - - Parameters - ---------- - config : Dict[str, Any] - The action configuration. - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - - Returns - ------- - Action - The created Action instance. - """ - from .concat import ConcatAction - from .data_sources import DataSourcesAction - from .function import FunctionAction - from .join import JoinAction - from .pipe import PipeAction - from .repeated_dates import RepeatedDatesAction - - # from .data_sources import DataSourcesAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - "data_sources": DataSourcesAction, - "data-sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - "repeated_dates": RepeatedDatesAction, - "repeated-dates": RepeatedDatesAction, - }.get(key) - - if cls is None: - from ..sources import create_source - - source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source) - - return cls(context, action_path + [key], *args, **kwargs) + key, value = next(iter(data.items())) + return make(key, value) diff --git a/src/anemoi/datasets/create/sources/__init__.py b/src/anemoi/datasets/create/sources/__init__.py index f8b99f36d..1710d1303 100644 --- a/src/anemoi/datasets/create/sources/__init__.py +++ b/src/anemoi/datasets/create/sources/__init__.py @@ -34,3 +34,14 @@ def create_source(context: Any, config: Any) -> Any: The created source. """ return source_registry.from_config(config, context) + + +def registered_sources() -> list[str]: + """Get a list of registered source names. + + Returns + ------- + list[str] + A list of names of registered sources. + """ + return source_registry.registered diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d72d0b3f4..f8035ee16 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -71,13 +71,13 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, dates) -> Any: + def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(self.context, (self.args, self.kwargs)) + args, kwargs = resolve(context, (self.args, self.kwargs)) try: - return execute(self.context, dates, *args, **kwargs) + return execute(context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") From ef4a5c9507dc8e11822bbbcc8d6e448b18054d17 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:19:48 +0000 Subject: [PATCH 21/79] add references --- src/anemoi/datasets/create/input/__init__.py | 65 ++++++--- src/anemoi/datasets/create/input/action.py | 141 +++++++++++-------- 2 files changed, 131 insertions(+), 75 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 9f30f1abc..d561a9fb4 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -26,10 +26,56 @@ class Context: use_grib_paramid = False + def __init__(self): + self.results = {} + def trace(self, emoji, message) -> None: rich.print(f"{emoji}: {message}") + def register(self, data: Any, path: list[str]) -> Any: + """Register data in the context. + + Parameters + ---------- + data : Any + Data to register. + path : list[str] + Path where the data should be registered. + + Returns + ------- + Any + Registered data. + """ + # This is a placeholder for actual registration logic. + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + +class FieldContext(Context): + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument.dates + class InputBuilder: """Builder class for creating input data from configuration and data sources.""" @@ -76,25 +122,10 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = Context() - action = action_factory(self.config, self.action_path) + context = FieldContext() + action = action_factory(self.config, "input") return action(context, group_of_dates) - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) - def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7c671db2f..8fdbf10be 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,6 +9,8 @@ import logging +import rich + LOG = logging.getLogger(__name__) @@ -24,8 +26,16 @@ def match(self, dates): return True -class Concat: - def __init__(self, config): +class Action: + def __init__(self, config, *path): + self.config = config + self.path = path + + +class Concat(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + assert isinstance(config, list), f"Value must be a dict {list}" self.choices = [] @@ -41,104 +51,119 @@ def __init__(self, config): def __repr__(self): return f"Concat({self.choices})" - def __call__(self, context, group_of_dates): + def __call__(self, context, argument): for predicate, action in self.choices: - if predicate.match(group_of_dates): - return action(group_of_dates) + if predicate.match(argument): + return context.register( + action(context, argument), + self.path, + ) - raise ValueError(f"No matching predicate for dates: {group_of_dates}") + raise ValueError(f"No matching predicate for dates: {argument}") -class Join: - def __init__(self, config): +class Join(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item) for item in config] + + self.actions = [ + action_factory( + item, + *path, + "join", + str(i), + ) + for i, item in enumerate(config) + ] def __repr__(self): return f"Join({self.actions})" - def __call__(self, context, group_of_dates): - results = [] + def __call__(self, context, argument): + results = context.empty_result() for action in self.actions: - results.append(action(context, group_of_dates)) - return results + results += action(context, argument) + return context.register( + results, + self.path, + ) -class Pipe: - def __init__(self, config): +class Pipe(Action): + def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item) for item in config] + super().__init__(config, *path) + self.actions = [ + action_factory( + item, + *path, + "pipe", + str(i), + ) + for i, item in enumerate(config) + ] def __repr__(self): return f"Pipe({self.actions})" - def __call__(self, context, dates): - result = None - for action in self.actions: - if result is None: - result = action(dates) - else: - result = action(result) - return result - - -class Function: - def __init__(self, config): - self.config = config + def __call__(self, context, argument): + result = context.empty_result() - # # if self._source: - # # self.source = action_factory(config[self._source]) - # # else: - # # self.source = None - - # def __repr__(self): - # return f"{self.__class__.__name__}({self.config})" + for i, action in enumerate(self.actions): + if i == 0: + result = action(context, argument) + else: + result = action(context, result) - # def __call__(self, context, dates): - # # Just a demo, in real case it would do something with the dates + return context.register( + result, + self.path, + ) - # config = self.config.copy() - # if self.source: - # config[self._source] = self.source(dates) - # return {self.__class__.__name__: self.config, "dates": dates} +class Function(Action): + def __init__(self, config, *path): + super().__init__(config, *path, self.name) class SourceFunction(Function): - def __init__(self, config): + + def __call__(self, context, argument): from anemoi.datasets.create.sources import create_source - super().__init__(config) - config["_type"] = self.name - self.source = create_source(self, config) + config = context.resolve(self.config) # Substitute the ${} variables in the config + config["_type"] = self.name # Find a better way to do this + source = create_source(self, config) + + rich.print(f"Executing source {self.name} from {config}") - def __call__(self, context, group_of_dates): - return self.source.execute(context, group_of_dates.dates) + return context.register( + source.execute(context, context.source_argument(argument)), + self.path, + ) class FilterFunction(Function): pass -def new_source(name, source=None): +def new_source(name): return type(name.title(), (SourceFunction,), {"name": name}) -def new_filter(name, source=None): +def new_filter(name): return type(name.title(), (FilterFunction,), {"name": name}) -KLASS = { - "concat": Concat, - "join": Join, - "pipe": Pipe, -} +KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} LEN_KLASS = len(KLASS) -def make(key, config): +def make(key, config, path): if LEN_KLASS == len(KLASS): # Load pluggins @@ -148,7 +173,7 @@ def make(key, config): assert name not in KLASS, f"Duplicate source name: {name}" KLASS[name] = new_source(name) - return KLASS[key](config) + return KLASS[key](config, *path) def action_factory(data, *path): @@ -156,4 +181,4 @@ def action_factory(data, *path): assert len(data) == 1, "Input data must contain exactly one key-value pair" key, value = next(iter(data.items())) - return make(key, value) + return make(key, value, path) From 93410d56e5b8c0312da55185b60b85ba4269574a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:58:12 +0000 Subject: [PATCH 22/79] refactor --- src/anemoi/datasets/create/input/__init__.py | 71 +------- src/anemoi/datasets/create/input/action.py | 1 + src/anemoi/datasets/create/input/concat.py | 164 ------------------ src/anemoi/datasets/create/input/context.py | 89 ---------- .../datasets/create/input/context/__init__.py | 70 ++++++++ .../datasets/create/input/context/field.py | 37 ++++ .../datasets/create/input/data_sources.py | 2 +- src/anemoi/datasets/create/input/empty.py | 2 +- src/anemoi/datasets/create/input/function.py | 2 +- src/anemoi/datasets/create/input/join.py | 130 -------------- src/anemoi/datasets/create/input/pipe.py | 66 ------- .../datasets/create/input/repeated_dates.py | 2 +- .../datasets/create/input/result/__init__.py | 21 +++ .../input/{result.py => result/field.py} | 6 +- src/anemoi/datasets/create/input/step.py | 2 +- 15 files changed, 141 insertions(+), 524 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/concat.py delete mode 100644 src/anemoi/datasets/create/input/context.py create mode 100644 src/anemoi/datasets/create/input/context/__init__.py create mode 100644 src/anemoi/datasets/create/input/context/field.py delete mode 100644 src/anemoi/datasets/create/input/join.py delete mode 100644 src/anemoi/datasets/create/input/pipe.py create mode 100644 src/anemoi/datasets/create/input/result/__init__.py rename src/anemoi/datasets/create/input/{result.py => result/field.py} (99%) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index d561a9fb4..19269f587 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,75 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from copy import deepcopy from typing import Any from typing import Union -import rich - +from anemoi.datasets.create.input.context.field import FieldContext from anemoi.datasets.dates.groups import GroupOfDates -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class Context: - """Context for building input data.""" - - use_grib_paramid = False - - def __init__(self): - self.results = {} - - def trace(self, emoji, message) -> None: - - rich.print(f"{emoji}: {message}") - - def register(self, data: Any, path: list[str]) -> Any: - """Register data in the context. - - Parameters - ---------- - data : Any - Data to register. - path : list[str] - Path where the data should be registered. - - Returns - ------- - Any - Registered data. - """ - # This is a placeholder for actual registration logic. - rich.print(f"Registering data at path: {path}") - self.results[tuple(path)] = data - return data - - def resolve(self, config): - config = config.copy() - - for key, value in list(config.items()): - if isinstance(value, str) and value.startswith("${") and value.endswith("}"): - path = tuple(value[2:-1].split(".")) - if path in self.results: - config[key] = self.results[path] - else: - raise KeyError(f"Path {path} not found in results: {self.results.keys()}") - - return config - - -class FieldContext(Context): - def empty_result(self) -> Any: - import earthkit.data as ekd - - return ekd.from_source("empty") - - def source_argument(self, argument: Any) -> Any: - return argument.dates - class InputBuilder: """Builder class for creating input data from configuration and data sources.""" @@ -105,7 +43,6 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) self.config = config self.action_path = ["input"] - @trace_select def select(self, group_of_dates: GroupOfDates) -> Any: """Select data based on the group of dates. @@ -122,9 +59,9 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = FieldContext() + context = FieldContext(**self.kwargs) action = action_factory(self.config, "input") - return action(context, group_of_dates) + return context.create_result(action(context, group_of_dates)) def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 8fdbf10be..76d99f4b7 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -23,6 +23,7 @@ def __repr__(self): def match(self, dates): # Just a demo + raise NotImplementedError("Not yet implemented") return True diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py deleted file mode 100644 index 5399bbc1f..000000000 --- a/src/anemoi/datasets/create/input/concat.py +++ /dev/null @@ -1,164 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from functools import cached_property -from typing import Any -from typing import Dict -from typing import List -from typing import Union - -from earthkit.data import FieldList - -from anemoi.datasets.dates import DatesProvider - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class ConcatResult(Result): - """Represents the result of concatenating multiple results.""" - - def __init__( - self, - context: object, - action_path: List[str], - group_of_dates: GroupOfDates, - results: List[Result], - **kwargs: Any, - ) -> None: - """Initializes a ConcatResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, group_of_dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the concatenated datasource from all results.""" - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self) -> List[str]: - """Returns the list of variables, ensuring all results have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self) -> str: - """Returns a string representation of the ConcatResult instance. - - Returns - ------- - str - A string representation of the ConcatResult instance. - """ - content = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class ConcatAction(Action): - """Represents an action that concatenates multiple actions based on their dates.""" - - def __init__(self, context: object, action_path: List[str], *configs: Dict[str, Any]) -> None: - """Initializes a ConcatAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - configs : Dict[str, Any] - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self) -> str: - """Returns a string representation of the ConcatAction instance. - - Returns - ------- - str - A string representation of the ConcatAction instance. - """ - content = "\n".join([str(i) for i in self.parts]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResult]: - """Selects the concatenated result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - Union[ConcatResult, EmptyResult] - The concatenated result or an empty result. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, group_of_dates) - - return ConcatResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py deleted file mode 100644 index 35784dba7..000000000 --- a/src/anemoi/datasets/create/input/context.py +++ /dev/null @@ -1,89 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import textwrap -from typing import Any -from typing import List -from typing import Tuple -from typing import Union - -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - -LOG = logging.getLogger(__name__) - - -class Context: - """Class to handle the build context in the dataset creation process.""" - - def __init__(self) -> None: - """Initializes a Context instance.""" - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key: Union[List, Tuple]) -> None: - """Marks a reference as needed. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key: Union[List, Tuple], result: Any) -> None: - """Notifies that a result is available for a reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - result : Any - The result object. - """ - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key: Union[List, Tuple]) -> Any: - """Retrieves the result for a given reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - - Returns - ------- - Any - The result for the given reference. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py new file mode 100644 index 000000000..23d31e6f3 --- /dev/null +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -0,0 +1,70 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Context(ABC): + """Context for building input data.""" + + def __init__(self): + self.results = {} + + def trace(self, emoji, message) -> None: + + rich.print(f"{emoji}: {message}") + + def register(self, data: Any, path: list[str]) -> Any: + """Register data in the context. + + Parameters + ---------- + data : Any + Data to register. + path : list[str] + Path where the data should be registered. + + Returns + ------- + Any + Registered data. + """ + # This is a placeholder for actual registration logic. + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + @abstractmethod + def empty_result(self) -> Any: ... + + @abstractmethod + def source_argument(self, argument: Any) -> Any: ... + + @abstractmethod + def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py new file mode 100644 index 000000000..c3a95f8cb --- /dev/null +++ b/src/anemoi/datasets/create/input/context/field.py @@ -0,0 +1,37 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any +from typing import Dict + +from earthkit.data.core.order import build_remapping + +from . import Context + + +class FieldContext(Context): + + def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: + super().__init__() + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument.dates + + def create_result(self, data): + return FieldResult(self, data) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index b95f85568..a64f9a7ef 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -20,7 +20,7 @@ from .action import Action from .action import action_factory from .misc import _tidy -from .result import Result +from .result.field import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py index 410b4c973..fdc959f0f 100644 --- a/src/anemoi/datasets/create/input/empty.py +++ b/src/anemoi/datasets/create/input/empty.py @@ -14,7 +14,7 @@ from earthkit.data import FieldList from .misc import assert_fieldlist -from .result import Result +from .result.field import Result from .trace import trace_datasource LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index 4d3d21b22..586003bb2 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -18,7 +18,7 @@ from .action import Action from .misc import _tidy from .misc import assert_fieldlist -from .result import Result +from .result.field import Result from .template import notify_result from .template import substitute from .trace import trace diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py deleted file mode 100644 index ba24c7072..000000000 --- a/src/anemoi/datasets/create/input/join.py +++ /dev/null @@ -1,130 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import List - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class JoinResult(Result): - """Represents a result that combines multiple results. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - - def __init__( - self, context: object, action_path: list, group_of_dates: GroupOfDates, results: List[Result], **kwargs: Any - ) -> None: - """Initializes a JoinResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - super().__init__(context, action_path, group_of_dates) - self.results: List[Result] = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the combined datasource from all results.""" - ds: FieldList = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self) -> str: - """Returns a string representation of the JoinResult instance.""" - content: str = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class JoinAction(Action): - """Represents an action that combines multiple actions. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - actions : List[Action] - The list of actions. - """ - - def __init__(self, context: object, action_path: list, *configs: dict) -> None: - """Initializes a JoinAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - *configs : dict - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def __repr__(self) -> str: - """Returns a string representation of the JoinAction instance.""" - content: str = "\n".join([str(i) for i in self.actions]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> JoinResult: - """Selects the results for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - JoinResult - The combined result for the given group of dates. - """ - results: List[Result] = [a.select(group_of_dates) for a in self.actions] - return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py deleted file mode 100644 index 6c9fea0df..000000000 --- a/src/anemoi/datasets/create/input/pipe.py +++ /dev/null @@ -1,66 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -from typing import Any - -from .action import Action -from .action import action_factory -from .step import step_factory -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class PipeAction(Action): - """A class to represent a pipeline of actions.""" - - def __init__(self, context: Any, action_path: list, *configs: dict) -> None: - """Initialize the PipeAction. - - Parameters - ---------- - context : Any - The context for the action. - action_path : list - The path of the action. - configs : dict - The configurations for the actions. - """ - super().__init__(context, action_path, *configs) - if len(configs) <= 1: - raise ValueError( - f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" - ) - - current: Any = action_factory(configs[0], context, action_path + ["0"]) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.last_step: Any = current - - @trace_select - def select(self, group_of_dates: Any) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to select data for. - - Returns - ------- - Any - The selected data. - """ - return self.last_step.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the PipeAction.""" - return f"PipeAction({self.last_step})" diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index ebf13b36e..37cc2dcf0 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -27,7 +27,7 @@ from .action import Action from .action import action_factory from .join import JoinResult -from .result import Result +from .result.field import Result from .trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result/__init__.py new file mode 100644 index 000000000..04c2ab733 --- /dev/null +++ b/src/anemoi/datasets/create/input/result/__init__.py @@ -0,0 +1,21 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Result(ABC): + pass diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result/field.py similarity index 99% rename from src/anemoi/datasets/create/input/result.py rename to src/anemoi/datasets/create/input/result/field.py index de5388fd6..cd334efd3 100644 --- a/src/anemoi/datasets/create/input/result.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,9 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from .action import ActionContext -from .trace import trace -from .trace import trace_datasource +from ..action import ActionContext +from ..trace import trace +from ..trace import trace_datasource LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py index e99717094..c88da3e71 100644 --- a/src/anemoi/datasets/create/input/step.py +++ b/src/anemoi/datasets/create/input/step.py @@ -19,7 +19,7 @@ from .action import Action from .action import ActionContext from .context import Context -from .result import Result +from .result.field import Result from .template import notify_result from .trace import trace_datasource from .trace import trace_select From 1df0ef758e3f3d0d9ad0fe57c2c0494c34012401 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:58:50 +0000 Subject: [PATCH 23/79] refactor --- src/anemoi/datasets/create/input/context/field.py | 1 + src/anemoi/datasets/create/input/result/field.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index c3a95f8cb..aee08c2c3 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -14,6 +14,7 @@ from earthkit.data.core.order import build_remapping from . import Context +from ..result.field import FieldResult class FieldContext(Context): diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index cd334efd3..66e29cec3 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -282,7 +282,7 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Result: +class FieldResult: """Class to represent the result of an action in the dataset creation process.""" empty: bool = False From 18df4ebe5fa4d39aef59820d70fbdd23b3d1fbee Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 11:10:07 +0000 Subject: [PATCH 24/79] refactor --- src/anemoi/datasets/create/__init__.py | 2 +- src/anemoi/datasets/create/input/__init__.py | 8 +- .../datasets/create/input/context/__init__.py | 18 +--- .../datasets/create/input/context/field.py | 8 +- .../datasets/create/input/result/field.py | 98 ++----------------- 5 files changed, 21 insertions(+), 113 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 9c8c4613a..635433e47 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -870,7 +870,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(group_of_dates=group) + result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 19269f587..ed39c9211 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -43,12 +43,12 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) self.config = config self.action_path = ["input"] - def select(self, group_of_dates: GroupOfDates) -> Any: + def select(self, argument: GroupOfDates) -> Any: """Select data based on the group of dates. Parameters ---------- - group_of_dates : GroupOfDates + argument : GroupOfDates Group of dates to select data for. Returns @@ -59,9 +59,9 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = FieldContext(**self.kwargs) + context = FieldContext(argument, **self.kwargs) action = action_factory(self.config, "input") - return context.create_result(action(context, group_of_dates)) + return context.create_result(action(context, argument)) def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 23d31e6f3..11aa1570c 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -20,29 +20,15 @@ class Context(ABC): """Context for building input data.""" - def __init__(self): + def __init__(self, /, argument: Any) -> None: self.results = {} + self.argument = argument def trace(self, emoji, message) -> None: rich.print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: - """Register data in the context. - - Parameters - ---------- - data : Any - Data to register. - path : list[str] - Path where the data should be registered. - - Returns - ------- - Any - Registered data. - """ - # This is a placeholder for actual registration logic. rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index aee08c2c3..4586e1d31 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -13,14 +13,16 @@ from earthkit.data.core.order import build_remapping -from . import Context from ..result.field import FieldResult +from . import Context class FieldContext(Context): - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - super().__init__() + def __init__( + self, /, argument: Any, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool + ) -> None: + super().__init__(argument) self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 66e29cec3..dd238998e 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from ..action import ActionContext -from ..trace import trace -from ..trace import trace_datasource +from . import Result LOG = logging.getLogger(__name__) @@ -282,40 +280,22 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class FieldResult: +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: ActionContext, action_path: List[str], dates: Any) -> None: - """Initialize a Result instance. + def __init__(self, context: Any, datasource: Any) -> None: - Parameters - ---------- - context : ActionContext - The context in which the result exists. - action_path : list of str - The action path. - dates : Any - The dates associated with the result. - """ from anemoi.datasets.dates.groups import GroupOfDates - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - self.context: Any = context - self.group_of_dates: Any = dates - self.action_path: List[str] = action_path - - @property - @trace_datasource - def datasource(self) -> Any: - """Retrieve the data source for the result.""" - self._raise_not_implemented() + self.datasource = datasource + self.group_of_dates = context.argument + assert isinstance( + self.group_of_dates, GroupOfDates + ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" @property def data_request(self) -> Dict[str, Any]: @@ -330,7 +310,7 @@ def get_cube(self) -> Any: Any The data cube. """ - trace("🧊", f"getting cube from {self.__class__.__name__}") + ds: Any = self.datasource remapping: Any = self.context.remapping @@ -523,66 +503,6 @@ def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: print() exit(1) - def _repr(self, *args: Any, _indent_: str = "\n", **kwargs: Any) -> str: - """Return the string representation of the Result instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str - Indentation string. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more: str = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates: str = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt: str = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Result instance.""" - return self._repr() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Trace the data source for the result. - - Parameters - ---------- - args : Any - Additional positional arguments. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({self.group_of_dates})" - def build_coords(self) -> None: """Build the coordinates for the result.""" if self._coords_already_built: From 3341d4cf1cdf3a4c91de8f0fc8371296379cb888 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 15:20:22 +0000 Subject: [PATCH 25/79] update --- src/anemoi/datasets/create/input/__init__.py | 19 +- src/anemoi/datasets/create/input/action.py | 171 ++++++---- .../datasets/create/input/context/__init__.py | 15 +- .../datasets/create/input/context/field.py | 18 +- src/anemoi/datasets/create/input/empty.py | 54 --- src/anemoi/datasets/create/input/filter.py | 118 ------- src/anemoi/datasets/create/input/function.py | 233 ------------- src/anemoi/datasets/create/input/step.py | 191 ----------- src/anemoi/datasets/create/input/template.py | 162 --------- .../datasets/create/sources/__init__.py | 11 - .../datasets/create/sources/constants.py | 2 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/legacy.py | 5 +- .../datasets/create/sources/repeated_dates.py | 319 ++++++++++++++++++ 14 files changed, 455 insertions(+), 865 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/empty.py delete mode 100644 src/anemoi/datasets/create/input/filter.py delete mode 100644 src/anemoi/datasets/create/input/function.py delete mode 100644 src/anemoi/datasets/create/input/step.py delete mode 100644 src/anemoi/datasets/create/input/template.py create mode 100644 src/anemoi/datasets/create/sources/repeated_dates.py diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index ed39c9211..cde4f6eaa 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -12,7 +12,6 @@ from typing import Union from anemoi.datasets.create.input.context.field import FieldContext -from anemoi.datasets.dates.groups import GroupOfDates class InputBuilder: @@ -41,9 +40,8 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) ) ) self.config = config - self.action_path = ["input"] - def select(self, argument: GroupOfDates) -> Any: + def select(self, argument) -> Any: """Select data based on the group of dates. Parameters @@ -62,18 +60,3 @@ def select(self, argument: GroupOfDates) -> Any: context = FieldContext(argument, **self.kwargs) action = action_factory(self.config, "input") return context.create_result(action(context, argument)) - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the select operation. - - Parameters - ---------- - group_of_dates : GroupOfDates - Group of dates to select data for. - - Returns - ------- - str - Trace string. - """ - return f"InputBuilder({group_of_dates})" diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 76d99f4b7..28e521b97 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -11,20 +11,9 @@ import rich -LOG = logging.getLogger(__name__) - - -class Predicate: - def __init__(self, config): - self.config = config - - def __repr__(self): - return f"Predicate({self.config})" +from anemoi.datasets.dates import DatesProvider - def match(self, dates): - # Just a demo - raise NotImplementedError("Not yet implemented") - return True +LOG = logging.getLogger(__name__) class Action: @@ -44,24 +33,26 @@ def __init__(self, config, *path): for item in config: assert "dates" in item, f"Value must contain the key 'date' {item}" - predicate = Predicate(item.pop("dates")) + dates = item.pop("dates") + filtering_dates = DatesProvider.from_config(**dates) action = action_factory(item) - self.choices.append((predicate, action)) + self.choices.append((filtering_dates, action)) def __repr__(self): return f"Concat({self.choices})" def __call__(self, context, argument): - for predicate, action in self.choices: - if predicate.match(argument): - return context.register( - action(context, argument), - self.path, - ) + results = context.empty_result() - raise ValueError(f"No matching predicate for dates: {argument}") + for filtering_dates, action in self.choices: + dates = context.matching_dates(filtering_dates, argument) + if len(dates) == 0: + continue + results += action(context, dates) + + return context.register(results, self.path) class Join(Action): @@ -70,42 +61,25 @@ def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [ - action_factory( - item, - *path, - "join", - str(i), - ) - for i, item in enumerate(config) - ] + self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Join({self.actions})" def __call__(self, context, argument): results = context.empty_result() + for action in self.actions: results += action(context, argument) - return context.register( - results, - self.path, - ) + + return context.register(results, self.path) class Pipe(Action): def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" super().__init__(config, *path) - self.actions = [ - action_factory( - item, - *path, - "pipe", - str(i), - ) - for i, item in enumerate(config) - ] + self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Pipe({self.actions})" @@ -119,44 +93,88 @@ def __call__(self, context, argument): else: result = action(context, result) - return context.register( - result, - self.path, - ) + return context.register(result, self.path) class Function(Action): def __init__(self, config, *path): super().__init__(config, *path, self.name) - -class SourceFunction(Function): - def __call__(self, context, argument): - from anemoi.datasets.create.sources import create_source config = context.resolve(self.config) # Substitute the ${} variables in the config + config["_type"] = self.name # Find a better way to do this - source = create_source(self, config) + + source = self.create_object(config) rich.print(f"Executing source {self.name} from {config}") - return context.register( - source.execute(context, context.source_argument(argument)), - self.path, - ) + return context.register(self.call_object(context, source, argument), self.path) + + +class DatasetSourceMixin: + def create_object(self, config): + from anemoi.datasets.create.sources import create_source as create_datasets_source + + return create_datasets_source(self, config) + + def call_object(self, context, source, argument): + return source.execute(context, context.source_argument(argument)) + + +class DatasetFilterMixin: + def create_object(self, config): + from anemoi.datasets.create.filters import create_filter as create_datasets_filter + + return create_datasets_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.execute(context.filter_argument(argument)) + + +class TransformSourceMixin: + def create_object(self, config): + from anemoi.transform.sources import create_source as create_transform_source + + return create_transform_source(self, config) + + +class TransformFilterMixin: + def create_object(self, config): + from anemoi.transform.filters import create_filter as create_transform_filter + + return create_transform_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.forward(context.filter_argument(argument)) class FilterFunction(Function): - pass + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) + +def _make_name(name, what): + name = name.replace("_", "-") + name = "".join(x.title() for x in name.split("-")) + return name + what.title() -def new_source(name): - return type(name.title(), (SourceFunction,), {"name": name}) +def new_source(name, mixin): + return type( + _make_name(name, "source"), + (Function, mixin), + {"name": name}, + ) -def new_filter(name): - return type(name.title(), (FilterFunction,), {"name": name}) + +def new_filter(name, mixin): + return type( + _make_name(name, "filter"), + (Function, mixin), + {"name": name}, + ) KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} @@ -167,14 +185,33 @@ def new_filter(name): def make(key, config, path): if LEN_KLASS == len(KLASS): + # Load pluggins - from anemoi.datasets.create.sources import registered_sources + from anemoi.transform.filters import filter_registry as transform_filter_registry + from anemoi.transform.sources import source_registry as transform_source_registry + + from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry + from anemoi.datasets.create.sources import source_registry as dataset_source_registry + + # Register sources, local first + for name in dataset_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + + for name in transform_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + + # Register filters, local first + for name in dataset_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) - for name in registered_sources(): - assert name not in KLASS, f"Duplicate source name: {name}" - KLASS[name] = new_source(name) + for name in transform_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) - return KLASS[key](config, *path) + return KLASS[key.replace("_", "-")](config, *path) def action_factory(data, *path): diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 11aa1570c..26d449659 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -22,13 +22,18 @@ class Context(ABC): def __init__(self, /, argument: Any) -> None: self.results = {} + self.cache = {} self.argument = argument - def trace(self, emoji, message) -> None: + def trace(self, emoji, *message) -> None: rich.print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: + + if not path: + return data + rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -46,11 +51,13 @@ def resolve(self, config): return config - @abstractmethod - def empty_result(self) -> Any: ... + def create_source(self, config: Any) -> Any: + from anemoi.datasets.create.input.action import action_factory + + return action_factory(config) @abstractmethod - def source_argument(self, argument: Any) -> Any: ... + def empty_result(self) -> Any: ... @abstractmethod def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 4586e1d31..c3456d89f 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -20,7 +20,13 @@ class FieldContext(Context): def __init__( - self, /, argument: Any, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool + self, + /, + argument: Any, + order_by: str, + flatten_grid: bool, + remapping: Dict[str, Any], + use_grib_paramid: bool, ) -> None: super().__init__(argument) self.order_by = order_by @@ -34,7 +40,15 @@ def empty_result(self) -> Any: return ekd.from_source("empty") def source_argument(self, argument: Any) -> Any: - return argument.dates + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument def create_result(self, data): return FieldResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py deleted file mode 100644 index fdc959f0f..000000000 --- a/src/anemoi/datasets/create/input/empty.py +++ /dev/null @@ -1,54 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import List - -from earthkit.data import FieldList - -from .misc import assert_fieldlist -from .result.field import Result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class EmptyResult(Result): - """Class to represent an empty result in the dataset creation process.""" - - empty = True - - def __init__(self, context: object, action_path: list, dates: object) -> None: - """Initializes an EmptyResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - dates : object - The dates object. - """ - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns an empty datasource.""" - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self) -> List[str]: - """Returns an empty list of variables.""" - return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py deleted file mode 100644 index 289bb3602..000000000 --- a/src/anemoi/datasets/create/input/filter.py +++ /dev/null @@ -1,118 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Type - -from earthkit.data import FieldList - -from .function import FunctionContext -from .misc import _tidy -from .misc import assert_fieldlist -from .step import StepAction -from .step import StepResult -from .template import notify_result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns the filtered datasource.""" - ds: FieldList = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - """Represents an action to filter a step result.""" - - result_class: Type[FilterStepResult] = FilterStepResult - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource after applying the function.""" - - self.action.filter.context = FunctionContext(self) - try: - return _tidy( - self.action.filter.execute( - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - A string representation of the traced datasource. - """ - return f"{self.action.name}({self.group_of_dates})" - - -class FunctionStepAction(StepAction): - """Represents an action to apply a function to a step result.""" - - result_class: Type[StepFunctionResult] = StepFunctionResult - - def __init__( - self, - context: object, - action_path: list, - previous_step: StepAction, - name: str, - filter: Any, - *args: Any, - **kwargs: Any, - ) -> None: - """Initializes a FunctionStepAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - previous_step : StepAction - The previous step action. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = name - self.filter = filter diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py deleted file mode 100644 index 586003bb2..000000000 --- a/src/anemoi/datasets/create/input/function.py +++ /dev/null @@ -1,233 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Dict - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .misc import _tidy -from .misc import assert_fieldlist -from .result.field import Result -from .template import notify_result -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner: Result) -> None: - """Initializes a FunctionContext instance. - - Parameters - ---------- - owner : object - The owner object. - """ - self.owner = owner - self.use_grib_paramid: bool = owner.context.use_grib_paramid - - def trace(self, emoji: str, *args: Any) -> None: - """Traces the given arguments with an emoji. - - Parameters - ---------- - emoji : str - The emoji to use. - *args : Any - The arguments to trace. - """ - trace(emoji, *args) - - def info(self, *args: Any, **kwargs: Any) -> None: - """Logs an info message. - - Parameters - ---------- - *args : Any - The arguments for the log message. - **kwargs : Any - The keyword arguments for the log message. - """ - LOG.info(*args, **kwargs) - - @property - def dates_provider(self) -> object: - """Returns the dates provider.""" - return self.owner.group_of_dates.provider - - @property - def partial_ok(self) -> bool: - """Returns whether partial results are acceptable.""" - return self.owner.group_of_dates.partial_ok - - def get_result(self, *args, **kwargs) -> Any: - return self.owner.context.get_result(*args, **kwargs) - - -class FunctionAction(Action): - """Represents an action that executes a function. - - Attributes - ---------- - name : str - The name of the function. - """ - - def __init__(self, context: object, action_path: list, _name: str, source, **kwargs: Dict[str, Any]) -> None: - """Initializes a FunctionAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - _name : str - The name of the function. - **kwargs : Dict[str, Any] - Additional keyword arguments. - """ - super().__init__(context, action_path, **kwargs) - self.name: str = _name - self.source = source - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": - """Selects the function result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - FunctionResult - The function result instance. - """ - return FunctionResult(self.context, self.action_path, group_of_dates, action=self) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionAction instance.""" - content: str = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return self._repr(_inline_=content, _indent_=" ") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Traces the selection of the function for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - str - The trace string. - """ - return f"{self.name}({group_of_dates})" - - -class FunctionResult(Result): - """Represents the result of executing a function. - - Attributes - ---------- - action : Action - The action instance. - args : tuple - The positional arguments for the function. - kwargs : dict - The keyword arguments for the function. - """ - - def __init__(self, context: object, action_path: list, group_of_dates: GroupOfDates, action: Action) -> None: - """Initializes a FunctionResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - action : Action - The action instance. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(action, Action), type(action) - self.action: Action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource for the function result.""" - # args, kwargs = resolve(self.context, (self.args, self.kwargs)) - self.action.source.context = FunctionContext(self) - - return _tidy( - self.action.source.execute( - list(self.group_of_dates), # Will provide a list of datetime objects - ) - ) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionResult instance.""" - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self) -> None: - """Raises NotImplementedError as this property is not implemented. - - Raises - ------ - NotImplementedError - Always raised. - """ - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py deleted file mode 100644 index c88da3e71..000000000 --- a/src/anemoi/datasets/create/input/step.py +++ /dev/null @@ -1,191 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Type - -from .action import Action -from .action import ActionContext -from .context import Context -from .result.field import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class StepResult(Result): - """Represents the result of a step in the data processing pipeline.""" - - def __init__( - self, context: Context, action_path: List[str], group_of_dates: Any, action: Action, upstream_result: Result - ) -> None: - """Initialize a StepResult instance. - - Parameters - ---------- - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - group_of_dates - The group of dates associated with this step. - action - The action associated with this step. - upstream_result - The result of the upstream step. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result: Result = upstream_result - self.action: Action = action - - @property - @notify_result - @trace_datasource - def datasource(self) -> Any: - """Retrieve the datasource associated with this step result.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - """Represents an action that is part of a step in the data processing pipeline.""" - - result_class: Optional[Type[StepResult]] = None - - def __init__( - self, context: ActionContext, action_path: List[str], previous_step: Any, *args: Any, **kwargs: Any - ) -> None: - """Initialize a StepAction instance. - - Parameters - ---------- - context - The context in which the action is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - """ - super().__init__(context, action_path, *args, **kwargs) - self.previous_step: Any = previous_step - - @trace_select - def select(self, group_of_dates: Any) -> StepResult: - """Select the result for a given group of dates. - - Parameters - ---------- - group_of_dates - The group of dates to select the result for. - - Returns - ------- - unknown - The result of the step. - """ - return self.result_class( - self.context, - self.action_path, - group_of_dates, - self, - self.previous_step.select(group_of_dates), - ) - - def __repr__(self) -> str: - """Return a string representation of the StepAction instance. - - Returns - ------- - unknown - String representation of the instance. - """ - return self._repr(self.previous_step, _inline_=str(self.kwargs)) - - -def step_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str], previous_step: Any) -> Any: - """Factory function to create a step action based on the given configuration. - - Parameters - ---------- - config - The configuration dictionary for the step. - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - - Returns - ------- - unknown - An instance of a step action. - """ - - from .filter import FilterStepAction - from .filter import FunctionStepAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is not None: - return cls(context, action_path, previous_step, *args, **kwargs) - - # Try filters from datasets filter registry - from anemoi.transform.filters import filter_registry as transform_filter_registry - - from ..filters import create_filter as create_datasets_filter - from ..filters import filter_registry as datasets_filter_registry - - if datasets_filter_registry.is_registered(key): - - if transform_filter_registry.is_registered(key): - warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") - - filter = create_datasets_filter(None, config) - return FunctionStepAction(context, action_path + [key], previous_step, key, filter) - - # Use filters from transform registry - - if transform_filter_registry.is_registered(key): - from ..filters.transform import TransformFilter - - return FunctionStepAction( - context, action_path + [key], previous_step, key, TransformFilter(context, key, config) - ) - - raise ValueError(f"Unknown step action `{key}`") diff --git a/src/anemoi/datasets/create/input/template.py b/src/anemoi/datasets/create/input/template.py deleted file mode 100644 index 8ea1ec275..000000000 --- a/src/anemoi/datasets/create/input/template.py +++ /dev/null @@ -1,162 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import re -from abc import ABC -from abc import abstractmethod -from functools import wraps -from typing import Any -from typing import Callable -from typing import List - -from .context import Context - -LOG = logging.getLogger(__name__) - - -def notify_result(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to notify the context of the result of the method call. - - Parameters - ---------- - method : Callable[..., Any] - The method to wrap. - - Returns - ------- - Callable[..., Any] - The wrapped method. - """ - - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result: Any = method(self, *args, **kwargs) - self.context.notify_result(self.action_path, result) - return result - - return wrapper - - -class Substitution(ABC): - """Abstract base class for substitutions in templates.""" - - @abstractmethod - def resolve(self, context: Context) -> Any: - """Resolve the substitution using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - pass - - -class Reference(Substitution): - """A class to represent a reference to another value in the context.""" - - def __init__(self, context: Any, action_path: List[str]) -> None: - """Initialize a Reference instance. - - Parameters - ---------- - context : Any - The context in which the reference exists. - action_path : list of str - The action path to resolve. - """ - self.context: Any = context - self.action_path: List[str] = action_path - - def resolve(self, context: Context) -> Any: - """Resolve the reference using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - return context.get_result(self.action_path) - - -def resolve(context: Context, x: Any) -> Any: - """Recursively resolve substitutions in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for resolution. - x : Union[tuple, list, dict, Substitution, Any] - The structure to resolve. - - Returns - ------- - Any - The resolved structure. - """ - if isinstance(x, tuple): - return tuple([resolve(context, y) for y in x]) - - if isinstance(x, list): - return [resolve(context, y) for y in x] - - if isinstance(x, dict): - return {k: resolve(context, v) for k, v in x.items()} - - if isinstance(x, Substitution): - return x.resolve(context) - - return x - - -def substitute(context: Context, x: Any) -> Any: - """Recursively substitute references in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for substitution. - x : Union[tuple, list, dict, str, Any] - The structure to substitute. - - Returns - ------- - Any - The substituted structure. - """ - if isinstance(x, tuple): - return tuple([substitute(context, y) for y in x]) - - if isinstance(x, list): - return [substitute(context, y) for y in x] - - if isinstance(x, dict): - return {k: substitute(context, v) for k, v in x.items()} - - if not isinstance(x, str): - return x - - if re.match(r"^\${[\.\w\-]+}$", x): - path = x[2:-1].split(".") - context.will_need_reference(path) - return Reference(context, path) - - return x diff --git a/src/anemoi/datasets/create/sources/__init__.py b/src/anemoi/datasets/create/sources/__init__.py index 1710d1303..f8b99f36d 100644 --- a/src/anemoi/datasets/create/sources/__init__.py +++ b/src/anemoi/datasets/create/sources/__init__.py @@ -34,14 +34,3 @@ def create_source(context: Any, config: Any) -> Any: The created source. """ return source_registry.from_config(config, context) - - -def registered_sources() -> list[str]: - """Get a list of registered source names. - - Returns - ------- - list[str] - A list of names of registered sources. - """ - return source_registry.registered diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 921469025..1958820c4 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -47,7 +47,7 @@ def constants(context: Any, dates: List[str], template: Dict[str, Any], param: s if len(template) == 0: raise ValueError("Forcings template is empty.") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index e1944e151..8e2977273 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -36,7 +36,7 @@ def forcings(context: Any, dates: List[str], template: str, param: str) -> Any: Loaded forcing data. """ context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute = forcings diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index f8035ee16..c76a11c9b 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,6 @@ from typing import Any from typing import Callable -from anemoi.datasets.create.input.template import resolve - from ..source import Source from . import source_registry @@ -74,7 +72,8 @@ def __call__(self, execute: Callable) -> Callable: def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(context, (self.args, self.kwargs)) + # args, kwargs = resolve(context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs try: return execute(context, dates, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py new file mode 100644 index 000000000..d092f08ad --- /dev/null +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -0,0 +1,319 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import Generator +from typing import Optional +from typing import Set +from typing import Tuple + +import numpy as np +import rich +from anemoi.transform.fields import new_field_with_valid_datetime +from anemoi.transform.fields import new_fieldlist_from_list +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +LOG = logging.getLogger(__name__) + + +class Action: + pass + + +class Result: + pass + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: Set[Any] = set() + self.found: Set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: Optional[int] = hour + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Optional[Any] = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Optional[Any] = date + + def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +@source_registry.register("repeated_dates") +class RepeatedDatesSource(Source): + + def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.mapper = DateMapper.from_mode(mode, source, kwargs) + self.source = source + + def execute(self, context, group_of_dates): + source = context.create_source(self.source) + + result = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + source_results = source(context, one_date_group) + for field in source_results: + for date in many_dates_group: + result.append(new_field_with_valid_datetime(field, date)) + + return new_fieldlist_from_list(result) From 58dc8a2652482fe0fc72c89e48c71929fd709f16 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 10 Jul 2025 12:44:57 +0000 Subject: [PATCH 26/79] work on migrate --- src/anemoi/datasets/commands/migrate.py | 602 ++++++++++++++--------- src/anemoi/datasets/commands/validate.py | 44 ++ src/anemoi/datasets/create/__init__.py | 41 ++ src/anemoi/datasets/schemas/recipe.json | 131 +++++ 4 files changed, 589 insertions(+), 229 deletions(-) create mode 100644 src/anemoi/datasets/commands/validate.py create mode 100644 src/anemoi/datasets/schemas/recipe.json diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 3183508b7..8d4b359ac 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -8,326 +8,443 @@ # nor does it submit to any jurisdiction. +import datetime import logging -from copy import deepcopy +import os +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any +import rich import yaml +from glom import assign +from glom import delete +from glom import glom + +from anemoi.datasets.create import validate_config from . import Command -errors = [] LOG = logging.getLogger(__name__) -ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") + +class MyDumper(yaml.SafeDumper): + pass + + +def find_paths(data, target_key=None, target_value=None, *path): + + matches = [] + + if isinstance(data, Mapping): + for k, v in data.items(): + if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): + matches.append(list(path) + [k]) + matches.extend(find_paths(v, target_key, target_value, *path, k)) + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)): + for i, item in enumerate(data): + matches.extend(find_paths(item, target_key, target_value, *path, str(i))) + return matches + + +# Custom representer for datetime.date and datetime.datetime +def represent_date(dumper, data): + if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): + data = datetime.datetime(data.year, data.month, data.day, 0, 0, 0) + # Ensure it's UTC + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + # Format as ISO 8601 with 'Z' + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# Custom representer for multiline strings using the '|' block style +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + # Flow style if list has <= 4 simple elements + if ( + all(isinstance(i, (str, int, float, bool, type(None))) for i in data) + and len(", ".join([str(x) for x in data])) + 2 <= 80 + ): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + +# Register custom representers +MyDumper.add_representer(datetime.date, represent_date) +MyDumper.add_representer(datetime.datetime, represent_date) +MyDumper.add_representer(str, represent_multiline_str) +MyDumper.add_representer(list, represent_inline_list) + + +def make_dates(config): + if isinstance(config, dict): + return {k: make_dates(v) for k, v in config.items()} + if isinstance(config, list): + return [make_dates(v) for v in config] + if isinstance(config, str): + try: + return datetime.datetime.fromisoformat(config) + except ValueError: + return config + return config + + +ORDER = ( + "name", + "description", + "dataset_status", + "licence", + "attribution", + "env", + "dates", + "common", + "data_sources", + "input", + "output", + "statistics", + "build", + "platform", +) ORDER = {k: i for i, k in enumerate(ORDER)} def order(x: str) -> int: - if x[0] not in ORDER: - ORDER[x[0]] = len(ORDER) - - return ORDER[x[0]] + try: + return ORDER[x[0]] + except KeyError: + rich.print(f"Unknown key: {x}") + raise MIGRATE = { "output.statistics_end": "statistics.end", "has_nans": "statistics.allow_nans", "loop.dates.group_by": "build.group_by", + "loop.0.dates.group_by": "build.group_by", "loop.dates": "dates", + "loop.0.dates": "dates", "copyright": "attribution", + "dates.<<": "dates", + "options.group_by": "build.group_by", + "loops.0.loop_a.dates": "dates", + "loop.0.loop_a.dates": "dates", + "dates.stop": "dates.end", + "dates.group_by": "build.group_by", + "include": "data_sources", + "ensemble_dimension": "build.ensemble_dimension", + "flatten_grid": "build.flatten_grid", } +DELETE = [ + "purpose", + "input.join.0.label", + "status", + "common", + "config_format_version", + "aliases", + "platform", + "loops.0.loop_a.applies_to", + "loop.0.loop_a.applies_to", + "dataset_status", + "alias", + "resources", +] + + SOURCES = { "oper-accumulations": "accumulations", "era5-accumulations": "accumulations", - "constants": "forcings", "ensemble-perturbations": "recentre", "ensemble_perturbations": "recentre", "perturbations": "recentre", "custom-regrid": "regrid", } +MARKER = object() -def _move(config, path, new_path, result): - path = path.split(".") - if new_path is not None: - new_path = new_path.split(".") - - for k in path[:-1]: - if k not in config: - return - config = config[k] - if path[-1] not in config: +def _delete(config, path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: return + rich.print(f"Deleting {path}={x}") + delete(result, path) - value = config.pop(path[-1]) - if new_path is None: +def _move(config, path, new_path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: return + rich.print(f"Moving {path}={x} to {new_path}={x}") + delete(result, path) + assign(result, new_path, x, missing=dict) - for k in new_path[:-1]: - if k not in result: - result[k] = {} - result = result[k] - result[new_path[-1]] = value +def _fix_input_0(result, config): + if isinstance(config["input"], dict): + return + input = config["input"] + new_input = result["input"] = [] -def _fix_dates(result, config): - dates = config["input"].pop("dates") - assert "join" in dates, dates - result["input"] = dates["join"] - config["input"] = result["input"].copy() + blocks = {} + first = None + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block -def _fix_list(result, config): - result["input"] = dict(join=result["input"]) - config["input"] = result["input"].copy() + block_name, values = list(block.items())[0] + if "kwargs" in values: + inherit = values.pop("inherit", None) + assert len(values) == 1, values + values = values["kwargs"] + values.pop("date", None) + source_name = values.pop("name", None) -def _fix_join_0(result, config): + if inherit is not None: + inherited = blocks[inherit].copy() + inherited.update(values) + values = inherited - join = config["input"]["join"] + if "source_or_dataset" in values: + values.pop("source_or_dataset", None) + values["template"] = "${input.join.0." + first + "}" - new_join = [] - for n in join: + if first is None: + first = source_name - if "function" in n: - f = n["function"] - name = f.pop("name") - data = _tidy(f) - for k, v in list(data.items()): - if isinstance(v, dict): - if "name" in v: - p = v.pop("name") - data[k] = {SOURCES.get(p, p): _tidy(v)} + blocks[block_name] = values.copy() - new_join.append({SOURCES.get(name, name): data}) - continue + new_input.append({block_name: {SOURCES.get(source_name, source_name): values.copy()}}) + else: + assert False, f"Block {block_name} does not have 'kwargs': {values}" - new_join.append(n) # {SOURCES.get(src, src): _tidy(data)}) + blocks[block_name] = values.copy() - result["input"] = dict(join=new_join) config["input"] = result["input"].copy() -def _fix_join_1(result, config): - - join = config["input"].pop("join") - new_join = [] - for n in join: - if isinstance(n, dict): - if len(n) == 1: - if "label" in n: - n = n["label"] - - if isinstance(n, dict): - if len(n) == 2: - if "name" in n and "source" in n: - n.pop("name") - - if isinstance(n, dict): - if len(n) == 1: - if "source" in n: - n = n["source"] - if "<<" in n: - n.update(n.pop("<<")) - name = n.pop("name", "mars") - new_join.append({SOURCES.get(name, name): _tidy(n)}) - continue - - new_join.append(n) - - result["input"] = dict(join=new_join) - config["input"] = result["input"].copy() - - -def _fix_join_3(result, config): - - join = config["input"].pop("join") - new_join = [] - for n in join: - if not isinstance(n, dict): - return - if len(n) != 1: - return - - name = list(n.keys())[0] - data = n[name] +def _fix_input_1(result, config): + if isinstance(config["input"], dict): + return - new_join.append({SOURCES.get(name, name): data}) + input = config["input"] + join = [] + for k in input: + assert isinstance(k, dict) + assert len(k) == 1, f"Input key {k} is not a string: {input}" + name, values = list(k.items())[0] + join.append(values) - result["input"] = dict(join=new_join) + result["input"] = {"join": join} config["input"] = result["input"].copy() -def _tidy(data): - for k, v in list(data.items()): - if k in ("date", "time"): - if isinstance(v, str) and v.startswith("$"): - del data[k] - - if "name" in data: - assert False, data - name = data.pop("name") - return {SOURCES.get(name, name): _tidy(data)} - - return data - +def remove_empties(config: dict) -> None: + """Remove empty dictionaries and lists from the config.""" + if isinstance(config, dict): + keys_to_delete = [k for k, v in config.items() if v in (None, {}, [], [{}])] -def _fix_join_2(result, config): + for k in keys_to_delete: + del config[k] - join = config["input"]["join"] + for k, v in config.items(): + remove_empties(v) - previous = {} + if isinstance(config, list): + for item in config: + remove_empties(item) - new_join = [] - for n in join: - if not isinstance(n, dict): - return - - if len(n) != 1: - return - - what = list(n.keys())[0] +def _fix_loops(result: dict, config: dict) -> None: + if "loops" not in config: + return - if n[what] is None: - assert False, (n, what, config["input"]) + input = config["input"] + loops = config["loops"] + + assert isinstance(loops, list), loops + assert isinstance(input, list), input + + entries = {} + dates_block = None + for loop in loops: + assert isinstance(loop, dict), loop + assert len(loop) == 1, loop + loop = list(loop.values())[0] + applies_to = loop["applies_to"] + dates = loop["dates"] + assert isinstance(applies_to, list), (applies_to, loop) + for a in applies_to: + entries[a] = dates.copy() + + if "start" in dates: + start = dates["start"] + else: + start = max(dates["values"]) + + if "end" in dates or "stop" in dates: + end = dates.get("end", dates.get("stop")) + else: + end = min(dates["values"]) + + if dates_block is None: + dates_block = { + "start": start, + "end": end, + } + + if "frequency" in dates: + if "frequency" not in dates_block: + dates_block["frequency"] = dates["frequency"] + else: + assert dates_block["frequency"] == dates["frequency"], (dates_block["frequency"], dates["frequency"]) + + dates_block["start"] = min(dates_block["start"], start) + dates_block["end"] = max(dates_block["end"], end) + + concat = [] + result["input"] = {"concat": concat} + + rich.print("Found loops:", entries) + + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block + name, values = list(block.items())[0] + assert name in entries, f"Loop {name} not found in loops: {list(entries.keys())}" + dates = entries[name].copy() + + assert "kwargs" not in values + + concat.append(dict(dates=dates, **values)) + + d = concat[0]["dates"] + if all(c["dates"] == d for c in concat): + join = [] + for c in concat: + del c["dates"] + join.append(c) + result["input"] = {"join": join} + + del config["loops"] + config["input"] = result["input"].copy() + config["dates"] = dates_block.copy() + del result["loops"] + result["dates"] = dates_block - if "kwargs" not in n[what]: - return - # assert False +def _fix_other(result: dict, config: dict) -> None: + paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") + for p in paths: + rich.print(f"Fixing {'.'.join(p)}") + assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) + delete(result, ".".join(p)) - previous[what] = deepcopy(n[what]["kwargs"]) - if "inherit" in n[what]: - previous[what].update(deepcopy(previous[n[what]["inherit"]])) + paths = find_paths(config, target_key="date", target_value="$dates") + for p in paths: + delete(result, ".".join(p)) - data = previous[what].copy() - src = data.pop("name", "mars") - new_join.append({SOURCES.get(src, src): _tidy(data)}) +def _migrate(config: dict, n) -> dict: - result["input"] = dict(join=new_join) - config["input"] = result["input"].copy() + result = config.copy() + _fix_input_0(result, config) + _fix_loops(result, config) + _fix_input_1(result, config) + _fix_other(result, config) -def _migrate(config: dict, n) -> dict: - result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) - if "dates" in config["input"]: - _fix_dates(result, config) - - if isinstance(config["input"], list): - _fix_list(result, config) - - if "join" in config["input"]: - _fix_join_0(result, config) - - if "join" in config["input"]: - _fix_join_1(result, config) - - if "join" in config["input"]: - _fix_join_2(result, config) - - if "join" in config["input"]: - _fix_join_3(result, config) - - # _check(result, "1") - - # if isinstance(result["input"], list): - # assert n == 0 - # join = [] - # prev = {} - # for n in result["input"]: - # assert isinstance(n, dict), (n, type(n)) - # assert len(n) == 1, (n, type(n)) - # name = list(n.keys())[0] - # prev[name] = n[name]["kwargs"] - # if "inherit" in n[name]: - # i = n[name]["inherit"] - # n[name]["kwargs"].update(prev[i]) - # n[name].pop("inherit") - - # data = n[name]["kwargs"] - - # src = data.pop("name", "mars") - - # join.append({SOURCES.get(src, src): data}) - - # result["input"] = dict(join=join) - # _check(result, "2") - - # if "join" in result["input"] and n == 0: - # join = result["input"].pop("join") - # new_join = [] - - # for j in join: - - # if "label" in j: - # if isinstance(j["label"], str): - # j.pop("label") - # else: - # if j["label"] is not None: - # j = j["label"] - # j.pop("name", None) - - # if "source" in j: - # j = j["source"] - - # src = j.pop("name", "mars") - # data = j - # if "<<" in data: - # data.update(data.pop("<<")) + for k in DELETE: + _delete(config, k, result) - # for k, v in list(data.items()): - # if k in ("date", "time"): - # if isinstance(v, str) and v.startswith("$"): - # del data[k] - - # if "mars" in data: - # new_join.append(data) - # else: - # new_join.append({SOURCES.get(src, src): data}) - - # result["input"]["join"] = new_join - # _check(result, "3") - - # if "join" in result["input"]: - # for j in result["input"]["join"]: - # k = list(j.keys())[0] - # j[k].pop("name", None) - - # if "source_or_dataset" in j[k]: - # j[k].pop("source_or_dataset", None) - # j[k]["template"] = "${input.0.join.0.mars}" - # _check(result, "4") - - result = {k: v for k, v in sorted(result.items(), key=order) if v} - - result.pop("loop", None) + remove_empties(result) return result def migrate(old: dict) -> dict: - # return _migrate(old) + for i in range(10): new = _migrate(old, i) if new == old: - # print(json.dumps(new, indent=2, default=str)) return new old = new return new +def has_key(config, key: str) -> bool: + if isinstance(config, dict): + if key in config: + return True + for k, v in config.items(): + if has_key(v, key): + return True + if isinstance(config, list): + for item in config: + if has_key(item, key): + return True + return False + + +def has_value(config, value: str) -> bool: + if isinstance(config, dict): + for k, v in config.items(): + if v == value: + return True + if has_value(v, value): + return True + + if isinstance(config, list): + for item in config: + if item == value: + return True + if has_value(item, value): + return True + return config == value + + +def check(config): + from anemoi.datasets.create import validate_config + + try: + + validate_config(config) + assert config.get("input", {}) + assert config.get("dates", {}) + assert not has_key(config, "label") + assert not has_key(config, "kwargs") + assert not has_value(config, "$previous_data") + assert not has_value(config, "$dates") + assert not has_key(config, "inherit") + assert not has_key(config, "source_or_dataset") + assert not has_key(config, "<<") + + for n in SOURCES.keys(): + assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." + + except Exception as e: + rich.print(f"Validation failed: {e}") + rich.print(f"Config: {config}") + raise + + class Recipe(Command): def add_arguments(self, command_parser: Any) -> None: """Add arguments to the command parser. @@ -343,12 +460,39 @@ def add_arguments(self, command_parser: Any) -> None: ) def run(self, args: Any) -> None: + + rich.print(f"Migrating {args.path}") + with open(args.path, "r") as file: config = yaml.safe_load(file) - print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) + try: + validate_config(config) + LOG.info(f"{args.path}: Validation successful.") + return + except Exception: + pass + + migrated = migrate(config) + + migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} + + check(migrated) + if migrated == config: + LOG.info(f"{args.path}: No changes needed.") + return + + migrated = make_dates(migrated) + text = yaml.dump(migrated, default_flow_style=False, sort_keys=False, indent=2, width=120, Dumper=MyDumper) + + LOG.info(f"{args.path}: updating.") + with open(args.path + ".tmp", "w") as f: + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) - assert not errors, f"Errors: {errors}" + os.rename(args.path + ".tmp", args.path) command = Recipe diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py new file mode 100644 index 000000000..84b25c6f8 --- /dev/null +++ b/src/anemoi/datasets/commands/validate.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import yaml + +from . import Command + +LOG = logging.getLogger(__name__) + + +class Validate(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + from anemoi.datasets.create import validate_config + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + validate_config(config) + + +command = Validate diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index b6b51cb69..37bf11763 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1631,6 +1631,47 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An return cls(**kwargs) +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise + + def config_to_python(config: Any) -> Any: config = loader_config(config) diff --git a/src/anemoi/datasets/schemas/recipe.json b/src/anemoi/datasets/schemas/recipe.json new file mode 100644 index 000000000..3c02bfd64 --- /dev/null +++ b/src/anemoi/datasets/schemas/recipe.json @@ -0,0 +1,131 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$id": "https://ecmwf.int/anemoi-datasets-recipe.schema.json", + "title": "Product", + "description": "Anemoi datasets recipe configuration", + "additionalProperties": false, + "$defs": { + "source-or-filter": { + "type": "object", + "minProperties": 1, + "maxProperties": 1 + }, + "pipe": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "join": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "concat": { + "type": "array", + "items": { + "type": "object", + "minProperties": 2, + "maxProperties": 2, + "required": [ + "dates" + ] + } + }, + "input-object": { + "oneOf": [ + { + "$ref": "#/$defs/pipe" + }, + { + "$ref": "#/$defs/join" + }, + { + "$ref": "#/$defs/concat" + }, + { + "$ref": "#/$defs/source-or-filter" + } + ] + } + }, + "properties": { + "env": { + "type": "object" + }, + "description": { + "type": "string" + }, + "name": { + "type": "string" + }, + "licence": { + "type": "string" + }, + "attribution": { + "type": "string" + }, + "dates": { + "type": "object", + "required": [ + "start", + "end" + ], + "properties": { + "start": { + "type": "string", + "format": "date" + }, + "end": { + "type": "string", + "format": "date" + }, + "frequency": { + "type": [ + "integer", + "string" + ] + }, + "group_by": { + "type": [ + "integer", + "string" + ] + } + } + }, + "input": { + "$ref": "#/$defs/input-object" + }, + "data_sources": { + "type": "object", + "patternProperties": { + "^[a-zA-Z_][a-zA-Z0-9_]*$": { + "$ref": "#/$defs/input-object" + } + }, + "additionalProperties": false + }, + "output": { + "type": "object" + }, + "statistics": { + "type": "object" + }, + "build": { + "type": "object" + }, + "common": { + "type": "object" + }, + "platform": { + "type": "object" + } + }, + "required": [ + "dates", + "input" + ] +} From 3e180f93d579df109027843607a1ceb9791b4dfc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 10 Jul 2025 14:53:05 +0000 Subject: [PATCH 27/79] work on migrate --- src/anemoi/datasets/commands/migrate.py | 101 +++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8d4b359ac..fa74886b8 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -11,7 +11,6 @@ import datetime import logging import os -from collections.abc import Mapping from collections.abc import Sequence from typing import Any @@ -36,7 +35,7 @@ def find_paths(data, target_key=None, target_value=None, *path): matches = [] - if isinstance(data, Mapping): + if isinstance(data, dict): for k, v in data.items(): if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): matches.append(list(path) + [k]) @@ -47,6 +46,21 @@ def find_paths(data, target_key=None, target_value=None, *path): return matches +def find_chevrons(data, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if k == "<<": + matches.append(list(path) + [k]) + matches.extend(find_chevrons(v, *path, k)) + elif isinstance(data, list): + for i, item in enumerate(data): + matches.extend(find_chevrons(item, *path, str(i))) + return matches + + # Custom representer for datetime.date and datetime.datetime def represent_date(dumper, data): if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): @@ -63,7 +77,9 @@ def represent_date(dumper, data): # Custom representer for multiline strings using the '|' block style def represent_multiline_str(dumper, data): if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + text_list = [line.rstrip() for line in data.splitlines()] + fixed_data = "\n".join(text_list) + return dumper.represent_scalar("tag:yaml.org,2002:str", fixed_data, style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -358,6 +374,82 @@ def _fix_other(result: dict, config: dict) -> None: delete(result, ".".join(p)) +def _fix_join(result: dict, config: dict) -> None: + rich.print("Fixing join...") + input = config["input"] + if "dates" in input and "join" in input["dates"]: + result["input"]["join"] = input["dates"]["join"] + config["input"]["join"] = input["dates"]["join"].copy() + + if "join" not in input: + return + + join = input["join"] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + if key not in ("label", "source"): + return + + assert isinstance(values, dict), f"Join values for {key} should be a dict: {values}" + if key == "label": + j = values + j.pop("name") + key, values = list(j.items())[0] + + print(values) + source_name = values.pop("name", "mars") + new_join.append( + { + SOURCES.get(source_name, source_name): values, + } + ) + + result["input"] = {"join": new_join} + config["input"] = result["input"].copy() + + +def _fix_sources(result: dict, config: dict, what) -> None: + + input = config["input"] + if what not in input: + return + + join = input[what] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + key = SOURCES.get(key, key) + + new_join.append( + { + key: values, + } + ) + + result["input"][what] = new_join + config["input"][what] = new_join.copy() + + +def _fix_chevrons(result: dict, config: dict) -> None: + rich.print("Fixing chevrons...") + paths = find_chevrons(config) + for p in paths: + a = glom(config, ".".join(p)) + b = glom(config, ".".join(p[:-1])) + delete(result, ".".join(p)) + a.update(b) + assign(result, ".".join(p[:-1]), a) + + def _migrate(config: dict, n) -> dict: result = config.copy() @@ -365,6 +457,9 @@ def _migrate(config: dict, n) -> dict: _fix_input_0(result, config) _fix_loops(result, config) _fix_input_1(result, config) + _fix_join(result, config) + _fix_sources(result, config, "join") + _fix_chevrons(result, config) _fix_other(result, config) for k, v in MIGRATE.items(): From 255c22df684878f2379f7da32c35d609cabfc577 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:05:00 +0200 Subject: [PATCH 28/79] merge --- CHANGELOG.md | 26 + docs/adr/adr-1.md | 63 ++ .../building/sources/anemoi-dataset.rst | 6 + .../datasets/building/sources/xarray-zarr.rst | 6 + .../sources/yaml/anemoi-zarr-dataset.yaml | 3 + docs/datasets/using/combining.rst | 4 +- docs/howtos/create/02-cf-data.rst | 9 + docs/howtos/create/yaml/zarr2.yaml | 8 + docs/howtos/usage/code/cutout-complement1.py | 1 + .../howtos/usage/yaml/cutout-complement1.yaml | 1 + pyproject.toml | 14 +- .../datasets/commands/finalise-additions.py | 3 +- src/anemoi/datasets/commands/finalise.py | 3 +- .../datasets/commands/init-additions.py | 3 +- .../datasets/commands/load-additions.py | 3 +- src/anemoi/datasets/commands/load.py | 3 +- src/anemoi/datasets/commands/recipe.py | 6 +- src/anemoi/datasets/create/__init__.py | 86 ++- src/anemoi/datasets/create/input/__init__.py | 86 +-- src/anemoi/datasets/create/input/action.py | 542 +++++++----------- src/anemoi/datasets/create/input/concat.py | 184 ------ src/anemoi/datasets/create/input/context.py | 89 --- .../datasets/create/input/context/__init__.py | 63 ++ .../datasets/create/input/context/field.py | 54 ++ .../datasets/create/input/data_sources.py | 11 +- src/anemoi/datasets/create/input/empty.py | 54 -- src/anemoi/datasets/create/input/filter.py | 133 ----- src/anemoi/datasets/create/input/function.py | 244 -------- src/anemoi/datasets/create/input/join.py | 137 ----- src/anemoi/datasets/create/input/pipe.py | 77 --- .../datasets/create/input/repeated_dates.py | 28 +- .../datasets/create/input/result/__init__.py | 17 + .../input/{result.py => result/field.py} | 98 +--- src/anemoi/datasets/create/input/step.py | 203 ------- src/anemoi/datasets/create/input/template.py | 162 ------ src/anemoi/datasets/create/python.py | 174 ++++++ .../datasets/create/sources/accumulations.py | 13 +- .../datasets/create/sources/constants.py | 2 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/grib.py | 2 +- src/anemoi/datasets/create/sources/legacy.py | 9 +- .../datasets/create/sources/patterns.py | 2 +- .../create/sources/planetary_computer.py | 44 ++ .../datasets/create/sources/repeated_dates.py | 319 +++++++++++ .../create/sources/xarray_support/__init__.py | 28 +- .../sources/xarray_support/coordinates.py | 8 + .../create/sources/xarray_support/field.py | 5 +- .../create/sources/xarray_support/flavour.py | 50 +- .../create/sources/xarray_support/patch.py | 45 +- .../create/sources/xarray_support/variable.py | 8 +- src/anemoi/datasets/data/complement.py | 54 +- src/anemoi/datasets/data/dataset.py | 29 + src/anemoi/datasets/data/forwards.py | 10 +- src/anemoi/datasets/data/misc.py | 90 ++- .../datasets/data/observations/__init__.py | 316 ++++++++++ .../data/observations/legacy_obs_dataset.py | 200 +++++++ .../datasets/data/observations/multi.py | 64 +++ src/anemoi/datasets/data/padded.py | 227 ++++++++ src/anemoi/datasets/data/records/__init__.py | 442 ++++++++++++++ .../data/records/backends/__init__.py | 157 +++++ src/anemoi/datasets/data/stores.py | 63 +- src/anemoi/datasets/data/subset.py | 5 + src/anemoi/datasets/grids.py | 9 +- tests/conftest.py | 1 + tests/create/__init__.py | 0 tests/create/test_create.py | 369 +----------- tests/create/test_sources.py | 179 +++++- tests/create/utils/__init__.py | 0 tests/create/utils/compare.py | 218 +++++++ .../create/utils/create.py | 0 tests/create/utils/mock_sources.py | 117 ++++ tests/test_data.py | 26 +- tests/test_records.py | 160 ++++++ tests/xarray/test_flavour.py | 104 ++++ tests/xarray/test_zarr.py | 28 - tools/build-obs.py | 52 ++ tools/check-obs.py | 60 ++ 77 files changed, 3726 insertions(+), 2395 deletions(-) create mode 100644 docs/adr/adr-1.md create mode 100644 docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml create mode 100644 docs/howtos/create/yaml/zarr2.yaml delete mode 100644 src/anemoi/datasets/create/input/concat.py delete mode 100644 src/anemoi/datasets/create/input/context.py create mode 100644 src/anemoi/datasets/create/input/context/__init__.py create mode 100644 src/anemoi/datasets/create/input/context/field.py delete mode 100644 src/anemoi/datasets/create/input/empty.py delete mode 100644 src/anemoi/datasets/create/input/filter.py delete mode 100644 src/anemoi/datasets/create/input/function.py delete mode 100644 src/anemoi/datasets/create/input/join.py delete mode 100644 src/anemoi/datasets/create/input/pipe.py create mode 100644 src/anemoi/datasets/create/input/result/__init__.py rename src/anemoi/datasets/create/input/{result.py => result/field.py} (87%) delete mode 100644 src/anemoi/datasets/create/input/step.py delete mode 100644 src/anemoi/datasets/create/input/template.py create mode 100644 src/anemoi/datasets/create/python.py create mode 100644 src/anemoi/datasets/create/sources/planetary_computer.py create mode 100644 src/anemoi/datasets/create/sources/repeated_dates.py create mode 100644 src/anemoi/datasets/data/observations/__init__.py create mode 100644 src/anemoi/datasets/data/observations/legacy_obs_dataset.py create mode 100644 src/anemoi/datasets/data/observations/multi.py create mode 100644 src/anemoi/datasets/data/padded.py create mode 100644 src/anemoi/datasets/data/records/__init__.py create mode 100644 src/anemoi/datasets/data/records/backends/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/create/__init__.py create mode 100644 tests/create/utils/__init__.py create mode 100644 tests/create/utils/compare.py rename src/anemoi/datasets/create/testing.py => tests/create/utils/create.py (100%) create mode 100644 tests/create/utils/mock_sources.py create mode 100644 tests/test_records.py create mode 100644 tests/xarray/test_flavour.py create mode 100755 tools/build-obs.py create mode 100755 tools/check-obs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 73cde7091..226403999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! +## [0.5.25](https://github.com/ecmwf/anemoi-datasets/compare/0.5.24...0.5.25) (2025-06-11) + + +### Features + +* Integrating non-regular datasets in anemoi for observations. ([#306](https://github.com/ecmwf/anemoi-datasets/issues/306)) ([95a0fe4](https://github.com/ecmwf/anemoi-datasets/commit/95a0fe4bb10dc48469c0be0efad94f4d5e2a9fe8)) + + +### Bug Fixes + +* Incremental dataset build tasks called regardless of presence of debug flag in CLI code ([#294](https://github.com/ecmwf/anemoi-datasets/issues/294)) ([37afc0d](https://github.com/ecmwf/anemoi-datasets/commit/37afc0d6489f2d6c4b3ce3f9901c40e4cec5c4eb)) +* Regression in accumulations [#354](https://github.com/ecmwf/anemoi-datasets/issues/354) ([#355](https://github.com/ecmwf/anemoi-datasets/issues/355)) ([f9769d7](https://github.com/ecmwf/anemoi-datasets/commit/f9769d7944738ecbedb6b3cc1f78cd26de36a73f)) +* Remove 2 layers of build function ([#348](https://github.com/ecmwf/anemoi-datasets/issues/348)) ([7a904c4](https://github.com/ecmwf/anemoi-datasets/commit/7a904c451772089f120419a9d39bff746e0aeebb)) + +## [0.5.24](https://github.com/ecmwf/anemoi-datasets/compare/0.5.23...0.5.24) (2025-05-23) + + +### Features + +* verify command ([#279](https://github.com/ecmwf/anemoi-datasets/issues/279)) ([aed36d2](https://github.com/ecmwf/anemoi-datasets/commit/aed36d2ea7a39ea1ae6bbd5f8d01ef0ce7523cde)) + + +### Bug Fixes + +* adapt to earthkit 0.14 (ignore_keys issue) ([#331](https://github.com/ecmwf/anemoi-datasets/issues/331)) ([fb3ab8d](https://github.com/ecmwf/anemoi-datasets/commit/fb3ab8d46b8e00c62d8d7cbb1d1afae0efea2054)) + ## [0.5.23](https://github.com/ecmwf/anemoi-datasets/compare/0.5.22...0.5.23) (2025-05-07) diff --git a/docs/adr/adr-1.md b/docs/adr/adr-1.md new file mode 100644 index 000000000..6e22b2aa2 --- /dev/null +++ b/docs/adr/adr-1.md @@ -0,0 +1,63 @@ +# Support irregular observations datasets + +## Status + + + +Proposed - 30/04/2025 + +## Context + + + +The objective of this change is to support observations data which is not regular. + +In contrast with the fields data where each date contain the same number of points, +in the observations data, the number of points can change for every time window. + +The Zarr format fits well the fields data, but does not fit the observations data. + +To allow storing data with irregular shape, we need to use another format than the zarr used for fields. +An experimental implementation using xarray-zarr has been developed and is not optimised for ML training. + +## Decision + + + +Add a functionality in anemoi-datasets to read observations datasets and provide the data as dictionary/mapper of numpy arrays. + +Mimic as much as possible what is done for field datasets : + +`ds = open_dataset(....)` +`ds[i]` -> provides the data for a given time window, related to a given reference date. As a dictionary-like object. +`ds.dates` -> list of reference date, `ds.dates[i]` is the reference date for the data provided in `ds[i]` + +Also expose the latitudes, longitudes in a sensible way (as `ds.latitudes` and `ds.longitudes` now depend on the dates) and name_to_index and statistics and metadata, etc. + +These API choices need to be made on an actual training use case. + +Step to achieve this: +- Implement now a prototype format to allow developing ML training code on observation data. +- Performing extensive benchmarking with various formats (explore parquet, and other). +- As the final format is not defined yet, ensure a flexible architecture to allow switching (this will help for benchmarking). + + +## Scope of Change + + +- anemoi-datasets +Not a breaking change, this only add functionality to read observations datasets. + +Must be in line with the change related to multi-datasets. + +## Consequences + + + +## Alternatives Considered [Optional] + + + +## References [Optional] + + diff --git a/docs/datasets/building/sources/anemoi-dataset.rst b/docs/datasets/building/sources/anemoi-dataset.rst index d279614cf..a8e336318 100644 --- a/docs/datasets/building/sources/anemoi-dataset.rst +++ b/docs/datasets/building/sources/anemoi-dataset.rst @@ -17,3 +17,9 @@ An anemoi-dataset can be a source for a dataset: The parameters are the same as those used in the ``open_dataset`` function, which allows you to subset and combine datasets. See :ref:`opening-datasets` for more information. + +In particular, this is how local zarr datasets created with anemoi in a +can be used as a source, contrary to :ref:`xarray-zarr` : + +.. literalinclude:: yaml/anemoi-zarr-dataset.yaml + :language: yaml diff --git a/docs/datasets/building/sources/xarray-zarr.rst b/docs/datasets/building/sources/xarray-zarr.rst index 84f5158df..0f9ce62c8 100644 --- a/docs/datasets/building/sources/xarray-zarr.rst +++ b/docs/datasets/building/sources/xarray-zarr.rst @@ -1,3 +1,5 @@ +.. _xarray-zarr: + ############# xarray-zarr ############# @@ -17,4 +19,8 @@ it is necessary to use the :ref:`join ` operation to join separate lists containing 2D variables and 3D variables. If all vertical levels are desired, then it is acceptable to specify a single source. +Also, an ``xarray-zarr`` source uses the ``url`` keyword, and cannot be +used for accessing local datasets. For using local zarr datasets as +sources, use instead :ref:`anemoi-dataset_source`. + See :ref:`create-cf-data` for more information. diff --git a/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml b/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml new file mode 100644 index 000000000..1aa7b4f7e --- /dev/null +++ b/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml @@ -0,0 +1,3 @@ +input: + anemoi-dataset: + dataset: path/to/dataset.zarr diff --git a/docs/datasets/using/combining.rst b/docs/datasets/using/combining.rst index 4eac4c500..358953a26 100644 --- a/docs/datasets/using/combining.rst +++ b/docs/datasets/using/combining.rst @@ -235,13 +235,15 @@ variables of `dataset1` and return the result. source=dataset2, what="variables", interpolate="nearest", + k=1, ) Currently ``what`` can only be ``variables`` and can be omitted. The value for ``interpolate`` can be one of ``none`` (default) or ``nearest``. In the case of ``none``, the grids of the two datasets must -match. +match. In case of ``interpolate``, an additional parameter ``k`` can be +set to specify the number of nearest neighbors to use. This feature was originally designed to be used in conjunction with ``cutout``, where `dataset1` is the lam, and `dataset2` is the global diff --git a/docs/howtos/create/02-cf-data.rst b/docs/howtos/create/02-cf-data.rst index 38050ba42..ba875e3a8 100644 --- a/docs/howtos/create/02-cf-data.rst +++ b/docs/howtos/create/02-cf-data.rst @@ -46,9 +46,18 @@ can contain patterns. See :ref:`file-pattern` for more information. Zarr ****** +For using remote hosted zarr datasets as sources, use +:ref:`xarray-zarr`. + .. literalinclude:: yaml/zarr1.yaml :language: yaml +For using local zarr datasets (such as anemoi-generated datasets), use +:ref:`anemoi-dataset_source`. + +.. literalinclude:: yaml/zarr2.yaml + :language: yaml + ********************************************* Handling data that is not 100% CF-compliant ********************************************* diff --git a/docs/howtos/create/yaml/zarr2.yaml b/docs/howtos/create/yaml/zarr2.yaml new file mode 100644 index 000000000..659be1c72 --- /dev/null +++ b/docs/howtos/create/yaml/zarr2.yaml @@ -0,0 +1,8 @@ +dates: + start: 2023-01-01T00:00:00 + end: 2023-01-02T18:00:00 + frequency: 6h + +input: + anemoi-dataset: + dataset: /path/to/input.zarr diff --git a/docs/howtos/usage/code/cutout-complement1.py b/docs/howtos/usage/code/cutout-complement1.py index 347ae420c..dd9106fa9 100644 --- a/docs/howtos/usage/code/cutout-complement1.py +++ b/docs/howtos/usage/code/cutout-complement1.py @@ -14,4 +14,5 @@ }, source="global-dataset", interpolation="nearest", + k=1, ) diff --git a/docs/howtos/usage/yaml/cutout-complement1.yaml b/docs/howtos/usage/yaml/cutout-complement1.yaml index 97109fb5b..b2dfb605d 100644 --- a/docs/howtos/usage/yaml/cutout-complement1.yaml +++ b/docs/howtos/usage/yaml/cutout-complement1.yaml @@ -9,3 +9,4 @@ dataset: adjust: dates source: global-dataset interpolation: nearest + k: 1 diff --git a/pyproject.toml b/pyproject.toml index 799c26daa..873415a24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,8 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-transform>=0.1.9", - "anemoi-utils[provenance]>=0.4.21", + "anemoi-transform>=0.1.10", + "anemoi-utils[provenance]>=0.4.26", "cfunits", "numcodecs<0.16", # Until we move to zarr3 "numpy", @@ -71,7 +71,7 @@ optional-dependencies.comparelam = [ optional-dependencies.create = [ "cachetools", - "earthkit-data[mars]>=0.12.4,<0.14", + "earthkit-data[mars]>=0.14", "earthkit-geo>=0.3", "earthkit-meteo>=0.3", "eccodes>=2.39.1", @@ -100,6 +100,7 @@ optional-dependencies.remote = [ optional-dependencies.tests = [ "anemoi-datasets[xarray]", "pytest", + "pytest-xdist", ] optional-dependencies.xarray = [ @@ -131,6 +132,13 @@ version_file = "src/anemoi/datasets/_version.py" [tool.isort] profile = "black" +[tool.pytest.ini_options] +testpaths = "tests" +addopts = [ + "--numprocesses=auto", + "--strict-config", +] + [tool.mypy] strict = false exclude = [ diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index a155f16e3..811760ae9 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -61,7 +61,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index ee9fa1c00..53999ad50 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -55,7 +55,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index 54a48ed5f..09788f0e4 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -61,7 +61,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index 68aaa601a..a8cd5d5c9 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -62,7 +62,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 74220025f..3d969f5c3 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -62,7 +62,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py index 3045f1a1b..f111aee1b 100644 --- a/src/anemoi/datasets/commands/recipe.py +++ b/src/anemoi/datasets/commands/recipe.py @@ -28,6 +28,9 @@ def add_arguments(self, command_parser: Any) -> None: command_parser : Any Command parser object. """ + + command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument( "path", help="Path to recipe.", @@ -38,7 +41,8 @@ def run(self, args: Any) -> None: with open(args.path, "r") as file: config = yaml.safe_load(file) - config = migrate(config) + if args.migrate: + config = migrate(config) print(config_to_python(config)) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 37bf11763..69b5a0d42 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -11,7 +11,6 @@ import json import logging import os -import re import time import uuid import warnings @@ -22,6 +21,7 @@ import cftime import numpy as np +import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -45,7 +45,7 @@ from .chunks import ChunkFilter from .config import build_output from .config import loader_config -from .input import build_input +from .input import InputBuilder from .statistics import Summary from .statistics import TmpStatistics from .statistics import check_variance @@ -102,7 +102,9 @@ def json_tidy(o: Any) -> Any: def build_statistics_dates( - dates: list[datetime.datetime], start: Optional[datetime.datetime], end: Optional[datetime.datetime] + dates: list[datetime.datetime], + start: Optional[datetime.datetime], + end: Optional[datetime.datetime], ) -> tuple[str, str]: """Compute the start and end dates for the statistics. @@ -552,36 +554,16 @@ def create_elements(self, config: Any) -> None: self.output = build_output(config.output, parent=self) - self.input = build_input_(main_config=config, output_config=self.output) - # LOG.info("%s", self.input) - - -def build_input_(main_config: Any, output_config: Any) -> Any: - """Build the input for the dataset. - - Parameters - ---------- - main_config : Any - The main configuration. - output_config : Any - The output configuration. - - Returns - ------- - Any - The input builder. - """ - builder = build_input( - main_config.input, - data_sources=main_config.get("data_sources", {}), - order_by=output_config.order_by, - flatten_grid=output_config.flatten_grid, - remapping=build_remapping(output_config.remapping), - use_grib_paramid=main_config.build.use_grib_paramid, - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(builder) - return builder + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=config.build.use_grib_paramid, + ) + LOG.debug("✅ INPUT_BUILDER") + LOG.debug(self.input) class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): @@ -690,6 +672,8 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) + rich.print("Minimal input dates:", self.minimal_input) + variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") @@ -886,7 +870,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(group_of_dates=group) + result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. @@ -1542,7 +1526,16 @@ def run(self) -> None: if not all(self.registry.get_flags(sync=False)): raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) self.registry.add_to_history("compute_statistics_end") @@ -1674,24 +1667,11 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - config = loader_config(config) - - input = build_input_(config, build_output(config.output, None)) + from ..create.python import PythonCode - prelude = [] - input.python_prelude(prelude) - code1 = "\n".join(prelude) - - code2 = input.to_python() - - code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" - - code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) + config = loader_config(config) - try: - import black + input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - return black.format_str(code, mode=black.Mode()) - except ImportError: - LOG.warning("Black not installed, skipping formatting") - return code + code = PythonCode() + return input.python_code(code) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 66266d53d..63020324c 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,22 +7,12 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from copy import deepcopy +from functools import cached_property from typing import Any from typing import Union -from anemoi.datasets.dates.groups import GroupOfDates - -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class Context: - """Context for building input data.""" - - pass +from anemoi.datasets.create.input.context.field import FieldContext class InputBuilder: @@ -51,15 +41,20 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) ) ) self.config = config - self.action_path = ["input"] - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Any: + @cached_property + def action(self) -> Any: + """Returns the action object based on the configuration.""" + from .action import action_factory + + return action_factory(self.config, "input") + + def select(self, argument) -> Any: """Select data based on the group of dates. Parameters ---------- - group_of_dates : GroupOfDates + argument : GroupOfDates Group of dates to select data for. Returns @@ -67,60 +62,11 @@ def select(self, group_of_dates: GroupOfDates) -> Any: Any Selected data. """ - from .action import ActionContext - from .action import action_factory - - """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) - - def to_python(self) -> str: - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - - return action.to_python() - - def python_prelude(self, prelude) -> str: - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.python_prelude(prelude) - - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) + context = FieldContext(argument, **self.kwargs) + return context.create_result(self.action(context, argument)) - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the select operation. - - Parameters - ---------- - group_of_dates : GroupOfDates - Group of dates to select data for. - - Returns - ------- - str - Trace string. - """ - return f"InputBuilder({group_of_dates})" + def python_code(self, code): + return self.action.python_code(code) def build_input(config: dict, data_sources: Union[dict, list], **kwargs: Any) -> InputBuilder: diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 121d1e387..8dadf14dc 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,338 +7,230 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import datetime -import json import logging -import re -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from anemoi.utils.dates import frequency_to_string -from earthkit.data.core.order import build_remapping +import rich -from ...dates.groups import GroupOfDates -from .context import Context -from .template import substitute +from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) class Action: - """Represents an action to be performed within a given context. - - Attributes - ---------- - context : ActionContext - The context in which the action exists. - kwargs : Dict[str, Any] - Additional keyword arguments. - args : Any - Additional positional arguments. - action_path : List[str] - The action path. - """ - - def __init__( - self, context: "ActionContext", action_path: List[str], /, *args: Any, **kwargs: Dict[str, Any] - ) -> None: - """Initialize an Action instance. - - Parameters - ---------- - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - args : Any - Additional positional arguments. - kwargs : Dict[str, Any] - Additional keyword arguments. - """ - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x: str) -> str: - """Shorten the string representation if it exceeds 1000 characters. - - Parameters - ---------- - x : str - The string to shorten. - - Returns - ------- - str - The shortened string. - """ - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def _repr(self, *args: Any, _indent_: str = "\n", _inline_: str = "", **kwargs: Any) -> str: - """Generate a string representation of the Action instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str, optional - The indentation string, by default "\n". - _inline_ : str, optional - The inline string, by default "". - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Action instance. - - Returns - ------- - str - The string representation. - """ - return self._repr() - - def select(self, dates: object, **kwargs: Any) -> None: - """Select dates for the action. - - Parameters - ---------- - dates : object - The dates to select. - kwargs : Any - Additional keyword arguments. - """ - self._raise_not_implemented() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the selection of a group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates to trace. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({group_of_dates})" - - def _to_python(self, name: str, config: dict, **extra: Any) -> str: - """Convert the action to Python code. - - Parameters - ---------- - name : str - The name of the action. - config : dict - The configuration for the action. - extra : Any - Additional keyword arguments. - - Returns - ------- - str - The Python code representation of the action. - """ - import json - - RESERVED_KEYWORDS = ( - "and", - "or", - "not", - "is", - "in", - "if", - "else", - "elif", - "for", - "while", - "return", - "class", - "def", - "with", - "as", - "import", - "from", - "try", - "except", - "finally", - "raise", - "assert", - "break", - "continue", - "pass", + def __init__(self, config, *path): + self.config = config + self.path = path + # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") + + +class Concat(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + + assert isinstance(config, list), f"Value must be a dict {list}" + + self.choices = [] + + for item in config: + + assert "dates" in item, f"Value must contain the key 'dates' {item}" + dates = item["dates"] + filtering_dates = DatesProvider.from_config(**dates) + action = action_factory({k: v for k, v in item.items() if k != "dates"}) + self.choices.append((filtering_dates, action)) + + def __repr__(self): + return f"Concat({self.choices})" + + def __call__(self, context, argument): + + results = context.empty_result() + + for filtering_dates, action in self.choices: + dates = context.matching_dates(filtering_dates, argument) + if len(dates) == 0: + continue + results += action(context, dates) + + return context.register(results, self.path) + + def python_code(self, code): + return code.concat( + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} ) - def convert(obj): - if isinstance(obj, datetime.datetime): - return obj.isoformat() - if isinstance(obj, datetime.date): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return frequency_to_string(obj) - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - config = json.loads(json.dumps(config, default=convert)) - - assert len(config) == 1, (name, config) - assert name in config, (name, config) - - config = config[name] - - params = [] - for k, v in config.items(): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: - return f"r.{name}({config})" - params.append(f"{k}={repr(v)}") - - for k, v in extra.items(): - params.append(f"{k}={v}") - - params = ",".join(params) - return f"r.{name}({params})" - # return f"{name}({config})" - - -class ActionContext(Context): - """Represents the context in which an action is performed. - - Attributes - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - """Initialize an ActionContext instance. - - Parameters - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -def action_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str]) -> Action: - """Factory function to create an Action instance based on the configuration. - - Parameters - ---------- - config : Dict[str, Any] - The action configuration. - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - - Returns - ------- - Action - The created Action instance. - """ - from .concat import ConcatAction - from .data_sources import DataSourcesAction - from .function import FunctionAction - from .join import JoinAction - from .pipe import PipeAction - from .repeated_dates import RepeatedDatesAction - - # from .data_sources import DataSourcesAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - if len(config) != 1: - if "label" in config: - config.pop("label") - if "name" in config: - config.pop("name") - if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - "data_sources": DataSourcesAction, - "data-sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - "repeated_dates": RepeatedDatesAction, - "repeated-dates": RepeatedDatesAction, - }.get(key) - - if cls is None: - from ..sources import create_source - - source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source, config) - - return cls(context, action_path + [key], *args, **kwargs) + +class Join(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + + assert isinstance(config, list), f"Value must be a list {config}" + + self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Join({self.actions})" + + def __call__(self, context, argument): + results = context.empty_result() + + for action in self.actions: + results += action(context, argument) + + return context.register(results, self.path) + + def python_code(self, code) -> None: + return code.sum(a.python_code(code) for a in self.actions) + + +class Pipe(Action): + def __init__(self, config, *path): + assert isinstance(config, list), f"Value must be a list {config}" + super().__init__(config, *path) + self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Pipe({self.actions})" + + def __call__(self, context, argument): + result = context.empty_result() + + for i, action in enumerate(self.actions): + if i == 0: + result = action(context, argument) + else: + result = action(context, result) + + return context.register(result, self.path) + + def python_code(self, code) -> None: + return code.pipe(a.python_code(code) for a in self.actions) + + +class Function(Action): + def __init__(self, config, *path): + super().__init__(config, *path, self.name) + + def __call__(self, context, argument): + + config = context.resolve(self.config) # Substitute the ${} variables in the config + + config["_type"] = self.name # Find a better way to do this + + source = self.create_object(config) + + rich.print(f"Executing source {self.name} from {config}") + + return context.register(self.call_object(context, source, argument), self.path) + + def python_code(self, code) -> str: + return code.call(self.name, self.config) + + +class DatasetSourceMixin: + def create_object(self, config): + from anemoi.datasets.create.sources import create_source as create_datasets_source + + return create_datasets_source(self, config) + + def call_object(self, context, source, argument): + return source.execute(context, context.source_argument(argument)) + + +class DatasetFilterMixin: + def create_object(self, config): + from anemoi.datasets.create.filters import create_filter as create_datasets_filter + + return create_datasets_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.execute(context.filter_argument(argument)) + + +class TransformSourceMixin: + def create_object(self, config): + from anemoi.transform.sources import create_source as create_transform_source + + return create_transform_source(self, config) + + +class TransformFilterMixin: + def create_object(self, config): + from anemoi.transform.filters import create_filter as create_transform_filter + + return create_transform_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.forward(context.filter_argument(argument)) + + +class FilterFunction(Function): + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) + + +def _make_name(name, what): + name = name.replace("_", "-") + name = "".join(x.title() for x in name.split("-")) + return name + what.title() + + +def new_source(name, mixin): + return type( + _make_name(name, "source"), + (Function, mixin), + {"name": name}, + ) + + +def new_filter(name, mixin): + return type( + _make_name(name, "filter"), + (Function, mixin), + {"name": name}, + ) + + +KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} + +LEN_KLASS = len(KLASS) + + +def make(key, config, path): + + if LEN_KLASS == len(KLASS): + + # Load pluggins + from anemoi.transform.filters import filter_registry as transform_filter_registry + from anemoi.transform.sources import source_registry as transform_source_registry + + from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry + from anemoi.datasets.create.sources import source_registry as dataset_source_registry + + # Register sources, local first + for name in dataset_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + + for name in transform_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + + # Register filters, local first + for name in dataset_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) + + for name in transform_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) + + return KLASS[key.replace("_", "-")](config, *path) + + +def action_factory(data, *path): + assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" + + key, value = next(iter(data.items())) + return make(key, value, path) diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py deleted file mode 100644 index bd906bd03..000000000 --- a/src/anemoi/datasets/create/input/concat.py +++ /dev/null @@ -1,184 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from functools import cached_property -from typing import Any -from typing import Dict -from typing import List -from typing import Union - -from earthkit.data import FieldList - -from anemoi.datasets.dates import DatesProvider - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class ConcatResult(Result): - """Represents the result of concatenating multiple results.""" - - def __init__( - self, - context: object, - action_path: List[str], - group_of_dates: GroupOfDates, - results: List[Result], - **kwargs: Any, - ) -> None: - """Initializes a ConcatResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, group_of_dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the concatenated datasource from all results.""" - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self) -> List[str]: - """Returns the list of variables, ensuring all results have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self) -> str: - """Returns a string representation of the ConcatResult instance. - - Returns - ------- - str - A string representation of the ConcatResult instance. - """ - content = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class ConcatAction(Action): - """Represents an action that concatenates multiple actions based on their dates.""" - - def __init__(self, context: object, action_path: List[str], *configs: Dict[str, Any]) -> None: - """Initializes a ConcatAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - configs : Dict[str, Any] - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self) -> str: - """Returns a string representation of the ConcatAction instance. - - Returns - ------- - str - A string representation of the ConcatAction instance. - """ - content = "\n".join([str(i) for i in self.parts]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResult]: - """Selects the concatenated result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - Union[ConcatResult, EmptyResult] - The concatenated result or an empty result. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, group_of_dates) - - return ConcatResult(self.context, self.action_path, group_of_dates, results) - - def to_python(self) -> str: - """Returns the Python representation of the ConcatAction instance. - - Returns - ------- - str - The Python representation of the ConcatAction instance. - """ - - result = [] - - for i, (filtering_dates, action) in enumerate(self.parts): - result.append(f"{filtering_dates.to_python()}:{action.to_python()}") - - return f"r.concat({{{','.join(result)}}})" - - def python_prelude(self, prelude) -> None: - for filtering_dates, action in self.parts: - action.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py deleted file mode 100644 index 35784dba7..000000000 --- a/src/anemoi/datasets/create/input/context.py +++ /dev/null @@ -1,89 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import textwrap -from typing import Any -from typing import List -from typing import Tuple -from typing import Union - -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - -LOG = logging.getLogger(__name__) - - -class Context: - """Class to handle the build context in the dataset creation process.""" - - def __init__(self) -> None: - """Initializes a Context instance.""" - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key: Union[List, Tuple]) -> None: - """Marks a reference as needed. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key: Union[List, Tuple], result: Any) -> None: - """Notifies that a result is available for a reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - result : Any - The result object. - """ - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key: Union[List, Tuple]) -> Any: - """Retrieves the result for a given reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - - Returns - ------- - Any - The result for the given reference. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py new file mode 100644 index 000000000..26d449659 --- /dev/null +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -0,0 +1,63 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Context(ABC): + """Context for building input data.""" + + def __init__(self, /, argument: Any) -> None: + self.results = {} + self.cache = {} + self.argument = argument + + def trace(self, emoji, *message) -> None: + + rich.print(f"{emoji}: {message}") + + def register(self, data: Any, path: list[str]) -> Any: + + if not path: + return data + + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + def create_source(self, config: Any) -> Any: + from anemoi.datasets.create.input.action import action_factory + + return action_factory(config) + + @abstractmethod + def empty_result(self) -> Any: ... + + @abstractmethod + def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py new file mode 100644 index 000000000..c3456d89f --- /dev/null +++ b/src/anemoi/datasets/create/input/context/field.py @@ -0,0 +1,54 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any +from typing import Dict + +from earthkit.data.core.order import build_remapping + +from ..result.field import FieldResult +from . import Context + + +class FieldContext(Context): + + def __init__( + self, + /, + argument: Any, + order_by: str, + flatten_grid: bool, + remapping: Dict[str, Any], + use_grib_paramid: bool, + ) -> None: + super().__init__(argument) + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument + + def create_result(self, data): + return FieldResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 5b6282469..f9811d178 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -20,7 +20,7 @@ from .action import Action from .action import action_factory from .misc import _tidy -from .result import Result +from .result.field import Result LOG = logging.getLogger(__name__) @@ -87,13 +87,10 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_prelude(self, prelude) -> str: + def python_code(self, code) -> str: for n, s in zip(self.names, self.sources): - s.python_prelude(prelude) - prelude.append(f"{n}={s.to_python()}") - - def to_python(self) -> str: - return self.input.to_python() + code.source(n, s.python_code(code)) + return code class DataSourcesResult(Result): diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py deleted file mode 100644 index 410b4c973..000000000 --- a/src/anemoi/datasets/create/input/empty.py +++ /dev/null @@ -1,54 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import List - -from earthkit.data import FieldList - -from .misc import assert_fieldlist -from .result import Result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class EmptyResult(Result): - """Class to represent an empty result in the dataset creation process.""" - - empty = True - - def __init__(self, context: object, action_path: list, dates: object) -> None: - """Initializes an EmptyResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - dates : object - The dates object. - """ - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns an empty datasource.""" - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self) -> List[str]: - """Returns an empty list of variables.""" - return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py deleted file mode 100644 index 9357d2178..000000000 --- a/src/anemoi/datasets/create/input/filter.py +++ /dev/null @@ -1,133 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Type - -from earthkit.data import FieldList - -from .function import FunctionContext -from .misc import _tidy -from .misc import assert_fieldlist -from .step import StepAction -from .step import StepResult -from .template import notify_result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns the filtered datasource.""" - ds: FieldList = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - """Represents an action to filter a step result.""" - - result_class: Type[FilterStepResult] = FilterStepResult - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource after applying the function.""" - - self.action.filter.context = FunctionContext(self) - try: - return _tidy( - self.action.filter.execute( - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - A string representation of the traced datasource. - """ - return f"{self.action.name}({self.group_of_dates})" - - -class FunctionStepAction(StepAction): - """Represents an action to apply a function to a step result.""" - - result_class: Type[StepFunctionResult] = StepFunctionResult - - def __init__( - self, - context: object, - action_path: list, - previous_step: StepAction, - name: str, - filter: Any, - config: dict, - *args: Any, - **kwargs: Any, - ) -> None: - """Initializes a FunctionStepAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - previous_step : StepAction - The previous step action. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = name - self.filter = filter - self.config = config - - def to_python(self) -> str: - """Converts the action to Python code. - - Returns - ------- - str - The converted Python code. - """ - return self._to_python(self.name, self.config) - - def python_prelude(self, prelude) -> None: - pass diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py deleted file mode 100644 index 651b509b2..000000000 --- a/src/anemoi/datasets/create/input/function.py +++ /dev/null @@ -1,244 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Dict - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner: Result) -> None: - """Initializes a FunctionContext instance. - - Parameters - ---------- - owner : object - The owner object. - """ - self.owner = owner - self.use_grib_paramid: bool = owner.context.use_grib_paramid - - def trace(self, emoji: str, *args: Any) -> None: - """Traces the given arguments with an emoji. - - Parameters - ---------- - emoji : str - The emoji to use. - *args : Any - The arguments to trace. - """ - trace(emoji, *args) - - def info(self, *args: Any, **kwargs: Any) -> None: - """Logs an info message. - - Parameters - ---------- - *args : Any - The arguments for the log message. - **kwargs : Any - The keyword arguments for the log message. - """ - LOG.info(*args, **kwargs) - - @property - def dates_provider(self) -> object: - """Returns the dates provider.""" - return self.owner.group_of_dates.provider - - @property - def partial_ok(self) -> bool: - """Returns whether partial results are acceptable.""" - return self.owner.group_of_dates.partial_ok - - def get_result(self, *args, **kwargs) -> Any: - return self.owner.context.get_result(*args, **kwargs) - - -class FunctionAction(Action): - """Represents an action that executes a function. - - Attributes - ---------- - name : str - The name of the function. - """ - - def __init__( - self, context: object, action_path: list, _name: str, source, config: dict, **kwargs: Dict[str, Any] - ) -> None: - """Initializes a FunctionAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - _name : str - The name of the function. - **kwargs : Dict[str, Any] - Additional keyword arguments. - """ - super().__init__(context, action_path, **kwargs) - self.name: str = _name - self.source = source - self.config = config - - def to_python(self) -> str: - """Returns the Python representation of the function action.""" - - return self._to_python(self.name, self.config) - - def python_prelude(self, prelude) -> str: - pass - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": - """Selects the function result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - FunctionResult - The function result instance. - """ - return FunctionResult(self.context, self.action_path, group_of_dates, action=self) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionAction instance.""" - content: str = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return self._repr(_inline_=content, _indent_=" ") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Traces the selection of the function for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - str - The trace string. - """ - return f"{self.name}({group_of_dates})" - - -class FunctionResult(Result): - """Represents the result of executing a function. - - Attributes - ---------- - action : Action - The action instance. - args : tuple - The positional arguments for the function. - kwargs : dict - The keyword arguments for the function. - """ - - def __init__(self, context: object, action_path: list, group_of_dates: GroupOfDates, action: Action) -> None: - """Initializes a FunctionResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - action : Action - The action instance. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(action, Action), type(action) - self.action: Action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource for the function result.""" - # args, kwargs = resolve(self.context, (self.args, self.kwargs)) - self.action.source.context = FunctionContext(self) - - return _tidy( - self.action.source.execute( - list(self.group_of_dates), # Will provide a list of datetime objects - ) - ) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionResult instance.""" - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self) -> None: - """Raises NotImplementedError as this property is not implemented. - - Raises - ------ - NotImplementedError - Always raised. - """ - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py deleted file mode 100644 index 9fc81eac9..000000000 --- a/src/anemoi/datasets/create/input/join.py +++ /dev/null @@ -1,137 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import List - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class JoinResult(Result): - """Represents a result that combines multiple results. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - - def __init__( - self, context: object, action_path: list, group_of_dates: GroupOfDates, results: List[Result], **kwargs: Any - ) -> None: - """Initializes a JoinResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - super().__init__(context, action_path, group_of_dates) - self.results: List[Result] = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the combined datasource from all results.""" - ds: FieldList = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self) -> str: - """Returns a string representation of the JoinResult instance.""" - content: str = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class JoinAction(Action): - """Represents an action that combines multiple actions. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - actions : List[Action] - The list of actions. - """ - - def __init__(self, context: object, action_path: list, *configs: dict) -> None: - """Initializes a JoinAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - *configs : dict - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def to_python(self) -> None: - return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" - - def python_prelude(self, prelude) -> None: - for i in self.actions: - i.python_prelude(prelude) - - def __repr__(self) -> str: - """Returns a string representation of the JoinAction instance.""" - content: str = "\n".join([str(i) for i in self.actions]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> JoinResult: - """Selects the results for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - JoinResult - The combined result for the given group of dates. - """ - results: List[Result] = [a.select(group_of_dates) for a in self.actions] - return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py deleted file mode 100644 index 7b8f062f3..000000000 --- a/src/anemoi/datasets/create/input/pipe.py +++ /dev/null @@ -1,77 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -from typing import Any - -from .action import Action -from .action import action_factory -from .step import step_factory -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class PipeAction(Action): - """A class to represent a pipeline of actions.""" - - def __init__(self, context: Any, action_path: list, *configs: dict) -> None: - """Initialize the PipeAction. - - Parameters - ---------- - context : Any - The context for the action. - action_path : list - The path of the action. - configs : dict - The configurations for the actions. - """ - super().__init__(context, action_path, *configs) - if len(configs) <= 1: - raise ValueError( - f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" - ) - - self.actions: list = [] - - current: Any = action_factory(configs[0], context, action_path + ["0"]) - self.actions.append(current) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.actions.append(current) - self.last_step: Any = current - - @trace_select - def select(self, group_of_dates: Any) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to select data for. - - Returns - ------- - Any - The selected data. - """ - return self.last_step.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the PipeAction.""" - return f"PipeAction({self.last_step})" - - def to_python(self) -> str: - return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" - - def python_prelude(self, prelude) -> None: - for i in self.actions: - i.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index bc121c5c9..6304fcecc 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -9,7 +9,6 @@ import logging -import warnings from collections import defaultdict from typing import Any from typing import Dict @@ -28,7 +27,7 @@ from .action import Action from .action import action_factory from .join import JoinResult -from .result import Result +from .result.field import Result from .trace import trace_select LOG = logging.getLogger(__name__) @@ -204,21 +203,6 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) self.day: int = day self.hour: Optional[int] = hour - def to_python(self) -> Dict[str, Any]: - """Convert the DateMapper to Python code. - - Returns - ------- - dict - The Python code representation of the DateMapper. - """ - return { - "mode": "climatology", - "year": self.year, - "day": self.day, - "hour": self.hour, - } - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. @@ -369,16 +353,6 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.mode = mode self.kwargs = kwargs - def to_python(self) -> str: - """Convert the action to Python code.""" - warnings.warn("RepeatedDatesAction.to_python is still a work in progress") - args = {"mode": self.mode} - args.update(self.kwargs) - return self._to_python("repeated_dates", {"repeated_dates": args}, source=self.source.to_python()) - - def python_prelude(self, prelude: Any) -> None: - self.source.python_prelude(prelude) - @trace_select def select(self, group_of_dates: Any) -> JoinResult: """Select and transform the group of dates. diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result/__init__.py new file mode 100644 index 000000000..03a00c51d --- /dev/null +++ b/src/anemoi/datasets/create/input/result/__init__.py @@ -0,0 +1,17 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC + +LOG = logging.getLogger(__name__) + + +class Result(ABC): + pass diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result/field.py similarity index 87% rename from src/anemoi/datasets/create/input/result.py rename to src/anemoi/datasets/create/input/result/field.py index de5388fd6..dd238998e 100644 --- a/src/anemoi/datasets/create/input/result.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from .action import ActionContext -from .trace import trace -from .trace import trace_datasource +from . import Result LOG = logging.getLogger(__name__) @@ -282,40 +280,22 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Result: +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: ActionContext, action_path: List[str], dates: Any) -> None: - """Initialize a Result instance. + def __init__(self, context: Any, datasource: Any) -> None: - Parameters - ---------- - context : ActionContext - The context in which the result exists. - action_path : list of str - The action path. - dates : Any - The dates associated with the result. - """ from anemoi.datasets.dates.groups import GroupOfDates - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - self.context: Any = context - self.group_of_dates: Any = dates - self.action_path: List[str] = action_path - - @property - @trace_datasource - def datasource(self) -> Any: - """Retrieve the data source for the result.""" - self._raise_not_implemented() + self.datasource = datasource + self.group_of_dates = context.argument + assert isinstance( + self.group_of_dates, GroupOfDates + ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" @property def data_request(self) -> Dict[str, Any]: @@ -330,7 +310,7 @@ def get_cube(self) -> Any: Any The data cube. """ - trace("🧊", f"getting cube from {self.__class__.__name__}") + ds: Any = self.datasource remapping: Any = self.context.remapping @@ -523,66 +503,6 @@ def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: print() exit(1) - def _repr(self, *args: Any, _indent_: str = "\n", **kwargs: Any) -> str: - """Return the string representation of the Result instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str - Indentation string. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more: str = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates: str = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt: str = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Result instance.""" - return self._repr() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Trace the data source for the result. - - Parameters - ---------- - args : Any - Additional positional arguments. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({self.group_of_dates})" - def build_coords(self) -> None: """Build the coordinates for the result.""" if self._coords_already_built: diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py deleted file mode 100644 index a2e3ccd42..000000000 --- a/src/anemoi/datasets/create/input/step.py +++ /dev/null @@ -1,203 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Type - -from .action import Action -from .action import ActionContext -from .context import Context -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class StepResult(Result): - """Represents the result of a step in the data processing pipeline.""" - - def __init__( - self, context: Context, action_path: List[str], group_of_dates: Any, action: Action, upstream_result: Result - ) -> None: - """Initialize a StepResult instance. - - Parameters - ---------- - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - group_of_dates - The group of dates associated with this step. - action - The action associated with this step. - upstream_result - The result of the upstream step. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result: Result = upstream_result - self.action: Action = action - - @property - @notify_result - @trace_datasource - def datasource(self) -> Any: - """Retrieve the datasource associated with this step result.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - """Represents an action that is part of a step in the data processing pipeline.""" - - result_class: Optional[Type[StepResult]] = None - - def __init__( - self, context: ActionContext, action_path: List[str], previous_step: Any, *args: Any, **kwargs: Any - ) -> None: - """Initialize a StepAction instance. - - Parameters - ---------- - context - The context in which the action is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - """ - super().__init__(context, action_path, *args, **kwargs) - self.previous_step: Any = previous_step - - @trace_select - def select(self, group_of_dates: Any) -> StepResult: - """Select the result for a given group of dates. - - Parameters - ---------- - group_of_dates - The group of dates to select the result for. - - Returns - ------- - unknown - The result of the step. - """ - return self.result_class( - self.context, - self.action_path, - group_of_dates, - self, - self.previous_step.select(group_of_dates), - ) - - def __repr__(self) -> str: - """Return a string representation of the StepAction instance. - - Returns - ------- - unknown - String representation of the instance. - """ - return self._repr(self.previous_step, _inline_=str(self.kwargs)) - - -def step_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str], previous_step: Any) -> Any: - """Factory function to create a step action based on the given configuration. - - Parameters - ---------- - config - The configuration dictionary for the step. - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - - Returns - ------- - unknown - An instance of a step action. - """ - - from .filter import FilterStepAction - from .filter import FunctionStepAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is not None: - return cls(context, action_path, previous_step, *args, **kwargs) - - # Try filters from datasets filter registry - from anemoi.transform.filters import filter_registry as transform_filter_registry - - from ..filters import create_filter as create_datasets_filter - from ..filters import filter_registry as datasets_filter_registry - - if datasets_filter_registry.is_registered(key): - - if transform_filter_registry.is_registered(key): - warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") - - filter = create_datasets_filter(None, config) - return FunctionStepAction( - context, - action_path + [key], - previous_step, - key, - filter, - config, - ) - - # Use filters from transform registry - - if transform_filter_registry.is_registered(key): - from ..filters.transform import TransformFilter - - return FunctionStepAction( - context, - action_path + [key], - previous_step, - key, - TransformFilter(context, key, config), - config, - ) - - raise ValueError(f"Unknown step action `{key}`") diff --git a/src/anemoi/datasets/create/input/template.py b/src/anemoi/datasets/create/input/template.py deleted file mode 100644 index 8ea1ec275..000000000 --- a/src/anemoi/datasets/create/input/template.py +++ /dev/null @@ -1,162 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import re -from abc import ABC -from abc import abstractmethod -from functools import wraps -from typing import Any -from typing import Callable -from typing import List - -from .context import Context - -LOG = logging.getLogger(__name__) - - -def notify_result(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to notify the context of the result of the method call. - - Parameters - ---------- - method : Callable[..., Any] - The method to wrap. - - Returns - ------- - Callable[..., Any] - The wrapped method. - """ - - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result: Any = method(self, *args, **kwargs) - self.context.notify_result(self.action_path, result) - return result - - return wrapper - - -class Substitution(ABC): - """Abstract base class for substitutions in templates.""" - - @abstractmethod - def resolve(self, context: Context) -> Any: - """Resolve the substitution using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - pass - - -class Reference(Substitution): - """A class to represent a reference to another value in the context.""" - - def __init__(self, context: Any, action_path: List[str]) -> None: - """Initialize a Reference instance. - - Parameters - ---------- - context : Any - The context in which the reference exists. - action_path : list of str - The action path to resolve. - """ - self.context: Any = context - self.action_path: List[str] = action_path - - def resolve(self, context: Context) -> Any: - """Resolve the reference using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - return context.get_result(self.action_path) - - -def resolve(context: Context, x: Any) -> Any: - """Recursively resolve substitutions in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for resolution. - x : Union[tuple, list, dict, Substitution, Any] - The structure to resolve. - - Returns - ------- - Any - The resolved structure. - """ - if isinstance(x, tuple): - return tuple([resolve(context, y) for y in x]) - - if isinstance(x, list): - return [resolve(context, y) for y in x] - - if isinstance(x, dict): - return {k: resolve(context, v) for k, v in x.items()} - - if isinstance(x, Substitution): - return x.resolve(context) - - return x - - -def substitute(context: Context, x: Any) -> Any: - """Recursively substitute references in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for substitution. - x : Union[tuple, list, dict, str, Any] - The structure to substitute. - - Returns - ------- - Any - The substituted structure. - """ - if isinstance(x, tuple): - return tuple([substitute(context, y) for y in x]) - - if isinstance(x, list): - return [substitute(context, y) for y in x] - - if isinstance(x, dict): - return {k: substitute(context, v) for k, v in x.items()} - - if not isinstance(x, str): - return x - - if re.match(r"^\${[\.\w\-]+}$", x): - path = x[2:-1].split(".") - context.will_need_reference(path) - return Reference(context, path) - - return x diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py new file mode 100644 index 000000000..271845c2d --- /dev/null +++ b/src/anemoi/datasets/create/python.py @@ -0,0 +1,174 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime +import json +import re + +from anemoi.utils.dates import frequency_to_string + +# input.python_prelude(code) +# code1 = "\n".join(prelude) +# rich.print(f"Input prelude:\n{code1}") +# code2 = input.to_python() + +# code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" + +# code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) + +# try: +# import black + +# return black.format_str(code, mode=black.Mode()) +# except ImportError: +# LOG.warning("Black not installed, skipping formatting") +# return code +RESERVED_KEYWORDS = ( + "and", + "or", + "not", + "is", + "in", + "if", + "else", + "elif", + "for", + "while", + "return", + "class", + "def", + "with", + "as", + "import", + "from", + "try", + "except", + "finally", + "raise", + "assert", + "break", + "continue", + "pass", +) + + +class PythonCode: + + def call(self, name, argument): + return PythonCall(name, argument) + + def sum(self, actions): + return PythonChain("+", actions) + + def pipe(self, actions): + return PythonChain("|", actions) + + def concat(self, argument): + return PythonConcat(argument) + + +class PythonConcat(PythonCode): + def __init__(self, argument): + self.argument = argument + + def __repr__(self): + return str(self.argument) + + +class PythonCall(PythonCode): + def __init__(self, name, argument): + self.name = name + self.argument = argument + + def __repr__(self): + name = self.name.replace("-", "_") + config = self.argument + + # def convert(obj): + # if isinstance(obj, datetime.datetime): + # return obj.isoformat() + # if isinstance(obj, datetime.date): + # return obj.isoformat() + # if isinstance(obj, datetime.timedelta): + # return frequency_to_string(obj) + # if isinstance(obj, PythonCode): + # return obj + # raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + # config = json.loads(json.dumps(config, default=convert)) + + params = [] + for k, v in config.items(): + if isinstance(k, str): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + # for k, v in extra.items(): + # params.append(f"{k}={v}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" + return f"{self.name}({self.argument})" + + +class PythonChain(PythonCode): + def __init__(self, op, actions): + self.op = op + self.actions = actions + + def __repr__(self): + return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + +def _python(name, config, **extra) -> str: + """Convert the action to Python code. + + Parameters + ---------- + name : str + The name of the action. + config : dict + The configuration for the action. + extra : Any + Additional keyword arguments. + + Returns + ------- + str + The Python code representation of the action. + """ + + name = name.replace("-", "_") + + def convert(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, datetime.date): + return obj.isoformat() + if isinstance(obj, datetime.timedelta): + return frequency_to_string(obj) + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + config = json.loads(json.dumps(config, default=convert)) + + params = [] + for k, v in config.items(): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + for k, v in extra.items(): + params.append(f"{k}={v}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index a70af9bdd..405fd4713 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -459,12 +459,13 @@ def _mars_date_time_step( A tuple representing the MARS date-time step. """ assert user_date is None, user_date - assert not frequency, frequency steps = (step1 + add_step, step2 + add_step) if steps[0] == 0: steps = (steps[1],) + assert frequency == 0 or frequency == (step2 - step1), frequency + return ( base_date.year * 10000 + base_date.month * 100 + base_date.day, base_date.hour * 100 + base_date.minute, @@ -824,6 +825,11 @@ def _compute_accumulations( step1, step2 = user_accumulation_period assert step1 < step2, user_accumulation_period + if accumulations_reset_frequency is not None: + AccumulationClass = AccumulationFromLastReset + else: + AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep + if data_accumulation_period is None: data_accumulation_period = user_accumulation_period[1] - user_accumulation_period[0] @@ -838,11 +844,6 @@ def _compute_accumulations( base_times = [t // 100 if t > 100 else t for t in base_times] - if accumulations_reset_frequency is not None: - AccumulationClass = AccumulationFromLastReset - else: - AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep - mars_date_time_steps = AccumulationClass.mars_date_time_steps( dates=dates, step1=step1, diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 921469025..1958820c4 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -47,7 +47,7 @@ def constants(context: Any, dates: List[str], template: Dict[str, Any], param: s if len(template) == 0: raise ValueError("Forcings template is empty.") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index e1944e151..8e2977273 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -36,7 +36,7 @@ def forcings(context: Any, dates: List[str], template: str, param: str) -> Any: Loaded forcing data. """ context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute = forcings diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 8bedd4519..f87bf679c 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -124,7 +124,7 @@ def execute( dates = [d.isoformat() for d in dates] for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(*args, date=dates, **kwargs) + paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): if name in kwargs: diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d72d0b3f4..c76a11c9b 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,6 @@ from typing import Any from typing import Callable -from anemoi.datasets.create.input.template import resolve - from ..source import Source from . import source_registry @@ -71,13 +69,14 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, dates) -> Any: + def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(self.context, (self.args, self.kwargs)) + # args, kwargs = resolve(context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs try: - return execute(self.context, dates, *args, **kwargs) + return execute(context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") diff --git a/src/anemoi/datasets/create/sources/patterns.py b/src/anemoi/datasets/create/sources/patterns.py index f3e6334a8..dc105289f 100644 --- a/src/anemoi/datasets/create/sources/patterns.py +++ b/src/anemoi/datasets/create/sources/patterns.py @@ -79,6 +79,6 @@ def iterate_patterns( kwargs["date"] = dates for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(**kwargs) + paths = Pattern(path).substitute(allow_extra=True, **kwargs) for path in _expand(paths): yield path, dates diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py new file mode 100644 index 000000000..b710bcbbe --- /dev/null +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from . import source_registry +from .xarray import XarraySourceBase + + +@source_registry.register("planetary_computer") +class PlanetaryComputerSource(XarraySourceBase): + """An Xarray data source for the planetary_computer.""" + + emoji = "🪐" + + def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict): + + import planetary_computer + import pystac_client + + self.data_catalog_id = data_catalog_id + self.flavour = kwargs.pop("flavour", None) + self.patch = kwargs.pop("patch", None) + self.options = kwargs.pop("options", {}) + + catalog = pystac_client.Client.open( + f"https://planetarycomputer.microsoft.com/api/stac/{version}/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + self.options["storage_options"] = asset.extra_fields["xarray:storage_options"] + + self.options.update(asset.extra_fields["xarray:open_kwargs"]) + + super().__init__(context, url=asset.href, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py new file mode 100644 index 000000000..d092f08ad --- /dev/null +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -0,0 +1,319 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import Generator +from typing import Optional +from typing import Set +from typing import Tuple + +import numpy as np +import rich +from anemoi.transform.fields import new_field_with_valid_datetime +from anemoi.transform.fields import new_fieldlist_from_list +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +LOG = logging.getLogger(__name__) + + +class Action: + pass + + +class Result: + pass + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: Set[Any] = set() + self.found: Set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: Optional[int] = hour + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Optional[Any] = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Optional[Any] = date + + def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +@source_registry.register("repeated_dates") +class RepeatedDatesSource(Source): + + def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.mapper = DateMapper.from_mode(mode, source, kwargs) + self.source = source + + def execute(self, context, group_of_dates): + source = context.create_source(self.source) + + result = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + source_results = source(context, one_date_group) + for field in source_results: + for date in many_dates_group: + result.append(new_field_with_valid_datetime(field, date)) + + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 4f4edb46f..665cfdad3 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -20,7 +20,6 @@ from earthkit.data.core.fieldlist import MultiFieldList from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.data.stores import name_to_zarr_store from ..legacy import legacy_source from .fieldlist import XarrayFieldList @@ -89,37 +88,22 @@ def load_one( The loaded dataset. """ - """ - We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack - zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of - connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs. - See https://github.com/pydata/xarray/issues/8842 - - We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`. - """ - if options is None: options = {} context.trace(emoji, dataset, options, kwargs) - if isinstance(dataset, str) and ".zarr" in dataset: - data = xr.open_zarr(name_to_zarr_store(dataset), **options) - elif "planetarycomputer" in dataset: - store = name_to_zarr_store(dataset) - if "store" in store: - data = xr.open_zarr(**store) - if "filename_or_obj" in store: - data = xr.open_dataset(**store) - else: - data = xr.open_dataset(dataset, **options) + if isinstance(dataset, str) and dataset.endswith(".zarr"): + # If the dataset is a zarr store, we need to use the zarr engine + options["engine"] = "zarr" + + data = xr.open_dataset(dataset, **options) fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: result = fs.sel(**kwargs) else: - print("dates", dates, kwargs) result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) if len(result) == 0: @@ -130,7 +114,7 @@ def load_one( a = ["valid_datetime", k.metadata("valid_datetime", default=None)] for n in kwargs.keys(): a.extend([n, k.metadata(n, default=None)]) - print([str(x) for x in a]) + LOG.warning(f"{[str(x) for x in a]}") if i > 16: break diff --git a/src/anemoi/datasets/create/sources/xarray_support/coordinates.py b/src/anemoi/datasets/create/sources/xarray_support/coordinates.py index 58df7ad65..161f28b8a 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/coordinates.py +++ b/src/anemoi/datasets/create/sources/xarray_support/coordinates.py @@ -95,6 +95,7 @@ class Coordinate: is_member = False is_x = False is_y = False + is_point = False def __init__(self, variable: xr.DataArray) -> None: """Initialize the coordinate. @@ -390,6 +391,13 @@ def normalise(self, value: Any) -> Any: return value +class PointCoordinate(Coordinate): + """Coordinate class for point data.""" + + is_point = True + mars_names = ("point",) + + class LongitudeCoordinate(Coordinate): """Coordinate class for longitude.""" diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index d46613474..6f4ecca7b 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -96,13 +96,10 @@ def __init__(self, owner: Any, selection: Any) -> None: if alias not in self._md: self._md[alias] = value - # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude self._shape = tuple(list(self.selection.shape)[-2:]) if math.prod(self._shape) != math.prod(self.selection.shape): - print(self.selection.ndim, self.selection.shape) - print(self.selection) - raise ValueError("Invalid shape for selection") + raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}") @property def shape(self) -> Tuple[int, int]: diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 4df374148..562401aae 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -26,6 +26,7 @@ from .coordinates import LatitudeCoordinate from .coordinates import LevelCoordinate from .coordinates import LongitudeCoordinate +from .coordinates import PointCoordinate from .coordinates import ScalarCoordinate from .coordinates import StepCoordinate from .coordinates import TimeCoordinate @@ -134,6 +135,10 @@ def _guess(self, coordinate: xr.DataArray, coord: Hashable) -> Coordinate: d: Optional[Coordinate] = None + d = self._is_point(coordinate, attributes) + if d is not None: + return d + d = self._is_longitude(coordinate, attributes) if d is not None: return d @@ -308,9 +313,9 @@ def _x_y_provided(self, x: Any, y: Any, variable: Any) -> Any: return self._grid_cache[(x.name, y.name, dim_vars)] grid_mapping = variable.attrs.get("grid_mapping", None) - if grid_mapping is not None: - print(f"grid_mapping: {grid_mapping}") - print(self.ds[grid_mapping]) + # if grid_mapping is not None: + # print(f"grid_mapping: {grid_mapping}") + # print(self.ds[grid_mapping]) if grid_mapping is None: LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'") @@ -392,6 +397,10 @@ def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Op """ pass + @abstractmethod + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + pass + @abstractmethod def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]: """Checks if the coordinate is a latitude. @@ -550,6 +559,15 @@ def __init__(self, ds: xr.Dataset) -> None: """ super().__init__(ds) + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + if attributes.standard_name in ["cell", "station", "poi", "point"]: + return PointCoordinate(c) + + if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench + return PointCoordinate(c) + + return None + def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]: """Checks if the coordinate is a longitude. @@ -750,6 +768,9 @@ def _is_level(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Option if attributes.standard_name == "air_pressure" and attributes.units == "hPa": return LevelCoordinate(c, "pl") + if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]: + return LevelCoordinate(c, "pl") + if attributes.name == "level": return LevelCoordinate(c, "pl") @@ -759,9 +780,6 @@ def _is_level(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Option if attributes.standard_name == "depth": return LevelCoordinate(c, "depth") - if attributes.name == "vertical" and attributes.units == "hPa": - return LevelCoordinate(c, "pl") - return None def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]: @@ -1040,3 +1058,23 @@ def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optio return EnsembleCoordinate(c) return None + + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + """Checks if the coordinate is a point coordinate using the flavour rules. + + Parameters + ---------- + c : xr.DataArray + The coordinate to check. + attributes : CoordinateAttributes + The attributes of the coordinate. + + Returns + ------- + Optional[PointCoordinate] + The StepCoorPointCoordinateinate if matched, else None. + """ + if self._match(c, "point", attributes): + return PointCoordinate(c) + + return None diff --git a/src/anemoi/datasets/create/sources/xarray_support/patch.py b/src/anemoi/datasets/create/sources/xarray_support/patch.py index 29ea620dd..da4ce6a9c 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/patch.py +++ b/src/anemoi/datasets/create/sources/xarray_support/patch.py @@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any: return ds +def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any: + """Rename variables in the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + renames : dict[str, str] + Mapping from old variable names to new variable names. + + Returns + ------- + Any + The patched dataset. + """ + return ds.rename(renames) + + +def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any: + """Sort the coordinates of the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + sort_coordinates : List[str] + The coordinates to sort. + + Returns + ------- + Any + The patched dataset. + """ + + for name in sort_coordinates: + ds = ds.sortby(name) + return ds + + PATCHES = { "attributes": patch_attributes, "coordinates": patch_coordinates, + "rename": patch_rename, + "sort_coordinates": patch_sort_coordinate, } @@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any: Any The patched dataset. """ - for what, values in patch.items(): + + ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"] + for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])): if what not in PATCHES: raise ValueError(f"Unknown patch type {what!r}") diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 7b0c67439..ee4e20ac7 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -82,8 +82,12 @@ def __init__( self.time = time - self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid) - self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} + self.shape = tuple( + len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point + ) + self.names = { + c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point + } self.by_name = {c.variable.name: c for c in coordinates} # We need that alias for the time dimension diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py index 1784a61fe..1b503ba06 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/data/complement.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import datetime import logging from abc import abstractmethod @@ -19,6 +18,7 @@ from typing import Set from typing import Tuple +import numpy as np from numpy.typing import NDArray from ..grids import nearest_grid_points @@ -85,6 +85,7 @@ def __init__( for v in self._source.variables: if v not in self._target.variables: self._variables.append(v) + LOG.info(f"The following variables will be complemented: {self._variables}") if not self._variables: raise ValueError("Augment: no missing variables") @@ -96,9 +97,11 @@ def variables(self) -> List[str]: @property def statistics(self) -> Dict[str, NDArray[Any]]: - """Returns the statistics of the complemented dataset.""" - index = [self._source.name_to_index[v] for v in self._variables] - return {k: v[index] for k, v in self._source.statistics.items()} + datasets = [self._source, self._target] + return { + k: [d.statistics[k][d.name_to_index[i]] for d in datasets for i in d.variables if i in self.variables] + for k in datasets[0].statistics + } def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]: index = [self._source.name_to_index[v] for v in self._variables] @@ -120,7 +123,11 @@ def shape(self) -> Shape: @property def variables_metadata(self) -> Dict[str, Any]: """Returns the metadata of the variables to be added to the target dataset.""" - return {k: v for k, v in self._source.variables_metadata.items() if k in self._variables} + # Merge the two dicts first + all_meta = {**self._source.variables_metadata, **self._target.variables_metadata} + + # Filter to keep only desired variables + return {k: v for k, v in all_meta.items() if k in self._variables} def check_same_variables(self, d1: Dataset, d2: Dataset) -> None: """Checks if the variables in two datasets are the same. @@ -231,7 +238,7 @@ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]: class ComplementNearest(Complement): """A class to complement a target dataset with variables from a source dataset using nearest neighbor interpolation.""" - def __init__(self, target: Any, source: Any, max_distance: float = None) -> None: + def __init__(self, target: Any, source: Any, max_distance: float = None, k: int = 1) -> None: """Initializes the ComplementNearest class. Parameters @@ -242,17 +249,25 @@ def __init__(self, target: Any, source: Any, max_distance: float = None) -> None The source dataset. max_distance : float, optional The maximum distance for nearest neighbor interpolation, default is None. + k : int, optional + The number of k closest neighbors to consider for interpolation """ super().__init__(target, source) - self._nearest_grid_points = nearest_grid_points( + self.k = k + self._distances, self._nearest_grid_points = nearest_grid_points( self._source.latitudes, self._source.longitudes, self._target.latitudes, self._target.longitudes, max_distance=max_distance, + k=k, ) + if k == 1: + self._distances = np.expand_dims(self._distances, axis=1) + self._nearest_grid_points = np.expand_dims(self._nearest_grid_points, axis=1) + def check_compatibility(self, d1: Dataset, d2: Dataset) -> None: """Checks the compatibility of two datasets for nearest neighbor interpolation. @@ -285,7 +300,19 @@ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]: source_data = self._source[index[0], source_index, index[2], ...] target_data = source_data[..., self._nearest_grid_points] - result = target_data[..., index[3]] + epsilon = 1e-8 # prevent division by zero + weights = 1.0 / (self._distances + epsilon) + weights = weights.astype(target_data.dtype) + weights /= weights.sum(axis=1, keepdims=True) # normalize + + # Reshape weights to broadcast correctly + # Add leading singleton dimensions so it matches target_data shape + while weights.ndim < target_data.ndim: + weights = np.expand_dims(weights, axis=0) + + # Compute weighted average along the last dimension + final_point = np.sum(target_data * weights, axis=-1) + result = final_point[..., index[3]] return apply_index_to_slices_changes(result, changes) @@ -330,6 +357,13 @@ def complement_factory(args: Tuple, kwargs: dict) -> Dataset: "nearest": ComplementNearest, }[interpolation] - complement = Class(target=target, source=source)._subset(**kwargs) + if interpolation == "nearest": + k = kwargs.pop("k", "1") + complement = Class(target=target, source=source, k=k)._subset(**kwargs) + + else: + complement = Class(target=target, source=source)._subset(**kwargs) + + joined = _open_dataset([target, complement]) - return _open_dataset([target, complement], reorder=source.variables) + return _open_dataset(joined, reorder=sorted(joined.variables)) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index f8592bb3d..0267022d1 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -179,6 +179,19 @@ def __subset(self, **kwargs: Any) -> "Dataset": if "start" in kwargs or "end" in kwargs: start = kwargs.pop("start", None) end = kwargs.pop("end", None) + padding = kwargs.pop("padding", None) + + if padding: + if padding != "empty": + raise ValueError(f"Only 'empty' padding is supported, got {padding=}") + from .padded import Padded + + frequency = kwargs.pop("frequency", self.frequency) + return ( + Padded(self, start, end, frequency, dict(start=start, end=end, frequency=frequency)) + ._subset(**kwargs) + .mutate() + ) from .subset import Subset @@ -724,6 +737,9 @@ def grids(self) -> TupleIndex: """Return the grid shape of the dataset.""" return (self.shape[-1],) + def empty_item(self) -> NDArray[Any]: + return np.zeros((*self.shape[1:-1], 0), dtype=self.dtype) + def _check(self) -> None: """Check for overridden private methods in the dataset.""" common = Dataset.__dict__.keys() & self.__class__.__dict__.keys() @@ -1075,3 +1091,16 @@ def get_dataset_names(self, names: Set[str]) -> None: The dataset names. """ pass + + def get_latitudes(self, i): + return self.get_aux(i)[0] + + def get_longitudes(self, i): + return self.get_aux(i)[1] + + def get_timedeltas(self, i): + return self.get_aux(i)[2] + + def get_aux(self, i): + # need to decide if Fields datasets need to implement this + raise NotImplementedError(f"get_aux is not implemented for this dataset, {type(self)}") diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 09d936efa..bb922e96d 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -330,8 +330,14 @@ def check_same_grid(self, d1: Dataset, d2: Dataset) -> None: ValueError If the grids are not the same. """ - if (d1.latitudes != d2.latitudes).any() or (d1.longitudes != d2.longitudes).any(): - raise ValueError(f"Incompatible grid ({d1} {d2})") + + # note: not a proper implementation, should be handled + # in a more consolidated way ... + rtol = 1.0e-7 + if not np.allclose(d1.latitudes, d2.latitudes, rtol=rtol) or not np.allclose( + d1.longitudes, d2.longitudes, rtol=rtol + ): + raise ValueError(f"Incompatible grid ({d1.longitudes} {d2.longitudes})") def check_same_shape(self, d1: Dataset, d2: Dataset) -> None: """Checks if the shapes of two datasets are the same. diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 2f9f77b25..b5523ef85 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -11,6 +11,7 @@ import calendar import datetime import logging +import os from pathlib import PurePath from typing import TYPE_CHECKING from typing import Any @@ -22,6 +23,7 @@ import numpy as np import zarr +from anemoi.utils.config import load_any_dict_format from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray @@ -108,7 +110,10 @@ def round_datetime(d: np.datetime64, dates: NDArray[np.datetime64], up: bool) -> def _as_date( - d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64], last: bool + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + last: bool, + frequency: Optional[datetime.timedelta] = None, ) -> np.datetime64: """Convert a date to a numpy datetime64 object, rounding to the nearest date in a list of dates. @@ -120,6 +125,8 @@ def _as_date( The list of dates. last : bool Whether to round to the last date. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- @@ -142,30 +149,49 @@ def _as_date( pass if isinstance(d, int): + delta = frequency + if delta is None: + delta = np.timedelta64(1, "s") + delta = np.timedelta64(delta, "s") + if len(str(d)) == 4: year = d if last: - return _as_date(np.datetime64(f"{year:04}-12-31T23:59:59"), dates, last) + year = year + 1 + npdate = np.datetime64(f"{year:04}-01-01T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last, frequency) if len(str(d)) == 6: year = d // 100 month = d % 100 if last: - _, last_day = calendar.monthrange(year, month) - return _as_date(np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59"), dates, last) + month = month + 1 + if month > 12: + month = 1 + year += 1 + npdate = np.datetime64(f"{year:04}-{month:02}-01T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last, frequency) if len(str(d)) == 8: year = d // 10000 month = (d % 10000) // 100 day = d % 100 if last: - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59"), dates, last) + day = day + 1 + if day > calendar.monthrange(year, month)[1]: + day = 1 + month += 1 + if month > 12: + month = 1 + year += 1 + npdate = np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last, frequency) if isinstance(d, str): @@ -201,19 +227,20 @@ def isfloat(s: str) -> bool: np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}"), dates, last, + frequency, ) if "-" in d: assert ":" not in d bits = d.split("-") if len(bits) == 1: - return _as_date(int(bits[0]), dates, last) + return _as_date(int(bits[0]), dates, last, frequency) if len(bits) == 2: - return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last) + return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last, frequency) if len(bits) == 3: - return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last) + return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last, frequency) if ":" in d: assert len(d) == 5 @@ -225,12 +252,16 @@ def isfloat(s: str) -> bool: month = first.month day = first.day - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last, frequency) raise NotImplementedError(f"Unsupported date: {d} ({type(d)})") -def as_first_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64]) -> np.datetime64: +def as_first_date( + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + frequency: Optional[datetime.timedelta] = None, +) -> np.datetime64: """Convert a date to the first date in a list of dates. Parameters @@ -239,16 +270,22 @@ def as_first_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArr The date to convert. dates : NDArray[np.datetime64] The list of dates. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- np.datetime64 The first date. """ - return _as_date(d, dates, last=False) + return _as_date(d, dates, last=False, frequency=frequency) -def as_last_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64]) -> np.datetime64: +def as_last_date( + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + frequency: Optional[datetime.timedelta] = None, +) -> np.datetime64: """Convert a date to the last date in a list of dates. Parameters @@ -257,13 +294,15 @@ def as_last_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArra The date to convert. dates : NDArray[np.datetime64] The list of dates. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- np.datetime64 The last date. """ - return _as_date(d, dates, last=True) + return _as_date(d, dates, last=True, frequency=frequency) def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple["Dataset", Dict[str, Any]]: @@ -317,6 +356,18 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - from .stores import Zarr from .stores import zarr_lookup + if isinstance(a, str) and len(a.split(".")) in [2, 3]: + + metadata_path = os.path.join(a, "metadata.json") + if os.path.exists(metadata_path): + metadata = load_any_dict_format(metadata_path) + if "backend" not in metadata: + raise ValueError(f"Metadata for {a} does not contain 'backend' key") + + from anemoi.datasets.data.records import open_records_dataset + + return open_records_dataset(a, backend=metadata["backend"]) + if isinstance(a, Dataset): return a.mutate() @@ -454,6 +505,13 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": for a in args: sets.append(_open(a)) + if "observations" in kwargs: + from .observations import observations_factory + + assert not sets, sets + + return observations_factory(args, kwargs).mutate() + if "xy" in kwargs: # Experimental feature, may be removed from .xy import xy_factory diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py new file mode 100644 index 000000000..b5f8ec5e9 --- /dev/null +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -0,0 +1,316 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from functools import cached_property +from typing import Any +from typing import Dict +from typing import Tuple + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.data.dataset import Dataset + +from ..debug import Node + +LOG = logging.getLogger(__name__) + + +def round_datetime(dt, frequency, up=True): + dt = dt.replace(minute=0, second=0, microsecond=0) + hour = dt.hour + if hour % frequency != 0: + dt = dt.replace(hour=(hour // frequency) * frequency) + dt = dt + datetime.timedelta(hours=frequency) + return dt + + +def make_dates(start, end, frequency): + if isinstance(start, np.datetime64): + start = start.astype(datetime.datetime) + if isinstance(end, np.datetime64): + end = end.astype(datetime.datetime) + + dates = [] + current_date = start + while current_date <= end: + dates.append(current_date) + current_date += frequency + + dates = [np.datetime64(d, "s") for d in dates] + dates = np.array(dates, dtype="datetime64[s]") + return dates + + +class ObservationsBase(Dataset): + resolution = None + + @cached_property + def shape(self): + return (len(self.dates), len(self.variables), "dynamic") + + def empty_item(self): + return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32) + + def metadata(self): + return dict(observations_datasets="obs datasets currenty have no metadata") + + def _check(self): + pass + + def __len__(self): + return len(self.dates) + + def tree(self): + return Node(self) + + def __getitem__(self, i): + if isinstance(i, int): + return self.getitem(i) + + # The following may would work but is likely to change in the future + # if isinstance(i, slice): + # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))] + # if isinstance(i, list): + # return [self.getitem(j) for j in i] + + raise ValueError( + ( + f"Expected int, got {i} of type {type(i)}. Only int is supported to index " + "observations datasets. Please use a second [] to select part of the data [i][a,b,c]" + ) + ) + + @property + def variables(self): + raise NotImplementedError() + + def collect_input_sources(self): + LOG.warning("collect_input_sources method is not implemented") + return [] + + def constant_fields(self): + LOG.warning("constant_fields method is not implemented") + return [] + + @property + def dates(self): + return self._dates + + @property + def dtype(self): + return np.float32 + + @property + def field_shape(self): + return self.shape[1:] + + @property + def frequency(self): + assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}" + return self._frequency + + @property + def latitudes(self): + raise NotImplementedError("latitudes property is not implemented") + + @property + def longitudes(self): + raise NotImplementedError("longitudes property is not implemented") + + @property + def missing(self): + return [] + + def statistics_tendencies(self): + raise NotImplementedError("statistics_tendencies method is not implemented") + + def variables_metadata(self): + raise NotImplementedError("variables_metadata method is not implemented") + + +class ObservationsZarr(ObservationsBase): + def __init__(self, dataset, frequency=None, window=None): + import zarr + + if isinstance(dataset, zarr.hierarchy.Group): + dataset = dataset._store.path + + from ..stores import zarr_lookup + + dataset = zarr_lookup(dataset) + self.path = dataset + assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" + + if frequency is None: + frequency = self._probe_attributes.get("frequency") + # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}") + if frequency is None: + frequency = "6h" + # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}") + self._frequency = frequency_to_timedelta(frequency) + assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}" + if self.frequency.total_seconds() != 6 * 3600: + LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown") + + frequency_hours = int(self.frequency.total_seconds() // 3600) + assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}" + + if window is None: + window = (-frequency_hours, 0) + if window != (-frequency_hours, 0): + raise ValueError("For now, only window = (- frequency, 0) are supported") + + self.window = window + + start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"] + start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end) + start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours) + + self._dates = make_dates(start + self.frequency, end, self.frequency) + + first_window_begin = start.strftime("%Y%m%d%H%M%S") + first_window_begin = int(first_window_begin) + # last_window_end must be the end of the time window of the last item + last_window_end = int(end.strftime("%Y%m%d%H%M%S")) + + from .legacy_obs_dataset import ObsDataset + + args = [self.path, first_window_begin, last_window_end] + kwargs = dict( + len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item + step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items + ) + self.forward = ObsDataset(*args, **kwargs) + + assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}" + assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}" + + if len(self.forward) != len(self.dates): + raise ValueError( + ( + f"Dates are not consistent with the number of items in the dataset. " + f"The dataset contains {len(self.forward)} time windows. " + f"This is not compatible with the " + f"{len(self.dates)} requested dates with frequency={frequency_hours}" + f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} " + ) + ) + + @property + def source(self): + return self.path + + def get_dataset_names(self): + name = os.path.basename(self.path) + if name.endswith(".zarr"): + name = name[:-5] + return [name] + + @cached_property + def _probe_attributes(self): + import zarr + + z = zarr.open(self.path, mode="r") + return dict(z.data.attrs) + + def get_aux(self, i): + data = self.forward[i] + + latitudes = data[:, self.name_to_index["__latitudes"]].numpy() + longitudes = data[:, self.name_to_index["__longitudes"]].numpy() + + reference = self.dates[i] + times = self.forward.get_dates(i) + if str(times.dtype) != "datetime64[s]": + LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ") + times = times.astype("datetime64[s]") + assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}" + timedeltas = times - reference + + assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" + assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" + + return latitudes, longitudes, timedeltas + + def getitem(self, i): + data = self.forward[i] + + data = data.numpy().astype(np.float32) + assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}" + data = data.T + + if not data.size: + data = self.empty_item() + assert ( + data.shape[0] == self.shape[1] + ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}" + return data + + @cached_property + def variables(self): + colnames = self.forward.colnames + variables = [] + for n in colnames: + if n.startswith("obsvalue_"): + n = n.replace("obsvalue_", "") + if n == "latitude" or n == "lat": + assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}" + variables.append("__latitudes") + continue + if n == "longitude" or n == "lon": + assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}" + variables.append("__longitudes") + continue + assert not n.startswith("__"), f"Invalid name {n} found in {colnames}" + variables.append(n) + return variables + + @property + def name_to_index(self): + return {n: i for i, n in enumerate(self.variables)} + + @property + def statistics(self): + mean = self.forward.properties["means"] + mean = np.array(mean, dtype=np.float32) + + var = self.forward.properties["vars"] + var = np.array(var, dtype=np.float32) + stdev = np.sqrt(var) + + minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32) + maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32) + + assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}" + assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}" + assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}" + assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}" + return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum) + + def tree(self): + return Node( + self, + [], + path=self.path, + frequency=self.frequency, + ) + + def __repr__(self): + return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})" + + +def observations_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> ObservationsBase: + observations = kwargs.pop("observations") + + if not isinstance(observations, dict): + observations = dict(dataset=observations) + dataset = ObservationsZarr(**observations) + return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/data/observations/legacy_obs_dataset.py b/src/anemoi/datasets/data/observations/legacy_obs_dataset.py new file mode 100644 index 000000000..85ab51583 --- /dev/null +++ b/src/anemoi/datasets/data/observations/legacy_obs_dataset.py @@ -0,0 +1,200 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import numpy as np +import pandas as pd +import torch +import zarr +from torch.utils.data import Dataset + +LOG = logging.getLogger(__name__) + + +class ObsDataset(Dataset): + + def __init__( + self, + filename: str, + start: int, + end: int, + len_hrs: int, + step_hrs: int = None, + select: list[str] = None, + drop: list[str] = None, + ) -> None: + + self.filename = filename + self.z = zarr.open(filename, mode="r") + self.data = self.z["data"] + self.dt = self.z["dates"] # datetime only + self.hrly_index = self.z["idx_197001010000_1"] + self.colnames = self.data.attrs["colnames"] + self.selected_colnames = self.colnames + self.selected_cols_idx = np.arange(len(self.colnames)) + self.len_hrs = len_hrs + self.step_hrs = step_hrs if step_hrs else len_hrs + + # Create index for samples + self._setup_sample_index(start, end, self.len_hrs, self.step_hrs) + + self._load_properties() + + if select: + self.select(select) + + if drop: + self.drop(drop) + + def __getitem__( + self, + idx: int, + ) -> torch.tensor: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + data = self.data.oindex[start_row:end_row, self.selected_cols_idx] + + return torch.from_numpy(data) + + def __len__(self) -> int: + + return len(self.indices_start) + + def get_dates( + self, + idx: int, + ) -> np.ndarray: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + dates = self.dt.oindex[start_row:end_row] + + assert len(dates.shape) == 2, dates.shape + dates = dates[:, 0] + + if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"): + dates = dates.astype("datetime64[s]") + + return dates + + def get_df(self, idx: int) -> pd.DataFrame: + """Convenience function to return data for sample idx packaged in a pandas DataFrame""" + + d = self.__getitem__(idx) + + df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx]) + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + dts = self.dt[start_row:end_row, :] + df["datetime"] = dts + + return df + + def select(self, cols_list: list[str]) -> None: + """Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + """ + self.selected_colnames = cols_list + self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list]) + + def drop(self, cols_to_drop: list[str]) -> None: + """Allow user to drop specific columns from the dataset. + Get functions no longer return data for these columns after being set. + """ + mask = [name not in cols_to_drop for name in self.selected_colnames] + + self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep] + self.selected_cols_idx = self.selected_cols_idx[mask] + + def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: + """Returns a tuple of datetime objects describing the start and end times of the sample at position idx.""" + + if idx < 0: + idx = len(self) + idx + + time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1) + time_end = min( + self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)), + self.end_dt, + ) + + return (np.datetime64(time_start), np.datetime64(time_end)) + + def first_sample_with_data(self) -> int: + """Returns the position of the first sample which contains data.""" + return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None + + def last_sample_with_data(self) -> int: + """Returns the position of the last sample which contains data.""" + if self.indices_end.max() == 0: + last_sample = None + else: + last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1) + + return last_sample + + def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None: + """Dataset is divided into samples; + - each n_hours long + - sample 0 starts at start (yyyymmddhhmm) + - index array has one entry for each sample; contains the index of the first row + containing data for that sample + """ + + try: + from obsdata.config import config + + assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000" + except ImportError: + pass + base_yyyymmddhhmm = 197001010000 + + assert start > base_yyyymmddhhmm, ( + f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n" + f" Current value: {start}" + ) + + format_str = "%Y%m%d%H%M%S" + base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str) + self.start_dt = datetime.datetime.strptime(str(start), format_str) + self.end_dt = datetime.datetime.strptime(str(end), format_str) + + # Calculate hours since the base date for the requested dataset ranges + diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600) + diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600) + + # Find elements that need to be extracted from the hourly index + # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample + sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs) + sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end) + + # Initialize local index arrays + self.indices_start = np.zeros(sample_starts.shape, dtype=int) + self.indices_end = np.zeros(self.indices_start.shape, dtype=int) + + max_hrly_index = self.hrly_index.shape[0] - 1 + valid_start_mask = sample_starts <= max_hrly_index + valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index) + + # Copy elements from the hrly_index into the local index + self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]] + self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0) + + def _load_properties(self) -> None: + + self.properties = {} + + self.properties["means"] = self.data.attrs["means"] + self.properties["vars"] = self.data.attrs["vars"] + self.properties["data_idxs"] = self.data.attrs["data_idxs"] + self.properties["obs_id"] = self.data.attrs["obs_id"] diff --git a/src/anemoi/datasets/data/observations/multi.py b/src/anemoi/datasets/data/observations/multi.py new file mode 100644 index 000000000..af5c02e71 --- /dev/null +++ b/src/anemoi/datasets/data/observations/multi.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os + +from anemoi.datasets.data import open_dataset + +LOG = logging.getLogger(__name__) + + +class LegacyDatasets: + def __init__(self, paths, start=None, end=None, **kwargs): + self.paths = paths + + if not start or not end: + print( + "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets" + ) + lst = [self._open_dataset(p, **kwargs) for p in paths] + start = min([d.dates[0] for d in lst]) + end = max([d.dates[-1] for d in lst]) + + self._datasets = { + os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty") + for p in paths + } + + first = list(self._datasets.values())[0] + for name, dataset in self._datasets.items(): + if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]: + raise ValueError("Datasets have different start and end times") + if dataset.frequency != first.frequency: + raise ValueError("Datasets have different frequencies") + + self._keys = self._datasets.keys + + self._first = list(self._datasets.values())[0] + + def _open_dataset(self, p, **kwargs): + if p.startswith("observations-"): + return open_dataset(observations=p, **kwargs) + else: + print("❗ Opening non-observations dataset:", p) + return open_dataset(p, **kwargs) + + def items(self): + return self._datasets.items() + + @property + def dates(self): + return self._first.dates + + def __len__(self): + return len(self._first) + + def __getitem__(self, i): + return {k: d[i] for k, d in self._datasets.items()} diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py new file mode 100644 index 000000000..dcff11bae --- /dev/null +++ b/src/anemoi/datasets/data/padded.py @@ -0,0 +1,227 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +from functools import cached_property +from typing import Any +from typing import Dict +from typing import Set + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta +from numpy.typing import NDArray + +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.misc import as_first_date +from anemoi.datasets.data.misc import as_last_date + +LOG = logging.getLogger(__name__) + + +class Padded(Forwards): + _before: int = 0 + _after: int = 0 + _inside: int = 0 + + def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reason: Dict[str, Any]) -> None: + """Create a padded subset of a dataset. + + Attributes: + dataset (Dataset): The dataset to subset. + start (str): The start date of the subset. + end (str): The end date of the subset. + frequency (str): The frequency of the subset. + reason (Dict[str, Any]): The reason for the padding. + """ + + self.reason = {k: v for k, v in reason.items() if v is not None} + + if frequency is None: + frequency = dataset.frequency + + self._frequency = frequency_to_timedelta(frequency) + + if start is None: + # default is to start at the first date + start = dataset.dates[0] + else: + start = as_first_date(start, None, frequency=self._frequency) + + if end is None: + # default is to end at the last date + end = dataset.dates[-1] + else: + end = as_last_date(end, None, frequency=self._frequency) + + assert isinstance(dataset.dates[0], np.datetime64), (dataset.dates[0], type(dataset.dates[0])) + + # 'start' is the requested start date + # 'end' is the requested end date + # 'first' is the first date of the dataset + # 'last' is the last date of the dataset + first = dataset.dates[0] + last = dataset.dates[-1] + + timedelta = np.array([frequency], dtype="timedelta64[s]")[0] + + parts = [] + before_end = min(end + timedelta, first) + before_part = np.arange(start, before_end, timedelta) + if start < first: + # if the start date is before the first date of the dataset, there is a "before" part + assert len(before_part) > 0, (start, first, before_end) + parts.append(before_part) + self._before = len(before_part) + if start >= first: + # if the start date is the first date of the dataset, there is no "before" part + assert len(before_part) == 0, (start, first, before_end) + self._before = 0 + + # if the start date is before the last date of the dataset + # and the end date is after the first date of the dataset + # there is an "inside" part + if start < last and end > first: + inside_start = max(start, first) + inside_end = min(end, last) + self.dataset = dataset._subset(start=inside_start, end=inside_end) + inside_part = self.dataset.dates + parts.append(inside_part) + self._inside = len(inside_part) + else: + self.dataset = dataset # still needed to get the empty_item + self._inside = 0 + + after_start = max(start, last + timedelta) + after_part = np.arange(after_start, end + timedelta, timedelta) + if end > last: + # if the end date is after the last date of the dataset, there is an "after" part + assert len(after_part) > 0, (end, last, after_start) + parts.append(after_part) + self._after = len(after_part) + if end <= last: + assert len(after_part) == 0, (end, last, after_start) + self._after = 0 + + self._dates = np.hstack(parts) + + assert len(self._dates) == self._before + self._inside + self._after, ( + len(self._dates), + self._before, + self._inside, + self._after, + ) + + assert self._dates[0] == start, (self._dates[0], start) + assert self._dates[-1] == end, (self._dates[-1], end) + + # Forward other properties to the super dataset + super().__init__(dataset) + + @debug_indexing + def __getitem__(self, n: FullIndex) -> NDArray[Any]: + if isinstance(n, tuple): + return self._get_tuple(n) + + if isinstance(n, slice): + return self._get_slice(n) + + if self._i_out_of_range(n): + return self.empty_item() + + return self.dataset[n - self._before] + + def _i_out_of_range(self, n: FullIndex) -> bool: + """Check if the index is out of range.""" + if 0 <= n < self._before: + return True + + if (self._before + self._inside) <= n < (self._before + self._inside + self._after): + return True + return False + + @debug_indexing + def _get_slice(self, s: slice) -> NDArray[Any]: + LOG.warning("Padded subset does not support slice indexing, returning a list") + return [self[i] for i in range(*s.indices(self._len))] + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, n: TupleIndex) -> NDArray[Any]: + LOG.warning("Padded subset does not support tuple indexing, returning a list") + return [self[i] for i in n] + + def empty_item(self): + return self.dataset.empty_item() + + def get_aux(self, i: FullIndex) -> NDArray[np.timedelta64]: + if self._i_out_of_range(i): + arr = np.array([], dtype=np.float32) + aux = arr, arr, arr + else: + aux = self.dataset.get_aux(i - self._before) + + assert len(aux) == 3, (aux, i) + return aux + + def __len__(self) -> int: + return len(self._dates) + + @property + def frequency(self) -> datetime.timedelta: + """Get the frequency of the subset.""" + return self._frequency + + @property + def dates(self) -> NDArray[np.datetime64]: + return self._dates + + @property + def shape(self) -> Shape: + return (len(self.dates),) + self.dataset.shape[1:] + + @cached_property + def missing(self) -> Set[int]: + raise NotImplementedError("Need to decide whether to include the added dates as missing or not") + # return self.forward.missing + + def tree(self) -> Node: + """Get the tree representation of the subset. + + Returns: + Node: The tree representation of the subset. + """ + return Node(self, [self.dataset.tree()], **self.reason) + + def forwards_subclass_metadata_specific(self) -> Dict[str, Any]: + """Get the metadata specific to the forwards subclass. + + Returns: + Dict[str, Any]: The metadata specific to the forwards subclass. + """ + return { + # "indices": self.indices, + "reason": self.reason, + } + + def __repr__(self) -> str: + """Get the string representation of the subset. + + Returns: + str: The string representation of the subset. + """ + return f"Padded({self.forward}, {self.dates[0]}...{self.dates[-1]}, frequency={self.frequency})" diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py new file mode 100644 index 000000000..f569a4105 --- /dev/null +++ b/src/anemoi/datasets/data/records/__init__.py @@ -0,0 +1,442 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from collections import defaultdict +from functools import cached_property + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.data.records.backends import backend_factory + +LOG = logging.getLogger(__name__) + +if os.environ.get("ANEMOI_DATASET_COUNTER", "0") == "1": + + def counter(func): + def wrapper(*args, **kwargs): + count = 0 + for i in range(len(args[0])): + count += 1 + yield func(*args, **kwargs) + print(f"Counter: {count} calls to {func.__name__}") + + return wrapper + +else: + + def counter(func): + return func + + +def open_records_dataset(dataset, **kwargs): + if not dataset.endswith(".vz"): + raise ValueError("dataset must be a .vz file") + return RecordsDataset(dataset, **kwargs) + + +class BaseRecordsDataset: + + def __getitem__(self, i): + if isinstance(i, str): + return self._getgroup(i) + + if isinstance(i, int): + return self._getrecord(i) + + raise ValueError(f"Invalid index {i}, must be int or str") + + def _getgroup(self, i): + return Tabular(self, i) + + def _getrecord(self, i): + return Record(self, i) + + def _load_data(self, i): + raise NotImplementedError("Must be implemented in subclass") + + @property + def start_date(self): + return self.dates[0] + + @property + def end_date(self): + if len(self.dates) == 0: + return None + if len(self.dates) == 1: + return self.dates[0] + return self.dates[-1] + + @property + def groups(self): + return tuple(self.keys()) + + def _subset(self, **kwargs): + start = kwargs.pop("start", None) + end = kwargs.pop("end", None) + frequency = kwargs.pop("frequency", self.frequency) + + if frequency != self.frequency: + raise ValueError(f"Changing the frequency {frequency} (from {self.frequency}) is not implemented yet.") + + if start is not None or end is not None: + + def _dates_to_indices(start, end): + from anemoi.datasets.data.misc import as_first_date + from anemoi.datasets.data.misc import as_last_date + + start = self.dates[0] if start is None else as_first_date(start, self.dates) + end = self.dates[-1] if end is None else as_last_date(end, self.dates) + + return [i for i, date in enumerate(self.dates) if start <= date <= end] + + return RecordsSubset( + self, _dates_to_indices(start, end), {"start": start, "end": end, "frequency": frequency} + )._subset(**kwargs) + + select = kwargs.pop("select", None) + if select is not None: + return Select(self, select)._subset(**kwargs) + + return self + + def mutate(self): + return self + + def _check(self): + pass + + @property + def name_to_index(self): + raise NotImplementedError("Must be implemented in subclass") + + +class RecordsForward(BaseRecordsDataset): + def __init__(self, dataset): + self.forward = dataset + + @property + def statistics(self): + return self.forward.statistics + + @property + def variables(self): + return self.forward.variables + + def keys(self): + return self.forward.keys() + + @property + def dates(self): + return self.forward.dates + + @property + def name_to_index(self): + return self.forward.name_to_index + + @property + def frequency(self): + return self.forward.frequency + + @property + def shapes(self): + return self.forward.shapes + + def __len__(self): + return len(self.forward) + + +def match_variable(lst, group, name): + # lst must be a list of strings with dots (if there is no dot, it is automatically added at the end) + # - a dict with keys as group and values as list of strings + + if name == "__latitudes" or name == "__longitudes": + # This should disappear in the future, when we stop saving a duplicate of lat/lon in the data + return False + + lst = [k if "." in k else f"{k}.*" for k in lst] + + key = f"{group}.{name}" + if key in lst: + return True + if f"{group}.*" in lst: + return True + if f"*.{name}" in lst: + return True + if "*" in lst: + return True + return False + + +class Select(RecordsForward): + def __init__(self, dataset, select): + super().__init__(dataset) + + self.dataset = dataset + + if isinstance(select, dict): + # if a dict is provided, make it a list of strings with '.' + sel = [] + for group, d in select.items(): + for name in d: + sel.append(f"{group}.{name}") + select = sel + + self._select = select + + self.reason = {"select": select} + self._build_indices_and_name_to_index() + + def _build_indices_and_name_to_index(self): + indices = {} + name_to_index = {} + variables = {} + + # this should be revisited to take into account the order requested by the user + # see what is done in the fields datasets + for group, names in self.dataset.variables.items(): + ind = np.zeros(len(names), dtype=bool) + count = 0 + for j, name in enumerate(names): + if self.match_variable(group, name): + assert j == names.index(name), f"Invalid index {j} for {name} in {group}" + ind[j] = True + indices[group] = ind + if group not in name_to_index: + name_to_index[group] = {} + assert group not in variables, (group, j, name, variables, name_to_index) + variables[group] = [] + name_to_index[group][name] = count + variables[group].append(name) + count += 1 + assert np.sum(ind) == count, f"Mismatch in {group}: {names}, {ind}" + self._indices = indices + self._name_to_index = name_to_index + self._variables = variables + + def match_variable(self, *args, **kwargs): + return match_variable(self._select, *args, **kwargs) + + def keys(self): + return self._indices.keys() + + def _load_data(self, i): + forward = self.dataset._load_data(i) + data = {} + for k, v in self._indices.items(): + data[f"latitudes:{k}"] = forward[f"latitudes:{k}"] + data[f"longitudes:{k}"] = forward[f"longitudes:{k}"] + data[f"timedeltas:{k}"] = forward[f"timedeltas:{k}"] + data[f"metadata:{k}"] = forward[f"metadata:{k}"] + for k, v in self._indices.items(): + data[f"data:{k}"] = forward[f"data:{k}"][v] # notice the [v] here + return data + + @property + def name_to_index(self): + return self._name_to_index + + @property + def variables(self): + return self._variables + + @property + def statistics(self): + dic = {} + for group, v in self._indices.items(): + stats = self.dataset.statistics[group] + dic[group] = {key: stats[key][v] for key in stats.keys()} + assert "mean" in dic[group], f"Missing mean in {dic[group]}" + return dic + + +class RecordsSubset(RecordsForward): + def __init__(self, dataset, indices, reason): + super().__init__(dataset) + self.dataset = dataset + self.reason = reason + self._indices = indices + + @cached_property + def dates(self): + return self.dataset.dates[self._indices] + + def _load_data(self, i): + return self.dataset._load_data(self._indices[i]) + + def __len__(self): + return len(self._indices) + + +class RecordsDataset(BaseRecordsDataset): + + def __init__(self, path, backend="npz1", **kwargs): + if kwargs: + print("Warning: ignoring additional kwargs", kwargs) + self.path = path + self.backend = backend_factory(backend, path, **kwargs) + self.keys = self.metadata["sources"].keys + + @property + def frequency(self): + frequency = self.metadata["frequency"] + frequency = frequency_to_timedelta(frequency) + return frequency + + @property + def name_to_index(self): + return self.metadata["name_to_index"] + + @property + def variables(self): + return self.metadata["variables"] + + @cached_property + def metadata(self): + return self.backend.read_metadata() + + @property + def shapes(self): + return self.metadata["shapes"] + + def items(self, *args, **kwargs): + return {k: Tabular(self, k) for k in self.keys()}.items(*args, **kwargs) + + @cached_property + def statistics(self): + return self.backend.read_statistics() + + def __len__(self): + return len(self.dates) + + @property + def start_date(self): + date = self.metadata["start_date"] + return datetime.datetime.fromisoformat(date) + + @property + def end_date(self): + date = self.metadata["end_date"] + return datetime.datetime.fromisoformat(date) + + @cached_property + def dates(self): + result = [] + delta = self.frequency + d = self.start_date + while d <= self.end_date: + result.append(d) + d += delta + return np.array(result) + + @counter + def _load_data(self, i): + return self.backend.read(i) + + def check(self, i=None): + if i is not None: + dict_of_sets = defaultdict(set) + for key in self._load_data(i).keys(): + kind, group = key.split(":") + dict_of_sets[group].add(kind) + for group, s in dict_of_sets.items(): + assert s == {"latitudes", "longitudes", "timedeltas", "metadata", "data"}, f"Invalid keys {s}" + + +class Record(dict): + def __init__(self, dataset, n): + self.dataset = dataset + self.n = n + + def __repr__(self): + d = {group: "" for group in self.dataset.keys()} + return str(d) + + def items(self): + return self._payload.items() + + @property + def name_to_index(self): + return self.dataset.name_to_index + + @cached_property + def _payload(self): + payload = self.dataset._load_data(self.n) + for k in payload.keys(): + assert len(k.split(":")) == 2, f"Invalid key {k}" + return payload + + def keys(self): + return self.dataset.keys() + + def __getitem__(self, group): + return self._payload["data:" + group] + + def _get_aux(self, name): + try: + return {k: self._payload[name + ":" + k] for k in self.keys()} + except KeyError as e: + e.add_note(f"Available keys are {self._payload.keys()}") + raise + + @property + def latitudes(self): + return self._get_aux("latitudes") + + @property + def longitudes(self): + return self._get_aux("longitudes") + + @property + def timedeltas(self): + return self._get_aux("timedeltas") + + @property + def statistics(self): + return self.dataset.statistics + + @property + def groups(self): + return tuple(self.keys()) + + +class Tabular: + def __init__(self, dataset, name): + self.dataset = dataset + self.name = name + + @property + def group(self): + return self.name + + def __getitem__(self, i): + return self.__get(i, "data") + + def __get(self, i, k): + payload = self.dataset._load_data(i) + try: + return payload[k + ":" + self.name] + except KeyError: + print(f"KeyError to retrieve {self.name} available groups are", payload.keys()) + raise + + @property + def variables(self): + return self.dataset.variables[self.name] + + @property + def name_to_index(self): + return self.dataset.name_to_index[self.name] + + @property + def statistics(self): + return self.dataset.statistics[self.name] diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py new file mode 100644 index 000000000..6971cafd4 --- /dev/null +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -0,0 +1,157 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +import os + +import numpy as np + + +class Backend: + def __init__(self, path, **kwargs): + self.path = path + self.kwargs = kwargs + + def read(self, i, **kwargs): + raise NotImplementedError("Must be implemented in subclass") + + def read_metadata(self): + raise NotImplementedError("Must be implemented in subclass") + + def read_statistics(self): + raise NotImplementedError("Must be implemented in subclass") + + +class Npz1Backend(Backend): + def read(self, i, **kwargs): + path = os.path.join(self.path, "data", str(int(i / 10)), f"{i}.npz") + with open(path, "rb") as f: + return dict(np.load(f)) + + def read_metadata(self): + with open(os.path.join(self.path, "metadata.json"), "r") as f: + return json.load(f) + + def read_statistics(self): + path = os.path.join(self.path, "statistics.npz") + dic = {} + for k, v in dict(np.load(path)).items(): + key, group = k.split(":") + if group not in dic: + dic[group] = {} + dic[group][key] = v + return dic + + +class Npz2Backend(Backend): + def read(self, i, **kwargs): + path = os.path.join(self.path, "data_", str(int(i / 10)), f"{i}_.npz") + with open(path, "rb") as f: + return dict(np.load(f)) + + def read_metadata(self): + with open(os.path.join(self.path, "metadata.json"), "r") as f: + return json.load(f) + + def read_statistics(self): + path = os.path.join(self.path, "statistics_.npz") + dic = {} + for k, v in dict(np.load(path)).items(): + key, group = k.split(":") + if group not in dic: + dic[group] = {} + dic[group][key] = v + return dic + + +def backend_factory(backend, *args, **kwargs): + BACKENDS = dict( + npz1=Npz1Backend, + npz2=Npz2Backend, + ) + return BACKENDS[backend](*args, **kwargs) + + +class WriteBackend(Backend): + def __init__(self, path, **kwargs): + super().__init__(path, **kwargs) + + def write(self, i, data, **kwargs): + raise NotImplementedError("Must be implemented in subclass") + + def write_metadata(self, metadata): + raise NotImplementedError("Must be implemented in subclass") + + def write_statistics(self, statistics): + raise NotImplementedError("Must be implemented in subclass") + + +class Npz1WriteBackend(WriteBackend): + def write(self, i, data, **kwargs): + path = os.path.join(self.path, "data", str(int(i / 10))) + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{i}.npz") + np.savez(out_path, **data) + + def write_metadata(self, metadata): + from anemoi.datasets.create import json_tidy + + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2, default=json_tidy) + + def write_statistics(self, statistics): + flatten = {} + for name, d in statistics.items(): + assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + for k, v in d.items(): + assert isinstance( + v, (int, float, np.ndarray) + ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" + flatten[k + ":" + name] = v + + path = os.path.join(self.path, "statistics.npz") + np.savez(path, **flatten) + + +class Npz2WriteBackend(WriteBackend): + def write(self, i, data, **kwargs): + path = os.path.join(self.path, "data_", str(int(i / 10))) + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{i}_.npz") + np.savez(out_path, **data) + + def write_metadata(self, metadata): + from anemoi.datasets.create import json_tidy + + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2, default=json_tidy) + + def write_statistics(self, statistics): + flatten = {} + for name, d in statistics.items(): + assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + for k, v in d.items(): + assert isinstance( + v, (int, float, np.ndarray) + ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" + flatten[k + ":" + name] = v + + os.makedirs(self.path, exist_ok=True) + path = os.path.join(self.path, "statistics_.npz") + np.savez(path, **flatten) + + +def writer_backend_factory(backend, *args, **kwargs): + WRITE_BACKENDS = dict( + npz1=Npz1WriteBackend, + npz2=Npz2WriteBackend, + ) + return WRITE_BACKENDS[backend](*args, **kwargs) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 4cc160547..3c7442487 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -107,51 +107,6 @@ def __getitem__(self, key: str) -> bytes: return response["Body"].read() -class PlanetaryComputerStore(ReadOnlyStore): - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - raise NotImplementedError() - - class DebugStore(ReadOnlyStore): """A store to debug the zarr loading.""" @@ -190,11 +145,11 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: if store.startswith("http://") or store.startswith("https://"): - parsed = urlparse(store) - if store.endswith(".zip"): import multiurl + parsed = urlparse(store) + # Zarr cannot handle zip files over HTTP tmpdir = tempfile.gettempdir() name = os.path.basename(parsed.path) @@ -210,15 +165,7 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: os.rename(path + ".tmp", path) return name_to_zarr_store(path) - bits = parsed.netloc.split(".") - if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): - s3_url = f"s3://{bits[0]}{parsed.path}" - store = S3Store(s3_url, region=bits[2]) - elif store.startswith("https://planetarycomputer.microsoft.com/"): - data_catalog_id = store.rsplit("/", 1)[-1] - store = PlanetaryComputerStore(data_catalog_id).store - else: - store = HTTPStore(store) + return HTTPStore(store) return store @@ -565,6 +512,10 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: config = load_config()["datasets"] use_search_path_not_found = config.get("use_search_path_not_found", False) + if name.endswith(".zarr/"): + LOG.warning("Removing trailing slash from path: %s", name) + name = name[:-1] + if name.endswith(".zarr") or name.endswith(".zip"): if os.path.exists(name): diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 3ea5265d9..2072137a0 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -185,6 +185,11 @@ def __getitem__(self, n: FullIndex) -> NDArray[Any]: n = self.indices[n] return self.dataset[n] + def get_aux(self, n: FullIndex) -> NDArray[Any]: + assert n >= 0, n + n = self.indices[n] + return self.dataset.get_aux(n) + @debug_indexing def _get_slice(self, s: slice) -> NDArray[Any]: """Get slice of data. diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 8a36aad6e..f663dade9 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -605,6 +605,7 @@ def nearest_grid_points( target_latitudes: NDArray[Any], target_longitudes: NDArray[Any], max_distance: float = None, + k: int = 1, ) -> NDArray[Any]: """Find the nearest grid points from source to target coordinates. @@ -621,6 +622,8 @@ def nearest_grid_points( max_distance: float, optional Maximum distance between nearest point and point to interpolate. Defaults to None. For example, 1e-3 is 1 km. + k : int, optional + The number of k closest neighbors to consider for interpolation Returns ------- @@ -637,10 +640,10 @@ def nearest_grid_points( target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) target_points = np.array(target_xyz).transpose() if max_distance is None: - _, indices = cKDTree(source_points).query(target_points, k=1) + distances, indices = cKDTree(source_points).query(target_points, k=k) else: - _, indices = cKDTree(source_points).query(target_points, k=1, distance_upper_bound=max_distance) - return indices + distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + return distances, indices if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..210ad7a14 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1 @@ +pytest_plugins = "anemoi.utils.testing" diff --git a/tests/create/__init__.py b/tests/create/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/create/test_create.py b/tests/create/test_create.py index beed90387..ec2d515ce 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -8,24 +8,17 @@ # nor does it submit to any jurisdiction. import glob -import hashlib -import json import logging import os import sys -from functools import wraps from unittest.mock import patch -import numpy as np import pytest -from anemoi.utils.testing import get_test_archive -from anemoi.utils.testing import get_test_data from anemoi.utils.testing import skip_if_offline -from earthkit.data import from_source as original_from_source -from anemoi.datasets import open_dataset -from anemoi.datasets.create.testing import create_dataset -from anemoi.datasets.data.stores import open_zarr +from .utils.compare import Comparer +from .utils.create import create_dataset +from .utils.mock_sources import LoadSource HERE = os.path.dirname(__file__) # find_yamls @@ -37,365 +30,41 @@ assert NAMES, "No yaml files found in " + HERE -def mockup_from_source(func: callable) -> callable: - """Decorator to mock the `from_source` function from the `earthkit.data` module. - - Parameters - ---------- - func : function - The function to be wrapped. - - Returns - ------- - function - The wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - with patch("earthkit.data.from_source", _from_source): - return func(*args, **kwargs) - - return wrapper - - -class LoadSource: - """Class to load data sources and handle mockup data.""" - - def filename(self, args: tuple, kwargs: dict) -> str: - """Generate a filename based on the arguments and keyword arguments. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - str - The generated filename. - """ - string = json.dumps([args, kwargs], sort_keys=True, default=str) - h = hashlib.md5(string.encode("utf8")).hexdigest() - return h + ".grib" - - def get_data(self, args: tuple, kwargs: dict, path: str) -> None: - """Retrieve data and save it to the specified path. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - path : str - The path to save the data. - - Raises - ------ - ValueError - If the test data is missing. - """ - upload_path = os.path.realpath(path + ".to_upload") - ds = original_from_source("mars", *args, **kwargs) - ds.save(upload_path) - print(f"Mockup: Saving to {upload_path} for {args}, {kwargs}") - print() - print("⚠️ To upload the test data, run this:") - path = os.path.relpath(upload_path, os.getcwd()) - name = os.path.basename(upload_path).replace(".to_upload", "") - print(f"scp {path} data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/{name}") - print() - exit(1) - raise ValueError("Test data is missing") - - def mars(self, args: tuple, kwargs: dict) -> object: - """Load data from the MARS archive. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - object - The loaded data source. - """ - - name = self.filename(args, kwargs) - - try: - return original_from_source("file", get_test_data(f"anemoi-datasets/create/mock-mars/{name}")) - except RuntimeError: - raise # If offline - except Exception: - self.get_data(args, kwargs, name) - - def __call__(self, name: str, *args: tuple, **kwargs: dict) -> object: - """Call the appropriate method based on the data source name. - - Parameters - ---------- - name : str - The name of the data source. - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - object - The loaded data source. - """ - if name == "mars": - return self.mars(args, kwargs) - - return original_from_source(name, *args, **kwargs) - - -_from_source = LoadSource() - - -def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: - """Compare the attributes of two Zarr datasets. - - Parameters - ---------- - a : dict - The attributes of the first dataset. - b : dict - The attributes of the second dataset. - path : str - The current path in the attribute hierarchy. - errors : list - The list to store error messages. - """ - if isinstance(a, dict): - a_keys = list(a.keys()) - b_keys = list(b.keys()) - for k in set(a_keys) | set(b_keys): - if k not in a_keys: - errors.append(f"❌ {path}.{k} : missing key (only in reference)") - continue - if k not in b_keys: - errors.append(f"❌ {path}.{k} : additional key (missing in reference)") - continue - if k in [ - "timestamp", - "uuid", - "latest_write_timestamp", - "history", - "provenance", - "provenance_load", - "description", - "config_path", - "total_size", - ]: - if type(a[k]) is not type(b[k]): - errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") - continue - - compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) - - return - - if isinstance(a, list): - if len(a) != len(b): - errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") - return - - for i, (v, w) in enumerate(zip(a, b)): - compare_dot_zattrs(v, w, f"{path}.{i}", errors) - - return - - if type(a) is not type(b): - msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" - errors.append(msg) - return - - if a != b: - msg = f"❌ {path} actual != expected : {a} != {b}" - errors.append(msg) - - -def compare_datasets(a: object, b: object) -> None: - """Compare two datasets. - - Parameters - ---------- - a : object - The first dataset. - b : object - The second dataset. - - Raises - ------ - AssertionError - If the datasets do not match. - """ - assert a.shape == b.shape, (a.shape, b.shape) - assert (a.dates == b.dates).all(), (a.dates, b.dates) - for a_, b_ in zip(a.variables, b.variables): - assert a_ == b_, (a, b) - assert a.missing == b.missing, "Missing are different" - - for i_date, date in zip(range(a.shape[0]), a.dates): - if i_date in a.missing: - continue - for i_param in range(a.shape[1]): - param = a.variables[i_param] - assert param == b.variables[i_param], ( - date, - param, - a.variables[i_param], - b.variables[i_param], - ) - a_ = a[i_date, i_param] - b_ = b[i_date, i_param] - assert a.shape == b.shape, (date, param, a.shape, b.shape) - - a_nans = np.isnan(a_) - b_nans = np.isnan(b_) - assert np.all(a_nans == b_nans), (date, param, "nans are different") - - a_ = np.where(a_nans, 0, a_) - b_ = np.where(b_nans, 0, b_) - - delta = a_ - b_ - max_delta = np.max(np.abs(delta)) - abs_error = np.abs(a_ - b_) - rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero - assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) - - -def compare_statistics(ds1: object, ds2: object) -> None: - """Compare the statistics of two datasets. - - Parameters - ---------- - ds1 : object - The first dataset. - ds2 : object - The second dataset. - - Raises - ------ - AssertionError - If the statistics do not match. - """ - vars1 = ds1.variables - vars2 = ds2.variables - assert len(vars1) == len(vars2) - for v1, v2 in zip(vars1, vars2): - idx1 = ds1.name_to_index[v1] - idx2 = ds2.name_to_index[v2] - assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() - assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() - assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() - assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() - - -class Comparer: - """Class to compare datasets and their metadata. - - Parameters - ---------- - name : str - The name of the dataset. - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - - def __init__(self, name: str, output_path: str = None, reference_path: str = None) -> None: - """Initialize the Comparer instance. - - Parameters - ---------- - name : str - The name of the dataset. - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - self.name = name - self.output_path = output_path or os.path.join(name + ".zarr") - self.reference_path = reference_path - print(f"Comparing {self.output_path} and {self.reference_path}") - - self.z_output = open_zarr(self.output_path) - self.z_reference = open_zarr(self.reference_path) - - self.z_reference["data"] - self.ds_output = open_dataset(self.output_path) - self.ds_reference = open_dataset(self.reference_path) - - def compare(self) -> None: - """Compare the output dataset with the reference dataset. - - Raises - ------ - AssertionError - If the datasets or their metadata do not match. - """ - errors = [] - compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) - if errors: - print("Comparison failed") - print("\n".join(errors)) - - if errors: - print() - - print() - print("⚠️ To update the reference data, run this:") - print("cd " + os.path.dirname(self.output_path)) - base = os.path.basename(self.output_path) - print(f"tar zcf {base}.tgz {base}") - print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") - print() - raise AssertionError("Comparison failed") - - compare_datasets(self.ds_output, self.ds_reference) - compare_statistics(self.ds_output, self.ds_reference) - # do not compare tendencies statistics yet, as we don't know yet if they should stay +@pytest.fixture +def load_source(get_test_data: callable) -> LoadSource: + return LoadSource(get_test_data) @skip_if_offline @pytest.mark.parametrize("name", NAMES) -@mockup_from_source -def test_run(name: str) -> None: +def test_run(name: str, get_test_archive: callable, load_source: LoadSource) -> None: """Run the test for the specified dataset. Parameters ---------- name : str The name of the dataset. + get_test_archive : callable + Fixture to retrieve the test archive. + load_source : LoadSource + Fixture to mock data sources. Raises ------ AssertionError If the comparison fails. """ - config = os.path.join(HERE, name + ".yaml") - output = os.path.join(HERE, name + ".zarr") - is_test = False + with patch("earthkit.data.from_source", load_source): + config = os.path.join(HERE, name + ".yaml") + output = os.path.join(HERE, name + ".zarr") + is_test = False - create_dataset(config=config, output=output, delta=["12h"], is_test=is_test) + create_dataset(config=config, output=output, delta=["12h"], is_test=is_test) - directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") - reference = os.path.join(directory, name + ".zarr") + directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") + reference = os.path.join(directory, name + ".zarr") - Comparer(name, output_path=output, reference_path=reference).compare() + Comparer(output_path=output, reference_path=reference).compare() if __name__ == "__main__": diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index f74471911..e8e766b51 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -7,20 +7,22 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging import os +import sys -from anemoi.utils.testing import get_test_data +import numpy as np +import pytest from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages from anemoi.utils.testing import skip_slow_tests from anemoi.datasets import open_dataset -from anemoi.datasets.create.testing import create_dataset + +from .utils.create import create_dataset @skip_if_offline -def test_grib() -> None: +def test_grib(get_test_data: callable) -> None: """Test the creation of a dataset from GRIB files. This function tests the creation of a dataset using GRIB files from @@ -50,8 +52,123 @@ def test_grib() -> None: assert ds.shape == (8, 12, 1, 162) +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" +) +@skip_if_offline +def test_grib_gridfile(get_test_data) -> None: + """Test the creation of a dataset from GRIB files with an unstructured grid. + + This function tests the creation of a dataset using GRIB files from + specific dates and verifies the shape of the resulting dataset. + This GRIB data is defined on an unstructured grid and therefore requires + specifying a grid file. + """ + data1 = get_test_data("anemoi-datasets/create/grib-iconch1-20250101.grib") + data2 = get_test_data("anemoi-datasets/create/grib-iconch1-20250102.grib") + gridfile = get_test_data("anemoi-datasets/create/icon_grid_0001_R19B08_mch.nc") + assert os.path.dirname(data1) == os.path.dirname(data2) + + path = os.path.dirname(data1) + + config = { + "dates": { + "start": "2025-01-01T00:00:00", + "end": "2025-01-02T18:00:00", + "frequency": "6h", + }, + "input": { + "grib": { + "path": os.path.join(path, "grib-iconch1-{date:strftime(%Y%m%d)}.grib"), + "grid_definition": {"icon": {"path": gridfile}}, + "flavour": [[{"levtype": "sfc"}, {"levelist": None}]], + }, + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (8, 1, 1, 1147980) + assert ds.variables == ["2t"] + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" +) @skip_if_offline -def test_netcdf() -> None: +@pytest.mark.parametrize( + "refinement_level_c,shape", + ( + (2, (2, 13, 1, 2880)), + (7, (2, 13, 1, 2949120)), + ), +) +def test_grib_gridfile_with_refinement_level( + refinement_level_c: str, shape: tuple[int, int, int, int, int], get_test_data: callable +) -> None: + """Test the creation of a dataset from GRIB files with an unstructured grid. + + This function tests the creation of a dataset using GRIB files from + specific dates and verifies the shape of the resulting dataset. + This GRIB data is defined on an unstructured grid and therefore requires + specifying a grid file. The `refinement_level_c` selection key and + strftimedelta are used. + """ + + p = "anemoi-datasets/create/test_grib_gridfile_with_refinement_level/" + data1 = get_test_data(p + "2023010103+fc_R03B07_rea_ml.2023010100") + data2 = get_test_data(p + "2023010106+fc_R03B07_rea_ml.2023010103") + gridfile = get_test_data("dwd/2024-12-11_00/icon_grid_0026_R03B07_subsetAICON.nc") + assert os.path.dirname(data1) == os.path.dirname(data2) + + path = os.path.dirname(data1) + + param = ["pres", "t", "u", "v", "q"] + level = [101, 119] + forcings = ["cos_latitude", "sin_latitude", "cos_julian_day"] + assert len(param) * len(level) + len(forcings) == shape[1] + + grib = { + "path": os.path.join(path, "{date:strftimedelta(+3h;%Y%m%d%H)}+fc_R03B07_rea_ml.{date:strftime(%Y%m%d%H)}"), + "grid_definition": {"icon": {"path": gridfile, "refinement_level_c": refinement_level_c}}, + "param": param, + "level": level, + } + refinement_filter = {"icon_refinement_level": {"grid": gridfile, "refinement_level_c": refinement_level_c}} + + config = { + "dates": { + "start": "2023-01-01T00:00:00", + "end": "2023-01-01T03:00:00", + "frequency": "3h", + }, + "input": { + "pipe": [ + { + "join": [ + {"grib": grib}, + {"forcings": {"param": forcings, "template": "${input.pipe.0.join.0.grib}"}}, + ] + }, + refinement_filter, + ] + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == shape + assert np.all(ds.data[ds.to_index(date=0, variable="cos_julian_day", member=0)] == 1.0), "cos(julian_day = 0) == 1" + assert np.all(ds.data[ds.to_index(date=0, variable="u_101", member=0)] == 42.0), "artificially constant data day 0" + assert np.all(ds.data[ds.to_index(date=1, variable="v_119", member=0)] == 21.0), "artificially constant data day 1" + assert ds.data[ds.to_index(date=0, variable="cos_latitude", member=0)].max() > 0.9 + assert ds.data[ds.to_index(date=0, variable="cos_latitude", member=0)].min() >= 0 + assert ds.data[ds.to_index(date=0, variable="sin_latitude", member=0)].max() > 0.9 + assert ds.data[ds.to_index(date=0, variable="sin_latitude", member=0)].min() < -0.9 + + +@skip_if_offline +def test_netcdf(get_test_data: callable) -> None: """Test for NetCDF files. This function tests the creation of a dataset from a NetCDF file. @@ -74,7 +191,7 @@ def test_netcdf() -> None: @skip_missing_packages("fstd", "rpnpy.librmn") -def test_eccs_fstd() -> None: +def test_eccs_fstd(get_test_data: callable) -> None: """Test for 'fstd' files from ECCC.""" # See https://github.com/neishm/fstd2nc @@ -98,7 +215,7 @@ def test_eccs_fstd() -> None: @skip_slow_tests @skip_if_offline @skip_missing_packages("kerchunk", "s3fs") -def test_kerchunk() -> None: +def test_kerchunk(get_test_data: callable) -> None: """Test for Kerchunk JSON files. This function tests the creation of a dataset from a Kerchunk JSON file. @@ -128,12 +245,44 @@ def test_kerchunk() -> None: assert ds.shape == (4, 1, 1, 1038240) +@skip_if_offline +@skip_missing_packages("planetary_computer", "adlfs") +def test_planetary_computer_conus404() -> None: + """Test loading and validating the planetary_computer_conus404 dataset.""" + + config = { + "dates": { + "start": "2022-01-01", + "end": "2022-01-02", + "frequency": "1d", + }, + "input": { + "planetary_computer": { + "data_catalog_id": "conus404", + "param": ["Z"], + "level": [1], + "patch": { + "coordinates": ["bottom_top_stag"], + "rename": { + "bottom_top_stag": "level", + }, + "attributes": { + "lon": {"standard_name": "longitude", "long_name": "Longitude"}, + "lat": {"standard_name": "latitude", "long_name": "Latitude"}, + }, + }, + } + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (2, 1, 1, 1387505), ds.shape + + if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - test_kerchunk() - exit() - """Run all test functions that start with 'test_'.""" - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() + test_planetary_computer_conus404() + exit(0) + from anemoi.utils.testing import run_tests + + run_tests(globals()) diff --git a/tests/create/utils/__init__.py b/tests/create/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py new file mode 100644 index 000000000..aa6a59dd2 --- /dev/null +++ b/tests/create/utils/compare.py @@ -0,0 +1,218 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os + +import numpy as np + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.stores import open_zarr + + +class Comparer: + """Class to compare datasets and their metadata. + + Parameters + ---------- + output_path : str, optional + The path to the output dataset. + reference_path : str, optional + The path to the reference dataset. + """ + + def __init__(self, output_path: str = None, reference_path: str = None) -> None: + """Initialize the Comparer instance. + + Parameters + ---------- + output_path : str, optional + The path to the output dataset. + reference_path : str, optional + The path to the reference dataset. + """ + self.output_path = output_path + self.reference_path = reference_path + print(f"Comparing {self.output_path} and {self.reference_path}") + + self.z_output = open_zarr(self.output_path) + self.z_reference = open_zarr(self.reference_path) + + self.z_reference["data"] + self.ds_output = open_dataset(self.output_path) + self.ds_reference = open_dataset(self.reference_path) + + @staticmethod + def compare_datasets(a: object, b: object) -> None: + """Compare two datasets. + + Parameters + ---------- + a : object + The first dataset. + b : object + The second dataset. + + Raises + ------ + AssertionError + If the datasets do not match. + """ + assert a.shape == b.shape, (a.shape, b.shape) + assert (a.dates == b.dates).all(), (a.dates, b.dates) + for a_, b_ in zip(a.variables, b.variables): + assert a_ == b_, (a, b) + assert a.missing == b.missing, "Missing are different" + + for i_date, date in zip(range(a.shape[0]), a.dates): + if i_date in a.missing: + continue + for i_param in range(a.shape[1]): + param = a.variables[i_param] + assert param == b.variables[i_param], ( + date, + param, + a.variables[i_param], + b.variables[i_param], + ) + a_ = a[i_date, i_param] + b_ = b[i_date, i_param] + assert a.shape == b.shape, (date, param, a.shape, b.shape) + + a_nans = np.isnan(a_) + b_nans = np.isnan(b_) + assert np.all(a_nans == b_nans), (date, param, "nans are different") + + a_ = np.where(a_nans, 0, a_) + b_ = np.where(b_nans, 0, b_) + + delta = a_ - b_ + max_delta = np.max(np.abs(delta)) + abs_error = np.abs(a_ - b_) + rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero + assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) + + @staticmethod + def compare_statistics(ds1: object, ds2: object) -> None: + """Compare the statistics of two datasets. + + Parameters + ---------- + ds1 : object + The first dataset. + ds2 : object + The second dataset. + + Raises + ------ + AssertionError + If the statistics do not match. + """ + vars1 = ds1.variables + vars2 = ds2.variables + assert len(vars1) == len(vars2) + for v1, v2 in zip(vars1, vars2): + idx1 = ds1.name_to_index[v1] + idx2 = ds2.name_to_index[v2] + assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() + assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() + assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() + assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() + + @staticmethod + def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: + """Compare the attributes of two Zarr datasets. + + Parameters + ---------- + a : dict + The attributes of the first dataset. + b : dict + The attributes of the second dataset. + path : str + The current path in the attribute hierarchy. + errors : list + The list to store error messages. + """ + if isinstance(a, dict): + a_keys = list(a.keys()) + b_keys = list(b.keys()) + for k in set(a_keys) | set(b_keys): + if k not in a_keys: + errors.append(f"❌ {path}.{k} : missing key (only in reference)") + continue + if k not in b_keys: + errors.append(f"❌ {path}.{k} : additional key (missing in reference)") + continue + if k in [ + "timestamp", + "uuid", + "latest_write_timestamp", + "history", + "provenance", + "provenance_load", + "description", + "config_path", + "total_size", + ]: + if type(a[k]) is not type(b[k]): + errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") + continue + + Comparer.compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) + + return + + if isinstance(a, list): + if len(a) != len(b): + errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") + return + + for i, (v, w) in enumerate(zip(a, b)): + Comparer.compare_dot_zattrs(v, w, f"{path}.{i}", errors) + + return + + if type(a) is not type(b): + msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" + errors.append(msg) + return + + if a != b: + msg = f"❌ {path} actual != expected : {a} != {b}" + errors.append(msg) + + def compare(self) -> None: + """Compare the output dataset with the reference dataset. + + Raises + ------ + AssertionError + If the datasets or their metadata do not match. + """ + errors = [] + self.compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) + if errors: + print("Comparison failed") + print("\n".join(errors)) + + if errors: + print() + + print() + print("⚠️ To update the reference data, run this:") + print("cd " + os.path.dirname(self.output_path)) + base = os.path.basename(self.output_path) + print(f"tar zcf {base}.tgz {base}") + print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") + print() + raise AssertionError("Comparison failed") + + self.compare_datasets(self.ds_output, self.ds_reference) + self.compare_statistics(self.ds_output, self.ds_reference) + # do not compare tendencies statistics yet, as we don't know yet if they should stay diff --git a/src/anemoi/datasets/create/testing.py b/tests/create/utils/create.py similarity index 100% rename from src/anemoi/datasets/create/testing.py rename to tests/create/utils/create.py diff --git a/tests/create/utils/mock_sources.py b/tests/create/utils/mock_sources.py new file mode 100644 index 000000000..6c8448372 --- /dev/null +++ b/tests/create/utils/mock_sources.py @@ -0,0 +1,117 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import hashlib +import json +import os + +from earthkit.data import from_source as original_from_source + + +class LoadSource: + """Class to load data sources and handle mockup data.""" + + def __init__(self, get_test_data_func) -> None: + self._get_test_data = get_test_data_func + + def filename(self, args: tuple, kwargs: dict) -> str: + """Generate a filename based on the arguments and keyword arguments. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + str + The generated filename. + """ + string = json.dumps([args, kwargs], sort_keys=True, default=str) + h = hashlib.md5(string.encode("utf8")).hexdigest() + return h + ".grib" + + def get_data(self, args: tuple, kwargs: dict, path: str) -> None: + """Retrieve data and save it to the specified path. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + path : str + The path to save the data. + + Raises + ------ + ValueError + If the test data is missing. + """ + upload_path = os.path.realpath(path + ".to_upload") + ds = original_from_source("mars", *args, **kwargs) + ds.save(upload_path) + print(f"Mockup: Saving to {upload_path} for {args}, {kwargs}") + print() + print("⚠️ To upload the test data, run this:") + path = os.path.relpath(upload_path, os.getcwd()) + name = os.path.basename(upload_path).replace(".to_upload", "") + print(f"scp {path} data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/{name}") + print() + exit(1) + raise ValueError("Test data is missing") + + def mars(self, args: tuple, kwargs: dict) -> object: + """Load data from the MARS archive. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + object + The loaded data source. + """ + + name = self.filename(args, kwargs) + + try: + return original_from_source("file", self._get_test_data(f"anemoi-datasets/create/mock-mars/{name}")) + except RuntimeError: + raise # If offline + except Exception: + self.get_data(args, kwargs, name) + + def __call__(self, name: str, *args: tuple, **kwargs: dict) -> object: + """Call the appropriate method based on the data source name. + + Parameters + ---------- + name : str + The name of the data source. + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + object + The loaded data source. + """ + if name == "mars": + return self.mars(args, kwargs) + + return original_from_source(name, *args, **kwargs) diff --git a/tests/test_data.py b/tests/test_data.py index 5c3e4c352..07b35887e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -31,6 +31,7 @@ from anemoi.datasets.data.join import Join from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.data.padded import Padded from anemoi.datasets.data.select import Rename from anemoi.datasets.data.select import Select from anemoi.datasets.data.statistics import Statistics @@ -389,6 +390,7 @@ def run( time_increment: datetime.timedelta, statistics_reference_dataset: Optional[Union[str, list]], statistics_reference_variables: Optional[Union[str, list]], + regular_shape: bool = True, ) -> None: """Run the dataset tests. @@ -414,6 +416,8 @@ def run( Reference dataset for statistics. statistics_reference_variables : Optional[Union[str, list]] Reference variables for statistics. + regular_shape : bool, optional + Whether the dataset has a regular shape, by default True. """ if isinstance(expected_variables, str): expected_variables = [v for v in expected_variables] @@ -452,7 +456,8 @@ def run( statistics_reference_variables, ) - self.indexing(self.ds) + if regular_shape: + self.indexing(self.ds) self.metadata(self.ds) self.ds.tree() @@ -669,6 +674,25 @@ def test_join_3() -> None: ) +@mockup_open_zarr +def test_padding_1() -> None: + """Test subsetting a dataset (case 2).""" + test = DatasetTester("test-2022-2022-1h-o96-abcd", start="2021-01-01", end="2023-12-31 23:00:00", padding="empty") + test.run( + expected_class=Padded, + expected_length=365 * 24 * 3, + expected_shape=(365 * 24 * 3, 4, 1, VALUES), + expected_variables="abcd", + expected_name_to_index="abcd", + date_to_row=lambda date: simple_row(date, "abcd") if date.year == 2022 else np.zeros((4, 1, 0)), + start_date=datetime.datetime(2021, 1, 1), + time_increment=datetime.timedelta(hours=1), + statistics_reference_dataset="test-2022-2022-1h-o96-abcd", + statistics_reference_variables="abcd", + regular_shape=False, + ) + + @mockup_open_zarr def test_subset_1() -> None: """Test subsetting a dataset (case 1).""" diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 000000000..896081f9a --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,160 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import os + +import numpy as np +import pytest + +from anemoi.datasets.data import open_dataset +from anemoi.datasets.data.records import Record +from anemoi.datasets.data.records import Tabular + + +def check_numpy(x, y): + assert x.shape == y.shape, f"Expected {x.shape} == {y.shape}" + assert type(x) == type(y), f"Expected {type(x)} == {type(y)}" # noqa: E721 + assert np.all(np.isnan(x) == np.isnan(y)) and np.all( + np.nan_to_num(x) == np.nan_to_num(y) + ), f"Expected {x} == {y} (ignoring NaNs)" + + +def _test(ds, nb_dates=None): + grp = "metop_a_ascat" + index_i = 0 + + if nb_dates is not None: + assert len(ds) == nb_dates, f"Expected {nb_dates} dates, got {len(ds)}" + + ################################# + # Order does not matter too much [i] and [grp] are exchangeable + + elt = ds[index_i] + assert isinstance(elt, Record), (type(ds), type(elt)) + assert ds[index_i].dataset == ds, (type(ds[index_i].dataset), type(ds)) + + group = ds[grp] + assert isinstance(group, Tabular), type(group) + + x = ds[grp][index_i] + y = ds[index_i][grp] + check_numpy(x, y) + + ############################################### + # lat and lon and timedelta are not the same for all elements + # but they have the same size + + lat = ds[index_i].latitudes[grp] + assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet + # lat = ds[grp].latitudes[index_i] + # assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet : do not need ? + # lat = ds.latitudes[grp][index_i] + # assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet : do not need ? + # lat = ds.latitudes[index_i][grp] + # assert isinstance(lat, np.ndarray), type(lat) + + lon = ds[index_i].longitudes[grp] + assert isinstance(lon, np.ndarray), type(lon) + assert len(lat) == len(lon), f"Expected same size for lat and lon {len(lat)} == {len(lon)}" + + timedeltas = ds[index_i].timedeltas[grp] + assert isinstance(timedeltas, np.ndarray), type(timedeltas) + assert len(timedeltas) == len(lat), f"Expected same size for lat and timedeltas {len(lat)} == {len(timedeltas)}" + + ############################################# + # name_to_index is must be the same for all elements + # name_to_index is a dict of dict (key is the group name) + + name_to_index = ds.name_to_index + assert isinstance(name_to_index, dict), type(name_to_index) + assert len(name_to_index) > 0, "name_to_index is empty" + assert all(isinstance(k, str) for k in name_to_index.keys()), name_to_index + assert all(isinstance(v, dict) for v in name_to_index.values()), name_to_index + + _name_to_index = ds[index_i].name_to_index + assert list(name_to_index.keys()) == list(_name_to_index.keys()), ( + list(name_to_index.keys()), + list(_name_to_index.keys()), + ) + assert name_to_index == _name_to_index, "name_to_index is not the same for all elements" + + ############################################### + # statistics is not the same for all elements + # statistics is a dict of dict (first key is the group name) + + statistics = ds.statistics + assert isinstance(statistics, dict), type(statistics) + assert len(statistics) > index_i, "statistics is empty" + assert all(isinstance(k, str) for k in statistics.keys()), statistics + assert all(isinstance(v, dict) for v in statistics.values()), statistics + assert grp in statistics, f"statistics does not contain {grp}" + + statistics_ = ds[grp].statistics + assert isinstance(statistics_, dict), type(statistics_) + assert "mean" in statistics_, "statistics does not contain mean" + + # ! here, the meaning could be ambigous, this is the statistics of the whole dataset. + # Do not document this, and maybe remove it. + _statistics = ds[index_i].statistics + assert isinstance(_statistics, dict), type(_statistics) + assert grp in _statistics, f"statistics does not contain {grp}" + assert _statistics.keys() == ds.keys(), (_statistics.keys(), ds.keys()) + for group_name, stats in _statistics.items(): + assert "mean" in stats, f"statistics does not contain mean for {group_name}" + for key, v in stats.items(): + assert np.all(statistics[group_name][key] == v), (key, statistics[group_name][key], v) + + assert statistics[grp].keys() == statistics_.keys(), (statistics[grp].keys(), statistics_.keys()) + for key, v in statistics[grp].items(): + assert np.all(statistics[grp][key] == v), (key, statistics[grp][key], v) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open(): + ds = open_dataset("../../data/vz/obs-2018-11.vz") + _test(ds) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open_with_subset_dates(): + ds = open_dataset( + "../../data/vz/obs-2018-11.vz", + end="2018-11-30", + select=[ + "metop_a_ascat.*", + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", + ], + ) + _test(ds, nb_dates=8) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open_with_subset_select(): + ds = open_dataset( + "../../data/vz/obs-2018-11.vz", + select=[ + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", + "metop_a_ascat.*", + ], + ) + _test(ds) + + +if __name__ == "__main__": + + test_open() + test_open_with_subset_select() + test_open_with_subset_dates() diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py new file mode 100644 index 000000000..ba185ca39 --- /dev/null +++ b/tests/xarray/test_flavour.py @@ -0,0 +1,104 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np +import pytest +import xarray as xr + +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.flavour import DefaultCoordinateGuesser + + +def create_ds(var_name, standard_name, long_name, units, coord_length=5): + attrs = { + k: v for k, v in [("standard_name", standard_name), ("long_name", long_name), ("units", units)] if v is not None + } + + ds = xr.Dataset( + {"x_wind": ([var_name], np.random.rand(coord_length))}, + coords={ + var_name: xr.DataArray(np.arange(coord_length), dims=var_name, attrs=attrs), + }, + ) + return ds + + +@pytest.mark.parametrize( + "var_name, standard_name, long_name, units, result", + [ + # longitude + ("longitude", None, None, None, LongitudeCoordinate), + ("longitude", None, "longitude", "degrees_east", LongitudeCoordinate), + ("longitude", None, "longitude", "degrees", LongitudeCoordinate), + ("lons", "longitude", None, "degrees", LongitudeCoordinate), + ("lons", None, None, None, UnsupportedCoordinate), + # latitude + ("latitude", None, None, None, LatitudeCoordinate), + ("latitude", None, "latitude", "degrees_north", LatitudeCoordinate), + ("latitude", None, "latitude", "degrees", LatitudeCoordinate), + ("lats", "latitude", None, "degrees", LatitudeCoordinate), + ("lats", None, None, "'degrees", UnsupportedCoordinate), + # x + ("x", None, None, None, XCoordinate), + ("x_coord", "projection_x_coordinate", None, None, XCoordinate), + ("x_coord", "grid_longitude", None, None, XCoordinate), + # y + ("y", None, None, None, YCoordinate), + ("y_coord", "projection_y_coordinate", None, None, YCoordinate), + ("y_coord", "grid_latitude", None, None, YCoordinate), + # time + ("time", "time", None, None, TimeCoordinate), + ("time", None, None, None, TimeCoordinate), + # date + ("t", "forecast_reference_time", None, None, DateCoordinate), + ("forecast_reference_time", None, None, None, DateCoordinate), + ("forecast_reference_time", "forecast_reference_time", None, None, DateCoordinate), + # step + ("fp", "forecast_period", None, None, StepCoordinate), + ("forecast_period", None, "time elapsed since the start of the forecast", None, StepCoordinate), + ("prediction_timedelta", None, None, None, StepCoordinate), + # level + ("lev", "atmosphere_hybrid_sigma_pressure_coordinate", None, None, LevelCoordinate), + ("h", None, "height", "m", LevelCoordinate), + ("level", "air_pressure", None, "hPa", LevelCoordinate), + ("pressure_0", None, "pressure", "hPa", LevelCoordinate), + ("pressure_0", None, "pressure", "Pa", LevelCoordinate), + ("level", None, None, None, LevelCoordinate), + ("lev", None, "level", None, UnsupportedCoordinate), + ("vertical", "vertical", None, "hPa", LevelCoordinate), + ("depth", "depth", None, "m", LevelCoordinate), + ("depth", "depth", None, None, LevelCoordinate), + # number + ("realization", None, None, None, EnsembleCoordinate), + ("number", None, None, None, EnsembleCoordinate), + ], +) +def test_coordinate_guesser(var_name, standard_name, long_name, units, result): + ds = create_ds(var_name, standard_name, long_name, units) + guesser = DefaultCoordinateGuesser(ds) + guess = guesser.guess(ds[var_name], var_name) + assert isinstance(guess, result) + + +def test_coordinate_guesser_scalar(): + var_name = "height" + ds = create_ds(var_name, None, None, "m", coord_length=1) + guesser = DefaultCoordinateGuesser(ds) + guess = guesser.guess(ds[var_name], var_name) + assert isinstance(guess, ScalarCoordinate) diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index c0a8f8ed2..81d36b923 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -13,7 +13,6 @@ from anemoi.utils.testing import skip_missing_packages from anemoi.datasets.create.sources.xarray import XarrayFieldList -from anemoi.datasets.data.stores import name_to_zarr_store from anemoi.datasets.testing import assert_field_list @@ -133,33 +132,6 @@ def test_noaa_replay() -> None: ) -@skip_if_offline -@skip_missing_packages("planetary_computer", "adlfs") -def test_planetary_computer_conus404() -> None: - """Test loading and validating the planetary_computer_conus404 dataset.""" - url = "https://planetarycomputer.microsoft.com/api/stac/v1/collections/conus404" - ds = xr.open_zarr(**name_to_zarr_store(url)) - - flavour = { - "rules": { - "latitude": {"name": "lat"}, - "longitude": {"name": "lon"}, - "x": {"name": "west_east"}, - "y": {"name": "south_north"}, - "time": {"name": "time"}, - }, - } - - fs = XarrayFieldList.from_xarray(ds, flavour=flavour) - - assert_field_list( - fs, - 74634912, - "1979-10-01T00:00:00", - "2022-09-30T23:00:00", - ) - - if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): diff --git a/tools/build-obs.py b/tools/build-obs.py new file mode 100755 index 000000000..e3caff9f9 --- /dev/null +++ b/tools/build-obs.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os +import shutil + +import tqdm + +from anemoi.datasets import open_dataset + +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="open a dataset and build a new one") + parser.add_argument("input", help="input dataset") + parser.add_argument("output", help="output dataset") + parser.add_argument("--backend", help="backend to use", type=str, default="npz1") + parser.add_argument("--overwrite", help="overwrite output directory if it exists", action="store_true") + args = parser.parse_args() + build(**vars(args)) + + +def build(input, output, backend, overwrite=False): + ds = open_dataset(input, backend=backend) + print(f"Using dataset {ds} as input") + print(f"{input} backend is '{ds.metadata['backend']}'") + print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") + print(f"Converting dataset to {output} using new backend '{backend}'") + + from anemoi.datasets.data.records.backends import writer_backend_factory + + if os.path.exists(output): + if overwrite: + LOG.warning(f"Output directory {output} already exists, removing it") + shutil.rmtree(output) + else: + raise FileExistsError(f"Output directory {output} already exists, use --overwrite to remove it") + writer = writer_backend_factory(backend, output) + + for i in tqdm.tqdm(range(len(ds))): + writer.write(i, ds[i]) + + writer.write_statistics(ds.statistics) + + metadata = ds.metadata.copy() + metadata["backend"] = backend + writer.write_metadata(metadata) + + +if __name__ == "__main__": + main() diff --git a/tools/check-obs.py b/tools/check-obs.py new file mode 100755 index 000000000..283b76aac --- /dev/null +++ b/tools/check-obs.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +import argparse +import logging + +import numpy as np + +from anemoi.datasets import open_dataset + +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="open two datasets and compare them") + parser.add_argument("dataset", help="dataset to check") + parser.add_argument("reference", help="reference dataset") + args = parser.parse_args() + compare(args.dataset, args.reference) + + +def _compare_nested_dicts(a, b): + if isinstance(a, dict) and isinstance(b, dict): + if a.keys() != b.keys(): + return False + return all(_compare_nested_dicts(a[k], b[k]) for k in a) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + if a.shape != b.shape: + return False + return np.array_equal(a, b) + assert False, f"Unsupported types for comparison: {type(a)} and {type(b)}" + + +def compare(input, reference): + ds = open_dataset(input) + ref = open_dataset(reference) + + if len(ds) != len(ref): + raise ValueError(f"Datasets have different lengths: {len(ds)} != {len(ref)}") + + for i in range(len(ds)): + if ds[i] != ref[i]: + raise ValueError(f"Datasets differ at index {i}: {ds[i]} != {ref[i]}") + if ds.dates[i] != ref.dates[i]: + raise ValueError(f"Dates differ at index {i}: {ds.dates[i]} != {ref.dates[i]}") + print("✅ Data and dates are identical") + + ds_metadata = ds.metadata.copy() + ref_metadata = ref.metadata.copy() + ds_metadata.pop("backend", None) + ref_metadata.pop("backend", None) + if ds_metadata != ref_metadata: + raise ValueError("Metadata differs between datasets (excluding backend)") + print("✅ Metadata is identical") + + if not _compare_nested_dicts(ds.statistics, ref.statistics): + raise ValueError("Statistics differ between datasets") + print("✅ Statistics are identical") + + +if __name__ == "__main__": + main() From 83936f75ee3f65982760b0c3d5227df5cf70dd57 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:05:53 +0200 Subject: [PATCH 29/79] merge --- 03-constant-fields.rst | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 03-constant-fields.rst diff --git a/03-constant-fields.rst b/03-constant-fields.rst deleted file mode 100644 index 2b21505dc..000000000 --- a/03-constant-fields.rst +++ /dev/null @@ -1,3 +0,0 @@ -######################## - Adding constant fields -######################## From 5209f26c0aa229d076834ab895094852c4208100 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:24:47 +0200 Subject: [PATCH 30/79] update --- src/anemoi/datasets/create/__init__.py | 14 +++- src/anemoi/datasets/create/python.py | 98 +++++++++++--------------- 2 files changed, 52 insertions(+), 60 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 69b5a0d42..508eb2952 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1667,11 +1667,19 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - from ..create.python import PythonCode + from ..create.python import PythonSource config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - code = PythonCode() - return input.python_code(code) + code = PythonSource() + code = input.python_code(code).source_code() + + try: + import black + + return black.format_str(code, mode=black.Mode()) + except ImportError: + LOG.warning("Black not installed, skipping formatting") + return code diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 271845c2d..8d3312583 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -8,11 +8,8 @@ # nor does it submit to any jurisdiction. # -import datetime -import json import re - -from anemoi.utils.dates import frequency_to_string +import textwrap # input.python_prelude(code) # code1 = "\n".join(prelude) @@ -61,21 +58,51 @@ class PythonCode: + def __init__(self, parent): + self.parent = parent + def call(self, name, argument): - return PythonCall(name, argument) + return PythonCall(self, name, argument) def sum(self, actions): - return PythonChain("+", actions) + return PythonChain(self, "+", actions) def pipe(self, actions): - return PythonChain("|", actions) + return PythonChain(self, "|", actions) def concat(self, argument): - return PythonConcat(argument) + return PythonConcat(self, argument) + + def source_code(self, top=None): + return self.parent.source_code(top=top or self) + + +class PythonSource(PythonCode): + + def __init__(self): + super().__init__(parent=None) + self.prelude = "" + + def source_code(self, top): + + return textwrap.dedent( + f""" + # Generated Python code for Anemoi dataset creation + + from anemoi.datasets.recipe import Recipe + r = Recipe() + {self.prelude} + r.input = {repr(top)} + + """ + ) + + return repr(top) class PythonConcat(PythonCode): - def __init__(self, argument): + def __init__(self, parent, argument): + super().__init__(parent=parent) self.argument = argument def __repr__(self): @@ -83,13 +110,14 @@ def __repr__(self): class PythonCall(PythonCode): - def __init__(self, name, argument): + def __init__(self, parent, name, argument): + super().__init__(parent=parent) self.name = name self.argument = argument def __repr__(self): name = self.name.replace("-", "_") - config = self.argument + config = dict(**self.argument) # def convert(obj): # if isinstance(obj, datetime.datetime): @@ -121,54 +149,10 @@ def __repr__(self): class PythonChain(PythonCode): - def __init__(self, op, actions): + def __init__(self, parent, op, actions): + super().__init__(parent=parent) self.op = op self.actions = actions def __repr__(self): return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - -def _python(name, config, **extra) -> str: - """Convert the action to Python code. - - Parameters - ---------- - name : str - The name of the action. - config : dict - The configuration for the action. - extra : Any - Additional keyword arguments. - - Returns - ------- - str - The Python code representation of the action. - """ - - name = name.replace("-", "_") - - def convert(obj): - if isinstance(obj, datetime.datetime): - return obj.isoformat() - if isinstance(obj, datetime.date): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return frequency_to_string(obj) - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - config = json.loads(json.dumps(config, default=convert)) - - params = [] - for k, v in config.items(): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: - return f"r.{name}({config})" - params.append(f"{k}={repr(v)}") - - for k, v in extra.items(): - params.append(f"{k}={v}") - - params = ",".join(params) - return f"r.{name}({params})" - # return f"{name}({config})" From c7a0e5d7d50817ad02d1b879b07a6147c9993e35 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 12 Aug 2025 12:21:59 +0200 Subject: [PATCH 31/79] update --- src/anemoi/datasets/create/__init__.py | 6 +- src/anemoi/datasets/create/python.py | 255 ++++++++++++++++++------- 2 files changed, 194 insertions(+), 67 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 508eb2952..7a076aa48 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1674,12 +1674,14 @@ def config_to_python(config: Any) -> Any: input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) code = PythonSource() - code = input.python_code(code).source_code() + x = input.python_code(code) + code = code.source_code(x) try: import black return black.format_str(code, mode=black.Mode()) - except ImportError: + # except ImportError: + except Exception: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 8d3312583..d54fca1f3 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -9,24 +9,9 @@ # import re -import textwrap +import sys +from collections import defaultdict -# input.python_prelude(code) -# code1 = "\n".join(prelude) -# rich.print(f"Input prelude:\n{code1}") -# code2 = input.to_python() - -# code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" - -# code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) - -# try: -# import black - -# return black.format_str(code, mode=black.Mode()) -# except ImportError: -# LOG.warning("Black not installed, skipping formatting") -# return code RESERVED_KEYWORDS = ( "and", "or", @@ -56,103 +41,243 @@ ) +def _sanitize_name(name): + name = name.replace("-", "_") + if name in RESERVED_KEYWORDS: + name = f"{name}_" + return name + + class PythonCode: - def __init__(self, parent): - self.parent = parent + def __init__(self, top): + print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) + self.top = top + self.top.register(self) + self.key = str(id(self)) def call(self, name, argument): - return PythonCall(self, name, argument) + return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self, "+", actions) + return PythonChain(self.top, "+", actions) def pipe(self, actions): - return PythonChain(self, "|", actions) + return PythonChain(self.top, "|", actions) def concat(self, argument): - return PythonConcat(self, argument) + return PythonConcat(self.top, argument) + + def source_code(self): + return self.top.source_code(self) + + def combine(self, nodes): + return None + + +class Argument: + + def __init__(self, name): + self.name = name - def source_code(self, top=None): - return self.parent.source_code(top=top or self) + def __repr__(self): + return f"{_sanitize_name(self.name)}" class PythonSource(PythonCode): def __init__(self): - super().__init__(parent=None) - self.prelude = "" + self._prelude = [] + self.nodes = [] + self._count = defaultdict(int) + super().__init__(top=self) + + def register(self, child): + if child is not self: + self.nodes.append(child) + + def prelude(self): + return "\n".join(self._prelude) + + def source_code(self, first): + + which = self.nodes.index(first) + + more = True + while more: + more = False + + by_class = defaultdict(list) + for node in self.nodes: + by_class[(node.__class__, node.key)].append(node) + + for (cls, key), nodes in by_class.items(): + if len(nodes) > 1: + print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) + print(f"Nodes: {len(nodes)}", file=sys.stderr) + changes = nodes[0].combine(nodes) + if changes: + self.replace_nodes(changes) + more = True + + first = self.nodes[which] + + return "\n\n".join( + [ + "# Generated Python code for Anemoi dataset creation", + "from anemoi.datasets.recipe import Recipe", + "r = Recipe()", + *self._prelude, + f"r.input = {repr(first)}", + "r.dump()", + ] + ) - def source_code(self, top): + def function(self, key, value, node): - return textwrap.dedent( - f""" - # Generated Python code for Anemoi dataset creation + n = self._count[node.name] + self._count[node.name] += 1 - from anemoi.datasets.recipe import Recipe - r = Recipe() - {self.prelude} - r.input = {repr(top)} + name = f"{node.name}_{n}" + name = _sanitize_name(name) + key = _sanitize_name(key) - """ - ) + class Function: + def __init__(self, name, key, value, node): + self.name = name + self.key = key + self.value = value + self.node = node + + def __repr__(self): + return f"{self.name}" + + self._prelude.append(f"def {name}({key}):") + self._prelude.append(f" return {node}") + return Function(name, key, value, node) - return repr(top) + def replace_nodes(self, changes): + + for old, new in changes: + assert old in self.nodes, f"Node {old} not found in {self.nodes}" + for i, node in enumerate(self.nodes): + + if node is old: + self.nodes[i] = new + else: + node.replace_node(old, new) class PythonConcat(PythonCode): - def __init__(self, parent, argument): - super().__init__(parent=parent) + def __init__(self, top, argument): + super().__init__(top=top) self.argument = argument + for k, v in self.argument.items(): + assert isinstance(v, PythonCode), f"Value must be a PythonCode instance {v}" def __repr__(self): - return str(self.argument) + return f"r.concat({self.argument})" + + def replace_node(self, old, new): + for k, v in list(self.argument.items()): + if v is old: + self.argument[k] = new + else: + v.replace_node(old, new) class PythonCall(PythonCode): - def __init__(self, parent, name, argument): - super().__init__(parent=parent) + def __init__(self, top, name, argument, parameters=None): + super().__init__(top=top) self.name = name self.argument = argument + self.key = name + self.parameters = parameters def __repr__(self): name = self.name.replace("-", "_") config = dict(**self.argument) - # def convert(obj): - # if isinstance(obj, datetime.datetime): - # return obj.isoformat() - # if isinstance(obj, datetime.date): - # return obj.isoformat() - # if isinstance(obj, datetime.timedelta): - # return frequency_to_string(obj) - # if isinstance(obj, PythonCode): - # return obj - # raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - # config = json.loads(json.dumps(config, default=convert)) - params = [] + for k, v in config.items(): if isinstance(k, str): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + + if k in RESERVED_KEYWORDS: + k = f"{k}_" + + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): return f"r.{name}({config})" params.append(f"{k}={repr(v)}") - # for k, v in extra.items(): - # params.append(f"{k}={v}") - params = ",".join(params) return f"r.{name}({params})" - # return f"{name}({config})" - return f"{self.name}({self.argument})" + + def replace_node(self, old, new): + pass + + def combine(self, nodes): + + x = defaultdict(list) + for node in nodes: + argument = node.argument + for k, v in argument.items(): + rest = {k2: v2 for k2, v2 in sorted(argument.items()) if k2 != k} + x[str(rest)].append((k, v, node)) + + for i in sorted(x.values(), key=len, reverse=True): + key, value, node = i[0] + if len(i) < 2: + return + + rest = {k: v for k, v in node.argument.items() if k != key} + rest[key] = Argument(key) + call = PythonCall(self.top, self.name, rest) + + func = self.top.function(key, value, node=call) + changes = [] + for key, value, node in i: + + new = PythonFunction( + top=self.top, + func=func, + argument={key: value}, + ) + + changes.append((node, new)) + + return changes class PythonChain(PythonCode): - def __init__(self, parent, op, actions): - super().__init__(parent=parent) + def __init__(self, top, op, actions): + super().__init__(top=top) self.op = op - self.actions = actions + self.actions = list(actions) + self.key = op def __repr__(self): return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + def replace_node(self, old, new): + + for i, node in enumerate(self.actions): + + if node is old: + self.actions[i] = new + else: + node.replace_node(old, new) + + +class PythonFunction(PythonCode): + def __init__(self, top, func, argument): + super().__init__(top=top) + self.func = func + self.argument = argument + self.key = func + + def __repr__(self): + return f"{self.func}({', '.join(f'{_sanitize_name(k)}={repr(v)}' for k, v in self.argument.items())})" + + def replace_node(self, old, new): + pass From b78a098829fe1dac60c28ea866919b29a81218f9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 12 Aug 2025 18:54:04 +0200 Subject: [PATCH 32/79] update --- src/anemoi/datasets/create/__init__.py | 4 +- src/anemoi/datasets/create/python.py | 217 ++++++++++++++++++------- 2 files changed, 161 insertions(+), 60 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 7a076aa48..fbe4ab024 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1667,13 +1667,13 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - from ..create.python import PythonSource + from ..create.python import PythonScript config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - code = PythonSource() + code = PythonScript() x = input.python_code(code) code = code.source_code(x) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d54fca1f3..dccf33642 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -11,6 +11,7 @@ import re import sys from collections import defaultdict +from functools import cached_property RESERVED_KEYWORDS = ( "and", @@ -74,22 +75,74 @@ def source_code(self): def combine(self, nodes): return None + def prelude(self): + return None + class Argument: def __init__(self, name): - self.name = name + self.name = _sanitize_name(name) def __repr__(self): - return f"{_sanitize_name(self.name)}" + return self.name -class PythonSource(PythonCode): +class Parameter: + + def __init__(self, name): + self.name = _sanitize_name(name) + + def __repr__(self): + return self.name + + +class Function: + def __init__(self, name, node, counter, *parameters): + self._name = name + self.node = node + self.used = False + self.counter = counter + # self.parameters = parameters + + def __repr__(self): + return self.name + + def prelude(self): + if self.used: + return None + + self.used = True + + node_prelude = self.node.prelude() + + arguments = self.node.free_arguments() + + return [ + *(node_prelude if node_prelude else []), + f"def {self.name}({','.join(repr(p) for p in arguments)}):", + f" return {self.node}", + ] + + def free_arguments(self): + return self.node.free_arguments() + + @cached_property + def name(self): + n = self.counter[self._name] + self.counter[self._name] += 1 + return _sanitize_name(f"{self._name}_{n}") + + def replace_node(self, old, new): + if self.node is old: + self.node = new + + +class PythonScript(PythonCode): def __init__(self): - self._prelude = [] self.nodes = [] - self._count = defaultdict(int) + self.counter = defaultdict(int) super().__init__(top=self) def register(self, child): @@ -97,7 +150,14 @@ def register(self, child): self.nodes.append(child) def prelude(self): - return "\n".join(self._prelude) + result = [] + for node in self.nodes: + prelude = node.prelude() + if prelude: + if not isinstance(prelude, (list, tuple)): + prelude = list(prelude) + result.extend(prelude) + return "\n".join(result) def source_code(self, first): @@ -127,34 +187,17 @@ def source_code(self, first): "# Generated Python code for Anemoi dataset creation", "from anemoi.datasets.recipe import Recipe", "r = Recipe()", - *self._prelude, + self.prelude(), f"r.input = {repr(first)}", "r.dump()", ] ) - def function(self, key, value, node): - - n = self._count[node.name] - self._count[node.name] += 1 + def function0(self, node): + return Function(node.name, node, self.counter) - name = f"{node.name}_{n}" - name = _sanitize_name(name) - key = _sanitize_name(key) - - class Function: - def __init__(self, name, key, value, node): - self.name = name - self.key = key - self.value = value - self.node = node - - def __repr__(self): - return f"{self.name}" - - self._prelude.append(f"def {name}({key}):") - self._prelude.append(f" return {node}") - return Function(name, key, value, node) + def function1(self, node, key): + return Function(node.name, node, self.counter, Parameter(key)) def replace_nodes(self, changes): @@ -172,8 +215,6 @@ class PythonConcat(PythonCode): def __init__(self, top, argument): super().__init__(top=top) self.argument = argument - for k, v in self.argument.items(): - assert isinstance(v, PythonCode), f"Value must be a PythonCode instance {v}" def __repr__(self): return f"r.concat({self.argument})" @@ -186,13 +227,39 @@ def replace_node(self, old, new): v.replace_node(old, new) +class PythonChain(PythonCode): + def __init__(self, top, op, actions): + super().__init__(top=top) + self.op = op + self.actions = list(actions) + self.key = op + + def __repr__(self): + return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + def replace_node(self, old, new): + + for i, node in enumerate(self.actions): + + if node is old: + self.actions[i] = new + else: + node.replace_node(old, new) + + class PythonCall(PythonCode): - def __init__(self, top, name, argument, parameters=None): + def __init__(self, top, name, argument): super().__init__(top=top) self.name = name self.argument = argument self.key = name - self.parameters = parameters + + def free_arguments(self): + result = [] + for k, v in self.argument.items(): + if isinstance(v, Argument): + result.append(v) + return result def __repr__(self): name = self.name.replace("-", "_") @@ -210,6 +277,9 @@ def __repr__(self): return f"r.{name}({config})" params.append(f"{k}={repr(v)}") + if params: + params.append("") # For a trailing comma + params = ",".join(params) return f"r.{name}({params})" @@ -218,6 +288,43 @@ def replace_node(self, old, new): def combine(self, nodes): + # Exact similarity + + changes = self._combine0(nodes) + if changes: + return changes + + # On key difference + changes = self._combine1(nodes) + if changes: + return changes + + def _combine0(self, nodes): + + x = defaultdict(list) + for node in nodes: + key = {k2: v2 for k2, v2 in sorted(node.argument.items())} + x[str(key)].append(node) + + for i in sorted(x.values(), key=len, reverse=True): + node = i[0] + if len(i) < 2: + return + + call = PythonCall(self.top, self.name, node.argument) + + func = self.top.function0(node=call) + changes = [] + for node in i: + + new = PythonFunction(top=self.top, func=func) + + changes.append((node, new)) + + return changes + + def _combine1(self, nodes): + x = defaultdict(list) for node in nodes: argument = node.argument @@ -234,14 +341,14 @@ def combine(self, nodes): rest[key] = Argument(key) call = PythonCall(self.top, self.name, rest) - func = self.top.function(key, value, node=call) + func = self.top.function1(call, key) changes = [] for key, value, node in i: new = PythonFunction( top=self.top, func=func, - argument={key: value}, + **{key: value}, ) changes.append((node, new)) @@ -249,35 +356,29 @@ def combine(self, nodes): return changes -class PythonChain(PythonCode): - def __init__(self, top, op, actions): +class PythonFunction(PythonCode): + def __init__(self, top, func, **kwargs): super().__init__(top=top) - self.op = op - self.actions = list(actions) - self.key = op + self.func = func + self.kwargs = kwargs def __repr__(self): - return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - def replace_node(self, old, new): - - for i, node in enumerate(self.actions): - - if node is old: - self.actions[i] = new + params = [] + for a in self.func.free_arguments(): + name = _sanitize_name(a.name) + if a.name in self.kwargs: + v = self.kwargs[a.name] + params.append(f"{name}={repr(v)}") else: - node.replace_node(old, new) + params.append(f"{name}={name}") + return f"{self.func}({', '.join(params)})" -class PythonFunction(PythonCode): - def __init__(self, top, func, argument): - super().__init__(top=top) - self.func = func - self.argument = argument - self.key = func + def replace_node(self, old, new): + self.func.replace_node(old, new) - def __repr__(self): - return f"{self.func}({', '.join(f'{_sanitize_name(k)}={repr(v)}' for k, v in self.argument.items())})" + def prelude(self): + return self.func.prelude() - def replace_node(self, old, new): - pass + def free_arguments(self): + return [a for a in self.func.free_arguments() if a.name not in self.kwargs] From d641ea78309228bdd870155a1226ebe67095ebbb Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 13 Aug 2025 17:05:18 +0200 Subject: [PATCH 33/79] update --- src/anemoi/datasets/create/input/action.py | 9 +- src/anemoi/datasets/create/python.py | 34 +++-- src/anemoi/datasets/dates/__init__.py | 91 +++++++----- src/anemoi/datasets/dates/groups.py | 155 +++++++++++++++------ 4 files changed, 197 insertions(+), 92 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 8dadf14dc..d8120a289 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -56,7 +56,10 @@ def __call__(self, context, argument): def python_code(self, code): return code.concat( - {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} + { + filtering_dates.to_python(just_dates=True): action.python_code(code) + for filtering_dates, action in self.choices + } ) @@ -124,6 +127,10 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) def python_code(self, code) -> str: + # For now... + if "source" in self.config: + source = action_factory(self.config["source"]) + self.config["source"] = source.python_code(code) return code.call(self.name, self.config) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index dccf33642..9e66ccfdd 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -49,6 +49,16 @@ def _sanitize_name(name): return name +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + class PythonCode: def __init__(self, top): @@ -98,12 +108,11 @@ def __repr__(self): class Function: - def __init__(self, name, node, counter, *parameters): + def __init__(self, name, node, counter): self._name = name self.node = node self.used = False self.counter = counter - # self.parameters = parameters def __repr__(self): return self.name @@ -131,6 +140,8 @@ def free_arguments(self): def name(self): n = self.counter[self._name] self.counter[self._name] += 1 + if n == 0: + return _sanitize_name(self._name) return _sanitize_name(f"{self._name}_{n}") def replace_node(self, old, new): @@ -193,12 +204,9 @@ def source_code(self, first): ] ) - def function0(self, node): + def function(self, node): return Function(node.name, node, self.counter) - def function1(self, node, key): - return Function(node.name, node, self.counter, Parameter(key)) - def replace_nodes(self, changes): for old, new in changes: @@ -214,7 +222,7 @@ def replace_nodes(self, changes): class PythonConcat(PythonCode): def __init__(self, top, argument): super().__init__(top=top) - self.argument = argument + self.argument = _un_dotdict(argument) def __repr__(self): return f"r.concat({self.argument})" @@ -251,7 +259,7 @@ class PythonCall(PythonCode): def __init__(self, top, name, argument): super().__init__(top=top) self.name = name - self.argument = argument + self.argument = _un_dotdict(argument) self.key = name def free_arguments(self): @@ -313,7 +321,7 @@ def _combine0(self, nodes): call = PythonCall(self.top, self.name, node.argument) - func = self.top.function0(node=call) + func = self.top.function(call) changes = [] for node in i: @@ -341,7 +349,7 @@ def _combine1(self, nodes): rest[key] = Argument(key) call = PythonCall(self.top, self.name, rest) - func = self.top.function1(call, key) + func = self.top.function(call) changes = [] for key, value, node in i: @@ -363,6 +371,12 @@ def __init__(self, top, func, **kwargs): self.kwargs = kwargs def __repr__(self): + + # if len(self.func.free_arguments()) == 0: + # a = repr(self.func.node) + # if '=' not in a: + # return a + params = [] for a in self.func.free_arguments(): name = _sanitize_name(a.name) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 9570d381f..0ae8105e8 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -32,13 +32,15 @@ def extend(x: Union[str, List[Any], Tuple[Any, ...]]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Args: - x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. - - Returns - ------- - Iterator[datetime.datetime] - An iterator of datetime objects. + Parameters + ---------- + x : Union[str, List[Any], Tuple[Any, ...]] + A date range string or list/tuple of dates. + + Yields + ------ + datetime.datetime + Individual datetime objects. """ if isinstance(x, (list, tuple)): @@ -63,6 +65,8 @@ def extend(x: Union[str, List[Any], Tuple[Any, ...]]) -> Iterator[datetime.datet class DatesProvider: """Base class for date generation. + Examples + -------- >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)] @@ -163,9 +167,12 @@ def summary(self) -> str: class ValuesDates(DatesProvider): """Class for handling a list of date values. - Args: - values (List[Union[str, datetime.datetime]]): List of date values. - **kwargs (Any): Additional arguments. + Parameters + ---------- + values : List[Union[str, datetime.datetime]] + List of date values. + **kwargs : Any + Additional arguments. """ def __init__(self, values: List[Union[str, datetime.datetime]], **kwargs: Any) -> None: @@ -202,11 +209,16 @@ def as_dict(self) -> Dict[str, Any]: class StartEndDates(DatesProvider): """Class for generating dates between a start and end date with a specified frequency. - Args: - start (Union[str, datetime.datetime]): Start date. - end (Union[str, datetime.datetime]): End date. - frequency (Union[int, str]): Frequency of dates. - **kwargs (Any): Additional arguments. + Parameters + ---------- + start : Union[str, datetime.datetime] + Start date. + end : Union[str, datetime.datetime] + End date. + frequency : Union[int, str] + Frequency of dates. + **kwargs : Any + Additional arguments. """ def __repr__(self) -> str: @@ -273,26 +285,27 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self) -> str: - """Convert the StartEndDates instance to a Python string. - - Returns - ------- - str - Python string representation of the instance. - """ - # assert self.frequency == frequency_to_timedelta(1), self.frequency - return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) + def to_python(self, just_dates=False) -> str: + """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" + if just_dates: + return (self.start.isoformat(), self.end.isoformat()) + else: + return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) class Hindcast: """Class representing a single hindcast date. - Args: - date (datetime.datetime): The date of the hindcast. - refdate (datetime.datetime): The reference date. - hdate (datetime.datetime): The hindcast date. - step (int): The step value. + Parameters + ---------- + date : datetime.datetime + The date of the hindcast. + refdate : datetime.datetime + The reference date. + hdate : datetime.datetime + The hindcast date. + step : int + The step value. """ def __init__( @@ -315,12 +328,18 @@ def __init__( class HindcastsDates(DatesProvider): """Class for generating hindcast dates over a range of years. - Args: - start (Union[str, List[str]]): Start date(s). - end (Union[str, List[str]]): End date(s). - steps (List[int]): List of step values. - years (int): Number of years. - **kwargs (Any): Additional arguments. + Parameters + ---------- + start : Union[str, List[str]] + Start date(s). + end : Union[str, List[str]] + End date(s). + steps : List[int] + List of step values. + years : int + Number of years. + **kwargs : Any + Additional arguments. """ def __init__( diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index f70fe8a57..88165f609 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -27,11 +27,15 @@ def _shorten(dates: Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]) -> Union[str, List[str]]: """Shorten the list of dates for display. - Args: - dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. - - Returns: - Union[str, List[str]]: The shortened list of dates. + Parameters + ---------- + dates : Union[List[datetime.datetime], Tuple[datetime.datetime, ...]] + The list of dates. + + Returns + ------- + Union[str, List[str]] + The shortened list of dates. """ if isinstance(dates, (list, tuple)): dates = [d.isoformat() for d in dates] @@ -44,6 +48,17 @@ class GroupOfDates: """A class to represent a group of dates.""" def __init__(self, dates: List[datetime.datetime], provider: DatesProvider, partial_ok: bool = False) -> None: + """Initialise a GroupOfDates instance. + + Parameters + ---------- + dates : List[datetime.datetime] + List of dates. + provider : DatesProvider + The dates provider. + partial_ok : bool, optional + Whether partial groups are allowed (default is False). + """ assert isinstance(provider, DatesProvider), type(provider) assert isinstance(dates, list) @@ -54,35 +69,45 @@ def __init__(self, dates: List[datetime.datetime], provider: DatesProvider, part def __len__(self) -> int: """Return the number of dates in the group. - Returns: - int: The number of dates. + Returns + ------- + int + The number of dates. """ return len(self.dates) def __iter__(self) -> Iterator[datetime.datetime]: """Return an iterator over the dates in the group. - Returns: - Iterator[datetime.datetime]: The iterator over the dates. + Returns + ------- + Iterator[datetime.datetime] + The iterator over the dates. """ return iter(self.dates) def __repr__(self) -> str: """Return a string representation of the group of dates. - Returns: - str: The string representation. + Returns + ------- + str + The string representation. """ return f"GroupOfDates(dates={_shorten(self.dates)})" def __eq__(self, other: object) -> bool: """Check if two groups of dates are equal. - Args: - other (object): The other group of dates. + Parameters + ---------- + other : object + The other group of dates. - Returns: - bool: True if the groups are equal, False otherwise. + Returns + ------- + bool + True if the groups are equal, False otherwise. """ return isinstance(other, GroupOfDates) and self.dates == other.dates @@ -90,7 +115,8 @@ def __eq__(self, other: object) -> bool: class Groups: """A collection of groups of dates. - Examples: + Examples + -------- >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] @@ -122,7 +148,8 @@ def __init__(self, **kwargs: Any) -> None: Parameters ---------- - **kwargs : Any : Arbitrary keyword arguments. Expected keys include: + **kwargs : Any + Arbitrary keyword arguments. Expected keys include: - group_by: Configuration for the Grouper. - Other keys for DatesProvider configuration. """ @@ -140,8 +167,10 @@ def provider(self) -> DatesProvider: def __iter__(self) -> Iterator[GroupOfDates]: """Return an iterator over the groups of dates. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ for go in self._grouper(self._dates): dates = self._filter(go.dates) @@ -152,8 +181,10 @@ def __iter__(self) -> Iterator[GroupOfDates]: def __len__(self) -> int: """Return the number of groups of dates. - Returns: - int: The number of groups. + Returns + ------- + int + The number of groups. """ return self._len @@ -171,24 +202,30 @@ def _len(self) -> int: def __repr__(self) -> str: """Return a string representation of the groups of dates. - Returns: - str: The string representation. + Returns + ------- + str + The string representation. """ return f"{self.__class__.__name__}(dates={len(self)},{_shorten(self._dates)})" def describe(self) -> str: """Return a summary description of the dates. - Returns: - str: The summary description. + Returns + ------- + str + The summary description. """ return self._dates.summary def one_date(self) -> GroupOfDates: """Return a group containing only one date. - Returns: - GroupOfDates: The group containing only one date. + Returns + ------- + GroupOfDates + The group containing only one date. """ go = next(iter(self)) return GroupOfDates([go.dates[0]], go.provider) @@ -203,22 +240,24 @@ def __init__(self, missing: List[datetime.datetime]) -> None: def __call__(self, dates: List[datetime.datetime]) -> List[datetime.datetime]: """Filter out missing dates from the list of dates. - Args: - dates (List[datetime.datetime]): The list of dates. + Parameters + ---------- + dates : List[datetime.datetime] + The list of dates. - Returns: - List[datetime.datetime]: The filtered list of dates. + Returns + ------- + List[datetime.datetime] + The filtered list of dates. """ return [d for d in dates if d not in self.missing] class Grouper(ABC): - """Abstract base class for grouping dates.""" @classmethod def from_config(cls, group_by: Any) -> "Grouper": """Create a grouper based on the configuration.""" - if isinstance(group_by, int) and group_by > 0: return GrouperByFixedSize(group_by) @@ -278,11 +317,15 @@ class GrouperOneGroup(Grouper): def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group all dates into a single group. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ assert isinstance(dates, DatesProvider), type(dates) @@ -293,16 +336,27 @@ class GrouperByKey(Grouper): """Group dates by a key.""" def __init__(self, key: Callable[[datetime.datetime], Any]) -> None: + """Initialise GrouperByKey with a key function. + + Parameters + ---------- + key : Callable[[datetime.datetime], Any] + Function to extract grouping key from a datetime. + """ self.key = key def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates based on the provided key. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ for _, g in itertools.groupby(sorted(dates, key=self.key), key=self.key): yield GroupOfDates(list(g), dates) @@ -312,16 +366,27 @@ class GrouperByFixedSize(Grouper): """Group dates by a fixed size.""" def __init__(self, size: int) -> None: + """Initialise GrouperByFixedSize with batch size. + + Parameters + ---------- + size : int + Number of dates per group. + """ self.size = size def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates into fixed-size batches. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ batch = [] From 3754eb267c86b8eace2ffe3b6c1b4170913296dc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 18:13:40 +0200 Subject: [PATCH 34/79] update --- src/anemoi/datasets/create/input/__init__.py | 22 ++- src/anemoi/datasets/create/input/action.py | 27 +++- src/anemoi/datasets/create/python.py | 161 ++++++++++++++++--- src/anemoi/datasets/recipe.py | 38 ++++- 4 files changed, 210 insertions(+), 38 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 63020324c..2e089a42f 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -18,36 +18,32 @@ class InputBuilder: """Builder class for creating input data from configuration and data sources.""" - def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) -> None: + def __init__(self, config: dict, data_sources: dict, **kwargs: Any) -> None: """Initialize the InputBuilder. Parameters ---------- config : dict Configuration dictionary. - data_sources : Union[dict, list] + data_sources : dict Data sources. **kwargs : Any Additional keyword arguments. """ self.kwargs = kwargs - - config = deepcopy(config) - if data_sources: - config = dict( - data_sources=dict( - sources=data_sources, - input=config, - ) - ) - self.config = config + self.config = deepcopy(config) + self.data_sources = deepcopy(dict(data_sources=data_sources)) @cached_property def action(self) -> Any: """Returns the action object based on the configuration.""" + from .action import Recipe from .action import action_factory - return action_factory(self.config, "input") + sources = action_factory(self.data_sources, "data_sources") + input = action_factory(self.config, "input") + + return Recipe(input, sources) def select(self, argument) -> Any: """Select data based on the group of dates. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index d8120a289..fdb412458 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -198,7 +198,32 @@ def new_filter(name, mixin): ) -KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} +class DataSources(Action): + def __init__(self, config, *path): + self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + + def python_code(self, code): + return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + + +class Recipe(Action): + def __init__(self, input, data_sources): + self.input = input + self.data_sources = data_sources + + def python_code(self, code): + return code.recipe( + self.input.python_code(code), + self.data_sources.python_code(code), + ) + + +KLASS = { + "concat": Concat, + "join": Join, + "pipe": Pipe, + "data-sources": DataSources, +} LEN_KLASS = len(KLASS) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 9e66ccfdd..6924af6cc 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -8,11 +8,14 @@ # nor does it submit to any jurisdiction. # +import logging import re import sys from collections import defaultdict from functools import cached_property +LOG = logging.getLogger(__name__) + RESERVED_KEYWORDS = ( "and", "or", @@ -62,7 +65,7 @@ def _un_dotdict(x): class PythonCode: def __init__(self, top): - print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) + # print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) self.top = top self.top.register(self) self.key = str(id(self)) @@ -71,10 +74,10 @@ def call(self, name, argument): return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self.top, "+", actions) + return PythonChain(self.top, "join", "+", actions) def pipe(self, actions): - return PythonChain(self.top, "|", actions) + return PythonChain(self.top, "pipe", "|", actions) def concat(self, argument): return PythonConcat(self.top, argument) @@ -85,27 +88,102 @@ def source_code(self): def combine(self, nodes): return None + def recipe(self, input, data_sources): + return PythonRecipe(self.top, input, data_sources) + def prelude(self): return None + def sources(self, sources): + return PythonSources(self.top, sources) -class Argument: - def __init__(self, name): - self.name = _sanitize_name(name) +class PythonRecipe(PythonCode): + def __init__(self, top, input, data_sources): + super().__init__(top) + self.input = input + self.data_sources = data_sources + + def apply_references(self, *path): + self.input.apply_references(*path, "input") + + def replace_node(self, old, new): + if self.input is old: + self.input = new + return + + if self.data_sources is old: + self.data_sources = new + return + + self.input.replace_node(old, new) + self.data_sources.replace_node(old, new) def __repr__(self): - return self.name + return repr(self.input) + + def prelude(self): + return self.data_sources.prelude() -class Parameter: +class Argument(PythonCode): - def __init__(self, name): + def __init__(self, top, name): + super().__init__(top=top) self.name = _sanitize_name(name) def __repr__(self): return self.name + def replace_node(self, old, new): + pass + + +class Anchor(PythonCode): + + def __init__(self, node): + super().__init__(top=node.top) + self.node = node + + @cached_property + def name(self): + n = self.top.counter["_anchor"] + self.top.counter["_anchor"] += 1 + return f"_a{n}" + + def __repr__(self): + return f"({self.name} := {repr(self.node)})" + + def replace_node(self, old, new): + pass + + +class Reference(PythonCode): + + def __init__(self, top, path): + super().__init__(top) + self.path = tuple(path) + self.anchor = None + + node = top.by_reference.get(self.path, None) + if node is None: + LOG.warning(f"Reference {self.path} not found") + for p in sorted(top.by_reference): + LOG.warning(f" - {p}") + else: + self.anchor = Anchor(node) + self.top.replace_nodes([(node, self.anchor)]) + + def __repr__(self): + if self.anchor is not None: + print("Reference:", self.path, "->", self.anchor.name, file=sys.stderr) + return self.anchor.name + + return f"'${{{'.'.join(self.path)}}}'" + + def replace_node(self, old, new): + pass + class Function: def __init__(self, name, node, counter): @@ -149,11 +227,38 @@ def replace_node(self, old, new): self.node = new +class PythonSources(PythonCode): + def __init__(self, top, sources): + super().__init__(top) + self.sources = sources + + def __repr__(self): + return "" + + def prelude(self): + result = [] + for k, v in self.sources.items(): + result.append(f"{k}={repr(v)}") + result.append("") + return result + + def replace_node(self, old, new): + for k, v in list(self.sources.items()): + if v is old: + self.sources[k] = new + else: + v.replace_node(old, new) + + def apply_references(self, *path): + self.top.by_reference[path + (self.name,)] = self + + class PythonScript(PythonCode): def __init__(self): self.nodes = [] self.counter = defaultdict(int) + self.by_reference = {} super().__init__(top=self) def register(self, child): @@ -174,6 +279,10 @@ def source_code(self, first): which = self.nodes.index(first) + first.apply_references() + # for k, v in self.by_reference.items(): + # print(f"Reference: {k} -> {v}", file=sys.stderr) + more = True while more: more = False @@ -184,8 +293,8 @@ def source_code(self, first): for (cls, key), nodes in by_class.items(): if len(nodes) > 1: - print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) - print(f"Nodes: {len(nodes)}", file=sys.stderr) + # print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) + # print(f"Nodes: {len(nodes)}", file=sys.stderr) changes = nodes[0].combine(nodes) if changes: self.replace_nodes(changes) @@ -234,11 +343,18 @@ def replace_node(self, old, new): else: v.replace_node(old, new) + def apply_references(self, *path): + assert "concat" not in path, path + self.top.by_reference[path + ("concat",)] = self + for i, node in enumerate(self.argument.values()): + node.apply_references(*path, "concat", str(i)) + class PythonChain(PythonCode): - def __init__(self, top, op, actions): + def __init__(self, top, kind, op, actions): super().__init__(top=top) self.op = op + self.kind = kind self.actions = list(actions) self.key = op @@ -254,6 +370,11 @@ def replace_node(self, old, new): else: node.replace_node(old, new) + def apply_references(self, *path): + self.top.by_reference[path + (self.kind,)] = self + for i, node in enumerate(self.actions): + node.apply_references(*path, self.kind, str(i)) + class PythonCall(PythonCode): def __init__(self, top, name, argument): @@ -283,6 +404,7 @@ def __repr__(self): if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") if params: @@ -346,7 +468,7 @@ def _combine1(self, nodes): return rest = {k: v for k, v in node.argument.items() if k != key} - rest[key] = Argument(key) + rest[key] = Argument(self.top, key) call = PythonCall(self.top, self.name, rest) func = self.top.function(call) @@ -363,6 +485,14 @@ def _combine1(self, nodes): return changes + def apply_references(self, *path): + self.top.by_reference[path + (self.name,)] = self + + for k, v in self.argument.items(): + if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): + path = m.group(1).split(".") + self.argument[k] = Reference(self.top, path) + class PythonFunction(PythonCode): def __init__(self, top, func, **kwargs): @@ -372,11 +502,6 @@ def __init__(self, top, func, **kwargs): def __repr__(self): - # if len(self.func.free_arguments()) == 0: - # a = repr(self.func.node) - # if '=' not in a: - # return a - params = [] for a in self.func.free_arguments(): name = _sanitize_name(a.name) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 69ce8b7df..24fc8caa2 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -12,6 +12,7 @@ import sys from tempfile import TemporaryDirectory +import rich import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict @@ -32,6 +33,11 @@ def __init__(self, index): def __repr__(self): return f"Index({self.name})" + def same(self, other): + if not isinstance(other, Index): + return False + return self.name == other.name + class Step: @@ -41,6 +47,9 @@ def __or__(self, other): def __add__(self, other): return Join(self, other) + def same(self, other): + return self is other + class Chain(Step): def __init__(self, *args): @@ -55,10 +64,18 @@ def as_dict(self, recipe): return self.steps[0].as_dict(recipe) return {self.name: [s.as_dict(recipe) for s in self.steps]} - def __repr__(self): - return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + # def __repr__(self): + # return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" def path(self, target, result, *path): + + rich.print(f"path: {target=}, {self=}, {result=}, {[s.name for s in path]}") + rich.print("-------------") + + if target is self: + result.append([*path, self]) + return + for i, s in enumerate(self.steps): s.path(target, result, *path, self, self.index[i]) @@ -81,6 +98,8 @@ def __init__(self, args): assert isinstance(args, dict), f"Invalid argument {args}" self.params = args + rich.print(f"Concat: {self=}") + def __setitem__(self, key, value): self.params[key] = value @@ -94,10 +113,15 @@ def as_dict(self, recipe): return {"concat": result} def collocated(self, a, b): - return a[0] is b[0] + return a[0].same(b[0]) def path(self, target, result, *path): + rich.print(f"path: {target=}, {self=}, {result=}, {path=}") + rich.print("-------------") + if target is self: + result.append([*path, self]) + return for i, (k, v) in enumerate(sorted(self.params.items())): v.path(target, result, *path, self, Index(i)) @@ -131,8 +155,8 @@ def resolve(params, recipe): return {self.owner.name: resolve(self.params, recipe)} - def __repr__(self): - return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" + # def __repr__(self): + # return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" def path(self, target, result, *path): @@ -233,7 +257,7 @@ def concat(self, *args, **kwargs): return Concat(*args, **kwargs) def resolve(self, source, target): - assert isinstance(target, Source), f"Only sources can be used as template {target}" + # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -263,6 +287,8 @@ def resolve(self, source, target): assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" + rich.print(f"Common ancestor: {common_ancestor=} {a=} {b=}") + if not common_ancestor.collocated(a, b): source = ".".join(s.name for s in path_to_source) target = ".".join(s.name for s in path_to_target) From 39ebc13063d71c3e020e7ee061737262c1e56dca Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:04:12 +0200 Subject: [PATCH 35/79] update --- src/anemoi/datasets/create/python.py | 232 ++++++++++++++++----------- 1 file changed, 134 insertions(+), 98 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 6924af6cc..1681f885b 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -10,7 +10,6 @@ import logging import re -import sys from collections import defaultdict from functools import cached_property @@ -97,6 +96,44 @@ def prelude(self): def sources(self, sources): return PythonSources(self.top, sources) + def update_anchor(self): + pass + + +class Variable(PythonCode): + def __init__(self, name, node): + super().__init__(top=node.top) + self.name = name + self.node = node + + def __repr__(self): + + return "" + + def replace_node(self, old, new): + pass + + def prelude(self): + return [f"{self.name} = {repr(self.node)}", ""] + + +class InLine(PythonCode): + def __init__(self, node): + super().__init__(top=node.top) + self.node = node + + @cached_property + def name(self): + n = self.top.counter["_anchor"] + self.top.counter["_anchor"] += 1 + return f"_a{n}" + + def __repr__(self): + return f"({self.name} := {repr(self.node)})" + + def replace_node(self, old, new): + pass + class PythonRecipe(PythonCode): def __init__(self, top, input, data_sources): @@ -105,6 +142,7 @@ def __init__(self, top, input, data_sources): self.data_sources = data_sources def apply_references(self, *path): + self.data_sources.apply_references(*path, "data_sources") self.input.apply_references(*path, "input") def replace_node(self, old, new): @@ -141,18 +179,17 @@ def replace_node(self, old, new): class Anchor(PythonCode): - def __init__(self, node): - super().__init__(top=node.top) - self.node = node + def __init__(self, identifier): + super().__init__(top=identifier.node.top) + self.identifier = identifier - @cached_property + @property def name(self): - n = self.top.counter["_anchor"] - self.top.counter["_anchor"] += 1 - return f"_a{n}" + return self.identifier.name def __repr__(self): - return f"({self.name} := {repr(self.node)})" + # assert False + return repr(self.identifier) def replace_node(self, old, new): pass @@ -165,18 +202,19 @@ def __init__(self, top, path): self.path = tuple(path) self.anchor = None - node = top.by_reference.get(self.path, None) + def update_anchor(self): + + node = self.top.by_reference.get(self.path, None) if node is None: LOG.warning(f"Reference {self.path} not found") - for p in sorted(top.by_reference): + for p in sorted(self.top.by_reference): LOG.warning(f" - {p}") else: self.anchor = Anchor(node) - self.top.replace_nodes([(node, self.anchor)]) + self.top.replace_nodes([(node.node, self.anchor)]) def __repr__(self): if self.anchor is not None: - print("Reference:", self.path, "->", self.anchor.name, file=sys.stderr) return self.anchor.name return f"'${{{'.'.join(self.path)}}}'" @@ -185,8 +223,9 @@ def replace_node(self, old, new): pass -class Function: +class Function(PythonCode): def __init__(self, name, node, counter): + super().__init__(top=node.top) self._name = name self.node = node self.used = False @@ -236,11 +275,7 @@ def __repr__(self): return "" def prelude(self): - result = [] - for k, v in self.sources.items(): - result.append(f"{k}={repr(v)}") - result.append("") - return result + pass def replace_node(self, old, new): for k, v in list(self.sources.items()): @@ -250,82 +285,8 @@ def replace_node(self, old, new): v.replace_node(old, new) def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = self - - -class PythonScript(PythonCode): - - def __init__(self): - self.nodes = [] - self.counter = defaultdict(int) - self.by_reference = {} - super().__init__(top=self) - - def register(self, child): - if child is not self: - self.nodes.append(child) - - def prelude(self): - result = [] - for node in self.nodes: - prelude = node.prelude() - if prelude: - if not isinstance(prelude, (list, tuple)): - prelude = list(prelude) - result.extend(prelude) - return "\n".join(result) - - def source_code(self, first): - - which = self.nodes.index(first) - - first.apply_references() - # for k, v in self.by_reference.items(): - # print(f"Reference: {k} -> {v}", file=sys.stderr) - - more = True - while more: - more = False - - by_class = defaultdict(list) - for node in self.nodes: - by_class[(node.__class__, node.key)].append(node) - - for (cls, key), nodes in by_class.items(): - if len(nodes) > 1: - # print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) - # print(f"Nodes: {len(nodes)}", file=sys.stderr) - changes = nodes[0].combine(nodes) - if changes: - self.replace_nodes(changes) - more = True - - first = self.nodes[which] - - return "\n\n".join( - [ - "# Generated Python code for Anemoi dataset creation", - "from anemoi.datasets.recipe import Recipe", - "r = Recipe()", - self.prelude(), - f"r.input = {repr(first)}", - "r.dump()", - ] - ) - - def function(self, node): - return Function(node.name, node, self.counter) - - def replace_nodes(self, changes): - - for old, new in changes: - assert old in self.nodes, f"Node {old} not found in {self.nodes}" - for i, node in enumerate(self.nodes): - - if node is old: - self.nodes[i] = new - else: - node.replace_node(old, new) + for k, v in self.sources.items(): + self.top.by_reference[path + (k,)] = Variable(k, v) class PythonConcat(PythonCode): @@ -345,7 +306,7 @@ def replace_node(self, old, new): def apply_references(self, *path): assert "concat" not in path, path - self.top.by_reference[path + ("concat",)] = self + self.top.by_reference[path + ("concat",)] = InLine(self) for i, node in enumerate(self.argument.values()): node.apply_references(*path, "concat", str(i)) @@ -371,7 +332,7 @@ def replace_node(self, old, new): node.replace_node(old, new) def apply_references(self, *path): - self.top.by_reference[path + (self.kind,)] = self + self.top.by_reference[path + (self.kind,)] = InLine(self) for i, node in enumerate(self.actions): node.apply_references(*path, self.kind, str(i)) @@ -486,7 +447,7 @@ def _combine1(self, nodes): return changes def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = self + self.top.by_reference[path + (self.name,)] = InLine(self) for k, v in self.argument.items(): if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): @@ -521,3 +482,78 @@ def prelude(self): def free_arguments(self): return [a for a in self.func.free_arguments() if a.name not in self.kwargs] + + +class PythonScript(PythonCode): + + def __init__(self): + self.nodes = [] + self.counter = defaultdict(int) + self.by_reference = {} + super().__init__(top=self) + + def register(self, child): + if child is not self: + self.nodes.append(child) + + def prelude(self): + result = [] + for node in self.nodes: + prelude = node.prelude() + if prelude: + if not isinstance(prelude, (list, tuple)): + prelude = list(prelude) + result.extend(prelude) + return "\n".join(result) + + def source_code(self, first): + + which = self.nodes.index(first) + + more = True + while more: + more = False + + by_class = defaultdict(list) + for node in self.nodes: + by_class[(node.__class__, node.key)].append(node) + + for nodes in by_class.values(): + if len(nodes) > 1: + changes = nodes[0].combine(nodes) + if changes: + self.replace_nodes(changes) + more = True + + first = self.nodes[which] + + first.apply_references() + for node in self.nodes: + node.update_anchor() + + first = self.nodes[which] + + return "\n\n".join( + [ + "# Generated Python code for Anemoi dataset creation", + "from anemoi.datasets.recipe import Recipe", + "r = Recipe()", + self.prelude(), + f"r.input = {repr(first)}", + "r.dump()", + ] + ) + + def function(self, node): + return Function(node.name, node, self.counter) + + def replace_nodes(self, changes): + + for old, new in changes: + assert old in self.nodes, f"Node {old} not found in {self.nodes}" + for i, node in enumerate(self.nodes): + + if node is old: + self.nodes[i] = new + else: + node.replace_node(old, new) From 5d327451ec751a59ef156b7443dc0cd500118505 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:08:18 +0200 Subject: [PATCH 36/79] update --- src/anemoi/datasets/create/python.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 1681f885b..d0ab72dd5 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -483,6 +483,9 @@ def prelude(self): def free_arguments(self): return [a for a in self.func.free_arguments() if a.name not in self.kwargs] + def apply_references(self, *path): + pass + class PythonScript(PythonCode): @@ -509,6 +512,9 @@ def prelude(self): def source_code(self, first): which = self.nodes.index(first) + first.apply_references() + for node in self.nodes: + node.update_anchor() more = True while more: @@ -527,12 +533,6 @@ def source_code(self, first): first = self.nodes[which] - first.apply_references() - for node in self.nodes: - node.update_anchor() - - first = self.nodes[which] - return "\n\n".join( [ "# Generated Python code for Anemoi dataset creation", From 6c1f14639e017d4c40ff0902e3ee84d87bc9d454 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:39:18 +0200 Subject: [PATCH 37/79] update --- src/anemoi/datasets/recipe.py | 56 ++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 24fc8caa2..e42263343 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -10,9 +10,9 @@ import logging import os import sys +from collections import defaultdict from tempfile import TemporaryDirectory -import rich import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict @@ -69,9 +69,6 @@ def as_dict(self, recipe): def path(self, target, result, *path): - rich.print(f"path: {target=}, {self=}, {result=}, {[s.name for s in path]}") - rich.print("-------------") - if target is self: result.append([*path, self]) return @@ -98,8 +95,6 @@ def __init__(self, args): assert isinstance(args, dict), f"Invalid argument {args}" self.params = args - rich.print(f"Concat: {self=}") - def __setitem__(self, key, value): self.params[key] = value @@ -116,9 +111,6 @@ def collocated(self, a, b): return a[0].same(b[0]) def path(self, target, result, *path): - rich.print(f"path: {target=}, {self=}, {result=}, {path=}") - rich.print("-------------") - if target is self: result.append([*path, self]) return @@ -138,9 +130,9 @@ def __init__(self, owner, *args, **kwargs): def as_dict(self, recipe): - def resolve(params, recipe): + def resolve(params, recipe, name=None): if isinstance(params, dict): - return {k: resolve(v, recipe) for k, v in params.items()} + return {k: resolve(v, recipe, name=k) for k, v in params.items()} if isinstance(params, (list, tuple)): return [resolve(v, recipe) for v in params] @@ -149,7 +141,7 @@ def resolve(params, recipe): return [resolve(v, recipe) for v in params] if isinstance(params, Step): - return recipe.resolve(self, params) + return recipe.resolve(self, params, name=name) return params @@ -209,6 +201,9 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self.statistics = DotDict() self.build = DotDict() + self._data_sources = {} + self._counter = defaultdict(int) + sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -247,6 +242,9 @@ def as_dict(self): "dates": self.dates, } + if self._data_sources: + result["data_sources"] = self._data_sources + for k, v in list(result.items()): if v is None: del result[k] @@ -256,7 +254,26 @@ def as_dict(self): def concat(self, *args, **kwargs): return Concat(*args, **kwargs) - def resolve(self, source, target): + # def assert False, (name, target.as_dict(self)) + + def make_data_source(self, name, target): + + target = target.as_dict(self) + + name = name or "source" + if name in self._data_sources: + if self._data_sources[name] == target: + return f"${{data_sources.{name}}}" + + n = self._counter[name] + self._counter[name] += 1 + + name = f"{name}_{n}" if n > 0 else name + + self._data_sources[name] = target.copy() + return f"${{data_sources.{name}}}" + + def resolve(self, source, target, name=None): # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -271,10 +288,13 @@ def resolve(self, source, target): path_to_target = [] self.input.path(target, path_to_target, top) - if len(path_to_target) == 0: - raise ValueError(f"Target {target} not found in recipe") if len(path_to_target) > 1: raise ValueError(f"Target {target} found in multiple locations {path_to_target}") + + if len(path_to_target) == 0: + # Add a `data_sources` entry + return self.make_data_source(name, target) + path_to_target = path_to_target[0] a = [s for s in path_to_target] @@ -287,8 +307,6 @@ def resolve(self, source, target): assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" - rich.print(f"Common ancestor: {common_ancestor=} {a=} {b=}") - if not common_ancestor.collocated(a, b): source = ".".join(s.name for s in path_to_source) target = ".".join(s.name for s in path_to_target) @@ -368,9 +386,11 @@ def dates(self, value): self._dates = self._parse_dates(value) def dump(self, file=sys.stdout): + input = self.input.as_dict(self) # First so we get the data_sources + result = self.as_dict() - result["input"] = self.input.as_dict(self) + result["input"] = input if self.output: result["output"] = self.output.as_dict() From 37de369a91b9fd4613f730c6697a73fc85c39f1d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 20:17:15 +0200 Subject: [PATCH 38/79] update --- src/anemoi/datasets/create/python.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d0ab72dd5..ee33398e9 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -64,7 +64,6 @@ def _un_dotdict(x): class PythonCode: def __init__(self, top): - # print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) self.top = top self.top.register(self) self.key = str(id(self)) @@ -107,7 +106,6 @@ def __init__(self, name, node): self.node = node def __repr__(self): - return "" def replace_node(self, old, new): From 24f2c2aba8152be268ee914803eccf87e31ce4ce Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 11:13:48 +0200 Subject: [PATCH 39/79] update --- src/anemoi/datasets/create/input/action.py | 5 +---- src/anemoi/datasets/create/python.py | 9 +++------ src/anemoi/datasets/dates/__init__.py | 4 ++-- src/anemoi/datasets/recipe.py | 8 +++++++- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index fdb412458..d173f0996 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -56,10 +56,7 @@ def __call__(self, context, argument): def python_code(self, code): return code.concat( - { - filtering_dates.to_python(just_dates=True): action.python_code(code) - for filtering_dates, action in self.choices - } + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} ) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index ee33398e9..2a47a5b67 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -356,13 +356,10 @@ def __repr__(self): params = [] for k, v in config.items(): - if isinstance(k, str): + k = _sanitize_name(k) - if k in RESERVED_KEYWORDS: - k = f"{k}_" - - if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): - return f"r.{name}({config})" + if not k.isidentifier(): + return f"r.{name}({config})" params.append(f"{k}={repr(v)}") diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 0ae8105e8..5e6c1863a 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -285,9 +285,9 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self, just_dates=False) -> str: + def to_python(self) -> str: """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" - if just_dates: + if self.frequency == datetime.timedelta(hours=1): return (self.start.isoformat(), self.end.isoformat()) else: return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index e42263343..455cf689a 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -132,7 +132,13 @@ def as_dict(self, recipe): def resolve(params, recipe, name=None): if isinstance(params, dict): - return {k: resolve(v, recipe, name=k) for k, v in params.items()} + + def _(k): + if k.endswith("_"): + return k[:-1] + return k + + return {_(k): resolve(v, recipe, name=_(k)) for k, v in params.items()} if isinstance(params, (list, tuple)): return [resolve(v, recipe) for v in params] From db4d8956c121fdc47862d9f711aa57b212e568cc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 12:27:44 +0200 Subject: [PATCH 40/79] add dumper --- src/anemoi/datasets/dumper.py | 71 +++++++++++++++++++++++++++++++++++ src/anemoi/datasets/recipe.py | 12 ++++-- 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 src/anemoi/datasets/dumper.py diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py new file mode 100644 index 000000000..052750d56 --- /dev/null +++ b/src/anemoi/datasets/dumper.py @@ -0,0 +1,71 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import yaml + +LOG = logging.getLogger(__name__) + + +class MyDumper(yaml.SafeDumper): + pass + + +def represent_date(dumper, data): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# --- Represent multiline strings with | style --- +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + + if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + elems = [yaml.dump(i, explicit_start=False, explicit_end=False).replace("\n...\n", "") for i in data] + lines = [] + line = [] + for e in elems: + if sum(len(x) for x in line) + len(e) + 2 * (len(line) + 1) <= 80: + line.append(e) + else: + lines.append(line) + line = [e] + + if line: + lines.append(line) + + if len(lines) == 1: + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + + block_lines = ["- [" + ", ".join(line) + "]" for line in lines] + return dumper.represent_scalar("tag:yaml.org,2002:str", "\n".join(block_lines), style="|") + + +# Register representers +MyDumper.add_representer(datetime.date, represent_date) +MyDumper.add_representer(datetime.datetime, represent_date) +MyDumper.add_representer(str, represent_multiline_str) +MyDumper.add_representer(list, represent_inline_list) + + +def yaml_dump(obj, **kwargs): + return yaml.dump(obj, Dumper=MyDumper, **kwargs) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 455cf689a..51e83b693 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -13,7 +13,6 @@ from collections import defaultdict from tempfile import TemporaryDirectory -import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict from anemoi.utils.dates import as_datetime @@ -103,7 +102,12 @@ def as_dict(self, recipe): result = [] for k, v in sorted(self.params.items()): - result.append({"dates": dict(start=k[0], end=k[1]), **v.as_dict(recipe)}) + + key = dict(start=as_datetime(k[0]), end=as_datetime(k[1])) + if len(k) == 3: + key["frequency"] = k[2] + + result.append({"dates": key, **v.as_dict(recipe)}) return {"concat": result} @@ -407,7 +411,9 @@ def dump(self, file=sys.stdout): if self.build: result["build"] = self.build.as_dict() - yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) + from .dumper import yaml_dump + + yaml_dump(result, sort_keys=False, indent=2, width=120, stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From 55f740d34db13f16553c2e2ea67e8eae50f24753 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 14:04:42 +0200 Subject: [PATCH 41/79] update --- src/anemoi/datasets/commands/format.py | 87 ++++++++++++++++++++++++++ src/anemoi/datasets/dumper.py | 16 ++++- 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 src/anemoi/datasets/commands/format.py diff --git a/src/anemoi/datasets/commands/format.py b/src/anemoi/datasets/commands/format.py new file mode 100644 index 000000000..42d3aa1f9 --- /dev/null +++ b/src/anemoi/datasets/commands/format.py @@ -0,0 +1,87 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import sys +from typing import Any + +import yaml + +from ..dumper import yaml_dump +from . import Command + +LOG = logging.getLogger(__name__) + + +def make_dates(config): + if isinstance(config, dict): + return {k: make_dates(v) for k, v in config.items()} + if isinstance(config, list): + return [make_dates(v) for v in config] + if isinstance(config, str): + try: + return datetime.datetime.fromisoformat(config) + except ValueError: + return config + return config + + +ORDER = ( + "name", + "description", + "dataset_status", + "licence", + "attribution", + "env", + "dates", + "common", + "data_sources", + "input", + "output", + "statistics", + "build", + "platform", +) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + config = make_dates(config) + + text = yaml_dump(config, sort_keys=False, indent=2, width=120, order=ORDER) + # with open(args.path + ".tmp", "w") as f: + f = sys.stdout + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) + + # os.rename(args.path + ".tmp", args.path) + + +command = Recipe diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 052750d56..de41b0710 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -68,4 +68,18 @@ def represent_inline_list(dumper, data): def yaml_dump(obj, **kwargs): - return yaml.dump(obj, Dumper=MyDumper, **kwargs) + + kwargs.setdefault("Dumper", MyDumper) + kwargs.setdefault("sort_keys", False) + kwargs.setdefault("indent", 2) + kwargs.setdefault("width", 120) + + order = kwargs.pop("order", None) + if order: + + def _ordering(k): + return order.index(k) if k in order else len(order) + + obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} + + return yaml.dump(obj, **kwargs) From 014dbbc90948874f1ccf7871d0208441b0a06fa7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 16:23:48 +0200 Subject: [PATCH 42/79] update --- src/anemoi/datasets/create/input/action.py | 9 +++++++++ src/anemoi/datasets/dumper.py | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index d173f0996..3085bc60d 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -202,6 +202,10 @@ def __init__(self, config, *path): def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + def __call__(self, context, argument): + for source in self.sources.values(): + source(context, argument) + class Recipe(Action): def __init__(self, input, data_sources): @@ -214,6 +218,11 @@ def python_code(self, code): self.data_sources.python_code(code), ) + def __call__(self, context, argument): + # Load data_sources + self.data_sources(context, argument) + return self.input(context, argument) + KLASS = { "concat": Concat, diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index de41b0710..7e3867969 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -67,14 +67,13 @@ def represent_inline_list(dumper, data): MyDumper.add_representer(list, represent_inline_list) -def yaml_dump(obj, **kwargs): +def yaml_dump(obj, order=None, **kwargs): kwargs.setdefault("Dumper", MyDumper) kwargs.setdefault("sort_keys", False) kwargs.setdefault("indent", 2) kwargs.setdefault("width", 120) - order = kwargs.pop("order", None) if order: def _ordering(k): From 8ad939661a73ee81c9ac7c0a52d0c0711aacfe88 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 15:31:27 +0000 Subject: [PATCH 43/79] bug fix in path --- src/anemoi/datasets/create/input/action.py | 23 +++++++++++-------- .../datasets/create/input/context/__init__.py | 5 ++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 3085bc60d..2c157e3f9 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -20,23 +20,27 @@ class Action: def __init__(self, config, *path): self.config = config self.path = path + assert path[0] in ( + "input", + "data_sources", + ), f"{self.__class__.__name__}. Path must start with 'input' or 'data_sources': {path}" # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") class Concat(Action): def __init__(self, config, *path): - super().__init__(config, *path) + super().__init__(config, *path, "concat") assert isinstance(config, list), f"Value must be a dict {list}" self.choices = [] - for item in config: + for i, item in enumerate(config): assert "dates" in item, f"Value must contain the key 'dates' {item}" dates = item["dates"] filtering_dates = DatesProvider.from_config(**dates) - action = action_factory({k: v for k, v in item.items() if k != "dates"}) + action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) self.choices.append((filtering_dates, action)) def __repr__(self): @@ -62,11 +66,11 @@ def python_code(self, code): class Join(Action): def __init__(self, config, *path): - super().__init__(config, *path) + super().__init__(config, *path, "join") assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Join({self.actions})" @@ -86,8 +90,8 @@ def python_code(self, code) -> None: class Pipe(Action): def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - super().__init__(config, *path) - self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] + super().__init__(config, *path, "pipe") + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Pipe({self.actions})" @@ -197,14 +201,15 @@ def new_filter(name, mixin): class DataSources(Action): def __init__(self, config, *path): + super().__init__(config, *path) self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) def __call__(self, context, argument): - for source in self.sources.values(): - source(context, argument) + for name, source in self.sources.items(): + context.register(source(context, argument), self.path + (name,)) class Recipe(Action): diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 26d449659..81ccbd593 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -34,6 +34,8 @@ def register(self, data: Any, path: list[str]) -> Any: if not path: return data + assert path[0] in ("input", "data_sources"), path + rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -47,6 +49,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: + rich.print(f"Path not found {path}") + for p in sorted(self.results): + rich.print(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config From 97566189a1930011a365fb07166f189450e5e0f6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:16:31 +0200 Subject: [PATCH 44/79] join recipe command --- src/anemoi/datasets/commands/recipe.py | 50 ---------- .../datasets/commands/recipe/__init__.py | 93 +++++++++++++++++++ .../datasets/commands/{ => recipe}/format.py | 48 +++------- .../datasets/commands/{ => recipe}/migrate.py | 61 +++--------- src/anemoi/datasets/create/input/action.py | 13 ++- 5 files changed, 126 insertions(+), 139 deletions(-) delete mode 100644 src/anemoi/datasets/commands/recipe.py create mode 100644 src/anemoi/datasets/commands/recipe/__init__.py rename src/anemoi/datasets/commands/{ => recipe}/format.py (51%) rename src/anemoi/datasets/commands/{ => recipe}/migrate.py (90%) diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py deleted file mode 100644 index f111aee1b..000000000 --- a/src/anemoi/datasets/commands/recipe.py +++ /dev/null @@ -1,50 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -import yaml - -from . import Command -from .migrate import migrate - -LOG = logging.getLogger(__name__) - - -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - - command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") - - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - from anemoi.datasets.create import config_to_python - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - if args.migrate: - config = migrate(config) - - print(config_to_python(config)) - - -command = Recipe diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py new file mode 100644 index 000000000..b45bc34d9 --- /dev/null +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -0,0 +1,93 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import argparse +import logging +import sys +from typing import Any + +import yaml + +from anemoi.datasets.create import config_to_python + +from .. import Command +from .format import format_recipe +from .migrate import migrate_recipe + +LOG = logging.getLogger(__name__) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + + command_parser.add_argument("--format", action="store_true", help="Format the recipe.") + command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") + + group = command_parser.add_mutually_exclusive_group() + group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") + group.add_argument("--output", type=str, help="Output file path for the converted recipe.") + + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + + if not args.format and not args.migrate and not args.python: + args.format = True + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + assert isinstance(config, dict) + + if args.migrate: + config = migrate_recipe(args, config) + if config is None: + LOG.info(f"{args.path}: No changes needed.") + return + + args.format = True + + if args.format: + formatted = format_recipe(args, config) + f = sys.stdout + if args.output: + f = open(args.output, "w") + + if args.inplace: + f = open(args.path, "w") + + print(formatted, file=f) + + if args.python: + if args.inplace: + argparse.ArgumentError(None, "Inplace conversion to Python is not supported.") + + if args.format: + raise argparse.ArgumentError(None, "Formatting is not supported when converting to Python.") + + if args.output: + with open(args.output, "w") as file: + file.write(config_to_python(config)) + else: + print(config_to_python(config)) + + +command = Recipe diff --git a/src/anemoi/datasets/commands/format.py b/src/anemoi/datasets/commands/recipe/format.py similarity index 51% rename from src/anemoi/datasets/commands/format.py rename to src/anemoi/datasets/commands/recipe/format.py index 42d3aa1f9..f221e0b8d 100644 --- a/src/anemoi/datasets/commands/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -9,14 +9,10 @@ import datetime +import io import logging -import sys -from typing import Any -import yaml - -from ..dumper import yaml_dump -from . import Command +from ...dumper import yaml_dump LOG = logging.getLogger(__name__) @@ -52,36 +48,16 @@ def make_dates(config): ) -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - config = make_dates(config) - - text = yaml_dump(config, sort_keys=False, indent=2, width=120, order=ORDER) - # with open(args.path + ".tmp", "w") as f: - f = sys.stdout - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) +def format_recipe(args, config: dict) -> str: - # os.rename(args.path + ".tmp", args.path) + config = make_dates(config) + assert config + text = yaml_dump(config, order=ORDER) + f = io.StringIO() + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) -command = Recipe + return f.getvalue() diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py similarity index 90% rename from src/anemoi/datasets/commands/migrate.py rename to src/anemoi/datasets/commands/recipe/migrate.py index fa74886b8..daaa3bf16 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -10,7 +10,6 @@ import datetime import logging -import os from collections.abc import Sequence from typing import Any @@ -22,8 +21,6 @@ from anemoi.datasets.create import validate_config -from . import Command - LOG = logging.getLogger(__name__) @@ -540,54 +537,22 @@ def check(config): raise -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - - rich.print(f"Migrating {args.path}") - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - try: - validate_config(config) - LOG.info(f"{args.path}: Validation successful.") - return - except Exception: - pass +def migrate_recipe(args: Any, config) -> None: - migrated = migrate(config) + rich.print(f"Migrating {args.path}") - migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} - - check(migrated) - if migrated == config: - LOG.info(f"{args.path}: No changes needed.") - return - - migrated = make_dates(migrated) - text = yaml.dump(migrated, default_flow_style=False, sort_keys=False, indent=2, width=120, Dumper=MyDumper) + try: + validate_config(config) + LOG.info(f"{args.path}: Validation successful.") + except Exception: + pass - LOG.info(f"{args.path}: updating.") - with open(args.path + ".tmp", "w") as f: - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) + migrated = migrate(config) - os.rename(args.path + ".tmp", args.path) + migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} + check(migrated) + if migrated == config: + return None -command = Recipe + return migrated diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 2c157e3f9..1e3c8175d 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -23,8 +23,7 @@ def __init__(self, config, *path): assert path[0] in ( "input", "data_sources", - ), f"{self.__class__.__name__}. Path must start with 'input' or 'data_sources': {path}" - # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") + ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" class Concat(Action): @@ -130,7 +129,7 @@ def __call__(self, context, argument): def python_code(self, code) -> str: # For now... if "source" in self.config: - source = action_factory(self.config["source"]) + source = action_factory(self.config["source"], *self.path, "source") self.config["source"] = source.python_code(code) return code.call(self.name, self.config) @@ -239,7 +238,7 @@ def __call__(self, context, argument): LEN_KLASS = len(KLASS) -def make(key, config, path): +def make(key, config, *path): if LEN_KLASS == len(KLASS): @@ -272,8 +271,12 @@ def make(key, config, path): def action_factory(data, *path): + + assert len(path) > 0, f"Path must contain at least one element {path}" + assert path[0] in ("input", "data_sources") + assert isinstance(data, dict), f"Input data must be a dictionary {data}" assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" key, value = next(iter(data.items())) - return make(key, value, path) + return make(key, value, *path) From a493a961c2a6ddd008688db59164ccc3cd6ab597 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:23:19 +0200 Subject: [PATCH 45/79] join recipe command --- .../datasets/commands/recipe/__init__.py | 17 ++++++- src/anemoi/datasets/commands/validate.py | 44 ------------------- 2 files changed, 15 insertions(+), 46 deletions(-) delete mode 100644 src/anemoi/datasets/commands/validate.py diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index b45bc34d9..546486f56 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -16,6 +16,7 @@ import yaml from anemoi.datasets.create import config_to_python +from anemoi.datasets.create import validate_config from .. import Command from .format import format_recipe @@ -34,6 +35,7 @@ def add_arguments(self, command_parser: Any) -> None: Command parser object. """ + command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") command_parser.add_argument("--format", action="store_true", help="Format the recipe.") command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") @@ -49,14 +51,25 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: - if not args.format and not args.migrate and not args.python: - args.format = True + if not args.validate and not args.format and not args.migrate and not args.python: + args.validate = True with open(args.path, "r") as file: config = yaml.safe_load(file) assert isinstance(config, dict) + if args.validate: + if args.inplace and (not args.format and not args.migrate and not args.python): + argparse.ArgumentError(None, "--inplace is not supported with --validate.") + + if args.output and (not args.format and not args.migrate and not args.python): + argparse.ArgumentError(None, "--output is not supported with --validate.") + + validate_config(config) + LOG.info(f"{args.path}: Recipe is valid.") + return + if args.migrate: config = migrate_recipe(args, config) if config is None: diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py deleted file mode 100644 index 84b25c6f8..000000000 --- a/src/anemoi/datasets/commands/validate.py +++ /dev/null @@ -1,44 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -import yaml - -from . import Command - -LOG = logging.getLogger(__name__) - - -class Validate(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - from anemoi.datasets.create import validate_config - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - validate_config(config) - - -command = Validate From 1cde9f8253332b532c767ea129d51bee9b44bc3a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:33:46 +0200 Subject: [PATCH 46/79] use ampersand --- src/anemoi/datasets/create/python.py | 2 +- src/anemoi/datasets/recipe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 2a47a5b67..07c521824 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -72,7 +72,7 @@ def call(self, name, argument): return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self.top, "join", "+", actions) + return PythonChain(self.top, "join", "&", actions) def pipe(self, actions): return PythonChain(self.top, "pipe", "|", actions) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 51e83b693..20199beae 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -43,7 +43,7 @@ class Step: def __or__(self, other): return Pipe(self, other) - def __add__(self, other): + def __and__(self, other): return Join(self, other) def same(self, other): From cdb1a9a08fff142fb6c37c39da6227e98eec2972 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:45:43 +0200 Subject: [PATCH 47/79] use ampersand --- src/anemoi/datasets/recipe.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 20199beae..9b8516bf9 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -63,9 +63,6 @@ def as_dict(self, recipe): return self.steps[0].as_dict(recipe) return {self.name: [s.as_dict(recipe) for s in self.steps]} - # def __repr__(self): - # return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" - def path(self, target, result, *path): if target is self: @@ -157,11 +154,7 @@ def _(k): return {self.owner.name: resolve(self.params, recipe)} - # def __repr__(self): - # return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" - def path(self, target, result, *path): - if self is target: result.append([*path, self]) From e69eb1071e38b0ec3f6ccfd45abe146b31a025cd Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 08:54:47 +0200 Subject: [PATCH 48/79] add settings --- src/anemoi/datasets/create/__init__.py | 4 +++- src/anemoi/datasets/create/python.py | 26 +++++++++++++++++++++++--- src/anemoi/datasets/recipe.py | 3 +++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index fbe4ab024..4d270f72d 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1669,13 +1669,15 @@ def config_to_python(config: Any) -> Any: from ..create.python import PythonScript + raw_config = config + config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) code = PythonScript() x = input.python_code(code) - code = code.source_code(x) + code = code.source_code(x, raw_config) try: import black diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 07c521824..9b5384b1c 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -494,8 +494,27 @@ def register(self, child): if child is not self: self.nodes.append(child) - def prelude(self): + def prelude(self, config): + + from anemoi.datasets.recipe import Recipe + + SKIP = ( + "input", + "data_sources", + ) + result = [] + + for k, v in config.items(): + + if k in SKIP: + continue + + if not hasattr(Recipe, k): + continue + + result.append(f"r.{k} = {repr(v)}") + for node in self.nodes: prelude = node.prelude() if prelude: @@ -504,7 +523,7 @@ def prelude(self): result.extend(prelude) return "\n".join(result) - def source_code(self, first): + def source_code(self, first, config): which = self.nodes.index(first) first.apply_references() @@ -531,9 +550,10 @@ def source_code(self, first): return "\n\n".join( [ "# Generated Python code for Anemoi dataset creation", + "import datetime", "from anemoi.datasets.recipe import Recipe", "r = Recipe()", - self.prelude(), + self.prelude(config), f"r.input = {repr(first)}", "r.dump()", ] diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 9b8516bf9..394304e9a 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -358,6 +358,9 @@ def dates(self): def _parse_dates(self, value): + if isinstance(value, dict): + return value + start = None end = None frequency = 1 From 92165b466df55baeb6b64e3c6ea2380ad22e9ac6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 08:56:51 +0200 Subject: [PATCH 49/79] add settings --- src/anemoi/datasets/create/python.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 9b5384b1c..d5accc2f4 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -511,6 +511,7 @@ def prelude(self, config): continue if not hasattr(Recipe, k): + LOG.warning(f"Unknown key in recipe: {k}") continue result.append(f"r.{k} = {repr(v)}") From a044e141d5c40f3caf4211cc145188857a3fa83a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 16:43:30 +0000 Subject: [PATCH 50/79] udpate --- src/anemoi/datasets/recipe.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 394304e9a..08b51b981 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -198,6 +198,8 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._licence = licence self._name = name self._dates = None + self._statistics = None + self._build = None self.input = Join() self.output = DotDict() @@ -243,6 +245,8 @@ def as_dict(self): "attribution": self.attribution, "licence": self.licence, "dates": self.dates, + "statistics": self.statistics, + "build": self.build, } if self._data_sources: @@ -391,6 +395,22 @@ def _parse_dates(self, value): def dates(self, value): self._dates = self._parse_dates(value) + @property + def statistics(self): + return self._statistics + + @statistics.setter + def statistics(self, value): + self._statistics = value + + @property + def build(self): + return self._build + + @build.setter + def build(self, value): + self._build = value + def dump(self, file=sys.stdout): input = self.input.as_dict(self) # First so we get the data_sources @@ -402,10 +422,10 @@ def dump(self, file=sys.stdout): result["output"] = self.output.as_dict() if self.statistics: - result["statistics"] = self.statistics.as_dict() + result["statistics"] = self.statistics if self.build: - result["build"] = self.build.as_dict() + result["build"] = self.build from .dumper import yaml_dump From cb9c5761e92bd3b468aec3779b4d51ca434b2897 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 18:46:07 +0200 Subject: [PATCH 51/79] use ruamel --- pyproject.toml | 1 + src/anemoi/datasets/dumper.py | 58 +++++++++++++---------------------- src/anemoi/datasets/recipe.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 873415a24..c10a858b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "numcodecs<0.16", # Until we move to zarr3 "numpy", "pyyaml", + "ruamel-yaml", "semantic-version", "tqdm", "zarr<=2.18.4", diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 7e3867969..69d9f4140 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -10,15 +10,11 @@ import datetime import logging -import yaml +import ruamel.yaml LOG = logging.getLogger(__name__) -class MyDumper(yaml.SafeDumper): - pass - - def represent_date(dumper, data): if data.tzinfo is None: data = data.replace(tzinfo=datetime.timezone.utc) @@ -30,7 +26,7 @@ def represent_date(dumper, data): # --- Represent multiline strings with | style --- def represent_multiline_str(dumper, data): if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -40,39 +36,15 @@ def represent_inline_list(dumper, data): if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): return dumper.represent_sequence("tag:yaml.org,2002:seq", data) - elems = [yaml.dump(i, explicit_start=False, explicit_end=False).replace("\n...\n", "") for i in data] - lines = [] - line = [] - for e in elems: - if sum(len(x) for x in line) + len(e) + 2 * (len(line) + 1) <= 80: - line.append(e) - else: - lines.append(line) - line = [e] - - if line: - lines.append(line) - - if len(lines) == 1: - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - - block_lines = ["- [" + ", ".join(line) + "]" for line in lines] - return dumper.represent_scalar("tag:yaml.org,2002:str", "\n".join(block_lines), style="|") - - -# Register representers -MyDumper.add_representer(datetime.date, represent_date) -MyDumper.add_representer(datetime.datetime, represent_date) -MyDumper.add_representer(str, represent_multiline_str) -MyDumper.add_representer(list, represent_inline_list) + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) def yaml_dump(obj, order=None, **kwargs): - kwargs.setdefault("Dumper", MyDumper) - kwargs.setdefault("sort_keys", False) - kwargs.setdefault("indent", 2) - kwargs.setdefault("width", 120) + # kwargs.setdefault("Dumper", MyDumper) + # kwargs.setdefault("sort_keys", False) + # kwargs.setdefault("indent", 2) + # kwargs.setdefault("width", 120) if order: @@ -81,4 +53,18 @@ def _ordering(k): obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} - return yaml.dump(obj, **kwargs) + # yaml = yaml.YAML(typ='unsafe', pure=True) + yaml = ruamel.yaml.YAML() + yaml.width = 120 # wrap long flow sequences + # yaml.default_flow_style = True + yaml.Representer.add_representer(datetime.date, represent_date) + yaml.Representer.add_representer(datetime.datetime, represent_date) + yaml.Representer.add_representer(str, represent_multiline_str) + yaml.Representer.add_representer(list, represent_inline_list) + + data = ruamel.yaml.comments.CommentedMap() + for k, v in obj.items(): + data[k] = v + data.yaml_set_comment_before_after_key(key=k, before="\n") + + return yaml.dump(data, **kwargs) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 08b51b981..f1bf60ddc 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -429,7 +429,7 @@ def dump(self, file=sys.stdout): from .dumper import yaml_dump - yaml_dump(result, sort_keys=False, indent=2, width=120, stream=file) + yaml_dump(result, stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From 99a5fb781d5682edbacfbb5a268409fec42eb8a7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 17:11:02 +0000 Subject: [PATCH 52/79] fix source as parameters --- src/anemoi/datasets/create/input/action.py | 2 +- src/anemoi/datasets/create/input/context/__init__.py | 9 +++++++-- src/anemoi/datasets/create/sources/repeated_dates.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 1e3c8175d..db9d8dace 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -275,7 +275,7 @@ def action_factory(data, *path): assert len(path) > 0, f"Path must contain at least one element {path}" assert path[0] in ("input", "data_sources") - assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert isinstance(data, dict), f"Input data must be a dictionary, got {type(data)}" assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" key, value = next(iter(data.items())) diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 81ccbd593..738f6a85b 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -56,10 +56,15 @@ def resolve(self, config): return config - def create_source(self, config: Any) -> Any: + def create_source(self, config: Any, *path) -> Any: from anemoi.datasets.create.input.action import action_factory - return action_factory(config) + if not isinstance(config, dict): + # It is already a result (e.g. ekd.FieldList), loaded from ${a.b.c} + # TODO: something more elegant + return lambda *args, **kwargs: config + + return action_factory(config, *path) @abstractmethod def empty_result(self) -> Any: ... diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index d092f08ad..eb235cd99 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -302,11 +302,12 @@ def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: class RepeatedDatesSource(Source): def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.owner = owner self.mapper = DateMapper.from_mode(mode, source, kwargs) self.source = source def execute(self, context, group_of_dates): - source = context.create_source(self.source) + source = context.create_source(self.source, *self.owner.path, "source") result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): From ce027f434b7ff598e2df7b49ea6a961460dc295b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 18 Aug 2025 18:50:34 +0200 Subject: [PATCH 53/79] update --- .../datasets/commands/recipe/__init__.py | 2 + src/anemoi/datasets/commands/recipe/format.py | 10 +- .../datasets/commands/recipe/migrate.py | 240 +++++++++--------- src/anemoi/datasets/create/config.py | 2 +- src/anemoi/datasets/create/python.py | 3 + src/anemoi/datasets/dumper.py | 36 +-- src/anemoi/datasets/recipe.py | 64 ++++- 7 files changed, 206 insertions(+), 151 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 546486f56..71b116213 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -80,6 +80,7 @@ def run(self, args: Any) -> None: if args.format: formatted = format_recipe(args, config) + assert "dates" in formatted f = sys.stdout if args.output: f = open(args.output, "w") @@ -88,6 +89,7 @@ def run(self, args: Any) -> None: f = open(args.path, "w") print(formatted, file=f) + f.close() if args.python: if args.inplace: diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index f221e0b8d..533a569c1 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -9,7 +9,6 @@ import datetime -import io import logging from ...dumper import yaml_dump @@ -53,11 +52,4 @@ def format_recipe(args, config: dict) -> str: config = make_dates(config) assert config - text = yaml_dump(config, order=ORDER) - f = io.StringIO() - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) - - return f.getvalue() + return yaml_dump(config, order=ORDER) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index daaa3bf16..31c67aa42 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -8,26 +8,22 @@ # nor does it submit to any jurisdiction. -import datetime import logging +import sys from collections.abc import Sequence from typing import Any import rich -import yaml from glom import assign from glom import delete from glom import glom from anemoi.datasets.create import validate_config +from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) -class MyDumper(yaml.SafeDumper): - pass - - def find_paths(data, target_key=None, target_value=None, *path): matches = [] @@ -58,85 +54,23 @@ def find_chevrons(data, *path): return matches -# Custom representer for datetime.date and datetime.datetime -def represent_date(dumper, data): - if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): - data = datetime.datetime(data.year, data.month, data.day, 0, 0, 0) - # Ensure it's UTC - if data.tzinfo is None: - data = data.replace(tzinfo=datetime.timezone.utc) - data = data.astimezone(datetime.timezone.utc) - # Format as ISO 8601 with 'Z' - iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" - return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) - - -# Custom representer for multiline strings using the '|' block style -def represent_multiline_str(dumper, data): - if "\n" in data: - text_list = [line.rstrip() for line in data.splitlines()] - fixed_data = "\n".join(text_list) - return dumper.represent_scalar("tag:yaml.org,2002:str", fixed_data, style="|") - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - -# --- Represent short lists inline (flow style) --- -def represent_inline_list(dumper, data): - # Flow style if list has <= 4 simple elements - if ( - all(isinstance(i, (str, int, float, bool, type(None))) for i in data) - and len(", ".join([str(x) for x in data])) + 2 <= 80 - ): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - return dumper.represent_sequence("tag:yaml.org,2002:seq", data) - - -# Register custom representers -MyDumper.add_representer(datetime.date, represent_date) -MyDumper.add_representer(datetime.datetime, represent_date) -MyDumper.add_representer(str, represent_multiline_str) -MyDumper.add_representer(list, represent_inline_list) - - -def make_dates(config): - if isinstance(config, dict): - return {k: make_dates(v) for k, v in config.items()} - if isinstance(config, list): - return [make_dates(v) for v in config] - if isinstance(config, str): - try: - return datetime.datetime.fromisoformat(config) - except ValueError: - return config - return config - - -ORDER = ( - "name", - "description", - "dataset_status", - "licence", - "attribution", - "env", - "dates", - "common", - "data_sources", - "input", - "output", - "statistics", - "build", - "platform", -) -ORDER = {k: i for i, k in enumerate(ORDER)} - - -def order(x: str) -> int: - +def find_paths_in_substrees(path, obj, cur_path=None): + if cur_path is None: + cur_path = [] + matches = [] try: - return ORDER[x[0]] - except KeyError: - rich.print(f"Unknown key: {x}") - raise + glom(obj, path) # just to check existence + matches.append(cur_path + path.split(".")) + except Exception: + pass + + if isinstance(obj, dict): + for k, v in obj.items(): + matches.extend(find_paths_in_substrees(path, v, cur_path + [k])) + elif isinstance(obj, list): + for i, v in enumerate(obj): + matches.extend(find_paths_in_substrees(path, v, cur_path + [str(i)])) + return matches MIGRATE = { @@ -153,24 +87,26 @@ def order(x: str) -> int: "loop.0.loop_a.dates": "dates", "dates.stop": "dates.end", "dates.group_by": "build.group_by", - "include": "data_sources", + "include.mars": "data_sources.mars.mars", "ensemble_dimension": "build.ensemble_dimension", "flatten_grid": "build.flatten_grid", } DELETE = [ "purpose", - "input.join.0.label", + # "input.join.0.label", "status", "common", "config_format_version", "aliases", - "platform", + # "platform", "loops.0.loop_a.applies_to", "loop.0.loop_a.applies_to", "dataset_status", "alias", "resources", + "input.dates.<<", + "input.dates.join.0.label.name", ] @@ -186,12 +122,12 @@ def order(x: str) -> int: MARKER = object() -def _delete(config, path, result): +def _delete(config, path): x = glom(config, path, default=MARKER) if x is MARKER: return rich.print(f"Deleting {path}={x}") - delete(result, path) + delete(config, path) def _move(config, path, new_path, result): @@ -203,12 +139,12 @@ def _move(config, path, new_path, result): assign(result, new_path, x, missing=dict) -def _fix_input_0(result, config): +def _fix_input_0(config): if isinstance(config["input"], dict): return input = config["input"] - new_input = result["input"] = [] + new_input = [] blocks = {} first = None @@ -227,26 +163,24 @@ def _fix_input_0(result, config): source_name = values.pop("name", None) if inherit is not None: + if inherit.startswith("$"): + inherit = inherit[1:] inherited = blocks[inherit].copy() inherited.update(values) values = inherited - if "source_or_dataset" in values: - values.pop("source_or_dataset", None) - values["template"] = "${input.join.0." + first + "}" - if first is None: first = source_name blocks[block_name] = values.copy() - new_input.append({block_name: {SOURCES.get(source_name, source_name): values.copy()}}) + new_input.append({SOURCES.get(source_name, source_name): values.copy()}) else: assert False, f"Block {block_name} does not have 'kwargs': {values}" blocks[block_name] = values.copy() - config["input"] = result["input"].copy() + config["input"] = dict(join=new_input) def _fix_input_1(result, config): @@ -410,7 +344,7 @@ def _fix_join(result: dict, config: dict) -> None: config["input"] = result["input"].copy() -def _fix_sources(result: dict, config: dict, what) -> None: +def _fix_sources(config: dict, what) -> None: input = config["input"] if what not in input: @@ -420,7 +354,7 @@ def _fix_sources(result: dict, config: dict, what) -> None: new_join = [] for j in join: assert isinstance(j, dict) - assert len(j) == 1 + assert len(j) == 1, j key, values = list(j.items())[0] @@ -432,10 +366,15 @@ def _fix_sources(result: dict, config: dict, what) -> None: } ) - result["input"][what] = new_join + config["input"][what] = new_join config["input"][what] = new_join.copy() +def _assign(config, path, value): + rich.print(f"Assign {path} {value}") + assign(config, path, value) + + def _fix_chevrons(result: dict, config: dict) -> None: rich.print("Fixing chevrons...") paths = find_chevrons(config) @@ -447,23 +386,91 @@ def _fix_chevrons(result: dict, config: dict) -> None: assign(result, ".".join(p[:-1]), a) +def _fix_some(config: dict) -> None: + + paths = find_paths_in_substrees("label.function", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + _assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("constants.source_or_dataset", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + node["template"] = node.pop("source_or_dataset") + if node["template"] == "$previous_data": + node["template"] = "${input.join.0.mars}" + paths = find_paths_in_substrees("constants.template", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + if node["template"] == "$pl_data": + node["template"] = "${input.join.0.mars}" + for d in ("date", "dates", "time"): + paths = find_paths_in_substrees(d, config) + for p in paths: + if len(p) > 1: + node = glom(config, ".".join(p[:-1])) + if isinstance(node, dict) and isinstance(node[d], str) and node[d].startswith("$"): + del node[d] + + paths = find_paths_in_substrees("source.<<", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + node.update(node.pop("<<")) + parent[node.pop("name")] = node + assert len(parent) == 2 + del parent["source"] + + paths = find_paths_in_substrees("label.mars", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("input.dates.join", config) + for p in paths: + node = glom(config, ".".join(p)) + config["input"]["join"] = node + del config["input"]["dates"] + + paths = find_paths_in_substrees("source.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assign(config, ".".join(p[:-2]), {name: node}) + + paths = find_paths_in_substrees("function.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assert node + assign(config, ".".join(p[:-2]), {name: node}) + + def _migrate(config: dict, n) -> dict: result = config.copy() - _fix_input_0(result, config) - _fix_loops(result, config) - _fix_input_1(result, config) - _fix_join(result, config) - _fix_sources(result, config, "join") - _fix_chevrons(result, config) - _fix_other(result, config) + _fix_input_0(result) + # _fix_loops(result, config) + # _fix_input_1(result, config) + # _fix_join(result, config) + # _fix_chevrons(result, config) + # _fix_other(result, config) for k, v in MIGRATE.items(): _move(config, k, v, result) + _fix_some(result) + _fix_sources(result, "join") + for k in DELETE: - _delete(config, k, result) + _delete(result, k) remove_empties(result) @@ -513,7 +520,6 @@ def has_value(config, value: str) -> bool: def check(config): - from anemoi.datasets.create import validate_config try: @@ -523,6 +529,7 @@ def check(config): assert not has_key(config, "label") assert not has_key(config, "kwargs") assert not has_value(config, "$previous_data") + assert not has_value(config, "$pl_data") assert not has_value(config, "$dates") assert not has_key(config, "inherit") assert not has_key(config, "source_or_dataset") @@ -532,25 +539,18 @@ def check(config): assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." except Exception as e: - rich.print(f"Validation failed: {e}") - rich.print(f"Config: {config}") - raise + rich.print("Validation failed:") + rich.print(e) + print(yaml_dump(config)) + sys.exit(1) def migrate_recipe(args: Any, config) -> None: rich.print(f"Migrating {args.path}") - try: - validate_config(config) - LOG.info(f"{args.path}: Validation successful.") - except Exception: - pass - migrated = migrate(config) - migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} - check(migrated) if migrated == config: return None diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 6de4a06cd..6b336654f 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -93,7 +93,7 @@ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: if dic[key] == value: return raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - LOG.info(f"Setting {key}={value} in config") + # LOG.info(f"Setting {key}={value} in config") dic[key] = value diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d5accc2f4..29b8c611d 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -501,6 +501,8 @@ def prelude(self, config): SKIP = ( "input", "data_sources", + "common", + "aliases", ) result = [] @@ -512,6 +514,7 @@ def prelude(self, config): if not hasattr(Recipe, k): LOG.warning(f"Unknown key in recipe: {k}") + assert False, f"Unknown key in recipe: {k}" continue result.append(f"r.{k} = {repr(v)}") diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 69d9f4140..18c8d34d4 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import datetime +import io import logging import ruamel.yaml @@ -16,10 +17,15 @@ def represent_date(dumper, data): - if data.tzinfo is None: - data = data.replace(tzinfo=datetime.timezone.utc) - data = data.astimezone(datetime.timezone.utc) - iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + + if isinstance(data, datetime.datetime): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + else: + iso_str = data.isoformat() + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) @@ -39,12 +45,7 @@ def represent_inline_list(dumper, data): return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) -def yaml_dump(obj, order=None, **kwargs): - - # kwargs.setdefault("Dumper", MyDumper) - # kwargs.setdefault("sort_keys", False) - # kwargs.setdefault("indent", 2) - # kwargs.setdefault("width", 120) +def yaml_dump(obj, order=None, stream=None, **kwargs): if order: @@ -53,18 +54,23 @@ def _ordering(k): obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} - # yaml = yaml.YAML(typ='unsafe', pure=True) yaml = ruamel.yaml.YAML() yaml.width = 120 # wrap long flow sequences - # yaml.default_flow_style = True + yaml.Representer.add_representer(datetime.date, represent_date) yaml.Representer.add_representer(datetime.datetime, represent_date) yaml.Representer.add_representer(str, represent_multiline_str) yaml.Representer.add_representer(list, represent_inline_list) data = ruamel.yaml.comments.CommentedMap() - for k, v in obj.items(): + for i, (k, v) in enumerate(obj.items()): data[k] = v - data.yaml_set_comment_before_after_key(key=k, before="\n") + if i > 0: + data.yaml_set_comment_before_after_key(key=k, before="\n") + + if stream: + yaml.dump(data, stream=stream, **kwargs) - return yaml.dump(data, **kwargs) + stream = io.StringIO() + yaml.dump(data, stream=stream, **kwargs) + return stream.getvalue() diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index f1bf60ddc..c0dbc1bea 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -25,6 +25,16 @@ LOG = logging.getLogger(__name__) +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + class Index: def __init__(self, index): self.name = str(index) @@ -135,7 +145,7 @@ def resolve(params, recipe, name=None): if isinstance(params, dict): def _(k): - if k.endswith("_"): + if isinstance(k, str) and k.endswith("_"): return k[:-1] return k @@ -200,6 +210,10 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._dates = None self._statistics = None self._build = None + self._env = None + self._dataset_status = None + self._output = None + self._platform = None self.input = Join() self.output = DotDict() @@ -261,8 +275,6 @@ def as_dict(self): def concat(self, *args, **kwargs): return Concat(*args, **kwargs) - # def assert False, (name, target.as_dict(self)) - def make_data_source(self, name, target): target = target.as_dict(self) @@ -281,7 +293,6 @@ def make_data_source(self, name, target): return f"${{data_sources.{name}}}" def resolve(self, source, target, name=None): - # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -395,6 +406,14 @@ def _parse_dates(self, value): def dates(self, value): self._dates = self._parse_dates(value) + @property + def output(self): + return self._output + + @output.setter + def output(self, value): + self._output = value + @property def statistics(self): return self._statistics @@ -411,6 +430,30 @@ def build(self): def build(self, value): self._build = value + @property + def env(self): + return self._env + + @env.setter + def env(self, value): + self._env = value + + @property + def dataset_status(self): + return self._dataset_status + + @dataset_status.setter + def dataset_status(self, value): + self._dataset_status = value + + @property + def platform(self): + return self._platform + + @platform.setter + def platform(self, value): + self._platform = value + def dump(self, file=sys.stdout): input = self.input.as_dict(self) # First so we get the data_sources @@ -419,7 +462,7 @@ def dump(self, file=sys.stdout): result["input"] = input if self.output: - result["output"] = self.output.as_dict() + result["output"] = self.output if self.statistics: result["statistics"] = self.statistics @@ -427,9 +470,18 @@ def dump(self, file=sys.stdout): if self.build: result["build"] = self.build + if self.env: + result["env"] = self.env + + if self.dataset_status: + result["dataset_status"] = self.dataset_status + + if self.platform: + result["platform"] = self.platform + from .dumper import yaml_dump - yaml_dump(result, stream=file) + yaml_dump(_un_dotdict(result), stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From 3d5f0ef62b29c5d103096e7b1d3026984a5a56e1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 08:38:51 +0100 Subject: [PATCH 54/79] tidy --- .../datasets/commands/recipe/__init__.py | 2 +- .../datasets/commands/recipe/migrate.py | 19 +++++------- src/anemoi/datasets/create/__init__.py | 3 -- src/anemoi/datasets/create/input/action.py | 4 --- .../datasets/create/input/context/__init__.py | 10 +++---- .../datasets/create/input/context/field.py | 3 +- .../datasets/create/sources/repeated_dates.py | 29 ++++++++----------- 7 files changed, 26 insertions(+), 44 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 71b116213..bf08d1ee7 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -54,7 +54,7 @@ def run(self, args: Any) -> None: if not args.validate and not args.format and not args.migrate and not args.python: args.validate = True - with open(args.path, "r") as file: + with open(args.path) as file: config = yaml.safe_load(file) assert isinstance(config, dict) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 31c67aa42..03da61fbc 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -13,7 +13,6 @@ from collections.abc import Sequence from typing import Any -import rich from glom import assign from glom import delete from glom import glom @@ -126,7 +125,6 @@ def _delete(config, path): x = glom(config, path, default=MARKER) if x is MARKER: return - rich.print(f"Deleting {path}={x}") delete(config, path) @@ -134,7 +132,6 @@ def _move(config, path, new_path, result): x = glom(config, path, default=MARKER) if x is MARKER: return - rich.print(f"Moving {path}={x} to {new_path}={x}") delete(result, path) assign(result, new_path, x, missing=dict) @@ -265,7 +262,7 @@ def _fix_loops(result: dict, config: dict) -> None: concat = [] result["input"] = {"concat": concat} - rich.print("Found loops:", entries) + print("Found loops:", entries) for block in input: assert isinstance(block, dict), block @@ -296,7 +293,7 @@ def _fix_loops(result: dict, config: dict) -> None: def _fix_other(result: dict, config: dict) -> None: paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") for p in paths: - rich.print(f"Fixing {'.'.join(p)}") + print(f"Fixing {'.'.join(p)}") assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) delete(result, ".".join(p)) @@ -306,7 +303,7 @@ def _fix_other(result: dict, config: dict) -> None: def _fix_join(result: dict, config: dict) -> None: - rich.print("Fixing join...") + print("Fixing join...") input = config["input"] if "dates" in input and "join" in input["dates"]: result["input"]["join"] = input["dates"]["join"] @@ -371,12 +368,12 @@ def _fix_sources(config: dict, what) -> None: def _assign(config, path, value): - rich.print(f"Assign {path} {value}") + print(f"Assign {path} {value}") assign(config, path, value) def _fix_chevrons(result: dict, config: dict) -> None: - rich.print("Fixing chevrons...") + print("Fixing chevrons...") paths = find_chevrons(config) for p in paths: a = glom(config, ".".join(p)) @@ -539,15 +536,15 @@ def check(config): assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." except Exception as e: - rich.print("Validation failed:") - rich.print(e) + print("Validation failed:") + print(e) print(yaml_dump(config)) sys.exit(1) def migrate_recipe(args: Any, config) -> None: - rich.print(f"Migrating {args.path}") + print(f"Migrating {args.path}") migrated = migrate(config) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 8eb1c87da..3c615e21f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -19,7 +19,6 @@ import cftime import numpy as np -import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -670,8 +669,6 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) - rich.print("Minimal input dates:", self.minimal_input) - variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index db9d8dace..7afd81e35 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,8 +9,6 @@ import logging -import rich - from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) @@ -122,8 +120,6 @@ def __call__(self, context, argument): source = self.create_object(config) - rich.print(f"Executing source {self.name} from {config}") - return context.register(self.call_object(context, source, argument), self.path) def python_code(self, code) -> str: diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 738f6a85b..eef61504c 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -12,8 +12,6 @@ from abc import abstractmethod from typing import Any -import rich - LOG = logging.getLogger(__name__) @@ -27,7 +25,7 @@ def __init__(self, /, argument: Any) -> None: def trace(self, emoji, *message) -> None: - rich.print(f"{emoji}: {message}") + print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: @@ -36,7 +34,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - rich.print(f"Registering data at path: {path}") + print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -49,9 +47,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: - rich.print(f"Path not found {path}") + print(f"Path not found {path}") for p in sorted(self.results): - rich.print(f" Available paths: {p}") + print(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index c3456d89f..23282dce0 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -9,7 +9,6 @@ from typing import Any -from typing import Dict from earthkit.data.core.order import build_remapping @@ -25,7 +24,7 @@ def __init__( argument: Any, order_by: str, flatten_grid: bool, - remapping: Dict[str, Any], + remapping: dict[str, Any], use_grib_paramid: bool, ) -> None: super().__init__(argument) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index eb235cd99..77a06c76c 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -10,15 +10,10 @@ import logging from collections import defaultdict +from collections.abc import Generator from typing import Any -from typing import Dict -from typing import Generator -from typing import Optional -from typing import Set -from typing import Tuple import numpy as np -import rich from anemoi.transform.fields import new_field_with_valid_datetime from anemoi.transform.fields import new_fieldlist_from_list from anemoi.utils.dates import as_datetime @@ -52,7 +47,7 @@ class DateMapper: """A factory class to create DateMapper instances based on the given mode.""" @staticmethod - def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": """Create a DateMapper instance based on the given mode. Parameters @@ -102,10 +97,10 @@ def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", ski self.maximum: Any = frequency_to_timedelta(maximum) self.frequency: Any = frequency_to_timedelta(frequency) self.skip_all_nans: bool = skip_all_nans - self.tried: Set[Any] = set() - self.found: Set[Any] = set() + self.tried: set[Any] = set() + self.found: set[Any] = set() - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: """Transform the group of dates to the closest available dates. Parameters @@ -200,7 +195,7 @@ def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, Non class DateMapperClimatology(DateMapper): """A DateMapper implementation that maps dates to specified climatology dates.""" - def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: """Initialize DateMapperClimatology. Parameters @@ -216,9 +211,9 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) """ self.year: int = year self.day: int = day - self.hour: Optional[int] = hour + self.hour: int | None = hour - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. Parameters @@ -254,7 +249,7 @@ def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, Non class DateMapperConstant(DateMapper): """A DateMapper implementation that maps dates to a constant date.""" - def __init__(self, source: Any, date: Optional[Any] = None) -> None: + def __init__(self, source: Any, date: Any | None = None) -> None: """Initialize DateMapperConstant. Parameters @@ -265,9 +260,9 @@ def __init__(self, source: Any, date: Optional[Any] = None) -> None: The constant date to map to. """ self.source: Any = source - self.date: Optional[Any] = date + self.date: Any | None = date - def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + def transform(self, group_of_dates: Any) -> tuple[Any, Any]: """Transform the group of dates to a constant date. Parameters @@ -311,7 +306,7 @@ def execute(self, context, group_of_dates): result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): - rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") source_results = source(context, one_date_group) for field in source_results: for date in many_dates_group: From 70272f69a112c75060a51949122f0beb24c7046d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 09:24:21 +0100 Subject: [PATCH 55/79] update --- src/anemoi/datasets/create/input/action.py | 7 +------ src/anemoi/datasets/create/input/context/field.py | 1 + 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7afd81e35..7bf6bb017 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -242,7 +242,6 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry from anemoi.datasets.create.sources import source_registry as dataset_source_registry # Register sources, local first @@ -254,11 +253,7 @@ def make(key, config, *path): if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) - # Register filters, local first - for name in dataset_filter_registry.registered: - if name not in KLASS: - KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) - + # Register filters for name in transform_filter_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 23282dce0..1dd01340e 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -32,6 +32,7 @@ def __init__( self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) self.use_grib_paramid = use_grib_paramid + self.partial_ok = False def empty_result(self) -> Any: import earthkit.data as ekd From 38ced18e528a2b4875a5476c971b9d948142d396 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 11:43:25 +0100 Subject: [PATCH 56/79] update tests --- src/anemoi/datasets/create/input/action.py | 6 +++++- tests/create/test_create.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7bf6bb017..25d7690e9 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -197,7 +197,11 @@ def new_filter(name, mixin): class DataSources(Action): def __init__(self, config, *path): super().__init__(config, *path) - self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + assert isinstance(config, (dict, list)), f"Invalid config type: {type(config)}" + if isinstance(config, dict): + self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + else: + self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 193c9a26a..dd3f37864 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -14,6 +14,8 @@ from unittest.mock import patch import pytest +from anemoi.transform.filter import Filter +from anemoi.transform.filters import filter_registry from anemoi.utils.testing import GetTestArchive from anemoi.utils.testing import GetTestData from anemoi.utils.testing import skip_if_offline @@ -32,6 +34,18 @@ assert NAMES, "No yaml files found in " + HERE +# Used by pipe.yaml +@filter_registry.register("filter") +class TestFilter(Filter): + + def __init__(self, **kwargs): + + self.kwargs = kwargs + + def forward(self, data): + return data.sel(**self.kwargs) + + @pytest.fixture def load_source(get_test_data: GetTestData) -> LoadSource: return LoadSource(get_test_data) From cb3847e4c1b23262658feb7791b0f441b8c764e5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 12:41:00 +0100 Subject: [PATCH 57/79] fix tests --- src/anemoi/datasets/create/input/action.py | 26 ++++++-------------- src/anemoi/datasets/create/sources/legacy.py | 4 +-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 25d7690e9..7e164c586 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -118,7 +118,7 @@ def __call__(self, context, argument): config["_type"] = self.name # Find a better way to do this - source = self.create_object(config) + source = self.create_object(context, config) return context.register(self.call_object(context, source, argument), self.path) @@ -131,37 +131,27 @@ def python_code(self, code) -> str: class DatasetSourceMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.datasets.create.sources import create_source as create_datasets_source - return create_datasets_source(self, config) + return create_datasets_source(context, config) def call_object(self, context, source, argument): - return source.execute(context, context.source_argument(argument)) - - -class DatasetFilterMixin: - def create_object(self, config): - from anemoi.datasets.create.filters import create_filter as create_datasets_filter - - return create_datasets_filter(self, config) - - def call_object(self, context, filter, argument): - return filter.execute(context.filter_argument(argument)) + return source.execute(context.source_argument(argument)) class TransformSourceMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.transform.sources import create_source as create_transform_source - return create_transform_source(self, config) + return create_transform_source(context, config) class TransformFilterMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.transform.filters import create_filter as create_transform_filter - return create_transform_filter(self, config) + return create_transform_filter(context, config) def call_object(self, context, filter, argument): return filter.forward(context.filter_argument(argument)) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d7ab82dd1..d7a15bfe7 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -69,14 +69,14 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, context, dates) -> Any: + def execute_wrapper(self, dates) -> Any: """Wrapper method to call the execute function.""" # args, kwargs = resolve(context, (self.args, self.kwargs)) args, kwargs = self.args, self.kwargs try: - return execute(context, dates, *args, **kwargs) + return execute(self.context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") From 00477c9f805e7e6552b068ffbf0da7a0c572415b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 25 Aug 2025 17:37:44 +0100 Subject: [PATCH 58/79] add missing package --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 30f82757e..a1f96a221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "anemoi-transform>=0.1.10", "anemoi-utils[provenance]>=0.4.32", "cfunits", + "glom", "numcodecs<0.16", # Until we move to zarr3 "numpy", "pyyaml", From 619c416c3c450709a8b233b5ccbd2e4278a779d9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 15 Sep 2025 14:45:07 +0100 Subject: [PATCH 59/79] remove unsused file --- src/anemoi/datasets/create/python.py | 578 --------------------------- 1 file changed, 578 deletions(-) delete mode 100644 src/anemoi/datasets/create/python.py diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py deleted file mode 100644 index 29b8c611d..000000000 --- a/src/anemoi/datasets/create/python.py +++ /dev/null @@ -1,578 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import logging -import re -from collections import defaultdict -from functools import cached_property - -LOG = logging.getLogger(__name__) - -RESERVED_KEYWORDS = ( - "and", - "or", - "not", - "is", - "in", - "if", - "else", - "elif", - "for", - "while", - "return", - "class", - "def", - "with", - "as", - "import", - "from", - "try", - "except", - "finally", - "raise", - "assert", - "break", - "continue", - "pass", -) - - -def _sanitize_name(name): - name = name.replace("-", "_") - if name in RESERVED_KEYWORDS: - name = f"{name}_" - return name - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class PythonCode: - - def __init__(self, top): - self.top = top - self.top.register(self) - self.key = str(id(self)) - - def call(self, name, argument): - return PythonCall(self.top, name, argument) - - def sum(self, actions): - return PythonChain(self.top, "join", "&", actions) - - def pipe(self, actions): - return PythonChain(self.top, "pipe", "|", actions) - - def concat(self, argument): - return PythonConcat(self.top, argument) - - def source_code(self): - return self.top.source_code(self) - - def combine(self, nodes): - return None - - def recipe(self, input, data_sources): - return PythonRecipe(self.top, input, data_sources) - - def prelude(self): - return None - - def sources(self, sources): - return PythonSources(self.top, sources) - - def update_anchor(self): - pass - - -class Variable(PythonCode): - def __init__(self, name, node): - super().__init__(top=node.top) - self.name = name - self.node = node - - def __repr__(self): - return "" - - def replace_node(self, old, new): - pass - - def prelude(self): - return [f"{self.name} = {repr(self.node)}", ""] - - -class InLine(PythonCode): - def __init__(self, node): - super().__init__(top=node.top) - self.node = node - - @cached_property - def name(self): - n = self.top.counter["_anchor"] - self.top.counter["_anchor"] += 1 - return f"_a{n}" - - def __repr__(self): - return f"({self.name} := {repr(self.node)})" - - def replace_node(self, old, new): - pass - - -class PythonRecipe(PythonCode): - def __init__(self, top, input, data_sources): - super().__init__(top) - self.input = input - self.data_sources = data_sources - - def apply_references(self, *path): - self.data_sources.apply_references(*path, "data_sources") - self.input.apply_references(*path, "input") - - def replace_node(self, old, new): - if self.input is old: - self.input = new - return - - if self.data_sources is old: - self.data_sources = new - return - - self.input.replace_node(old, new) - self.data_sources.replace_node(old, new) - - def __repr__(self): - return repr(self.input) - - def prelude(self): - return self.data_sources.prelude() - - -class Argument(PythonCode): - - def __init__(self, top, name): - super().__init__(top=top) - self.name = _sanitize_name(name) - - def __repr__(self): - return self.name - - def replace_node(self, old, new): - pass - - -class Anchor(PythonCode): - - def __init__(self, identifier): - super().__init__(top=identifier.node.top) - self.identifier = identifier - - @property - def name(self): - return self.identifier.name - - def __repr__(self): - # assert False - return repr(self.identifier) - - def replace_node(self, old, new): - pass - - -class Reference(PythonCode): - - def __init__(self, top, path): - super().__init__(top) - self.path = tuple(path) - self.anchor = None - - def update_anchor(self): - - node = self.top.by_reference.get(self.path, None) - if node is None: - LOG.warning(f"Reference {self.path} not found") - for p in sorted(self.top.by_reference): - LOG.warning(f" - {p}") - else: - self.anchor = Anchor(node) - self.top.replace_nodes([(node.node, self.anchor)]) - - def __repr__(self): - if self.anchor is not None: - return self.anchor.name - - return f"'${{{'.'.join(self.path)}}}'" - - def replace_node(self, old, new): - pass - - -class Function(PythonCode): - def __init__(self, name, node, counter): - super().__init__(top=node.top) - self._name = name - self.node = node - self.used = False - self.counter = counter - - def __repr__(self): - return self.name - - def prelude(self): - if self.used: - return None - - self.used = True - - node_prelude = self.node.prelude() - - arguments = self.node.free_arguments() - - return [ - *(node_prelude if node_prelude else []), - f"def {self.name}({','.join(repr(p) for p in arguments)}):", - f" return {self.node}", - ] - - def free_arguments(self): - return self.node.free_arguments() - - @cached_property - def name(self): - n = self.counter[self._name] - self.counter[self._name] += 1 - if n == 0: - return _sanitize_name(self._name) - return _sanitize_name(f"{self._name}_{n}") - - def replace_node(self, old, new): - if self.node is old: - self.node = new - - -class PythonSources(PythonCode): - def __init__(self, top, sources): - super().__init__(top) - self.sources = sources - - def __repr__(self): - return "" - - def prelude(self): - pass - - def replace_node(self, old, new): - for k, v in list(self.sources.items()): - if v is old: - self.sources[k] = new - else: - v.replace_node(old, new) - - def apply_references(self, *path): - for k, v in self.sources.items(): - self.top.by_reference[path + (k,)] = Variable(k, v) - - -class PythonConcat(PythonCode): - def __init__(self, top, argument): - super().__init__(top=top) - self.argument = _un_dotdict(argument) - - def __repr__(self): - return f"r.concat({self.argument})" - - def replace_node(self, old, new): - for k, v in list(self.argument.items()): - if v is old: - self.argument[k] = new - else: - v.replace_node(old, new) - - def apply_references(self, *path): - assert "concat" not in path, path - self.top.by_reference[path + ("concat",)] = InLine(self) - for i, node in enumerate(self.argument.values()): - node.apply_references(*path, "concat", str(i)) - - -class PythonChain(PythonCode): - def __init__(self, top, kind, op, actions): - super().__init__(top=top) - self.op = op - self.kind = kind - self.actions = list(actions) - self.key = op - - def __repr__(self): - return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - def replace_node(self, old, new): - - for i, node in enumerate(self.actions): - - if node is old: - self.actions[i] = new - else: - node.replace_node(old, new) - - def apply_references(self, *path): - self.top.by_reference[path + (self.kind,)] = InLine(self) - for i, node in enumerate(self.actions): - node.apply_references(*path, self.kind, str(i)) - - -class PythonCall(PythonCode): - def __init__(self, top, name, argument): - super().__init__(top=top) - self.name = name - self.argument = _un_dotdict(argument) - self.key = name - - def free_arguments(self): - result = [] - for k, v in self.argument.items(): - if isinstance(v, Argument): - result.append(v) - return result - - def __repr__(self): - name = self.name.replace("-", "_") - config = dict(**self.argument) - - params = [] - - for k, v in config.items(): - k = _sanitize_name(k) - - if not k.isidentifier(): - return f"r.{name}({config})" - - params.append(f"{k}={repr(v)}") - - if params: - params.append("") # For a trailing comma - - params = ",".join(params) - return f"r.{name}({params})" - - def replace_node(self, old, new): - pass - - def combine(self, nodes): - - # Exact similarity - - changes = self._combine0(nodes) - if changes: - return changes - - # On key difference - changes = self._combine1(nodes) - if changes: - return changes - - def _combine0(self, nodes): - - x = defaultdict(list) - for node in nodes: - key = {k2: v2 for k2, v2 in sorted(node.argument.items())} - x[str(key)].append(node) - - for i in sorted(x.values(), key=len, reverse=True): - node = i[0] - if len(i) < 2: - return - - call = PythonCall(self.top, self.name, node.argument) - - func = self.top.function(call) - changes = [] - for node in i: - - new = PythonFunction(top=self.top, func=func) - - changes.append((node, new)) - - return changes - - def _combine1(self, nodes): - - x = defaultdict(list) - for node in nodes: - argument = node.argument - for k, v in argument.items(): - rest = {k2: v2 for k2, v2 in sorted(argument.items()) if k2 != k} - x[str(rest)].append((k, v, node)) - - for i in sorted(x.values(), key=len, reverse=True): - key, value, node = i[0] - if len(i) < 2: - return - - rest = {k: v for k, v in node.argument.items() if k != key} - rest[key] = Argument(self.top, key) - call = PythonCall(self.top, self.name, rest) - - func = self.top.function(call) - changes = [] - for key, value, node in i: - - new = PythonFunction( - top=self.top, - func=func, - **{key: value}, - ) - - changes.append((node, new)) - - return changes - - def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = InLine(self) - - for k, v in self.argument.items(): - if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): - path = m.group(1).split(".") - self.argument[k] = Reference(self.top, path) - - -class PythonFunction(PythonCode): - def __init__(self, top, func, **kwargs): - super().__init__(top=top) - self.func = func - self.kwargs = kwargs - - def __repr__(self): - - params = [] - for a in self.func.free_arguments(): - name = _sanitize_name(a.name) - if a.name in self.kwargs: - v = self.kwargs[a.name] - params.append(f"{name}={repr(v)}") - else: - params.append(f"{name}={name}") - - return f"{self.func}({', '.join(params)})" - - def replace_node(self, old, new): - self.func.replace_node(old, new) - - def prelude(self): - return self.func.prelude() - - def free_arguments(self): - return [a for a in self.func.free_arguments() if a.name not in self.kwargs] - - def apply_references(self, *path): - pass - - -class PythonScript(PythonCode): - - def __init__(self): - self.nodes = [] - self.counter = defaultdict(int) - self.by_reference = {} - super().__init__(top=self) - - def register(self, child): - if child is not self: - self.nodes.append(child) - - def prelude(self, config): - - from anemoi.datasets.recipe import Recipe - - SKIP = ( - "input", - "data_sources", - "common", - "aliases", - ) - - result = [] - - for k, v in config.items(): - - if k in SKIP: - continue - - if not hasattr(Recipe, k): - LOG.warning(f"Unknown key in recipe: {k}") - assert False, f"Unknown key in recipe: {k}" - continue - - result.append(f"r.{k} = {repr(v)}") - - for node in self.nodes: - prelude = node.prelude() - if prelude: - if not isinstance(prelude, (list, tuple)): - prelude = list(prelude) - result.extend(prelude) - return "\n".join(result) - - def source_code(self, first, config): - - which = self.nodes.index(first) - first.apply_references() - for node in self.nodes: - node.update_anchor() - - more = True - while more: - more = False - - by_class = defaultdict(list) - for node in self.nodes: - by_class[(node.__class__, node.key)].append(node) - - for nodes in by_class.values(): - if len(nodes) > 1: - changes = nodes[0].combine(nodes) - if changes: - self.replace_nodes(changes) - more = True - - first = self.nodes[which] - - return "\n\n".join( - [ - "# Generated Python code for Anemoi dataset creation", - "import datetime", - "from anemoi.datasets.recipe import Recipe", - "r = Recipe()", - self.prelude(config), - f"r.input = {repr(first)}", - "r.dump()", - ] - ) - - def function(self, node): - return Function(node.name, node, self.counter) - - def replace_nodes(self, changes): - - for old, new in changes: - assert old in self.nodes, f"Node {old} not found in {self.nodes}" - for i, node in enumerate(self.nodes): - - if node is old: - self.nodes[i] = new - else: - node.replace_node(old, new) From 21321cb4ee0d2c708896e520e8fcef6d665c826d Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Mon, 15 Sep 2025 15:46:47 +0100 Subject: [PATCH 60/79] Update copyright year --- src/anemoi/datasets/commands/recipe/format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index 533a569c1..872060981 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. From 90e21eede3a3dcc833b566e4609c29e26dd0b95e Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Mon, 15 Sep 2025 16:06:00 +0100 Subject: [PATCH 61/79] Fix copyright header --- src/anemoi/datasets/create/sources/repeated_dates.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index 77a06c76c..6ba8b1e4b 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -22,16 +22,6 @@ from anemoi.datasets.create.source import Source from anemoi.datasets.create.sources import source_registry -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - LOG = logging.getLogger(__name__) From fa3575547c56922cdec301e06ece2528ed355bb2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 18 Sep 2025 10:21:28 +0000 Subject: [PATCH 62/79] bug fix in repeated dates --- src/anemoi/datasets/create/sources/repeated_dates.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index 6ba8b1e4b..b56537979 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -286,18 +286,19 @@ def transform(self, group_of_dates: Any) -> tuple[Any, Any]: @source_registry.register("repeated_dates") class RepeatedDatesSource(Source): - def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: - self.owner = owner + def __init__(self, context, source: Any, mode: str, **kwargs) -> None: + # assert False, (context, source, mode, kwargs) + super().__init__(context, **kwargs) self.mapper = DateMapper.from_mode(mode, source, kwargs) self.source = source - def execute(self, context, group_of_dates): - source = context.create_source(self.source, *self.owner.path, "source") + def execute(self, group_of_dates): + source = self.context.create_source(self.source, "data_sources", str(id(self))) result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") - source_results = source(context, one_date_group) + source_results = source(self.context, one_date_group) for field in source_results: for date in many_dates_group: result.append(new_field_with_valid_datetime(field, date)) From 2fc4189f0830303775aaeaa058cc1645eec28fa0 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 18 Sep 2025 12:11:43 +0000 Subject: [PATCH 63/79] bug fix in repeated dates --- src/anemoi/datasets/dates/groups.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index 547e99892..b5454dbbd 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -73,6 +73,21 @@ def __len__(self) -> int: """ return len(self.dates) + def __getitem__(self, index) -> datetime.datetime: + """Return the date at the specified index. + + Parameters + ---------- + index : int + The index of the date. + + Returns + ------- + datetime.datetime + The date at the specified index. + """ + return self.dates[index] + def __iter__(self) -> Iterator[datetime.datetime]: """Return an iterator over the dates in the group. From 9cb1b631e49ac4673f923e9955e934ad990386c9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 21 Sep 2025 18:40:56 +0100 Subject: [PATCH 64/79] remove python generating code that will be in another PR --- .../datasets/commands/recipe/__init__.py | 21 +- src/anemoi/datasets/create/__init__.py | 22 - src/anemoi/datasets/recipe.py | 539 ------------------ 3 files changed, 3 insertions(+), 579 deletions(-) delete mode 100644 src/anemoi/datasets/recipe.py diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index bf08d1ee7..7630547bc 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -38,7 +38,6 @@ def add_arguments(self, command_parser: Any) -> None: command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") command_parser.add_argument("--format", action="store_true", help="Format the recipe.") command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") - command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") group = command_parser.add_mutually_exclusive_group() group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") @@ -51,7 +50,7 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: - if not args.validate and not args.format and not args.migrate and not args.python: + if not args.validate and not args.format and not args.migrate: args.validate = True with open(args.path) as file: @@ -60,10 +59,10 @@ def run(self, args: Any) -> None: assert isinstance(config, dict) if args.validate: - if args.inplace and (not args.format and not args.migrate and not args.python): + if args.inplace and (not args.format and not args.migrate): argparse.ArgumentError(None, "--inplace is not supported with --validate.") - if args.output and (not args.format and not args.migrate and not args.python): + if args.output and (not args.format and not args.migrate): argparse.ArgumentError(None, "--output is not supported with --validate.") validate_config(config) @@ -91,18 +90,4 @@ def run(self, args: Any) -> None: print(formatted, file=f) f.close() - if args.python: - if args.inplace: - argparse.ArgumentError(None, "Inplace conversion to Python is not supported.") - - if args.format: - raise argparse.ArgumentError(None, "Formatting is not supported when converting to Python.") - - if args.output: - with open(args.output, "w") as file: - file.write(config_to_python(config)) - else: - print(config_to_python(config)) - - command = Recipe diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 3c615e21f..352b5a6d6 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1660,25 +1660,3 @@ def _tidy(d): raise -def config_to_python(config: Any) -> Any: - - from ..create.python import PythonScript - - raw_config = config - - config = loader_config(config) - - input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - - code = PythonScript() - x = input.python_code(code) - code = code.source_code(x, raw_config) - - try: - import black - - return black.format_str(code, mode=black.Mode()) - # except ImportError: - except Exception: - LOG.warning("Black not installed, skipping formatting") - return code diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py deleted file mode 100644 index c0dbc1bea..000000000 --- a/src/anemoi/datasets/recipe.py +++ /dev/null @@ -1,539 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import os -import sys -from collections import defaultdict -from tempfile import TemporaryDirectory - -from anemoi.transform.filters import filter_registry as transform_filter_registry -from anemoi.utils.config import DotDict -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry -from anemoi.datasets.create.sources import source_registry - -LOG = logging.getLogger(__name__) - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class Index: - def __init__(self, index): - self.name = str(index) - - def __repr__(self): - return f"Index({self.name})" - - def same(self, other): - if not isinstance(other, Index): - return False - return self.name == other.name - - -class Step: - - def __or__(self, other): - return Pipe(self, other) - - def __and__(self, other): - return Join(self, other) - - def same(self, other): - return self is other - - -class Chain(Step): - def __init__(self, *args): - if len(args) > 0 and isinstance(args[0], self.__class__): - args = args[0].steps + args[1:] - - self.steps = args - self.index = [Index(i) for i in range(len(self.steps))] - - def as_dict(self, recipe): - if len(self.steps) == 1: - return self.steps[0].as_dict(recipe) - return {self.name: [s.as_dict(recipe) for s in self.steps]} - - def path(self, target, result, *path): - - if target is self: - result.append([*path, self]) - return - - for i, s in enumerate(self.steps): - s.path(target, result, *path, self, self.index[i]) - - def collocated(self, a, b): - return True - - -class Pipe(Chain): - name = "pipe" - - -class Join(Chain): - name = "join" - - -class Concat(Step): - name = "concat" - - def __init__(self, args): - assert isinstance(args, dict), f"Invalid argument {args}" - self.params = args - - def __setitem__(self, key, value): - self.params[key] = value - - def as_dict(self, recipe): - - result = [] - - for k, v in sorted(self.params.items()): - - key = dict(start=as_datetime(k[0]), end=as_datetime(k[1])) - if len(k) == 3: - key["frequency"] = k[2] - - result.append({"dates": key, **v.as_dict(recipe)}) - - return {"concat": result} - - def collocated(self, a, b): - return a[0].same(b[0]) - - def path(self, target, result, *path): - if target is self: - result.append([*path, self]) - return - for i, (k, v) in enumerate(sorted(self.params.items())): - v.path(target, result, *path, self, Index(i)) - - -class Base(Step): - def __init__(self, owner, *args, **kwargs): - self.owner = owner - self.name = owner.name - self.params = {} - for a in args: - assert isinstance(a, dict), f"Invalid argument {a}" - self.params.update(a) - self.params.update(kwargs) - - def as_dict(self, recipe): - - def resolve(params, recipe, name=None): - if isinstance(params, dict): - - def _(k): - if isinstance(k, str) and k.endswith("_"): - return k[:-1] - return k - - return {_(k): resolve(v, recipe, name=_(k)) for k, v in params.items()} - - if isinstance(params, (list, tuple)): - return [resolve(v, recipe) for v in params] - - if isinstance(params, list): - return [resolve(v, recipe) for v in params] - - if isinstance(params, Step): - return recipe.resolve(self, params, name=name) - - return params - - return {self.owner.name: resolve(self.params, recipe)} - - def path(self, target, result, *path): - if self is target: - result.append([*path, self]) - - -class Source(Base): - pass - - -class Filter(Base): - pass - - -class SourceMaker: - def __init__(self, name, factory): - self.name = name - self.factory = factory - - def __call__(self, *args, **kwargs): - return Source(self, *args, **kwargs) - - -class FilterMaker: - def __init__(self, name, factory): - self.name = name - self.factory = factory - - def __call__(self, *args, **kwargs): - if len(args) > 0 and isinstance(args[0], Step): - prev = args[0] - args = args[1:] - return Pipe(prev, Filter(self, *args, **kwargs)) - return Filter(self, *args, **kwargs) - - -class Recipe: - - def __init__(self, name=None, description=None, attribution=None, licence=None): - - self._description = description - self._attribution = attribution - self._licence = licence - self._name = name - self._dates = None - self._statistics = None - self._build = None - self._env = None - self._dataset_status = None - self._output = None - self._platform = None - - self.input = Join() - self.output = DotDict() - self.statistics = DotDict() - self.build = DotDict() - - self._data_sources = {} - self._counter = defaultdict(int) - - sources = source_registry.factories.copy() - filters = transform_filter_registry.factories.copy() - - for key, factory in datasets_filter_registry.factories.items(): - if key in filters: - LOG.warning( - f"Filter `{key}` is registered in anemoi.datasets filter registry and in anemoi.transform filter registry" - ) - filters[key] = factory - - for key, factory in sources.items(): - if key in filters: - LOG.warning( - f"Source `{key}` is registered in anemoi.datasets source registry and in anemoi.transform filter registry" - ) - del filters[key] - - for key, factory in sources.items(): - key = key.replace("-", "_") - assert not hasattr(self, key) - setattr(self, key, SourceMaker(key, factory)) - - for key, factory in filters.items(): - key = key.replace("-", "_") - assert not hasattr(self, key) - setattr(self, key, FilterMaker(key, factory)) - - self.repeated_dates = SourceMaker("repeated_dates", None) - - def as_dict(self): - result = { - "name": self.name, - "description": self.description, - "attribution": self.attribution, - "licence": self.licence, - "dates": self.dates, - "statistics": self.statistics, - "build": self.build, - } - - if self._data_sources: - result["data_sources"] = self._data_sources - - for k, v in list(result.items()): - if v is None: - del result[k] - - return result - - def concat(self, *args, **kwargs): - return Concat(*args, **kwargs) - - def make_data_source(self, name, target): - - target = target.as_dict(self) - - name = name or "source" - if name in self._data_sources: - if self._data_sources[name] == target: - return f"${{data_sources.{name}}}" - - n = self._counter[name] - self._counter[name] += 1 - - name = f"{name}_{n}" if n > 0 else name - - self._data_sources[name] = target.copy() - return f"${{data_sources.{name}}}" - - def resolve(self, source, target, name=None): - - top = Index("input") # So we have 'input' first in the path - - path_to_source = [] - self.input.path(source, path_to_source, top) - if len(path_to_source) == 0: - raise ValueError(f"Source {source} not found in recipe") - if len(path_to_source) > 1: - raise ValueError(f"Source {source} found in multiple locations {path_to_source}") - path_to_source = path_to_source[0] - - path_to_target = [] - self.input.path(target, path_to_target, top) - if len(path_to_target) > 1: - raise ValueError(f"Target {target} found in multiple locations {path_to_target}") - - if len(path_to_target) == 0: - # Add a `data_sources` entry - return self.make_data_source(name, target) - - path_to_target = path_to_target[0] - - a = [s for s in path_to_target] - b = [s for s in path_to_source] - common_ancestor = None - while a[0] is b[0]: - common_ancestor = a[0] - a = a[1:] - b = b[1:] - - assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" - - if not common_ancestor.collocated(a, b): - source = ".".join(s.name for s in path_to_source) - target = ".".join(s.name for s in path_to_target) - raise ValueError( - f"Source ${{{source}}} and target ${{{target}}} are not collocated (i.e. they are not branch of a 'concat')" - ) - - target = ".".join(s.name for s in path_to_target) - return f"${{{target}}}" - - @property - def description(self): - return self._description - - @description.setter - def description(self, value): - self._description = value - - @property - def attribution(self): - return self._attribution - - @attribution.setter - def attribution(self, value): - self._attribution = value - - @property - def licence(self): - return self._licence - - @licence.setter - def licence(self, value): - self._licence = value - - @property - def name(self): - return self._name - - @name.setter - def name(self, value): - self._name = value - - @property - def dates(self): - return self._dates - - def _parse_dates(self, value): - - if isinstance(value, dict): - return value - - start = None - end = None - frequency = 1 - - if isinstance(value, (list, tuple)): - if len(value) in [2, 3]: - start = value[0] - end = value[1] - - if len(value) == 3: - frequency = frequency_to_string(frequency_to_timedelta(value[2])) - if isinstance(frequency, int): - frequency = f"{frequency}h" - - if start is None or end is None: - raise ValueError(f"Invalid dates {value}") - - if isinstance(frequency, int): - frequency = f"{frequency}h" - - return dict( - start=as_datetime(start), - end=as_datetime(end), - frequency=frequency, - ) - - @dates.setter - def dates(self, value): - self._dates = self._parse_dates(value) - - @property - def output(self): - return self._output - - @output.setter - def output(self, value): - self._output = value - - @property - def statistics(self): - return self._statistics - - @statistics.setter - def statistics(self, value): - self._statistics = value - - @property - def build(self): - return self._build - - @build.setter - def build(self, value): - self._build = value - - @property - def env(self): - return self._env - - @env.setter - def env(self, value): - self._env = value - - @property - def dataset_status(self): - return self._dataset_status - - @dataset_status.setter - def dataset_status(self, value): - self._dataset_status = value - - @property - def platform(self): - return self._platform - - @platform.setter - def platform(self, value): - self._platform = value - - def dump(self, file=sys.stdout): - input = self.input.as_dict(self) # First so we get the data_sources - - result = self.as_dict() - - result["input"] = input - - if self.output: - result["output"] = self.output - - if self.statistics: - result["statistics"] = self.statistics - - if self.build: - result["build"] = self.build - - if self.env: - result["env"] = self.env - - if self.dataset_status: - result["dataset_status"] = self.dataset_status - - if self.platform: - result["platform"] = self.platform - - from .dumper import yaml_dump - - yaml_dump(_un_dotdict(result), stream=file) - - def test(self, output="recipe.zarr"): - from argparse import ArgumentParser - - from anemoi.datasets.commands.create import command - - parser = ArgumentParser() - parser.add_argument("command", help="Command to run") - - cmd = command() - cmd.add_arguments(parser) - - with TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "recipe.yaml") - with open(path, "w") as file: - self.dump(file) - - args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) - cmd.run(args) - - -if __name__ == "__main__": - - r = Recipe() - r.description = "test" - - r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") - - m1 = r.mars(expver="0001", grid=[20, 20]) - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") - - r.input = m1 - - r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) - - # m0 = r.mars(expver="0000") - # c = r.concat( - # { - # ("190", "2000"): m0, - # ("2001", "2020"): r.mars(expver="0002"), - # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - # }, - # ) - - # c[("2031", "2033")] = r.mars(expver="0005") - - # r.input += c - - r.output.group_by = "day" - r.build.additions = True - r.statistics.end = "80%" - - r.dump() - r.test() From 5746b04798124cccf47c3a8393cfe5e79690eb14 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 21 Sep 2025 18:41:57 +0100 Subject: [PATCH 65/79] remove python generating code that will be in another PR --- src/anemoi/datasets/commands/recipe/__init__.py | 2 +- src/anemoi/datasets/create/__init__.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 7630547bc..45400806c 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,7 +15,6 @@ import yaml -from anemoi.datasets.create import config_to_python from anemoi.datasets.create import validate_config from .. import Command @@ -90,4 +89,5 @@ def run(self, args: Any) -> None: print(formatted, file=f) f.close() + command = Recipe diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 352b5a6d6..acaf3807d 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1658,5 +1658,3 @@ def _tidy(d): LOG.error("❌ Config validation failed (jsonschema):") LOG.error(e.message) raise - - From e4712f0c56184983b1178b642eecfa51426c3cf6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 21 Sep 2025 18:51:18 +0100 Subject: [PATCH 66/79] remove python generating code that will be in another PR --- src/anemoi/datasets/create/input/__init__.py | 3 - src/anemoi/datasets/create/input/action.py | 27 --- .../datasets/create/input/data_sources.py | 5 - src/anemoi/datasets/dates/__init__.py | 80 +++------ src/anemoi/datasets/dates/groups.py | 170 +++++------------- 5 files changed, 71 insertions(+), 214 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 22b98d07e..07eee7c79 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -60,9 +60,6 @@ def select(self, argument) -> Any: context = FieldContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) - def python_code(self, code): - return self.action.python_code(code) - def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: """Build an InputBuilder instance. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7e164c586..6ad2aa822 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -55,11 +55,6 @@ def __call__(self, context, argument): return context.register(results, self.path) - def python_code(self, code): - return code.concat( - {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} - ) - class Join(Action): def __init__(self, config, *path): @@ -80,9 +75,6 @@ def __call__(self, context, argument): return context.register(results, self.path) - def python_code(self, code) -> None: - return code.sum(a.python_code(code) for a in self.actions) - class Pipe(Action): def __init__(self, config, *path): @@ -104,9 +96,6 @@ def __call__(self, context, argument): return context.register(result, self.path) - def python_code(self, code) -> None: - return code.pipe(a.python_code(code) for a in self.actions) - class Function(Action): def __init__(self, config, *path): @@ -122,13 +111,6 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) - def python_code(self, code) -> str: - # For now... - if "source" in self.config: - source = action_factory(self.config["source"], *self.path, "source") - self.config["source"] = source.python_code(code) - return code.call(self.name, self.config) - class DatasetSourceMixin: def create_object(self, context, config): @@ -193,9 +175,6 @@ def __init__(self, config, *path): else: self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} - def python_code(self, code): - return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) - def __call__(self, context, argument): for name, source in self.sources.items(): context.register(source(context, argument), self.path + (name,)) @@ -206,12 +185,6 @@ def __init__(self, input, data_sources): self.input = input self.data_sources = data_sources - def python_code(self, code): - return code.recipe( - self.input.python_code(code), - self.data_sources.python_code(code), - ) - def __call__(self, context, argument): # Load data_sources self.data_sources(context, argument) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 9aa2429dd..31bf3d8cc 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -84,11 +84,6 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_code(self, code) -> str: - for n, s in zip(self.names, self.sources): - code.source(n, s.python_code(code)) - return code - class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 223736971..bc6dacafd 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -27,15 +27,13 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Parameters - ---------- - x : Union[str, List[Any], Tuple[Any, ...]] - A date range string or list/tuple of dates. - - Yields - ------ - datetime.datetime - Individual datetime objects. + Args: + x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. + + Returns + ------- + Iterator[datetime.datetime] + An iterator of datetime objects. """ if isinstance(x, (list, tuple)): @@ -162,12 +160,9 @@ def summary(self) -> str: class ValuesDates(DatesProvider): """Class for handling a list of date values. - Parameters - ---------- - values : List[Union[str, datetime.datetime]] - List of date values. - **kwargs : Any - Additional arguments. + Args: + values (List[Union[str, datetime.datetime]]): List of date values. + **kwargs (Any): Additional arguments. """ def __init__(self, values: list[str | datetime.datetime], **kwargs: Any) -> None: @@ -204,16 +199,11 @@ def as_dict(self) -> dict[str, Any]: class StartEndDates(DatesProvider): """Class for generating dates between a start and end date with a specified frequency. - Parameters - ---------- - start : Union[str, datetime.datetime] - Start date. - end : Union[str, datetime.datetime] - End date. - frequency : Union[int, str] - Frequency of dates. - **kwargs : Any - Additional arguments. + Args: + start (Union[str, datetime.datetime]): Start date. + end (Union[str, datetime.datetime]): End date. + frequency (Union[int, str]): Frequency of dates. + **kwargs (Any): Additional arguments. """ def __repr__(self) -> str: @@ -280,27 +270,15 @@ def as_dict(self) -> dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self) -> str: - """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" - if self.frequency == datetime.timedelta(hours=1): - return (self.start.isoformat(), self.end.isoformat()) - else: - return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) - class Hindcast: """Class representing a single hindcast date. - Parameters - ---------- - date : datetime.datetime - The date of the hindcast. - refdate : datetime.datetime - The reference date. - hdate : datetime.datetime - The hindcast date. - step : int - The step value. + Args: + date (datetime.datetime): The date of the hindcast. + refdate (datetime.datetime): The reference date. + hdate (datetime.datetime): The hindcast date. + step (int): The step value. """ def __init__( @@ -323,18 +301,12 @@ def __init__( class HindcastsDates(DatesProvider): """Class for generating hindcast dates over a range of years. - Parameters - ---------- - start : Union[str, List[str]] - Start date(s). - end : Union[str, List[str]] - End date(s). - steps : List[int] - List of step values. - years : int - Number of years. - **kwargs : Any - Additional arguments. + Args: + start (Union[str, List[str]]): Start date(s). + end (Union[str, List[str]]): End date(s). + steps (List[int]): List of step values. + years (int): Number of years. + **kwargs (Any): Additional arguments. """ def __init__( diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index b5454dbbd..a72fdaa74 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -24,15 +24,11 @@ def _shorten(dates: list[datetime.datetime] | tuple[datetime.datetime, ...]) -> str | list[str]: """Shorten the list of dates for display. - Parameters - ---------- - dates : Union[List[datetime.datetime], Tuple[datetime.datetime, ...]] - The list of dates. - - Returns - ------- - Union[str, List[str]] - The shortened list of dates. + Args: + dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. + + Returns: + Union[str, List[str]]: The shortened list of dates. """ if isinstance(dates, (list, tuple)): dates = [d.isoformat() for d in dates] @@ -45,17 +41,6 @@ class GroupOfDates: """A class to represent a group of dates.""" def __init__(self, dates: list[datetime.datetime], provider: DatesProvider, partial_ok: bool = False) -> None: - """Initialise a GroupOfDates instance. - - Parameters - ---------- - dates : List[datetime.datetime] - List of dates. - provider : DatesProvider - The dates provider. - partial_ok : bool, optional - Whether partial groups are allowed (default is False). - """ assert isinstance(provider, DatesProvider), type(provider) assert isinstance(dates, list) @@ -66,60 +51,35 @@ def __init__(self, dates: list[datetime.datetime], provider: DatesProvider, part def __len__(self) -> int: """Return the number of dates in the group. - Returns - ------- - int - The number of dates. + Returns: + int: The number of dates. """ return len(self.dates) - def __getitem__(self, index) -> datetime.datetime: - """Return the date at the specified index. - - Parameters - ---------- - index : int - The index of the date. - - Returns - ------- - datetime.datetime - The date at the specified index. - """ - return self.dates[index] - def __iter__(self) -> Iterator[datetime.datetime]: """Return an iterator over the dates in the group. - Returns - ------- - Iterator[datetime.datetime] - The iterator over the dates. + Returns: + Iterator[datetime.datetime]: The iterator over the dates. """ return iter(self.dates) def __repr__(self) -> str: """Return a string representation of the group of dates. - Returns - ------- - str - The string representation. + Returns: + str: The string representation. """ return f"GroupOfDates(dates={_shorten(self.dates)})" def __eq__(self, other: object) -> bool: """Check if two groups of dates are equal. - Parameters - ---------- - other : object - The other group of dates. + Args: + other (object): The other group of dates. - Returns - ------- - bool - True if the groups are equal, False otherwise. + Returns: + bool: True if the groups are equal, False otherwise. """ return isinstance(other, GroupOfDates) and self.dates == other.dates @@ -127,8 +87,7 @@ def __eq__(self, other: object) -> bool: class Groups: """A collection of groups of dates. - Examples - -------- + Examples: >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] @@ -160,8 +119,7 @@ def __init__(self, **kwargs: Any) -> None: Parameters ---------- - **kwargs : Any - Arbitrary keyword arguments. Expected keys include: + **kwargs : Any : Arbitrary keyword arguments. Expected keys include: - group_by: Configuration for the Grouper. - Other keys for DatesProvider configuration. """ @@ -179,10 +137,8 @@ def provider(self) -> DatesProvider: def __iter__(self) -> Iterator[GroupOfDates]: """Return an iterator over the groups of dates. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ for go in self._grouper(self._dates): dates = self._filter(go.dates) @@ -193,10 +149,8 @@ def __iter__(self) -> Iterator[GroupOfDates]: def __len__(self) -> int: """Return the number of groups of dates. - Returns - ------- - int - The number of groups. + Returns: + int: The number of groups. """ return self._len @@ -214,30 +168,24 @@ def _len(self) -> int: def __repr__(self) -> str: """Return a string representation of the groups of dates. - Returns - ------- - str - The string representation. + Returns: + str: The string representation. """ return f"{self.__class__.__name__}(dates={len(self)},{_shorten(self._dates)})" def describe(self) -> str: """Return a summary description of the dates. - Returns - ------- - str - The summary description. + Returns: + str: The summary description. """ return self._dates.summary def one_date(self) -> GroupOfDates: """Return a group containing only one date. - Returns - ------- - GroupOfDates - The group containing only one date. + Returns: + GroupOfDates: The group containing only one date. """ go = next(iter(self)) return GroupOfDates([go.dates[0]], go.provider) @@ -252,24 +200,22 @@ def __init__(self, missing: list[datetime.datetime]) -> None: def __call__(self, dates: list[datetime.datetime]) -> list[datetime.datetime]: """Filter out missing dates from the list of dates. - Parameters - ---------- - dates : List[datetime.datetime] - The list of dates. + Args: + dates (List[datetime.datetime]): The list of dates. - Returns - ------- - List[datetime.datetime] - The filtered list of dates. + Returns: + List[datetime.datetime]: The filtered list of dates. """ return [d for d in dates if d not in self.missing] class Grouper(ABC): + """Abstract base class for grouping dates.""" @classmethod def from_config(cls, group_by: Any) -> "Grouper": """Create a grouper based on the configuration.""" + if isinstance(group_by, int) and group_by > 0: return GrouperByFixedSize(group_by) @@ -329,15 +275,11 @@ class GrouperOneGroup(Grouper): def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group all dates into a single group. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ assert isinstance(dates, DatesProvider), type(dates) @@ -348,27 +290,16 @@ class GrouperByKey(Grouper): """Group dates by a key.""" def __init__(self, key: Callable[[datetime.datetime], Any]) -> None: - """Initialise GrouperByKey with a key function. - - Parameters - ---------- - key : Callable[[datetime.datetime], Any] - Function to extract grouping key from a datetime. - """ self.key = key def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates based on the provided key. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ for _, g in itertools.groupby(sorted(dates, key=self.key), key=self.key): yield GroupOfDates(list(g), dates) @@ -378,27 +309,16 @@ class GrouperByFixedSize(Grouper): """Group dates by a fixed size.""" def __init__(self, size: int) -> None: - """Initialise GrouperByFixedSize with batch size. - - Parameters - ---------- - size : int - Number of dates per group. - """ self.size = size def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates into fixed-size batches. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ batch = [] From 976c6a465e155d4b4191573a2eed00de4468edaf Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 21 Sep 2025 19:05:13 +0100 Subject: [PATCH 67/79] feat: python recipes --- docs/cli/grib-index.rst | 2 +- docs/cli/introduction.rst | 1 - docs/datasets/building/code/using-python-1.py | 3 + docs/datasets/building/code/using-python-2.py | 8 + docs/datasets/building/code/using-python-3.py | 12 + docs/datasets/building/code/using-python-4.py | 0 docs/datasets/building/filters.rst | 14 + docs/datasets/building/introduction.rst | 11 + .../building/sources/repeated-dates.rst | 2 +- docs/datasets/building/using-python.rst | 26 + .../datasets/commands/recipe/__init__.py | 21 +- src/anemoi/datasets/create/__init__.py | 24 + src/anemoi/datasets/create/config.py | 7 + src/anemoi/datasets/create/input/__init__.py | 3 + src/anemoi/datasets/create/input/action.py | 42 +- .../datasets/create/input/data_sources.py | 5 + src/anemoi/datasets/create/python.py | 578 ++++++++++++++++++ src/anemoi/datasets/recipe.py | 531 ++++++++++++++++ 18 files changed, 1283 insertions(+), 7 deletions(-) create mode 100644 docs/datasets/building/code/using-python-1.py create mode 100644 docs/datasets/building/code/using-python-2.py create mode 100644 docs/datasets/building/code/using-python-3.py create mode 100644 docs/datasets/building/code/using-python-4.py create mode 100644 docs/datasets/building/filters.rst create mode 100644 docs/datasets/building/using-python.rst create mode 100644 src/anemoi/datasets/create/python.py create mode 100644 src/anemoi/datasets/recipe.py diff --git a/docs/cli/grib-index.rst b/docs/cli/grib-index.rst index 2b97b0157..61684dbfb 100644 --- a/docs/cli/grib-index.rst +++ b/docs/cli/grib-index.rst @@ -1,7 +1,7 @@ .. _grib-index_command: Grib-index Command -============ +================== The `grib-index` command is used to create an index file for GRIB files. The index file is then used by the `grib-index` :ref:`source `. diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index 45facf21e..8c574132c 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -20,5 +20,4 @@ The commands are: - :ref:`Inspect Command ` - :ref:`Compare Command ` - :ref:`Scan Command ` -- :ref:`Validate Command ` - :ref:`Compare LAM Command ` diff --git a/docs/datasets/building/code/using-python-1.py b/docs/datasets/building/code/using-python-1.py new file mode 100644 index 000000000..196a25f42 --- /dev/null +++ b/docs/datasets/building/code/using-python-1.py @@ -0,0 +1,3 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() diff --git a/docs/datasets/building/code/using-python-2.py b/docs/datasets/building/code/using-python-2.py new file mode 100644 index 000000000..717129592 --- /dev/null +++ b/docs/datasets/building/code/using-python-2.py @@ -0,0 +1,8 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe( + description="Example dataset recipe", + name="example-dataset", + licence="CC-BY-4.0", + attribution="my-organisation", +) diff --git a/docs/datasets/building/code/using-python-3.py b/docs/datasets/building/code/using-python-3.py new file mode 100644 index 000000000..f21dc3947 --- /dev/null +++ b/docs/datasets/building/code/using-python-3.py @@ -0,0 +1,12 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.description = """ +Example dataset recipe using Python, with attributes set one by one +and a multiline description. +""" + +r.name = "example-dataset" +r.licence = "CC-BY-4.0" +r.attribution = "my-organisation" diff --git a/docs/datasets/building/code/using-python-4.py b/docs/datasets/building/code/using-python-4.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/datasets/building/filters.rst b/docs/datasets/building/filters.rst new file mode 100644 index 000000000..3b3bd5abf --- /dev/null +++ b/docs/datasets/building/filters.rst @@ -0,0 +1,14 @@ +.. _filters: + +######### + Filters +######### + +.. warning:: + + This is still a work-in-progress. Some of the filters may be renamed + later. + +Filters are used to modify the data or metadata in a dataset. + +See :ref:`install ` for more information. diff --git a/docs/datasets/building/introduction.rst b/docs/datasets/building/introduction.rst index 71107baeb..2054c75c5 100644 --- a/docs/datasets/building/introduction.rst +++ b/docs/datasets/building/introduction.rst @@ -105,3 +105,14 @@ operations can be combined to build complex datasets. :caption: Naming Conventions naming-conventions + +**************** + Python recipes +**************** + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Python recipes + + using-python diff --git a/docs/datasets/building/sources/repeated-dates.rst b/docs/datasets/building/sources/repeated-dates.rst index ba16e4707..53baf3283 100644 --- a/docs/datasets/building/sources/repeated-dates.rst +++ b/docs/datasets/building/sources/repeated-dates.rst @@ -10,7 +10,7 @@ dates of the dataset. The general format of the `repeated-dates` source is: -.. literalinclude:: yaml/repeated_dates1.yaml +.. literalinclude:: yaml/repeated-dates1.yaml :language: yaml where ``source`` is any of the :ref:`operations ` or diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst new file mode 100644 index 000000000..fbf2892cf --- /dev/null +++ b/docs/datasets/building/using-python.rst @@ -0,0 +1,26 @@ +############################# + Using Python define recipes +############################# + +You can use Python to define recipes for building datasets. This allows +for more complex logic and flexibility compared to using static +configuration files. + +When executed, the Python code will generate a YAML configuration that +can be used by the dataset building tool. + +Here is an example of how to define a dataset recipe using Python. + +First create a ``Recipe`` object, which will hold the configuration: + +.. literalinclude:: code/using-python-1.py + :language: python + +.. literalinclude:: code/using-python-2.py + :language: python + +.. literalinclude:: code/using-python-3.py + :language: python + +.. literalinclude:: code/using-python-4.py + :language: python diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 45400806c..bf08d1ee7 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,6 +15,7 @@ import yaml +from anemoi.datasets.create import config_to_python from anemoi.datasets.create import validate_config from .. import Command @@ -37,6 +38,7 @@ def add_arguments(self, command_parser: Any) -> None: command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") command_parser.add_argument("--format", action="store_true", help="Format the recipe.") command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") group = command_parser.add_mutually_exclusive_group() group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") @@ -49,7 +51,7 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: - if not args.validate and not args.format and not args.migrate: + if not args.validate and not args.format and not args.migrate and not args.python: args.validate = True with open(args.path) as file: @@ -58,10 +60,10 @@ def run(self, args: Any) -> None: assert isinstance(config, dict) if args.validate: - if args.inplace and (not args.format and not args.migrate): + if args.inplace and (not args.format and not args.migrate and not args.python): argparse.ArgumentError(None, "--inplace is not supported with --validate.") - if args.output and (not args.format and not args.migrate): + if args.output and (not args.format and not args.migrate and not args.python): argparse.ArgumentError(None, "--output is not supported with --validate.") validate_config(config) @@ -89,5 +91,18 @@ def run(self, args: Any) -> None: print(formatted, file=f) f.close() + if args.python: + if args.inplace: + argparse.ArgumentError(None, "Inplace conversion to Python is not supported.") + + if args.format: + raise argparse.ArgumentError(None, "Formatting is not supported when converting to Python.") + + if args.output: + with open(args.output, "w") as file: + file.write(config_to_python(config)) + else: + print(config_to_python(config)) + command = Recipe diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index acaf3807d..3c615e21f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1658,3 +1658,27 @@ def _tidy(d): LOG.error("❌ Config validation failed (jsonschema):") LOG.error(e.message) raise + + +def config_to_python(config: Any) -> Any: + + from ..create.python import PythonScript + + raw_config = config + + config = loader_config(config) + + input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) + + code = PythonScript() + x = input.python_code(code) + code = code.source_code(x, raw_config) + + try: + import black + + return black.format_str(code, mode=black.Mode()) + # except ImportError: + except Exception: + LOG.warning("Black not installed, skipping formatting") + return code diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index ffee2662f..2e5f27de7 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -10,6 +10,8 @@ import datetime import logging import os +import subprocess +import sys from copy import deepcopy from typing import Any @@ -402,6 +404,11 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: LoadersConfig The validated configuration object. """ + + if isinstance(config, str) and config.endswith(".py"): + result = subprocess.run([sys.executable, config], capture_output=True, text=True, check=True) + config = yaml.safe_load(result.stdout) + config = Config(config) if is_test: set_to_test_mode(config) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 07eee7c79..22b98d07e 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -60,6 +60,9 @@ def select(self, argument) -> Any: context = FieldContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) + def python_code(self, code): + return self.action.python_code(code) + def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: """Build an InputBuilder instance. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 6ad2aa822..bd037c0ea 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -8,13 +8,15 @@ # nor does it submit to any jurisdiction. import logging +from abc import ABC +from abc import abstractmethod from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) -class Action: +class Action(ABC): def __init__(self, config, *path): self.config = config self.path = path @@ -23,6 +25,17 @@ def __init__(self, config, *path): "data_sources", ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" + @abstractmethod + def __call__(self, context, argument): + pass + + @abstractmethod + def python_code(self, code): + pass + + def __repr__(self): + return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" + class Concat(Action): def __init__(self, config, *path): @@ -55,6 +68,11 @@ def __call__(self, context, argument): return context.register(results, self.path) + def python_code(self, code): + return code.concat( + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} + ) + class Join(Action): def __init__(self, config, *path): @@ -75,6 +93,9 @@ def __call__(self, context, argument): return context.register(results, self.path) + def python_code(self, code) -> None: + return code.sum(a.python_code(code) for a in self.actions) + class Pipe(Action): def __init__(self, config, *path): @@ -96,6 +117,9 @@ def __call__(self, context, argument): return context.register(result, self.path) + def python_code(self, code) -> None: + return code.pipe(a.python_code(code) for a in self.actions) + class Function(Action): def __init__(self, config, *path): @@ -111,6 +135,13 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) + def python_code(self, code) -> str: + # For now... + if "source" in self.config: + source = action_factory(self.config["source"], *self.path, "source") + self.config["source"] = source.python_code(code) + return code.call(self.name, self.config) + class DatasetSourceMixin: def create_object(self, context, config): @@ -175,6 +206,9 @@ def __init__(self, config, *path): else: self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} + def python_code(self, code): + return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + def __call__(self, context, argument): for name, source in self.sources.items(): context.register(source(context, argument), self.path + (name,)) @@ -185,6 +219,12 @@ def __init__(self, input, data_sources): self.input = input self.data_sources = data_sources + def python_code(self, code): + return code.recipe( + self.input.python_code(code), + self.data_sources.python_code(code), + ) + def __call__(self, context, argument): # Load data_sources self.data_sources(context, argument) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 31bf3d8cc..9aa2429dd 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -84,6 +84,11 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) + def python_code(self, code) -> str: + for n, s in zip(self.names, self.sources): + code.source(n, s.python_code(code)) + return code + class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py new file mode 100644 index 000000000..29b8c611d --- /dev/null +++ b/src/anemoi/datasets/create/python.py @@ -0,0 +1,578 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import re +from collections import defaultdict +from functools import cached_property + +LOG = logging.getLogger(__name__) + +RESERVED_KEYWORDS = ( + "and", + "or", + "not", + "is", + "in", + "if", + "else", + "elif", + "for", + "while", + "return", + "class", + "def", + "with", + "as", + "import", + "from", + "try", + "except", + "finally", + "raise", + "assert", + "break", + "continue", + "pass", +) + + +def _sanitize_name(name): + name = name.replace("-", "_") + if name in RESERVED_KEYWORDS: + name = f"{name}_" + return name + + +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + +class PythonCode: + + def __init__(self, top): + self.top = top + self.top.register(self) + self.key = str(id(self)) + + def call(self, name, argument): + return PythonCall(self.top, name, argument) + + def sum(self, actions): + return PythonChain(self.top, "join", "&", actions) + + def pipe(self, actions): + return PythonChain(self.top, "pipe", "|", actions) + + def concat(self, argument): + return PythonConcat(self.top, argument) + + def source_code(self): + return self.top.source_code(self) + + def combine(self, nodes): + return None + + def recipe(self, input, data_sources): + return PythonRecipe(self.top, input, data_sources) + + def prelude(self): + return None + + def sources(self, sources): + return PythonSources(self.top, sources) + + def update_anchor(self): + pass + + +class Variable(PythonCode): + def __init__(self, name, node): + super().__init__(top=node.top) + self.name = name + self.node = node + + def __repr__(self): + return "" + + def replace_node(self, old, new): + pass + + def prelude(self): + return [f"{self.name} = {repr(self.node)}", ""] + + +class InLine(PythonCode): + def __init__(self, node): + super().__init__(top=node.top) + self.node = node + + @cached_property + def name(self): + n = self.top.counter["_anchor"] + self.top.counter["_anchor"] += 1 + return f"_a{n}" + + def __repr__(self): + return f"({self.name} := {repr(self.node)})" + + def replace_node(self, old, new): + pass + + +class PythonRecipe(PythonCode): + def __init__(self, top, input, data_sources): + super().__init__(top) + self.input = input + self.data_sources = data_sources + + def apply_references(self, *path): + self.data_sources.apply_references(*path, "data_sources") + self.input.apply_references(*path, "input") + + def replace_node(self, old, new): + if self.input is old: + self.input = new + return + + if self.data_sources is old: + self.data_sources = new + return + + self.input.replace_node(old, new) + self.data_sources.replace_node(old, new) + + def __repr__(self): + return repr(self.input) + + def prelude(self): + return self.data_sources.prelude() + + +class Argument(PythonCode): + + def __init__(self, top, name): + super().__init__(top=top) + self.name = _sanitize_name(name) + + def __repr__(self): + return self.name + + def replace_node(self, old, new): + pass + + +class Anchor(PythonCode): + + def __init__(self, identifier): + super().__init__(top=identifier.node.top) + self.identifier = identifier + + @property + def name(self): + return self.identifier.name + + def __repr__(self): + # assert False + return repr(self.identifier) + + def replace_node(self, old, new): + pass + + +class Reference(PythonCode): + + def __init__(self, top, path): + super().__init__(top) + self.path = tuple(path) + self.anchor = None + + def update_anchor(self): + + node = self.top.by_reference.get(self.path, None) + if node is None: + LOG.warning(f"Reference {self.path} not found") + for p in sorted(self.top.by_reference): + LOG.warning(f" - {p}") + else: + self.anchor = Anchor(node) + self.top.replace_nodes([(node.node, self.anchor)]) + + def __repr__(self): + if self.anchor is not None: + return self.anchor.name + + return f"'${{{'.'.join(self.path)}}}'" + + def replace_node(self, old, new): + pass + + +class Function(PythonCode): + def __init__(self, name, node, counter): + super().__init__(top=node.top) + self._name = name + self.node = node + self.used = False + self.counter = counter + + def __repr__(self): + return self.name + + def prelude(self): + if self.used: + return None + + self.used = True + + node_prelude = self.node.prelude() + + arguments = self.node.free_arguments() + + return [ + *(node_prelude if node_prelude else []), + f"def {self.name}({','.join(repr(p) for p in arguments)}):", + f" return {self.node}", + ] + + def free_arguments(self): + return self.node.free_arguments() + + @cached_property + def name(self): + n = self.counter[self._name] + self.counter[self._name] += 1 + if n == 0: + return _sanitize_name(self._name) + return _sanitize_name(f"{self._name}_{n}") + + def replace_node(self, old, new): + if self.node is old: + self.node = new + + +class PythonSources(PythonCode): + def __init__(self, top, sources): + super().__init__(top) + self.sources = sources + + def __repr__(self): + return "" + + def prelude(self): + pass + + def replace_node(self, old, new): + for k, v in list(self.sources.items()): + if v is old: + self.sources[k] = new + else: + v.replace_node(old, new) + + def apply_references(self, *path): + for k, v in self.sources.items(): + self.top.by_reference[path + (k,)] = Variable(k, v) + + +class PythonConcat(PythonCode): + def __init__(self, top, argument): + super().__init__(top=top) + self.argument = _un_dotdict(argument) + + def __repr__(self): + return f"r.concat({self.argument})" + + def replace_node(self, old, new): + for k, v in list(self.argument.items()): + if v is old: + self.argument[k] = new + else: + v.replace_node(old, new) + + def apply_references(self, *path): + assert "concat" not in path, path + self.top.by_reference[path + ("concat",)] = InLine(self) + for i, node in enumerate(self.argument.values()): + node.apply_references(*path, "concat", str(i)) + + +class PythonChain(PythonCode): + def __init__(self, top, kind, op, actions): + super().__init__(top=top) + self.op = op + self.kind = kind + self.actions = list(actions) + self.key = op + + def __repr__(self): + return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + def replace_node(self, old, new): + + for i, node in enumerate(self.actions): + + if node is old: + self.actions[i] = new + else: + node.replace_node(old, new) + + def apply_references(self, *path): + self.top.by_reference[path + (self.kind,)] = InLine(self) + for i, node in enumerate(self.actions): + node.apply_references(*path, self.kind, str(i)) + + +class PythonCall(PythonCode): + def __init__(self, top, name, argument): + super().__init__(top=top) + self.name = name + self.argument = _un_dotdict(argument) + self.key = name + + def free_arguments(self): + result = [] + for k, v in self.argument.items(): + if isinstance(v, Argument): + result.append(v) + return result + + def __repr__(self): + name = self.name.replace("-", "_") + config = dict(**self.argument) + + params = [] + + for k, v in config.items(): + k = _sanitize_name(k) + + if not k.isidentifier(): + return f"r.{name}({config})" + + params.append(f"{k}={repr(v)}") + + if params: + params.append("") # For a trailing comma + + params = ",".join(params) + return f"r.{name}({params})" + + def replace_node(self, old, new): + pass + + def combine(self, nodes): + + # Exact similarity + + changes = self._combine0(nodes) + if changes: + return changes + + # On key difference + changes = self._combine1(nodes) + if changes: + return changes + + def _combine0(self, nodes): + + x = defaultdict(list) + for node in nodes: + key = {k2: v2 for k2, v2 in sorted(node.argument.items())} + x[str(key)].append(node) + + for i in sorted(x.values(), key=len, reverse=True): + node = i[0] + if len(i) < 2: + return + + call = PythonCall(self.top, self.name, node.argument) + + func = self.top.function(call) + changes = [] + for node in i: + + new = PythonFunction(top=self.top, func=func) + + changes.append((node, new)) + + return changes + + def _combine1(self, nodes): + + x = defaultdict(list) + for node in nodes: + argument = node.argument + for k, v in argument.items(): + rest = {k2: v2 for k2, v2 in sorted(argument.items()) if k2 != k} + x[str(rest)].append((k, v, node)) + + for i in sorted(x.values(), key=len, reverse=True): + key, value, node = i[0] + if len(i) < 2: + return + + rest = {k: v for k, v in node.argument.items() if k != key} + rest[key] = Argument(self.top, key) + call = PythonCall(self.top, self.name, rest) + + func = self.top.function(call) + changes = [] + for key, value, node in i: + + new = PythonFunction( + top=self.top, + func=func, + **{key: value}, + ) + + changes.append((node, new)) + + return changes + + def apply_references(self, *path): + self.top.by_reference[path + (self.name,)] = InLine(self) + + for k, v in self.argument.items(): + if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): + path = m.group(1).split(".") + self.argument[k] = Reference(self.top, path) + + +class PythonFunction(PythonCode): + def __init__(self, top, func, **kwargs): + super().__init__(top=top) + self.func = func + self.kwargs = kwargs + + def __repr__(self): + + params = [] + for a in self.func.free_arguments(): + name = _sanitize_name(a.name) + if a.name in self.kwargs: + v = self.kwargs[a.name] + params.append(f"{name}={repr(v)}") + else: + params.append(f"{name}={name}") + + return f"{self.func}({', '.join(params)})" + + def replace_node(self, old, new): + self.func.replace_node(old, new) + + def prelude(self): + return self.func.prelude() + + def free_arguments(self): + return [a for a in self.func.free_arguments() if a.name not in self.kwargs] + + def apply_references(self, *path): + pass + + +class PythonScript(PythonCode): + + def __init__(self): + self.nodes = [] + self.counter = defaultdict(int) + self.by_reference = {} + super().__init__(top=self) + + def register(self, child): + if child is not self: + self.nodes.append(child) + + def prelude(self, config): + + from anemoi.datasets.recipe import Recipe + + SKIP = ( + "input", + "data_sources", + "common", + "aliases", + ) + + result = [] + + for k, v in config.items(): + + if k in SKIP: + continue + + if not hasattr(Recipe, k): + LOG.warning(f"Unknown key in recipe: {k}") + assert False, f"Unknown key in recipe: {k}" + continue + + result.append(f"r.{k} = {repr(v)}") + + for node in self.nodes: + prelude = node.prelude() + if prelude: + if not isinstance(prelude, (list, tuple)): + prelude = list(prelude) + result.extend(prelude) + return "\n".join(result) + + def source_code(self, first, config): + + which = self.nodes.index(first) + first.apply_references() + for node in self.nodes: + node.update_anchor() + + more = True + while more: + more = False + + by_class = defaultdict(list) + for node in self.nodes: + by_class[(node.__class__, node.key)].append(node) + + for nodes in by_class.values(): + if len(nodes) > 1: + changes = nodes[0].combine(nodes) + if changes: + self.replace_nodes(changes) + more = True + + first = self.nodes[which] + + return "\n\n".join( + [ + "# Generated Python code for Anemoi dataset creation", + "import datetime", + "from anemoi.datasets.recipe import Recipe", + "r = Recipe()", + self.prelude(config), + f"r.input = {repr(first)}", + "r.dump()", + ] + ) + + def function(self, node): + return Function(node.name, node, self.counter) + + def replace_nodes(self, changes): + + for old, new in changes: + assert old in self.nodes, f"Node {old} not found in {self.nodes}" + for i, node in enumerate(self.nodes): + + if node is old: + self.nodes[i] = new + else: + node.replace_node(old, new) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py new file mode 100644 index 000000000..6f8118174 --- /dev/null +++ b/src/anemoi/datasets/recipe.py @@ -0,0 +1,531 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os +import sys +from collections import defaultdict +from tempfile import TemporaryDirectory + +from anemoi.transform.filters import filter_registry as transform_filter_registry +from anemoi.utils.config import DotDict +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.sources import source_registry + +LOG = logging.getLogger(__name__) + + +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + +class Index: + def __init__(self, index): + self.name = str(index) + + def __repr__(self): + return f"Index({self.name})" + + def same(self, other): + if not isinstance(other, Index): + return False + return self.name == other.name + + +class Step: + + def __or__(self, other): + return Pipe(self, other) + + def __and__(self, other): + return Join(self, other) + + def same(self, other): + return self is other + + +class Chain(Step): + def __init__(self, *args): + if len(args) > 0 and isinstance(args[0], self.__class__): + args = args[0].steps + args[1:] + + self.steps = args + self.index = [Index(i) for i in range(len(self.steps))] + + def as_dict(self, recipe): + if len(self.steps) == 1: + return self.steps[0].as_dict(recipe) + return {self.name: [s.as_dict(recipe) for s in self.steps]} + + def path(self, target, result, *path): + + if target is self: + result.append([*path, self]) + return + + for i, s in enumerate(self.steps): + s.path(target, result, *path, self, self.index[i]) + + def collocated(self, a, b): + return True + + +class Pipe(Chain): + name = "pipe" + + +class Join(Chain): + name = "join" + + +class Concat(Step): + name = "concat" + + def __init__(self, args): + assert isinstance(args, dict), f"Invalid argument {args}" + self.params = args + + def __setitem__(self, key, value): + self.params[key] = value + + def as_dict(self, recipe): + + result = [] + + for k, v in sorted(self.params.items()): + + key = dict(start=as_datetime(k[0]), end=as_datetime(k[1])) + if len(k) == 3: + key["frequency"] = k[2] + + result.append({"dates": key, **v.as_dict(recipe)}) + + return {"concat": result} + + def collocated(self, a, b): + return a[0].same(b[0]) + + def path(self, target, result, *path): + if target is self: + result.append([*path, self]) + return + for i, (k, v) in enumerate(sorted(self.params.items())): + v.path(target, result, *path, self, Index(i)) + + +class Base(Step): + def __init__(self, owner, *args, **kwargs): + self.owner = owner + self.name = owner.name + self.params = {} + for a in args: + assert isinstance(a, dict), f"Invalid argument {a}" + self.params.update(a) + self.params.update(kwargs) + + def as_dict(self, recipe): + + def resolve(params, recipe, name=None): + if isinstance(params, dict): + + def _(k): + if isinstance(k, str) and k.endswith("_"): + return k[:-1] + return k + + return {_(k): resolve(v, recipe, name=_(k)) for k, v in params.items()} + + if isinstance(params, (list, tuple)): + return [resolve(v, recipe) for v in params] + + if isinstance(params, list): + return [resolve(v, recipe) for v in params] + + if isinstance(params, Step): + return recipe.resolve(self, params, name=name) + + return params + + return {self.owner.name: resolve(self.params, recipe)} + + def path(self, target, result, *path): + if self is target: + result.append([*path, self]) + + +class Source(Base): + pass + + +class Filter(Base): + pass + + +class SourceMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + return Source(self, *args, **kwargs) + + +class FilterMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], Step): + prev = args[0] + args = args[1:] + return Pipe(prev, Filter(self, *args, **kwargs)) + return Filter(self, *args, **kwargs) + + +class Recipe: + + def __init__(self, name=None, description=None, attribution=None, licence=None): + + self._description = description + self._attribution = attribution + self._licence = licence + self._name = name + self._dates = None + self._statistics = None + self._build = None + self._env = None + self._dataset_status = None + self._output = None + self._platform = None + + self.input = Join() + self.output = DotDict() + self.statistics = DotDict() + self.build = DotDict() + + self._data_sources = {} + self._counter = defaultdict(int) + + sources = source_registry.factories.copy() + filters = transform_filter_registry.factories.copy() + + for key, factory in sources.items(): + if key in filters: + LOG.warning( + f"Source `{key}` is registered in anemoi.datasets source registry and in anemoi.transform filter registry" + ) + del filters[key] + + for key, factory in sources.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, SourceMaker(key, factory)) + + for key, factory in filters.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, FilterMaker(key, factory)) + + self.repeated_dates = SourceMaker("repeated_dates", None) + + def as_dict(self): + result = { + "name": self.name, + "description": self.description, + "attribution": self.attribution, + "licence": self.licence, + "dates": self.dates, + "statistics": self.statistics, + "build": self.build, + } + + if self._data_sources: + result["data_sources"] = self._data_sources + + for k, v in list(result.items()): + if v is None: + del result[k] + + return result + + def concat(self, *args, **kwargs): + return Concat(*args, **kwargs) + + def make_data_source(self, name, target): + + target = target.as_dict(self) + + name = name or "source" + if name in self._data_sources: + if self._data_sources[name] == target: + return f"${{data_sources.{name}}}" + + n = self._counter[name] + self._counter[name] += 1 + + name = f"{name}_{n}" if n > 0 else name + + self._data_sources[name] = target.copy() + return f"${{data_sources.{name}}}" + + def resolve(self, source, target, name=None): + + top = Index("input") # So we have 'input' first in the path + + path_to_source = [] + self.input.path(source, path_to_source, top) + if len(path_to_source) == 0: + raise ValueError(f"Source {source} not found in recipe") + if len(path_to_source) > 1: + raise ValueError(f"Source {source} found in multiple locations {path_to_source}") + path_to_source = path_to_source[0] + + path_to_target = [] + self.input.path(target, path_to_target, top) + if len(path_to_target) > 1: + raise ValueError(f"Target {target} found in multiple locations {path_to_target}") + + if len(path_to_target) == 0: + # Add a `data_sources` entry + return self.make_data_source(name, target) + + path_to_target = path_to_target[0] + + a = [s for s in path_to_target] + b = [s for s in path_to_source] + common_ancestor = None + while a[0] is b[0]: + common_ancestor = a[0] + a = a[1:] + b = b[1:] + + assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" + + if not common_ancestor.collocated(a, b): + source = ".".join(s.name for s in path_to_source) + target = ".".join(s.name for s in path_to_target) + raise ValueError( + f"Source ${{{source}}} and target ${{{target}}} are not collocated (i.e. they are not branch of a 'concat')" + ) + + target = ".".join(s.name for s in path_to_target) + return f"${{{target}}}" + + @property + def description(self): + return self._description + + @description.setter + def description(self, value): + self._description = value.strip() + + @property + def attribution(self): + return self._attribution + + @attribution.setter + def attribution(self, value): + self._attribution = value.strip() + + @property + def licence(self): + return self._licence + + @licence.setter + def licence(self, value): + self._licence = value.strip() + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value.strip() + + @property + def dates(self): + return self._dates + + def _parse_dates(self, value): + + if isinstance(value, dict): + return value + + start = None + end = None + frequency = 1 + + if isinstance(value, (list, tuple)): + if len(value) in [2, 3]: + start = value[0] + end = value[1] + + if len(value) == 3: + frequency = frequency_to_string(frequency_to_timedelta(value[2])) + if isinstance(frequency, int): + frequency = f"{frequency}h" + + if start is None or end is None: + raise ValueError(f"Invalid dates {value}") + + if isinstance(frequency, int): + frequency = f"{frequency}h" + + return dict( + start=as_datetime(start), + end=as_datetime(end), + frequency=frequency, + ) + + @dates.setter + def dates(self, value): + self._dates = self._parse_dates(value) + + @property + def output(self): + return self._output + + @output.setter + def output(self, value): + self._output = value + + @property + def statistics(self): + return self._statistics + + @statistics.setter + def statistics(self, value): + self._statistics = value + + @property + def build(self): + return self._build + + @build.setter + def build(self, value): + self._build = value + + @property + def env(self): + return self._env + + @env.setter + def env(self, value): + self._env = value + + @property + def dataset_status(self): + return self._dataset_status + + @dataset_status.setter + def dataset_status(self, value): + self._dataset_status = value + + @property + def platform(self): + return self._platform + + @platform.setter + def platform(self, value): + self._platform = value + + def dump(self, file=sys.stdout): + input = self.input.as_dict(self) # First so we get the data_sources + + result = self.as_dict() + + result["input"] = input + + if self.output: + result["output"] = self.output + + if self.statistics: + result["statistics"] = self.statistics + + if self.build: + result["build"] = self.build + + if self.env: + result["env"] = self.env + + if self.dataset_status: + result["dataset_status"] = self.dataset_status + + if self.platform: + result["platform"] = self.platform + + from .dumper import yaml_dump + + yaml_dump(_un_dotdict(result), stream=file) + + def test(self, output="recipe.zarr"): + from argparse import ArgumentParser + + from anemoi.datasets.commands.create import command + + parser = ArgumentParser() + parser.add_argument("command", help="Command to run") + + cmd = command() + cmd.add_arguments(parser) + + with TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "recipe.yaml") + with open(path, "w") as file: + self.dump(file) + + args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) + cmd.run(args) + + +if __name__ == "__main__": + + r = Recipe() + r.description = "test" + + r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") + + m1 = r.mars(expver="0001", grid=[20, 20]) + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") + + r.input = m1 + + r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) + + # m0 = r.mars(expver="0000") + # c = r.concat( + # { + # ("190", "2000"): m0, + # ("2001", "2020"): r.mars(expver="0002"), + # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + # }, + # ) + + # c[("2031", "2033")] = r.mars(expver="0005") + + # r.input += c + + r.output.group_by = "day" + r.build.additions = True + r.statistics.end = "80%" + + r.dump() + r.test() From 6a3cab8ea5b2059834814653d1cd79deb36fb617 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 09:29:22 +0100 Subject: [PATCH 68/79] work on doc --- docs/Makefile | 1 - docs/cli/validate.rst | 2 +- docs/datasets/building/code/using-python-4.py | 23 +++++++++++++++++++ .../datasets/commands/recipe/__init__.py | 6 +++-- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docs/Makefile b/docs/Makefile index 12f080733..6c0762a44 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -19,5 +19,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - bash $(SOURCEDIR)/scripts/api_build.sh @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/cli/validate.rst b/docs/cli/validate.rst index 57d5bf9e8..56aa0fbc7 100644 --- a/docs/cli/validate.rst +++ b/docs/cli/validate.rst @@ -1,7 +1,7 @@ .. _validate_command: Validate Command -============ +================ Use this command to validate a zarr dataset, or a class that implements the :class:`anemoi.datasets.Dataset` interface. diff --git a/docs/datasets/building/code/using-python-4.py b/docs/datasets/building/code/using-python-4.py index e69de29bb..29ac3b9d4 100644 --- a/docs/datasets/building/code/using-python-4.py +++ b/docs/datasets/building/code/using-python-4.py @@ -0,0 +1,23 @@ +from datetime import datetime + +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +# As a tuple (start, end, frequency) +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +# As a dictionary +r.dates = { + "start": "2023-01-01T00:00:00", + "end": "2023-12-31T18:00:00", + "frequency": "12h", +} + +# You can also provide datetime objects + +r.dates = { + "start": datetime(2023, 1, 1, 0, 0), + "end": datetime(2023, 12, 31, 18, 0), + "frequency": "12h", +} diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index bf08d1ee7..7cbc2d9ef 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -19,8 +19,6 @@ from anemoi.datasets.create import validate_config from .. import Command -from .format import format_recipe -from .migrate import migrate_recipe LOG = logging.getLogger(__name__) @@ -71,6 +69,8 @@ def run(self, args: Any) -> None: return if args.migrate: + from .migrate import migrate_recipe + config = migrate_recipe(args, config) if config is None: LOG.info(f"{args.path}: No changes needed.") @@ -79,6 +79,8 @@ def run(self, args: Any) -> None: args.format = True if args.format: + from .format import format_recipe + formatted = format_recipe(args, config) assert "dates" in formatted f = sys.stdout From 0cc80268160ef47bd4c92c93be409bb13d2d0304 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 11:59:43 +0100 Subject: [PATCH 69/79] update docs --- docs/datasets/building/code/using-python-4.py | 1 - docs/datasets/building/code/using-python-5.py | 7 ++++ docs/datasets/building/code/using-python-6.py | 7 ++++ docs/datasets/building/code/using-python-7.py | 8 ++++ docs/datasets/building/code/using-python-8.py | 10 +++++ docs/datasets/building/using-python.rst | 40 +++++++++++++++++++ .../building/yaml/using-python-1.yaml | 15 +++++++ src/anemoi/datasets/recipe.py | 22 +++++----- 8 files changed, 97 insertions(+), 13 deletions(-) create mode 100644 docs/datasets/building/code/using-python-5.py create mode 100644 docs/datasets/building/code/using-python-6.py create mode 100644 docs/datasets/building/code/using-python-7.py create mode 100644 docs/datasets/building/code/using-python-8.py create mode 100644 docs/datasets/building/yaml/using-python-1.yaml diff --git a/docs/datasets/building/code/using-python-4.py b/docs/datasets/building/code/using-python-4.py index 29ac3b9d4..169203146 100644 --- a/docs/datasets/building/code/using-python-4.py +++ b/docs/datasets/building/code/using-python-4.py @@ -15,7 +15,6 @@ } # You can also provide datetime objects - r.dates = { "start": datetime(2023, 1, 1, 0, 0), "end": datetime(2023, 12, 31, 18, 0), diff --git a/docs/datasets/building/code/using-python-5.py b/docs/datasets/building/code/using-python-5.py new file mode 100644 index 000000000..a6c5881c4 --- /dev/null +++ b/docs/datasets/building/code/using-python-5.py @@ -0,0 +1,7 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = r.grib(path="data/*.grib") | r.clip(minimum=0, maximum=100) diff --git a/docs/datasets/building/code/using-python-6.py b/docs/datasets/building/code/using-python-6.py new file mode 100644 index 000000000..02b18afdf --- /dev/null +++ b/docs/datasets/building/code/using-python-6.py @@ -0,0 +1,7 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = r.grib(path="dir1/*.grib") & r.grib(path="dir2/*.grib") diff --git a/docs/datasets/building/code/using-python-7.py b/docs/datasets/building/code/using-python-7.py new file mode 100644 index 000000000..d6d62de1f --- /dev/null +++ b/docs/datasets/building/code/using-python-7.py @@ -0,0 +1,8 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = (r.grib(path="dir1/*.grib") & r.grib(path="dir2/*.grib")) | r.clip(minimum=0, maximum=100) diff --git a/docs/datasets/building/code/using-python-8.py b/docs/datasets/building/code/using-python-8.py new file mode 100644 index 000000000..db95ecf9b --- /dev/null +++ b/docs/datasets/building/code/using-python-8.py @@ -0,0 +1,10 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = (r.grib(path="dir1/*.grib") & r.grib(path="dir2/*.grib")) | r.clip(minimum=0, maximum=100) + +r.dump() diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index fbf2892cf..0ebec523e 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -16,11 +16,51 @@ First create a ``Recipe`` object, which will hold the configuration: .. literalinclude:: code/using-python-1.py :language: python +you can pass parameters to the ``Recipe`` constructor: + .. literalinclude:: code/using-python-2.py :language: python +or set them later: + .. literalinclude:: code/using-python-3.py :language: python +You need to select which dates to use for building the dataset: + .. literalinclude:: code/using-python-4.py :language: python + +All data sources and filters are defined as method calls on the +``Recipe`` (any hyphen is replaced by an underscore): + +So the ``grib`` source is defined as ``Recipe.grib(...)`` and the +``clip`` filter as ``Recipe.clip(...)``. + +Source and filter methods can be combined together and assigned to +``Recipe.input``. + +Use the pipe operator ``|`` to chain sources and filters: + +.. literalinclude:: code/using-python-5.py + :language: python + +Use the ampersand operator ``&`` to combine multiple inputs: + +.. literalinclude:: code/using-python-6.py + :language: python + +And you can combine both operators: + +.. literalinclude:: code/using-python-7.py + :language: python + +To generate the YAML configuration, call the ``dump()`` method: + +.. literalinclude:: code/using-python-8.py + :language: python + +Which will output: + +.. literalinclude:: yaml/using-python-1.yaml + :language: yaml diff --git a/docs/datasets/building/yaml/using-python-1.yaml b/docs/datasets/building/yaml/using-python-1.yaml new file mode 100644 index 000000000..cff307b90 --- /dev/null +++ b/docs/datasets/building/yaml/using-python-1.yaml @@ -0,0 +1,15 @@ +dates: + start: 2023-01-01T00:00:00Z + end: 2023-12-31T18:00:00Z + frequency: 12h + +input: + pipe: + - join: + - grib: + path: dir1/*.grib + - grib: + path: dir2/*.grib + - clip: + minimum: 0 + maximum: 100 diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 6f8118174..fdbcfb738 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -453,23 +453,21 @@ def dump(self, file=sys.stdout): result["input"] = input - if self.output: - result["output"] = self.output + result["output"] = self.output - if self.statistics: - result["statistics"] = self.statistics + result["statistics"] = self.statistics - if self.build: - result["build"] = self.build + result["build"] = self.build - if self.env: - result["env"] = self.env + result["env"] = self.env - if self.dataset_status: - result["dataset_status"] = self.dataset_status + result["dataset_status"] = self.dataset_status - if self.platform: - result["platform"] = self.platform + result["platform"] = self.platform + + for k, v in list(result.items()): + if v is None or v == {} or v == []: + del result[k] from .dumper import yaml_dump From 066d1ce89b19466124a2973d99571f73bb232f05 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 12:10:58 +0100 Subject: [PATCH 70/79] update docs --- docs/cli/introduction.rst | 9 +-------- docs/datasets/building/filters.rst | 3 ++- docs/datasets/building/using-python.rst | 6 ++++++ docs/index.rst | 4 +++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index 8c574132c..dab10bc09 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -13,11 +13,4 @@ The tool can provide help with the ``--help`` options: % anemoi-datasets --help -The commands are: - -- :ref:`Create Command ` -- :ref:`Copy Command ` -- :ref:`Inspect Command ` -- :ref:`Compare Command ` -- :ref:`Scan Command ` -- :ref:`Compare LAM Command ` +The commands are listed in the left side menu of the documentation. diff --git a/docs/datasets/building/filters.rst b/docs/datasets/building/filters.rst index 3b3bd5abf..1348d17d5 100644 --- a/docs/datasets/building/filters.rst +++ b/docs/datasets/building/filters.rst @@ -11,4 +11,5 @@ Filters are used to modify the data or metadata in a dataset. -See :ref:`install ` for more information. +See :ref:`anemoi-transform ` for more +information. diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index 0ebec523e..f17062ab8 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -64,3 +64,9 @@ Which will output: .. literalinclude:: yaml/using-python-1.yaml :language: yaml + +.. note:: + + To get you started quickly, you can use the :ref:`anemoi-datasets + recipe --python recipe.yaml ` to transform an + existing YAML recipe into a Python script. diff --git a/docs/index.rst b/docs/index.rst index 4d2247ad8..cfd2c0afc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -156,15 +156,17 @@ You may also have to install pandoc on macOS: :hidden: :caption: CLI + cli/introduction cli/create cli/inspect cli/grib-index cli/compare cli/copy cli/scan - cli/patch + cli/recipe cli/compare-lam cli/validate + cli/patch .. toctree:: :maxdepth: 1 From 423f8a31f55c7c810e7adba3787ce3dd8bb0f3fa Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 12:13:15 +0100 Subject: [PATCH 71/79] update docs --- docs/datasets/building/using-python.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index f17062ab8..d2ab29f2f 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -1,6 +1,6 @@ -############################# - Using Python define recipes -############################# +############################## + Using Python defined recipes +############################## You can use Python to define recipes for building datasets. This allows for more complex logic and flexibility compared to using static From 8bfa5d390a4c0ce3523b8f637910c79734fa0266 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 12:47:40 +0100 Subject: [PATCH 72/79] update --- .../datasets/building/code/using-python-10.py | 13 ++ .../datasets/building/code/using-python-11.py | 14 ++ docs/datasets/building/code/using-python-7.py | 1 - docs/datasets/building/code/using-python-8.py | 1 - docs/datasets/building/code/using-python-9.py | 13 ++ docs/datasets/building/using-python.rst | 24 ++++ ...an-oper-0001-mars-o48-2020-2021-6h-v1.yaml | 122 +++++++++--------- 7 files changed, 125 insertions(+), 63 deletions(-) create mode 100644 docs/datasets/building/code/using-python-10.py create mode 100644 docs/datasets/building/code/using-python-11.py create mode 100644 docs/datasets/building/code/using-python-9.py diff --git a/docs/datasets/building/code/using-python-10.py b/docs/datasets/building/code/using-python-10.py new file mode 100644 index 000000000..1a9488927 --- /dev/null +++ b/docs/datasets/building/code/using-python-10.py @@ -0,0 +1,13 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = ( + (a := r.grib(path="dir1/*.grib")) + & r.grib(path="dir2/*.grib") + & r.forcings(param=["cos_latitude", "sin_latitude"], template=a) +) + +r.dump() diff --git a/docs/datasets/building/code/using-python-11.py b/docs/datasets/building/code/using-python-11.py new file mode 100644 index 000000000..735d3e8dc --- /dev/null +++ b/docs/datasets/building/code/using-python-11.py @@ -0,0 +1,14 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +r.input = r.concat( + { + ("2023-01-01T00:00:00", "2023-06-30T18:00:00", "12h"): r.grib(path="gribs/*.grib"), + ("2023-07-01T00:00:00", "2023-12-31T18:00:00", "12h"): r.netcdf(path="ncdfs/*.nc"), + } +) + +r.dump() diff --git a/docs/datasets/building/code/using-python-7.py b/docs/datasets/building/code/using-python-7.py index d6d62de1f..2f5042f79 100644 --- a/docs/datasets/building/code/using-python-7.py +++ b/docs/datasets/building/code/using-python-7.py @@ -2,7 +2,6 @@ r = Recipe() - r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") r.input = (r.grib(path="dir1/*.grib") & r.grib(path="dir2/*.grib")) | r.clip(minimum=0, maximum=100) diff --git a/docs/datasets/building/code/using-python-8.py b/docs/datasets/building/code/using-python-8.py index db95ecf9b..0c24a626c 100644 --- a/docs/datasets/building/code/using-python-8.py +++ b/docs/datasets/building/code/using-python-8.py @@ -2,7 +2,6 @@ r = Recipe() - r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") r.input = (r.grib(path="dir1/*.grib") & r.grib(path="dir2/*.grib")) | r.clip(minimum=0, maximum=100) diff --git a/docs/datasets/building/code/using-python-9.py b/docs/datasets/building/code/using-python-9.py new file mode 100644 index 000000000..6aa86608e --- /dev/null +++ b/docs/datasets/building/code/using-python-9.py @@ -0,0 +1,13 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.dates = ("2023-01-01T00:00:00", "2023-12-31T18:00:00", "12h") + +a = r.grib(path="dir1/*.grib") +b = r.grib(path="dir2/*.grib") +c = r.forcings(param=["cos_latitude", "sin_latitude"], template=a) + +r.input = a & b & c + +r.dump() diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index d2ab29f2f..f6b5da168 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -65,6 +65,30 @@ Which will output: .. literalinclude:: yaml/using-python-1.yaml :language: yaml +Sometimes you need to refer to part of the input in a source or a +filter, such as when using the :ref:`forcing_variables` source. + +You can do this by assigning the result of a source or filter to a +variable, and use that variable later in the recipe. + +.. literalinclude:: code/using-python-9.py + :language: python + +Or you can assigning the result of a source or filter to a variable +using the walrus operator ``:=`` to both assign and use the variable in +the same expression: + +.. literalinclude:: code/using-python-10.py + :language: python + +Finally, if you need different inputs for different dates, you can use +the following ``Recipe.concat``: + +.. literalinclude:: code/using-python-11.py + +Note that the dates can also be ``datetime`` objects and the frequency +can be a ``timedelta`` object. + .. note:: To get you started quickly, you can use the :ref:`anemoi-datasets diff --git a/docs/usage/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml b/docs/usage/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml index aff1efbd2..7ecaf966d 100644 --- a/docs/usage/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml +++ b/docs/usage/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml @@ -7,66 +7,66 @@ attribution: ECMWF/C3S licence: CC-BY-4.0 dates: - start: '2020-01-01T00:00:00' - end: '2021-12-31T23:00:00' - frequency: 6h + start: '2020-01-01T00:00:00' + end: '2021-12-31T23:00:00' + frequency: 6h input: - join: - - mars: - use_cdsapi_dataset: "reanalysis-era5-complete" - class: ea - expver: '0001' - grid: o48 - levtype: sfc - param: - - 10u - - 10v - - 2d - - 2t - - lsm - - msl - - sdor - - skt - - slor - - sp - - tcw - - z - - mars: - use_cdsapi_dataset: "reanalysis-era5-complete" - class: ea - expver: '0001' - grid: o48 - level: - - 250 - - 500 - - 850 - - 1000 - levtype: pl - param: - - u - - v - - q - - t - - z - - accumulations: - use_cdsapi_dataset: "reanalysis-era5-complete" - accumulation_period: 6 - class: ea - expver: '0001' - grid: o48 - param: - - cp - - tp - - constants: - param: - - cos_latitude - - cos_longitude - - sin_latitude - - sin_longitude - - cos_julian_day - - cos_local_time - - sin_julian_day - - sin_local_time - - insolation - template: ${input.join.0.mars} + join: + - mars: + use_cdsapi_dataset: "reanalysis-era5-complete" + class: ea + expver: '0001' + grid: o48 + levtype: sfc + param: + - 10u + - 10v + - 2d + - 2t + - lsm + - msl + - sdor + - skt + - slor + - sp + - tcw + - z + - mars: + use_cdsapi_dataset: "reanalysis-era5-complete" + class: ea + expver: '0001' + grid: o48 + level: + - 250 + - 500 + - 850 + - 1000 + levtype: pl + param: + - u + - v + - q + - t + - z + - accumulations: + use_cdsapi_dataset: "reanalysis-era5-complete" + accumulation_period: 6 + class: ea + expver: '0001' + grid: o48 + param: + - cp + - tp + - forcings: + param: + - cos_latitude + - cos_longitude + - sin_latitude + - sin_longitude + - cos_julian_day + - cos_local_time + - sin_julian_day + - sin_local_time + - insolation + template: ${input.join.0.mars} From e4f3ddb0a52588bd995a685542d808b87cf14b27 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 12:51:03 +0100 Subject: [PATCH 73/79] update --- docs/cli/recipe.rst | 20 ++++++++++++++++++++ docs/datasets/building/using-python.rst | 4 ++-- 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 docs/cli/recipe.rst diff --git a/docs/cli/recipe.rst b/docs/cli/recipe.rst new file mode 100644 index 000000000..0fb325ffa --- /dev/null +++ b/docs/cli/recipe.rst @@ -0,0 +1,20 @@ +.. _recipe_command: + +Recipe Command +============== + + +Anemoi datasets are stored in a zarr format and can be located on a local file system or on a remote server. +The `inspect` command is used to inspect the contents of a dataset. +This command will output the metadata of the dataset, including the variables, dimensions, and attributes. + +.. code:: console + + $ anemoi-datasets recipe [options] recipe.yaml + + +.. argparse:: + :module: anemoi.datasets.__main__ + :func: create_parser + :prog: anemoi-datasets + :path: recipe diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index f6b5da168..959fc002d 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -86,8 +86,8 @@ the following ``Recipe.concat``: .. literalinclude:: code/using-python-11.py -Note that the dates can also be ``datetime`` objects and the frequency -can be a ``timedelta`` object. +Note that the dates can also be :class:`datetime.datetime` objects and +the frequency can be a :class:`datetime.timedelta` object. .. note:: From 83694ef146aa0fe562342399baad44c858c16ddf Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 22 Sep 2025 12:54:18 +0100 Subject: [PATCH 74/79] update --- docs/datasets/building/using-python.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index 959fc002d..fffcd4ab5 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -82,7 +82,8 @@ the same expression: :language: python Finally, if you need different inputs for different dates, you can use -the following ``Recipe.concat``: +the ``Recipe.concat`` method, which takes a dictionary mapping dates to +inputs: .. literalinclude:: code/using-python-11.py From 66945f0edd5135175d10ed95b371a606c1bff801 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Tue, 23 Sep 2025 09:01:26 +0100 Subject: [PATCH 75/79] Update docs/datasets/building/code/using-python-3.py Co-authored-by: Florian Pinault --- docs/datasets/building/code/using-python-3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/datasets/building/code/using-python-3.py b/docs/datasets/building/code/using-python-3.py index f21dc3947..173ce89ea 100644 --- a/docs/datasets/building/code/using-python-3.py +++ b/docs/datasets/building/code/using-python-3.py @@ -4,7 +4,7 @@ r.description = """ Example dataset recipe using Python, with attributes set one by one -and a multiline description. +and a multi-line description. """ r.name = "example-dataset" From e26ce75543cc4b142cf1b39ad4f302dd58ed37ec Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 23 Sep 2025 09:27:35 +0100 Subject: [PATCH 76/79] add docs --- docs/conf.py | 3 + .../datasets/building/code/using-python-12.py | 249 ++++++++++++++++++ .../building/sources/accumulations.rst | 2 + .../building/sources/anemoi-dataset.rst | 2 +- docs/datasets/building/sources/cds.rst | 2 +- docs/datasets/building/sources/forcings.rst | 2 +- docs/datasets/building/sources/grib-index.rst | 2 +- docs/datasets/building/sources/grib.rst | 2 +- docs/datasets/building/sources/mars.rst | 2 + docs/datasets/building/sources/netcdf.rst | 2 + docs/datasets/building/sources/opendap.rst | 2 + docs/datasets/building/sources/recentre.rst | 2 +- .../building/sources/repeated-dates.rst | 2 + .../building/sources/xarray-based.rst | 2 + .../building/sources/xarray-kerchunk.rst | 2 + .../datasets/building/sources/xarray-zarr.rst | 2 +- docs/datasets/building/sources/zenodo.rst | 2 + docs/datasets/building/using-python.rst | 24 +- docs/index.rst | 8 - 19 files changed, 294 insertions(+), 20 deletions(-) create mode 100644 docs/datasets/building/code/using-python-12.py diff --git a/docs/conf.py b/docs/conf.py index f9e5c6aff..51fe4b3a1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,6 +77,9 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "'**.ipynb_checkpoints'"] +# To list the symbols: +# python -m sphinx.ext.intersphinx https://anemoi-utils.readthedocs.io/en/latest/objects.inv . + intersphinx_mapping = { "python": ("https://python.readthedocs.io/en/latest", None), "anemoi-docs": ( diff --git a/docs/datasets/building/code/using-python-12.py b/docs/datasets/building/code/using-python-12.py new file mode 100644 index 000000000..8b65664ad --- /dev/null +++ b/docs/datasets/building/code/using-python-12.py @@ -0,0 +1,249 @@ +import datetime + +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.description = """ +This is a complex example of a dataset recipe written in Python. +It uses data from two different ECMWF research experiments for atmospheric and wave data, +from ECMWF's MARS archive. For the atmospheric data, it combines data from two +12-hourly data streams (oper and lwda) to create a dataset with a 6-hourly frequency. +""" + +r.name = "aifs-rd-an-oper-ioku-mars-n320-2024-2024-6h-v1" +r.licence = "CC-BY-4.0" +r.attribution = "ECMWF" + +start_date = datetime.datetime(2024, 5, 2, 0, 0) +end_date = datetime.datetime(2024, 9, 8, 18, 0) + +r.dates = { + "start": start_date, + "end": end_date, + "frequency": "6h", +} + +r.build = {"use_grib_paramid": True} +r.statistics = {"allow_nans": True} + + +grid = "n320" + +ioku = { + "class": "rd", + "grid": grid, + "expver": "ioku", +} + +ikdi = { + "class": "rd", + "grid": grid, + "expver": "ikdi", +} + +accumulations_stream = { + "oper": "lwda", + "lwda": "oper", +} + + +def accumulations(stream): + return r.accumulations( + levtype="sfc", + param=["cp", "tp", "sf", "strd", "ssrd"], + stream=accumulations_stream[stream], + **ioku, + ) + + +def pressure_levels(stream): + return r.mars( + stream=stream, + level=[ + 1, + 10, + 30, + 50, + 70, + 100, + 150, + 200, + 250, + 300, + 400, + 500, + 600, + 700, + 850, + 925, + 1000, + ], + levtype="pl", + param=["t", "u", "v", "w", "z"], + **ioku, + ) + + +def pressure_levels_q(stream): + return r.mars( + levtype="pl", + param=["q"], + level=[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], + stream=stream, + **ioku, + ) + + +def sfc_fields(stream): + return r.mars( + levtype="sfc", + param=[ + "10u", + "10v", + "2d", + "2t", + "lsm", + "msl", + "sdor", + "skt", + "slor", + "tcw", + "z", + # Land parameters below + "stl1", + "stl2", + "tcc", + "mcc", + "hcc", + "lcc", + "100u", + "100v", + ], + stream=stream, + **ioku, + ) + + +def surface_pressure(stream): + return ( + r.mars( + levtype="ml", + levelist=1, + param="lnsp", + stream=stream, + **ioku, + ) + | r.lnsp_to_sp() + ) + + +def apply_mask(): + return r.apply_mask( + path="/data/climate.v015/319_3/lsm.grib", + mask_value=0, + ) + + +def land_params(stream): + soil_params = r.mars( + levtype="sfc", + param=["swvl1", "swvl2", "sd"], + stream=stream, + **ioku, + ) + + snow_cover = ( + r.mars( + levtype="sfc", + param=["sd", "rsn"], + stream=stream, + **ioku, + ) + | r.snow_cover() + ) + + run_off = r.accumulations( + levtype="sfc", + param=["ro"], + stream=accumulations_stream[stream], + **ioku, + ) + + return (soil_params & snow_cover & run_off) | apply_mask() + + +def constants(template): + return r.constants( + param=[ + "cos_latitude", + "cos_longitude", + "sin_latitude", + "sin_longitude", + "cos_julian_day", + "cos_local_time", + "sin_julian_day", + "sin_local_time", + "insolation", + ], + template=template, + ) + + +def wave_data(): + return ( + r.mars( + param=[ + "swh", + "cdww", + "mwp", + "mwd", + "wmb", + "h1012", + "h1214", + "h1417", + "h1721", + "h2125", + "h2530", + ], + stream="wave", + **ikdi, + ) + | r.cos_sin_mean_wave_direction() + ) + + +def atmos_data(stream): + return ( + (a := sfc_fields(stream)) + & surface_pressure(stream) + & pressure_levels(stream) + & pressure_levels_q(stream) + & accumulations(stream) + & land_params(stream) + & constants(template=a) + ) + + +def dates(hour): + s = start_date.replace(hour=hour) + e = end_date.replace(hour=hour + 12) + while s > start_date: + s -= datetime.timedelta(hours=24) + while e < end_date: + e += datetime.timedelta(hours=24) + return (s, e, "12h") + + +def input_data(): + return r.concat( + { + dates(0): atmos_data("oper"), + dates(6): atmos_data("lwda"), + } + ) + + +r.input = input_data() & wave_data() + +r.dump() diff --git a/docs/datasets/building/sources/accumulations.rst b/docs/datasets/building/sources/accumulations.rst index c11f4aee8..c9e26af47 100644 --- a/docs/datasets/building/sources/accumulations.rst +++ b/docs/datasets/building/sources/accumulations.rst @@ -1,3 +1,5 @@ +.. _accumulations-source: + ############### accumulations ############### diff --git a/docs/datasets/building/sources/anemoi-dataset.rst b/docs/datasets/building/sources/anemoi-dataset.rst index a8e336318..5f653df6a 100644 --- a/docs/datasets/building/sources/anemoi-dataset.rst +++ b/docs/datasets/building/sources/anemoi-dataset.rst @@ -1,4 +1,4 @@ -.. _anemoi-dataset_source: +.. _anemoi-dataset-source: ################ anemoi-dataset diff --git a/docs/datasets/building/sources/cds.rst b/docs/datasets/building/sources/cds.rst index e7bac5bc8..a8f8505ca 100644 --- a/docs/datasets/building/sources/cds.rst +++ b/docs/datasets/building/sources/cds.rst @@ -1,4 +1,4 @@ -.. _cds_source: +.. _cds-source: ##### cds diff --git a/docs/datasets/building/sources/forcings.rst b/docs/datasets/building/sources/forcings.rst index d6192aa61..47b1951d0 100644 --- a/docs/datasets/building/sources/forcings.rst +++ b/docs/datasets/building/sources/forcings.rst @@ -1,4 +1,4 @@ -.. _forcing_variables: +.. _forcings-source: ########## forcings diff --git a/docs/datasets/building/sources/grib-index.rst b/docs/datasets/building/sources/grib-index.rst index 5feca291a..328c2b676 100644 --- a/docs/datasets/building/sources/grib-index.rst +++ b/docs/datasets/building/sources/grib-index.rst @@ -1,4 +1,4 @@ -.. _grib-index_source: +.. _grib-index-source: ############ grib-index diff --git a/docs/datasets/building/sources/grib.rst b/docs/datasets/building/sources/grib.rst index 27ee0cd97..9582106ba 100644 --- a/docs/datasets/building/sources/grib.rst +++ b/docs/datasets/building/sources/grib.rst @@ -1,4 +1,4 @@ -.. _grib_source: +.. _grib-source: ###### grib diff --git a/docs/datasets/building/sources/mars.rst b/docs/datasets/building/sources/mars.rst index f2b7ccaf2..c897f0f20 100644 --- a/docs/datasets/building/sources/mars.rst +++ b/docs/datasets/building/sources/mars.rst @@ -1,3 +1,5 @@ +.. _mars-source: + ###### mars ###### diff --git a/docs/datasets/building/sources/netcdf.rst b/docs/datasets/building/sources/netcdf.rst index 61e0e03cb..dd663b1d4 100644 --- a/docs/datasets/building/sources/netcdf.rst +++ b/docs/datasets/building/sources/netcdf.rst @@ -1,3 +1,5 @@ +.. _netcdf-source: + ######## netcdf ######## diff --git a/docs/datasets/building/sources/opendap.rst b/docs/datasets/building/sources/opendap.rst index 41f6c35ed..df1904761 100644 --- a/docs/datasets/building/sources/opendap.rst +++ b/docs/datasets/building/sources/opendap.rst @@ -1,3 +1,5 @@ +.. _opendap-source: + ######### opendap ######### diff --git a/docs/datasets/building/sources/recentre.rst b/docs/datasets/building/sources/recentre.rst index 091f93003..e7faf82a5 100644 --- a/docs/datasets/building/sources/recentre.rst +++ b/docs/datasets/building/sources/recentre.rst @@ -1,4 +1,4 @@ -.. _recentre: +.. _recentre-source: ########## recentre diff --git a/docs/datasets/building/sources/repeated-dates.rst b/docs/datasets/building/sources/repeated-dates.rst index 53baf3283..241a5e3b0 100644 --- a/docs/datasets/building/sources/repeated-dates.rst +++ b/docs/datasets/building/sources/repeated-dates.rst @@ -1,3 +1,5 @@ +.. _repeated-dates-source: + ################ repeated-dates ################ diff --git a/docs/datasets/building/sources/xarray-based.rst b/docs/datasets/building/sources/xarray-based.rst index 44dcc5923..cb7dbb0a8 100644 --- a/docs/datasets/building/sources/xarray-based.rst +++ b/docs/datasets/building/sources/xarray-based.rst @@ -1,3 +1,5 @@ +.. _xarray-based-sources: + ###################### xarray-based Sources ###################### diff --git a/docs/datasets/building/sources/xarray-kerchunk.rst b/docs/datasets/building/sources/xarray-kerchunk.rst index e50543055..1b6a96f0d 100644 --- a/docs/datasets/building/sources/xarray-kerchunk.rst +++ b/docs/datasets/building/sources/xarray-kerchunk.rst @@ -1,3 +1,5 @@ +.. _xarray-kerchunk-source: + ################# xarray-kerchunk ################# diff --git a/docs/datasets/building/sources/xarray-zarr.rst b/docs/datasets/building/sources/xarray-zarr.rst index 0f9ce62c8..4771602ce 100644 --- a/docs/datasets/building/sources/xarray-zarr.rst +++ b/docs/datasets/building/sources/xarray-zarr.rst @@ -1,4 +1,4 @@ -.. _xarray-zarr: +.. _xarray-zarr-source: ############# xarray-zarr diff --git a/docs/datasets/building/sources/zenodo.rst b/docs/datasets/building/sources/zenodo.rst index ce73aca10..93968dbf4 100644 --- a/docs/datasets/building/sources/zenodo.rst +++ b/docs/datasets/building/sources/zenodo.rst @@ -1,3 +1,5 @@ +.. _zenodo-source: + ######## zenodo ######## diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index fffcd4ab5..b84c1921f 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -34,8 +34,9 @@ You need to select which dates to use for building the dataset: All data sources and filters are defined as method calls on the ``Recipe`` (any hyphen is replaced by an underscore): -So the ``grib`` source is defined as ``Recipe.grib(...)`` and the -``clip`` filter as ``Recipe.clip(...)``. +So the :ref:`grib ` source is defined as +``Recipe.grib(...)`` and the :ref:`clip ` +filter as ``Recipe.clip(...)``. Source and filter methods can be combined together and assigned to ``Recipe.input``. @@ -55,7 +56,7 @@ And you can combine both operators: .. literalinclude:: code/using-python-7.py :language: python -To generate the YAML configuration, call the ``dump()`` method: +To generate the YAML configuration, call the ``Recipe.dump()`` method: .. literalinclude:: code/using-python-8.py :language: python @@ -82,10 +83,11 @@ the same expression: :language: python Finally, if you need different inputs for different dates, you can use -the ``Recipe.concat`` method, which takes a dictionary mapping dates to -inputs: +the ``Recipe.concat()`` method, which takes a dictionary mapping dates +to inputs: .. literalinclude:: code/using-python-11.py + :language: python Note that the dates can also be :class:`datetime.datetime` objects and the frequency can be a :class:`datetime.timedelta` object. @@ -95,3 +97,15 @@ the frequency can be a :class:`datetime.timedelta` object. To get you started quickly, you can use the :ref:`anemoi-datasets recipe --python recipe.yaml ` to transform an existing YAML recipe into a Python script. + +Below is the complete example. It uses the :ref:`mars-source` and +:ref:`accumulations-source` source to get data from the ECMWF's MARS +archive. In addition, it uses :ref:`lnsp-to-sp +` to convert the logarithm of the +surface pressure to the surface pressure, :ref:`snow-cover +` to compute the snow cover from the +snow depth and snow density and :ref:`apply-mask +` to replace zeros with `NaNs`. + +.. literalinclude:: code/using-python-12.py + :language: python diff --git a/docs/index.rst b/docs/index.rst index cfd2c0afc..03f3ea083 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -168,14 +168,6 @@ You may also have to install pandoc on macOS: cli/validate cli/patch -.. toctree:: - :maxdepth: 1 - :glob: - :hidden: - :caption: API Reference - - modules/* - .. toctree:: :maxdepth: 1 :hidden: From ec4e74c7f2aa13a33a0d9417afda968a126ae115 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 13:43:26 +0100 Subject: [PATCH 77/79] update --- src/anemoi/datasets/commands/recipe/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index d6f02901a..773e9e4ab 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -95,6 +95,8 @@ def run(self, args: Any) -> None: f.close() if args.python: + from anemoi.datasets.create import config_to_python + if args.inplace: argparse.ArgumentError(None, "Inplace conversion to Python is not supported.") From 3c0ae14760cc1ae9d8529300eac9f0fae5961ef8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 13:52:38 +0100 Subject: [PATCH 78/79] update with main --- docs/building/introduction.rst | 11 -- docs/building/sources/accumulations.rst | 2 - docs/building/sources/anemoi-dataset.rst | 2 +- docs/building/sources/cds.rst | 2 +- docs/building/sources/forcings.rst | 2 +- docs/building/sources/grib-index.rst | 2 +- docs/building/sources/grib.rst | 2 +- docs/building/sources/mars.rst | 2 - docs/building/sources/netcdf.rst | 2 - docs/building/sources/opendap.rst | 2 - docs/building/sources/recentre.rst | 2 +- docs/building/sources/repeated-dates.rst | 2 - docs/building/sources/xarray-based.rst | 2 - docs/building/sources/xarray-kerchunk.rst | 2 - docs/building/sources/xarray-zarr.rst | 2 +- docs/building/sources/zenodo.rst | 2 - docs/cli/introduction.rst | 10 +- docs/cli/validate.rst | 2 +- docs/conf.py | 3 - ...an-oper-0001-mars-o48-2020-2021-6h-v1.yaml | 122 +++++++++--------- docs/index.rst | 11 +- src/anemoi/datasets/create/__init__.py | 1 - src/anemoi/datasets/create/config.py | 2 +- src/anemoi/datasets/create/input/__init__.py | 94 ++++++-------- src/anemoi/datasets/create/input/action.py | 4 +- 25 files changed, 127 insertions(+), 163 deletions(-) diff --git a/docs/building/introduction.rst b/docs/building/introduction.rst index 5ace4044e..e85a796c9 100644 --- a/docs/building/introduction.rst +++ b/docs/building/introduction.rst @@ -105,14 +105,3 @@ operations can be combined to build complex datasets. :caption: Naming Conventions naming-conventions - -**************** - Python recipes -**************** - -.. toctree:: - :maxdepth: 1 - :hidden: - :caption: Python recipes - - using-python diff --git a/docs/building/sources/accumulations.rst b/docs/building/sources/accumulations.rst index c9e26af47..c11f4aee8 100644 --- a/docs/building/sources/accumulations.rst +++ b/docs/building/sources/accumulations.rst @@ -1,5 +1,3 @@ -.. _accumulations-source: - ############### accumulations ############### diff --git a/docs/building/sources/anemoi-dataset.rst b/docs/building/sources/anemoi-dataset.rst index 5f653df6a..a8e336318 100644 --- a/docs/building/sources/anemoi-dataset.rst +++ b/docs/building/sources/anemoi-dataset.rst @@ -1,4 +1,4 @@ -.. _anemoi-dataset-source: +.. _anemoi-dataset_source: ################ anemoi-dataset diff --git a/docs/building/sources/cds.rst b/docs/building/sources/cds.rst index a8f8505ca..e7bac5bc8 100644 --- a/docs/building/sources/cds.rst +++ b/docs/building/sources/cds.rst @@ -1,4 +1,4 @@ -.. _cds-source: +.. _cds_source: ##### cds diff --git a/docs/building/sources/forcings.rst b/docs/building/sources/forcings.rst index 47b1951d0..d6192aa61 100644 --- a/docs/building/sources/forcings.rst +++ b/docs/building/sources/forcings.rst @@ -1,4 +1,4 @@ -.. _forcings-source: +.. _forcing_variables: ########## forcings diff --git a/docs/building/sources/grib-index.rst b/docs/building/sources/grib-index.rst index 328c2b676..5feca291a 100644 --- a/docs/building/sources/grib-index.rst +++ b/docs/building/sources/grib-index.rst @@ -1,4 +1,4 @@ -.. _grib-index-source: +.. _grib-index_source: ############ grib-index diff --git a/docs/building/sources/grib.rst b/docs/building/sources/grib.rst index 9582106ba..27ee0cd97 100644 --- a/docs/building/sources/grib.rst +++ b/docs/building/sources/grib.rst @@ -1,4 +1,4 @@ -.. _grib-source: +.. _grib_source: ###### grib diff --git a/docs/building/sources/mars.rst b/docs/building/sources/mars.rst index c897f0f20..f2b7ccaf2 100644 --- a/docs/building/sources/mars.rst +++ b/docs/building/sources/mars.rst @@ -1,5 +1,3 @@ -.. _mars-source: - ###### mars ###### diff --git a/docs/building/sources/netcdf.rst b/docs/building/sources/netcdf.rst index dd663b1d4..61e0e03cb 100644 --- a/docs/building/sources/netcdf.rst +++ b/docs/building/sources/netcdf.rst @@ -1,5 +1,3 @@ -.. _netcdf-source: - ######## netcdf ######## diff --git a/docs/building/sources/opendap.rst b/docs/building/sources/opendap.rst index df1904761..41f6c35ed 100644 --- a/docs/building/sources/opendap.rst +++ b/docs/building/sources/opendap.rst @@ -1,5 +1,3 @@ -.. _opendap-source: - ######### opendap ######### diff --git a/docs/building/sources/recentre.rst b/docs/building/sources/recentre.rst index e7faf82a5..091f93003 100644 --- a/docs/building/sources/recentre.rst +++ b/docs/building/sources/recentre.rst @@ -1,4 +1,4 @@ -.. _recentre-source: +.. _recentre: ########## recentre diff --git a/docs/building/sources/repeated-dates.rst b/docs/building/sources/repeated-dates.rst index 241a5e3b0..53baf3283 100644 --- a/docs/building/sources/repeated-dates.rst +++ b/docs/building/sources/repeated-dates.rst @@ -1,5 +1,3 @@ -.. _repeated-dates-source: - ################ repeated-dates ################ diff --git a/docs/building/sources/xarray-based.rst b/docs/building/sources/xarray-based.rst index cb7dbb0a8..44dcc5923 100644 --- a/docs/building/sources/xarray-based.rst +++ b/docs/building/sources/xarray-based.rst @@ -1,5 +1,3 @@ -.. _xarray-based-sources: - ###################### xarray-based Sources ###################### diff --git a/docs/building/sources/xarray-kerchunk.rst b/docs/building/sources/xarray-kerchunk.rst index 1b6a96f0d..e50543055 100644 --- a/docs/building/sources/xarray-kerchunk.rst +++ b/docs/building/sources/xarray-kerchunk.rst @@ -1,5 +1,3 @@ -.. _xarray-kerchunk-source: - ################# xarray-kerchunk ################# diff --git a/docs/building/sources/xarray-zarr.rst b/docs/building/sources/xarray-zarr.rst index 4771602ce..0f9ce62c8 100644 --- a/docs/building/sources/xarray-zarr.rst +++ b/docs/building/sources/xarray-zarr.rst @@ -1,4 +1,4 @@ -.. _xarray-zarr-source: +.. _xarray-zarr: ############# xarray-zarr diff --git a/docs/building/sources/zenodo.rst b/docs/building/sources/zenodo.rst index 93968dbf4..ce73aca10 100644 --- a/docs/building/sources/zenodo.rst +++ b/docs/building/sources/zenodo.rst @@ -1,5 +1,3 @@ -.. _zenodo-source: - ######## zenodo ######## diff --git a/docs/cli/introduction.rst b/docs/cli/introduction.rst index fd82e4cef..1e4154402 100644 --- a/docs/cli/introduction.rst +++ b/docs/cli/introduction.rst @@ -13,4 +13,12 @@ The tool can provide help with the ``--help`` options: % anemoi-datasets --help -The commands are listed in the left side menu of the documentation. +The commands are: + +- :ref:`Create Command ` +- :ref:`Copy Command ` +- :ref:`Inspect Command ` +- :ref:`Compare Command ` +- :ref:`Scan Command ` +- :ref:`Validate Command ` +- :ref:`Compare LAM Command ` diff --git a/docs/cli/validate.rst b/docs/cli/validate.rst index 56aa0fbc7..57d5bf9e8 100644 --- a/docs/cli/validate.rst +++ b/docs/cli/validate.rst @@ -1,7 +1,7 @@ .. _validate_command: Validate Command -================ +============ Use this command to validate a zarr dataset, or a class that implements the :class:`anemoi.datasets.Dataset` interface. diff --git a/docs/conf.py b/docs/conf.py index 51fe4b3a1..f9e5c6aff 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,9 +77,6 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "'**.ipynb_checkpoints'"] -# To list the symbols: -# python -m sphinx.ext.intersphinx https://anemoi-utils.readthedocs.io/en/latest/objects.inv . - intersphinx_mapping = { "python": ("https://python.readthedocs.io/en/latest", None), "anemoi-docs": ( diff --git a/docs/howtos/create/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml b/docs/howtos/create/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml index 7ecaf966d..aff1efbd2 100644 --- a/docs/howtos/create/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml +++ b/docs/howtos/create/yaml/aifs-ea-an-oper-0001-mars-o48-2020-2021-6h-v1.yaml @@ -7,66 +7,66 @@ attribution: ECMWF/C3S licence: CC-BY-4.0 dates: - start: '2020-01-01T00:00:00' - end: '2021-12-31T23:00:00' - frequency: 6h + start: '2020-01-01T00:00:00' + end: '2021-12-31T23:00:00' + frequency: 6h input: - join: - - mars: - use_cdsapi_dataset: "reanalysis-era5-complete" - class: ea - expver: '0001' - grid: o48 - levtype: sfc - param: - - 10u - - 10v - - 2d - - 2t - - lsm - - msl - - sdor - - skt - - slor - - sp - - tcw - - z - - mars: - use_cdsapi_dataset: "reanalysis-era5-complete" - class: ea - expver: '0001' - grid: o48 - level: - - 250 - - 500 - - 850 - - 1000 - levtype: pl - param: - - u - - v - - q - - t - - z - - accumulations: - use_cdsapi_dataset: "reanalysis-era5-complete" - accumulation_period: 6 - class: ea - expver: '0001' - grid: o48 - param: - - cp - - tp - - forcings: - param: - - cos_latitude - - cos_longitude - - sin_latitude - - sin_longitude - - cos_julian_day - - cos_local_time - - sin_julian_day - - sin_local_time - - insolation - template: ${input.join.0.mars} + join: + - mars: + use_cdsapi_dataset: "reanalysis-era5-complete" + class: ea + expver: '0001' + grid: o48 + levtype: sfc + param: + - 10u + - 10v + - 2d + - 2t + - lsm + - msl + - sdor + - skt + - slor + - sp + - tcw + - z + - mars: + use_cdsapi_dataset: "reanalysis-era5-complete" + class: ea + expver: '0001' + grid: o48 + level: + - 250 + - 500 + - 850 + - 1000 + levtype: pl + param: + - u + - v + - q + - t + - z + - accumulations: + use_cdsapi_dataset: "reanalysis-era5-complete" + accumulation_period: 6 + class: ea + expver: '0001' + grid: o48 + param: + - cp + - tp + - constants: + param: + - cos_latitude + - cos_longitude + - sin_latitude + - sin_longitude + - cos_julian_day + - cos_local_time + - sin_julian_day + - sin_local_time + - insolation + template: ${input.join.0.mars} diff --git a/docs/index.rst b/docs/index.rst index 7cf2d126b..bf0fcaa6b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -149,10 +149,17 @@ You may also have to install pandoc on macOS: cli/compare cli/copy cli/scan - cli/recipe + cli/patch cli/compare-lam cli/validate - cli/patch + +.. toctree:: + :maxdepth: 1 + :glob: + :hidden: + :caption: API Reference + + modules/* .. toctree:: :maxdepth: 1 diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 99f279ce2..b60258d02 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -865,7 +865,6 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(argument=group) result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 77b51a3fc..b551c6f4d 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -93,7 +93,7 @@ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: if dic[key] == value: return raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - # LOG.info(f"Setting {key}={value} in config") + LOG.info(f"Setting {key}={value} in config") dic[key] = value diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 29194a80a..095ce61ba 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -8,14 +8,14 @@ # nor does it submit to any jurisdiction. from copy import deepcopy +from functools import cached_property +from typing import TYPE_CHECKING from typing import Any -from typing import Union -from anemoi.datasets.dates.groups import GroupOfDates +from anemoi.datasets.create.input.context.field import FieldContext -from .trace import trace_select - -LOG = logging.getLogger(__name__) +if TYPE_CHECKING: + from anemoi.datasets.create.input.action import Recipe class Context: @@ -40,17 +40,17 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No Additional keyword arguments. """ self.kwargs = kwargs + self.config = deepcopy(config) + self.data_sources = deepcopy(dict(data_sources=data_sources)) + + @cached_property + def action(self) -> "Recipe": + """Returns the action object based on the configuration.""" + from .action import Recipe + from .action import action_factory - config = deepcopy(config) - if data_sources: - config = dict( - data_sources=dict( - sources=data_sources, - input=config, - ) - ) - self.config = config - self.action_path = ["input"] + sources = action_factory(self.data_sources, "data_sources") + input = action_factory(self.config, "input") return Recipe(input, sources) @@ -67,45 +67,25 @@ def select(self, argument) -> Any: Any Selected data. """ - from .action import ActionContext - from .action import action_factory - - """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) - - def _trace_select(self, config: dict, data_sources: Union[dict, list], **kwargs) -> str: - """Trace the select operation. - - Parameters - ---------- - config : dict - Configuration dictionary. - data_sources : Union[dict, list] - Data sources. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - InputBuilder - An instance of InputBuilder. - """ - - return InputBuilder(config, data_sources, **kwargs) + context = FieldContext(argument, **self.kwargs) + return context.create_result(self.action(context, argument)) + + +def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: + """Build an InputBuilder instance. + + Parameters + ---------- + config : dict + Configuration dictionary. + data_sources : Union[dict, list] + Data sources. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + InputBuilder + An instance of InputBuilder. + """ + return InputBuilder(config, data_sources, **kwargs) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 62dabed19..5f21503d5 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -124,7 +124,7 @@ class Join(Action): def __init__(self, config, *path): super().__init__(config, *path, "join") - assert isinstance(config, list), f"Value must be a list {config}" + assert isinstance(config, list), f"Value of Join Action must be a list, got: {config}" self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] @@ -162,7 +162,7 @@ class Pipe(Action): """ def __init__(self, config, *path): - assert isinstance(config, list), f"Value must be a list {config}" + assert isinstance(config, list), f"Value of Pipe Action must be a list, got {config}" super().__init__(config, *path, "pipe") self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] From 26dfbaa7815df98d7fbfe51add339391ebba8974 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 13:54:22 +0100 Subject: [PATCH 79/79] update with main --- src/anemoi/datasets/create/input/__init__.py | 6 ---- src/anemoi/datasets/recipe.py | 36 -------------------- 2 files changed, 42 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 095ce61ba..e30ecefb5 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -18,12 +18,6 @@ from anemoi.datasets.create.input.action import Recipe -class Context: - """Context for building input data.""" - - pass - - class InputBuilder: """Builder class for creating input data from configuration and data sources.""" diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index fdbcfb738..9acaadfdb 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -491,39 +491,3 @@ def test(self, output="recipe.zarr"): args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) cmd.run(args) - - -if __name__ == "__main__": - - r = Recipe() - r.description = "test" - - r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") - - m1 = r.mars(expver="0001", grid=[20, 20]) - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") - - r.input = m1 - - r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) - - # m0 = r.mars(expver="0000") - # c = r.concat( - # { - # ("190", "2000"): m0, - # ("2001", "2020"): r.mars(expver="0002"), - # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - # }, - # ) - - # c[("2031", "2033")] = r.mars(expver="0005") - - # r.input += c - - r.output.group_by = "day" - r.build.additions = True - r.statistics.end = "80%" - - r.dump() - r.test()