From 9cb8c52a46f2915185cf2dfe306b29ecbef690d7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 16:40:18 +0000 Subject: [PATCH 001/212] feat: recipe generator --- src/anemoi/datasets/recipe.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/anemoi/datasets/recipe.py diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py new file mode 100644 index 000000000..eed457f5b --- /dev/null +++ b/src/anemoi/datasets/recipe.py @@ -0,0 +1,23 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +class Recipe: + pass + + def dump(): + pass + + +if __name__ == "__main__": + r = Recipe() + r.description = "test" + + r.add(r.mars()) + r.add(r.rename(r.mars())) From 01774dda92f39b25aa44e455bfccc08ecfe252f4 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 18:15:39 +0100 Subject: [PATCH 002/212] update --- src/anemoi/datasets/recipe.py | 97 +++++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index eed457f5b..06f3e05eb 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -7,12 +7,101 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import logging -class Recipe: +import yaml +from anemoi.transform.filters import filter_registry as transform_filter_registry + +from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry +from anemoi.datasets.create.sources import source_registry + +LOG = logging.getLogger(__name__) + + +class Step: + def __init__(self, owner, *args, **kwargs): + self.owner = owner + self.args = args + self.kwargs = kwargs + + def as_dict(self): + return {self.owner.name: self.kwargs} + + +class Source(Step): pass - def dump(): - pass + +class Filter(Step): + pass + + +class SourceMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + return Source(self, *args, **kwargs) + + +class FilterMaker: + def __init__(self, name, factory): + self.name = name + self.factory = factory + + def __call__(self, *args, **kwargs): + return Filter(self, *args, **kwargs) + + +class Recipe: + + def __init__(self): + self.description = None + self._steps = [] + + sources = source_registry.factories.copy() + filters = transform_filter_registry.factories.copy() + + for key, factory in datasets_filter_registry.factories.items(): + if key in filters: + LOG.warning( + f"Filter `{key}` is registered in anemoi.datasets filter registry and in anemoi.transform filter registry" + ) + filters[key] = factory + + for key, factory in sources.items(): + if key in filters: + LOG.warning( + f"Source `{key}` is registered in anemoi.datasets source registry and in anemoi.transform filter registry" + ) + del filters[key] + + for key, factory in sources.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, SourceMaker(key, factory)) + + for key, factory in filters.items(): + key = key.replace("-", "_") + assert not hasattr(self, key) + setattr(self, key, FilterMaker(key, factory)) + + def add(self, step): + self._steps.append(step) + + def dump(self): + result = { + "description": self.description, + "input": [s.as_dict() for s in self._steps], + } + + if len(result["input"]) == 1: + result = result["input"][0] + else: + result["input"] = {"join": result["input"]} + + print(yaml.safe_dump(result)) if __name__ == "__main__": @@ -21,3 +110,5 @@ def dump(): r.add(r.mars()) r.add(r.rename(r.mars())) + + r.dump() From 33793a674df19b922f454055b6d4813303cb522f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 18:24:01 +0100 Subject: [PATCH 003/212] update --- src/anemoi/datasets/recipe.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 06f3e05eb..f8b01caa5 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -29,11 +29,23 @@ def as_dict(self): class Source(Step): - pass + def __init__(self, owner, **kwargs): + super().__init__(owner, **kwargs) + self._source = None class Filter(Step): - pass + def __init__(self, owner, previous, **kwargs): + super().__init__(owner, **kwargs) + self.previous = previous + + def as_dict(self): + prev = self.previous.as_dict() + if isinstance(prev, dict) and "pipe" in prev: + prev = prev.copy() + prev["pipe"] = prev["pipe"].copy() + [super().as_dict()] + return prev + return {"pipe": [prev, super().as_dict()]} class SourceMaker: @@ -109,6 +121,6 @@ def dump(self): r.description = "test" r.add(r.mars()) - r.add(r.rename(r.mars())) + r.add(r.rescale(r.rename(r.mars()))) r.dump() From 920d52303ee56836e7303da7cc4d3c6ec680b074 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 19:14:49 +0100 Subject: [PATCH 004/212] update --- src/anemoi/datasets/recipe.py | 114 ++++++++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 26 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index f8b01caa5..5702db119 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -19,33 +19,60 @@ class Step: + + def __or__(self, other): + return Pipe(self, other) + + def __add__(self, other): + return Join(self, other) + + +class Chain(Step): + def __init__(self, *args): + if len(args) > 0 and isinstance(args[0], self.__class__): + args = args[0].steps + args[1:] + + self.steps = args + + def as_dict(self): + if len(self.steps) == 1: + return self.steps[0].as_dict() + return {self.name: [s.as_dict() for s in self.steps]} + + def __repr__(self): + return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + + +class Pipe(Chain): + name = "pipe" + + +class Join(Chain): + name = "join" + + +class Base(Step): def __init__(self, owner, *args, **kwargs): self.owner = owner - self.args = args - self.kwargs = kwargs + self.params = {} + for a in args: + assert isinstance(a, dict), f"Invalid argument {a}" + self.params.update(a) + self.params.update(kwargs) def as_dict(self): - return {self.owner.name: self.kwargs} + return {self.owner.name: self.params} + def __repr__(self): + return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" -class Source(Step): - def __init__(self, owner, **kwargs): - super().__init__(owner, **kwargs) - self._source = None +class Source(Base): + pass -class Filter(Step): - def __init__(self, owner, previous, **kwargs): - super().__init__(owner, **kwargs) - self.previous = previous - def as_dict(self): - prev = self.previous.as_dict() - if isinstance(prev, dict) and "pipe" in prev: - prev = prev.copy() - prev["pipe"] = prev["pipe"].copy() + [super().as_dict()] - return prev - return {"pipe": [prev, super().as_dict()]} +class Filter(Base): + pass class SourceMaker: @@ -63,6 +90,10 @@ def __init__(self, name, factory): self.factory = factory def __call__(self, *args, **kwargs): + if len(args) > 0 and isinstance(args[0], Step): + prev = args[0] + args = args[1:] + return Pipe(prev, Filter(self, *args, **kwargs)) return Filter(self, *args, **kwargs) @@ -105,14 +136,9 @@ def add(self, step): def dump(self): result = { "description": self.description, - "input": [s.as_dict() for s in self._steps], + "input": Join(*self._steps).as_dict(), } - if len(result["input"]) == 1: - result = result["input"][0] - else: - result["input"] = {"join": result["input"]} - print(yaml.safe_dump(result)) @@ -120,7 +146,43 @@ def dump(self): r = Recipe() r.description = "test" - r.add(r.mars()) - r.add(r.rescale(r.rename(r.mars()))) + # r.add( + # r.mars( + # expver="0001", + # levtype="sfc", + # param=["2t"], + # number=[0, 1], + # ) + # ) + + # r.add( + # r.rescale( + # r.rename( + # r.mars( + # expver="0002", + # levtype="sfc", + # param=["2t"], + # number=[0, 1], + # ), + # param={"2t": "2t_0002"}, + # ), + # {"2t_0002": ["mm", "m"]}, + # ) + # ) + + m1 = r.mars(expver="0001", levtype="sfc", param=["2t"], number=[0, 1]) + m2 = r.mars(expver="0002", levtype="sfc", param=["2t"], number=[0, 1]) + + m3 = r.mars(expver="0003", levtype="sfc", param=["2t"], number=[0, 1]) + + r.add( + (m1 + m2 + m3) + | r.rename( + param={"2t": "2t_0002"}, + ) + | r.rescale( + {"2t_0002": ["mm", "m"]}, + ) + ) r.dump() From ed5d190b0fdf13fdcfa8909dacb9522bacec6e87 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 8 May 2025 21:31:40 +0100 Subject: [PATCH 005/212] update --- src/anemoi/datasets/recipe.py | 183 +++++++++++++++++++++++++--------- 1 file changed, 136 insertions(+), 47 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 5702db119..9d8b17545 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -18,6 +18,14 @@ LOG = logging.getLogger(__name__) +class Index: + def __init__(self, index): + self.name = str(index) + + def __repr__(self): + return f"Index({self.name})" + + class Step: def __or__(self, other): @@ -33,15 +41,23 @@ def __init__(self, *args): args = args[0].steps + args[1:] self.steps = args + self.index = [Index(i) for i in range(len(self.steps))] - def as_dict(self): + def as_dict(self, recipe): if len(self.steps) == 1: - return self.steps[0].as_dict() - return {self.name: [s.as_dict() for s in self.steps]} + return self.steps[0].as_dict(recipe) + return {self.name: [s.as_dict(recipe) for s in self.steps]} def __repr__(self): return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + def path(self, target, result, *path): + for i, s in enumerate(self.steps): + s.path(target, result, *path, self, self.index[i]) + + def collocated(self, a, b): + return True + class Pipe(Chain): name = "pipe" @@ -51,21 +67,71 @@ class Join(Chain): name = "join" +class Concat(Step): + name = "concat" + + def __init__(self, args): + assert isinstance(args, dict), f"Invalid argument {args}" + self.params = args + + def __setitem__(self, key, value): + self.params[key] = value + + def as_dict(self, recipe): + + result = [] + + for k, v in sorted(self.params.items()): + result.append({"dates": dict(start=k[0], end=k[1]), **v.as_dict(recipe)}) + + return {"concat": result} + + def collocated(self, a, b): + return a[0] is b[0] + + def path(self, target, result, *path): + + for i, (k, v) in enumerate(sorted(self.params.items())): + v.path(target, result, *path, self, Index(i)) + + class Base(Step): def __init__(self, owner, *args, **kwargs): self.owner = owner + self.name = owner.name self.params = {} for a in args: assert isinstance(a, dict), f"Invalid argument {a}" self.params.update(a) self.params.update(kwargs) - def as_dict(self): - return {self.owner.name: self.params} + def as_dict(self, recipe): + + def resolve(params, recipe): + if isinstance(params, dict): + return {k: resolve(v, recipe) for k, v in params.items()} + + if isinstance(params, (list, tuple)): + return [resolve(v, recipe) for v in params] + + if isinstance(params, list): + return [resolve(v, recipe) for v in params] + + if isinstance(params, Step): + return recipe.resolve(self, params) + + return params + + return {self.owner.name: resolve(self.params, recipe)} def __repr__(self): return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" + def path(self, target, result, *path): + + if self is target: + result.append([*path, self]) + class Source(Base): pass @@ -101,7 +167,7 @@ class Recipe: def __init__(self): self.description = None - self._steps = [] + self.input = Join() sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -130,59 +196,82 @@ def __init__(self): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) - def add(self, step): - self._steps.append(step) - def dump(self): result = { "description": self.description, - "input": Join(*self._steps).as_dict(), + "input": self.input.as_dict(self), } print(yaml.safe_dump(result)) + def concat(self, *args, **kwargs): + return Concat(*args, **kwargs) + + def resolve(self, source, target): + assert isinstance(target, Source), f"Only sources can be used as template {target}" + + top = Index("input") # So we have 'input' first in the path + + path_to_source = [] + self.input.path(source, path_to_source, top) + if len(path_to_source) == 0: + raise ValueError(f"Source {source} not found in recipe") + if len(path_to_source) > 1: + raise ValueError(f"Source {source} found in multiple locations {path_to_source}") + path_to_source = path_to_source[0] + + path_to_target = [] + self.input.path(target, path_to_target, top) + if len(path_to_target) == 0: + raise ValueError(f"Target {target} not found in recipe") + if len(path_to_target) > 1: + raise ValueError(f"Target {target} found in multiple locations {path_to_target}") + path_to_target = path_to_target[0] + + a = [s for s in path_to_target] + b = [s for s in path_to_source] + common_ancestor = None + while a[0] is b[0]: + common_ancestor = a[0] + a = a[1:] + b = b[1:] + + assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" + + if not common_ancestor.collocated(a, b): + source = ".".join(s.name for s in path_to_source) + target = ".".join(s.name for s in path_to_target) + raise ValueError( + f"Source ${{{source}}} and target ${{{target}}} are not collocated (i.e. they are not branch of a 'concat')" + ) + + target = ".".join(s.name for s in path_to_target) + return f"${{{target}}}" + if __name__ == "__main__": r = Recipe() r.description = "test" - # r.add( - # r.mars( - # expver="0001", - # levtype="sfc", - # param=["2t"], - # number=[0, 1], - # ) - # ) - - # r.add( - # r.rescale( - # r.rename( - # r.mars( - # expver="0002", - # levtype="sfc", - # param=["2t"], - # number=[0, 1], - # ), - # param={"2t": "2t_0002"}, - # ), - # {"2t_0002": ["mm", "m"]}, - # ) - # ) - - m1 = r.mars(expver="0001", levtype="sfc", param=["2t"], number=[0, 1]) - m2 = r.mars(expver="0002", levtype="sfc", param=["2t"], number=[0, 1]) - - m3 = r.mars(expver="0003", levtype="sfc", param=["2t"], number=[0, 1]) - - r.add( - (m1 + m2 + m3) - | r.rename( - param={"2t": "2t_0002"}, - ) - | r.rescale( - {"2t_0002": ["mm", "m"]}, - ) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") + + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, ) + c[("2031", "2033")] = r.mars(expver="0005") + + r.input += c + r.dump() From 2562baf9ad8cf1be050b2605794bab4a3adf956b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:26:08 +0000 Subject: [PATCH 006/212] fix: better handling of xarray metadata --- .../create/sources/xarray_support/field.py | 13 +- .../create/sources/xarray_support/metadata.py | 123 ++---------------- 2 files changed, 19 insertions(+), 117 deletions(-) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 663aeab54..d46613474 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -80,12 +80,21 @@ def __init__(self, owner: Any, selection: Any) -> None: # Copy the metadata from the owner self._md = owner._metadata.copy() + aliases = {} for coord_name, coord_value in self.selection.coords.items(): if is_scalar(coord_value): # Extract the single value from the scalar dimension # and store it in the metadata coordinate = owner.by_name[coord_name] - self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value)) + normalised = coordinate.normalise(extract_single_value(coord_value)) + self._md[coord_name] = normalised + for alias in coordinate.mars_names: + aliases[alias] = normalised + + # Add metadata aliases (e.g. levelist == level) only if they are not already present + for alias, value in aliases.items(): + if alias not in self._md: + self._md[alias] = value # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude @@ -188,7 +197,7 @@ def forecast_reference_time(self) -> datetime.datetime: def __repr__(self) -> str: """Return a string representation of the field.""" - return repr(self._metadata) + return f"XArrayField({self._metadata})" def _values(self, dtype: Optional[type] = None) -> Any: """Return the values of the selection. diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 2a0e4d9cb..80f633c3c 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -23,87 +23,6 @@ LOG = logging.getLogger(__name__) -class _MDMapping: - """A class to handle metadata mapping for variables. - - Attributes - ---------- - variable : Any - The variable to map. - time : Any - The time associated with the variable. - mapping : Dict[str, str] - A dictionary mapping keys to variable names. - """ - - def __init__(self, variable: Any) -> None: - """Initialize the _MDMapping class. - - Parameters - ---------- - variable : Any - The variable to map. - """ - self.variable = variable - self.time = variable.time - self.mapping = dict() - # Aliases - - def _from_user(self, key: str) -> str: - """Get the internal key corresponding to a user-provided key. - - Parameters - ---------- - key : str - The user-provided key. - - Returns - ------- - str - The internal key corresponding to the user-provided key. - """ - return self.mapping.get(key, key) - - def from_user(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: - """Convert user-provided keys to internal keys. - - Parameters - ---------- - kwargs : Dict[str, Any] - A dictionary of user-provided keys and values. - - Returns - ------- - Dict[str, Any] - A dictionary with internal keys and original values. - """ - return {self._from_user(k): v for k, v in kwargs.items()} - - def __repr__(self) -> str: - """Return a string representation of the _MDMapping object. - - Returns - ------- - str - String representation of the _MDMapping object. - """ - return f"MDMapping({self.mapping})" - - def fill_time_metadata(self, field: Any, md: Dict[str, Any]) -> None: - """Fill the time metadata for a field. - - Parameters - ---------- - field : Any - The field to fill metadata for. - md : Dict[str, Any] - The metadata dictionary to update. - """ - valid_datetime = self.variable.time.fill_time_metadata(field._md, md) - if valid_datetime is not None: - md["valid_datetime"] = as_datetime(valid_datetime).isoformat() - - class XArrayMetadata(RawMetadata): """A class to handle metadata for XArray fields. @@ -129,10 +48,16 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ + from .field import XArrayField + + assert isinstance(field, XArrayField), type(field) self._field = field md = field._md.copy() - self._mapping = _MDMapping(field.owner) - self._mapping.fill_time_metadata(field, md) + + valid_datetime = field.owner.time.fill_time_metadata(field._md, md) + if valid_datetime is not None: + md["valid_datetime"] = as_datetime(valid_datetime).isoformat() + super().__init__(md) @cached_property @@ -192,38 +117,6 @@ def _valid_datetime(self) -> Optional[datetime.datetime]: """ return self._get("valid_datetime") - def get(self, key: str, astype: Optional[type] = None, **kwargs: Any) -> Any: - """Get a metadata value by key. - - Parameters - ---------- - key : str - The key to get the value for. - astype : Optional[type] - The type to cast the value to. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The value for the specified key, optionally cast to the specified type. - """ - - if key == "levelist": - # Special case for levelist, for compatibility with GRIB - if key not in self._d and "level" in self._d: - key = "level" - - if key in self._d: - if astype is not None: - return astype(self._d[key]) - return self._d[key] - - key = self._mapping._from_user(key) - - return super().get(key, astype=astype, **kwargs) - class XArrayFieldGeography(Geography): """A class to handle geography information for XArray fields. From f048a4b7c84ba4961b32210503920d29c61711b7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:30:52 +0000 Subject: [PATCH 007/212] update --- src/anemoi/datasets/commands/copy.py | 2 + src/anemoi/datasets/create/__init__.py | 18 +++++ src/anemoi/datasets/create/input/__init__.py | 8 +++ src/anemoi/datasets/create/input/action.py | 65 ++++++++++++++++++- src/anemoi/datasets/create/input/filter.py | 17 +++++ src/anemoi/datasets/create/input/function.py | 10 ++- src/anemoi/datasets/create/input/join.py | 3 + src/anemoi/datasets/create/input/pipe.py | 7 ++ .../datasets/create/input/repeated_dates.py | 9 +++ src/anemoi/datasets/create/input/step.py | 16 ++++- 10 files changed, 151 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index ea08c8aab..9cdca5a4d 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -505,6 +505,7 @@ def add_arguments(self, command_parser: Any) -> None: default=100, help="For optimisation purposes, data is transfered by blocks. Default is 100.", ) + command_parser.add_argument("--workdir", help="Working directory for the copy operation.", default=".") command_parser.add_argument("source", help="Source location.") command_parser.add_argument("target", help="Target location.") @@ -534,6 +535,7 @@ def run(self, args: Any) -> None: resume=args.resume, verbosity=args.verbosity, threads=args.transfers, + workdir=args.workdir, ) copier.run() return diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 081c618ed..fb46dd5d1 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1628,3 +1628,21 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An )[name] LOG.debug(f"Creating {cls.__name__} with {kwargs}") return cls(**kwargs) + + +def config_to_python(config: Any) -> Any: + import sys + + config = loader_config(config) + input = build_input_(config, build_output(config.output, None)) + code = input.to_python() + + code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();r.input = {code}; r.dump()" + + try: + import black + + return black.format_str(code, mode=black.Mode()) + except ImportError: + LOG.warning("Black not installed, skipping formatting") + return code diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index cd0b4ad40..734175698 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -75,6 +75,14 @@ def select(self, group_of_dates: GroupOfDates) -> Any: action = action_factory(self.config, context, self.action_path) return action.select(group_of_dates) + def to_python(self) -> str: + from .action import ActionContext + from .action import action_factory + + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.to_python() + def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index eadf01339..98936750e 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,6 +9,7 @@ import json import logging +import re from copy import deepcopy from typing import Any from typing import Dict @@ -159,6 +160,68 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" + def _to_python(self, name, config): + """Convert the action to Python code. + + Parameters + ---------- + name : str + The name of the action. + config : dict + The configuration for the action. + + Returns + ------- + str + The Python code representation of the action. + """ + import json + + RESERVED_KEYWORDS = ( + "and", + "or", + "not", + "is", + "in", + "if", + "else", + "elif", + "for", + "while", + "return", + "class", + "def", + "with", + "as", + "import", + "from", + "try", + "except", + "finally", + "raise", + "assert", + "break", + "continue", + "pass", + ) + + config = json.loads(json.dumps(config)) + + assert len(config) == 1, (name, config) + assert name in config, (name, config) + + config = config[name] + + params = [] + for k, v in config.items(): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" + class ActionContext(Context): """Represents the context in which an action is performed. @@ -254,6 +317,6 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path: from ..sources import create_source source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source) + return FunctionAction(context, action_path + [key], key, source, config) return cls(context, action_path + [key], *args, **kwargs) diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index 289bb3602..f4c55fa67 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -95,6 +95,7 @@ def __init__( previous_step: StepAction, name: str, filter: Any, + config: dict, *args: Any, **kwargs: Any, ) -> None: @@ -116,3 +117,19 @@ def __init__( super().__init__(context, action_path, previous_step, *args, **kwargs) self.name = name self.filter = filter + self.config = config + + def to_python(self) -> Any: + """Converts the action to Python code. + + Parameters + ---------- + file : str + The file to convert. + + Returns + ------- + Any + The converted Python code. + """ + return self._to_python(self.name, self.config) diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index 4d3d21b22..bb33a678a 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -91,7 +91,9 @@ class FunctionAction(Action): The name of the function. """ - def __init__(self, context: object, action_path: list, _name: str, source, **kwargs: Dict[str, Any]) -> None: + def __init__( + self, context: object, action_path: list, _name: str, source, config: dict, **kwargs: Dict[str, Any] + ) -> None: """Initializes a FunctionAction instance. Parameters @@ -108,6 +110,12 @@ def __init__(self, context: object, action_path: list, _name: str, source, **kwa super().__init__(context, action_path, **kwargs) self.name: str = _name self.source = source + self.config = config + + def to_python(self) -> str: + """Returns the Python representation of the function action.""" + + return self._to_python(self.name, self.config) @trace_select def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py index ba24c7072..761922271 100644 --- a/src/anemoi/datasets/create/input/join.py +++ b/src/anemoi/datasets/create/input/join.py @@ -107,6 +107,9 @@ def __init__(self, context: object, action_path: list, *configs: dict) -> None: super().__init__(context, action_path, *configs) self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] + def to_python(self) -> None: + return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" + def __repr__(self) -> str: """Returns a string representation of the JoinAction instance.""" content: str = "\n".join([str(i) for i in self.actions]) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py index 6c9fea0df..e2fda03e4 100644 --- a/src/anemoi/datasets/create/input/pipe.py +++ b/src/anemoi/datasets/create/input/pipe.py @@ -40,9 +40,13 @@ def __init__(self, context: Any, action_path: list, *configs: dict) -> None: f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" ) + self.actions: list = [] + current: Any = action_factory(configs[0], context, action_path + ["0"]) + self.actions.append(current) for i, c in enumerate(configs[1:]): current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) + self.actions.append(current) self.last_step: Any = current @trace_select @@ -64,3 +68,6 @@ def select(self, group_of_dates: Any) -> Any: def __repr__(self) -> str: """Return a string representation of the PipeAction.""" return f"PipeAction({self.last_step})" + + def to_python(self) -> str: + return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index ebf13b36e..788bf7f10 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -351,6 +351,15 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) + def to_python(self) -> None: + """Convert the action to Python code. + + Args: + file (Any): The file to write the Python code to. + """ + return self.source.to_python() + # self.mapper.to_python(file) + @trace_select def select(self, group_of_dates: Any) -> JoinResult: """Select and transform the group of dates. diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py index e99717094..a2e3ccd42 100644 --- a/src/anemoi/datasets/create/input/step.py +++ b/src/anemoi/datasets/create/input/step.py @@ -177,7 +177,14 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") filter = create_datasets_filter(None, config) - return FunctionStepAction(context, action_path + [key], previous_step, key, filter) + return FunctionStepAction( + context, + action_path + [key], + previous_step, + key, + filter, + config, + ) # Use filters from transform registry @@ -185,7 +192,12 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li from ..filters.transform import TransformFilter return FunctionStepAction( - context, action_path + [key], previous_step, key, TransformFilter(context, key, config) + context, + action_path + [key], + previous_step, + key, + TransformFilter(context, key, config), + config, ) raise ValueError(f"Unknown step action `{key}`") From 8ad5eb032e26cdace61ade8b5caa8b85c948a4d6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 09:35:19 +0000 Subject: [PATCH 008/212] update --- src/anemoi/datasets/create/__init__.py | 1 - src/anemoi/datasets/create/input/action.py | 2 +- src/anemoi/datasets/create/input/filter.py | 9 +--- .../datasets/create/input/repeated_dates.py | 11 ++--- src/anemoi/datasets/recipe.py | 42 +++++++++++-------- 5 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index fb46dd5d1..34504e974 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1631,7 +1631,6 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An def config_to_python(config: Any) -> Any: - import sys config = loader_config(config) input = build_input_(config, build_output(config.output, None)) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 98936750e..4cadff1c1 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -160,7 +160,7 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" - def _to_python(self, name, config): + def _to_python(self, name: str, config: dict) -> str: """Convert the action to Python code. Parameters diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index f4c55fa67..298cead41 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -119,17 +119,12 @@ def __init__( self.filter = filter self.config = config - def to_python(self) -> Any: + def to_python(self) -> str: """Converts the action to Python code. - Parameters - ---------- - file : str - The file to convert. - Returns ------- - Any + str The converted Python code. """ return self._to_python(self.name, self.config) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index 788bf7f10..0f5b5730a 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -9,6 +9,7 @@ import logging +import warnings from collections import defaultdict from typing import Any from typing import Dict @@ -351,14 +352,10 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) - def to_python(self) -> None: - """Convert the action to Python code. - - Args: - file (Any): The file to write the Python code to. - """ + def to_python(self) -> str: + """Convert the action to Python code.""" + warnings.warn("RepeatedDatesAction.to_python is still a work in progress") return self.source.to_python() - # self.mapper.to_python(file) @trace_select def select(self, group_of_dates: Any) -> JoinResult: diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 9d8b17545..5594682dd 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -250,28 +250,34 @@ def resolve(self, source, target): if __name__ == "__main__": - r = Recipe() - r.description = "test" - m1 = r.mars(expver="0001") - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") + if False: + r = Recipe() + r.description = "test" - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - c[("2031", "2033")] = r.mars(expver="0005") + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, + ) - r.input += c + c[("2031", "2033")] = r.mars(expver="0005") - r.dump() + r.input += c + + r.dump() + else: + from anemoi.datasets.create import config_to_python + + print(config_to_python("x.yaml")) From b33f3acfc990670c090f70b345a41ded802e5fae Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 11:42:35 +0000 Subject: [PATCH 009/212] fix: support other keys that param in rename filter --- src/anemoi/datasets/create/filters/rename.py | 22 +++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/anemoi/datasets/create/filters/rename.py b/src/anemoi/datasets/create/filters/rename.py index ddae03f0f..4c87bd4bb 100644 --- a/src/anemoi/datasets/create/filters/rename.py +++ b/src/anemoi/datasets/create/filters/rename.py @@ -45,9 +45,7 @@ def __init__(self, field: Any, what: str, renaming: Dict[str, Dict[str, str]]) - """ self.field = field self.what = what - self.renaming = {} - for k, v in renaming.items(): - self.renaming[k] = {str(a): str(b) for a, b in v.items()} + self.renaming = renaming.copy() def metadata(self, key: Optional[str] = None, **kwargs: Any) -> Any: """Get metadata from the original field, with the option to rename the parameter. @@ -69,7 +67,7 @@ def metadata(self, key: Optional[str] = None, **kwargs: Any) -> Any: value = self.field.metadata(key, **kwargs) if key == self.what: - return self.renaming.get(self.what, {}).get(value, value) + return self.renaming.get(value, value) return value @@ -179,7 +177,7 @@ def __repr__(self) -> str: @legacy_filter(__file__) -def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: Any) -> ekd.FieldList: +def execute(context: Any, input: ekd.FieldList, **kwargs: Any) -> ekd.FieldList: """Rename fields based on the value of another field or a format string. Parameters @@ -188,8 +186,6 @@ def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: A The context in which the function is executed. input : List[Any] List of input fields. - what : str, optional - The field to be used for renaming. Defaults to "param". **kwargs : Any Additional keyword arguments for renaming. @@ -199,7 +195,13 @@ def execute(context: Any, input: ekd.FieldList, what: str = "param", **kwargs: A Array of renamed fields. """ - if what in kwargs and isinstance(kwargs[what], str): - return FieldArray([RenamedFieldFormat(fs, what, kwargs[what]) for fs in input]) + for k, v in kwargs.items(): - return FieldArray([RenamedFieldMapping(fs, what, kwargs) for fs in input]) + if not isinstance(v, dict): + input = [RenamedFieldMapping(fs, k, v) for fs in input] + elif isinstance(v, str): + input = [RenamedFieldFormat(fs, k, v) for fs in input] + else: + raise ValueError("Invalid renaming dictionary. Values must be strings or dictionaries.") + + return FieldArray(input) From 6d23027d3668fa0394ce6c4e17799f828199c2d9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 11:45:59 +0000 Subject: [PATCH 010/212] typo --- src/anemoi/datasets/create/filters/rename.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/filters/rename.py b/src/anemoi/datasets/create/filters/rename.py index 4c87bd4bb..5905d35b2 100644 --- a/src/anemoi/datasets/create/filters/rename.py +++ b/src/anemoi/datasets/create/filters/rename.py @@ -197,11 +197,11 @@ def execute(context: Any, input: ekd.FieldList, **kwargs: Any) -> ekd.FieldList: for k, v in kwargs.items(): - if not isinstance(v, dict): + if isinstance(v, dict): input = [RenamedFieldMapping(fs, k, v) for fs in input] elif isinstance(v, str): input = [RenamedFieldFormat(fs, k, v) for fs in input] else: - raise ValueError("Invalid renaming dictionary. Values must be strings or dictionaries.") + raise ValueError(f"Invalid renaming dictionary. Values must be strings or dictionaries. ({type(v)})") return FieldArray(input) From 9179daebf86230785d5e1102f9578b1c640909a2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 10 May 2025 17:52:28 +0100 Subject: [PATCH 011/212] add command line --- .gitignore | 1 + src/anemoi/datasets/commands/recipe.py | 39 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 src/anemoi/datasets/commands/recipe.py diff --git a/.gitignore b/.gitignore index 4b70d9e82..158ba7bdd 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,4 @@ Untitled-*.py *.db *.tgz _api/ +trace.txt diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py new file mode 100644 index 000000000..b028969f6 --- /dev/null +++ b/src/anemoi/datasets/commands/recipe.py @@ -0,0 +1,39 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +from . import Command + +LOG = logging.getLogger(__name__) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + from anemoi.datasets.create import config_to_python + + print(config_to_python(args.path)) + + +command = Recipe From 79a391be1741fffe9bb960590a4a43eef0217050 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 05:34:18 +0100 Subject: [PATCH 012/212] update --- src/anemoi/datasets/create/__init__.py | 8 +- src/anemoi/datasets/create/config.py | 2 + src/anemoi/datasets/create/input/__init__.py | 9 ++ src/anemoi/datasets/create/input/concat.py | 16 +++ .../datasets/create/input/data_sources.py | 10 ++ src/anemoi/datasets/dates/__init__.py | 11 ++ src/anemoi/datasets/recipe.py | 103 +++++++++++++----- 7 files changed, 130 insertions(+), 29 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 34504e974..0aa8fc8e4 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -11,6 +11,7 @@ import json import logging import os +import re import time import uuid import warnings @@ -1634,9 +1635,12 @@ def config_to_python(config: Any) -> Any: config = loader_config(config) input = build_input_(config, build_output(config.output, None)) - code = input.to_python() + code1 = input.python_prelude() + code2 = input.to_python() - code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();r.input = {code}; r.dump()" + code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();{code1};r.input = {code2}; r.dump()" + + code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) try: import black diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 1042709a6..6de4a06cd 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -282,6 +282,8 @@ def __init__(self, config: dict, *args, **kwargs): self.output.order_by = normalize_order_by(self.output.order_by) + self.setdefault("dates", Config()) + self.dates["group_by"] = self.build.group_by ########### diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 734175698..b0b23f1b2 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -81,8 +81,17 @@ def to_python(self) -> str: context = ActionContext(**self.kwargs) action = action_factory(self.config, context, self.action_path) + return action.to_python() + def python_prelude(self) -> str: + from .action import ActionContext + from .action import action_factory + + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.python_prelude() + def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index 5399bbc1f..777610a18 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -162,3 +162,19 @@ def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResul return EmptyResult(self.context, self.action_path, group_of_dates) return ConcatResult(self.context, self.action_path, group_of_dates, results) + + def to_python(self) -> str: + """Returns the Python representation of the ConcatAction instance. + + Returns + ------- + str + The Python representation of the ConcatAction instance. + """ + + result = [] + + for i, (filtering_dates, action) in enumerate(self.parts): + result.append(f"{filtering_dates.to_python()}:{action.to_python()}") + + return f"r.concat({{{','.join(result)})" diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index b95f85568..d1ca2dc6e 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -58,6 +58,7 @@ def __init__( self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] self.input = action_factory(input, context, ["input"]) + self.names = [a_path for a_path, config in configs] def select(self, group_of_dates: GroupOfDates) -> "DataSourcesResult": """Selects the data sources result for the given group of dates. @@ -86,6 +87,15 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) + def python_prelude(self) -> str: + result = [] + for n, s in zip(self.names, self.sources): + result.append(f"{n}={s.to_python()}") + return ";".join(result) + + def to_python(self) -> str: + return self.input.to_python() + class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 99771faf1..9570d381f 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -273,6 +273,17 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) + def to_python(self) -> str: + """Convert the StartEndDates instance to a Python string. + + Returns + ------- + str + Python string representation of the instance. + """ + # assert self.frequency == frequency_to_timedelta(1), self.frequency + return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) + class Hindcast: """Class representing a single hindcast date. diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 5594682dd..2cbef29f7 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -11,6 +11,7 @@ import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry +from anemoi.utils.config import DotDict from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry from anemoi.datasets.create.sources import source_registry @@ -165,9 +166,17 @@ def __call__(self, *args, **kwargs): class Recipe: - def __init__(self): - self.description = None + def __init__(self, name=None, description=None, attribution=None, licence=None): + + self._description = description + self._attribution = attribution + self._licence = licence + self._name = name + self.input = Join() + self.output = DotDict() + self.statistics = DotDict() + self.build = DotDict() sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -197,12 +206,25 @@ def __init__(self): setattr(self, key, FilterMaker(key, factory)) def dump(self): + result = self.as_dict(self) + result["input"] = self.input.as_dict(self) + result["output"] = self.description + + print(yaml.safe_dump(result)) + + def as_dict(self): result = { + "name": self.name, "description": self.description, - "input": self.input.as_dict(self), + "attribution": self.attribution, + "licence": self.licence, } - print(yaml.safe_dump(result)) + for k, v in list(result.items()): + if v is None: + del result[k] + + return result def concat(self, *args, **kwargs): return Concat(*args, **kwargs) @@ -248,36 +270,63 @@ def resolve(self, source, target): target = ".".join(s.name for s in path_to_target) return f"${{{target}}}" + @property + def description(self): + return self._description -if __name__ == "__main__": + @description.setter + def description(self, value): + self._description = value - if False: - r = Recipe() - r.description = "test" + @property + def attribution(self): + return self._attribution + + @attribution.setter + def attribution(self, value): + self._attribution = value + + @property + def licence(self): + return self._licence + + @licence.setter + def licence(self, value): + self._licence = value + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + + +if __name__ == "__main__": - m1 = r.mars(expver="0001") - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") + r = Recipe() + r.description = "test" - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + m1 = r.mars(expver="0001") + m2 = r.mars(expver="0002") + m3 = r.mars(expver="0003") - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - c[("2031", "2033")] = r.mars(expver="0005") + m0 = r.mars(expver="0000") + c = r.concat( + { + ("1900", "2000"): m0, + ("2001", "2020"): r.mars(expver="0002"), + ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + }, + ) - r.input += c + c[("2031", "2033")] = r.mars(expver="0005") - r.dump() - else: - from anemoi.datasets.create import config_to_python + r.input += c - print(config_to_python("x.yaml")) + r.dump() From 203e09bd5bbef3a06c214cfc975db4bdde9ff469 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 05:44:35 +0100 Subject: [PATCH 013/212] update --- src/anemoi/datasets/create/__init__.py | 2 +- src/anemoi/datasets/recipe.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 0aa8fc8e4..0b0cb18ba 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1638,7 +1638,7 @@ def config_to_python(config: Any) -> Any: code1 = input.python_prelude() code2 = input.to_python() - code = f"from anemoi.datasets.recipe import Recipe;r = Recipe();{code1};r.input = {code2}; r.dump()" + code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 2cbef29f7..c1b8bf8f7 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -206,9 +206,9 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): setattr(self, key, FilterMaker(key, factory)) def dump(self): - result = self.as_dict(self) + result = self.as_dict() result["input"] = self.input.as_dict(self) - result["output"] = self.description + # result["output"] = self.description print(yaml.safe_dump(result)) From b4433bd35667b4238e4724490c589bb637276212 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 10:11:34 +0100 Subject: [PATCH 014/212] update --- src/anemoi/datasets/create/__init__.py | 8 +- src/anemoi/datasets/create/input/__init__.py | 4 +- src/anemoi/datasets/create/input/action.py | 10 +- src/anemoi/datasets/create/input/concat.py | 4 + .../datasets/create/input/data_sources.py | 7 +- src/anemoi/datasets/create/input/filter.py | 3 + src/anemoi/datasets/create/input/function.py | 3 + src/anemoi/datasets/create/input/join.py | 4 + src/anemoi/datasets/create/input/pipe.py | 4 + src/anemoi/datasets/recipe.py | 101 ++++++++++++++---- 10 files changed, 120 insertions(+), 28 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 0b0cb18ba..e50695f4e 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1635,7 +1635,11 @@ def config_to_python(config: Any) -> Any: config = loader_config(config) input = build_input_(config, build_output(config.output, None)) - code1 = input.python_prelude() + + prelude = [] + input.python_prelude(prelude) + code1 = "\n".join(prelude) + code2 = input.to_python() code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" @@ -1646,6 +1650,6 @@ def config_to_python(config: Any) -> Any: import black return black.format_str(code, mode=black.Mode()) - except ImportError: + except Exception: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index b0b23f1b2..66266d53d 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -84,13 +84,13 @@ def to_python(self) -> str: return action.to_python() - def python_prelude(self) -> str: + def python_prelude(self, prelude) -> str: from .action import ActionContext from .action import action_factory context = ActionContext(**self.kwargs) action = action_factory(self.config, context, self.action_path) - return action.python_prelude() + return action.python_prelude(prelude) def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 4cadff1c1..6057106ef 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -288,9 +288,15 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path: assert isinstance(context, Context), (type, context) if not isinstance(config, dict): raise ValueError(f"Invalid input config {config}") + if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") + if "label" in config: + config.pop("label") + if "name" in config: + config.pop("name") + if len(config) != 1: + print(json.dumps(config, indent=2, default=str)) + raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") config = deepcopy(config) key = list(config.keys())[0] diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index 777610a18..cef2a64c5 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -178,3 +178,7 @@ def to_python(self) -> str: result.append(f"{filtering_dates.to_python()}:{action.to_python()}") return f"r.concat({{{','.join(result)})" + + def python_prelude(self, prelude) -> None: + for filtering_dates, action in self.parts: + action.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index d1ca2dc6e..42c610315 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -87,11 +87,10 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_prelude(self) -> str: - result = [] + def python_prelude(self, prelude) -> str: for n, s in zip(self.names, self.sources): - result.append(f"{n}={s.to_python()}") - return ";".join(result) + self.sources.python_prelude(prelude) + prelude.append(f"{n}={s.to_python()}") def to_python(self) -> str: return self.input.to_python() diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py index 298cead41..9357d2178 100644 --- a/src/anemoi/datasets/create/input/filter.py +++ b/src/anemoi/datasets/create/input/filter.py @@ -128,3 +128,6 @@ def to_python(self) -> str: The converted Python code. """ return self._to_python(self.name, self.config) + + def python_prelude(self, prelude) -> None: + pass diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index bb33a678a..651b509b2 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -117,6 +117,9 @@ def to_python(self) -> str: return self._to_python(self.name, self.config) + def python_prelude(self, prelude) -> str: + pass + @trace_select def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": """Selects the function result for the given group of dates. diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py index 761922271..9fc81eac9 100644 --- a/src/anemoi/datasets/create/input/join.py +++ b/src/anemoi/datasets/create/input/join.py @@ -110,6 +110,10 @@ def __init__(self, context: object, action_path: list, *configs: dict) -> None: def to_python(self) -> None: return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" + def python_prelude(self, prelude) -> None: + for i in self.actions: + i.python_prelude(prelude) + def __repr__(self) -> str: """Returns a string representation of the JoinAction instance.""" content: str = "\n".join([str(i) for i in self.actions]) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py index e2fda03e4..7b8f062f3 100644 --- a/src/anemoi/datasets/create/input/pipe.py +++ b/src/anemoi/datasets/create/input/pipe.py @@ -71,3 +71,7 @@ def __repr__(self) -> str: def to_python(self) -> str: return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" + + def python_prelude(self, prelude) -> None: + for i in self.actions: + i.python_prelude(prelude) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index c1b8bf8f7..3b6bad628 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -8,10 +8,16 @@ # nor does it submit to any jurisdiction. import logging +import os +import sys +from tempfile import TemporaryDirectory import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry from anemoi.datasets.create.sources import source_registry @@ -172,6 +178,7 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._attribution = attribution self._licence = licence self._name = name + self._dates = None self.input = Join() self.output = DotDict() @@ -205,19 +212,13 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) - def dump(self): - result = self.as_dict() - result["input"] = self.input.as_dict(self) - # result["output"] = self.description - - print(yaml.safe_dump(result)) - def as_dict(self): result = { "name": self.name, "description": self.description, "attribution": self.attribution, "licence": self.licence, + "dates": self.dates, } for k, v in list(result.items()): @@ -302,31 +303,95 @@ def name(self): def name(self, value): self._name = value + @property + def dates(self): + return self._dates + + def _parse_dates(self, value): + + start = None + end = None + frequency = 1 + + if isinstance(value, (list, tuple)): + if len(value) in [2, 3]: + start = value[0] + end = value[1] + + if len(value) == 3: + frequency = frequency_to_string(frequency_to_timedelta(value[2])) + if isinstance(frequency, int): + frequency = f"{frequency}h" + + if start is None or end is None: + raise ValueError(f"Invalid dates {value}") + + if isinstance(frequency, int): + frequency = f"{frequency}h" + + return dict( + start=as_datetime(start), + end=as_datetime(end), + frequency=frequency, + ) + + @dates.setter + def dates(self, value): + self._dates = self._parse_dates(value) + + def dump(self, file=sys.stdout): + result = self.as_dict() + result["input"] = self.input.as_dict(self) + + yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) + + def test(self, output="recipe.zarr"): + from argparse import ArgumentParser + + from anemoi.datasets.commands.create import command + + parser = ArgumentParser() + parser.add_argument("command", help="Command to run") + + cmd = command() + cmd.add_arguments(parser) + + with TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "recipe.yaml") + with open(path, "w") as file: + self.dump(file) + + args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) + cmd.run(args) + if __name__ == "__main__": r = Recipe() r.description = "test" + r.dates = ("1900-01-01", "2023-12-31") + m1 = r.mars(expver="0001") m2 = r.mars(expver="0002") m3 = r.mars(expver="0003") - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) | r.rescale(tp=["mm", "m"]) + r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) # | r.rescale(tp=["mm", "m"]) r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) - m0 = r.mars(expver="0000") - c = r.concat( - { - ("1900", "2000"): m0, - ("2001", "2020"): r.mars(expver="0002"), - ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - }, - ) + # m0 = r.mars(expver="0000") + # c = r.concat( + # { + # ("190", "2000"): m0, + # ("2001", "2020"): r.mars(expver="0002"), + # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), + # }, + # ) - c[("2031", "2033")] = r.mars(expver="0005") + # c[("2031", "2033")] = r.mars(expver="0005") - r.input += c + # r.input += c r.dump() + r.test() From 6f3fdb06148426b4042b2fabc46a97cff22f0785 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 11 May 2025 10:11:58 +0100 Subject: [PATCH 015/212] update --- src/anemoi/datasets/commands/migrate.py | 73 +++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/anemoi/datasets/commands/migrate.py diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py new file mode 100644 index 000000000..8e409740d --- /dev/null +++ b/src/anemoi/datasets/commands/migrate.py @@ -0,0 +1,73 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import yaml + +from . import Command + +LOG = logging.getLogger(__name__) + +ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") +ORDER = {k: i for i, k in enumerate(ORDER)} + + +def order(x: str) -> int: + + if x[0] not in ORDER: + ORDER[x[0]] = len(ORDER) + + return ORDER[x[0]] + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + with open(args.path, "r") as file: + old = yaml.safe_load(file) + + while True: + new = self.migrate(old) + if new == old: + break + old = new + + print(yaml.safe_dump(new, sort_keys=False, indent=2, width=120)) + + def migrate(self, config: dict) -> dict: + result = config.copy() + # if 'loop' in config: + # result.pop('loop') + # # config['loop'] = config['loop'].replace('(', '[').replace(')', ']') + + if "statistics_end" in config.get("output", {}): + result.setdefault("statistics", {}) + result["statistics"]["end"] = result["output"].pop("statistics_end") + + result = {k: v for k, v in sorted(result.items(), key=order)} + + return result + + +command = Recipe From 45365a107b49f7d7b228ee716831a40b4976b0e2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 15:44:27 +0100 Subject: [PATCH 016/212] upadte --- src/anemoi/datasets/commands/migrate.py | 106 ++++++++++++++---- src/anemoi/datasets/commands/recipe.py | 9 +- src/anemoi/datasets/create/__init__.py | 3 +- src/anemoi/datasets/create/input/action.py | 20 +++- .../datasets/create/input/data_sources.py | 2 +- .../datasets/create/input/repeated_dates.py | 24 +++- src/anemoi/datasets/recipe.py | 24 +++- 7 files changed, 156 insertions(+), 32 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8e409740d..cca70ffee 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -29,6 +29,88 @@ def order(x: str) -> int: return ORDER[x[0]] +MIGRATE = { + "output.statistics_end": "statistics.end", + "input.dates.<<": "dates", + "input.dates.join": "input.join", + "input.dates": None, + "has_nans": "statistics.allow_nans", + "loop.dates": "dates", + # 'copyright': 'citation', +} + +SOURCES = {"oper-accumulations": "accumulations", "constants": "forcings"} + + +def _move(config, path, new_path, result): + path = path.split(".") + if new_path is not None: + new_path = new_path.split(".") + + for k in path[:-1]: + if k not in config: + return + config = config[k] + + if path[-1] not in config: + return + + value = config.pop(path[-1]) + + if new_path is None: + return + + for k in new_path[:-1]: + if k not in result: + result[k] = {} + result = result[k] + + result[new_path[-1]] = value + + +def _migrate(config: dict) -> dict: + result = config.copy() + for k, v in MIGRATE.items(): + _move(config, k, v, result) + + if isinstance(result["input"], list): + join = [] + prev = {} + for n in result["input"]: + assert isinstance(n, dict), (n, type(n)) + assert len(n) == 1, (n, type(n)) + name = list(n.keys())[0] + prev[name] = n[name]["kwargs"] + if "inherit" in n[name]: + i = n[name]["inherit"] + n[name]["kwargs"].update(prev[i]) + n[name].pop("inherit") + + data = n[name]["kwargs"] + + src = data.pop("name", "mars") + + join.append({SOURCES.get(src, src): data}) + + result["input"] = dict(join=join) + + result = {k: v for k, v in sorted(result.items(), key=order) if v} + + return result + + +def migrate(old: dict) -> dict: + # return _migrate(old) + for i in range(10): + new = _migrate(old) + if new == old: + # print(json.dumps(new, indent=2, default=str)) + return new + old = new + + return new + + class Recipe(Command): def add_arguments(self, command_parser: Any) -> None: """Add arguments to the command parser. @@ -45,29 +127,9 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: with open(args.path, "r") as file: - old = yaml.safe_load(file) - - while True: - new = self.migrate(old) - if new == old: - break - old = new - - print(yaml.safe_dump(new, sort_keys=False, indent=2, width=120)) - - def migrate(self, config: dict) -> dict: - result = config.copy() - # if 'loop' in config: - # result.pop('loop') - # # config['loop'] = config['loop'].replace('(', '[').replace(')', ']') - - if "statistics_end" in config.get("output", {}): - result.setdefault("statistics", {}) - result["statistics"]["end"] = result["output"].pop("statistics_end") - - result = {k: v for k, v in sorted(result.items(), key=order)} + config = yaml.safe_load(file) - return result + print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) command = Recipe diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py index b028969f6..3045f1a1b 100644 --- a/src/anemoi/datasets/commands/recipe.py +++ b/src/anemoi/datasets/commands/recipe.py @@ -11,7 +11,10 @@ import logging from typing import Any +import yaml + from . import Command +from .migrate import migrate LOG = logging.getLogger(__name__) @@ -33,7 +36,11 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: from anemoi.datasets.create import config_to_python - print(config_to_python(args.path)) + with open(args.path, "r") as file: + config = yaml.safe_load(file) + config = migrate(config) + + print(config_to_python(config)) command = Recipe diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index e50695f4e..b6b51cb69 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1634,6 +1634,7 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An def config_to_python(config: Any) -> Any: config = loader_config(config) + input = build_input_(config, build_output(config.output, None)) prelude = [] @@ -1650,6 +1651,6 @@ def config_to_python(config: Any) -> Any: import black return black.format_str(code, mode=black.Mode()) - except Exception: + except ImportError: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 6057106ef..121d1e387 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import datetime import json import logging import re @@ -15,6 +16,7 @@ from typing import Dict from typing import List +from anemoi.utils.dates import frequency_to_string from earthkit.data.core.order import build_remapping from ...dates.groups import GroupOfDates @@ -160,7 +162,7 @@ def _trace_select(self, group_of_dates: GroupOfDates) -> str: """ return f"{self.__class__.__name__}({group_of_dates})" - def _to_python(self, name: str, config: dict) -> str: + def _to_python(self, name: str, config: dict, **extra: Any) -> str: """Convert the action to Python code. Parameters @@ -169,6 +171,8 @@ def _to_python(self, name: str, config: dict) -> str: The name of the action. config : dict The configuration for the action. + extra : Any + Additional keyword arguments. Returns ------- @@ -205,7 +209,16 @@ def _to_python(self, name: str, config: dict) -> str: "pass", ) - config = json.loads(json.dumps(config)) + def convert(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, datetime.date): + return obj.isoformat() + if isinstance(obj, datetime.timedelta): + return frequency_to_string(obj) + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + config = json.loads(json.dumps(config, default=convert)) assert len(config) == 1, (name, config) assert name in config, (name, config) @@ -218,6 +231,9 @@ def _to_python(self, name: str, config: dict) -> str: return f"r.{name}({config})" params.append(f"{k}={repr(v)}") + for k, v in extra.items(): + params.append(f"{k}={v}") + params = ",".join(params) return f"r.{name}({params})" # return f"{name}({config})" diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 42c610315..5b6282469 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -89,7 +89,7 @@ def __repr__(self) -> str: def python_prelude(self, prelude) -> str: for n, s in zip(self.names, self.sources): - self.sources.python_prelude(prelude) + s.python_prelude(prelude) prelude.append(f"{n}={s.to_python()}") def to_python(self) -> str: diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index 0f5b5730a..bc121c5c9 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -204,6 +204,21 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) self.day: int = day self.hour: Optional[int] = hour + def to_python(self) -> Dict[str, Any]: + """Convert the DateMapper to Python code. + + Returns + ------- + dict + The Python code representation of the DateMapper. + """ + return { + "mode": "climatology", + "year": self.year, + "day": self.day, + "hour": self.hour, + } + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. @@ -351,11 +366,18 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.source: Any = action_factory(source, context, action_path + ["source"]) self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) + self.mode = mode + self.kwargs = kwargs def to_python(self) -> str: """Convert the action to Python code.""" warnings.warn("RepeatedDatesAction.to_python is still a work in progress") - return self.source.to_python() + args = {"mode": self.mode} + args.update(self.kwargs) + return self._to_python("repeated_dates", {"repeated_dates": args}, source=self.source.to_python()) + + def python_prelude(self, prelude: Any) -> None: + self.source.python_prelude(prelude) @trace_select def select(self, group_of_dates: Any) -> JoinResult: diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 3b6bad628..69ce8b7df 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -212,6 +212,8 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): assert not hasattr(self, key) setattr(self, key, FilterMaker(key, factory)) + self.repeated_dates = SourceMaker("repeated_dates", None) + def as_dict(self): result = { "name": self.name, @@ -341,8 +343,18 @@ def dates(self, value): def dump(self, file=sys.stdout): result = self.as_dict() + result["input"] = self.input.as_dict(self) + if self.output: + result["output"] = self.output.as_dict() + + if self.statistics: + result["statistics"] = self.statistics.as_dict() + + if self.build: + result["build"] = self.build.as_dict() + yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) def test(self, output="recipe.zarr"): @@ -370,15 +382,15 @@ def test(self, output="recipe.zarr"): r = Recipe() r.description = "test" - r.dates = ("1900-01-01", "2023-12-31") + r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") - m1 = r.mars(expver="0001") + m1 = r.mars(expver="0001", grid=[20, 20]) m2 = r.mars(expver="0002") m3 = r.mars(expver="0003") - r.input = (m1 + m2 + m3) | r.rename(param={"2t": "2t_0002"}) # | r.rescale(tp=["mm", "m"]) + r.input = m1 - r.input += r.forcings(template=m1, param=["cos_lat", "sin_lat"]) + r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) # m0 = r.mars(expver="0000") # c = r.concat( @@ -393,5 +405,9 @@ def test(self, output="recipe.zarr"): # r.input += c + r.output.group_by = "day" + r.build.additions = True + r.statistics.end = "80%" + r.dump() r.test() From d7cc82ca6c579bafd830d857178d703c6d9589b7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 16:12:35 +0100 Subject: [PATCH 017/212] update --- src/anemoi/datasets/commands/migrate.py | 43 ++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index cca70ffee..8dbd2ca7b 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -35,11 +35,17 @@ def order(x: str) -> int: "input.dates.join": "input.join", "input.dates": None, "has_nans": "statistics.allow_nans", + "loop.dates.group_by": "build.group_by", "loop.dates": "dates", - # 'copyright': 'citation', + "copyright": "attribution", } -SOURCES = {"oper-accumulations": "accumulations", "constants": "forcings"} +SOURCES = { + "oper-accumulations": "accumulations", + "era5-accumulations": "accumulations", + "constants": "forcings", + "ensemble-perturbations": "recentre", +} def _move(config, path, new_path, result): @@ -68,12 +74,13 @@ def _move(config, path, new_path, result): result[new_path[-1]] = value -def _migrate(config: dict) -> dict: +def _migrate(config: dict, n) -> dict: result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) if isinstance(result["input"], list): + assert n == 0 join = [] prev = {} for n in result["input"]: @@ -94,6 +101,34 @@ def _migrate(config: dict) -> dict: result["input"] = dict(join=join) + if "join" in result["input"] and n == 0: + join = result["input"].pop("join") + new_join = [] + + for j in join: + + if "label" in j: + j = j["label"] + + if "source" in j: + j = j["source"] + + src = j.pop("name", "mars") + data = j + if "<<" in data: + data.update(data.pop("<<")) + + for k, v in list(data.items()): + if k in ("date", "time"): + if isinstance(v, str) and v.startswith("$"): + del data[k] + + new_join.append({SOURCES.get(src, src): data}) + + print(new_join) + + result["input"]["join"] = new_join + result = {k: v for k, v in sorted(result.items(), key=order) if v} return result @@ -102,7 +137,7 @@ def _migrate(config: dict) -> dict: def migrate(old: dict) -> dict: # return _migrate(old) for i in range(10): - new = _migrate(old) + new = _migrate(old, i) if new == old: # print(json.dumps(new, indent=2, default=str)) return new From f381b00a87e9039f71d953331d80157aba0bd6c1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 12 May 2025 16:49:27 +0100 Subject: [PATCH 018/212] update --- src/anemoi/datasets/commands/migrate.py | 20 +++++++++++++++++--- src/anemoi/datasets/create/input/concat.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8dbd2ca7b..ec6d2749a 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -108,7 +108,12 @@ def _migrate(config: dict, n) -> dict: for j in join: if "label" in j: - j = j["label"] + if isinstance(j["label"], str): + j.pop("label") + else: + if j["label"] is not None: + j = j["label"] + j.pop("name", None) if "source" in j: j = j["source"] @@ -125,12 +130,21 @@ def _migrate(config: dict, n) -> dict: new_join.append({SOURCES.get(src, src): data}) - print(new_join) - result["input"]["join"] = new_join + if "join" in result["input"]: + for j in result["input"]["join"]: + k = list(j.keys())[0] + j[k].pop("name", None) + + if "source_or_dataset" in j[k]: + j[k].pop("source_or_dataset", None) + j[k]["template"] = "${input.0.join.0.mars}" + result = {k: v for k, v in sorted(result.items(), key=order) if v} + result.pop("loop", None) + return result diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py index cef2a64c5..bd906bd03 100644 --- a/src/anemoi/datasets/create/input/concat.py +++ b/src/anemoi/datasets/create/input/concat.py @@ -177,7 +177,7 @@ def to_python(self) -> str: for i, (filtering_dates, action) in enumerate(self.parts): result.append(f"{filtering_dates.to_python()}:{action.to_python()}") - return f"r.concat({{{','.join(result)})" + return f"r.concat({{{','.join(result)}}})" def python_prelude(self, prelude) -> None: for filtering_dates, action in self.parts: From da93ad651ce742475549c31024dd208470fc2906 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 14 May 2025 14:10:40 +0100 Subject: [PATCH 019/212] update --- .gitignore | 1 + src/anemoi/datasets/commands/migrate.py | 270 +++++++++++++++++++----- 2 files changed, 221 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 158ba7bdd..05db8afa7 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,4 @@ Untitled-*.py *.tgz _api/ trace.txt +?/ diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index ec6d2749a..3183508b7 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -9,12 +9,14 @@ import logging +from copy import deepcopy from typing import Any import yaml from . import Command +errors = [] LOG = logging.getLogger(__name__) ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") @@ -31,9 +33,6 @@ def order(x: str) -> int: MIGRATE = { "output.statistics_end": "statistics.end", - "input.dates.<<": "dates", - "input.dates.join": "input.join", - "input.dates": None, "has_nans": "statistics.allow_nans", "loop.dates.group_by": "build.group_by", "loop.dates": "dates", @@ -45,6 +44,9 @@ def order(x: str) -> int: "era5-accumulations": "accumulations", "constants": "forcings", "ensemble-perturbations": "recentre", + "ensemble_perturbations": "recentre", + "perturbations": "recentre", + "custom-regrid": "regrid", } @@ -74,72 +76,238 @@ def _move(config, path, new_path, result): result[new_path[-1]] = value +def _fix_dates(result, config): + dates = config["input"].pop("dates") + assert "join" in dates, dates + result["input"] = dates["join"] + config["input"] = result["input"].copy() + + +def _fix_list(result, config): + result["input"] = dict(join=result["input"]) + config["input"] = result["input"].copy() + + +def _fix_join_0(result, config): + + join = config["input"]["join"] + + new_join = [] + for n in join: + + if "function" in n: + f = n["function"] + name = f.pop("name") + data = _tidy(f) + for k, v in list(data.items()): + if isinstance(v, dict): + if "name" in v: + p = v.pop("name") + data[k] = {SOURCES.get(p, p): _tidy(v)} + + new_join.append({SOURCES.get(name, name): data}) + continue + + new_join.append(n) # {SOURCES.get(src, src): _tidy(data)}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _fix_join_1(result, config): + + join = config["input"].pop("join") + new_join = [] + for n in join: + if isinstance(n, dict): + if len(n) == 1: + if "label" in n: + n = n["label"] + + if isinstance(n, dict): + if len(n) == 2: + if "name" in n and "source" in n: + n.pop("name") + + if isinstance(n, dict): + if len(n) == 1: + if "source" in n: + n = n["source"] + if "<<" in n: + n.update(n.pop("<<")) + name = n.pop("name", "mars") + new_join.append({SOURCES.get(name, name): _tidy(n)}) + continue + + new_join.append(n) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _fix_join_3(result, config): + + join = config["input"].pop("join") + new_join = [] + for n in join: + if not isinstance(n, dict): + return + if len(n) != 1: + return + + name = list(n.keys())[0] + data = n[name] + + new_join.append({SOURCES.get(name, name): data}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + +def _tidy(data): + for k, v in list(data.items()): + if k in ("date", "time"): + if isinstance(v, str) and v.startswith("$"): + del data[k] + + if "name" in data: + assert False, data + name = data.pop("name") + return {SOURCES.get(name, name): _tidy(data)} + + return data + + +def _fix_join_2(result, config): + + join = config["input"]["join"] + + previous = {} + + new_join = [] + for n in join: + + if not isinstance(n, dict): + return + + if len(n) != 1: + return + + what = list(n.keys())[0] + + if n[what] is None: + assert False, (n, what, config["input"]) + + if "kwargs" not in n[what]: + return + + # assert False + + previous[what] = deepcopy(n[what]["kwargs"]) + if "inherit" in n[what]: + previous[what].update(deepcopy(previous[n[what]["inherit"]])) + + data = previous[what].copy() + src = data.pop("name", "mars") + + new_join.append({SOURCES.get(src, src): _tidy(data)}) + + result["input"] = dict(join=new_join) + config["input"] = result["input"].copy() + + def _migrate(config: dict, n) -> dict: result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) - if isinstance(result["input"], list): - assert n == 0 - join = [] - prev = {} - for n in result["input"]: - assert isinstance(n, dict), (n, type(n)) - assert len(n) == 1, (n, type(n)) - name = list(n.keys())[0] - prev[name] = n[name]["kwargs"] - if "inherit" in n[name]: - i = n[name]["inherit"] - n[name]["kwargs"].update(prev[i]) - n[name].pop("inherit") + if "dates" in config["input"]: + _fix_dates(result, config) - data = n[name]["kwargs"] + if isinstance(config["input"], list): + _fix_list(result, config) - src = data.pop("name", "mars") + if "join" in config["input"]: + _fix_join_0(result, config) - join.append({SOURCES.get(src, src): data}) + if "join" in config["input"]: + _fix_join_1(result, config) - result["input"] = dict(join=join) + if "join" in config["input"]: + _fix_join_2(result, config) - if "join" in result["input"] and n == 0: - join = result["input"].pop("join") - new_join = [] + if "join" in config["input"]: + _fix_join_3(result, config) - for j in join: + # _check(result, "1") - if "label" in j: - if isinstance(j["label"], str): - j.pop("label") - else: - if j["label"] is not None: - j = j["label"] - j.pop("name", None) + # if isinstance(result["input"], list): + # assert n == 0 + # join = [] + # prev = {} + # for n in result["input"]: + # assert isinstance(n, dict), (n, type(n)) + # assert len(n) == 1, (n, type(n)) + # name = list(n.keys())[0] + # prev[name] = n[name]["kwargs"] + # if "inherit" in n[name]: + # i = n[name]["inherit"] + # n[name]["kwargs"].update(prev[i]) + # n[name].pop("inherit") - if "source" in j: - j = j["source"] + # data = n[name]["kwargs"] - src = j.pop("name", "mars") - data = j - if "<<" in data: - data.update(data.pop("<<")) + # src = data.pop("name", "mars") - for k, v in list(data.items()): - if k in ("date", "time"): - if isinstance(v, str) and v.startswith("$"): - del data[k] + # join.append({SOURCES.get(src, src): data}) + + # result["input"] = dict(join=join) + # _check(result, "2") - new_join.append({SOURCES.get(src, src): data}) + # if "join" in result["input"] and n == 0: + # join = result["input"].pop("join") + # new_join = [] - result["input"]["join"] = new_join + # for j in join: - if "join" in result["input"]: - for j in result["input"]["join"]: - k = list(j.keys())[0] - j[k].pop("name", None) + # if "label" in j: + # if isinstance(j["label"], str): + # j.pop("label") + # else: + # if j["label"] is not None: + # j = j["label"] + # j.pop("name", None) - if "source_or_dataset" in j[k]: - j[k].pop("source_or_dataset", None) - j[k]["template"] = "${input.0.join.0.mars}" + # if "source" in j: + # j = j["source"] + + # src = j.pop("name", "mars") + # data = j + # if "<<" in data: + # data.update(data.pop("<<")) + + # for k, v in list(data.items()): + # if k in ("date", "time"): + # if isinstance(v, str) and v.startswith("$"): + # del data[k] + + # if "mars" in data: + # new_join.append(data) + # else: + # new_join.append({SOURCES.get(src, src): data}) + + # result["input"]["join"] = new_join + # _check(result, "3") + + # if "join" in result["input"]: + # for j in result["input"]["join"]: + # k = list(j.keys())[0] + # j[k].pop("name", None) + + # if "source_or_dataset" in j[k]: + # j[k].pop("source_or_dataset", None) + # j[k]["template"] = "${input.0.join.0.mars}" + # _check(result, "4") result = {k: v for k, v in sorted(result.items(), key=order) if v} @@ -180,5 +348,7 @@ def run(self, args: Any) -> None: print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) + assert not errors, f"Errors: {errors}" + command = Recipe From 47f650b4ed54fc77482d2c6e48854f8b81939e53 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 28 May 2025 12:23:50 +0000 Subject: [PATCH 020/212] feat: missing features for observations --- src/anemoi/datasets/data/records/backends/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index 6971cafd4..146749e6c 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -110,6 +110,7 @@ def write_statistics(self, statistics): flatten = {} for name, d in statistics.items(): assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" for k, v in d.items(): assert isinstance( v, (int, float, np.ndarray) @@ -138,6 +139,7 @@ def write_statistics(self, statistics): flatten = {} for name, d in statistics.items(): assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" for k, v in d.items(): assert isinstance( v, (int, float, np.ndarray) From eeeedcd2ad2616e9ba69aa60608b07a6dc7a4278 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 2 Jun 2025 14:34:02 +0000 Subject: [PATCH 021/212] change window on the fly --- src/anemoi/datasets/data/misc.py | 2 +- src/anemoi/datasets/data/records/__init__.py | 276 +++++++++++++++++- .../data/records/backends/__init__.py | 37 ++- tests/test_records.py | 58 +++- 4 files changed, 357 insertions(+), 16 deletions(-) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index b5523ef85..dd70967f1 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -356,7 +356,7 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - from .stores import Zarr from .stores import zarr_lookup - if isinstance(a, str) and len(a.split(".")) in [2, 3]: + if isinstance(a, str) and len(a.split(".")[-1]) in [1, 2, 3]: metadata_path = os.path.join(a, "metadata.json") if os.path.exists(metadata_path): diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index f569a4105..abfc68636 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -14,6 +14,7 @@ from functools import cached_property import numpy as np +from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.data.records.backends import backend_factory @@ -44,6 +45,33 @@ def open_records_dataset(dataset, **kwargs): return RecordsDataset(dataset, **kwargs) +def merge_data(list_of_dicts): + merged = defaultdict(list) + for d in list_of_dicts: + for key, value in d.items(): + merged[key].append(value) + return {k: np.hstack(v) for k, v in merged.items()} + + +def _to_numpy_timedelta(td): + if isinstance(td, np.timedelta64): + assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" + return td + return np.timedelta64(int(td.total_seconds()), "s") + + +def _to_numpy_date(d): + if isinstance(d, np.datetime64): + assert d.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {d.dtype}" + return d + assert isinstance(d, datetime.datetime), f"date must be a datetime.datetime, got {type(d)}" + return _to_numpy_dates([d])[0] + + +def _to_numpy_dates(d): + return np.array(d, dtype="datetime64[s]") + + class BaseRecordsDataset: def __getitem__(self, i): @@ -55,6 +83,10 @@ def __getitem__(self, i): raise ValueError(f"Invalid index {i}, must be int or str") + @cached_property + def window(self): + return str(self._window) + def _getgroup(self, i): return Tabular(self, i) @@ -107,6 +139,13 @@ def _dates_to_indices(start, end): if select is not None: return Select(self, select)._subset(**kwargs) + window = kwargs.pop("window", None) + if window is not None: + return Rewindowed(self, window)._subset(**kwargs) + + if kwargs: + raise ValueError(f"Invalid kwargs {kwargs}, must be 'start', 'end', 'frequency' or 'select'") + return self def mutate(self): @@ -147,12 +186,16 @@ def name_to_index(self): def frequency(self): return self.forward.frequency + @property + def _window(self): + return self.forward._window + @property def shapes(self): return self.forward.shapes def __len__(self): - return len(self.forward) + return len(self.dates) def match_variable(lst, group, name): @@ -177,6 +220,222 @@ def match_variable(lst, group, name): return False +def window_from_str(txt): + """Parses a window string of the form '(-6h, 0h]' and returns a WindowsSpec object.""" + if txt.startswith("["): + include_start = True + elif txt.startswith("("): + include_start = False + else: + raise ValueError(f"Invalid window {txt}, must start with '(' or '['") + txt = txt[1:] + + if txt.endswith("]"): + include_end = True + elif txt.endswith(")"): + include_end = False + else: + raise ValueError(f"Invalid window {txt}, must end with ')' or ']'") + txt = txt[:-1] + + txt = txt.strip() + if ";" in txt: + txt = txt.replace(";", ",") + lst = txt.split(",") + if len(lst) != 2: + raise ValueError( + f"Invalid window {txt}, must be of the form '(start, end)' or '[start, end]' or '[start, end)' or '(start, end]'" + ) + start, end = lst + start = start.strip() + end = end.strip() + + def _to_timedelta(t): + # This part should go into utils + from anemoi.utils.dates import as_timedelta + + if t.startswith(" ") or t.endswith(" "): + t = t.strip() + if t.startswith("-"): + return -as_timedelta(t[1:]) + if t.startswith("+"): + return as_timedelta(t[1:]) + # end of : This part should go into utils + return as_timedelta(t) + + start = _to_timedelta(start) + end = _to_timedelta(end) + return WindowsSpec( + start=start, + end=end, + include_start=include_start, + include_end=include_end, + ) + + +class WindowsSpec: + def __init__(self, *, start, end, include_start=False, include_end=True): + assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" + assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}" + assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" + assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" + assert include_start in (True, False), f"Invalid include_start {include_start}" # None is not allowed + assert include_end in (True, False), f"Invalid include_end {include_end}" # None is not allowed + if start >= end: + raise ValueError(f"start {start} must be less than end {end}") + self.start = start + self.end = end + self.include_start = include_start + self.include_end = include_end + + self._start_np = _to_numpy_timedelta(start) + self._end_np = _to_numpy_timedelta(end) + + def __repr__(self): + first = "[" if self.include_start else "(" + last = "]" if self.include_end else ")" + + def _frequency_to_string(t): + if t < datetime.timedelta(0): + return f"-{frequency_to_string(-t)}" + elif t == datetime.timedelta(0): + return "0" + return frequency_to_string(t) + + return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" + + def compute_mask(self, timedeltas): + if self.include_start: + lower_mask = timedeltas >= self._start_np + else: + lower_mask = timedeltas > self._start_np + + if self.include_end: + upper_mask = timedeltas <= self._end_np + else: + upper_mask = timedeltas < self._end_np + + return lower_mask & upper_mask + + def starts_before(self, my_dates, other_dates, other_window): + assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" + assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" + assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" + + my_start = my_dates[0] + self._start_np + other_start = other_dates[0] + other_window._start_np + + if my_start == other_start: + return (not other_window.include_start) or self.include_start + return my_start <= other_start + + def ends_after(self, my_dates, other_dates, other_window): + assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" + assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" + assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" + + my_end = my_dates[-1] + self._end_np + other_end = other_dates[-1] + other_window._end_np + + if my_end == other_end: + print(".", (not other_window.include_end) or self.include_end) + return (not other_window.include_end) or self.include_end + print(my_end >= other_end) + return my_end >= other_end + + +class Rewindowed(RecordsForward): + def __init__(self, dataset, window): + super().__init__(dataset) + self.dataset = dataset + + # in this class anything with 1 refers to the original window/dataset + # and anything with 2 refers to the new window/dataset + # and we use _Δ for timedeltas + + self._window1 = self.forward._window + self._window2 = window_from_str(window) + self.reason = {"window": self.window} + + self._dates1 = _to_numpy_dates(self.forward.dates) + dates = self._dates1 + self.dates_offset = 0 + while len(dates) > 0 and not self._window1.starts_before(self._dates1, dates, self._window2): + LOG.warning(f"Removing first date {dates[0]} because it is to early") + self.dates_offset += 1 + dates = dates[1:] + while len(dates) > 0 and not self._window1.ends_after(self._dates1, dates, self._window2): + LOG.warning(f"Removing last date {dates[-1]} because it is to late") + dates = dates[:-1] + + if len(dates) == 0: + raise ValueError( + f"No dates left after rewindowing {self._window1} -> {self._window2} (frequency={self.frequency}), check your window" + ) + self._dates = dates + + before_span1 = self._window1.start / self.frequency + before_span2 = self._window2.start / self.frequency + delta_before_span = before_span2 - before_span1 + if delta_before_span == int(delta_before_span): + if not self._window1.include_start and self._window2.include_start: + # if the start of the window is not included, we need to read one more index + delta_before_span -= 1 + delta_before_span = int(delta_before_span) + self.delta_before_span = delta_before_span + + after_span1 = self._window1.end / self.frequency + after_span2 = self._window2.end / self.frequency + delta_after_span = after_span2 - after_span1 + if delta_after_span == int(delta_after_span): + if not self._window1.include_end and self._window2.include_end: + # if the end of the window is not included, we need to read one more index + delta_after_span += 1 + delta_after_span = int(delta_after_span) + self.delta_after_span = delta_after_span + + @property + def window(self): + return self._window2 + + @property + def dates(self): + return self._dates + + def __len__(self): + return len(self.dates) + + @property + def frequency(self): + return self.forward.frequency + + def _load_data(self, i): + print(f"Rewindowing data for i={i} (date={self.dates[i]}) : {self._window1} -> {self._window2}") + + first_j = i + self.delta_before_span + last_j = i + self.delta_after_span + + first_j = first_j + self.dates_offset + last_j = last_j + self.dates_offset + print(f"Requested ds({i}) : need to read {list(range(first_j, last_j + 1))} indices") + + # _load_data could support a list of indices, but for now we merge the data ourselves + too_much_data = merge_data(self.forward._load_data(j) for j in range(first_j, last_j + 1)) + + out = {} + for group in self.groups: + timedeltas = too_much_data[f"timedeltas:{group}"] + mask = self._window.compute_mask(timedeltas) + + out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] + out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] + out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] + out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] + out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] + + return out + + class Select(RecordsForward): def __init__(self, dataset, select): super().__init__(dataset) @@ -285,6 +544,12 @@ def __init__(self, path, backend="npz1", **kwargs): self.path = path self.backend = backend_factory(backend, path, **kwargs) self.keys = self.metadata["sources"].keys + for k in self.keys(): + assert k == self.normalise_key(k), k + + @classmethod + def normalise_key(cls, k): + return "".join([x.lower() if x.isalnum() else "-" for x in k]) @property def frequency(self): @@ -300,6 +565,11 @@ def name_to_index(self): def variables(self): return self.metadata["variables"] + @cached_property + def _window(self): + window = self.metadata["window"] + return window_from_str(window) + @cached_property def metadata(self): return self.backend.read_metadata() @@ -340,7 +610,9 @@ def dates(self): @counter def _load_data(self, i): - return self.backend.read(i) + data = self.backend.read(i) + self.backend._check_data(data) + return data def check(self, i=None): if i is not None: diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index 146749e6c..bffccec4b 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -13,6 +13,10 @@ import numpy as np +def normalise_key(k): + return "".join([x.lower() if x.isalnum() else "-" for x in k]) + + class Backend: def __init__(self, path, **kwargs): self.path = path @@ -27,10 +31,19 @@ def read_metadata(self): def read_statistics(self): raise NotImplementedError("Must be implemented in subclass") + def _check_data(self, data): + for k in list(data.keys()): + k = k.split(":")[-1] + if k != normalise_key(k): + raise ValueError(f"{k} must be alphanumerical and '-' only.") + class Npz1Backend(Backend): + number_of_files_per_subdirectory = 10 + def read(self, i, **kwargs): - path = os.path.join(self.path, "data", str(int(i / 10)), f"{i}.npz") + d = str(int(i / self.number_of_files_per_subdirectory)) + path = os.path.join(self.path, "data", d, f"{i}.npz") with open(path, "rb") as f: return dict(np.load(f)) @@ -51,7 +64,7 @@ def read_statistics(self): class Npz2Backend(Backend): def read(self, i, **kwargs): - path = os.path.join(self.path, "data_", str(int(i / 10)), f"{i}_.npz") + path = os.path.join(self.path, "data_", str(int(i / 100)), f"{i}_.npz") with open(path, "rb") as f: return dict(np.load(f)) @@ -70,12 +83,13 @@ def read_statistics(self): return dic -def backend_factory(backend, *args, **kwargs): +def backend_factory(name, *args, **kwargs): BACKENDS = dict( npz1=Npz1Backend, npz2=Npz2Backend, ) - return BACKENDS[backend](*args, **kwargs) + cls = BACKENDS[name] + return cls(*args, **kwargs) class WriteBackend(Backend): @@ -91,10 +105,20 @@ def write_metadata(self, metadata): def write_statistics(self, statistics): raise NotImplementedError("Must be implemented in subclass") + def _check_data(self, data): + for k in list(data.keys()): + k = k.split(":")[-1] + if k != normalise_key(k): + raise ValueError(f"{k} must be alphanumerical and '-' only.") + class Npz1WriteBackend(WriteBackend): + number_of_files_per_subdirectory = 10 + def write(self, i, data, **kwargs): - path = os.path.join(self.path, "data", str(int(i / 10))) + self._check_data(data) + d = str(int(i / self.number_of_files_per_subdirectory)) + path = os.path.join(self.path, "data", d) os.makedirs(path, exist_ok=True) out_path = os.path.join(path, f"{i}.npz") np.savez(out_path, **data) @@ -123,7 +147,8 @@ def write_statistics(self, statistics): class Npz2WriteBackend(WriteBackend): def write(self, i, data, **kwargs): - path = os.path.join(self.path, "data_", str(int(i / 10))) + self._check_data(data) + path = os.path.join(self.path, "data_", str(int(i / 100))) os.makedirs(path, exist_ok=True) out_path = os.path.join(path, f"{i}_.npz") np.savez(out_path, **data) diff --git a/tests/test_records.py b/tests/test_records.py index 896081f9a..1b7cd0ccc 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -25,7 +25,7 @@ def check_numpy(x, y): def _test(ds, nb_dates=None): - grp = "metop_a_ascat" + grp = "metop-a-ascat" index_i = 0 if nb_dates is not None: @@ -132,22 +132,66 @@ def test_open_with_subset_dates(): "../../data/vz/obs-2018-11.vz", end="2018-11-30", select=[ - "metop_a_ascat.*", - "amsr2_h180.rawbt_4", - "amsr2_h180.rawbt_3", + "metop-a-ascat.*", + "amsr2-h180.rawbt_4", + "amsr2-h180.rawbt_3", ], ) _test(ds, nb_dates=8) +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open_with_window(): + dates = dict(end="2018-11-30") + ds = open_dataset("../../data/vz/obs-2018-11.vz", window="(-6h, 0h]", **dates) + _test(ds, nb_dates=8) + + ds = open_dataset("../../data/vz/obs-2018-11.vz", window="(-1h, 0)", **dates) + _test(ds, nb_dates=8) + + +def test_open_bad_window(): + subset = dict(end="2018-11-30") + with pytest.raises(ValueError, match="No dates left after rewindowing"): + open_dataset("../../data/vz/obs-2018-11.vz", window="(-48h, +48h)", **subset) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.parametrize( + "window, missing_dates", + [ + ("(-12h, 0)", -1), # first window is incomplete + ("[-12h, 0)", -2), # first two windows are incomplete + ("(-3h, +3h)", -1), # last date is incomplete + ("[-6h, 0h)", -1), # incomplete due to rounding + ("(-6h, 0h)", 0), + ("(-1h, 0h]", 0), + ("(-1h, 0)", 0), + ("(-6h, +6h)", -1), + ("(-6h, +5h)", -1), + ("(-12h, +12h)", -3), + ("(-1h, +15h]", -3), + ], +) +def test_open_with_window_parametrized(window, missing_dates): + subset = dict(end="2018-11-30") + + ds = open_dataset("../../data/vz/obs-2018-11.vz", **subset) + assert len(ds) == 8 + nb_dates = len(ds) + missing_dates + + ds = open_dataset("../../data/vz/obs-2018-11.vz", window=window, **subset) + _test(ds, nb_dates=nb_dates) + + @pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") def test_open_with_subset_select(): ds = open_dataset( "../../data/vz/obs-2018-11.vz", select=[ - "amsr2_h180.rawbt_4", - "amsr2_h180.rawbt_3", - "metop_a_ascat.*", + "amsr2-h180.rawbt_4", + "amsr2-h180.rawbt_3", + "metop-a-ascat.*", ], ) _test(ds) From ecba1db5768f5d62dee850d63b7d71af73ea565e Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 4 Jun 2025 15:26:08 +0000 Subject: [PATCH 022/212] implement netcdf backend --- src/anemoi/datasets/data/records/__init__.py | 6 +- .../data/records/backends/__init__.py | 77 +++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index abfc68636..0db5e7d5c 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -40,8 +40,6 @@ def counter(func): def open_records_dataset(dataset, **kwargs): - if not dataset.endswith(".vz"): - raise ValueError("dataset must be a .vz file") return RecordsDataset(dataset, **kwargs) @@ -143,7 +141,9 @@ def _dates_to_indices(start, end): if window is not None: return Rewindowed(self, window)._subset(**kwargs) - if kwargs: + for k in kwargs: + if k in ["backend"]: + continue raise ValueError(f"Invalid kwargs {kwargs}, must be 'start', 'end', 'frequency' or 'select'") return self diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index bffccec4b..a599de1d7 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -83,10 +83,40 @@ def read_statistics(self): return dic +class Nc1Backend(Backend): + + def read(self, i, **kwargs): + d = str(int(i / self.number_of_files_per_subdirectory)) + path = os.path.join(self.path, "data", d, f"{i}.nc") + import xarray as xr + + ds = xr.open_dataset(path) + return {var: ds[var].values for var in ds.data_vars} + + def read_metadata(self): + with open(os.path.join(self.path, "metadata.json"), "r") as f: + return json.load(f) + + def read_statistics(self): + path = os.path.join(self.path, "statistics.nc") + import xarray as xr + + ds = xr.open_dataset(path) + flatten = {var: ds[var].values for var in ds.data_vars} + dic = {} + for k, v in flatten.items(): + key, group = k.split(":") + if group not in dic: + dic[group] = {} + dic[group][key] = v + return dic + + def backend_factory(name, *args, **kwargs): BACKENDS = dict( npz1=Npz1Backend, npz2=Npz2Backend, + nc1=Nc1Backend, ) cls = BACKENDS[name] return cls(*args, **kwargs) @@ -145,6 +175,52 @@ def write_statistics(self, statistics): np.savez(path, **flatten) +class Nc1WriteBackend(WriteBackend): + number_of_files_per_subdirectory = 10 + + def write(self, i, data, **kwargs): + self._check_data(data) + d = str(int(i / self.number_of_files_per_subdirectory)) + path = os.path.join(self.path, "data", d) + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{i}.nc") + + import xarray as xr + + ds = xr.Dataset( + {key: ([f"dim_{key}" + str(i) for i in range(value.ndim)], value) for key, value in data.items()} + ) + ds.to_netcdf(out_path) + + def write_metadata(self, metadata): + from anemoi.datasets.create import json_tidy + + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2, default=json_tidy) + + def write_statistics(self, statistics): + flatten = {} + for name, d in statistics.items(): + assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" + for k, v in d.items(): + assert isinstance( + v, (int, float, np.ndarray) + ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" + flatten[k + ":" + name] = v + + path = os.path.join(self.path, "statistics.nc") + + import xarray as xr + + ds = xr.Dataset( + {key: ([f"dim_{key}" + str(i) for i in range(value.ndim)], value) for key, value in flatten.items()} + ) + ds.to_netcdf(path) + np.savez(path, **flatten) + + class Npz2WriteBackend(WriteBackend): def write(self, i, data, **kwargs): self._check_data(data) @@ -180,5 +256,6 @@ def writer_backend_factory(backend, *args, **kwargs): WRITE_BACKENDS = dict( npz1=Npz1WriteBackend, npz2=Npz2WriteBackend, + nc1=Nc1WriteBackend, ) return WRITE_BACKENDS[backend](*args, **kwargs) From 8ccd23732e6a03728b7aa52ee356d42086346787 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 4 Jun 2025 15:30:32 +0000 Subject: [PATCH 023/212] up --- src/anemoi/datasets/data/misc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index dd70967f1..dea006194 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -357,6 +357,8 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - from .stores import zarr_lookup if isinstance(a, str) and len(a.split(".")[-1]) in [1, 2, 3]: + # This will do nothing if there is no "metadata.json" file + # .zarr datasets do not have "metadata.json" metadata_path = os.path.join(a, "metadata.json") if os.path.exists(metadata_path): From e2b54ad64f6ee5c5346011d584148f677f30e4d9 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 6 Jun 2025 13:02:09 +0000 Subject: [PATCH 024/212] up --- src/anemoi/datasets/data/records/backends/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index a599de1d7..ba277e32b 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -39,7 +39,7 @@ def _check_data(self, data): class Npz1Backend(Backend): - number_of_files_per_subdirectory = 10 + number_of_files_per_subdirectory = 100 def read(self, i, **kwargs): d = str(int(i / self.number_of_files_per_subdirectory)) @@ -64,7 +64,7 @@ def read_statistics(self): class Npz2Backend(Backend): def read(self, i, **kwargs): - path = os.path.join(self.path, "data_", str(int(i / 100)), f"{i}_.npz") + path = os.path.join(self.path, "data_", str(int(i / 10)), f"{i}_.npz") with open(path, "rb") as f: return dict(np.load(f)) @@ -176,7 +176,7 @@ def write_statistics(self, statistics): class Nc1WriteBackend(WriteBackend): - number_of_files_per_subdirectory = 10 + number_of_files_per_subdirectory = 100 def write(self, i, data, **kwargs): self._check_data(data) @@ -224,7 +224,7 @@ def write_statistics(self, statistics): class Npz2WriteBackend(WriteBackend): def write(self, i, data, **kwargs): self._check_data(data) - path = os.path.join(self.path, "data_", str(int(i / 100))) + path = os.path.join(self.path, "data_", str(int(i / 10))) os.makedirs(path, exist_ok=True) out_path = os.path.join(path, f"{i}_.npz") np.savez(out_path, **data) From fcd46a3aaeff4005b67e3e27492b472f23be9f44 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 10 Jun 2025 15:22:12 +0000 Subject: [PATCH 025/212] up --- src/anemoi/datasets/data/records/backends/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index ba277e32b..e7629069d 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -39,7 +39,7 @@ def _check_data(self, data): class Npz1Backend(Backend): - number_of_files_per_subdirectory = 100 + number_of_files_per_subdirectory = 10 def read(self, i, **kwargs): d = str(int(i / self.number_of_files_per_subdirectory)) From a02deb3d13dc5a1fc68ae62864ea36b26bb740c3 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 10 Jun 2025 15:41:19 +0000 Subject: [PATCH 026/212] fix --- src/anemoi/datasets/data/records/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 0db5e7d5c..4e46f0061 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -305,6 +305,7 @@ def _frequency_to_string(t): return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" def compute_mask(self, timedeltas): + assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" if self.include_start: lower_mask = timedeltas >= self._start_np else: @@ -425,6 +426,12 @@ def _load_data(self, i): out = {} for group in self.groups: timedeltas = too_much_data[f"timedeltas:{group}"] + if timedeltas.dtype != "timedelta64[s]": + if len(timedeltas) != 0: + raise ValueError(f"Wrong type for {group}") + else: + LOG.warning(f"TODO: Fixing {group} on the fly") + timedeltas = np.ones_like(timedeltas, dtype="timedelta64[s]") * 0 mask = self._window.compute_mask(timedeltas) out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] From 48df56ee08eaa82d67413f591a73d835767973c5 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 10 Jun 2025 20:48:51 +0000 Subject: [PATCH 027/212] creating obs dataset. Draft --- .../datasets/create/sources/observations.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 src/anemoi/datasets/create/sources/observations.py diff --git a/src/anemoi/datasets/create/sources/observations.py b/src/anemoi/datasets/create/sources/observations.py new file mode 100644 index 000000000..4f39949eb --- /dev/null +++ b/src/anemoi/datasets/create/sources/observations.py @@ -0,0 +1,43 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import pandas as pd + + +def check_dataframe(df): + """Check the DataFrame for consistency.""" + if df.empty: + pass + if "times" not in df.columns: + raise ValueError("The DataFrame must contain a 'times' column.") + if not pd.api.types.is_datetime64_any_dtype(df["times"]): + raise TypeError("The 'times' column must be of datetime type.") + if "latitudes" not in df.columns or "longitudes" not in df.columns: + raise ValueError("The DataFrame must contain 'latitudes' and 'longitudes' columns.") + + +class ObservationsSource: + def __call__(self, window): + raise NotImplementedError("This method should be implemented by subclasses") + + def _check(self, df): + check_dataframe(df) + return df + + +class ObservationsFilter: + def __call__(self, df): + """Filter the data based on the given window.""" + check_dataframe(df) + return df + + def _check(self, df): + check_dataframe(df) + return df From 48230994c1916865bff048edb7323c3c911a052d Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 10 Jun 2025 20:49:53 +0000 Subject: [PATCH 028/212] creating obs dataset. Draft --- tests/create/test_observations.py | 73 +++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/create/test_observations.py diff --git a/tests/create/test_observations.py b/tests/create/test_observations.py new file mode 100644 index 000000000..2166827b9 --- /dev/null +++ b/tests/create/test_observations.py @@ -0,0 +1,73 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime + +import numpy as np +import pandas as pd + +from anemoi.datasets.create.sources.observations import ObservationsFilter +from anemoi.datasets.create.sources.observations import ObservationsSource +from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import window_from_str + + +class DummpySource(ObservationsSource): + def __init__(self, data): + assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" + self.data = data + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + if window.include_start: + mask = self.data["times"] > window.start + else: + mask = self.data["times"] >= window.start + if window.include_end: + mask &= self.data["times"] <= window.end + else: + mask &= self.data["times"] < window.end + + df = self.data[mask] + + return self._check(df) + + +class DummyFilter(ObservationsFilter): + def __call__(self, df): + """Filter the data based on the given window.""" + self._check(df) + # Here we can add any filtering logic if needed + df["a1"] = df["a1"] + 0.42 + return self._check(df) + + +dates = [datetime.datetime(2023, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] + +N = 100 +source = DummpySource( + pd.DataFrame( + { + "times": np.arange(N) * datetime.timedelta(hours=1) + dates[0], + "latitudes": -0.1 * np.arange(N), + "longitudes": -0.2 * np.arange(N), + "a1": np.arange(N) * 1.0, + "a2": np.arange(N) * 2.0, + } + ) +) +filter = DummyFilter() + +for d in dates: + window = window_from_str("(-5h, 1h]").to_absolute_window(d) + d = source(window) + d = filter(d) + print(window) + print(d) From 89fe75179fe68e40016711aac28492af6fa2abbc Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 12 Jun 2025 10:08:39 +0000 Subject: [PATCH 029/212] feat(obs): add set_group and rename --- src/anemoi/datasets/data/records/__init__.py | 42 ++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 4e46f0061..38abd846a 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -141,6 +141,14 @@ def _dates_to_indices(start, end): if window is not None: return Rewindowed(self, window)._subset(**kwargs) + set_group = kwargs.pop("set_group", None) + if set_group is not None: + return SetGroup(self, set_group)._subset(**kwargs) + + rename = kwargs.pop("rename", None) + if rename is not None: + return Rename(self, rename)._subset(**kwargs) + for k in kwargs: if k in ["backend"]: continue @@ -198,6 +206,40 @@ def __len__(self): return len(self.dates) +class Rename(RecordsForward): + def __init__(self, dataset, rename): + self.forward = dataset + # rename: {"current_group": "new_group"} + assert isinstance(rename, dict) + for k, v in rename.items(): + assert isinstance(k, str), k + assert isinstance(v, str), v + self.rename = rename + + @property + def statistics(self): + return {self.rename.get(k, k): v for k, v in self.forward.statistics.items()} + + @property + def variables(self): + return {self.rename.get(k, k): v for k, v in self.forward.variables.items()} + + @property + def name_to_index(self): + return {self.rename.get(k, k): v for k, v in self.forward.name_to_index.items()} + + def keys(self): + return [self.rename.get(k, k) for k in self.forward.keys()] + + +class SetGroup(Rename): + def __init__(self, dataset, set_group): + if len(dataset.groups) != 1: + raise ValueError(f"{self.__class__.__name__} can only be used with datasets containing a single group.") + + super.__init__(dataset, {dataset.groups[0]: set_group}) + + def match_variable(lst, group, name): # lst must be a list of strings with dots (if there is no dot, it is automatically added at the end) # - a dict with keys as group and values as list of strings From 4e6d8774e91dd602e5af688f8761b8f82904f17b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:09:08 +0000 Subject: [PATCH 030/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/data/records/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 38abd846a..43d2669e7 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -219,15 +219,15 @@ def __init__(self, dataset, rename): @property def statistics(self): return {self.rename.get(k, k): v for k, v in self.forward.statistics.items()} - + @property def variables(self): return {self.rename.get(k, k): v for k, v in self.forward.variables.items()} - + @property def name_to_index(self): return {self.rename.get(k, k): v for k, v in self.forward.name_to_index.items()} - + def keys(self): return [self.rename.get(k, k) for k in self.forward.keys()] From 2bd2567972bafa5962a01cc6efd5cba7191bb23b Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 11 Jun 2025 09:05:40 +0000 Subject: [PATCH 031/212] fieldsrecords and set_group. and some refactor. may not be backward compatible with dataset created previously --- src/anemoi/datasets/data/misc.py | 8 ++ src/anemoi/datasets/data/records/__init__.py | 114 +++++++++++++++--- .../data/records/backends/__init__.py | 66 ++++++++-- tests/test_data.py | 11 ++ tests/test_records.py | 2 +- 5 files changed, 173 insertions(+), 28 deletions(-) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index dea006194..6094cdeca 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -514,6 +514,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": return observations_factory(args, kwargs).mutate() + if "xy" in kwargs: # Experimental feature, may be removed from .xy import xy_factory @@ -594,6 +595,13 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": assert len(sets) > 0, (args, kwargs) + if "set_group" in kwargs: + from anemoi.datasets.data.records import FieldsRecords + assert len(sets) == 1, sets + set_group = kwargs.pop("set_group") + + return FieldsRecords(*sets, name=set_group).mutate() + if len(sets) > 1: dataset, kwargs = _concat_or_join(sets, kwargs) return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 43d2669e7..2e6a21468 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -108,7 +108,7 @@ def end_date(self): @property def groups(self): - return tuple(self.keys()) + raise NotImplementedError("Must be implemented in subclass") def _subset(self, **kwargs): start = kwargs.pop("start", None) @@ -179,8 +179,9 @@ def statistics(self): def variables(self): return self.forward.variables - def keys(self): - return self.forward.keys() + @property + def groups(self): + return self.forward.groups @property def dates(self): @@ -206,6 +207,54 @@ def __len__(self): return len(self.dates) +class FieldsRecords(RecordsForward): + """A wrapper around a FieldsDataset to provide a consistent interface for records datasets.""" + + def __init__(self, fields_dataset, name): + self.forward = fields_dataset + self._name = name + self._groups = [name] + + def _nest_in_dict(self, obj): + """Helper to nest the object in a dict with the name as key.""" + return {self._name: obj} + + @property + def groups(self): + return self._groups + + @property + def statistics(self): + return self._nest_in_dict(self.forward.statistics) + + @property + def variables(self): + return self._nest_in_dict(self.forward.variables) + + @property + def dates(self): + return self.forward.dates + + @property + def name_to_index(self): + return self._nest_in_dict(self.forward.name_to_index) + + @property + def frequency(self): + return self.forward.frequency + + @property + def _window(self): + return self.forward._window + + @property + def shapes(self): + return self._nest_in_dict(self.forward.shape) + + def __len__(self): + return len(self.forward.dates) + + class Rename(RecordsForward): def __init__(self, dataset, rename): self.forward = dataset @@ -228,8 +277,9 @@ def variables(self): def name_to_index(self): return {self.rename.get(k, k): v for k, v in self.forward.name_to_index.items()} - def keys(self): - return [self.rename.get(k, k) for k in self.forward.keys()] + @property + def groups(self): + return [self.rename.get(k, k) for k in self.forward.groups] class SetGroup(Rename): @@ -315,6 +365,23 @@ def _to_timedelta(t): ) +class AbsoluteWindow: + def __init__(self, start, end, include_start=True, include_end=True): + assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" + assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}" + assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" + assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" + if start >= end: + raise ValueError(f"start {start} must be less than end {end}") + self.start = start + self.end = end + self.include_start = include_start + self.include_end = include_end + + def __repr__(self): + return f"{'[' if self.include_start else '('}{self.start.isoformat()},{self.end.isoformat()}{']' if self.include_end else ')'}" + + class WindowsSpec: def __init__(self, *, start, end, include_start=False, include_end=True): assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" @@ -333,6 +400,13 @@ def __init__(self, *, start, end, include_start=False, include_end=True): self._start_np = _to_numpy_timedelta(start) self._end_np = _to_numpy_timedelta(end) + def to_absolute_window(self, date): + """Convert the window to an absolute window based on a date.""" + assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" + start = date + self.start + end = date + self.end + return AbsoluteWindow(start=start, end=end, include_start=self.include_start, include_end=self.include_end) + def __repr__(self): first = "[" if self.include_start else "(" last = "]" if self.include_end else ")" @@ -534,8 +608,9 @@ def _build_indices_and_name_to_index(self): def match_variable(self, *args, **kwargs): return match_variable(self._select, *args, **kwargs) - def keys(self): - return self._indices.keys() + @property + def groups(self): + return list(self._indices.keys()) def _load_data(self, i): forward = self.dataset._load_data(i) @@ -592,10 +667,14 @@ def __init__(self, path, backend="npz1", **kwargs): print("Warning: ignoring additional kwargs", kwargs) self.path = path self.backend = backend_factory(backend, path, **kwargs) - self.keys = self.metadata["sources"].keys - for k in self.keys(): + self._groups = list(self.metadata["sources"].keys()) + for k in self.groups: assert k == self.normalise_key(k), k + @property + def groups(self): + return self._groups + @classmethod def normalise_key(cls, k): return "".join([x.lower() if x.isalnum() else "-" for x in k]) @@ -628,7 +707,7 @@ def shapes(self): return self.metadata["shapes"] def items(self, *args, **kwargs): - return {k: Tabular(self, k) for k in self.keys()}.items(*args, **kwargs) + return {k: Tabular(self, k) for k in self.groups}.items(*args, **kwargs) @cached_property def statistics(self): @@ -673,13 +752,13 @@ def check(self, i=None): assert s == {"latitudes", "longitudes", "timedeltas", "metadata", "data"}, f"Invalid keys {s}" -class Record(dict): +class Record: def __init__(self, dataset, n): self.dataset = dataset self.n = n def __repr__(self): - d = {group: "" for group in self.dataset.keys()} + d = {group: "" for group in self.dataset.groups} return str(d) def items(self): @@ -696,15 +775,16 @@ def _payload(self): assert len(k.split(":")) == 2, f"Invalid key {k}" return payload - def keys(self): - return self.dataset.keys() + @cached_property + def groups(self): + return self.dataset.groups def __getitem__(self, group): return self._payload["data:" + group] def _get_aux(self, name): try: - return {k: self._payload[name + ":" + k] for k in self.keys()} + return {k: self._payload[name + ":" + k] for k in self.groups} except KeyError as e: e.add_note(f"Available keys are {self._payload.keys()}") raise @@ -725,10 +805,6 @@ def timedeltas(self): def statistics(self): return self.dataset.statistics - @property - def groups(self): - return tuple(self.keys()) - class Tabular: def __init__(self, dataset, name): diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index e7629069d..e831d8f82 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -8,9 +8,13 @@ # nor does it submit to any jurisdiction. import json +import logging import os import numpy as np +from cachetools import LRUCache + +LOG = logging.getLogger(__name__) def normalise_key(k): @@ -39,13 +43,24 @@ def _check_data(self, data): class Npz1Backend(Backend): - number_of_files_per_subdirectory = 10 + number_of_files_per_subdirectory = 100 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cache = None def read(self, i, **kwargs): + if self._cache is None: + self._cache = LRUCache(maxsize=5) + if i in self._cache: + return self._cache[i] + d = str(int(i / self.number_of_files_per_subdirectory)) path = os.path.join(self.path, "data", d, f"{i}.npz") with open(path, "rb") as f: - return dict(np.load(f)) + data = dict(np.load(f)) + self._cache[i] = data + return data def read_metadata(self): with open(os.path.join(self.path, "metadata.json"), "r") as f: @@ -141,26 +156,60 @@ def _check_data(self, data): if k != normalise_key(k): raise ValueError(f"{k} must be alphanumerical and '-' only.") + def _dataframes_to_record(self, i, data, variables, **kwargs): + + assert isinstance(data, (dict)), type(data) + if not data: + LOG.warning(f"Empty data for index {i}.") + return data + first = data[list(data.keys())[0]] + import pandas as pd + + if isinstance(first, pd.DataFrame): + data = {name: self._dataframe_to_dict(name, df, **kwargs) for name, df in data.items()} + else: + assert False + + return data + + def _dataframe_to_dict(self, name, df, **kwargs): + d = {} + d["timedeltas:" + name] = df["timedeltas"] + d["latitudes:" + name] = df["latitudes"] + d["longitudes:" + name] = df["longitudes"] + d["data:" + name] = df["data"] + d["metadata:" + name] = df["metadata"] + return d + class Npz1WriteBackend(WriteBackend): - number_of_files_per_subdirectory = 10 + number_of_files_per_subdirectory = 100 def write(self, i, data, **kwargs): self._check_data(data) d = str(int(i / self.number_of_files_per_subdirectory)) - path = os.path.join(self.path, "data", d) - os.makedirs(path, exist_ok=True) - out_path = os.path.join(path, f"{i}.npz") - np.savez(out_path, **data) + dir_path = os.path.join(self.path, "data", d) + + out_path = os.path.join(dir_path, f"{i}.npz") + tmp_path = os.path.join(dir_path, f"{i}.tmp.npz") + + os.makedirs(os.path.dirname(tmp_path), exist_ok=True) + np.savez(tmp_path, **data) + os.rename(tmp_path, out_path) def write_metadata(self, metadata): from anemoi.datasets.create import json_tidy os.makedirs(self.path, exist_ok=True) - with open(os.path.join(self.path, "metadata.json"), "w") as f: + + path = os.path.join(self.path, "metadata.json") + tmp_path = path + ".tmp" + with open(tmp_path, "w") as f: json.dump(metadata, f, indent=2, default=json_tidy) + os.rename(tmp_path, path) def write_statistics(self, statistics): + os.makedirs(self.path, exist_ok=True) flatten = {} for name, d in statistics.items(): assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" @@ -200,6 +249,7 @@ def write_metadata(self, metadata): json.dump(metadata, f, indent=2, default=json_tidy) def write_statistics(self, statistics): + os.makedirs(self.path, exist_ok=True) flatten = {} for name, d in statistics.items(): assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" diff --git a/tests/test_data.py b/tests/test_data.py index 07b35887e..d69f53eaa 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1450,6 +1450,17 @@ def test_invalid_trim_edge() -> None: ) +@mockup_open_zarr +def test_fields_to_records() -> None: + """Test joining datasets (case 2).""" + + key = 'grp' + ds = open_dataset(dataset="test-2021-2021-6h-o96-abcd-1", set_group=key) + unwrapped = open_dataset(dataset="test-2021-2021-6h-o96-abcd-2") + + assert ds.groups == [key] + assert ds.variables == {key:["a", "b", "c", "d"]} + if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): diff --git a/tests/test_records.py b/tests/test_records.py index 1b7cd0ccc..7c065e4a7 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -109,7 +109,7 @@ def _test(ds, nb_dates=None): _statistics = ds[index_i].statistics assert isinstance(_statistics, dict), type(_statistics) assert grp in _statistics, f"statistics does not contain {grp}" - assert _statistics.keys() == ds.keys(), (_statistics.keys(), ds.keys()) + assert list(_statistics.keys()) == ds.groups, (_statistics.keys(), ds.groups) for group_name, stats in _statistics.items(): assert "mean" in stats, f"statistics does not contain mean for {group_name}" for key, v in stats.items(): From ab564021bbdc1c6470dae8512f9f250d1c037225 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:23:09 +0000 Subject: [PATCH 032/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/data/misc.py | 2 +- tests/test_data.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 6094cdeca..169830202 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -514,7 +514,6 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": return observations_factory(args, kwargs).mutate() - if "xy" in kwargs: # Experimental feature, may be removed from .xy import xy_factory @@ -597,6 +596,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "set_group" in kwargs: from anemoi.datasets.data.records import FieldsRecords + assert len(sets) == 1, sets set_group = kwargs.pop("set_group") diff --git a/tests/test_data.py b/tests/test_data.py index d69f53eaa..488d05bae 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1454,12 +1454,13 @@ def test_invalid_trim_edge() -> None: def test_fields_to_records() -> None: """Test joining datasets (case 2).""" - key = 'grp' + key = "grp" ds = open_dataset(dataset="test-2021-2021-6h-o96-abcd-1", set_group=key) unwrapped = open_dataset(dataset="test-2021-2021-6h-o96-abcd-2") assert ds.groups == [key] - assert ds.variables == {key:["a", "b", "c", "d"]} + assert ds.variables == {key: ["a", "b", "c", "d"]} + if __name__ == "__main__": for name, obj in list(globals().items()): From 111bcc8bb9c8762cc67d3788ac491b1d0522a8e2 Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Tue, 17 Jun 2025 09:53:41 +0000 Subject: [PATCH 033/212] adding odb mars draft example --- tests/create/odb2df.py | 113 ++++++++++++++++++++++ tests/create/test_observations_mars.py | 128 +++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 tests/create/odb2df.py create mode 100644 tests/create/test_observations_mars.py diff --git a/tests/create/odb2df.py b/tests/create/odb2df.py new file mode 100644 index 000000000..be63aa0c4 --- /dev/null +++ b/tests/create/odb2df.py @@ -0,0 +1,113 @@ +import json +import logging +from typing import List, Dict, Optional, Union +import pandas as pd +from earthkit.data.readers.odb import ODBReader + +logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") + + +def load_varno_dict(path: Optional[str] = None) -> Dict: + """Load varno mapping, return empty dict if not found.""" + try: + with open(path or "varno.json") as f: + return json.load(f) + except: + return {"data": []} + + +def get_varno_name(varno: Union[int, str], varno_dict: Dict) -> str: + """Get varno name or return original if not found.""" + try: + v = int(varno) + for entry in varno_dict.get("data", []): + if v in entry: + return str(entry[0]) + except: + pass + return str(varno) + + +def rename_cols(cols: List, extra_obs: List[str] = None, varno_path: str = None) -> List[str]: + """Rename columns: base_name_varno_level""" + varno_dict = load_varno_dict(varno_path) + extra_obs = extra_obs or [] + + result = [] + for col in cols: + if isinstance(col, tuple): + parts = col + ("", "") + name, varno = parts[:2] + level = parts[2] if len(parts) > 2 else "" + else: + name, varno, level = col, "", "" + + base = name.split("@")[0] + if base in extra_obs: + base = f"obsvalue_{base}" + + if varno: + varno_name = get_varno_name(varno, varno_dict) + level_str = str(int(level)) if level and not isinstance(level, (list, tuple)) else "0" + result.append(f"{base}_{varno_name}_{level_str}") + else: + result.append(base) + + return result + + +def process_odb(reader: ODBReader, index: List[str], pivot: List[str], values: List[str], + sort: List[str] = None, extra_obs: List[str] = None, drop_na: bool = False, + datetime_cols: tuple = ("date@hdr", "time@hdr"), varno_path: str = None) -> pd.DataFrame: + """Process ODB data: convert to pandas, pivot, rename columns.""" + + try: + df = reader.to_pandas() + except Exception as e: + logging.error(f"ODB conversion failed: {e}") + return pd.DataFrame() + + if df.empty: + return df + + # Remove duplicates and pivot + df = df.drop_duplicates(subset=index + pivot, keep="first") + df = df.pivot(index=index, columns=pivot, values=values) + + # Sort and reset + if sort and all(c in df.index.names for c in sort): + df = df.sort_values(by=sort, kind="stable") + df = df.reset_index() + + # Reorganize columns + meta = df[index] + obs = df.drop(columns=index, level=0).sort_index(axis=1) + df = pd.concat([meta, obs], axis=1) + + if drop_na: + df = df.dropna() + + # Create datetime if both columns exist + date_col, time_col = datetime_cols + if date_col in df.columns and time_col in df.columns: + try: + df["times"] = pd.to_datetime( + df[date_col].astype(int).astype(str) + + df[time_col].astype(int).astype(str).str.zfill(6), + format="%Y%m%d%H%M%S" + ) + df = df.drop(columns=[date_col, time_col], level=0) + except: + logging.warning("Could not create datetime column") + + # Rename columns + df.columns = rename_cols(df.columns.tolist(), extra_obs, varno_path) + + # Rename lat/lon columns to match expected format + df = df.rename(columns={"lat": "latitudes", "lon": "longitudes"}) + + return df + + +# Example usage: +# df = process_odb(reader, ["seqno@hdr", "lat@hdr", "lon@hdr"], ["varno@body"], ["obsvalue@body"]) \ No newline at end of file diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py new file mode 100644 index 000000000..0e1ea80d6 --- /dev/null +++ b/tests/create/test_observations_mars.py @@ -0,0 +1,128 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime + +import numpy as np +import pandas as pd + +from anemoi.datasets.create.sources.observations import ObservationsFilter +from anemoi.datasets.create.sources.observations import ObservationsSource +from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import window_from_str +from earthkit.data import from_source +from odb2df import process_odb +import logging + +log = logging.getLogger(__name__) + + +class DummpySource(ObservationsSource): + def __init__(self, data): + assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" + self.data = data + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + if window.include_start: + mask = self.data["times"] > window.start + else: + mask = self.data["times"] >= window.start + if window.include_end: + mask &= self.data["times"] <= window.end + else: + mask &= self.data["times"] < window.end + + df = self.data[mask] + + return self._check(df) + + +class MarsSource(ObservationsSource): + def __init__(self, request_dict, post_process_dict): + assert isinstance(request_dict, dict), "request_dict must be a dictionary" + self.request_dict = request_dict + self.post_process_dict = post_process_dict + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + request_dict = self.request_dict + request_dict["date"] = ( + f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" + ) + try: + ekd_ds = from_source("mars", request_dict) + except Exception as e: + if "File is empty" in str(e): + log.warning( + f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." + ) + return + else: + raise # Re-raise if it's a different error + + data = process_odb(ekd_ds, **self.post_process_dict) + + print(data) + print(data.columns) + + if window.include_start: + mask = data["times"] > window.start + else: + mask = data["times"] >= window.start + if window.include_end: + mask &= data["times"] <= window.end + else: + mask &= data["times"] < window.end + + df = data[mask] + + return self._check(df) + + +class DummyFilter(ObservationsFilter): + def __call__(self, df, col_name): + """Filter the data based on the given window.""" + self._check(df) + # Here we can add any filtering logic if needed + df.loc[:, col_name] = df[col_name] + 0.42 + return self._check(df) + + +dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] + +N = 100 +source = MarsSource( + request_dict={ + "class": "ea", + "expver": "0001", + "stream": "oper", + "obsgroup": "conv", + "reportype": "16001/16002/16004/16065/16076", + "type": "ofb", + "time": "00/12", + "filter": "'select seqno,reportype,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000) or (varno == 91)) and time in (000000,030000,060000,090000,120000,150000,180000,210000);'" + }, + post_process_dict={ + "index": ["seqno@hdr", "lat@hdr", "lon@hdr", "date@hdr", "time@hdr", "stalt@hdr", "lsm@modsurf"], + "pivot": ["varno@body"], + "values": ["obsvalue@body"] + } +) +filter = DummyFilter() + +for d in dates: + window = window_from_str("(-5h, 1h]").to_absolute_window(d) + print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) + d = source(window) + d = filter(d, "obsvalue_v10m_0") + print(window) + print(d) From 5738660a3fec3159359b0a2c9fd30ee5e6a256ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 09:54:09 +0000 Subject: [PATCH 034/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/create/odb2df.py | 55 +++++++++++++++----------- tests/create/test_observations_mars.py | 25 ++++++------ 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/tests/create/odb2df.py b/tests/create/odb2df.py index be63aa0c4..5ada55252 100644 --- a/tests/create/odb2df.py +++ b/tests/create/odb2df.py @@ -1,6 +1,10 @@ import json import logging -from typing import List, Dict, Optional, Union +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + import pandas as pd from earthkit.data.readers.odb import ODBReader @@ -32,7 +36,7 @@ def rename_cols(cols: List, extra_obs: List[str] = None, varno_path: str = None) """Rename columns: base_name_varno_level""" varno_dict = load_varno_dict(varno_path) extra_obs = extra_obs or [] - + result = [] for col in cols: if isinstance(col, tuple): @@ -41,73 +45,80 @@ def rename_cols(cols: List, extra_obs: List[str] = None, varno_path: str = None) level = parts[2] if len(parts) > 2 else "" else: name, varno, level = col, "", "" - + base = name.split("@")[0] if base in extra_obs: base = f"obsvalue_{base}" - + if varno: varno_name = get_varno_name(varno, varno_dict) level_str = str(int(level)) if level and not isinstance(level, (list, tuple)) else "0" result.append(f"{base}_{varno_name}_{level_str}") else: result.append(base) - + return result -def process_odb(reader: ODBReader, index: List[str], pivot: List[str], values: List[str], - sort: List[str] = None, extra_obs: List[str] = None, drop_na: bool = False, - datetime_cols: tuple = ("date@hdr", "time@hdr"), varno_path: str = None) -> pd.DataFrame: +def process_odb( + reader: ODBReader, + index: List[str], + pivot: List[str], + values: List[str], + sort: List[str] = None, + extra_obs: List[str] = None, + drop_na: bool = False, + datetime_cols: tuple = ("date@hdr", "time@hdr"), + varno_path: str = None, +) -> pd.DataFrame: """Process ODB data: convert to pandas, pivot, rename columns.""" - + try: df = reader.to_pandas() except Exception as e: logging.error(f"ODB conversion failed: {e}") return pd.DataFrame() - + if df.empty: return df - + # Remove duplicates and pivot df = df.drop_duplicates(subset=index + pivot, keep="first") df = df.pivot(index=index, columns=pivot, values=values) - + # Sort and reset if sort and all(c in df.index.names for c in sort): df = df.sort_values(by=sort, kind="stable") df = df.reset_index() - + # Reorganize columns meta = df[index] obs = df.drop(columns=index, level=0).sort_index(axis=1) df = pd.concat([meta, obs], axis=1) - + if drop_na: df = df.dropna() - + # Create datetime if both columns exist date_col, time_col = datetime_cols if date_col in df.columns and time_col in df.columns: try: df["times"] = pd.to_datetime( - df[date_col].astype(int).astype(str) + - df[time_col].astype(int).astype(str).str.zfill(6), - format="%Y%m%d%H%M%S" + df[date_col].astype(int).astype(str) + df[time_col].astype(int).astype(str).str.zfill(6), + format="%Y%m%d%H%M%S", ) df = df.drop(columns=[date_col, time_col], level=0) except: logging.warning("Could not create datetime column") - + # Rename columns df.columns = rename_cols(df.columns.tolist(), extra_obs, varno_path) - + # Rename lat/lon columns to match expected format df = df.rename(columns={"lat": "latitudes", "lon": "longitudes"}) - + return df # Example usage: -# df = process_odb(reader, ["seqno@hdr", "lat@hdr", "lon@hdr"], ["varno@body"], ["obsvalue@body"]) \ No newline at end of file +# df = process_odb(reader, ["seqno@hdr", "lat@hdr", "lon@hdr"], ["varno@body"], ["obsvalue@body"]) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index 0e1ea80d6..12482fc5f 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -8,17 +8,16 @@ # nor does it submit to any jurisdiction. import datetime +import logging -import numpy as np import pandas as pd +from earthkit.data import from_source +from odb2df import process_odb from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource from anemoi.datasets.data.records import AbsoluteWindow from anemoi.datasets.data.records import window_from_str -from earthkit.data import from_source -from odb2df import process_odb -import logging log = logging.getLogger(__name__) @@ -53,11 +52,9 @@ def __init__(self, request_dict, post_process_dict): def __call__(self, window): assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" - + request_dict = self.request_dict - request_dict["date"] = ( - f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" - ) + request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" try: ekd_ds = from_source("mars", request_dict) except Exception as e: @@ -70,7 +67,7 @@ def __call__(self, window): raise # Re-raise if it's a different error data = process_odb(ekd_ds, **self.post_process_dict) - + print(data) print(data.columns) @@ -109,13 +106,13 @@ def __call__(self, df, col_name): "reportype": "16001/16002/16004/16065/16076", "type": "ofb", "time": "00/12", - "filter": "'select seqno,reportype,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000) or (varno == 91)) and time in (000000,030000,060000,090000,120000,150000,180000,210000);'" + "filter": "'select seqno,reportype,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000) or (varno == 91)) and time in (000000,030000,060000,090000,120000,150000,180000,210000);'", }, post_process_dict={ - "index": ["seqno@hdr", "lat@hdr", "lon@hdr", "date@hdr", "time@hdr", "stalt@hdr", "lsm@modsurf"], - "pivot": ["varno@body"], - "values": ["obsvalue@body"] - } + "index": ["seqno@hdr", "lat@hdr", "lon@hdr", "date@hdr", "time@hdr", "stalt@hdr", "lsm@modsurf"], + "pivot": ["varno@body"], + "values": ["obsvalue@body"], + }, ) filter = DummyFilter() From 3680c6b6f1a3f50130c0e79d1fbf5745b829f394 Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Tue, 17 Jun 2025 09:58:54 +0000 Subject: [PATCH 035/212] adding odb mars draft example --- tests/create/test_observations_mars.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index 0e1ea80d6..4ad7bf935 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -71,8 +71,8 @@ def __call__(self, window): data = process_odb(ekd_ds, **self.post_process_dict) - print(data) - print(data.columns) + # print(data) + # print(data.columns) if window.include_start: mask = data["times"] > window.start @@ -89,17 +89,19 @@ def __call__(self, window): class DummyFilter(ObservationsFilter): - def __call__(self, df, col_name): + def __init__(self, col_name): + self.col_name = col_name + + def __call__(self, df): """Filter the data based on the given window.""" self._check(df) # Here we can add any filtering logic if needed - df.loc[:, col_name] = df[col_name] + 0.42 + df.loc[:, self.col_name] = df[self.col_name] + 0.42 return self._check(df) dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] -N = 100 source = MarsSource( request_dict={ "class": "ea", @@ -117,12 +119,12 @@ def __call__(self, df, col_name): "values": ["obsvalue@body"] } ) -filter = DummyFilter() +filter = DummyFilter("obsvalue_v10m_0") for d in dates: window = window_from_str("(-5h, 1h]").to_absolute_window(d) print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) d = source(window) - d = filter(d, "obsvalue_v10m_0") + d = filter(d) print(window) print(d) From bc7015370508d6599f777a1bdb08f463440801db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:03:05 +0000 Subject: [PATCH 036/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/create/test_observations_mars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index 5f45438af..1e54a4a08 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -67,7 +67,7 @@ def __call__(self, window): raise # Re-raise if it's a different error data = process_odb(ekd_ds, **self.post_process_dict) - + # print(data) # print(data.columns) From a195b03ce8459450786eee0f27c4342d3afbace9 Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Tue, 17 Jun 2025 10:17:35 +0000 Subject: [PATCH 037/212] adding needed metadata file --- tests/create/varno.json | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/create/varno.json diff --git a/tests/create/varno.json b/tests/create/varno.json new file mode 100644 index 000000000..3ce87746d --- /dev/null +++ b/tests/create/varno.json @@ -0,0 +1 @@ +{"fields": ["name", "code", "description"], "data": [["u", 3, "upper air u component"], ["v", 4, "upper air v component"], ["z", 1, "geopotential"], ["dz", 57, "thickness"], ["rh", 29, "upper air rel. humidity"], ["pwc", 9, "precipitable water content"], ["rh2m", 58, "2m rel. humidity"], ["t", 2, "upper air temperature (K)"], ["td", 59, "upper air dew point (K)"], ["t2m", 39, "2m temperature (K)"], ["td2m", 40, "2m dew point (K)"], ["ts", 11, "surface temperature (K)"], ["ptend", 30, "pressure tendency"], ["w", 60, "past weather (w)"], ["ww", 61, "present weather (ww)"], ["vv", 62, "visibility"], ["ch", 63, "type of high clouds (ch)"], ["cm", 64, "type of middle clouds (cm)"], ["cl", 65, "type of low clouds (cl)"], ["nh", 66, "cloud base height (nh) (meter)"], ["nn", 67, "low cloud amount (n)"], ["hshs", 68, "additional cloud group height (hh)"], ["c", 69, "additional cloud group type (c)"], ["ns", 70, "additional cloud group amount (ns)"], ["sdepth", 71, "snow depth"], ["e", 72, "state of ground (e)"], ["tgtg", 73, "ground temperature (tgtg)"], ["spsp1", 74, "special phenomena (spsp)#1"], ["spsp2", 75, "special phenomena (spsp)#2"], ["rs", 76, "ice code type (rs)"], ["eses", 77, "ice thickness (eses)"], ["is", 78, "ice (is)"], ["trtr", 79, "original time period of rain obs. (trtr)"], ["rr", 80, "6hr rain (liquid part)"], ["jj", 81, "max. temperature (jj)"], ["vs", 82, "ship speed (vs)"], ["ds", 83, "ship direction (ds)"], ["hwhw", 84, "wave height"], ["pwpw", 85, "wave period"], ["dwdw", 86, "wave direction"], ["gclg", 87, "general cloud group"], ["rhlc", 88, "rel. humidity from low clouds"], ["rhmc", 89, "rel. humidity from middle clouds"], ["rhhc", 90, "rel. humidity from high clouds"], ["n", 91, "total amount of clouds"], ["sfall", 92, "6hr snowfall (solid part of rain)"], ["ps", 110, "surface pressure"], ["dd", 111, "wind direction"], ["ff", 112, "wind force"], ["rawbt", 119, "brightness temperature (K)"], ["rawra", 120, "raw radiance"], ["satcl", 121, "cloud amount from satellite"], ["scatss", 122, "sigma 0"], ["du", 5, "wind shear (du)"], ["dv", 6, "wind shear (dv)"], ["u10m", 41, "10m u component (m/s)"], ["v10m", 42, "10m v component (m/s)"], ["rhlay", 19, "layer rel. humidity"], ["cllqw", 123, "cloud liquid water"], ["scatdd", 124, "ambiguous v component"], ["scatff", 125, "ambiguous u component"], ["q", 7, "specific humidity (q)"], ["scatwd", 126, "ambiguous wind direction"], ["scatws", 127, "ambiguous wind speed"], ["vsp", 8, "vertical speed"], ["vt", 56, "virtual temperature"], ["o3lay", 206, "layer ozone"], ["height", 156, "height"], ["1dvar", 215, "1d-var model level (pseudo)-variable"], ["w2", 160, "past weather 2 (used in synoptic maps)"], ["cpt", 130, "characteristic of pressure tendency (used in synoptic maps)"], ["tsts", 12, "sea water temperature (used in synoptic maps)"], ["refl", 192, "radar reflectivity"], ["apdss", 128, "atmospheric path delay in satellite signal"], ["bend_angle", 162, "radio occultation bending angle"], ["los", 187, "horizontal line-of-sight wind component"], ["aerod", 174, "aerosol optical depth at 0.55 microns"], ["limb_radiance", 163, "Limb Radiances"], ["chem3", 183, "chem3: co"], ["chem2", 182, "chem2: so2"], ["chem1", 181, "chem1: no2/nox"], ["cod", 175, "cloud optical depth"], ["rao", 176, "Ratio of fine mode to total aerosol optical depth at 0.55 microns"], ["od", 177, "optical depth"], ["rfltnc", 178, "Aerosol reflectance multi-channel"], ["nsoilm", 179, "normalized soil moisture (0-100%)"], ["soilm", 180, "soil moisture"], ["flgt_phase", 201, "phase of aircraft flight"], ["height_assignment_method", 211, "Height assignment method"], ["dopp", 195, "radar doppler wind"], ["ghg1", 186, "ghg1: carbon dioxide"], ["ghg2", 188, "ghg2: methane"], ["ghg3", 189, "ghg3: nitrous oxide"], ["bt_real", 190, "brightness temperature real part"], ["bt_imaginary", 191, "brightness temperature imaginary part"], ["prc", 202, "radar rain rate"], ["lnprc", 203, "log(radar rain rate mm/h + epsilon)"], ["libksc", 222, "lidar backscattering"], ["ralt_swh", 220, "significant wave height (m)"], ["ralt_sws", 221, "surface wind speed (m/s)"], ["rawbt_clear", 193, "brightness temperature for clear (K)"], ["rawbt_cloudy", 194, "brightness temperature for cloudy (K)"], ["binary_snow_cover", 223, "binary snow cover (0: no snow / 1: presence of snow)"], ["salinity", 224, "ocean salinity (PSU)"], ["potential_temp", 225, "potential temperature (Kelvin)"], ["humidity_mixing_ratio", 226, "humidity mixing ratio (kg/kg)"], ["airframe_icing", 227, "airframe icing"], ["turbulence_index", 228, "turbulence index"], ["pstation", 107, "Station pressure (Pa)"], ["pmsl", 108, "Mean sea-level pressure (Pa)"], ["pstandard", 109, "Standard level pressure (Pa)"], ["vert_vv", 218, "Vertical visibility (m)"], ["max_wind_shear1", 219, "Wind shear above and below 1st maximum wind in sonde profile (s-1)"], ["tot_zen_delay", 229, "Total zenith delay (GPS)"], ["tot_zen_delay_err", 230, "Total zenith delay error (GPS)"], ["cloud_top_temp", 231, "Cloud top temperature (K)"], ["rawsca", 233, "Scaled radiance"], ["cloud_top_press", 235, "Cloud top pressure (Pa)"], ["mean_freq", 241, "GPSRO mean frequency"], ["u_amb", 242, "Ambiguous u-wind component (m/s)"], ["v_amb", 243, "Ambiguous v-wind component (m/s)"], ["lwp", 244, "Liquid water path"], ["tcwv", 245, "Total column water vapour"], ["cloud_frac_clear", 247, "Cloud clear fraction"], ["rawbt_hirs", 248, "Raw brightness temperature specific to HIRS (K)"], ["rawbt_amsu", 249, "Raw brightness temperature specific to AMSU (K)"], ["rawbt_hirs20", 250, "Raw brightness temperature specific to HIRS (K)"], ["sea_ice", 253, "Sea ice fraction"], ["cloud_frac_covered", 257, "Cloud covered fraction"], ["level_mixing_ratio", 258, "humidity_mixing_ratio]"], ["radial_velocity", 259, "Radial velocity from doppler radar"], ["cloud_ice_water", 260, "Cloud ice water"], ["wind_gust", 261, "Maximum wind gust (m/s)"], ["mass_density", 262, "Mass density"], ["atmosphere_number", 263, "SFERICS number of atmospheres"], ["lightning", 265, "Lightning strike observation (ATDNET)"], ["level_cloud", 266, "Cloud fraction (multi-level)"], ["rawbt_amsr_89ghz", 267, "Raw brightness temperature specific to AMSR 89GHz channels (K)"], ["max_wind_shear2", 268, "Wind shear above and below 2nd maximum wind in sonde profile"], ["lower_layer_p", 269, "Pressure at bottom of layer SBUV (Pa)"], ["upper_layer_p", 270, "Pressure at top of later SBUV (Pa)"], ["cloud_cover", 271, "Total cloud cover"], ["depth", 272, "Depth (m)"], ["ssh", 273, "Sea surface height (m)"], ["rawbt_mwts", 274, "Raw brightness temperature specific to MWTS (K)"], ["rawbt_mwhs", 275, "Raw brightness temperature specific to MWHS (K)"], ["tot_lightning_flash_dens", 196, "total (cloud-to-ground plus intra-cloud) lightning flash density (fl/km2/day)"], ["cg_lightning_flash_dens", 197, "cloud-to-ground lightning flash density ( fl/km2/day)"], ["lidar_aerosol_extinction", 236, "lidar aerosol extinction (1/m)"], ["lidar_cloud_backscatter", 237, "lidar cloud backscatter"], ["lidar_cloud_extinction", 238, "lidar cloud extinction"], ["cloud_radar_reflectivity", 239, "cloud radar reflectivity"], ["lidar_aerosol_attenuated_backscatter", 280, "lidar aerosol attenuated backscatter (1/m*sr)"], ["q2m", 281, "specific humidity at 2m (kg/kg)"], ["chem6", 284, "volcanic SO2"], ["sla", 287, "sea level anomaly"], ["ice_freeboard", 286, "Height of sea ice above open water"], ["snow_freeboard", 285, "Height of snow and sea ice above open water"], ["visible_spectral_reflectance", 240, "Visible Spectral Reflectance"], ["od10", 288, "optical depth at 10 microns"], ["chem4", 184, "chem4: hcho"], ["chem5", 185, "chem5: go3"], ["frac_snow_cover", 282, "fractional snow cover"], ["cloud_doppler_velocity", 251, "vertical radar doppler velocity"], ["lidar_rayleigh_backscatter", 252, "lidar Rayleigh backscatter"], ["sigma0_sm", 283, "backscatter coefficient normalized at 40 degree (db)"], ["t2m_min", 37, "minimum 2m temperature (K)"], ["t2m_max", 38, "maximum 2m temperature (K)"], ["ssrd", 25, "downward surface solar radiation (J/m2)"]]} \ No newline at end of file From 1f9f9aad4421284aadaacb05265282a365010cf8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 10:19:30 +0000 Subject: [PATCH 038/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/create/varno.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/create/varno.json b/tests/create/varno.json index 3ce87746d..7b54097c7 100644 --- a/tests/create/varno.json +++ b/tests/create/varno.json @@ -1 +1 @@ -{"fields": ["name", "code", "description"], "data": [["u", 3, "upper air u component"], ["v", 4, "upper air v component"], ["z", 1, "geopotential"], ["dz", 57, "thickness"], ["rh", 29, "upper air rel. humidity"], ["pwc", 9, "precipitable water content"], ["rh2m", 58, "2m rel. humidity"], ["t", 2, "upper air temperature (K)"], ["td", 59, "upper air dew point (K)"], ["t2m", 39, "2m temperature (K)"], ["td2m", 40, "2m dew point (K)"], ["ts", 11, "surface temperature (K)"], ["ptend", 30, "pressure tendency"], ["w", 60, "past weather (w)"], ["ww", 61, "present weather (ww)"], ["vv", 62, "visibility"], ["ch", 63, "type of high clouds (ch)"], ["cm", 64, "type of middle clouds (cm)"], ["cl", 65, "type of low clouds (cl)"], ["nh", 66, "cloud base height (nh) (meter)"], ["nn", 67, "low cloud amount (n)"], ["hshs", 68, "additional cloud group height (hh)"], ["c", 69, "additional cloud group type (c)"], ["ns", 70, "additional cloud group amount (ns)"], ["sdepth", 71, "snow depth"], ["e", 72, "state of ground (e)"], ["tgtg", 73, "ground temperature (tgtg)"], ["spsp1", 74, "special phenomena (spsp)#1"], ["spsp2", 75, "special phenomena (spsp)#2"], ["rs", 76, "ice code type (rs)"], ["eses", 77, "ice thickness (eses)"], ["is", 78, "ice (is)"], ["trtr", 79, "original time period of rain obs. (trtr)"], ["rr", 80, "6hr rain (liquid part)"], ["jj", 81, "max. temperature (jj)"], ["vs", 82, "ship speed (vs)"], ["ds", 83, "ship direction (ds)"], ["hwhw", 84, "wave height"], ["pwpw", 85, "wave period"], ["dwdw", 86, "wave direction"], ["gclg", 87, "general cloud group"], ["rhlc", 88, "rel. humidity from low clouds"], ["rhmc", 89, "rel. humidity from middle clouds"], ["rhhc", 90, "rel. humidity from high clouds"], ["n", 91, "total amount of clouds"], ["sfall", 92, "6hr snowfall (solid part of rain)"], ["ps", 110, "surface pressure"], ["dd", 111, "wind direction"], ["ff", 112, "wind force"], ["rawbt", 119, "brightness temperature (K)"], ["rawra", 120, "raw radiance"], ["satcl", 121, "cloud amount from satellite"], ["scatss", 122, "sigma 0"], ["du", 5, "wind shear (du)"], ["dv", 6, "wind shear (dv)"], ["u10m", 41, "10m u component (m/s)"], ["v10m", 42, "10m v component (m/s)"], ["rhlay", 19, "layer rel. humidity"], ["cllqw", 123, "cloud liquid water"], ["scatdd", 124, "ambiguous v component"], ["scatff", 125, "ambiguous u component"], ["q", 7, "specific humidity (q)"], ["scatwd", 126, "ambiguous wind direction"], ["scatws", 127, "ambiguous wind speed"], ["vsp", 8, "vertical speed"], ["vt", 56, "virtual temperature"], ["o3lay", 206, "layer ozone"], ["height", 156, "height"], ["1dvar", 215, "1d-var model level (pseudo)-variable"], ["w2", 160, "past weather 2 (used in synoptic maps)"], ["cpt", 130, "characteristic of pressure tendency (used in synoptic maps)"], ["tsts", 12, "sea water temperature (used in synoptic maps)"], ["refl", 192, "radar reflectivity"], ["apdss", 128, "atmospheric path delay in satellite signal"], ["bend_angle", 162, "radio occultation bending angle"], ["los", 187, "horizontal line-of-sight wind component"], ["aerod", 174, "aerosol optical depth at 0.55 microns"], ["limb_radiance", 163, "Limb Radiances"], ["chem3", 183, "chem3: co"], ["chem2", 182, "chem2: so2"], ["chem1", 181, "chem1: no2/nox"], ["cod", 175, "cloud optical depth"], ["rao", 176, "Ratio of fine mode to total aerosol optical depth at 0.55 microns"], ["od", 177, "optical depth"], ["rfltnc", 178, "Aerosol reflectance multi-channel"], ["nsoilm", 179, "normalized soil moisture (0-100%)"], ["soilm", 180, "soil moisture"], ["flgt_phase", 201, "phase of aircraft flight"], ["height_assignment_method", 211, "Height assignment method"], ["dopp", 195, "radar doppler wind"], ["ghg1", 186, "ghg1: carbon dioxide"], ["ghg2", 188, "ghg2: methane"], ["ghg3", 189, "ghg3: nitrous oxide"], ["bt_real", 190, "brightness temperature real part"], ["bt_imaginary", 191, "brightness temperature imaginary part"], ["prc", 202, "radar rain rate"], ["lnprc", 203, "log(radar rain rate mm/h + epsilon)"], ["libksc", 222, "lidar backscattering"], ["ralt_swh", 220, "significant wave height (m)"], ["ralt_sws", 221, "surface wind speed (m/s)"], ["rawbt_clear", 193, "brightness temperature for clear (K)"], ["rawbt_cloudy", 194, "brightness temperature for cloudy (K)"], ["binary_snow_cover", 223, "binary snow cover (0: no snow / 1: presence of snow)"], ["salinity", 224, "ocean salinity (PSU)"], ["potential_temp", 225, "potential temperature (Kelvin)"], ["humidity_mixing_ratio", 226, "humidity mixing ratio (kg/kg)"], ["airframe_icing", 227, "airframe icing"], ["turbulence_index", 228, "turbulence index"], ["pstation", 107, "Station pressure (Pa)"], ["pmsl", 108, "Mean sea-level pressure (Pa)"], ["pstandard", 109, "Standard level pressure (Pa)"], ["vert_vv", 218, "Vertical visibility (m)"], ["max_wind_shear1", 219, "Wind shear above and below 1st maximum wind in sonde profile (s-1)"], ["tot_zen_delay", 229, "Total zenith delay (GPS)"], ["tot_zen_delay_err", 230, "Total zenith delay error (GPS)"], ["cloud_top_temp", 231, "Cloud top temperature (K)"], ["rawsca", 233, "Scaled radiance"], ["cloud_top_press", 235, "Cloud top pressure (Pa)"], ["mean_freq", 241, "GPSRO mean frequency"], ["u_amb", 242, "Ambiguous u-wind component (m/s)"], ["v_amb", 243, "Ambiguous v-wind component (m/s)"], ["lwp", 244, "Liquid water path"], ["tcwv", 245, "Total column water vapour"], ["cloud_frac_clear", 247, "Cloud clear fraction"], ["rawbt_hirs", 248, "Raw brightness temperature specific to HIRS (K)"], ["rawbt_amsu", 249, "Raw brightness temperature specific to AMSU (K)"], ["rawbt_hirs20", 250, "Raw brightness temperature specific to HIRS (K)"], ["sea_ice", 253, "Sea ice fraction"], ["cloud_frac_covered", 257, "Cloud covered fraction"], ["level_mixing_ratio", 258, "humidity_mixing_ratio]"], ["radial_velocity", 259, "Radial velocity from doppler radar"], ["cloud_ice_water", 260, "Cloud ice water"], ["wind_gust", 261, "Maximum wind gust (m/s)"], ["mass_density", 262, "Mass density"], ["atmosphere_number", 263, "SFERICS number of atmospheres"], ["lightning", 265, "Lightning strike observation (ATDNET)"], ["level_cloud", 266, "Cloud fraction (multi-level)"], ["rawbt_amsr_89ghz", 267, "Raw brightness temperature specific to AMSR 89GHz channels (K)"], ["max_wind_shear2", 268, "Wind shear above and below 2nd maximum wind in sonde profile"], ["lower_layer_p", 269, "Pressure at bottom of layer SBUV (Pa)"], ["upper_layer_p", 270, "Pressure at top of later SBUV (Pa)"], ["cloud_cover", 271, "Total cloud cover"], ["depth", 272, "Depth (m)"], ["ssh", 273, "Sea surface height (m)"], ["rawbt_mwts", 274, "Raw brightness temperature specific to MWTS (K)"], ["rawbt_mwhs", 275, "Raw brightness temperature specific to MWHS (K)"], ["tot_lightning_flash_dens", 196, "total (cloud-to-ground plus intra-cloud) lightning flash density (fl/km2/day)"], ["cg_lightning_flash_dens", 197, "cloud-to-ground lightning flash density ( fl/km2/day)"], ["lidar_aerosol_extinction", 236, "lidar aerosol extinction (1/m)"], ["lidar_cloud_backscatter", 237, "lidar cloud backscatter"], ["lidar_cloud_extinction", 238, "lidar cloud extinction"], ["cloud_radar_reflectivity", 239, "cloud radar reflectivity"], ["lidar_aerosol_attenuated_backscatter", 280, "lidar aerosol attenuated backscatter (1/m*sr)"], ["q2m", 281, "specific humidity at 2m (kg/kg)"], ["chem6", 284, "volcanic SO2"], ["sla", 287, "sea level anomaly"], ["ice_freeboard", 286, "Height of sea ice above open water"], ["snow_freeboard", 285, "Height of snow and sea ice above open water"], ["visible_spectral_reflectance", 240, "Visible Spectral Reflectance"], ["od10", 288, "optical depth at 10 microns"], ["chem4", 184, "chem4: hcho"], ["chem5", 185, "chem5: go3"], ["frac_snow_cover", 282, "fractional snow cover"], ["cloud_doppler_velocity", 251, "vertical radar doppler velocity"], ["lidar_rayleigh_backscatter", 252, "lidar Rayleigh backscatter"], ["sigma0_sm", 283, "backscatter coefficient normalized at 40 degree (db)"], ["t2m_min", 37, "minimum 2m temperature (K)"], ["t2m_max", 38, "maximum 2m temperature (K)"], ["ssrd", 25, "downward surface solar radiation (J/m2)"]]} \ No newline at end of file +{"fields": ["name", "code", "description"], "data": [["u", 3, "upper air u component"], ["v", 4, "upper air v component"], ["z", 1, "geopotential"], ["dz", 57, "thickness"], ["rh", 29, "upper air rel. humidity"], ["pwc", 9, "precipitable water content"], ["rh2m", 58, "2m rel. humidity"], ["t", 2, "upper air temperature (K)"], ["td", 59, "upper air dew point (K)"], ["t2m", 39, "2m temperature (K)"], ["td2m", 40, "2m dew point (K)"], ["ts", 11, "surface temperature (K)"], ["ptend", 30, "pressure tendency"], ["w", 60, "past weather (w)"], ["ww", 61, "present weather (ww)"], ["vv", 62, "visibility"], ["ch", 63, "type of high clouds (ch)"], ["cm", 64, "type of middle clouds (cm)"], ["cl", 65, "type of low clouds (cl)"], ["nh", 66, "cloud base height (nh) (meter)"], ["nn", 67, "low cloud amount (n)"], ["hshs", 68, "additional cloud group height (hh)"], ["c", 69, "additional cloud group type (c)"], ["ns", 70, "additional cloud group amount (ns)"], ["sdepth", 71, "snow depth"], ["e", 72, "state of ground (e)"], ["tgtg", 73, "ground temperature (tgtg)"], ["spsp1", 74, "special phenomena (spsp)#1"], ["spsp2", 75, "special phenomena (spsp)#2"], ["rs", 76, "ice code type (rs)"], ["eses", 77, "ice thickness (eses)"], ["is", 78, "ice (is)"], ["trtr", 79, "original time period of rain obs. (trtr)"], ["rr", 80, "6hr rain (liquid part)"], ["jj", 81, "max. temperature (jj)"], ["vs", 82, "ship speed (vs)"], ["ds", 83, "ship direction (ds)"], ["hwhw", 84, "wave height"], ["pwpw", 85, "wave period"], ["dwdw", 86, "wave direction"], ["gclg", 87, "general cloud group"], ["rhlc", 88, "rel. humidity from low clouds"], ["rhmc", 89, "rel. humidity from middle clouds"], ["rhhc", 90, "rel. humidity from high clouds"], ["n", 91, "total amount of clouds"], ["sfall", 92, "6hr snowfall (solid part of rain)"], ["ps", 110, "surface pressure"], ["dd", 111, "wind direction"], ["ff", 112, "wind force"], ["rawbt", 119, "brightness temperature (K)"], ["rawra", 120, "raw radiance"], ["satcl", 121, "cloud amount from satellite"], ["scatss", 122, "sigma 0"], ["du", 5, "wind shear (du)"], ["dv", 6, "wind shear (dv)"], ["u10m", 41, "10m u component (m/s)"], ["v10m", 42, "10m v component (m/s)"], ["rhlay", 19, "layer rel. humidity"], ["cllqw", 123, "cloud liquid water"], ["scatdd", 124, "ambiguous v component"], ["scatff", 125, "ambiguous u component"], ["q", 7, "specific humidity (q)"], ["scatwd", 126, "ambiguous wind direction"], ["scatws", 127, "ambiguous wind speed"], ["vsp", 8, "vertical speed"], ["vt", 56, "virtual temperature"], ["o3lay", 206, "layer ozone"], ["height", 156, "height"], ["1dvar", 215, "1d-var model level (pseudo)-variable"], ["w2", 160, "past weather 2 (used in synoptic maps)"], ["cpt", 130, "characteristic of pressure tendency (used in synoptic maps)"], ["tsts", 12, "sea water temperature (used in synoptic maps)"], ["refl", 192, "radar reflectivity"], ["apdss", 128, "atmospheric path delay in satellite signal"], ["bend_angle", 162, "radio occultation bending angle"], ["los", 187, "horizontal line-of-sight wind component"], ["aerod", 174, "aerosol optical depth at 0.55 microns"], ["limb_radiance", 163, "Limb Radiances"], ["chem3", 183, "chem3: co"], ["chem2", 182, "chem2: so2"], ["chem1", 181, "chem1: no2/nox"], ["cod", 175, "cloud optical depth"], ["rao", 176, "Ratio of fine mode to total aerosol optical depth at 0.55 microns"], ["od", 177, "optical depth"], ["rfltnc", 178, "Aerosol reflectance multi-channel"], ["nsoilm", 179, "normalized soil moisture (0-100%)"], ["soilm", 180, "soil moisture"], ["flgt_phase", 201, "phase of aircraft flight"], ["height_assignment_method", 211, "Height assignment method"], ["dopp", 195, "radar doppler wind"], ["ghg1", 186, "ghg1: carbon dioxide"], ["ghg2", 188, "ghg2: methane"], ["ghg3", 189, "ghg3: nitrous oxide"], ["bt_real", 190, "brightness temperature real part"], ["bt_imaginary", 191, "brightness temperature imaginary part"], ["prc", 202, "radar rain rate"], ["lnprc", 203, "log(radar rain rate mm/h + epsilon)"], ["libksc", 222, "lidar backscattering"], ["ralt_swh", 220, "significant wave height (m)"], ["ralt_sws", 221, "surface wind speed (m/s)"], ["rawbt_clear", 193, "brightness temperature for clear (K)"], ["rawbt_cloudy", 194, "brightness temperature for cloudy (K)"], ["binary_snow_cover", 223, "binary snow cover (0: no snow / 1: presence of snow)"], ["salinity", 224, "ocean salinity (PSU)"], ["potential_temp", 225, "potential temperature (Kelvin)"], ["humidity_mixing_ratio", 226, "humidity mixing ratio (kg/kg)"], ["airframe_icing", 227, "airframe icing"], ["turbulence_index", 228, "turbulence index"], ["pstation", 107, "Station pressure (Pa)"], ["pmsl", 108, "Mean sea-level pressure (Pa)"], ["pstandard", 109, "Standard level pressure (Pa)"], ["vert_vv", 218, "Vertical visibility (m)"], ["max_wind_shear1", 219, "Wind shear above and below 1st maximum wind in sonde profile (s-1)"], ["tot_zen_delay", 229, "Total zenith delay (GPS)"], ["tot_zen_delay_err", 230, "Total zenith delay error (GPS)"], ["cloud_top_temp", 231, "Cloud top temperature (K)"], ["rawsca", 233, "Scaled radiance"], ["cloud_top_press", 235, "Cloud top pressure (Pa)"], ["mean_freq", 241, "GPSRO mean frequency"], ["u_amb", 242, "Ambiguous u-wind component (m/s)"], ["v_amb", 243, "Ambiguous v-wind component (m/s)"], ["lwp", 244, "Liquid water path"], ["tcwv", 245, "Total column water vapour"], ["cloud_frac_clear", 247, "Cloud clear fraction"], ["rawbt_hirs", 248, "Raw brightness temperature specific to HIRS (K)"], ["rawbt_amsu", 249, "Raw brightness temperature specific to AMSU (K)"], ["rawbt_hirs20", 250, "Raw brightness temperature specific to HIRS (K)"], ["sea_ice", 253, "Sea ice fraction"], ["cloud_frac_covered", 257, "Cloud covered fraction"], ["level_mixing_ratio", 258, "humidity_mixing_ratio]"], ["radial_velocity", 259, "Radial velocity from doppler radar"], ["cloud_ice_water", 260, "Cloud ice water"], ["wind_gust", 261, "Maximum wind gust (m/s)"], ["mass_density", 262, "Mass density"], ["atmosphere_number", 263, "SFERICS number of atmospheres"], ["lightning", 265, "Lightning strike observation (ATDNET)"], ["level_cloud", 266, "Cloud fraction (multi-level)"], ["rawbt_amsr_89ghz", 267, "Raw brightness temperature specific to AMSR 89GHz channels (K)"], ["max_wind_shear2", 268, "Wind shear above and below 2nd maximum wind in sonde profile"], ["lower_layer_p", 269, "Pressure at bottom of layer SBUV (Pa)"], ["upper_layer_p", 270, "Pressure at top of later SBUV (Pa)"], ["cloud_cover", 271, "Total cloud cover"], ["depth", 272, "Depth (m)"], ["ssh", 273, "Sea surface height (m)"], ["rawbt_mwts", 274, "Raw brightness temperature specific to MWTS (K)"], ["rawbt_mwhs", 275, "Raw brightness temperature specific to MWHS (K)"], ["tot_lightning_flash_dens", 196, "total (cloud-to-ground plus intra-cloud) lightning flash density (fl/km2/day)"], ["cg_lightning_flash_dens", 197, "cloud-to-ground lightning flash density ( fl/km2/day)"], ["lidar_aerosol_extinction", 236, "lidar aerosol extinction (1/m)"], ["lidar_cloud_backscatter", 237, "lidar cloud backscatter"], ["lidar_cloud_extinction", 238, "lidar cloud extinction"], ["cloud_radar_reflectivity", 239, "cloud radar reflectivity"], ["lidar_aerosol_attenuated_backscatter", 280, "lidar aerosol attenuated backscatter (1/m*sr)"], ["q2m", 281, "specific humidity at 2m (kg/kg)"], ["chem6", 284, "volcanic SO2"], ["sla", 287, "sea level anomaly"], ["ice_freeboard", 286, "Height of sea ice above open water"], ["snow_freeboard", 285, "Height of snow and sea ice above open water"], ["visible_spectral_reflectance", 240, "Visible Spectral Reflectance"], ["od10", 288, "optical depth at 10 microns"], ["chem4", 184, "chem4: hcho"], ["chem5", 185, "chem5: go3"], ["frac_snow_cover", 282, "fractional snow cover"], ["cloud_doppler_velocity", 251, "vertical radar doppler velocity"], ["lidar_rayleigh_backscatter", 252, "lidar Rayleigh backscatter"], ["sigma0_sm", 283, "backscatter coefficient normalized at 40 degree (db)"], ["t2m_min", 37, "minimum 2m temperature (K)"], ["t2m_max", 38, "maximum 2m temperature (K)"], ["ssrd", 25, "downward surface solar radiation (J/m2)"]]} From a3e72f7e96abaaa6868c1d620f9024615b1761e9 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 17 Jun 2025 17:44:18 +0200 Subject: [PATCH 039/212] qa --- tests/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_data.py b/tests/test_data.py index 488d05bae..f2957dd4c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1456,7 +1456,7 @@ def test_fields_to_records() -> None: key = "grp" ds = open_dataset(dataset="test-2021-2021-6h-o96-abcd-1", set_group=key) - unwrapped = open_dataset(dataset="test-2021-2021-6h-o96-abcd-2") + # unwrapped = open_dataset(dataset="test-2021-2021-6h-o96-abcd-2") assert ds.groups == [key] assert ds.variables == {key: ["a", "b", "c", "d"]} From 2b015649ea8514110248e3fb1513a706f99a71c1 Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Wed, 18 Jun 2025 07:57:12 +0000 Subject: [PATCH 040/212] updating mars example --- tests/create/test_observations_mars.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index 1e54a4a08..f6dbd6e17 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -45,10 +45,10 @@ def __call__(self, window): class MarsSource(ObservationsSource): - def __init__(self, request_dict, post_process_dict): + def __init__(self, request_dict, pre_process_dict, post_process_dict): assert isinstance(request_dict, dict), "request_dict must be a dictionary" self.request_dict = request_dict - self.post_process_dict = post_process_dict + self.pre_process_dict = pre_process_dict def __call__(self, window): assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" @@ -66,10 +66,7 @@ def __call__(self, window): else: raise # Re-raise if it's a different error - data = process_odb(ekd_ds, **self.post_process_dict) - - # print(data) - # print(data.columns) + data = process_odb(ekd_ds, **self.pre_process_dict) if window.include_start: mask = data["times"] > window.start @@ -110,10 +107,12 @@ def __call__(self, df): "time": "00/12", "filter": "'select seqno,reportype,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000) or (varno == 91)) and time in (000000,030000,060000,090000,120000,150000,180000,210000);'", }, - post_process_dict={ + pre_process_dict={ + # "target": odb2df.process_odb, "index": ["seqno@hdr", "lat@hdr", "lon@hdr", "date@hdr", "time@hdr", "stalt@hdr", "lsm@modsurf"], "pivot": ["varno@body"], "values": ["obsvalue@body"], + "drop_na": True, }, ) filter = DummyFilter("obsvalue_v10m_0") From 8d6ccd6e33f6f63d8946c35c1fd9f2135a5cba20 Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Thu, 19 Jun 2025 07:08:07 +0000 Subject: [PATCH 041/212] updating mars odb example --- tests/create/test_observations_mars.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index f6dbd6e17..c98340f48 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -45,10 +45,11 @@ def __call__(self, window): class MarsSource(ObservationsSource): - def __init__(self, request_dict, pre_process_dict, post_process_dict): + def __init__(self, request_dict, pre_process_dict, process_func): assert isinstance(request_dict, dict), "request_dict must be a dictionary" self.request_dict = request_dict self.pre_process_dict = pre_process_dict + self.process_func = process_func def __call__(self, window): assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" @@ -66,7 +67,7 @@ def __call__(self, window): else: raise # Re-raise if it's a different error - data = process_odb(ekd_ds, **self.pre_process_dict) + data = self.process_func(ekd_ds, **self.pre_process_dict) if window.include_start: mask = data["times"] > window.start @@ -114,6 +115,7 @@ def __call__(self, df): "values": ["obsvalue@body"], "drop_na": True, }, + process_func=process_odb, ) filter = DummyFilter("obsvalue_v10m_0") From 6b2adb8063f26a1774a0bf5e9ce011000ba9bae8 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 23 Jun 2025 15:14:58 +0000 Subject: [PATCH 042/212] more metadata --- src/anemoi/datasets/data/records/__init__.py | 4 ++-- .../datasets/data/records/backends/__init__.py | 17 +++++++++-------- tools/build-obs.py | 5 ++++- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 2e6a21468..f8c2f577f 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -662,11 +662,11 @@ def __len__(self): class RecordsDataset(BaseRecordsDataset): - def __init__(self, path, backend="npz1", **kwargs): + def __init__(self, path, backend=None, **kwargs): if kwargs: print("Warning: ignoring additional kwargs", kwargs) self.path = path - self.backend = backend_factory(backend, path, **kwargs) + self.backend = backend_factory(**backend, path=path) self._groups = list(self.metadata["sources"].keys()) for k in self.groups: assert k == self.normalise_key(k), k diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index e831d8f82..f49b93751 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -43,10 +43,10 @@ def _check_data(self, data): class Npz1Backend(Backend): - number_of_files_per_subdirectory = 100 - def __init__(self, *args, **kwargs): + def __init__(self, *args, number_of_files_per_subdirectory=100, **kwargs): super().__init__(*args, **kwargs) + self.number_of_files_per_subdirectory = number_of_files_per_subdirectory self._cache = None def read(self, i, **kwargs): @@ -99,6 +99,7 @@ def read_statistics(self): class Nc1Backend(Backend): + number_of_files_per_subdirectory = 100 def read(self, i, **kwargs): d = str(int(i / self.number_of_files_per_subdirectory)) @@ -138,8 +139,8 @@ def backend_factory(name, *args, **kwargs): class WriteBackend(Backend): - def __init__(self, path, **kwargs): - super().__init__(path, **kwargs) + def __init__(self, *, target, **kwargs): + super().__init__(target, **kwargs) def write(self, i, data, **kwargs): raise NotImplementedError("Must be implemented in subclass") @@ -183,9 +184,9 @@ def _dataframe_to_dict(self, name, df, **kwargs): class Npz1WriteBackend(WriteBackend): - number_of_files_per_subdirectory = 100 - def write(self, i, data, **kwargs): + def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): + self.number_of_files_per_subdirectory = number_of_files_per_subdirectory self._check_data(data) d = str(int(i / self.number_of_files_per_subdirectory)) dir_path = os.path.join(self.path, "data", d) @@ -302,10 +303,10 @@ def write_statistics(self, statistics): np.savez(path, **flatten) -def writer_backend_factory(backend, *args, **kwargs): +def writer_backend_factory(name, **kwargs): WRITE_BACKENDS = dict( npz1=Npz1WriteBackend, npz2=Npz2WriteBackend, nc1=Nc1WriteBackend, ) - return WRITE_BACKENDS[backend](*args, **kwargs) + return WRITE_BACKENDS[name](**kwargs) diff --git a/tools/build-obs.py b/tools/build-obs.py index e3caff9f9..566db4df0 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,6 +28,9 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") + if not isinstance(backend, dict): + backend = {"name": backend} + from anemoi.datasets.data.records.backends import writer_backend_factory if os.path.exists(output): @@ -36,7 +39,7 @@ def build(input, output, backend, overwrite=False): shutil.rmtree(output) else: raise FileExistsError(f"Output directory {output} already exists, use --overwrite to remove it") - writer = writer_backend_factory(backend, output) + writer = writer_backend_factory(**backend, target=output) for i in tqdm.tqdm(range(len(ds))): writer.write(i, ds[i]) From dcc180299a1e35fccf46b9130585d05bbd88b79d Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 24 Jun 2025 10:56:58 +0000 Subject: [PATCH 043/212] fix timedelta type --- src/anemoi/datasets/data/observations/__init__.py | 1 + src/anemoi/datasets/data/padded.py | 6 ++++-- src/anemoi/datasets/data/records/__init__.py | 7 ++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py index b5f8ec5e9..f19ad996e 100644 --- a/src/anemoi/datasets/data/observations/__init__.py +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -238,6 +238,7 @@ def get_aux(self, i): assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" + assert timedeltas.dtype == "timedelta64[s]", f"Expected timedelta64[s], got {timedeltas.dtype}" return latitudes, longitudes, timedeltas def getitem(self, i): diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index dcff11bae..2dcc0d71c 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -170,8 +170,10 @@ def empty_item(self): def get_aux(self, i: FullIndex) -> NDArray[np.timedelta64]: if self._i_out_of_range(i): - arr = np.array([], dtype=np.float32) - aux = arr, arr, arr + lats = np.array([], dtype=np.float32) + lons = lats + timedeltas = np.ones_like(lons, dtype="timedelta64[s]") * 0 + aux = lats, lons, timedeltas else: aux = self.dataset.get_aux(i - self._before) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index f8c2f577f..1a92ad8cc 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -537,17 +537,14 @@ def _load_data(self, i): print(f"Requested ds({i}) : need to read {list(range(first_j, last_j + 1))} indices") # _load_data could support a list of indices, but for now we merge the data ourselves + # we merge the windows that we need, and then remove unnecessary data too_much_data = merge_data(self.forward._load_data(j) for j in range(first_j, last_j + 1)) out = {} for group in self.groups: timedeltas = too_much_data[f"timedeltas:{group}"] if timedeltas.dtype != "timedelta64[s]": - if len(timedeltas) != 0: - raise ValueError(f"Wrong type for {group}") - else: - LOG.warning(f"TODO: Fixing {group} on the fly") - timedeltas = np.ones_like(timedeltas, dtype="timedelta64[s]") * 0 + raise ValueError(f"Wrong type for {group}") mask = self._window.compute_mask(timedeltas) out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] From 826dff8d0c9605e8f3dab8623d3ad63cbcaf97e0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 24 Jun 2025 11:26:18 +0000 Subject: [PATCH 044/212] more logs --- .../datasets/data/observations/__init__.py | 5 ++++- src/anemoi/datasets/data/records/__init__.py | 18 +++++++++++++++--- .../datasets/data/records/backends/__init__.py | 4 ++-- tests/test_records.py | 2 ++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py index f19ad996e..b7fd53d23 100644 --- a/src/anemoi/datasets/data/observations/__init__.py +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -69,7 +69,10 @@ def __len__(self): return len(self.dates) def tree(self): - return Node(self) + return Node( + self, + [], + ) def __getitem__(self, i): if isinstance(i, int): diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 1a92ad8cc..f47b8a6a3 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -17,6 +17,7 @@ from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta +from anemoi.datasets.data.debug import Node from anemoi.datasets.data.records.backends import backend_factory LOG = logging.getLogger(__name__) @@ -206,6 +207,9 @@ def shapes(self): def __len__(self): return len(self.dates) + def tree(self): + return Node(self, [self.forward.tree()], **self.reason) + class FieldsRecords(RecordsForward): """A wrapper around a FieldsDataset to provide a consistent interface for records datasets.""" @@ -214,6 +218,7 @@ def __init__(self, fields_dataset, name): self.forward = fields_dataset self._name = name self._groups = [name] + self.reason = {"name": name} def _nest_in_dict(self, obj): """Helper to nest the object in a dict with the name as key.""" @@ -255,15 +260,15 @@ def __len__(self): return len(self.forward.dates) -class Rename(RecordsForward): +class GenericRename(RecordsForward): def __init__(self, dataset, rename): self.forward = dataset - # rename: {"current_group": "new_group"} assert isinstance(rename, dict) for k, v in rename.items(): assert isinstance(k, str), k assert isinstance(v, str), v self.rename = rename + self.reason = {"rename": rename} @property def statistics(self): @@ -282,7 +287,11 @@ def groups(self): return [self.rename.get(k, k) for k in self.forward.groups] -class SetGroup(Rename): +class Rename(GenericRename): + pass + + +class SetGroup(GenericRename): def __init__(self, dataset, set_group): if len(dataset.groups) != 1: raise ValueError(f"{self.__class__.__name__} can only be used with datasets containing a single group.") @@ -748,6 +757,9 @@ def check(self, i=None): for group, s in dict_of_sets.items(): assert s == {"latitudes", "longitudes", "timedeltas", "metadata", "data"}, f"Invalid keys {s}" + def tree(self): + return Node(self, [], path=self.path) + class Record: def __init__(self, dataset, n): diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index f49b93751..e9182feef 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -18,7 +18,7 @@ def normalise_key(k): - return "".join([x.lower() if x.isalnum() else "-" for x in k]) + return "".join([x.lower() if x.isalnum() else "_" for x in k]) class Backend: @@ -155,7 +155,7 @@ def _check_data(self, data): for k in list(data.keys()): k = k.split(":")[-1] if k != normalise_key(k): - raise ValueError(f"{k} must be alphanumerical and '-' only.") + raise ValueError(f"{k} must be alphanumerical and '_' only.") def _dataframes_to_record(self, i, data, variables, **kwargs): diff --git a/tests/test_records.py b/tests/test_records.py index 7c065e4a7..bda74d592 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -25,6 +25,8 @@ def check_numpy(x, y): def _test(ds, nb_dates=None): + print(f"💬 Testing {type(ds)} with {len(ds)} dates") + print(ds.tree()) grp = "metop-a-ascat" index_i = 0 From d4be46322a5d499947a5b4f185d37c6b57dae020 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 24 Jun 2025 12:10:17 +0000 Subject: [PATCH 045/212] typo --- src/anemoi/datasets/data/records/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index f47b8a6a3..1f19d8a96 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -683,7 +683,7 @@ def groups(self): @classmethod def normalise_key(cls, k): - return "".join([x.lower() if x.isalnum() else "-" for x in k]) + return "".join([x.lower() if x.isalnum() else "_" for x in k]) @property def frequency(self): From 552c7970bfbae0a11945ed4850e0f2f66f181fd0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 24 Jun 2025 14:04:55 +0000 Subject: [PATCH 046/212] update 2025.06.24 --- src/anemoi/datasets/data/debug.py | 1 + src/anemoi/datasets/data/misc.py | 8 +++++-- src/anemoi/datasets/data/records/__init__.py | 25 ++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/data/debug.py b/src/anemoi/datasets/data/debug.py index 1c9f0fa3d..33200eeef 100644 --- a/src/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/data/debug.py @@ -69,6 +69,7 @@ def __init__(self, dataset: "Dataset", kids: List[Any], **kwargs: Any) -> None: Additional keyword arguments. """ self.dataset = dataset + assert isinstance(kids, list), "Kids must be a list" self.kids = kids self.kwargs = kwargs diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 169830202..1a51ef868 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -597,10 +597,14 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "set_group" in kwargs: from anemoi.datasets.data.records import FieldsRecords - assert len(sets) == 1, sets set_group = kwargs.pop("set_group") + assert len(sets) == 1, "set_group can only be used with a single dataset" + dataset = sets[0] - return FieldsRecords(*sets, name=set_group).mutate() + from anemoi.datasets.data.dataset import Dataset + + if isinstance(dataset, Dataset): # Fields dataset + return FieldsRecords(dataset, **kwargs, name=set_group).mutate() if len(sets) > 1: dataset, kwargs = _concat_or_join(sets, kwargs) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 1f19d8a96..fecbdcb3c 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -216,6 +216,9 @@ class FieldsRecords(RecordsForward): def __init__(self, fields_dataset, name): self.forward = fields_dataset + from anemoi.datasets.data.dataset import Dataset + + assert isinstance(fields_dataset, Dataset), f"fields_dataset must be a Dataset, got {type(fields_dataset)}" self._name = name self._groups = [name] self.reason = {"name": name} @@ -224,6 +227,17 @@ def _nest_in_dict(self, obj): """Helper to nest the object in a dict with the name as key.""" return {self._name: obj} + def _load_data(self, i): + data = self.forward[i] + out = {} + out[f"data:{self._name}"] = data + # out[f"latitudes:{self._name}"] = self.forward.latitudes + # out[f"longitudes:{self._name}"] = self.forward.longitudes + out[f"timedeltas:{self._name}"] = np.zeros_like(data, dtype="timedelta64[s]") + _to_numpy_date( + self.forward.dates[i] + ) + return out + @property def groups(self): return self._groups @@ -298,6 +312,9 @@ def __init__(self, dataset, set_group): super.__init__(dataset, {dataset.groups[0]: set_group}) + def _load_data(self, i): + return self.dataset._load_data(i) + def match_variable(lst, group, name): # lst must be a list of strings with dots (if there is no dot, it is automatically added at the end) @@ -607,6 +624,10 @@ def _build_indices_and_name_to_index(self): variables[group].append(name) count += 1 assert np.sum(ind) == count, f"Mismatch in {group}: {names}, {ind}" + if not variables: + raise ValueError( + f"No variables matched in {self._select} for dataset {self.dataset}. Available groups: {self.dataset.groups} Available variables: {self.dataset.variables} " + ) self._indices = indices self._name_to_index = name_to_index self._variables = variables @@ -814,6 +835,10 @@ def timedeltas(self): def statistics(self): return self.dataset.statistics + def as_dict(self): + """Returns the record as a dictionary with group names as keys.""" + return {group: self[group] for group in self.groups} + class Tabular: def __init__(self, dataset, name): From 0053aec9cd0bf8af8f80ae5bb203ed356397c12d Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 30 Jun 2025 15:08:26 +0000 Subject: [PATCH 047/212] up --- src/anemoi/datasets/commands/inspect.py | 28 ++++++++----- src/anemoi/datasets/data/misc.py | 33 +++++++-------- src/anemoi/datasets/data/records/__init__.py | 11 ++++- src/anemoi/datasets/data/stores.py | 28 +++++++++++-- tests/test_records.py | 42 ++++++++++---------- 5 files changed, 92 insertions(+), 50 deletions(-) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 400cdcf98..cb3aaf847 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -22,6 +22,7 @@ import numpy as np import semantic_version import tqdm +from anemoi.utils.config import load_any_dict_format from anemoi.utils.humanize import bytes from anemoi.utils.humanize import bytes_to_human from anemoi.utils.humanize import when @@ -31,8 +32,8 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset +from anemoi.datasets.data.stores import dataset_lookup from anemoi.datasets.data.stores import open_zarr -from anemoi.datasets.data.stores import zarr_lookup from . import Command @@ -300,12 +301,12 @@ def variables(self) -> List[str]: @property def total_size(self) -> Optional[int]: """Get the total size of the dataset.""" - return self.zarr.attrs.get("total_size") + return self.metadata.get("total_size") @property def total_number_of_files(self) -> Optional[int]: """Get the total number of files in the dataset.""" - return self.zarr.attrs.get("total_number_of_files") + return self.metadata.get("total_number_of_files") def print_sizes(self, size: bool) -> None: """Print the size and number of files in the dataset. @@ -362,15 +363,14 @@ def build_flags(self) -> Optional[NDArray[Any]]: @cached_property def copy_flags(self) -> Optional[NDArray[Any]]: - """Get the copy flags of the dataset.""" - if "_copy" not in self.zarr: + if not self.zarr or "_copy" not in self.zarr: return None return self.zarr["_copy"][:] @property def copy_in_progress(self) -> bool: """Check if a copy operation is in progress.""" - if "_copy" not in self.zarr: + if not self.zarr or "_copy" not in self.zarr: return False start = self.zarr["_copy"].attrs.get("copy_start_timestamp") @@ -383,6 +383,8 @@ def copy_in_progress(self) -> bool: @property def build_lengths(self) -> Optional[NDArray]: """Get the build lengths of the dataset.""" + if not self.zarr: + return None return self.zarr.get("_build_lengths") def progress(self) -> None: @@ -652,7 +654,7 @@ def details(self) -> None: def ready(self) -> bool: """Check if the dataset is ready.""" - if "_build_flags" not in self.zarr: + if not self.zarr or "_build_flags" not in self.zarr: return False build_flags = self.zarr["_build_flags"] @@ -708,7 +710,7 @@ class Version0_13(Version0_12): @property def build_flags(self) -> Optional[NDArray]: """Get the build flags for the dataset.""" - if "_build" not in self.zarr: + if not self.zarr or "_build" not in self.zarr: return None build = self.zarr["_build"] return build.get("flags") @@ -818,9 +820,15 @@ def _info(self, path: str) -> Version: Version The version object of the dataset. """ - z = open_zarr(zarr_lookup(path)) + resolved_path = dataset_lookup(path) + if resolved_path.endswith(".vz"): + LOG.warning(f"Inspecting a .vz file: {resolved_path}. This is not supported yet.") + metadata = load_any_dict_format(os.path.join(resolved_path, "metadata.json")) + z = None + else: + z = open_zarr(resolved_path) + metadata = dict(z.attrs) - metadata = dict(z.attrs) version = metadata.get("version", "0.0.0") if isinstance(version, int): version = f"0.{version}" diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 1a51ef868..5d7d62f7a 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -354,21 +354,7 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - """ from .dataset import Dataset from .stores import Zarr - from .stores import zarr_lookup - - if isinstance(a, str) and len(a.split(".")[-1]) in [1, 2, 3]: - # This will do nothing if there is no "metadata.json" file - # .zarr datasets do not have "metadata.json" - - metadata_path = os.path.join(a, "metadata.json") - if os.path.exists(metadata_path): - metadata = load_any_dict_format(metadata_path) - if "backend" not in metadata: - raise ValueError(f"Metadata for {a} does not contain 'backend' key") - - from anemoi.datasets.data.records import open_records_dataset - - return open_records_dataset(a, backend=metadata["backend"]) + from .stores import dataset_lookup if isinstance(a, Dataset): return a.mutate() @@ -377,7 +363,22 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - return Zarr(a).mutate() if isinstance(a, str): - return Zarr(zarr_lookup(a)).mutate() + path = dataset_lookup(a) + + if path and path.endswith(".zarr") or path.endswith(".zip"): + return Zarr(path).mutate() + + if path and path.endswith(".vz"): + metadata_path = os.path.join(path, "metadata.json") + if os.path.exists(metadata_path): + if "backend" not in load_any_dict_format(metadata_path): + raise ValueError(f"Metadata for {path} does not contain 'backend' key") + + from anemoi.datasets.data.records import open_records_dataset + + return open_records_dataset(path) + + raise ValueError(f"Unsupported dataset path: {path}. ") if isinstance(a, PurePath): return _open(str(a)).mutate() diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index fecbdcb3c..07d18d72b 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -14,6 +14,7 @@ from functools import cached_property import numpy as np +from anemoi.utils.config import load_any_dict_format from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta @@ -41,6 +42,11 @@ def counter(func): def open_records_dataset(dataset, **kwargs): + metadata_path = os.path.join(dataset, "metadata.json") + if not os.path.exists(metadata_path): + return None + metadata = load_any_dict_format(metadata_path) + kwargs["backend"] = kwargs.get("backend", metadata["backend"]) return RecordsDataset(dataset, **kwargs) @@ -810,7 +816,10 @@ def groups(self): return self.dataset.groups def __getitem__(self, group): - return self._payload["data:" + group] + k = f"data:{group}" + if k not in self._payload: + raise KeyError(f"Group {group} not found in record {self.n}. Available groups are {self.groups}") + return self._payload[k] def _get_aux(self, name): try: diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index e31d4cfb9..1ad2802c3 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -559,17 +559,21 @@ def label(self) -> str: QUIET = set() -def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: +def zarr_lookup(*args, **kwargs) -> Optional[str]: + return dataset_lookup(*args, **kwargs) + + +def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: """Look up a zarr dataset by name.""" config = load_config()["datasets"] use_search_path_not_found = config.get("use_search_path_not_found", False) - if name.endswith(".zarr/"): + if name.endswith(".zarr/") or name.endswith(".vz/"): LOG.warning("Removing trailing slash from path: %s", name) name = name[:-1] - if name.endswith(".zarr") or name.endswith(".zip"): + if name.endswith(".zarr") or name.endswith(".zip") or name.endswith(".vz"): if os.path.exists(name): return name @@ -591,6 +595,24 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: for location in config["path"]: if not location.endswith("/"): location += "/" + + full = location + name + ".vz" + tried.append(full) + try: + + from anemoi.datasets.data.records import open_records_dataset + + z = open_records_dataset(full) + if z is not None: + # Cache for next time + config["named"][name] = full + if name not in QUIET: + LOG.info("Opening `%s` as `%s`", name, full) + QUIET.add(name) + return full + except zarr.errors.PathNotFoundError: + pass + full = location + name + ".zarr" tried.append(full) try: diff --git a/tests/test_records.py b/tests/test_records.py index bda74d592..8560e069a 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -15,6 +15,8 @@ from anemoi.datasets.data.records import Record from anemoi.datasets.data.records import Tabular +TEST_DATASET = "../../data/vz/observations-testing-2018-2018-6h-v0.vz" + def check_numpy(x, y): assert x.shape == y.shape, f"Expected {x.shape} == {y.shape}" @@ -27,7 +29,7 @@ def check_numpy(x, y): def _test(ds, nb_dates=None): print(f"💬 Testing {type(ds)} with {len(ds)} dates") print(ds.tree()) - grp = "metop-a-ascat" + grp = "metop_a" index_i = 0 if nb_dates is not None: @@ -122,43 +124,43 @@ def _test(ds, nb_dates=None): assert np.all(statistics[grp][key] == v), (key, statistics[grp][key], v) -@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") def test_open(): - ds = open_dataset("../../data/vz/obs-2018-11.vz") + ds = open_dataset(TEST_DATASET) _test(ds) -@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") def test_open_with_subset_dates(): ds = open_dataset( - "../../data/vz/obs-2018-11.vz", + TEST_DATASET, end="2018-11-30", select=[ - "metop-a-ascat.*", - "amsr2-h180.rawbt_4", - "amsr2-h180.rawbt_3", + "metop_a.*", + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", ], ) _test(ds, nb_dates=8) -@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") def test_open_with_window(): dates = dict(end="2018-11-30") - ds = open_dataset("../../data/vz/obs-2018-11.vz", window="(-6h, 0h]", **dates) + ds = open_dataset(TEST_DATASET, window="(-6h, 0h]", **dates) _test(ds, nb_dates=8) - ds = open_dataset("../../data/vz/obs-2018-11.vz", window="(-1h, 0)", **dates) + ds = open_dataset(TEST_DATASET, window="(-1h, 0)", **dates) _test(ds, nb_dates=8) def test_open_bad_window(): subset = dict(end="2018-11-30") with pytest.raises(ValueError, match="No dates left after rewindowing"): - open_dataset("../../data/vz/obs-2018-11.vz", window="(-48h, +48h)", **subset) + open_dataset(TEST_DATASET, window="(-48h, +48h)", **subset) -@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") @pytest.mark.parametrize( "window, missing_dates", [ @@ -178,22 +180,22 @@ def test_open_bad_window(): def test_open_with_window_parametrized(window, missing_dates): subset = dict(end="2018-11-30") - ds = open_dataset("../../data/vz/obs-2018-11.vz", **subset) + ds = open_dataset(TEST_DATASET, **subset) assert len(ds) == 8 nb_dates = len(ds) + missing_dates - ds = open_dataset("../../data/vz/obs-2018-11.vz", window=window, **subset) + ds = open_dataset(TEST_DATASET, window=window, **subset) _test(ds, nb_dates=nb_dates) -@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") def test_open_with_subset_select(): ds = open_dataset( - "../../data/vz/obs-2018-11.vz", + TEST_DATASET, select=[ - "amsr2-h180.rawbt_4", - "amsr2-h180.rawbt_3", - "metop-a-ascat.*", + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", + "metop_a.*", ], ) _test(ds) From 7c2d4fb576231aed5d46b9c2a886cc43920ee25c Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 1 Jul 2025 11:20:32 +0000 Subject: [PATCH 048/212] padding="missing" or "raise" --- src/anemoi/datasets/data/dataset.py | 6 +++--- src/anemoi/datasets/data/padded.py | 21 +++++++++++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 0267022d1..b9e29dd24 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -182,13 +182,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": padding = kwargs.pop("padding", None) if padding: - if padding != "empty": - raise ValueError(f"Only 'empty' padding is supported, got {padding=}") from .padded import Padded frequency = kwargs.pop("frequency", self.frequency) return ( - Padded(self, start, end, frequency, dict(start=start, end=end, frequency=frequency)) + Padded( + self, start, end, frequency, dict(start=start, end=end, frequency=frequency, padding=padding) + ) ._subset(**kwargs) .mutate() ) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index 2dcc0d71c..0160b674f 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -19,6 +19,7 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray +from anemoi.datasets.data import MissingDateError from anemoi.datasets.data.dataset import Dataset from anemoi.datasets.data.dataset import FullIndex from anemoi.datasets.data.dataset import Shape @@ -38,7 +39,15 @@ class Padded(Forwards): _after: int = 0 _inside: int = 0 - def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reason: Dict[str, Any]) -> None: + def __init__( + self, + dataset: Dataset, + start: str, + end: str, + frequency: str, + reason: Dict[str, Any], + padding: str, + ) -> None: """Create a padded subset of a dataset. Attributes: @@ -48,6 +57,7 @@ def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reaso frequency (str): The frequency of the subset. reason (Dict[str, Any]): The reason for the padding. """ + self.padding = padding self.reason = {k: v for k, v in reason.items() if v is not None} @@ -165,8 +175,15 @@ def _get_tuple(self, n: TupleIndex) -> NDArray[Any]: LOG.warning("Padded subset does not support tuple indexing, returning a list") return [self[i] for i in n] + @property def empty_item(self): - return self.dataset.empty_item() + if self.padding == "empty": + return self.dataset.empty_item() + elif self.padding == "raise": + raise ValueError("Padding is set to 'raise', cannot return an empty item.") + elif self.padding == "missing": + raise MissingDateError("Padding is set to 'missing'") + assert False, self.padding def get_aux(self, i: FullIndex) -> NDArray[np.timedelta64]: if self._i_out_of_range(i): From 031e9f222e61e8eaeb4d5110641b6df31c00e6f7 Mon Sep 17 00:00:00 2001 From: "mihai.alexe" Date: Wed, 2 Jul 2025 11:32:37 +0000 Subject: [PATCH 049/212] DOP dataset: first draft, missing len, sample factory index issue --- dop_dataset.py | 399 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 399 insertions(+) create mode 100644 dop_dataset.py diff --git a/dop_dataset.py b/dop_dataset.py new file mode 100644 index 000000000..ee3f6cc68 --- /dev/null +++ b/dop_dataset.py @@ -0,0 +1,399 @@ +from typing import Optional +import numpy as np +import os + +from torch import Tensor +from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info + +import torch +import random +import json + +import numpy as np +import yaml +from rich.console import Console +from rich.tree import Tree + +from anemoi.datasets import open_dataset + +CONFIG = dict( + data=dict( + # era5=dict( + # dataset=dict(dataset="aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8", set_group="era5"), + # # preprocessors=dict( + # # tp=[dict(normalizer="mean-std")]), + # # ), + # ), + snow=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), + metop_a=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), + amsr2_h180=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), + ), + sample=dict( + GROUPS=dict( + input=dict( + GROUPS=dict( + # fields=dict( # "fields" is a user defined key + # STEPS=dict( + # _6h=dict( + # variables=["q_50", "2t"], + # data="era5", + # ), + # _0h=dict( + # variables=["q_50", "2t"], + # data="era5", + # ), + # ), + # ), + # user-friendly config would be: + # fields=dict( + # steps=['-6h', '0h'], + # variables=["q_50", "2t"], + # data="era5", + # ), + ascat_metop_a=dict( # "metar" is a user defined key + STEPS=dict( + _6h=dict( + variables=["scatss_1", "scatss_2"], + data="metop_a", + ), + ), + ), + snow=dict( # "iasi" is a user defined key + STEPS=dict( + _6h=dict( + variables=["sdepth_0"], + data="snow", + ), + ), + ), + amsr2=dict( # "iasi" is a user defined key + STEPS=dict( + _6h=dict( + variables=["rawbt_1", "rawbt_2", "rawbt_3", "rawbt_4"], + data="amsr_h180", + ), + ), + ), + ), + ), + ), + ), +) + + +class Sample: + def __init__(self, datahandlers): + self.datahandlers = datahandlers + + def __repr__(self): + console = Console(record=True, width=120) + tree = self._build_tree() + with console.capture() as capture: + console.print(tree) + return capture.get() + + def _build_tree(self, label="Sample"): + return Tree(label) + + +class GroupedSample(Sample): + def __init__(self, datahandlers, dic): + super().__init__(datahandlers) + self._samples = {k: sample_factory(**v) for k, v in dic.items()} + + def __getitem__(self, item): + return {k: v[item] for k, v in self._samples.items()} + + def _build_tree(self, label="GroupedSample"): + tree = Tree(label) + for k, v in self._samples.items(): + subtree = v._build_tree(label=f"{k}: {type(v).__name__}") + tree.add(subtree) + return tree + + +class StepSample(Sample): + def __init__(self, datahandlers, dic): + super().__init__(datahandlers) + self._samples = {k: sample_factory(**v) for k, v in dic.items()} + + def __getitem__(self, item): + out = [] + for k, v in self._samples.items(): + if k == "_6h": + out.append(v[item - 1]) + elif k == "_0h": + out.append(v[item]) + elif k == "p6h": + out.append(v[item + 1]) + return out + + def _build_tree(self, label="GroupedSample"): + tree = Tree(label) + for k, v in self._samples.items(): + subtree = v._build_tree(label=f"{k}: {type(v).__name__}") + tree.add(subtree) + return tree + + +class Leaf(Sample): + def __init__(self, datahandlers, variables, data): + super().__init__(datahandlers) + self.data_key = data + self.variables = variables + + def __getitem__(self, item): + result = Result(self.data_key, item, variables=self.variables) + return result.load() + + def _build_tree(self, label="Leaf"): + return Tree(f"{label} -> {self.data_key} variables={self.variables}") + + +def sample_factory(datahandlers=None, **kwargs): + kwargs = kwargs.copy() + if datahandlers is None: + datahandlers = [] + if "GROUPS" in kwargs: + return GroupedSample(datahandlers, kwargs["GROUPS"]) + if "STEPS" in kwargs: + return StepSample(datahandlers, kwargs["STEPS"]) + if "variables" in kwargs: + return Leaf(datahandlers, variables=kwargs["variables"], data=kwargs["data"]) + assert False, f"Unknown sample type for kwargs {kwargs}" + + +class Result: + def __init__(self, datahandler_key, *args, variables=[], **kwargs): + cfg = CONFIG["data"][datahandler_key] + assert "select" not in cfg, (cfg, variables) + variables = [f"{datahandler_key}.{v}" for v in variables] + dh = DataHandler(datahandler_key, **cfg, select=variables) + + self.func = dh.__getitem__ + self.args = args + self.kwargs = kwargs + + def load(self): + return self.func(*self.args, **self.kwargs) + + def __repr__(self): + inside = [] + inside += [str(arg) for arg in self.args] + inside += [f"{k}={v}" for k, v in self.kwargs.items()] + return f"Result({self.datahandler} ({', '.join(inside)})" + + +class DataHandler: + def __init__(self, name, **config): + self.name = name + if isinstance(config, str): + config = dict(dataset=config) + if isinstance(config["dataset"], str): + config = dict(dataset=config) + + self.config = config + self._config_str = " ".join(f"{k}={v}" for k, v in config.items()) + + def is_grouped_dataset(self, ds): + from anemoi.datasets.data.records import BaseRecordsDataset + + return isinstance(ds, BaseRecordsDataset) + + @property + def ds(self): + ds = open_dataset(**self.config["dataset"]) + print(f"🔍 Opened dataset {self.name} with config: {self._config_str}") + if self.name not in ds.groups: + raise ValueError(f"Group '{self.name}' not found in dataset. Available groups: {ds.groups}") + ds = ds[self.name] + print(f" Available variables for group '{self.name}': {ds.variables}") + return ds + + def __getitem__(self, item): + data = self.ds[item] + assert isinstance(data, np.ndarray), f"Expected np.array, got {type(data)}, {type(self.ds)}" + return data + return f"np.array ds[{item}] with ds from {self._config_str} " + + def __str__(self): + return f"DataHandler({self._config_str})" + + +def show_yaml(structure): + return yaml.dump(structure, indent=2, sort_keys=False) + + +def show_json(structure): + return json.dumps(structure, indent=2, default=shorten_numpy) + + +def shorten_numpy(structure): + if isinstance(structure, np.ndarray): + return f"np.array({structure.shape})" + return structure + + +def get_base_seed(): + """ + Get a base seed for random number generation. + This is a placeholder function; replace with actual logic to get a base seed. + """ + return 42 # Example fixed seed, replace with actual logic as needed + + +class DOPDataset(IterableDataset): + def __init__( + self, + # config: dict, + shuffle: bool = True, + rollout: int = 1, + multistep: int = 1, + task: str = "training", + ) -> None: + + self.shuffle = shuffle + # self.config = config + self.rollout = rollout + self.multistep = multistep + self.task = task + + # lazy init + self.n_samples_per_epoch_total: int = 0 + self.n_samples_per_epoch_per_worker: int = 0 + + # additional state vars (lazy init) + self.n_samples_per_worker = 0 + self.chunk_index_range: Optional[np.ndarray] = None + self.shuffle = shuffle + self.rng: Optional[np.random.Generator] = None + self.worker_id: int = -1 + + # "full" shuffling + self.data_indices: Optional[np.ndarray] = None + + self.seed_comm_group_id = 0 + self.seed_comm_num_groups = 1 + + self._sample_factory = sample_factory(**CONFIG["sample"]) + + self.len = 25 # len(self._sample_factory) + + def __get_sample(self, index: int): + """ + Get a sample from the dataset. + """ + return self._sample_factory[index] + + def per_worker_init(self, n_workers: int, worker_id: int) -> None: + """Called by worker_init_func on each copy of dataset. + + This initialises after the worker process has been spawned. + + Parameters + ---------- + n_workers : int + Number of workers + worker_id : int + Worker ID + """ + self.worker_id = worker_id + + # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs + len_corrected = self.len - self.rollout - self.multistep + 1 + self.data_indices = np.arange(len_corrected, dtype=np.uint32) + + # Divide this equally across shards (one shard per group!) + shard_size = len_corrected // self.seed_comm_num_groups + shard_start = self.seed_comm_group_id * shard_size + shard_end = min((self.seed_comm_group_id + 1) * shard_size, self.len - self.rollout - self.multistep + 1) + + shard_len = shard_end - shard_start + self.n_samples_per_worker = shard_len // n_workers + + low = shard_start + worker_id * self.n_samples_per_worker + high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + + seed = get_base_seed() # all workers get the same seed (so they all get the same index shuffle) + torch.manual_seed(seed) + random.seed(seed) + self.rng = np.random.default_rng(seed=seed) + sanity_rnd = self.rng.random(1) + print("Sanity check random number:", sanity_rnd) + + def __iter__(self): + if self.shuffle: + # do a full shuffle, then get my index range + shuffled_data_indices = self.rng.choice(self.data_indices, size=len(self.data_indices), replace=False) + shuffled_chunk_indices = shuffled_data_indices[self.chunk_index_range] + + while True: # the pl.Trainer will break out of this loop after a fixed number of samples + idx = self.rng.choice(shuffled_chunk_indices) + print( + f"TRAINING: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...", + ) + yield self.__get_sample(idx) + + else: + shuffled_chunk_indices = self.data_indices[self.chunk_index_range] + # no shuffle, just iterate over the chunk indices + for idx in self.chunk_index_range: + print( + f"VALIDATION: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...", + ) + yield self.__get_sample(idx) + + +def worker_init_func(worker_id: int) -> None: + """Configures each dataset worker process. + + Calls WeatherBenchDataset.per_worker_init() on each dataset object. + + Parameters + ---------- + worker_id : int + Worker ID + + Raises + ------ + RuntimeError + If worker_info is None + """ + worker_info = get_worker_info() # information specific to each worker process + if worker_info is None: + print("worker_info is None! Set num_workers > 0 in your dataloader!") + raise RuntimeError + dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. + dataset_obj.per_worker_init( + n_workers=worker_info.num_workers, + worker_id=worker_id, + ) + + +if __name__ == "__main__": + + ds = DOPDataset( + # CONFIG, + shuffle=False, + rollout=1, + multistep=1, + task="training", + ) + + loader_params = { + "batch_size": 1, # must be 1 for the time being + "batch_sampler": None, + "num_workers": 2, + "pin_memory": False, + "worker_init_fn": worker_init_func, + # "collate_fn": None, # collator_wrapper(return_original_metadata=cfg_.dataloader.return_dates), + } + + dl = torch.utils.data.DataLoader(ds, **loader_params, sampler=None) + + for batch_idx, batch in enumerate(dl): + print.info("%s", batch) + if batch_idx >= 1: + break From 5bebd276410d2335456577f106e409443ef32a7e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 11:33:01 +0000 Subject: [PATCH 050/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dop_dataset.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/dop_dataset.py b/dop_dataset.py index ee3f6cc68..38348c531 100644 --- a/dop_dataset.py +++ b/dop_dataset.py @@ -1,19 +1,15 @@ -from typing import Optional -import numpy as np +import json import os - -from torch import Tensor -from torch.utils.data import IterableDataset -from torch.utils.data import get_worker_info - -import torch import random -import json +from typing import Optional import numpy as np +import torch import yaml from rich.console import Console from rich.tree import Tree +from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info from anemoi.datasets import open_dataset @@ -236,8 +232,7 @@ def shorten_numpy(structure): def get_base_seed(): - """ - Get a base seed for random number generation. + """Get a base seed for random number generation. This is a placeholder function; replace with actual logic to get a base seed. """ return 42 # Example fixed seed, replace with actual logic as needed @@ -281,9 +276,7 @@ def __init__( self.len = 25 # len(self._sample_factory) def __get_sample(self, index: int): - """ - Get a sample from the dataset. - """ + """Get a sample from the dataset.""" return self._sample_factory[index] def per_worker_init(self, n_workers: int, worker_id: int) -> None: From fbb4121713af1c30668d0bbdcc7cfaa5af962144 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 2 Jul 2025 11:39:26 +0000 Subject: [PATCH 051/212] remove codes that does not belong here --- dop_dataset.py | 392 ------------------------------------------------- 1 file changed, 392 deletions(-) delete mode 100644 dop_dataset.py diff --git a/dop_dataset.py b/dop_dataset.py deleted file mode 100644 index 38348c531..000000000 --- a/dop_dataset.py +++ /dev/null @@ -1,392 +0,0 @@ -import json -import os -import random -from typing import Optional - -import numpy as np -import torch -import yaml -from rich.console import Console -from rich.tree import Tree -from torch.utils.data import IterableDataset -from torch.utils.data import get_worker_info - -from anemoi.datasets import open_dataset - -CONFIG = dict( - data=dict( - # era5=dict( - # dataset=dict(dataset="aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8", set_group="era5"), - # # preprocessors=dict( - # # tp=[dict(normalizer="mean-std")]), - # # ), - # ), - snow=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), - metop_a=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), - amsr2_h180=dict(dataset="observations-testing-2018-2018-6h-v1-one-month"), - ), - sample=dict( - GROUPS=dict( - input=dict( - GROUPS=dict( - # fields=dict( # "fields" is a user defined key - # STEPS=dict( - # _6h=dict( - # variables=["q_50", "2t"], - # data="era5", - # ), - # _0h=dict( - # variables=["q_50", "2t"], - # data="era5", - # ), - # ), - # ), - # user-friendly config would be: - # fields=dict( - # steps=['-6h', '0h'], - # variables=["q_50", "2t"], - # data="era5", - # ), - ascat_metop_a=dict( # "metar" is a user defined key - STEPS=dict( - _6h=dict( - variables=["scatss_1", "scatss_2"], - data="metop_a", - ), - ), - ), - snow=dict( # "iasi" is a user defined key - STEPS=dict( - _6h=dict( - variables=["sdepth_0"], - data="snow", - ), - ), - ), - amsr2=dict( # "iasi" is a user defined key - STEPS=dict( - _6h=dict( - variables=["rawbt_1", "rawbt_2", "rawbt_3", "rawbt_4"], - data="amsr_h180", - ), - ), - ), - ), - ), - ), - ), -) - - -class Sample: - def __init__(self, datahandlers): - self.datahandlers = datahandlers - - def __repr__(self): - console = Console(record=True, width=120) - tree = self._build_tree() - with console.capture() as capture: - console.print(tree) - return capture.get() - - def _build_tree(self, label="Sample"): - return Tree(label) - - -class GroupedSample(Sample): - def __init__(self, datahandlers, dic): - super().__init__(datahandlers) - self._samples = {k: sample_factory(**v) for k, v in dic.items()} - - def __getitem__(self, item): - return {k: v[item] for k, v in self._samples.items()} - - def _build_tree(self, label="GroupedSample"): - tree = Tree(label) - for k, v in self._samples.items(): - subtree = v._build_tree(label=f"{k}: {type(v).__name__}") - tree.add(subtree) - return tree - - -class StepSample(Sample): - def __init__(self, datahandlers, dic): - super().__init__(datahandlers) - self._samples = {k: sample_factory(**v) for k, v in dic.items()} - - def __getitem__(self, item): - out = [] - for k, v in self._samples.items(): - if k == "_6h": - out.append(v[item - 1]) - elif k == "_0h": - out.append(v[item]) - elif k == "p6h": - out.append(v[item + 1]) - return out - - def _build_tree(self, label="GroupedSample"): - tree = Tree(label) - for k, v in self._samples.items(): - subtree = v._build_tree(label=f"{k}: {type(v).__name__}") - tree.add(subtree) - return tree - - -class Leaf(Sample): - def __init__(self, datahandlers, variables, data): - super().__init__(datahandlers) - self.data_key = data - self.variables = variables - - def __getitem__(self, item): - result = Result(self.data_key, item, variables=self.variables) - return result.load() - - def _build_tree(self, label="Leaf"): - return Tree(f"{label} -> {self.data_key} variables={self.variables}") - - -def sample_factory(datahandlers=None, **kwargs): - kwargs = kwargs.copy() - if datahandlers is None: - datahandlers = [] - if "GROUPS" in kwargs: - return GroupedSample(datahandlers, kwargs["GROUPS"]) - if "STEPS" in kwargs: - return StepSample(datahandlers, kwargs["STEPS"]) - if "variables" in kwargs: - return Leaf(datahandlers, variables=kwargs["variables"], data=kwargs["data"]) - assert False, f"Unknown sample type for kwargs {kwargs}" - - -class Result: - def __init__(self, datahandler_key, *args, variables=[], **kwargs): - cfg = CONFIG["data"][datahandler_key] - assert "select" not in cfg, (cfg, variables) - variables = [f"{datahandler_key}.{v}" for v in variables] - dh = DataHandler(datahandler_key, **cfg, select=variables) - - self.func = dh.__getitem__ - self.args = args - self.kwargs = kwargs - - def load(self): - return self.func(*self.args, **self.kwargs) - - def __repr__(self): - inside = [] - inside += [str(arg) for arg in self.args] - inside += [f"{k}={v}" for k, v in self.kwargs.items()] - return f"Result({self.datahandler} ({', '.join(inside)})" - - -class DataHandler: - def __init__(self, name, **config): - self.name = name - if isinstance(config, str): - config = dict(dataset=config) - if isinstance(config["dataset"], str): - config = dict(dataset=config) - - self.config = config - self._config_str = " ".join(f"{k}={v}" for k, v in config.items()) - - def is_grouped_dataset(self, ds): - from anemoi.datasets.data.records import BaseRecordsDataset - - return isinstance(ds, BaseRecordsDataset) - - @property - def ds(self): - ds = open_dataset(**self.config["dataset"]) - print(f"🔍 Opened dataset {self.name} with config: {self._config_str}") - if self.name not in ds.groups: - raise ValueError(f"Group '{self.name}' not found in dataset. Available groups: {ds.groups}") - ds = ds[self.name] - print(f" Available variables for group '{self.name}': {ds.variables}") - return ds - - def __getitem__(self, item): - data = self.ds[item] - assert isinstance(data, np.ndarray), f"Expected np.array, got {type(data)}, {type(self.ds)}" - return data - return f"np.array ds[{item}] with ds from {self._config_str} " - - def __str__(self): - return f"DataHandler({self._config_str})" - - -def show_yaml(structure): - return yaml.dump(structure, indent=2, sort_keys=False) - - -def show_json(structure): - return json.dumps(structure, indent=2, default=shorten_numpy) - - -def shorten_numpy(structure): - if isinstance(structure, np.ndarray): - return f"np.array({structure.shape})" - return structure - - -def get_base_seed(): - """Get a base seed for random number generation. - This is a placeholder function; replace with actual logic to get a base seed. - """ - return 42 # Example fixed seed, replace with actual logic as needed - - -class DOPDataset(IterableDataset): - def __init__( - self, - # config: dict, - shuffle: bool = True, - rollout: int = 1, - multistep: int = 1, - task: str = "training", - ) -> None: - - self.shuffle = shuffle - # self.config = config - self.rollout = rollout - self.multistep = multistep - self.task = task - - # lazy init - self.n_samples_per_epoch_total: int = 0 - self.n_samples_per_epoch_per_worker: int = 0 - - # additional state vars (lazy init) - self.n_samples_per_worker = 0 - self.chunk_index_range: Optional[np.ndarray] = None - self.shuffle = shuffle - self.rng: Optional[np.random.Generator] = None - self.worker_id: int = -1 - - # "full" shuffling - self.data_indices: Optional[np.ndarray] = None - - self.seed_comm_group_id = 0 - self.seed_comm_num_groups = 1 - - self._sample_factory = sample_factory(**CONFIG["sample"]) - - self.len = 25 # len(self._sample_factory) - - def __get_sample(self, index: int): - """Get a sample from the dataset.""" - return self._sample_factory[index] - - def per_worker_init(self, n_workers: int, worker_id: int) -> None: - """Called by worker_init_func on each copy of dataset. - - This initialises after the worker process has been spawned. - - Parameters - ---------- - n_workers : int - Number of workers - worker_id : int - Worker ID - """ - self.worker_id = worker_id - - # Total number of valid ICs is dataset length minus rollout minus additional multistep inputs - len_corrected = self.len - self.rollout - self.multistep + 1 - self.data_indices = np.arange(len_corrected, dtype=np.uint32) - - # Divide this equally across shards (one shard per group!) - shard_size = len_corrected // self.seed_comm_num_groups - shard_start = self.seed_comm_group_id * shard_size - shard_end = min((self.seed_comm_group_id + 1) * shard_size, self.len - self.rollout - self.multistep + 1) - - shard_len = shard_end - shard_start - self.n_samples_per_worker = shard_len // n_workers - - low = shard_start + worker_id * self.n_samples_per_worker - high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) - self.chunk_index_range = np.arange(low, high, dtype=np.uint32) - - seed = get_base_seed() # all workers get the same seed (so they all get the same index shuffle) - torch.manual_seed(seed) - random.seed(seed) - self.rng = np.random.default_rng(seed=seed) - sanity_rnd = self.rng.random(1) - print("Sanity check random number:", sanity_rnd) - - def __iter__(self): - if self.shuffle: - # do a full shuffle, then get my index range - shuffled_data_indices = self.rng.choice(self.data_indices, size=len(self.data_indices), replace=False) - shuffled_chunk_indices = shuffled_data_indices[self.chunk_index_range] - - while True: # the pl.Trainer will break out of this loop after a fixed number of samples - idx = self.rng.choice(shuffled_chunk_indices) - print( - f"TRAINING: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...", - ) - yield self.__get_sample(idx) - - else: - shuffled_chunk_indices = self.data_indices[self.chunk_index_range] - # no shuffle, just iterate over the chunk indices - for idx in self.chunk_index_range: - print( - f"VALIDATION: Worker {self.worker_id} (pid {os.getpid()}) fetching sample index {idx} ...", - ) - yield self.__get_sample(idx) - - -def worker_init_func(worker_id: int) -> None: - """Configures each dataset worker process. - - Calls WeatherBenchDataset.per_worker_init() on each dataset object. - - Parameters - ---------- - worker_id : int - Worker ID - - Raises - ------ - RuntimeError - If worker_info is None - """ - worker_info = get_worker_info() # information specific to each worker process - if worker_info is None: - print("worker_info is None! Set num_workers > 0 in your dataloader!") - raise RuntimeError - dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. - dataset_obj.per_worker_init( - n_workers=worker_info.num_workers, - worker_id=worker_id, - ) - - -if __name__ == "__main__": - - ds = DOPDataset( - # CONFIG, - shuffle=False, - rollout=1, - multistep=1, - task="training", - ) - - loader_params = { - "batch_size": 1, # must be 1 for the time being - "batch_sampler": None, - "num_workers": 2, - "pin_memory": False, - "worker_init_fn": worker_init_func, - # "collate_fn": None, # collator_wrapper(return_original_metadata=cfg_.dataloader.return_dates), - } - - dl = torch.utils.data.DataLoader(ds, **loader_params, sampler=None) - - for batch_idx, batch in enumerate(dl): - print.info("%s", batch) - if batch_idx >= 1: - break From e6ecbc003b31c95b945d8e8a36c229c2fad15c81 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 2 Jul 2025 13:07:44 +0000 Subject: [PATCH 052/212] up --- src/anemoi/datasets/data/records/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 07d18d72b..5d81dc9cf 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -83,7 +83,7 @@ def __getitem__(self, i): if isinstance(i, str): return self._getgroup(i) - if isinstance(i, int): + if isinstance(i, (int, np.integer)): return self._getrecord(i) raise ValueError(f"Invalid index {i}, must be int or str") From ad61f6595051ee14f9557e59975195d8b987e870 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 3 Jul 2025 11:05:02 +0000 Subject: [PATCH 053/212] bring lats and lons --- src/anemoi/datasets/data/records/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 5d81dc9cf..81b492da7 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -237,8 +237,8 @@ def _load_data(self, i): data = self.forward[i] out = {} out[f"data:{self._name}"] = data - # out[f"latitudes:{self._name}"] = self.forward.latitudes - # out[f"longitudes:{self._name}"] = self.forward.longitudes + out[f"latitudes:{self._name}"] = self.forward.latitudes + out[f"longitudes:{self._name}"] = self.forward.longitudes out[f"timedeltas:{self._name}"] = np.zeros_like(data, dtype="timedelta64[s]") + _to_numpy_date( self.forward.dates[i] ) @@ -260,6 +260,14 @@ def variables(self): def dates(self): return self.forward.dates + @property + def longitudes(self): + return self._nest_in_dict(self.forward.longitudes) + + @property + def latitudes(self): + return self._nest_in_dict(self.forward.latitudes) + @property def name_to_index(self): return self._nest_in_dict(self.forward.name_to_index) From ea2f7d91828bb11ae4f4b9eb4bf5f7fda3a4ff75 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 3 Jul 2025 11:26:04 +0000 Subject: [PATCH 054/212] update timedelta to 0s for field records --- src/anemoi/datasets/data/records/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 81b492da7..04352955d 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -239,9 +239,9 @@ def _load_data(self, i): out[f"data:{self._name}"] = data out[f"latitudes:{self._name}"] = self.forward.latitudes out[f"longitudes:{self._name}"] = self.forward.longitudes - out[f"timedeltas:{self._name}"] = np.zeros_like(data, dtype="timedelta64[s]") + _to_numpy_date( - self.forward.dates[i] - ) + out[f"timedeltas:{self._name}"] = np.zeros(data.shape[-1], dtype="timedelta64[s]")# + _to_numpy_date( + # self.forward.dates[i] + #) return out @property From c10ea855b722879e5d6896076aef19b7ca8ca14c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 11:26:31 +0000 Subject: [PATCH 055/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/data/records/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 04352955d..9acce6165 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -237,11 +237,11 @@ def _load_data(self, i): data = self.forward[i] out = {} out[f"data:{self._name}"] = data - out[f"latitudes:{self._name}"] = self.forward.latitudes - out[f"longitudes:{self._name}"] = self.forward.longitudes - out[f"timedeltas:{self._name}"] = np.zeros(data.shape[-1], dtype="timedelta64[s]")# + _to_numpy_date( + out[f"latitudes:{self._name}"] = self.forward.latitudes + out[f"longitudes:{self._name}"] = self.forward.longitudes + out[f"timedeltas:{self._name}"] = np.zeros(data.shape[-1], dtype="timedelta64[s]") # + _to_numpy_date( # self.forward.dates[i] - #) + # ) return out @property From e5658669dc2f601ec773dcc4f82b940ad05ac4ec Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 3 Jul 2025 11:35:41 +0000 Subject: [PATCH 056/212] add metadata --- src/anemoi/datasets/data/records/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 04352955d..df9ae6d1c 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -242,6 +242,7 @@ def _load_data(self, i): out[f"timedeltas:{self._name}"] = np.zeros(data.shape[-1], dtype="timedelta64[s]")# + _to_numpy_date( # self.forward.dates[i] #) + out[f"metadata:{self._name}"] = self.forward.metadata() return out @property From 98dae84b97f9b2357c42ab0874b1cbb75b36d1b1 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 7 Jul 2025 08:09:43 +0000 Subject: [PATCH 057/212] revert inspect --- src/anemoi/datasets/commands/inspect.py | 28 +++++++++---------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index cb3aaf847..400cdcf98 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -22,7 +22,6 @@ import numpy as np import semantic_version import tqdm -from anemoi.utils.config import load_any_dict_format from anemoi.utils.humanize import bytes from anemoi.utils.humanize import bytes_to_human from anemoi.utils.humanize import when @@ -32,8 +31,8 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset -from anemoi.datasets.data.stores import dataset_lookup from anemoi.datasets.data.stores import open_zarr +from anemoi.datasets.data.stores import zarr_lookup from . import Command @@ -301,12 +300,12 @@ def variables(self) -> List[str]: @property def total_size(self) -> Optional[int]: """Get the total size of the dataset.""" - return self.metadata.get("total_size") + return self.zarr.attrs.get("total_size") @property def total_number_of_files(self) -> Optional[int]: """Get the total number of files in the dataset.""" - return self.metadata.get("total_number_of_files") + return self.zarr.attrs.get("total_number_of_files") def print_sizes(self, size: bool) -> None: """Print the size and number of files in the dataset. @@ -363,14 +362,15 @@ def build_flags(self) -> Optional[NDArray[Any]]: @cached_property def copy_flags(self) -> Optional[NDArray[Any]]: - if not self.zarr or "_copy" not in self.zarr: + """Get the copy flags of the dataset.""" + if "_copy" not in self.zarr: return None return self.zarr["_copy"][:] @property def copy_in_progress(self) -> bool: """Check if a copy operation is in progress.""" - if not self.zarr or "_copy" not in self.zarr: + if "_copy" not in self.zarr: return False start = self.zarr["_copy"].attrs.get("copy_start_timestamp") @@ -383,8 +383,6 @@ def copy_in_progress(self) -> bool: @property def build_lengths(self) -> Optional[NDArray]: """Get the build lengths of the dataset.""" - if not self.zarr: - return None return self.zarr.get("_build_lengths") def progress(self) -> None: @@ -654,7 +652,7 @@ def details(self) -> None: def ready(self) -> bool: """Check if the dataset is ready.""" - if not self.zarr or "_build_flags" not in self.zarr: + if "_build_flags" not in self.zarr: return False build_flags = self.zarr["_build_flags"] @@ -710,7 +708,7 @@ class Version0_13(Version0_12): @property def build_flags(self) -> Optional[NDArray]: """Get the build flags for the dataset.""" - if not self.zarr or "_build" not in self.zarr: + if "_build" not in self.zarr: return None build = self.zarr["_build"] return build.get("flags") @@ -820,15 +818,9 @@ def _info(self, path: str) -> Version: Version The version object of the dataset. """ - resolved_path = dataset_lookup(path) - if resolved_path.endswith(".vz"): - LOG.warning(f"Inspecting a .vz file: {resolved_path}. This is not supported yet.") - metadata = load_any_dict_format(os.path.join(resolved_path, "metadata.json")) - z = None - else: - z = open_zarr(resolved_path) - metadata = dict(z.attrs) + z = open_zarr(zarr_lookup(path)) + metadata = dict(z.attrs) version = metadata.get("version", "0.0.0") if isinstance(version, int): version = f"0.{version}" From 3082edfa84a3a479747768ef57eac77df48bfb07 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 09:45:02 +0000 Subject: [PATCH 058/212] refactor missing --- .gitignore | 3 + src/anemoi/datasets/create/__init__.py | 3 + src/anemoi/datasets/create/input/__init__.py | 15 +- src/anemoi/datasets/create/input/action.py | 388 +++++++----------- .../datasets/create/sources/__init__.py | 11 + src/anemoi/datasets/create/sources/legacy.py | 6 +- 6 files changed, 174 insertions(+), 252 deletions(-) diff --git a/.gitignore b/.gitignore index f3777cde2..746409a43 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,6 @@ _version.py *.to_upload tempCodeRunnerFile.python Untitled-*.py +*.prof +prof/ +*.gz diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 8d74975c5..9c8c4613a 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -21,6 +21,7 @@ import cftime import numpy as np +import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -671,6 +672,8 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) + rich.print("Minimal input dates:", self.minimal_input) + variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index c64f8275b..9f30f1abc 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -12,6 +12,8 @@ from typing import Any from typing import Union +import rich + from anemoi.datasets.dates.groups import GroupOfDates from .trace import trace_select @@ -22,7 +24,11 @@ class Context: """Context for building input data.""" - pass + use_grib_paramid = False + + def trace(self, emoji, message) -> None: + + rich.print(f"{emoji}: {message}") class InputBuilder: @@ -67,13 +73,12 @@ def select(self, group_of_dates: GroupOfDates) -> Any: Any Selected data. """ - from .action import ActionContext from .action import action_factory """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) + context = Context() + action = action_factory(self.config, self.action_path) + return action(context, group_of_dates) def __repr__(self) -> str: """Return a string representation of the InputBuilder. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index eadf01339..7c671db2f 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,253 +7,153 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json import logging -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from earthkit.data.core.order import build_remapping +LOG = logging.getLogger(__name__) -from ...dates.groups import GroupOfDates -from .context import Context -from .template import substitute -LOG = logging.getLogger(__name__) +class Predicate: + def __init__(self, config): + self.config = config + + def __repr__(self): + return f"Predicate({self.config})" + + def match(self, dates): + # Just a demo + return True + + +class Concat: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a dict {list}" + + self.choices = [] + + for item in config: + + assert "dates" in item, f"Value must contain the key 'date' {item}" + predicate = Predicate(item.pop("dates")) + action = action_factory(item) + + self.choices.append((predicate, action)) + + def __repr__(self): + return f"Concat({self.choices})" + + def __call__(self, context, group_of_dates): + + for predicate, action in self.choices: + if predicate.match(group_of_dates): + return action(group_of_dates) + + raise ValueError(f"No matching predicate for dates: {group_of_dates}") + + +class Join: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a list {config}" + self.actions = [action_factory(item) for item in config] + + def __repr__(self): + return f"Join({self.actions})" + + def __call__(self, context, group_of_dates): + results = [] + for action in self.actions: + results.append(action(context, group_of_dates)) + return results + + +class Pipe: + def __init__(self, config): + assert isinstance(config, list), f"Value must be a list {config}" + self.actions = [action_factory(item) for item in config] + + def __repr__(self): + return f"Pipe({self.actions})" + + def __call__(self, context, dates): + result = None + for action in self.actions: + if result is None: + result = action(dates) + else: + result = action(result) + return result + + +class Function: + def __init__(self, config): + self.config = config + + # # if self._source: + # # self.source = action_factory(config[self._source]) + # # else: + # # self.source = None + + # def __repr__(self): + # return f"{self.__class__.__name__}({self.config})" + + # def __call__(self, context, dates): + # # Just a demo, in real case it would do something with the dates + + # config = self.config.copy() + # if self.source: + # config[self._source] = self.source(dates) + + # return {self.__class__.__name__: self.config, "dates": dates} + + +class SourceFunction(Function): + def __init__(self, config): + from anemoi.datasets.create.sources import create_source + + super().__init__(config) + config["_type"] = self.name + self.source = create_source(self, config) + + def __call__(self, context, group_of_dates): + return self.source.execute(context, group_of_dates.dates) + + +class FilterFunction(Function): + pass + + +def new_source(name, source=None): + return type(name.title(), (SourceFunction,), {"name": name}) + + +def new_filter(name, source=None): + return type(name.title(), (FilterFunction,), {"name": name}) + + +KLASS = { + "concat": Concat, + "join": Join, + "pipe": Pipe, +} + +LEN_KLASS = len(KLASS) + + +def make(key, config): + + if LEN_KLASS == len(KLASS): + # Load pluggins + from anemoi.datasets.create.sources import registered_sources + + for name in registered_sources(): + assert name not in KLASS, f"Duplicate source name: {name}" + KLASS[name] = new_source(name) + + return KLASS[key](config) + +def action_factory(data, *path): + assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert len(data) == 1, "Input data must contain exactly one key-value pair" -class Action: - """Represents an action to be performed within a given context. - - Attributes - ---------- - context : ActionContext - The context in which the action exists. - kwargs : Dict[str, Any] - Additional keyword arguments. - args : Any - Additional positional arguments. - action_path : List[str] - The action path. - """ - - def __init__( - self, context: "ActionContext", action_path: List[str], /, *args: Any, **kwargs: Dict[str, Any] - ) -> None: - """Initialize an Action instance. - - Parameters - ---------- - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - args : Any - Additional positional arguments. - kwargs : Dict[str, Any] - Additional keyword arguments. - """ - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x: str) -> str: - """Shorten the string representation if it exceeds 1000 characters. - - Parameters - ---------- - x : str - The string to shorten. - - Returns - ------- - str - The shortened string. - """ - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def _repr(self, *args: Any, _indent_: str = "\n", _inline_: str = "", **kwargs: Any) -> str: - """Generate a string representation of the Action instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str, optional - The indentation string, by default "\n". - _inline_ : str, optional - The inline string, by default "". - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Action instance. - - Returns - ------- - str - The string representation. - """ - return self._repr() - - def select(self, dates: object, **kwargs: Any) -> None: - """Select dates for the action. - - Parameters - ---------- - dates : object - The dates to select. - kwargs : Any - Additional keyword arguments. - """ - self._raise_not_implemented() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the selection of a group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates to trace. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({group_of_dates})" - - -class ActionContext(Context): - """Represents the context in which an action is performed. - - Attributes - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - """Initialize an ActionContext instance. - - Parameters - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -def action_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str]) -> Action: - """Factory function to create an Action instance based on the configuration. - - Parameters - ---------- - config : Dict[str, Any] - The action configuration. - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - - Returns - ------- - Action - The created Action instance. - """ - from .concat import ConcatAction - from .data_sources import DataSourcesAction - from .function import FunctionAction - from .join import JoinAction - from .pipe import PipeAction - from .repeated_dates import RepeatedDatesAction - - # from .data_sources import DataSourcesAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - "data_sources": DataSourcesAction, - "data-sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - "repeated_dates": RepeatedDatesAction, - "repeated-dates": RepeatedDatesAction, - }.get(key) - - if cls is None: - from ..sources import create_source - - source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source) - - return cls(context, action_path + [key], *args, **kwargs) + key, value = next(iter(data.items())) + return make(key, value) diff --git a/src/anemoi/datasets/create/sources/__init__.py b/src/anemoi/datasets/create/sources/__init__.py index f8b99f36d..1710d1303 100644 --- a/src/anemoi/datasets/create/sources/__init__.py +++ b/src/anemoi/datasets/create/sources/__init__.py @@ -34,3 +34,14 @@ def create_source(context: Any, config: Any) -> Any: The created source. """ return source_registry.from_config(config, context) + + +def registered_sources() -> list[str]: + """Get a list of registered source names. + + Returns + ------- + list[str] + A list of names of registered sources. + """ + return source_registry.registered diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d72d0b3f4..f8035ee16 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -71,13 +71,13 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, dates) -> Any: + def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(self.context, (self.args, self.kwargs)) + args, kwargs = resolve(context, (self.args, self.kwargs)) try: - return execute(self.context, dates, *args, **kwargs) + return execute(context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") From ef4a5c9507dc8e11822bbbcc8d6e448b18054d17 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:19:48 +0000 Subject: [PATCH 059/212] add references --- src/anemoi/datasets/create/input/__init__.py | 65 ++++++--- src/anemoi/datasets/create/input/action.py | 141 +++++++++++-------- 2 files changed, 131 insertions(+), 75 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 9f30f1abc..d561a9fb4 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -26,10 +26,56 @@ class Context: use_grib_paramid = False + def __init__(self): + self.results = {} + def trace(self, emoji, message) -> None: rich.print(f"{emoji}: {message}") + def register(self, data: Any, path: list[str]) -> Any: + """Register data in the context. + + Parameters + ---------- + data : Any + Data to register. + path : list[str] + Path where the data should be registered. + + Returns + ------- + Any + Registered data. + """ + # This is a placeholder for actual registration logic. + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + +class FieldContext(Context): + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument.dates + class InputBuilder: """Builder class for creating input data from configuration and data sources.""" @@ -76,25 +122,10 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = Context() - action = action_factory(self.config, self.action_path) + context = FieldContext() + action = action_factory(self.config, "input") return action(context, group_of_dates) - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) - def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7c671db2f..8fdbf10be 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,6 +9,8 @@ import logging +import rich + LOG = logging.getLogger(__name__) @@ -24,8 +26,16 @@ def match(self, dates): return True -class Concat: - def __init__(self, config): +class Action: + def __init__(self, config, *path): + self.config = config + self.path = path + + +class Concat(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + assert isinstance(config, list), f"Value must be a dict {list}" self.choices = [] @@ -41,104 +51,119 @@ def __init__(self, config): def __repr__(self): return f"Concat({self.choices})" - def __call__(self, context, group_of_dates): + def __call__(self, context, argument): for predicate, action in self.choices: - if predicate.match(group_of_dates): - return action(group_of_dates) + if predicate.match(argument): + return context.register( + action(context, argument), + self.path, + ) - raise ValueError(f"No matching predicate for dates: {group_of_dates}") + raise ValueError(f"No matching predicate for dates: {argument}") -class Join: - def __init__(self, config): +class Join(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item) for item in config] + + self.actions = [ + action_factory( + item, + *path, + "join", + str(i), + ) + for i, item in enumerate(config) + ] def __repr__(self): return f"Join({self.actions})" - def __call__(self, context, group_of_dates): - results = [] + def __call__(self, context, argument): + results = context.empty_result() for action in self.actions: - results.append(action(context, group_of_dates)) - return results + results += action(context, argument) + return context.register( + results, + self.path, + ) -class Pipe: - def __init__(self, config): +class Pipe(Action): + def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item) for item in config] + super().__init__(config, *path) + self.actions = [ + action_factory( + item, + *path, + "pipe", + str(i), + ) + for i, item in enumerate(config) + ] def __repr__(self): return f"Pipe({self.actions})" - def __call__(self, context, dates): - result = None - for action in self.actions: - if result is None: - result = action(dates) - else: - result = action(result) - return result - - -class Function: - def __init__(self, config): - self.config = config + def __call__(self, context, argument): + result = context.empty_result() - # # if self._source: - # # self.source = action_factory(config[self._source]) - # # else: - # # self.source = None - - # def __repr__(self): - # return f"{self.__class__.__name__}({self.config})" + for i, action in enumerate(self.actions): + if i == 0: + result = action(context, argument) + else: + result = action(context, result) - # def __call__(self, context, dates): - # # Just a demo, in real case it would do something with the dates + return context.register( + result, + self.path, + ) - # config = self.config.copy() - # if self.source: - # config[self._source] = self.source(dates) - # return {self.__class__.__name__: self.config, "dates": dates} +class Function(Action): + def __init__(self, config, *path): + super().__init__(config, *path, self.name) class SourceFunction(Function): - def __init__(self, config): + + def __call__(self, context, argument): from anemoi.datasets.create.sources import create_source - super().__init__(config) - config["_type"] = self.name - self.source = create_source(self, config) + config = context.resolve(self.config) # Substitute the ${} variables in the config + config["_type"] = self.name # Find a better way to do this + source = create_source(self, config) + + rich.print(f"Executing source {self.name} from {config}") - def __call__(self, context, group_of_dates): - return self.source.execute(context, group_of_dates.dates) + return context.register( + source.execute(context, context.source_argument(argument)), + self.path, + ) class FilterFunction(Function): pass -def new_source(name, source=None): +def new_source(name): return type(name.title(), (SourceFunction,), {"name": name}) -def new_filter(name, source=None): +def new_filter(name): return type(name.title(), (FilterFunction,), {"name": name}) -KLASS = { - "concat": Concat, - "join": Join, - "pipe": Pipe, -} +KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} LEN_KLASS = len(KLASS) -def make(key, config): +def make(key, config, path): if LEN_KLASS == len(KLASS): # Load pluggins @@ -148,7 +173,7 @@ def make(key, config): assert name not in KLASS, f"Duplicate source name: {name}" KLASS[name] = new_source(name) - return KLASS[key](config) + return KLASS[key](config, *path) def action_factory(data, *path): @@ -156,4 +181,4 @@ def action_factory(data, *path): assert len(data) == 1, "Input data must contain exactly one key-value pair" key, value = next(iter(data.items())) - return make(key, value) + return make(key, value, path) From 93410d56e5b8c0312da55185b60b85ba4269574a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:58:12 +0000 Subject: [PATCH 060/212] refactor --- src/anemoi/datasets/create/input/__init__.py | 71 +------- src/anemoi/datasets/create/input/action.py | 1 + src/anemoi/datasets/create/input/concat.py | 164 ------------------ src/anemoi/datasets/create/input/context.py | 89 ---------- .../datasets/create/input/context/__init__.py | 70 ++++++++ .../datasets/create/input/context/field.py | 37 ++++ .../datasets/create/input/data_sources.py | 2 +- src/anemoi/datasets/create/input/empty.py | 2 +- src/anemoi/datasets/create/input/function.py | 2 +- src/anemoi/datasets/create/input/join.py | 130 -------------- src/anemoi/datasets/create/input/pipe.py | 66 ------- .../datasets/create/input/repeated_dates.py | 2 +- .../datasets/create/input/result/__init__.py | 21 +++ .../input/{result.py => result/field.py} | 6 +- src/anemoi/datasets/create/input/step.py | 2 +- 15 files changed, 141 insertions(+), 524 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/concat.py delete mode 100644 src/anemoi/datasets/create/input/context.py create mode 100644 src/anemoi/datasets/create/input/context/__init__.py create mode 100644 src/anemoi/datasets/create/input/context/field.py delete mode 100644 src/anemoi/datasets/create/input/join.py delete mode 100644 src/anemoi/datasets/create/input/pipe.py create mode 100644 src/anemoi/datasets/create/input/result/__init__.py rename src/anemoi/datasets/create/input/{result.py => result/field.py} (99%) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index d561a9fb4..19269f587 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,75 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from copy import deepcopy from typing import Any from typing import Union -import rich - +from anemoi.datasets.create.input.context.field import FieldContext from anemoi.datasets.dates.groups import GroupOfDates -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class Context: - """Context for building input data.""" - - use_grib_paramid = False - - def __init__(self): - self.results = {} - - def trace(self, emoji, message) -> None: - - rich.print(f"{emoji}: {message}") - - def register(self, data: Any, path: list[str]) -> Any: - """Register data in the context. - - Parameters - ---------- - data : Any - Data to register. - path : list[str] - Path where the data should be registered. - - Returns - ------- - Any - Registered data. - """ - # This is a placeholder for actual registration logic. - rich.print(f"Registering data at path: {path}") - self.results[tuple(path)] = data - return data - - def resolve(self, config): - config = config.copy() - - for key, value in list(config.items()): - if isinstance(value, str) and value.startswith("${") and value.endswith("}"): - path = tuple(value[2:-1].split(".")) - if path in self.results: - config[key] = self.results[path] - else: - raise KeyError(f"Path {path} not found in results: {self.results.keys()}") - - return config - - -class FieldContext(Context): - def empty_result(self) -> Any: - import earthkit.data as ekd - - return ekd.from_source("empty") - - def source_argument(self, argument: Any) -> Any: - return argument.dates - class InputBuilder: """Builder class for creating input data from configuration and data sources.""" @@ -105,7 +43,6 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) self.config = config self.action_path = ["input"] - @trace_select def select(self, group_of_dates: GroupOfDates) -> Any: """Select data based on the group of dates. @@ -122,9 +59,9 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = FieldContext() + context = FieldContext(**self.kwargs) action = action_factory(self.config, "input") - return action(context, group_of_dates) + return context.create_result(action(context, group_of_dates)) def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 8fdbf10be..76d99f4b7 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -23,6 +23,7 @@ def __repr__(self): def match(self, dates): # Just a demo + raise NotImplementedError("Not yet implemented") return True diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py deleted file mode 100644 index 5399bbc1f..000000000 --- a/src/anemoi/datasets/create/input/concat.py +++ /dev/null @@ -1,164 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from functools import cached_property -from typing import Any -from typing import Dict -from typing import List -from typing import Union - -from earthkit.data import FieldList - -from anemoi.datasets.dates import DatesProvider - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class ConcatResult(Result): - """Represents the result of concatenating multiple results.""" - - def __init__( - self, - context: object, - action_path: List[str], - group_of_dates: GroupOfDates, - results: List[Result], - **kwargs: Any, - ) -> None: - """Initializes a ConcatResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, group_of_dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the concatenated datasource from all results.""" - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self) -> List[str]: - """Returns the list of variables, ensuring all results have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self) -> str: - """Returns a string representation of the ConcatResult instance. - - Returns - ------- - str - A string representation of the ConcatResult instance. - """ - content = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class ConcatAction(Action): - """Represents an action that concatenates multiple actions based on their dates.""" - - def __init__(self, context: object, action_path: List[str], *configs: Dict[str, Any]) -> None: - """Initializes a ConcatAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - configs : Dict[str, Any] - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self) -> str: - """Returns a string representation of the ConcatAction instance. - - Returns - ------- - str - A string representation of the ConcatAction instance. - """ - content = "\n".join([str(i) for i in self.parts]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResult]: - """Selects the concatenated result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - Union[ConcatResult, EmptyResult] - The concatenated result or an empty result. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, group_of_dates) - - return ConcatResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py deleted file mode 100644 index 35784dba7..000000000 --- a/src/anemoi/datasets/create/input/context.py +++ /dev/null @@ -1,89 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import textwrap -from typing import Any -from typing import List -from typing import Tuple -from typing import Union - -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - -LOG = logging.getLogger(__name__) - - -class Context: - """Class to handle the build context in the dataset creation process.""" - - def __init__(self) -> None: - """Initializes a Context instance.""" - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key: Union[List, Tuple]) -> None: - """Marks a reference as needed. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key: Union[List, Tuple], result: Any) -> None: - """Notifies that a result is available for a reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - result : Any - The result object. - """ - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key: Union[List, Tuple]) -> Any: - """Retrieves the result for a given reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - - Returns - ------- - Any - The result for the given reference. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py new file mode 100644 index 000000000..23d31e6f3 --- /dev/null +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -0,0 +1,70 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Context(ABC): + """Context for building input data.""" + + def __init__(self): + self.results = {} + + def trace(self, emoji, message) -> None: + + rich.print(f"{emoji}: {message}") + + def register(self, data: Any, path: list[str]) -> Any: + """Register data in the context. + + Parameters + ---------- + data : Any + Data to register. + path : list[str] + Path where the data should be registered. + + Returns + ------- + Any + Registered data. + """ + # This is a placeholder for actual registration logic. + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + @abstractmethod + def empty_result(self) -> Any: ... + + @abstractmethod + def source_argument(self, argument: Any) -> Any: ... + + @abstractmethod + def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py new file mode 100644 index 000000000..c3a95f8cb --- /dev/null +++ b/src/anemoi/datasets/create/input/context/field.py @@ -0,0 +1,37 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any +from typing import Dict + +from earthkit.data.core.order import build_remapping + +from . import Context + + +class FieldContext(Context): + + def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: + super().__init__() + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument.dates + + def create_result(self, data): + return FieldResult(self, data) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index b95f85568..a64f9a7ef 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -20,7 +20,7 @@ from .action import Action from .action import action_factory from .misc import _tidy -from .result import Result +from .result.field import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py index 410b4c973..fdc959f0f 100644 --- a/src/anemoi/datasets/create/input/empty.py +++ b/src/anemoi/datasets/create/input/empty.py @@ -14,7 +14,7 @@ from earthkit.data import FieldList from .misc import assert_fieldlist -from .result import Result +from .result.field import Result from .trace import trace_datasource LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py index 4d3d21b22..586003bb2 100644 --- a/src/anemoi/datasets/create/input/function.py +++ b/src/anemoi/datasets/create/input/function.py @@ -18,7 +18,7 @@ from .action import Action from .misc import _tidy from .misc import assert_fieldlist -from .result import Result +from .result.field import Result from .template import notify_result from .template import substitute from .trace import trace diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py deleted file mode 100644 index ba24c7072..000000000 --- a/src/anemoi/datasets/create/input/join.py +++ /dev/null @@ -1,130 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import List - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class JoinResult(Result): - """Represents a result that combines multiple results. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - - def __init__( - self, context: object, action_path: list, group_of_dates: GroupOfDates, results: List[Result], **kwargs: Any - ) -> None: - """Initializes a JoinResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - super().__init__(context, action_path, group_of_dates) - self.results: List[Result] = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the combined datasource from all results.""" - ds: FieldList = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self) -> str: - """Returns a string representation of the JoinResult instance.""" - content: str = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class JoinAction(Action): - """Represents an action that combines multiple actions. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - actions : List[Action] - The list of actions. - """ - - def __init__(self, context: object, action_path: list, *configs: dict) -> None: - """Initializes a JoinAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - *configs : dict - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def __repr__(self) -> str: - """Returns a string representation of the JoinAction instance.""" - content: str = "\n".join([str(i) for i in self.actions]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> JoinResult: - """Selects the results for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - JoinResult - The combined result for the given group of dates. - """ - results: List[Result] = [a.select(group_of_dates) for a in self.actions] - return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py deleted file mode 100644 index 6c9fea0df..000000000 --- a/src/anemoi/datasets/create/input/pipe.py +++ /dev/null @@ -1,66 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -from typing import Any - -from .action import Action -from .action import action_factory -from .step import step_factory -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class PipeAction(Action): - """A class to represent a pipeline of actions.""" - - def __init__(self, context: Any, action_path: list, *configs: dict) -> None: - """Initialize the PipeAction. - - Parameters - ---------- - context : Any - The context for the action. - action_path : list - The path of the action. - configs : dict - The configurations for the actions. - """ - super().__init__(context, action_path, *configs) - if len(configs) <= 1: - raise ValueError( - f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" - ) - - current: Any = action_factory(configs[0], context, action_path + ["0"]) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.last_step: Any = current - - @trace_select - def select(self, group_of_dates: Any) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to select data for. - - Returns - ------- - Any - The selected data. - """ - return self.last_step.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the PipeAction.""" - return f"PipeAction({self.last_step})" diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index ebf13b36e..37cc2dcf0 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -27,7 +27,7 @@ from .action import Action from .action import action_factory from .join import JoinResult -from .result import Result +from .result.field import Result from .trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result/__init__.py new file mode 100644 index 000000000..04c2ab733 --- /dev/null +++ b/src/anemoi/datasets/create/input/result/__init__.py @@ -0,0 +1,21 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Result(ABC): + pass diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result/field.py similarity index 99% rename from src/anemoi/datasets/create/input/result.py rename to src/anemoi/datasets/create/input/result/field.py index de5388fd6..cd334efd3 100644 --- a/src/anemoi/datasets/create/input/result.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,9 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from .action import ActionContext -from .trace import trace -from .trace import trace_datasource +from ..action import ActionContext +from ..trace import trace +from ..trace import trace_datasource LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py index e99717094..c88da3e71 100644 --- a/src/anemoi/datasets/create/input/step.py +++ b/src/anemoi/datasets/create/input/step.py @@ -19,7 +19,7 @@ from .action import Action from .action import ActionContext from .context import Context -from .result import Result +from .result.field import Result from .template import notify_result from .trace import trace_datasource from .trace import trace_select From 1df0ef758e3f3d0d9ad0fe57c2c0494c34012401 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 10:58:50 +0000 Subject: [PATCH 061/212] refactor --- src/anemoi/datasets/create/input/context/field.py | 1 + src/anemoi/datasets/create/input/result/field.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index c3a95f8cb..aee08c2c3 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -14,6 +14,7 @@ from earthkit.data.core.order import build_remapping from . import Context +from ..result.field import FieldResult class FieldContext(Context): diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index cd334efd3..66e29cec3 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -282,7 +282,7 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Result: +class FieldResult: """Class to represent the result of an action in the dataset creation process.""" empty: bool = False From 18df4ebe5fa4d39aef59820d70fbdd23b3d1fbee Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 11:10:07 +0000 Subject: [PATCH 062/212] refactor --- src/anemoi/datasets/create/__init__.py | 2 +- src/anemoi/datasets/create/input/__init__.py | 8 +- .../datasets/create/input/context/__init__.py | 18 +--- .../datasets/create/input/context/field.py | 8 +- .../datasets/create/input/result/field.py | 98 ++----------------- 5 files changed, 21 insertions(+), 113 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 9c8c4613a..635433e47 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -870,7 +870,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(group_of_dates=group) + result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 19269f587..ed39c9211 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -43,12 +43,12 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) self.config = config self.action_path = ["input"] - def select(self, group_of_dates: GroupOfDates) -> Any: + def select(self, argument: GroupOfDates) -> Any: """Select data based on the group of dates. Parameters ---------- - group_of_dates : GroupOfDates + argument : GroupOfDates Group of dates to select data for. Returns @@ -59,9 +59,9 @@ def select(self, group_of_dates: GroupOfDates) -> Any: from .action import action_factory """This changes the context.""" - context = FieldContext(**self.kwargs) + context = FieldContext(argument, **self.kwargs) action = action_factory(self.config, "input") - return context.create_result(action(context, group_of_dates)) + return context.create_result(action(context, argument)) def _trace_select(self, group_of_dates: GroupOfDates) -> str: """Trace the select operation. diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 23d31e6f3..11aa1570c 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -20,29 +20,15 @@ class Context(ABC): """Context for building input data.""" - def __init__(self): + def __init__(self, /, argument: Any) -> None: self.results = {} + self.argument = argument def trace(self, emoji, message) -> None: rich.print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: - """Register data in the context. - - Parameters - ---------- - data : Any - Data to register. - path : list[str] - Path where the data should be registered. - - Returns - ------- - Any - Registered data. - """ - # This is a placeholder for actual registration logic. rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index aee08c2c3..4586e1d31 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -13,14 +13,16 @@ from earthkit.data.core.order import build_remapping -from . import Context from ..result.field import FieldResult +from . import Context class FieldContext(Context): - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - super().__init__() + def __init__( + self, /, argument: Any, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool + ) -> None: + super().__init__(argument) self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 66e29cec3..dd238998e 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from ..action import ActionContext -from ..trace import trace -from ..trace import trace_datasource +from . import Result LOG = logging.getLogger(__name__) @@ -282,40 +280,22 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class FieldResult: +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: ActionContext, action_path: List[str], dates: Any) -> None: - """Initialize a Result instance. + def __init__(self, context: Any, datasource: Any) -> None: - Parameters - ---------- - context : ActionContext - The context in which the result exists. - action_path : list of str - The action path. - dates : Any - The dates associated with the result. - """ from anemoi.datasets.dates.groups import GroupOfDates - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - self.context: Any = context - self.group_of_dates: Any = dates - self.action_path: List[str] = action_path - - @property - @trace_datasource - def datasource(self) -> Any: - """Retrieve the data source for the result.""" - self._raise_not_implemented() + self.datasource = datasource + self.group_of_dates = context.argument + assert isinstance( + self.group_of_dates, GroupOfDates + ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" @property def data_request(self) -> Dict[str, Any]: @@ -330,7 +310,7 @@ def get_cube(self) -> Any: Any The data cube. """ - trace("🧊", f"getting cube from {self.__class__.__name__}") + ds: Any = self.datasource remapping: Any = self.context.remapping @@ -523,66 +503,6 @@ def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: print() exit(1) - def _repr(self, *args: Any, _indent_: str = "\n", **kwargs: Any) -> str: - """Return the string representation of the Result instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str - Indentation string. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more: str = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates: str = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt: str = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Result instance.""" - return self._repr() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Trace the data source for the result. - - Parameters - ---------- - args : Any - Additional positional arguments. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({self.group_of_dates})" - def build_coords(self) -> None: """Build the coordinates for the result.""" if self._coords_already_built: From 3341d4cf1cdf3a4c91de8f0fc8371296379cb888 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 9 Jul 2025 15:20:22 +0000 Subject: [PATCH 063/212] update --- src/anemoi/datasets/create/input/__init__.py | 19 +- src/anemoi/datasets/create/input/action.py | 171 ++++++---- .../datasets/create/input/context/__init__.py | 15 +- .../datasets/create/input/context/field.py | 18 +- src/anemoi/datasets/create/input/empty.py | 54 --- src/anemoi/datasets/create/input/filter.py | 118 ------- src/anemoi/datasets/create/input/function.py | 233 ------------- src/anemoi/datasets/create/input/step.py | 191 ----------- src/anemoi/datasets/create/input/template.py | 162 --------- .../datasets/create/sources/__init__.py | 11 - .../datasets/create/sources/constants.py | 2 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/legacy.py | 5 +- .../datasets/create/sources/repeated_dates.py | 319 ++++++++++++++++++ 14 files changed, 455 insertions(+), 865 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/empty.py delete mode 100644 src/anemoi/datasets/create/input/filter.py delete mode 100644 src/anemoi/datasets/create/input/function.py delete mode 100644 src/anemoi/datasets/create/input/step.py delete mode 100644 src/anemoi/datasets/create/input/template.py create mode 100644 src/anemoi/datasets/create/sources/repeated_dates.py diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index ed39c9211..cde4f6eaa 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -12,7 +12,6 @@ from typing import Union from anemoi.datasets.create.input.context.field import FieldContext -from anemoi.datasets.dates.groups import GroupOfDates class InputBuilder: @@ -41,9 +40,8 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) ) ) self.config = config - self.action_path = ["input"] - def select(self, argument: GroupOfDates) -> Any: + def select(self, argument) -> Any: """Select data based on the group of dates. Parameters @@ -62,18 +60,3 @@ def select(self, argument: GroupOfDates) -> Any: context = FieldContext(argument, **self.kwargs) action = action_factory(self.config, "input") return context.create_result(action(context, argument)) - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the select operation. - - Parameters - ---------- - group_of_dates : GroupOfDates - Group of dates to select data for. - - Returns - ------- - str - Trace string. - """ - return f"InputBuilder({group_of_dates})" diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 76d99f4b7..28e521b97 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -11,20 +11,9 @@ import rich -LOG = logging.getLogger(__name__) - - -class Predicate: - def __init__(self, config): - self.config = config - - def __repr__(self): - return f"Predicate({self.config})" +from anemoi.datasets.dates import DatesProvider - def match(self, dates): - # Just a demo - raise NotImplementedError("Not yet implemented") - return True +LOG = logging.getLogger(__name__) class Action: @@ -44,24 +33,26 @@ def __init__(self, config, *path): for item in config: assert "dates" in item, f"Value must contain the key 'date' {item}" - predicate = Predicate(item.pop("dates")) + dates = item.pop("dates") + filtering_dates = DatesProvider.from_config(**dates) action = action_factory(item) - self.choices.append((predicate, action)) + self.choices.append((filtering_dates, action)) def __repr__(self): return f"Concat({self.choices})" def __call__(self, context, argument): - for predicate, action in self.choices: - if predicate.match(argument): - return context.register( - action(context, argument), - self.path, - ) + results = context.empty_result() - raise ValueError(f"No matching predicate for dates: {argument}") + for filtering_dates, action in self.choices: + dates = context.matching_dates(filtering_dates, argument) + if len(dates) == 0: + continue + results += action(context, dates) + + return context.register(results, self.path) class Join(Action): @@ -70,42 +61,25 @@ def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [ - action_factory( - item, - *path, - "join", - str(i), - ) - for i, item in enumerate(config) - ] + self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Join({self.actions})" def __call__(self, context, argument): results = context.empty_result() + for action in self.actions: results += action(context, argument) - return context.register( - results, - self.path, - ) + + return context.register(results, self.path) class Pipe(Action): def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" super().__init__(config, *path) - self.actions = [ - action_factory( - item, - *path, - "pipe", - str(i), - ) - for i, item in enumerate(config) - ] + self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Pipe({self.actions})" @@ -119,44 +93,88 @@ def __call__(self, context, argument): else: result = action(context, result) - return context.register( - result, - self.path, - ) + return context.register(result, self.path) class Function(Action): def __init__(self, config, *path): super().__init__(config, *path, self.name) - -class SourceFunction(Function): - def __call__(self, context, argument): - from anemoi.datasets.create.sources import create_source config = context.resolve(self.config) # Substitute the ${} variables in the config + config["_type"] = self.name # Find a better way to do this - source = create_source(self, config) + + source = self.create_object(config) rich.print(f"Executing source {self.name} from {config}") - return context.register( - source.execute(context, context.source_argument(argument)), - self.path, - ) + return context.register(self.call_object(context, source, argument), self.path) + + +class DatasetSourceMixin: + def create_object(self, config): + from anemoi.datasets.create.sources import create_source as create_datasets_source + + return create_datasets_source(self, config) + + def call_object(self, context, source, argument): + return source.execute(context, context.source_argument(argument)) + + +class DatasetFilterMixin: + def create_object(self, config): + from anemoi.datasets.create.filters import create_filter as create_datasets_filter + + return create_datasets_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.execute(context.filter_argument(argument)) + + +class TransformSourceMixin: + def create_object(self, config): + from anemoi.transform.sources import create_source as create_transform_source + + return create_transform_source(self, config) + + +class TransformFilterMixin: + def create_object(self, config): + from anemoi.transform.filters import create_filter as create_transform_filter + + return create_transform_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.forward(context.filter_argument(argument)) class FilterFunction(Function): - pass + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) + +def _make_name(name, what): + name = name.replace("_", "-") + name = "".join(x.title() for x in name.split("-")) + return name + what.title() -def new_source(name): - return type(name.title(), (SourceFunction,), {"name": name}) +def new_source(name, mixin): + return type( + _make_name(name, "source"), + (Function, mixin), + {"name": name}, + ) -def new_filter(name): - return type(name.title(), (FilterFunction,), {"name": name}) + +def new_filter(name, mixin): + return type( + _make_name(name, "filter"), + (Function, mixin), + {"name": name}, + ) KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} @@ -167,14 +185,33 @@ def new_filter(name): def make(key, config, path): if LEN_KLASS == len(KLASS): + # Load pluggins - from anemoi.datasets.create.sources import registered_sources + from anemoi.transform.filters import filter_registry as transform_filter_registry + from anemoi.transform.sources import source_registry as transform_source_registry + + from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry + from anemoi.datasets.create.sources import source_registry as dataset_source_registry + + # Register sources, local first + for name in dataset_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + + for name in transform_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + + # Register filters, local first + for name in dataset_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) - for name in registered_sources(): - assert name not in KLASS, f"Duplicate source name: {name}" - KLASS[name] = new_source(name) + for name in transform_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) - return KLASS[key](config, *path) + return KLASS[key.replace("_", "-")](config, *path) def action_factory(data, *path): diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 11aa1570c..26d449659 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -22,13 +22,18 @@ class Context(ABC): def __init__(self, /, argument: Any) -> None: self.results = {} + self.cache = {} self.argument = argument - def trace(self, emoji, message) -> None: + def trace(self, emoji, *message) -> None: rich.print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: + + if not path: + return data + rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -46,11 +51,13 @@ def resolve(self, config): return config - @abstractmethod - def empty_result(self) -> Any: ... + def create_source(self, config: Any) -> Any: + from anemoi.datasets.create.input.action import action_factory + + return action_factory(config) @abstractmethod - def source_argument(self, argument: Any) -> Any: ... + def empty_result(self) -> Any: ... @abstractmethod def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 4586e1d31..c3456d89f 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -20,7 +20,13 @@ class FieldContext(Context): def __init__( - self, /, argument: Any, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool + self, + /, + argument: Any, + order_by: str, + flatten_grid: bool, + remapping: Dict[str, Any], + use_grib_paramid: bool, ) -> None: super().__init__(argument) self.order_by = order_by @@ -34,7 +40,15 @@ def empty_result(self) -> Any: return ekd.from_source("empty") def source_argument(self, argument: Any) -> Any: - return argument.dates + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument def create_result(self, data): return FieldResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py deleted file mode 100644 index fdc959f0f..000000000 --- a/src/anemoi/datasets/create/input/empty.py +++ /dev/null @@ -1,54 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import List - -from earthkit.data import FieldList - -from .misc import assert_fieldlist -from .result.field import Result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class EmptyResult(Result): - """Class to represent an empty result in the dataset creation process.""" - - empty = True - - def __init__(self, context: object, action_path: list, dates: object) -> None: - """Initializes an EmptyResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - dates : object - The dates object. - """ - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns an empty datasource.""" - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self) -> List[str]: - """Returns an empty list of variables.""" - return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py deleted file mode 100644 index 289bb3602..000000000 --- a/src/anemoi/datasets/create/input/filter.py +++ /dev/null @@ -1,118 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Type - -from earthkit.data import FieldList - -from .function import FunctionContext -from .misc import _tidy -from .misc import assert_fieldlist -from .step import StepAction -from .step import StepResult -from .template import notify_result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns the filtered datasource.""" - ds: FieldList = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - """Represents an action to filter a step result.""" - - result_class: Type[FilterStepResult] = FilterStepResult - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource after applying the function.""" - - self.action.filter.context = FunctionContext(self) - try: - return _tidy( - self.action.filter.execute( - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - A string representation of the traced datasource. - """ - return f"{self.action.name}({self.group_of_dates})" - - -class FunctionStepAction(StepAction): - """Represents an action to apply a function to a step result.""" - - result_class: Type[StepFunctionResult] = StepFunctionResult - - def __init__( - self, - context: object, - action_path: list, - previous_step: StepAction, - name: str, - filter: Any, - *args: Any, - **kwargs: Any, - ) -> None: - """Initializes a FunctionStepAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - previous_step : StepAction - The previous step action. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = name - self.filter = filter diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py deleted file mode 100644 index 586003bb2..000000000 --- a/src/anemoi/datasets/create/input/function.py +++ /dev/null @@ -1,233 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Dict - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .misc import _tidy -from .misc import assert_fieldlist -from .result.field import Result -from .template import notify_result -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner: Result) -> None: - """Initializes a FunctionContext instance. - - Parameters - ---------- - owner : object - The owner object. - """ - self.owner = owner - self.use_grib_paramid: bool = owner.context.use_grib_paramid - - def trace(self, emoji: str, *args: Any) -> None: - """Traces the given arguments with an emoji. - - Parameters - ---------- - emoji : str - The emoji to use. - *args : Any - The arguments to trace. - """ - trace(emoji, *args) - - def info(self, *args: Any, **kwargs: Any) -> None: - """Logs an info message. - - Parameters - ---------- - *args : Any - The arguments for the log message. - **kwargs : Any - The keyword arguments for the log message. - """ - LOG.info(*args, **kwargs) - - @property - def dates_provider(self) -> object: - """Returns the dates provider.""" - return self.owner.group_of_dates.provider - - @property - def partial_ok(self) -> bool: - """Returns whether partial results are acceptable.""" - return self.owner.group_of_dates.partial_ok - - def get_result(self, *args, **kwargs) -> Any: - return self.owner.context.get_result(*args, **kwargs) - - -class FunctionAction(Action): - """Represents an action that executes a function. - - Attributes - ---------- - name : str - The name of the function. - """ - - def __init__(self, context: object, action_path: list, _name: str, source, **kwargs: Dict[str, Any]) -> None: - """Initializes a FunctionAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - _name : str - The name of the function. - **kwargs : Dict[str, Any] - Additional keyword arguments. - """ - super().__init__(context, action_path, **kwargs) - self.name: str = _name - self.source = source - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": - """Selects the function result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - FunctionResult - The function result instance. - """ - return FunctionResult(self.context, self.action_path, group_of_dates, action=self) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionAction instance.""" - content: str = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return self._repr(_inline_=content, _indent_=" ") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Traces the selection of the function for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - str - The trace string. - """ - return f"{self.name}({group_of_dates})" - - -class FunctionResult(Result): - """Represents the result of executing a function. - - Attributes - ---------- - action : Action - The action instance. - args : tuple - The positional arguments for the function. - kwargs : dict - The keyword arguments for the function. - """ - - def __init__(self, context: object, action_path: list, group_of_dates: GroupOfDates, action: Action) -> None: - """Initializes a FunctionResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - action : Action - The action instance. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(action, Action), type(action) - self.action: Action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource for the function result.""" - # args, kwargs = resolve(self.context, (self.args, self.kwargs)) - self.action.source.context = FunctionContext(self) - - return _tidy( - self.action.source.execute( - list(self.group_of_dates), # Will provide a list of datetime objects - ) - ) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionResult instance.""" - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self) -> None: - """Raises NotImplementedError as this property is not implemented. - - Raises - ------ - NotImplementedError - Always raised. - """ - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py deleted file mode 100644 index c88da3e71..000000000 --- a/src/anemoi/datasets/create/input/step.py +++ /dev/null @@ -1,191 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Type - -from .action import Action -from .action import ActionContext -from .context import Context -from .result.field import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class StepResult(Result): - """Represents the result of a step in the data processing pipeline.""" - - def __init__( - self, context: Context, action_path: List[str], group_of_dates: Any, action: Action, upstream_result: Result - ) -> None: - """Initialize a StepResult instance. - - Parameters - ---------- - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - group_of_dates - The group of dates associated with this step. - action - The action associated with this step. - upstream_result - The result of the upstream step. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result: Result = upstream_result - self.action: Action = action - - @property - @notify_result - @trace_datasource - def datasource(self) -> Any: - """Retrieve the datasource associated with this step result.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - """Represents an action that is part of a step in the data processing pipeline.""" - - result_class: Optional[Type[StepResult]] = None - - def __init__( - self, context: ActionContext, action_path: List[str], previous_step: Any, *args: Any, **kwargs: Any - ) -> None: - """Initialize a StepAction instance. - - Parameters - ---------- - context - The context in which the action is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - """ - super().__init__(context, action_path, *args, **kwargs) - self.previous_step: Any = previous_step - - @trace_select - def select(self, group_of_dates: Any) -> StepResult: - """Select the result for a given group of dates. - - Parameters - ---------- - group_of_dates - The group of dates to select the result for. - - Returns - ------- - unknown - The result of the step. - """ - return self.result_class( - self.context, - self.action_path, - group_of_dates, - self, - self.previous_step.select(group_of_dates), - ) - - def __repr__(self) -> str: - """Return a string representation of the StepAction instance. - - Returns - ------- - unknown - String representation of the instance. - """ - return self._repr(self.previous_step, _inline_=str(self.kwargs)) - - -def step_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str], previous_step: Any) -> Any: - """Factory function to create a step action based on the given configuration. - - Parameters - ---------- - config - The configuration dictionary for the step. - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - - Returns - ------- - unknown - An instance of a step action. - """ - - from .filter import FilterStepAction - from .filter import FunctionStepAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is not None: - return cls(context, action_path, previous_step, *args, **kwargs) - - # Try filters from datasets filter registry - from anemoi.transform.filters import filter_registry as transform_filter_registry - - from ..filters import create_filter as create_datasets_filter - from ..filters import filter_registry as datasets_filter_registry - - if datasets_filter_registry.is_registered(key): - - if transform_filter_registry.is_registered(key): - warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") - - filter = create_datasets_filter(None, config) - return FunctionStepAction(context, action_path + [key], previous_step, key, filter) - - # Use filters from transform registry - - if transform_filter_registry.is_registered(key): - from ..filters.transform import TransformFilter - - return FunctionStepAction( - context, action_path + [key], previous_step, key, TransformFilter(context, key, config) - ) - - raise ValueError(f"Unknown step action `{key}`") diff --git a/src/anemoi/datasets/create/input/template.py b/src/anemoi/datasets/create/input/template.py deleted file mode 100644 index 8ea1ec275..000000000 --- a/src/anemoi/datasets/create/input/template.py +++ /dev/null @@ -1,162 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import re -from abc import ABC -from abc import abstractmethod -from functools import wraps -from typing import Any -from typing import Callable -from typing import List - -from .context import Context - -LOG = logging.getLogger(__name__) - - -def notify_result(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to notify the context of the result of the method call. - - Parameters - ---------- - method : Callable[..., Any] - The method to wrap. - - Returns - ------- - Callable[..., Any] - The wrapped method. - """ - - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result: Any = method(self, *args, **kwargs) - self.context.notify_result(self.action_path, result) - return result - - return wrapper - - -class Substitution(ABC): - """Abstract base class for substitutions in templates.""" - - @abstractmethod - def resolve(self, context: Context) -> Any: - """Resolve the substitution using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - pass - - -class Reference(Substitution): - """A class to represent a reference to another value in the context.""" - - def __init__(self, context: Any, action_path: List[str]) -> None: - """Initialize a Reference instance. - - Parameters - ---------- - context : Any - The context in which the reference exists. - action_path : list of str - The action path to resolve. - """ - self.context: Any = context - self.action_path: List[str] = action_path - - def resolve(self, context: Context) -> Any: - """Resolve the reference using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - return context.get_result(self.action_path) - - -def resolve(context: Context, x: Any) -> Any: - """Recursively resolve substitutions in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for resolution. - x : Union[tuple, list, dict, Substitution, Any] - The structure to resolve. - - Returns - ------- - Any - The resolved structure. - """ - if isinstance(x, tuple): - return tuple([resolve(context, y) for y in x]) - - if isinstance(x, list): - return [resolve(context, y) for y in x] - - if isinstance(x, dict): - return {k: resolve(context, v) for k, v in x.items()} - - if isinstance(x, Substitution): - return x.resolve(context) - - return x - - -def substitute(context: Context, x: Any) -> Any: - """Recursively substitute references in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for substitution. - x : Union[tuple, list, dict, str, Any] - The structure to substitute. - - Returns - ------- - Any - The substituted structure. - """ - if isinstance(x, tuple): - return tuple([substitute(context, y) for y in x]) - - if isinstance(x, list): - return [substitute(context, y) for y in x] - - if isinstance(x, dict): - return {k: substitute(context, v) for k, v in x.items()} - - if not isinstance(x, str): - return x - - if re.match(r"^\${[\.\w\-]+}$", x): - path = x[2:-1].split(".") - context.will_need_reference(path) - return Reference(context, path) - - return x diff --git a/src/anemoi/datasets/create/sources/__init__.py b/src/anemoi/datasets/create/sources/__init__.py index 1710d1303..f8b99f36d 100644 --- a/src/anemoi/datasets/create/sources/__init__.py +++ b/src/anemoi/datasets/create/sources/__init__.py @@ -34,14 +34,3 @@ def create_source(context: Any, config: Any) -> Any: The created source. """ return source_registry.from_config(config, context) - - -def registered_sources() -> list[str]: - """Get a list of registered source names. - - Returns - ------- - list[str] - A list of names of registered sources. - """ - return source_registry.registered diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 921469025..1958820c4 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -47,7 +47,7 @@ def constants(context: Any, dates: List[str], template: Dict[str, Any], param: s if len(template) == 0: raise ValueError("Forcings template is empty.") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index e1944e151..8e2977273 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -36,7 +36,7 @@ def forcings(context: Any, dates: List[str], template: str, param: str) -> Any: Loaded forcing data. """ context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute = forcings diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index f8035ee16..c76a11c9b 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,6 @@ from typing import Any from typing import Callable -from anemoi.datasets.create.input.template import resolve - from ..source import Source from . import source_registry @@ -74,7 +72,8 @@ def __call__(self, execute: Callable) -> Callable: def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(context, (self.args, self.kwargs)) + # args, kwargs = resolve(context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs try: return execute(context, dates, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py new file mode 100644 index 000000000..d092f08ad --- /dev/null +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -0,0 +1,319 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import Generator +from typing import Optional +from typing import Set +from typing import Tuple + +import numpy as np +import rich +from anemoi.transform.fields import new_field_with_valid_datetime +from anemoi.transform.fields import new_fieldlist_from_list +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +LOG = logging.getLogger(__name__) + + +class Action: + pass + + +class Result: + pass + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: Set[Any] = set() + self.found: Set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: Optional[int] = hour + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Optional[Any] = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Optional[Any] = date + + def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +@source_registry.register("repeated_dates") +class RepeatedDatesSource(Source): + + def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.mapper = DateMapper.from_mode(mode, source, kwargs) + self.source = source + + def execute(self, context, group_of_dates): + source = context.create_source(self.source) + + result = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + source_results = source(context, one_date_group) + for field in source_results: + for date in many_dates_group: + result.append(new_field_with_valid_datetime(field, date)) + + return new_fieldlist_from_list(result) From 58dc8a2652482fe0fc72c89e48c71929fd709f16 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 10 Jul 2025 12:44:57 +0000 Subject: [PATCH 064/212] work on migrate --- src/anemoi/datasets/commands/migrate.py | 602 ++++++++++++++--------- src/anemoi/datasets/commands/validate.py | 44 ++ src/anemoi/datasets/create/__init__.py | 41 ++ src/anemoi/datasets/schemas/recipe.json | 131 +++++ 4 files changed, 589 insertions(+), 229 deletions(-) create mode 100644 src/anemoi/datasets/commands/validate.py create mode 100644 src/anemoi/datasets/schemas/recipe.json diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 3183508b7..8d4b359ac 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -8,326 +8,443 @@ # nor does it submit to any jurisdiction. +import datetime import logging -from copy import deepcopy +import os +from collections.abc import Mapping +from collections.abc import Sequence from typing import Any +import rich import yaml +from glom import assign +from glom import delete +from glom import glom + +from anemoi.datasets.create import validate_config from . import Command -errors = [] LOG = logging.getLogger(__name__) -ORDER = ("name", "description", "licence", "input", "output", "statistics", "build") + +class MyDumper(yaml.SafeDumper): + pass + + +def find_paths(data, target_key=None, target_value=None, *path): + + matches = [] + + if isinstance(data, Mapping): + for k, v in data.items(): + if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): + matches.append(list(path) + [k]) + matches.extend(find_paths(v, target_key, target_value, *path, k)) + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)): + for i, item in enumerate(data): + matches.extend(find_paths(item, target_key, target_value, *path, str(i))) + return matches + + +# Custom representer for datetime.date and datetime.datetime +def represent_date(dumper, data): + if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): + data = datetime.datetime(data.year, data.month, data.day, 0, 0, 0) + # Ensure it's UTC + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + # Format as ISO 8601 with 'Z' + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# Custom representer for multiline strings using the '|' block style +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + # Flow style if list has <= 4 simple elements + if ( + all(isinstance(i, (str, int, float, bool, type(None))) for i in data) + and len(", ".join([str(x) for x in data])) + 2 <= 80 + ): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + +# Register custom representers +MyDumper.add_representer(datetime.date, represent_date) +MyDumper.add_representer(datetime.datetime, represent_date) +MyDumper.add_representer(str, represent_multiline_str) +MyDumper.add_representer(list, represent_inline_list) + + +def make_dates(config): + if isinstance(config, dict): + return {k: make_dates(v) for k, v in config.items()} + if isinstance(config, list): + return [make_dates(v) for v in config] + if isinstance(config, str): + try: + return datetime.datetime.fromisoformat(config) + except ValueError: + return config + return config + + +ORDER = ( + "name", + "description", + "dataset_status", + "licence", + "attribution", + "env", + "dates", + "common", + "data_sources", + "input", + "output", + "statistics", + "build", + "platform", +) ORDER = {k: i for i, k in enumerate(ORDER)} def order(x: str) -> int: - if x[0] not in ORDER: - ORDER[x[0]] = len(ORDER) - - return ORDER[x[0]] + try: + return ORDER[x[0]] + except KeyError: + rich.print(f"Unknown key: {x}") + raise MIGRATE = { "output.statistics_end": "statistics.end", "has_nans": "statistics.allow_nans", "loop.dates.group_by": "build.group_by", + "loop.0.dates.group_by": "build.group_by", "loop.dates": "dates", + "loop.0.dates": "dates", "copyright": "attribution", + "dates.<<": "dates", + "options.group_by": "build.group_by", + "loops.0.loop_a.dates": "dates", + "loop.0.loop_a.dates": "dates", + "dates.stop": "dates.end", + "dates.group_by": "build.group_by", + "include": "data_sources", + "ensemble_dimension": "build.ensemble_dimension", + "flatten_grid": "build.flatten_grid", } +DELETE = [ + "purpose", + "input.join.0.label", + "status", + "common", + "config_format_version", + "aliases", + "platform", + "loops.0.loop_a.applies_to", + "loop.0.loop_a.applies_to", + "dataset_status", + "alias", + "resources", +] + + SOURCES = { "oper-accumulations": "accumulations", "era5-accumulations": "accumulations", - "constants": "forcings", "ensemble-perturbations": "recentre", "ensemble_perturbations": "recentre", "perturbations": "recentre", "custom-regrid": "regrid", } +MARKER = object() -def _move(config, path, new_path, result): - path = path.split(".") - if new_path is not None: - new_path = new_path.split(".") - - for k in path[:-1]: - if k not in config: - return - config = config[k] - if path[-1] not in config: +def _delete(config, path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: return + rich.print(f"Deleting {path}={x}") + delete(result, path) - value = config.pop(path[-1]) - if new_path is None: +def _move(config, path, new_path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: return + rich.print(f"Moving {path}={x} to {new_path}={x}") + delete(result, path) + assign(result, new_path, x, missing=dict) - for k in new_path[:-1]: - if k not in result: - result[k] = {} - result = result[k] - result[new_path[-1]] = value +def _fix_input_0(result, config): + if isinstance(config["input"], dict): + return + input = config["input"] + new_input = result["input"] = [] -def _fix_dates(result, config): - dates = config["input"].pop("dates") - assert "join" in dates, dates - result["input"] = dates["join"] - config["input"] = result["input"].copy() + blocks = {} + first = None + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block -def _fix_list(result, config): - result["input"] = dict(join=result["input"]) - config["input"] = result["input"].copy() + block_name, values = list(block.items())[0] + if "kwargs" in values: + inherit = values.pop("inherit", None) + assert len(values) == 1, values + values = values["kwargs"] + values.pop("date", None) + source_name = values.pop("name", None) -def _fix_join_0(result, config): + if inherit is not None: + inherited = blocks[inherit].copy() + inherited.update(values) + values = inherited - join = config["input"]["join"] + if "source_or_dataset" in values: + values.pop("source_or_dataset", None) + values["template"] = "${input.join.0." + first + "}" - new_join = [] - for n in join: + if first is None: + first = source_name - if "function" in n: - f = n["function"] - name = f.pop("name") - data = _tidy(f) - for k, v in list(data.items()): - if isinstance(v, dict): - if "name" in v: - p = v.pop("name") - data[k] = {SOURCES.get(p, p): _tidy(v)} + blocks[block_name] = values.copy() - new_join.append({SOURCES.get(name, name): data}) - continue + new_input.append({block_name: {SOURCES.get(source_name, source_name): values.copy()}}) + else: + assert False, f"Block {block_name} does not have 'kwargs': {values}" - new_join.append(n) # {SOURCES.get(src, src): _tidy(data)}) + blocks[block_name] = values.copy() - result["input"] = dict(join=new_join) config["input"] = result["input"].copy() -def _fix_join_1(result, config): - - join = config["input"].pop("join") - new_join = [] - for n in join: - if isinstance(n, dict): - if len(n) == 1: - if "label" in n: - n = n["label"] - - if isinstance(n, dict): - if len(n) == 2: - if "name" in n and "source" in n: - n.pop("name") - - if isinstance(n, dict): - if len(n) == 1: - if "source" in n: - n = n["source"] - if "<<" in n: - n.update(n.pop("<<")) - name = n.pop("name", "mars") - new_join.append({SOURCES.get(name, name): _tidy(n)}) - continue - - new_join.append(n) - - result["input"] = dict(join=new_join) - config["input"] = result["input"].copy() - - -def _fix_join_3(result, config): - - join = config["input"].pop("join") - new_join = [] - for n in join: - if not isinstance(n, dict): - return - if len(n) != 1: - return - - name = list(n.keys())[0] - data = n[name] +def _fix_input_1(result, config): + if isinstance(config["input"], dict): + return - new_join.append({SOURCES.get(name, name): data}) + input = config["input"] + join = [] + for k in input: + assert isinstance(k, dict) + assert len(k) == 1, f"Input key {k} is not a string: {input}" + name, values = list(k.items())[0] + join.append(values) - result["input"] = dict(join=new_join) + result["input"] = {"join": join} config["input"] = result["input"].copy() -def _tidy(data): - for k, v in list(data.items()): - if k in ("date", "time"): - if isinstance(v, str) and v.startswith("$"): - del data[k] - - if "name" in data: - assert False, data - name = data.pop("name") - return {SOURCES.get(name, name): _tidy(data)} - - return data - +def remove_empties(config: dict) -> None: + """Remove empty dictionaries and lists from the config.""" + if isinstance(config, dict): + keys_to_delete = [k for k, v in config.items() if v in (None, {}, [], [{}])] -def _fix_join_2(result, config): + for k in keys_to_delete: + del config[k] - join = config["input"]["join"] + for k, v in config.items(): + remove_empties(v) - previous = {} + if isinstance(config, list): + for item in config: + remove_empties(item) - new_join = [] - for n in join: - if not isinstance(n, dict): - return - - if len(n) != 1: - return - - what = list(n.keys())[0] +def _fix_loops(result: dict, config: dict) -> None: + if "loops" not in config: + return - if n[what] is None: - assert False, (n, what, config["input"]) + input = config["input"] + loops = config["loops"] + + assert isinstance(loops, list), loops + assert isinstance(input, list), input + + entries = {} + dates_block = None + for loop in loops: + assert isinstance(loop, dict), loop + assert len(loop) == 1, loop + loop = list(loop.values())[0] + applies_to = loop["applies_to"] + dates = loop["dates"] + assert isinstance(applies_to, list), (applies_to, loop) + for a in applies_to: + entries[a] = dates.copy() + + if "start" in dates: + start = dates["start"] + else: + start = max(dates["values"]) + + if "end" in dates or "stop" in dates: + end = dates.get("end", dates.get("stop")) + else: + end = min(dates["values"]) + + if dates_block is None: + dates_block = { + "start": start, + "end": end, + } + + if "frequency" in dates: + if "frequency" not in dates_block: + dates_block["frequency"] = dates["frequency"] + else: + assert dates_block["frequency"] == dates["frequency"], (dates_block["frequency"], dates["frequency"]) + + dates_block["start"] = min(dates_block["start"], start) + dates_block["end"] = max(dates_block["end"], end) + + concat = [] + result["input"] = {"concat": concat} + + rich.print("Found loops:", entries) + + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block + name, values = list(block.items())[0] + assert name in entries, f"Loop {name} not found in loops: {list(entries.keys())}" + dates = entries[name].copy() + + assert "kwargs" not in values + + concat.append(dict(dates=dates, **values)) + + d = concat[0]["dates"] + if all(c["dates"] == d for c in concat): + join = [] + for c in concat: + del c["dates"] + join.append(c) + result["input"] = {"join": join} + + del config["loops"] + config["input"] = result["input"].copy() + config["dates"] = dates_block.copy() + del result["loops"] + result["dates"] = dates_block - if "kwargs" not in n[what]: - return - # assert False +def _fix_other(result: dict, config: dict) -> None: + paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") + for p in paths: + rich.print(f"Fixing {'.'.join(p)}") + assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) + delete(result, ".".join(p)) - previous[what] = deepcopy(n[what]["kwargs"]) - if "inherit" in n[what]: - previous[what].update(deepcopy(previous[n[what]["inherit"]])) + paths = find_paths(config, target_key="date", target_value="$dates") + for p in paths: + delete(result, ".".join(p)) - data = previous[what].copy() - src = data.pop("name", "mars") - new_join.append({SOURCES.get(src, src): _tidy(data)}) +def _migrate(config: dict, n) -> dict: - result["input"] = dict(join=new_join) - config["input"] = result["input"].copy() + result = config.copy() + _fix_input_0(result, config) + _fix_loops(result, config) + _fix_input_1(result, config) + _fix_other(result, config) -def _migrate(config: dict, n) -> dict: - result = config.copy() for k, v in MIGRATE.items(): _move(config, k, v, result) - if "dates" in config["input"]: - _fix_dates(result, config) - - if isinstance(config["input"], list): - _fix_list(result, config) - - if "join" in config["input"]: - _fix_join_0(result, config) - - if "join" in config["input"]: - _fix_join_1(result, config) - - if "join" in config["input"]: - _fix_join_2(result, config) - - if "join" in config["input"]: - _fix_join_3(result, config) - - # _check(result, "1") - - # if isinstance(result["input"], list): - # assert n == 0 - # join = [] - # prev = {} - # for n in result["input"]: - # assert isinstance(n, dict), (n, type(n)) - # assert len(n) == 1, (n, type(n)) - # name = list(n.keys())[0] - # prev[name] = n[name]["kwargs"] - # if "inherit" in n[name]: - # i = n[name]["inherit"] - # n[name]["kwargs"].update(prev[i]) - # n[name].pop("inherit") - - # data = n[name]["kwargs"] - - # src = data.pop("name", "mars") - - # join.append({SOURCES.get(src, src): data}) - - # result["input"] = dict(join=join) - # _check(result, "2") - - # if "join" in result["input"] and n == 0: - # join = result["input"].pop("join") - # new_join = [] - - # for j in join: - - # if "label" in j: - # if isinstance(j["label"], str): - # j.pop("label") - # else: - # if j["label"] is not None: - # j = j["label"] - # j.pop("name", None) - - # if "source" in j: - # j = j["source"] - - # src = j.pop("name", "mars") - # data = j - # if "<<" in data: - # data.update(data.pop("<<")) + for k in DELETE: + _delete(config, k, result) - # for k, v in list(data.items()): - # if k in ("date", "time"): - # if isinstance(v, str) and v.startswith("$"): - # del data[k] - - # if "mars" in data: - # new_join.append(data) - # else: - # new_join.append({SOURCES.get(src, src): data}) - - # result["input"]["join"] = new_join - # _check(result, "3") - - # if "join" in result["input"]: - # for j in result["input"]["join"]: - # k = list(j.keys())[0] - # j[k].pop("name", None) - - # if "source_or_dataset" in j[k]: - # j[k].pop("source_or_dataset", None) - # j[k]["template"] = "${input.0.join.0.mars}" - # _check(result, "4") - - result = {k: v for k, v in sorted(result.items(), key=order) if v} - - result.pop("loop", None) + remove_empties(result) return result def migrate(old: dict) -> dict: - # return _migrate(old) + for i in range(10): new = _migrate(old, i) if new == old: - # print(json.dumps(new, indent=2, default=str)) return new old = new return new +def has_key(config, key: str) -> bool: + if isinstance(config, dict): + if key in config: + return True + for k, v in config.items(): + if has_key(v, key): + return True + if isinstance(config, list): + for item in config: + if has_key(item, key): + return True + return False + + +def has_value(config, value: str) -> bool: + if isinstance(config, dict): + for k, v in config.items(): + if v == value: + return True + if has_value(v, value): + return True + + if isinstance(config, list): + for item in config: + if item == value: + return True + if has_value(item, value): + return True + return config == value + + +def check(config): + from anemoi.datasets.create import validate_config + + try: + + validate_config(config) + assert config.get("input", {}) + assert config.get("dates", {}) + assert not has_key(config, "label") + assert not has_key(config, "kwargs") + assert not has_value(config, "$previous_data") + assert not has_value(config, "$dates") + assert not has_key(config, "inherit") + assert not has_key(config, "source_or_dataset") + assert not has_key(config, "<<") + + for n in SOURCES.keys(): + assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." + + except Exception as e: + rich.print(f"Validation failed: {e}") + rich.print(f"Config: {config}") + raise + + class Recipe(Command): def add_arguments(self, command_parser: Any) -> None: """Add arguments to the command parser. @@ -343,12 +460,39 @@ def add_arguments(self, command_parser: Any) -> None: ) def run(self, args: Any) -> None: + + rich.print(f"Migrating {args.path}") + with open(args.path, "r") as file: config = yaml.safe_load(file) - print(yaml.safe_dump(migrate(config), sort_keys=False, indent=2, width=120)) + try: + validate_config(config) + LOG.info(f"{args.path}: Validation successful.") + return + except Exception: + pass + + migrated = migrate(config) + + migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} + + check(migrated) + if migrated == config: + LOG.info(f"{args.path}: No changes needed.") + return + + migrated = make_dates(migrated) + text = yaml.dump(migrated, default_flow_style=False, sort_keys=False, indent=2, width=120, Dumper=MyDumper) + + LOG.info(f"{args.path}: updating.") + with open(args.path + ".tmp", "w") as f: + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) - assert not errors, f"Errors: {errors}" + os.rename(args.path + ".tmp", args.path) command = Recipe diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py new file mode 100644 index 000000000..84b25c6f8 --- /dev/null +++ b/src/anemoi/datasets/commands/validate.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import yaml + +from . import Command + +LOG = logging.getLogger(__name__) + + +class Validate(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + from anemoi.datasets.create import validate_config + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + validate_config(config) + + +command = Validate diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index b6b51cb69..37bf11763 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1631,6 +1631,47 @@ def creator_factory(name: str, trace: Optional[str] = None, **kwargs: Any) -> An return cls(**kwargs) +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise + + def config_to_python(config: Any) -> Any: config = loader_config(config) diff --git a/src/anemoi/datasets/schemas/recipe.json b/src/anemoi/datasets/schemas/recipe.json new file mode 100644 index 000000000..3c02bfd64 --- /dev/null +++ b/src/anemoi/datasets/schemas/recipe.json @@ -0,0 +1,131 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$id": "https://ecmwf.int/anemoi-datasets-recipe.schema.json", + "title": "Product", + "description": "Anemoi datasets recipe configuration", + "additionalProperties": false, + "$defs": { + "source-or-filter": { + "type": "object", + "minProperties": 1, + "maxProperties": 1 + }, + "pipe": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "join": { + "type": "array", + "items": { + "$ref": "#/$defs/input-object" + } + }, + "concat": { + "type": "array", + "items": { + "type": "object", + "minProperties": 2, + "maxProperties": 2, + "required": [ + "dates" + ] + } + }, + "input-object": { + "oneOf": [ + { + "$ref": "#/$defs/pipe" + }, + { + "$ref": "#/$defs/join" + }, + { + "$ref": "#/$defs/concat" + }, + { + "$ref": "#/$defs/source-or-filter" + } + ] + } + }, + "properties": { + "env": { + "type": "object" + }, + "description": { + "type": "string" + }, + "name": { + "type": "string" + }, + "licence": { + "type": "string" + }, + "attribution": { + "type": "string" + }, + "dates": { + "type": "object", + "required": [ + "start", + "end" + ], + "properties": { + "start": { + "type": "string", + "format": "date" + }, + "end": { + "type": "string", + "format": "date" + }, + "frequency": { + "type": [ + "integer", + "string" + ] + }, + "group_by": { + "type": [ + "integer", + "string" + ] + } + } + }, + "input": { + "$ref": "#/$defs/input-object" + }, + "data_sources": { + "type": "object", + "patternProperties": { + "^[a-zA-Z_][a-zA-Z0-9_]*$": { + "$ref": "#/$defs/input-object" + } + }, + "additionalProperties": false + }, + "output": { + "type": "object" + }, + "statistics": { + "type": "object" + }, + "build": { + "type": "object" + }, + "common": { + "type": "object" + }, + "platform": { + "type": "object" + } + }, + "required": [ + "dates", + "input" + ] +} From 3e180f93d579df109027843607a1ceb9791b4dfc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 10 Jul 2025 14:53:05 +0000 Subject: [PATCH 065/212] work on migrate --- src/anemoi/datasets/commands/migrate.py | 101 +++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/migrate.py index 8d4b359ac..fa74886b8 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/migrate.py @@ -11,7 +11,6 @@ import datetime import logging import os -from collections.abc import Mapping from collections.abc import Sequence from typing import Any @@ -36,7 +35,7 @@ def find_paths(data, target_key=None, target_value=None, *path): matches = [] - if isinstance(data, Mapping): + if isinstance(data, dict): for k, v in data.items(): if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): matches.append(list(path) + [k]) @@ -47,6 +46,21 @@ def find_paths(data, target_key=None, target_value=None, *path): return matches +def find_chevrons(data, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if k == "<<": + matches.append(list(path) + [k]) + matches.extend(find_chevrons(v, *path, k)) + elif isinstance(data, list): + for i, item in enumerate(data): + matches.extend(find_chevrons(item, *path, str(i))) + return matches + + # Custom representer for datetime.date and datetime.datetime def represent_date(dumper, data): if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): @@ -63,7 +77,9 @@ def represent_date(dumper, data): # Custom representer for multiline strings using the '|' block style def represent_multiline_str(dumper, data): if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + text_list = [line.rstrip() for line in data.splitlines()] + fixed_data = "\n".join(text_list) + return dumper.represent_scalar("tag:yaml.org,2002:str", fixed_data, style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -358,6 +374,82 @@ def _fix_other(result: dict, config: dict) -> None: delete(result, ".".join(p)) +def _fix_join(result: dict, config: dict) -> None: + rich.print("Fixing join...") + input = config["input"] + if "dates" in input and "join" in input["dates"]: + result["input"]["join"] = input["dates"]["join"] + config["input"]["join"] = input["dates"]["join"].copy() + + if "join" not in input: + return + + join = input["join"] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + if key not in ("label", "source"): + return + + assert isinstance(values, dict), f"Join values for {key} should be a dict: {values}" + if key == "label": + j = values + j.pop("name") + key, values = list(j.items())[0] + + print(values) + source_name = values.pop("name", "mars") + new_join.append( + { + SOURCES.get(source_name, source_name): values, + } + ) + + result["input"] = {"join": new_join} + config["input"] = result["input"].copy() + + +def _fix_sources(result: dict, config: dict, what) -> None: + + input = config["input"] + if what not in input: + return + + join = input[what] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + key = SOURCES.get(key, key) + + new_join.append( + { + key: values, + } + ) + + result["input"][what] = new_join + config["input"][what] = new_join.copy() + + +def _fix_chevrons(result: dict, config: dict) -> None: + rich.print("Fixing chevrons...") + paths = find_chevrons(config) + for p in paths: + a = glom(config, ".".join(p)) + b = glom(config, ".".join(p[:-1])) + delete(result, ".".join(p)) + a.update(b) + assign(result, ".".join(p[:-1]), a) + + def _migrate(config: dict, n) -> dict: result = config.copy() @@ -365,6 +457,9 @@ def _migrate(config: dict, n) -> dict: _fix_input_0(result, config) _fix_loops(result, config) _fix_input_1(result, config) + _fix_join(result, config) + _fix_sources(result, config, "join") + _fix_chevrons(result, config) _fix_other(result, config) for k, v in MIGRATE.items(): From a9816ca11f99b9bb31c3a44264bb1fa943a97ca2 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 16 Jul 2025 15:54:31 +0000 Subject: [PATCH 066/212] fix --- src/anemoi/datasets/data/dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index b9e29dd24..bdcf51813 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -187,7 +187,12 @@ def __subset(self, **kwargs: Any) -> "Dataset": frequency = kwargs.pop("frequency", self.frequency) return ( Padded( - self, start, end, frequency, dict(start=start, end=end, frequency=frequency, padding=padding) + self, + start=start, + end=end, + frequency=frequency, + padding=padding, + reason=dict(start=start, end=end, frequency=frequency, padding=padding), ) ._subset(**kwargs) .mutate() From 20fc1859222a4ab2d7cb0020ca1e30d188fa5fcf Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 17 Jul 2025 16:06:26 +0000 Subject: [PATCH 067/212] fix padding --- src/anemoi/datasets/data/padded.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index 0160b674f..aab5ed565 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -175,7 +175,6 @@ def _get_tuple(self, n: TupleIndex) -> NDArray[Any]: LOG.warning("Padded subset does not support tuple indexing, returning a list") return [self[i] for i in n] - @property def empty_item(self): if self.padding == "empty": return self.dataset.empty_item() From 558f1a60786bfa377a82a1823dc5131c3b3e4d76 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 5 Aug 2025 09:49:53 +0000 Subject: [PATCH 068/212] fix: _select with set_group --- src/anemoi/datasets/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index bdcf51813..0ab456588 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -141,7 +141,8 @@ def _subset(self, **kwargs: Any) -> "Dataset": if not kwargs: return self.mutate() - name = kwargs.pop("name", None) + name = kwargs.pop("set_group", None) # TODO(Florian) + name = kwargs.pop("name", name) result = self.__subset(**kwargs) result._name = name From 255c22df684878f2379f7da32c35d609cabfc577 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:05:00 +0200 Subject: [PATCH 069/212] merge --- CHANGELOG.md | 26 + docs/adr/adr-1.md | 63 ++ .../building/sources/anemoi-dataset.rst | 6 + .../datasets/building/sources/xarray-zarr.rst | 6 + .../sources/yaml/anemoi-zarr-dataset.yaml | 3 + docs/datasets/using/combining.rst | 4 +- docs/howtos/create/02-cf-data.rst | 9 + docs/howtos/create/yaml/zarr2.yaml | 8 + docs/howtos/usage/code/cutout-complement1.py | 1 + .../howtos/usage/yaml/cutout-complement1.yaml | 1 + pyproject.toml | 14 +- .../datasets/commands/finalise-additions.py | 3 +- src/anemoi/datasets/commands/finalise.py | 3 +- .../datasets/commands/init-additions.py | 3 +- .../datasets/commands/load-additions.py | 3 +- src/anemoi/datasets/commands/load.py | 3 +- src/anemoi/datasets/commands/recipe.py | 6 +- src/anemoi/datasets/create/__init__.py | 86 ++- src/anemoi/datasets/create/input/__init__.py | 86 +-- src/anemoi/datasets/create/input/action.py | 542 +++++++----------- src/anemoi/datasets/create/input/concat.py | 184 ------ src/anemoi/datasets/create/input/context.py | 89 --- .../datasets/create/input/context/__init__.py | 63 ++ .../datasets/create/input/context/field.py | 54 ++ .../datasets/create/input/data_sources.py | 11 +- src/anemoi/datasets/create/input/empty.py | 54 -- src/anemoi/datasets/create/input/filter.py | 133 ----- src/anemoi/datasets/create/input/function.py | 244 -------- src/anemoi/datasets/create/input/join.py | 137 ----- src/anemoi/datasets/create/input/pipe.py | 77 --- .../datasets/create/input/repeated_dates.py | 28 +- .../datasets/create/input/result/__init__.py | 17 + .../input/{result.py => result/field.py} | 98 +--- src/anemoi/datasets/create/input/step.py | 203 ------- src/anemoi/datasets/create/input/template.py | 162 ------ src/anemoi/datasets/create/python.py | 174 ++++++ .../datasets/create/sources/accumulations.py | 13 +- .../datasets/create/sources/constants.py | 2 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/grib.py | 2 +- src/anemoi/datasets/create/sources/legacy.py | 9 +- .../datasets/create/sources/patterns.py | 2 +- .../create/sources/planetary_computer.py | 44 ++ .../datasets/create/sources/repeated_dates.py | 319 +++++++++++ .../create/sources/xarray_support/__init__.py | 28 +- .../sources/xarray_support/coordinates.py | 8 + .../create/sources/xarray_support/field.py | 5 +- .../create/sources/xarray_support/flavour.py | 50 +- .../create/sources/xarray_support/patch.py | 45 +- .../create/sources/xarray_support/variable.py | 8 +- src/anemoi/datasets/data/complement.py | 54 +- src/anemoi/datasets/data/dataset.py | 29 + src/anemoi/datasets/data/forwards.py | 10 +- src/anemoi/datasets/data/misc.py | 90 ++- .../datasets/data/observations/__init__.py | 316 ++++++++++ .../data/observations/legacy_obs_dataset.py | 200 +++++++ .../datasets/data/observations/multi.py | 64 +++ src/anemoi/datasets/data/padded.py | 227 ++++++++ src/anemoi/datasets/data/records/__init__.py | 442 ++++++++++++++ .../data/records/backends/__init__.py | 157 +++++ src/anemoi/datasets/data/stores.py | 63 +- src/anemoi/datasets/data/subset.py | 5 + src/anemoi/datasets/grids.py | 9 +- tests/conftest.py | 1 + tests/create/__init__.py | 0 tests/create/test_create.py | 369 +----------- tests/create/test_sources.py | 179 +++++- tests/create/utils/__init__.py | 0 tests/create/utils/compare.py | 218 +++++++ .../create/utils/create.py | 0 tests/create/utils/mock_sources.py | 117 ++++ tests/test_data.py | 26 +- tests/test_records.py | 160 ++++++ tests/xarray/test_flavour.py | 104 ++++ tests/xarray/test_zarr.py | 28 - tools/build-obs.py | 52 ++ tools/check-obs.py | 60 ++ 77 files changed, 3726 insertions(+), 2395 deletions(-) create mode 100644 docs/adr/adr-1.md create mode 100644 docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml create mode 100644 docs/howtos/create/yaml/zarr2.yaml delete mode 100644 src/anemoi/datasets/create/input/concat.py delete mode 100644 src/anemoi/datasets/create/input/context.py create mode 100644 src/anemoi/datasets/create/input/context/__init__.py create mode 100644 src/anemoi/datasets/create/input/context/field.py delete mode 100644 src/anemoi/datasets/create/input/empty.py delete mode 100644 src/anemoi/datasets/create/input/filter.py delete mode 100644 src/anemoi/datasets/create/input/function.py delete mode 100644 src/anemoi/datasets/create/input/join.py delete mode 100644 src/anemoi/datasets/create/input/pipe.py create mode 100644 src/anemoi/datasets/create/input/result/__init__.py rename src/anemoi/datasets/create/input/{result.py => result/field.py} (87%) delete mode 100644 src/anemoi/datasets/create/input/step.py delete mode 100644 src/anemoi/datasets/create/input/template.py create mode 100644 src/anemoi/datasets/create/python.py create mode 100644 src/anemoi/datasets/create/sources/planetary_computer.py create mode 100644 src/anemoi/datasets/create/sources/repeated_dates.py create mode 100644 src/anemoi/datasets/data/observations/__init__.py create mode 100644 src/anemoi/datasets/data/observations/legacy_obs_dataset.py create mode 100644 src/anemoi/datasets/data/observations/multi.py create mode 100644 src/anemoi/datasets/data/padded.py create mode 100644 src/anemoi/datasets/data/records/__init__.py create mode 100644 src/anemoi/datasets/data/records/backends/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/create/__init__.py create mode 100644 tests/create/utils/__init__.py create mode 100644 tests/create/utils/compare.py rename src/anemoi/datasets/create/testing.py => tests/create/utils/create.py (100%) create mode 100644 tests/create/utils/mock_sources.py create mode 100644 tests/test_records.py create mode 100644 tests/xarray/test_flavour.py create mode 100755 tools/build-obs.py create mode 100755 tools/check-obs.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 73cde7091..226403999 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! +## [0.5.25](https://github.com/ecmwf/anemoi-datasets/compare/0.5.24...0.5.25) (2025-06-11) + + +### Features + +* Integrating non-regular datasets in anemoi for observations. ([#306](https://github.com/ecmwf/anemoi-datasets/issues/306)) ([95a0fe4](https://github.com/ecmwf/anemoi-datasets/commit/95a0fe4bb10dc48469c0be0efad94f4d5e2a9fe8)) + + +### Bug Fixes + +* Incremental dataset build tasks called regardless of presence of debug flag in CLI code ([#294](https://github.com/ecmwf/anemoi-datasets/issues/294)) ([37afc0d](https://github.com/ecmwf/anemoi-datasets/commit/37afc0d6489f2d6c4b3ce3f9901c40e4cec5c4eb)) +* Regression in accumulations [#354](https://github.com/ecmwf/anemoi-datasets/issues/354) ([#355](https://github.com/ecmwf/anemoi-datasets/issues/355)) ([f9769d7](https://github.com/ecmwf/anemoi-datasets/commit/f9769d7944738ecbedb6b3cc1f78cd26de36a73f)) +* Remove 2 layers of build function ([#348](https://github.com/ecmwf/anemoi-datasets/issues/348)) ([7a904c4](https://github.com/ecmwf/anemoi-datasets/commit/7a904c451772089f120419a9d39bff746e0aeebb)) + +## [0.5.24](https://github.com/ecmwf/anemoi-datasets/compare/0.5.23...0.5.24) (2025-05-23) + + +### Features + +* verify command ([#279](https://github.com/ecmwf/anemoi-datasets/issues/279)) ([aed36d2](https://github.com/ecmwf/anemoi-datasets/commit/aed36d2ea7a39ea1ae6bbd5f8d01ef0ce7523cde)) + + +### Bug Fixes + +* adapt to earthkit 0.14 (ignore_keys issue) ([#331](https://github.com/ecmwf/anemoi-datasets/issues/331)) ([fb3ab8d](https://github.com/ecmwf/anemoi-datasets/commit/fb3ab8d46b8e00c62d8d7cbb1d1afae0efea2054)) + ## [0.5.23](https://github.com/ecmwf/anemoi-datasets/compare/0.5.22...0.5.23) (2025-05-07) diff --git a/docs/adr/adr-1.md b/docs/adr/adr-1.md new file mode 100644 index 000000000..6e22b2aa2 --- /dev/null +++ b/docs/adr/adr-1.md @@ -0,0 +1,63 @@ +# Support irregular observations datasets + +## Status + + + +Proposed - 30/04/2025 + +## Context + + + +The objective of this change is to support observations data which is not regular. + +In contrast with the fields data where each date contain the same number of points, +in the observations data, the number of points can change for every time window. + +The Zarr format fits well the fields data, but does not fit the observations data. + +To allow storing data with irregular shape, we need to use another format than the zarr used for fields. +An experimental implementation using xarray-zarr has been developed and is not optimised for ML training. + +## Decision + + + +Add a functionality in anemoi-datasets to read observations datasets and provide the data as dictionary/mapper of numpy arrays. + +Mimic as much as possible what is done for field datasets : + +`ds = open_dataset(....)` +`ds[i]` -> provides the data for a given time window, related to a given reference date. As a dictionary-like object. +`ds.dates` -> list of reference date, `ds.dates[i]` is the reference date for the data provided in `ds[i]` + +Also expose the latitudes, longitudes in a sensible way (as `ds.latitudes` and `ds.longitudes` now depend on the dates) and name_to_index and statistics and metadata, etc. + +These API choices need to be made on an actual training use case. + +Step to achieve this: +- Implement now a prototype format to allow developing ML training code on observation data. +- Performing extensive benchmarking with various formats (explore parquet, and other). +- As the final format is not defined yet, ensure a flexible architecture to allow switching (this will help for benchmarking). + + +## Scope of Change + + +- anemoi-datasets +Not a breaking change, this only add functionality to read observations datasets. + +Must be in line with the change related to multi-datasets. + +## Consequences + + + +## Alternatives Considered [Optional] + + + +## References [Optional] + + diff --git a/docs/datasets/building/sources/anemoi-dataset.rst b/docs/datasets/building/sources/anemoi-dataset.rst index d279614cf..a8e336318 100644 --- a/docs/datasets/building/sources/anemoi-dataset.rst +++ b/docs/datasets/building/sources/anemoi-dataset.rst @@ -17,3 +17,9 @@ An anemoi-dataset can be a source for a dataset: The parameters are the same as those used in the ``open_dataset`` function, which allows you to subset and combine datasets. See :ref:`opening-datasets` for more information. + +In particular, this is how local zarr datasets created with anemoi in a +can be used as a source, contrary to :ref:`xarray-zarr` : + +.. literalinclude:: yaml/anemoi-zarr-dataset.yaml + :language: yaml diff --git a/docs/datasets/building/sources/xarray-zarr.rst b/docs/datasets/building/sources/xarray-zarr.rst index 84f5158df..0f9ce62c8 100644 --- a/docs/datasets/building/sources/xarray-zarr.rst +++ b/docs/datasets/building/sources/xarray-zarr.rst @@ -1,3 +1,5 @@ +.. _xarray-zarr: + ############# xarray-zarr ############# @@ -17,4 +19,8 @@ it is necessary to use the :ref:`join ` operation to join separate lists containing 2D variables and 3D variables. If all vertical levels are desired, then it is acceptable to specify a single source. +Also, an ``xarray-zarr`` source uses the ``url`` keyword, and cannot be +used for accessing local datasets. For using local zarr datasets as +sources, use instead :ref:`anemoi-dataset_source`. + See :ref:`create-cf-data` for more information. diff --git a/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml b/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml new file mode 100644 index 000000000..1aa7b4f7e --- /dev/null +++ b/docs/datasets/building/sources/yaml/anemoi-zarr-dataset.yaml @@ -0,0 +1,3 @@ +input: + anemoi-dataset: + dataset: path/to/dataset.zarr diff --git a/docs/datasets/using/combining.rst b/docs/datasets/using/combining.rst index 4eac4c500..358953a26 100644 --- a/docs/datasets/using/combining.rst +++ b/docs/datasets/using/combining.rst @@ -235,13 +235,15 @@ variables of `dataset1` and return the result. source=dataset2, what="variables", interpolate="nearest", + k=1, ) Currently ``what`` can only be ``variables`` and can be omitted. The value for ``interpolate`` can be one of ``none`` (default) or ``nearest``. In the case of ``none``, the grids of the two datasets must -match. +match. In case of ``interpolate``, an additional parameter ``k`` can be +set to specify the number of nearest neighbors to use. This feature was originally designed to be used in conjunction with ``cutout``, where `dataset1` is the lam, and `dataset2` is the global diff --git a/docs/howtos/create/02-cf-data.rst b/docs/howtos/create/02-cf-data.rst index 38050ba42..ba875e3a8 100644 --- a/docs/howtos/create/02-cf-data.rst +++ b/docs/howtos/create/02-cf-data.rst @@ -46,9 +46,18 @@ can contain patterns. See :ref:`file-pattern` for more information. Zarr ****** +For using remote hosted zarr datasets as sources, use +:ref:`xarray-zarr`. + .. literalinclude:: yaml/zarr1.yaml :language: yaml +For using local zarr datasets (such as anemoi-generated datasets), use +:ref:`anemoi-dataset_source`. + +.. literalinclude:: yaml/zarr2.yaml + :language: yaml + ********************************************* Handling data that is not 100% CF-compliant ********************************************* diff --git a/docs/howtos/create/yaml/zarr2.yaml b/docs/howtos/create/yaml/zarr2.yaml new file mode 100644 index 000000000..659be1c72 --- /dev/null +++ b/docs/howtos/create/yaml/zarr2.yaml @@ -0,0 +1,8 @@ +dates: + start: 2023-01-01T00:00:00 + end: 2023-01-02T18:00:00 + frequency: 6h + +input: + anemoi-dataset: + dataset: /path/to/input.zarr diff --git a/docs/howtos/usage/code/cutout-complement1.py b/docs/howtos/usage/code/cutout-complement1.py index 347ae420c..dd9106fa9 100644 --- a/docs/howtos/usage/code/cutout-complement1.py +++ b/docs/howtos/usage/code/cutout-complement1.py @@ -14,4 +14,5 @@ }, source="global-dataset", interpolation="nearest", + k=1, ) diff --git a/docs/howtos/usage/yaml/cutout-complement1.yaml b/docs/howtos/usage/yaml/cutout-complement1.yaml index 97109fb5b..b2dfb605d 100644 --- a/docs/howtos/usage/yaml/cutout-complement1.yaml +++ b/docs/howtos/usage/yaml/cutout-complement1.yaml @@ -9,3 +9,4 @@ dataset: adjust: dates source: global-dataset interpolation: nearest + k: 1 diff --git a/pyproject.toml b/pyproject.toml index 799c26daa..873415a24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,8 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-transform>=0.1.9", - "anemoi-utils[provenance]>=0.4.21", + "anemoi-transform>=0.1.10", + "anemoi-utils[provenance]>=0.4.26", "cfunits", "numcodecs<0.16", # Until we move to zarr3 "numpy", @@ -71,7 +71,7 @@ optional-dependencies.comparelam = [ optional-dependencies.create = [ "cachetools", - "earthkit-data[mars]>=0.12.4,<0.14", + "earthkit-data[mars]>=0.14", "earthkit-geo>=0.3", "earthkit-meteo>=0.3", "eccodes>=2.39.1", @@ -100,6 +100,7 @@ optional-dependencies.remote = [ optional-dependencies.tests = [ "anemoi-datasets[xarray]", "pytest", + "pytest-xdist", ] optional-dependencies.xarray = [ @@ -131,6 +132,13 @@ version_file = "src/anemoi/datasets/_version.py" [tool.isort] profile = "black" +[tool.pytest.ini_options] +testpaths = "tests" +addopts = [ + "--numprocesses=auto", + "--strict-config", +] + [tool.mypy] strict = false exclude = [ diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index a155f16e3..811760ae9 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -61,7 +61,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index ee9fa1c00..53999ad50 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -55,7 +55,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index 54a48ed5f..09788f0e4 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -61,7 +61,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index 68aaa601a..a8cd5d5c9 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -62,7 +62,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 74220025f..3d969f5c3 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -62,7 +62,8 @@ def run(self, args: Any) -> None: if "debug" in options: options.pop("debug") - task(step, options) + + task(step, options) LOG.info(f"Create step '{step}' completed in {seconds_to_human(time.time()-now)}") diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py index 3045f1a1b..f111aee1b 100644 --- a/src/anemoi/datasets/commands/recipe.py +++ b/src/anemoi/datasets/commands/recipe.py @@ -28,6 +28,9 @@ def add_arguments(self, command_parser: Any) -> None: command_parser : Any Command parser object. """ + + command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument( "path", help="Path to recipe.", @@ -38,7 +41,8 @@ def run(self, args: Any) -> None: with open(args.path, "r") as file: config = yaml.safe_load(file) - config = migrate(config) + if args.migrate: + config = migrate(config) print(config_to_python(config)) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 37bf11763..69b5a0d42 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -11,7 +11,6 @@ import json import logging import os -import re import time import uuid import warnings @@ -22,6 +21,7 @@ import cftime import numpy as np +import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -45,7 +45,7 @@ from .chunks import ChunkFilter from .config import build_output from .config import loader_config -from .input import build_input +from .input import InputBuilder from .statistics import Summary from .statistics import TmpStatistics from .statistics import check_variance @@ -102,7 +102,9 @@ def json_tidy(o: Any) -> Any: def build_statistics_dates( - dates: list[datetime.datetime], start: Optional[datetime.datetime], end: Optional[datetime.datetime] + dates: list[datetime.datetime], + start: Optional[datetime.datetime], + end: Optional[datetime.datetime], ) -> tuple[str, str]: """Compute the start and end dates for the statistics. @@ -552,36 +554,16 @@ def create_elements(self, config: Any) -> None: self.output = build_output(config.output, parent=self) - self.input = build_input_(main_config=config, output_config=self.output) - # LOG.info("%s", self.input) - - -def build_input_(main_config: Any, output_config: Any) -> Any: - """Build the input for the dataset. - - Parameters - ---------- - main_config : Any - The main configuration. - output_config : Any - The output configuration. - - Returns - ------- - Any - The input builder. - """ - builder = build_input( - main_config.input, - data_sources=main_config.get("data_sources", {}), - order_by=output_config.order_by, - flatten_grid=output_config.flatten_grid, - remapping=build_remapping(output_config.remapping), - use_grib_paramid=main_config.build.use_grib_paramid, - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(builder) - return builder + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=config.build.use_grib_paramid, + ) + LOG.debug("✅ INPUT_BUILDER") + LOG.debug(self.input) class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): @@ -690,6 +672,8 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) + rich.print("Minimal input dates:", self.minimal_input) + variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") @@ -886,7 +870,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(group_of_dates=group) + result = self.input.select(argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. @@ -1542,7 +1526,16 @@ def run(self) -> None: if not all(self.registry.get_flags(sync=False)): raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) self.registry.add_to_history("compute_statistics_end") @@ -1674,24 +1667,11 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - config = loader_config(config) - - input = build_input_(config, build_output(config.output, None)) + from ..create.python import PythonCode - prelude = [] - input.python_prelude(prelude) - code1 = "\n".join(prelude) - - code2 = input.to_python() - - code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" - - code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) + config = loader_config(config) - try: - import black + input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - return black.format_str(code, mode=black.Mode()) - except ImportError: - LOG.warning("Black not installed, skipping formatting") - return code + code = PythonCode() + return input.python_code(code) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 66266d53d..63020324c 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -7,22 +7,12 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging from copy import deepcopy +from functools import cached_property from typing import Any from typing import Union -from anemoi.datasets.dates.groups import GroupOfDates - -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class Context: - """Context for building input data.""" - - pass +from anemoi.datasets.create.input.context.field import FieldContext class InputBuilder: @@ -51,15 +41,20 @@ def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) ) ) self.config = config - self.action_path = ["input"] - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Any: + @cached_property + def action(self) -> Any: + """Returns the action object based on the configuration.""" + from .action import action_factory + + return action_factory(self.config, "input") + + def select(self, argument) -> Any: """Select data based on the group of dates. Parameters ---------- - group_of_dates : GroupOfDates + argument : GroupOfDates Group of dates to select data for. Returns @@ -67,60 +62,11 @@ def select(self, group_of_dates: GroupOfDates) -> Any: Any Selected data. """ - from .action import ActionContext - from .action import action_factory - - """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(group_of_dates) - - def to_python(self) -> str: - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - - return action.to_python() - - def python_prelude(self, prelude) -> str: - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.python_prelude(prelude) - - def __repr__(self) -> str: - """Return a string representation of the InputBuilder. - - Returns - ------- - str - String representation. - """ - from .action import ActionContext - from .action import action_factory - - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) + context = FieldContext(argument, **self.kwargs) + return context.create_result(self.action(context, argument)) - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the select operation. - - Parameters - ---------- - group_of_dates : GroupOfDates - Group of dates to select data for. - - Returns - ------- - str - Trace string. - """ - return f"InputBuilder({group_of_dates})" + def python_code(self, code): + return self.action.python_code(code) def build_input(config: dict, data_sources: Union[dict, list], **kwargs: Any) -> InputBuilder: diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 121d1e387..8dadf14dc 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -7,338 +7,230 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import datetime -import json import logging -import re -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from anemoi.utils.dates import frequency_to_string -from earthkit.data.core.order import build_remapping +import rich -from ...dates.groups import GroupOfDates -from .context import Context -from .template import substitute +from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) class Action: - """Represents an action to be performed within a given context. - - Attributes - ---------- - context : ActionContext - The context in which the action exists. - kwargs : Dict[str, Any] - Additional keyword arguments. - args : Any - Additional positional arguments. - action_path : List[str] - The action path. - """ - - def __init__( - self, context: "ActionContext", action_path: List[str], /, *args: Any, **kwargs: Dict[str, Any] - ) -> None: - """Initialize an Action instance. - - Parameters - ---------- - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - args : Any - Additional positional arguments. - kwargs : Dict[str, Any] - Additional keyword arguments. - """ - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x: str) -> str: - """Shorten the string representation if it exceeds 1000 characters. - - Parameters - ---------- - x : str - The string to shorten. - - Returns - ------- - str - The shortened string. - """ - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def _repr(self, *args: Any, _indent_: str = "\n", _inline_: str = "", **kwargs: Any) -> str: - """Generate a string representation of the Action instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str, optional - The indentation string, by default "\n". - _inline_ : str, optional - The inline string, by default "". - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Action instance. - - Returns - ------- - str - The string representation. - """ - return self._repr() - - def select(self, dates: object, **kwargs: Any) -> None: - """Select dates for the action. - - Parameters - ---------- - dates : object - The dates to select. - kwargs : Any - Additional keyword arguments. - """ - self._raise_not_implemented() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Trace the selection of a group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates to trace. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({group_of_dates})" - - def _to_python(self, name: str, config: dict, **extra: Any) -> str: - """Convert the action to Python code. - - Parameters - ---------- - name : str - The name of the action. - config : dict - The configuration for the action. - extra : Any - Additional keyword arguments. - - Returns - ------- - str - The Python code representation of the action. - """ - import json - - RESERVED_KEYWORDS = ( - "and", - "or", - "not", - "is", - "in", - "if", - "else", - "elif", - "for", - "while", - "return", - "class", - "def", - "with", - "as", - "import", - "from", - "try", - "except", - "finally", - "raise", - "assert", - "break", - "continue", - "pass", + def __init__(self, config, *path): + self.config = config + self.path = path + # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") + + +class Concat(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + + assert isinstance(config, list), f"Value must be a dict {list}" + + self.choices = [] + + for item in config: + + assert "dates" in item, f"Value must contain the key 'dates' {item}" + dates = item["dates"] + filtering_dates = DatesProvider.from_config(**dates) + action = action_factory({k: v for k, v in item.items() if k != "dates"}) + self.choices.append((filtering_dates, action)) + + def __repr__(self): + return f"Concat({self.choices})" + + def __call__(self, context, argument): + + results = context.empty_result() + + for filtering_dates, action in self.choices: + dates = context.matching_dates(filtering_dates, argument) + if len(dates) == 0: + continue + results += action(context, dates) + + return context.register(results, self.path) + + def python_code(self, code): + return code.concat( + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} ) - def convert(obj): - if isinstance(obj, datetime.datetime): - return obj.isoformat() - if isinstance(obj, datetime.date): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return frequency_to_string(obj) - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - config = json.loads(json.dumps(config, default=convert)) - - assert len(config) == 1, (name, config) - assert name in config, (name, config) - - config = config[name] - - params = [] - for k, v in config.items(): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: - return f"r.{name}({config})" - params.append(f"{k}={repr(v)}") - - for k, v in extra.items(): - params.append(f"{k}={v}") - - params = ",".join(params) - return f"r.{name}({params})" - # return f"{name}({config})" - - -class ActionContext(Context): - """Represents the context in which an action is performed. - - Attributes - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - - def __init__(self, /, order_by: str, flatten_grid: bool, remapping: Dict[str, Any], use_grib_paramid: bool) -> None: - """Initialize an ActionContext instance. - - Parameters - ---------- - order_by : str - The order by criteria. - flatten_grid : bool - Whether to flatten the grid. - remapping : Dict[str, Any] - The remapping configuration. - use_grib_paramid : bool - Whether to use GRIB parameter ID. - """ - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -def action_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str]) -> Action: - """Factory function to create an Action instance based on the configuration. - - Parameters - ---------- - config : Dict[str, Any] - The action configuration. - context : ActionContext - The context in which the action exists. - action_path : List[str] - The action path. - - Returns - ------- - Action - The created Action instance. - """ - from .concat import ConcatAction - from .data_sources import DataSourcesAction - from .function import FunctionAction - from .join import JoinAction - from .pipe import PipeAction - from .repeated_dates import RepeatedDatesAction - - # from .data_sources import DataSourcesAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - if len(config) != 1: - if "label" in config: - config.pop("label") - if "name" in config: - config.pop("name") - if len(config) != 1: - print(json.dumps(config, indent=2, default=str)) - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - "data_sources": DataSourcesAction, - "data-sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - "repeated_dates": RepeatedDatesAction, - "repeated-dates": RepeatedDatesAction, - }.get(key) - - if cls is None: - from ..sources import create_source - - source = create_source(None, substitute(context, config)) - return FunctionAction(context, action_path + [key], key, source, config) - - return cls(context, action_path + [key], *args, **kwargs) + +class Join(Action): + def __init__(self, config, *path): + super().__init__(config, *path) + + assert isinstance(config, list), f"Value must be a list {config}" + + self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Join({self.actions})" + + def __call__(self, context, argument): + results = context.empty_result() + + for action in self.actions: + results += action(context, argument) + + return context.register(results, self.path) + + def python_code(self, code) -> None: + return code.sum(a.python_code(code) for a in self.actions) + + +class Pipe(Action): + def __init__(self, config, *path): + assert isinstance(config, list), f"Value must be a list {config}" + super().__init__(config, *path) + self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] + + def __repr__(self): + return f"Pipe({self.actions})" + + def __call__(self, context, argument): + result = context.empty_result() + + for i, action in enumerate(self.actions): + if i == 0: + result = action(context, argument) + else: + result = action(context, result) + + return context.register(result, self.path) + + def python_code(self, code) -> None: + return code.pipe(a.python_code(code) for a in self.actions) + + +class Function(Action): + def __init__(self, config, *path): + super().__init__(config, *path, self.name) + + def __call__(self, context, argument): + + config = context.resolve(self.config) # Substitute the ${} variables in the config + + config["_type"] = self.name # Find a better way to do this + + source = self.create_object(config) + + rich.print(f"Executing source {self.name} from {config}") + + return context.register(self.call_object(context, source, argument), self.path) + + def python_code(self, code) -> str: + return code.call(self.name, self.config) + + +class DatasetSourceMixin: + def create_object(self, config): + from anemoi.datasets.create.sources import create_source as create_datasets_source + + return create_datasets_source(self, config) + + def call_object(self, context, source, argument): + return source.execute(context, context.source_argument(argument)) + + +class DatasetFilterMixin: + def create_object(self, config): + from anemoi.datasets.create.filters import create_filter as create_datasets_filter + + return create_datasets_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.execute(context.filter_argument(argument)) + + +class TransformSourceMixin: + def create_object(self, config): + from anemoi.transform.sources import create_source as create_transform_source + + return create_transform_source(self, config) + + +class TransformFilterMixin: + def create_object(self, config): + from anemoi.transform.filters import create_filter as create_transform_filter + + return create_transform_filter(self, config) + + def call_object(self, context, filter, argument): + return filter.forward(context.filter_argument(argument)) + + +class FilterFunction(Function): + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) + + +def _make_name(name, what): + name = name.replace("_", "-") + name = "".join(x.title() for x in name.split("-")) + return name + what.title() + + +def new_source(name, mixin): + return type( + _make_name(name, "source"), + (Function, mixin), + {"name": name}, + ) + + +def new_filter(name, mixin): + return type( + _make_name(name, "filter"), + (Function, mixin), + {"name": name}, + ) + + +KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} + +LEN_KLASS = len(KLASS) + + +def make(key, config, path): + + if LEN_KLASS == len(KLASS): + + # Load pluggins + from anemoi.transform.filters import filter_registry as transform_filter_registry + from anemoi.transform.sources import source_registry as transform_source_registry + + from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry + from anemoi.datasets.create.sources import source_registry as dataset_source_registry + + # Register sources, local first + for name in dataset_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + + for name in transform_source_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + + # Register filters, local first + for name in dataset_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) + + for name in transform_filter_registry.registered: + if name not in KLASS: + KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) + + return KLASS[key.replace("_", "-")](config, *path) + + +def action_factory(data, *path): + assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" + + key, value = next(iter(data.items())) + return make(key, value, path) diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py deleted file mode 100644 index bd906bd03..000000000 --- a/src/anemoi/datasets/create/input/concat.py +++ /dev/null @@ -1,184 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from copy import deepcopy -from functools import cached_property -from typing import Any -from typing import Dict -from typing import List -from typing import Union - -from earthkit.data import FieldList - -from anemoi.datasets.dates import DatesProvider - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class ConcatResult(Result): - """Represents the result of concatenating multiple results.""" - - def __init__( - self, - context: object, - action_path: List[str], - group_of_dates: GroupOfDates, - results: List[Result], - **kwargs: Any, - ) -> None: - """Initializes a ConcatResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, group_of_dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the concatenated datasource from all results.""" - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self) -> List[str]: - """Returns the list of variables, ensuring all results have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self) -> str: - """Returns a string representation of the ConcatResult instance. - - Returns - ------- - str - A string representation of the ConcatResult instance. - """ - content = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class ConcatAction(Action): - """Represents an action that concatenates multiple actions based on their dates.""" - - def __init__(self, context: object, action_path: List[str], *configs: Dict[str, Any]) -> None: - """Initializes a ConcatAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : List[str] - The action path. - configs : Dict[str, Any] - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self) -> str: - """Returns a string representation of the ConcatAction instance. - - Returns - ------- - str - A string representation of the ConcatAction instance. - """ - content = "\n".join([str(i) for i in self.parts]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> Union[ConcatResult, EmptyResult]: - """Selects the concatenated result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - Union[ConcatResult, EmptyResult] - The concatenated result or an empty result. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, group_of_dates) - - return ConcatResult(self.context, self.action_path, group_of_dates, results) - - def to_python(self) -> str: - """Returns the Python representation of the ConcatAction instance. - - Returns - ------- - str - The Python representation of the ConcatAction instance. - """ - - result = [] - - for i, (filtering_dates, action) in enumerate(self.parts): - result.append(f"{filtering_dates.to_python()}:{action.to_python()}") - - return f"r.concat({{{','.join(result)}}})" - - def python_prelude(self, prelude) -> None: - for filtering_dates, action in self.parts: - action.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py deleted file mode 100644 index 35784dba7..000000000 --- a/src/anemoi/datasets/create/input/context.py +++ /dev/null @@ -1,89 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import textwrap -from typing import Any -from typing import List -from typing import Tuple -from typing import Union - -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - -LOG = logging.getLogger(__name__) - - -class Context: - """Class to handle the build context in the dataset creation process.""" - - def __init__(self) -> None: - """Initializes a Context instance.""" - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key: Union[List, Tuple]) -> None: - """Marks a reference as needed. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key: Union[List, Tuple], result: Any) -> None: - """Notifies that a result is available for a reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - result : Any - The result object. - """ - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key: Union[List, Tuple]) -> Any: - """Retrieves the result for a given reference. - - Parameters - ---------- - key : Union[List, Tuple] - The reference key. - - Returns - ------- - Any - The result for the given reference. - """ - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py new file mode 100644 index 000000000..26d449659 --- /dev/null +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -0,0 +1,63 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from abc import abstractmethod +from typing import Any + +import rich + +LOG = logging.getLogger(__name__) + + +class Context(ABC): + """Context for building input data.""" + + def __init__(self, /, argument: Any) -> None: + self.results = {} + self.cache = {} + self.argument = argument + + def trace(self, emoji, *message) -> None: + + rich.print(f"{emoji}: {message}") + + def register(self, data: Any, path: list[str]) -> Any: + + if not path: + return data + + rich.print(f"Registering data at path: {path}") + self.results[tuple(path)] = data + return data + + def resolve(self, config): + config = config.copy() + + for key, value in list(config.items()): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + path = tuple(value[2:-1].split(".")) + if path in self.results: + config[key] = self.results[path] + else: + raise KeyError(f"Path {path} not found in results: {self.results.keys()}") + + return config + + def create_source(self, config: Any) -> Any: + from anemoi.datasets.create.input.action import action_factory + + return action_factory(config) + + @abstractmethod + def empty_result(self) -> Any: ... + + @abstractmethod + def create_result(self, data: Any) -> Any: ... diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py new file mode 100644 index 000000000..c3456d89f --- /dev/null +++ b/src/anemoi/datasets/create/input/context/field.py @@ -0,0 +1,54 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any +from typing import Dict + +from earthkit.data.core.order import build_remapping + +from ..result.field import FieldResult +from . import Context + + +class FieldContext(Context): + + def __init__( + self, + /, + argument: Any, + order_by: str, + flatten_grid: bool, + remapping: Dict[str, Any], + use_grib_paramid: bool, + ) -> None: + super().__init__(argument) + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + + def empty_result(self) -> Any: + import earthkit.data as ekd + + return ekd.from_source("empty") + + def source_argument(self, argument: Any) -> Any: + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument + + def create_result(self, data): + return FieldResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 5b6282469..f9811d178 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -20,7 +20,7 @@ from .action import Action from .action import action_factory from .misc import _tidy -from .result import Result +from .result.field import Result LOG = logging.getLogger(__name__) @@ -87,13 +87,10 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_prelude(self, prelude) -> str: + def python_code(self, code) -> str: for n, s in zip(self.names, self.sources): - s.python_prelude(prelude) - prelude.append(f"{n}={s.to_python()}") - - def to_python(self) -> str: - return self.input.to_python() + code.source(n, s.python_code(code)) + return code class DataSourcesResult(Result): diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py deleted file mode 100644 index 410b4c973..000000000 --- a/src/anemoi/datasets/create/input/empty.py +++ /dev/null @@ -1,54 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import List - -from earthkit.data import FieldList - -from .misc import assert_fieldlist -from .result import Result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class EmptyResult(Result): - """Class to represent an empty result in the dataset creation process.""" - - empty = True - - def __init__(self, context: object, action_path: list, dates: object) -> None: - """Initializes an EmptyResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - dates : object - The dates object. - """ - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns an empty datasource.""" - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self) -> List[str]: - """Returns an empty list of variables.""" - return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py deleted file mode 100644 index 9357d2178..000000000 --- a/src/anemoi/datasets/create/input/filter.py +++ /dev/null @@ -1,133 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Type - -from earthkit.data import FieldList - -from .function import FunctionContext -from .misc import _tidy -from .misc import assert_fieldlist -from .step import StepAction -from .step import StepResult -from .template import notify_result -from .trace import trace_datasource - -LOG = logging.getLogger(__name__) - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self) -> FieldList: - """Returns the filtered datasource.""" - ds: FieldList = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - """Represents an action to filter a step result.""" - - result_class: Type[FilterStepResult] = FilterStepResult - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource after applying the function.""" - - self.action.filter.context = FunctionContext(self) - try: - return _tidy( - self.action.filter.execute( - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - A string representation of the traced datasource. - """ - return f"{self.action.name}({self.group_of_dates})" - - -class FunctionStepAction(StepAction): - """Represents an action to apply a function to a step result.""" - - result_class: Type[StepFunctionResult] = StepFunctionResult - - def __init__( - self, - context: object, - action_path: list, - previous_step: StepAction, - name: str, - filter: Any, - config: dict, - *args: Any, - **kwargs: Any, - ) -> None: - """Initializes a FunctionStepAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - previous_step : StepAction - The previous step action. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - """ - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = name - self.filter = filter - self.config = config - - def to_python(self) -> str: - """Converts the action to Python code. - - Returns - ------- - str - The converted Python code. - """ - return self._to_python(self.name, self.config) - - def python_prelude(self, prelude) -> None: - pass diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py deleted file mode 100644 index 651b509b2..000000000 --- a/src/anemoi/datasets/create/input/function.py +++ /dev/null @@ -1,244 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import Dict - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner: Result) -> None: - """Initializes a FunctionContext instance. - - Parameters - ---------- - owner : object - The owner object. - """ - self.owner = owner - self.use_grib_paramid: bool = owner.context.use_grib_paramid - - def trace(self, emoji: str, *args: Any) -> None: - """Traces the given arguments with an emoji. - - Parameters - ---------- - emoji : str - The emoji to use. - *args : Any - The arguments to trace. - """ - trace(emoji, *args) - - def info(self, *args: Any, **kwargs: Any) -> None: - """Logs an info message. - - Parameters - ---------- - *args : Any - The arguments for the log message. - **kwargs : Any - The keyword arguments for the log message. - """ - LOG.info(*args, **kwargs) - - @property - def dates_provider(self) -> object: - """Returns the dates provider.""" - return self.owner.group_of_dates.provider - - @property - def partial_ok(self) -> bool: - """Returns whether partial results are acceptable.""" - return self.owner.group_of_dates.partial_ok - - def get_result(self, *args, **kwargs) -> Any: - return self.owner.context.get_result(*args, **kwargs) - - -class FunctionAction(Action): - """Represents an action that executes a function. - - Attributes - ---------- - name : str - The name of the function. - """ - - def __init__( - self, context: object, action_path: list, _name: str, source, config: dict, **kwargs: Dict[str, Any] - ) -> None: - """Initializes a FunctionAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - _name : str - The name of the function. - **kwargs : Dict[str, Any] - Additional keyword arguments. - """ - super().__init__(context, action_path, **kwargs) - self.name: str = _name - self.source = source - self.config = config - - def to_python(self) -> str: - """Returns the Python representation of the function action.""" - - return self._to_python(self.name, self.config) - - def python_prelude(self, prelude) -> str: - pass - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> "FunctionResult": - """Selects the function result for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - FunctionResult - The function result instance. - """ - return FunctionResult(self.context, self.action_path, group_of_dates, action=self) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionAction instance.""" - content: str = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return self._repr(_inline_=content, _indent_=" ") - - def _trace_select(self, group_of_dates: GroupOfDates) -> str: - """Traces the selection of the function for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - str - The trace string. - """ - return f"{self.name}({group_of_dates})" - - -class FunctionResult(Result): - """Represents the result of executing a function. - - Attributes - ---------- - action : Action - The action instance. - args : tuple - The positional arguments for the function. - kwargs : dict - The keyword arguments for the function. - """ - - def __init__(self, context: object, action_path: list, group_of_dates: GroupOfDates, action: Action) -> None: - """Initializes a FunctionResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - action : Action - The action instance. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(action, Action), type(action) - self.action: Action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Traces the datasource for the given arguments. - - Parameters - ---------- - *args : Any - The arguments. - **kwargs : Any - The keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the datasource for the function result.""" - # args, kwargs = resolve(self.context, (self.args, self.kwargs)) - self.action.source.context = FunctionContext(self) - - return _tidy( - self.action.source.execute( - list(self.group_of_dates), # Will provide a list of datetime objects - ) - ) - - def __repr__(self) -> str: - """Returns a string representation of the FunctionResult instance.""" - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self) -> None: - """Raises NotImplementedError as this property is not implemented. - - Raises - ------ - NotImplementedError - Always raised. - """ - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py deleted file mode 100644 index 9fc81eac9..000000000 --- a/src/anemoi/datasets/create/input/join.py +++ /dev/null @@ -1,137 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from functools import cached_property -from typing import Any -from typing import List - -from earthkit.data import FieldList - -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .empty import EmptyResult -from .misc import _tidy -from .misc import assert_fieldlist -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class JoinResult(Result): - """Represents a result that combines multiple results. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - - def __init__( - self, context: object, action_path: list, group_of_dates: GroupOfDates, results: List[Result], **kwargs: Any - ) -> None: - """Initializes a JoinResult instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - group_of_dates : GroupOfDates - The group of dates. - results : List[Result] - The list of results. - """ - super().__init__(context, action_path, group_of_dates) - self.results: List[Result] = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self) -> FieldList: - """Returns the combined datasource from all results.""" - ds: FieldList = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self) -> str: - """Returns a string representation of the JoinResult instance.""" - content: str = "\n".join([str(i) for i in self.results]) - return self._repr(content) - - -class JoinAction(Action): - """Represents an action that combines multiple actions. - - Attributes - ---------- - context : object - The context object. - action_path : list - The action path. - actions : List[Action] - The list of actions. - """ - - def __init__(self, context: object, action_path: list, *configs: dict) -> None: - """Initializes a JoinAction instance. - - Parameters - ---------- - context : object - The context object. - action_path : list - The action path. - *configs : dict - The configuration dictionaries. - """ - super().__init__(context, action_path, *configs) - self.actions: List[Action] = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def to_python(self) -> None: - return "(" + " + ".join([i.to_python() for i in self.actions]) + ")" - - def python_prelude(self, prelude) -> None: - for i in self.actions: - i.python_prelude(prelude) - - def __repr__(self) -> str: - """Returns a string representation of the JoinAction instance.""" - content: str = "\n".join([str(i) for i in self.actions]) - return self._repr(content) - - @trace_select - def select(self, group_of_dates: GroupOfDates) -> JoinResult: - """Selects the results for the given group of dates. - - Parameters - ---------- - group_of_dates : GroupOfDates - The group of dates. - - Returns - ------- - JoinResult - The combined result for the given group of dates. - """ - results: List[Result] = [a.select(group_of_dates) for a in self.actions] - return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py deleted file mode 100644 index 7b8f062f3..000000000 --- a/src/anemoi/datasets/create/input/pipe.py +++ /dev/null @@ -1,77 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -from typing import Any - -from .action import Action -from .action import action_factory -from .step import step_factory -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class PipeAction(Action): - """A class to represent a pipeline of actions.""" - - def __init__(self, context: Any, action_path: list, *configs: dict) -> None: - """Initialize the PipeAction. - - Parameters - ---------- - context : Any - The context for the action. - action_path : list - The path of the action. - configs : dict - The configurations for the actions. - """ - super().__init__(context, action_path, *configs) - if len(configs) <= 1: - raise ValueError( - f"PipeAction requires at least two actions, got {len(configs)}\n{json.dumps(configs, indent=2)}" - ) - - self.actions: list = [] - - current: Any = action_factory(configs[0], context, action_path + ["0"]) - self.actions.append(current) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.actions.append(current) - self.last_step: Any = current - - @trace_select - def select(self, group_of_dates: Any) -> Any: - """Select data based on the group of dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to select data for. - - Returns - ------- - Any - The selected data. - """ - return self.last_step.select(group_of_dates) - - def __repr__(self) -> str: - """Return a string representation of the PipeAction.""" - return f"PipeAction({self.last_step})" - - def to_python(self) -> str: - return "(" + " | ".join([i.to_python() for i in self.actions]) + ")" - - def python_prelude(self, prelude) -> None: - for i in self.actions: - i.python_prelude(prelude) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index bc121c5c9..6304fcecc 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -9,7 +9,6 @@ import logging -import warnings from collections import defaultdict from typing import Any from typing import Dict @@ -28,7 +27,7 @@ from .action import Action from .action import action_factory from .join import JoinResult -from .result import Result +from .result.field import Result from .trace import trace_select LOG = logging.getLogger(__name__) @@ -204,21 +203,6 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) self.day: int = day self.hour: Optional[int] = hour - def to_python(self) -> Dict[str, Any]: - """Convert the DateMapper to Python code. - - Returns - ------- - dict - The Python code representation of the DateMapper. - """ - return { - "mode": "climatology", - "year": self.year, - "day": self.day, - "hour": self.hour, - } - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. @@ -369,16 +353,6 @@ def __init__(self, context: Any, action_path: List[str], source: Any, mode: str, self.mode = mode self.kwargs = kwargs - def to_python(self) -> str: - """Convert the action to Python code.""" - warnings.warn("RepeatedDatesAction.to_python is still a work in progress") - args = {"mode": self.mode} - args.update(self.kwargs) - return self._to_python("repeated_dates", {"repeated_dates": args}, source=self.source.to_python()) - - def python_prelude(self, prelude: Any) -> None: - self.source.python_prelude(prelude) - @trace_select def select(self, group_of_dates: Any) -> JoinResult: """Select and transform the group of dates. diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result/__init__.py new file mode 100644 index 000000000..03a00c51d --- /dev/null +++ b/src/anemoi/datasets/create/input/result/__init__.py @@ -0,0 +1,17 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC + +LOG = logging.getLogger(__name__) + + +class Result(ABC): + pass diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result/field.py similarity index 87% rename from src/anemoi/datasets/create/input/result.py rename to src/anemoi/datasets/create/input/result/field.py index de5388fd6..dd238998e 100644 --- a/src/anemoi/datasets/create/input/result.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -26,9 +26,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from .action import ActionContext -from .trace import trace -from .trace import trace_datasource +from . import Result LOG = logging.getLogger(__name__) @@ -282,40 +280,22 @@ def sort(old_dic: DefaultDict[str, set]) -> Dict[str, List[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class Result: +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: ActionContext, action_path: List[str], dates: Any) -> None: - """Initialize a Result instance. + def __init__(self, context: Any, datasource: Any) -> None: - Parameters - ---------- - context : ActionContext - The context in which the result exists. - action_path : list of str - The action path. - dates : Any - The dates associated with the result. - """ from anemoi.datasets.dates.groups import GroupOfDates - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - self.context: Any = context - self.group_of_dates: Any = dates - self.action_path: List[str] = action_path - - @property - @trace_datasource - def datasource(self) -> Any: - """Retrieve the data source for the result.""" - self._raise_not_implemented() + self.datasource = datasource + self.group_of_dates = context.argument + assert isinstance( + self.group_of_dates, GroupOfDates + ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" @property def data_request(self) -> Dict[str, Any]: @@ -330,7 +310,7 @@ def get_cube(self) -> Any: Any The data cube. """ - trace("🧊", f"getting cube from {self.__class__.__name__}") + ds: Any = self.datasource remapping: Any = self.context.remapping @@ -523,66 +503,6 @@ def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: print() exit(1) - def _repr(self, *args: Any, _indent_: str = "\n", **kwargs: Any) -> str: - """Return the string representation of the Result instance. - - Parameters - ---------- - args : Any - Additional positional arguments. - _indent_ : str - Indentation string. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The string representation. - """ - more: str = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates: str = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%dT%H:%M") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt: str = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def __repr__(self) -> str: - """Return the string representation of the Result instance.""" - return self._repr() - - def _raise_not_implemented(self) -> None: - """Raise a NotImplementedError indicating the method is not implemented.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args: Any, **kwargs: Any) -> str: - """Trace the data source for the result. - - Parameters - ---------- - args : Any - Additional positional arguments. - kwargs : Any - Additional keyword arguments. - - Returns - ------- - str - The trace string. - """ - return f"{self.__class__.__name__}({self.group_of_dates})" - def build_coords(self) -> None: """Build the coordinates for the result.""" if self._coords_already_built: diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py deleted file mode 100644 index a2e3ccd42..000000000 --- a/src/anemoi/datasets/create/input/step.py +++ /dev/null @@ -1,203 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings -from copy import deepcopy -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Type - -from .action import Action -from .action import ActionContext -from .context import Context -from .result import Result -from .template import notify_result -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class StepResult(Result): - """Represents the result of a step in the data processing pipeline.""" - - def __init__( - self, context: Context, action_path: List[str], group_of_dates: Any, action: Action, upstream_result: Result - ) -> None: - """Initialize a StepResult instance. - - Parameters - ---------- - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - group_of_dates - The group of dates associated with this step. - action - The action associated with this step. - upstream_result - The result of the upstream step. - """ - super().__init__(context, action_path, group_of_dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result: Result = upstream_result - self.action: Action = action - - @property - @notify_result - @trace_datasource - def datasource(self) -> Any: - """Retrieve the datasource associated with this step result.""" - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - """Represents an action that is part of a step in the data processing pipeline.""" - - result_class: Optional[Type[StepResult]] = None - - def __init__( - self, context: ActionContext, action_path: List[str], previous_step: Any, *args: Any, **kwargs: Any - ) -> None: - """Initialize a StepAction instance. - - Parameters - ---------- - context - The context in which the action is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - """ - super().__init__(context, action_path, *args, **kwargs) - self.previous_step: Any = previous_step - - @trace_select - def select(self, group_of_dates: Any) -> StepResult: - """Select the result for a given group of dates. - - Parameters - ---------- - group_of_dates - The group of dates to select the result for. - - Returns - ------- - unknown - The result of the step. - """ - return self.result_class( - self.context, - self.action_path, - group_of_dates, - self, - self.previous_step.select(group_of_dates), - ) - - def __repr__(self) -> str: - """Return a string representation of the StepAction instance. - - Returns - ------- - unknown - String representation of the instance. - """ - return self._repr(self.previous_step, _inline_=str(self.kwargs)) - - -def step_factory(config: Dict[str, Any], context: ActionContext, action_path: List[str], previous_step: Any) -> Any: - """Factory function to create a step action based on the given configuration. - - Parameters - ---------- - config - The configuration dictionary for the step. - context - The context in which the step is executed. - action_path - The path of actions leading to this step. - previous_step - The previous step in the pipeline. - - Returns - ------- - unknown - An instance of a step action. - """ - - from .filter import FilterStepAction - from .filter import FunctionStepAction - - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is not None: - return cls(context, action_path, previous_step, *args, **kwargs) - - # Try filters from datasets filter registry - from anemoi.transform.filters import filter_registry as transform_filter_registry - - from ..filters import create_filter as create_datasets_filter - from ..filters import filter_registry as datasets_filter_registry - - if datasets_filter_registry.is_registered(key): - - if transform_filter_registry.is_registered(key): - warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries") - - filter = create_datasets_filter(None, config) - return FunctionStepAction( - context, - action_path + [key], - previous_step, - key, - filter, - config, - ) - - # Use filters from transform registry - - if transform_filter_registry.is_registered(key): - from ..filters.transform import TransformFilter - - return FunctionStepAction( - context, - action_path + [key], - previous_step, - key, - TransformFilter(context, key, config), - config, - ) - - raise ValueError(f"Unknown step action `{key}`") diff --git a/src/anemoi/datasets/create/input/template.py b/src/anemoi/datasets/create/input/template.py deleted file mode 100644 index 8ea1ec275..000000000 --- a/src/anemoi/datasets/create/input/template.py +++ /dev/null @@ -1,162 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import re -from abc import ABC -from abc import abstractmethod -from functools import wraps -from typing import Any -from typing import Callable -from typing import List - -from .context import Context - -LOG = logging.getLogger(__name__) - - -def notify_result(method: Callable[..., Any]) -> Callable[..., Any]: - """Decorator to notify the context of the result of the method call. - - Parameters - ---------- - method : Callable[..., Any] - The method to wrap. - - Returns - ------- - Callable[..., Any] - The wrapped method. - """ - - @wraps(method) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result: Any = method(self, *args, **kwargs) - self.context.notify_result(self.action_path, result) - return result - - return wrapper - - -class Substitution(ABC): - """Abstract base class for substitutions in templates.""" - - @abstractmethod - def resolve(self, context: Context) -> Any: - """Resolve the substitution using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - pass - - -class Reference(Substitution): - """A class to represent a reference to another value in the context.""" - - def __init__(self, context: Any, action_path: List[str]) -> None: - """Initialize a Reference instance. - - Parameters - ---------- - context : Any - The context in which the reference exists. - action_path : list of str - The action path to resolve. - """ - self.context: Any = context - self.action_path: List[str] = action_path - - def resolve(self, context: Context) -> Any: - """Resolve the reference using the given context. - - Parameters - ---------- - context : Context - The context to use for resolution. - - Returns - ------- - Any - The resolved value. - """ - return context.get_result(self.action_path) - - -def resolve(context: Context, x: Any) -> Any: - """Recursively resolve substitutions in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for resolution. - x : Union[tuple, list, dict, Substitution, Any] - The structure to resolve. - - Returns - ------- - Any - The resolved structure. - """ - if isinstance(x, tuple): - return tuple([resolve(context, y) for y in x]) - - if isinstance(x, list): - return [resolve(context, y) for y in x] - - if isinstance(x, dict): - return {k: resolve(context, v) for k, v in x.items()} - - if isinstance(x, Substitution): - return x.resolve(context) - - return x - - -def substitute(context: Context, x: Any) -> Any: - """Recursively substitute references in the given structure using the context. - - Parameters - ---------- - context : Context - The context to use for substitution. - x : Union[tuple, list, dict, str, Any] - The structure to substitute. - - Returns - ------- - Any - The substituted structure. - """ - if isinstance(x, tuple): - return tuple([substitute(context, y) for y in x]) - - if isinstance(x, list): - return [substitute(context, y) for y in x] - - if isinstance(x, dict): - return {k: substitute(context, v) for k, v in x.items()} - - if not isinstance(x, str): - return x - - if re.match(r"^\${[\.\w\-]+}$", x): - path = x[2:-1].split(".") - context.will_need_reference(path) - return Reference(context, path) - - return x diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py new file mode 100644 index 000000000..271845c2d --- /dev/null +++ b/src/anemoi/datasets/create/python.py @@ -0,0 +1,174 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import datetime +import json +import re + +from anemoi.utils.dates import frequency_to_string + +# input.python_prelude(code) +# code1 = "\n".join(prelude) +# rich.print(f"Input prelude:\n{code1}") +# code2 = input.to_python() + +# code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" + +# code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) + +# try: +# import black + +# return black.format_str(code, mode=black.Mode()) +# except ImportError: +# LOG.warning("Black not installed, skipping formatting") +# return code +RESERVED_KEYWORDS = ( + "and", + "or", + "not", + "is", + "in", + "if", + "else", + "elif", + "for", + "while", + "return", + "class", + "def", + "with", + "as", + "import", + "from", + "try", + "except", + "finally", + "raise", + "assert", + "break", + "continue", + "pass", +) + + +class PythonCode: + + def call(self, name, argument): + return PythonCall(name, argument) + + def sum(self, actions): + return PythonChain("+", actions) + + def pipe(self, actions): + return PythonChain("|", actions) + + def concat(self, argument): + return PythonConcat(argument) + + +class PythonConcat(PythonCode): + def __init__(self, argument): + self.argument = argument + + def __repr__(self): + return str(self.argument) + + +class PythonCall(PythonCode): + def __init__(self, name, argument): + self.name = name + self.argument = argument + + def __repr__(self): + name = self.name.replace("-", "_") + config = self.argument + + # def convert(obj): + # if isinstance(obj, datetime.datetime): + # return obj.isoformat() + # if isinstance(obj, datetime.date): + # return obj.isoformat() + # if isinstance(obj, datetime.timedelta): + # return frequency_to_string(obj) + # if isinstance(obj, PythonCode): + # return obj + # raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + # config = json.loads(json.dumps(config, default=convert)) + + params = [] + for k, v in config.items(): + if isinstance(k, str): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + # for k, v in extra.items(): + # params.append(f"{k}={v}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" + return f"{self.name}({self.argument})" + + +class PythonChain(PythonCode): + def __init__(self, op, actions): + self.op = op + self.actions = actions + + def __repr__(self): + return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + +def _python(name, config, **extra) -> str: + """Convert the action to Python code. + + Parameters + ---------- + name : str + The name of the action. + config : dict + The configuration for the action. + extra : Any + Additional keyword arguments. + + Returns + ------- + str + The Python code representation of the action. + """ + + name = name.replace("-", "_") + + def convert(obj): + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, datetime.date): + return obj.isoformat() + if isinstance(obj, datetime.timedelta): + return frequency_to_string(obj) + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + config = json.loads(json.dumps(config, default=convert)) + + params = [] + for k, v in config.items(): + if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") + + for k, v in extra.items(): + params.append(f"{k}={v}") + + params = ",".join(params) + return f"r.{name}({params})" + # return f"{name}({config})" diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index a70af9bdd..405fd4713 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -459,12 +459,13 @@ def _mars_date_time_step( A tuple representing the MARS date-time step. """ assert user_date is None, user_date - assert not frequency, frequency steps = (step1 + add_step, step2 + add_step) if steps[0] == 0: steps = (steps[1],) + assert frequency == 0 or frequency == (step2 - step1), frequency + return ( base_date.year * 10000 + base_date.month * 100 + base_date.day, base_date.hour * 100 + base_date.minute, @@ -824,6 +825,11 @@ def _compute_accumulations( step1, step2 = user_accumulation_period assert step1 < step2, user_accumulation_period + if accumulations_reset_frequency is not None: + AccumulationClass = AccumulationFromLastReset + else: + AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep + if data_accumulation_period is None: data_accumulation_period = user_accumulation_period[1] - user_accumulation_period[0] @@ -838,11 +844,6 @@ def _compute_accumulations( base_times = [t // 100 if t > 100 else t for t in base_times] - if accumulations_reset_frequency is not None: - AccumulationClass = AccumulationFromLastReset - else: - AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep - mars_date_time_steps = AccumulationClass.mars_date_time_steps( dates=dates, step1=step1, diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 921469025..1958820c4 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -47,7 +47,7 @@ def constants(context: Any, dates: List[str], template: Dict[str, Any], param: s if len(template) == 0: raise ValueError("Forcings template is empty.") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index e1944e151..8e2977273 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -36,7 +36,7 @@ def forcings(context: Any, dates: List[str], template: str, param: str) -> Any: Loaded forcing data. """ context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=dates, param=param) + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) execute = forcings diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 8bedd4519..f87bf679c 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -124,7 +124,7 @@ def execute( dates = [d.isoformat() for d in dates] for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(*args, date=dates, **kwargs) + paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): if name in kwargs: diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d72d0b3f4..c76a11c9b 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,6 @@ from typing import Any from typing import Callable -from anemoi.datasets.create.input.template import resolve - from ..source import Source from . import source_registry @@ -71,13 +69,14 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, dates) -> Any: + def execute_wrapper(self, context, dates) -> Any: """Wrapper method to call the execute function.""" - args, kwargs = resolve(self.context, (self.args, self.kwargs)) + # args, kwargs = resolve(context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs try: - return execute(self.context, dates, *args, **kwargs) + return execute(context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") diff --git a/src/anemoi/datasets/create/sources/patterns.py b/src/anemoi/datasets/create/sources/patterns.py index f3e6334a8..dc105289f 100644 --- a/src/anemoi/datasets/create/sources/patterns.py +++ b/src/anemoi/datasets/create/sources/patterns.py @@ -79,6 +79,6 @@ def iterate_patterns( kwargs["date"] = dates for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(**kwargs) + paths = Pattern(path).substitute(allow_extra=True, **kwargs) for path in _expand(paths): yield path, dates diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py new file mode 100644 index 000000000..b710bcbbe --- /dev/null +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from . import source_registry +from .xarray import XarraySourceBase + + +@source_registry.register("planetary_computer") +class PlanetaryComputerSource(XarraySourceBase): + """An Xarray data source for the planetary_computer.""" + + emoji = "🪐" + + def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict): + + import planetary_computer + import pystac_client + + self.data_catalog_id = data_catalog_id + self.flavour = kwargs.pop("flavour", None) + self.patch = kwargs.pop("patch", None) + self.options = kwargs.pop("options", {}) + + catalog = pystac_client.Client.open( + f"https://planetarycomputer.microsoft.com/api/stac/{version}/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + self.options["storage_options"] = asset.extra_fields["xarray:storage_options"] + + self.options.update(asset.extra_fields["xarray:open_kwargs"]) + + super().__init__(context, url=asset.href, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py new file mode 100644 index 000000000..d092f08ad --- /dev/null +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -0,0 +1,319 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import Generator +from typing import Optional +from typing import Set +from typing import Tuple + +import numpy as np +import rich +from anemoi.transform.fields import new_field_with_valid_datetime +from anemoi.transform.fields import new_fieldlist_from_list +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +LOG = logging.getLogger(__name__) + + +class Action: + pass + + +class Result: + pass + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: Set[Any] = set() + self.found: Set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: Optional[int] = hour + + def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Optional[Any] = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Optional[Any] = date + + def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +@source_registry.register("repeated_dates") +class RepeatedDatesSource(Source): + + def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.mapper = DateMapper.from_mode(mode, source, kwargs) + self.source = source + + def execute(self, context, group_of_dates): + source = context.create_source(self.source) + + result = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + source_results = source(context, one_date_group) + for field in source_results: + for date in many_dates_group: + result.append(new_field_with_valid_datetime(field, date)) + + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 4f4edb46f..665cfdad3 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -20,7 +20,6 @@ from earthkit.data.core.fieldlist import MultiFieldList from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.data.stores import name_to_zarr_store from ..legacy import legacy_source from .fieldlist import XarrayFieldList @@ -89,37 +88,22 @@ def load_one( The loaded dataset. """ - """ - We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack - zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of - connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs. - See https://github.com/pydata/xarray/issues/8842 - - We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`. - """ - if options is None: options = {} context.trace(emoji, dataset, options, kwargs) - if isinstance(dataset, str) and ".zarr" in dataset: - data = xr.open_zarr(name_to_zarr_store(dataset), **options) - elif "planetarycomputer" in dataset: - store = name_to_zarr_store(dataset) - if "store" in store: - data = xr.open_zarr(**store) - if "filename_or_obj" in store: - data = xr.open_dataset(**store) - else: - data = xr.open_dataset(dataset, **options) + if isinstance(dataset, str) and dataset.endswith(".zarr"): + # If the dataset is a zarr store, we need to use the zarr engine + options["engine"] = "zarr" + + data = xr.open_dataset(dataset, **options) fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: result = fs.sel(**kwargs) else: - print("dates", dates, kwargs) result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) if len(result) == 0: @@ -130,7 +114,7 @@ def load_one( a = ["valid_datetime", k.metadata("valid_datetime", default=None)] for n in kwargs.keys(): a.extend([n, k.metadata(n, default=None)]) - print([str(x) for x in a]) + LOG.warning(f"{[str(x) for x in a]}") if i > 16: break diff --git a/src/anemoi/datasets/create/sources/xarray_support/coordinates.py b/src/anemoi/datasets/create/sources/xarray_support/coordinates.py index 58df7ad65..161f28b8a 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/coordinates.py +++ b/src/anemoi/datasets/create/sources/xarray_support/coordinates.py @@ -95,6 +95,7 @@ class Coordinate: is_member = False is_x = False is_y = False + is_point = False def __init__(self, variable: xr.DataArray) -> None: """Initialize the coordinate. @@ -390,6 +391,13 @@ def normalise(self, value: Any) -> Any: return value +class PointCoordinate(Coordinate): + """Coordinate class for point data.""" + + is_point = True + mars_names = ("point",) + + class LongitudeCoordinate(Coordinate): """Coordinate class for longitude.""" diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index d46613474..6f4ecca7b 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -96,13 +96,10 @@ def __init__(self, owner: Any, selection: Any) -> None: if alias not in self._md: self._md[alias] = value - # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude self._shape = tuple(list(self.selection.shape)[-2:]) if math.prod(self._shape) != math.prod(self.selection.shape): - print(self.selection.ndim, self.selection.shape) - print(self.selection) - raise ValueError("Invalid shape for selection") + raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}") @property def shape(self) -> Tuple[int, int]: diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 4df374148..562401aae 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -26,6 +26,7 @@ from .coordinates import LatitudeCoordinate from .coordinates import LevelCoordinate from .coordinates import LongitudeCoordinate +from .coordinates import PointCoordinate from .coordinates import ScalarCoordinate from .coordinates import StepCoordinate from .coordinates import TimeCoordinate @@ -134,6 +135,10 @@ def _guess(self, coordinate: xr.DataArray, coord: Hashable) -> Coordinate: d: Optional[Coordinate] = None + d = self._is_point(coordinate, attributes) + if d is not None: + return d + d = self._is_longitude(coordinate, attributes) if d is not None: return d @@ -308,9 +313,9 @@ def _x_y_provided(self, x: Any, y: Any, variable: Any) -> Any: return self._grid_cache[(x.name, y.name, dim_vars)] grid_mapping = variable.attrs.get("grid_mapping", None) - if grid_mapping is not None: - print(f"grid_mapping: {grid_mapping}") - print(self.ds[grid_mapping]) + # if grid_mapping is not None: + # print(f"grid_mapping: {grid_mapping}") + # print(self.ds[grid_mapping]) if grid_mapping is None: LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'") @@ -392,6 +397,10 @@ def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Op """ pass + @abstractmethod + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + pass + @abstractmethod def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]: """Checks if the coordinate is a latitude. @@ -550,6 +559,15 @@ def __init__(self, ds: xr.Dataset) -> None: """ super().__init__(ds) + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + if attributes.standard_name in ["cell", "station", "poi", "point"]: + return PointCoordinate(c) + + if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench + return PointCoordinate(c) + + return None + def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]: """Checks if the coordinate is a longitude. @@ -750,6 +768,9 @@ def _is_level(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Option if attributes.standard_name == "air_pressure" and attributes.units == "hPa": return LevelCoordinate(c, "pl") + if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]: + return LevelCoordinate(c, "pl") + if attributes.name == "level": return LevelCoordinate(c, "pl") @@ -759,9 +780,6 @@ def _is_level(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Option if attributes.standard_name == "depth": return LevelCoordinate(c, "depth") - if attributes.name == "vertical" and attributes.units == "hPa": - return LevelCoordinate(c, "pl") - return None def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]: @@ -1040,3 +1058,23 @@ def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optio return EnsembleCoordinate(c) return None + + def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]: + """Checks if the coordinate is a point coordinate using the flavour rules. + + Parameters + ---------- + c : xr.DataArray + The coordinate to check. + attributes : CoordinateAttributes + The attributes of the coordinate. + + Returns + ------- + Optional[PointCoordinate] + The StepCoorPointCoordinateinate if matched, else None. + """ + if self._match(c, "point", attributes): + return PointCoordinate(c) + + return None diff --git a/src/anemoi/datasets/create/sources/xarray_support/patch.py b/src/anemoi/datasets/create/sources/xarray_support/patch.py index 29ea620dd..da4ce6a9c 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/patch.py +++ b/src/anemoi/datasets/create/sources/xarray_support/patch.py @@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any: return ds +def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any: + """Rename variables in the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + renames : dict[str, str] + Mapping from old variable names to new variable names. + + Returns + ------- + Any + The patched dataset. + """ + return ds.rename(renames) + + +def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any: + """Sort the coordinates of the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + sort_coordinates : List[str] + The coordinates to sort. + + Returns + ------- + Any + The patched dataset. + """ + + for name in sort_coordinates: + ds = ds.sortby(name) + return ds + + PATCHES = { "attributes": patch_attributes, "coordinates": patch_coordinates, + "rename": patch_rename, + "sort_coordinates": patch_sort_coordinate, } @@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any: Any The patched dataset. """ - for what, values in patch.items(): + + ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"] + for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])): if what not in PATCHES: raise ValueError(f"Unknown patch type {what!r}") diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 7b0c67439..ee4e20ac7 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -82,8 +82,12 @@ def __init__( self.time = time - self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid) - self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} + self.shape = tuple( + len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point + ) + self.names = { + c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point + } self.by_name = {c.variable.name: c for c in coordinates} # We need that alias for the time dimension diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py index 1784a61fe..1b503ba06 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/data/complement.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import datetime import logging from abc import abstractmethod @@ -19,6 +18,7 @@ from typing import Set from typing import Tuple +import numpy as np from numpy.typing import NDArray from ..grids import nearest_grid_points @@ -85,6 +85,7 @@ def __init__( for v in self._source.variables: if v not in self._target.variables: self._variables.append(v) + LOG.info(f"The following variables will be complemented: {self._variables}") if not self._variables: raise ValueError("Augment: no missing variables") @@ -96,9 +97,11 @@ def variables(self) -> List[str]: @property def statistics(self) -> Dict[str, NDArray[Any]]: - """Returns the statistics of the complemented dataset.""" - index = [self._source.name_to_index[v] for v in self._variables] - return {k: v[index] for k, v in self._source.statistics.items()} + datasets = [self._source, self._target] + return { + k: [d.statistics[k][d.name_to_index[i]] for d in datasets for i in d.variables if i in self.variables] + for k in datasets[0].statistics + } def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]: index = [self._source.name_to_index[v] for v in self._variables] @@ -120,7 +123,11 @@ def shape(self) -> Shape: @property def variables_metadata(self) -> Dict[str, Any]: """Returns the metadata of the variables to be added to the target dataset.""" - return {k: v for k, v in self._source.variables_metadata.items() if k in self._variables} + # Merge the two dicts first + all_meta = {**self._source.variables_metadata, **self._target.variables_metadata} + + # Filter to keep only desired variables + return {k: v for k, v in all_meta.items() if k in self._variables} def check_same_variables(self, d1: Dataset, d2: Dataset) -> None: """Checks if the variables in two datasets are the same. @@ -231,7 +238,7 @@ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]: class ComplementNearest(Complement): """A class to complement a target dataset with variables from a source dataset using nearest neighbor interpolation.""" - def __init__(self, target: Any, source: Any, max_distance: float = None) -> None: + def __init__(self, target: Any, source: Any, max_distance: float = None, k: int = 1) -> None: """Initializes the ComplementNearest class. Parameters @@ -242,17 +249,25 @@ def __init__(self, target: Any, source: Any, max_distance: float = None) -> None The source dataset. max_distance : float, optional The maximum distance for nearest neighbor interpolation, default is None. + k : int, optional + The number of k closest neighbors to consider for interpolation """ super().__init__(target, source) - self._nearest_grid_points = nearest_grid_points( + self.k = k + self._distances, self._nearest_grid_points = nearest_grid_points( self._source.latitudes, self._source.longitudes, self._target.latitudes, self._target.longitudes, max_distance=max_distance, + k=k, ) + if k == 1: + self._distances = np.expand_dims(self._distances, axis=1) + self._nearest_grid_points = np.expand_dims(self._nearest_grid_points, axis=1) + def check_compatibility(self, d1: Dataset, d2: Dataset) -> None: """Checks the compatibility of two datasets for nearest neighbor interpolation. @@ -285,7 +300,19 @@ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]: source_data = self._source[index[0], source_index, index[2], ...] target_data = source_data[..., self._nearest_grid_points] - result = target_data[..., index[3]] + epsilon = 1e-8 # prevent division by zero + weights = 1.0 / (self._distances + epsilon) + weights = weights.astype(target_data.dtype) + weights /= weights.sum(axis=1, keepdims=True) # normalize + + # Reshape weights to broadcast correctly + # Add leading singleton dimensions so it matches target_data shape + while weights.ndim < target_data.ndim: + weights = np.expand_dims(weights, axis=0) + + # Compute weighted average along the last dimension + final_point = np.sum(target_data * weights, axis=-1) + result = final_point[..., index[3]] return apply_index_to_slices_changes(result, changes) @@ -330,6 +357,13 @@ def complement_factory(args: Tuple, kwargs: dict) -> Dataset: "nearest": ComplementNearest, }[interpolation] - complement = Class(target=target, source=source)._subset(**kwargs) + if interpolation == "nearest": + k = kwargs.pop("k", "1") + complement = Class(target=target, source=source, k=k)._subset(**kwargs) + + else: + complement = Class(target=target, source=source)._subset(**kwargs) + + joined = _open_dataset([target, complement]) - return _open_dataset([target, complement], reorder=source.variables) + return _open_dataset(joined, reorder=sorted(joined.variables)) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index f8592bb3d..0267022d1 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -179,6 +179,19 @@ def __subset(self, **kwargs: Any) -> "Dataset": if "start" in kwargs or "end" in kwargs: start = kwargs.pop("start", None) end = kwargs.pop("end", None) + padding = kwargs.pop("padding", None) + + if padding: + if padding != "empty": + raise ValueError(f"Only 'empty' padding is supported, got {padding=}") + from .padded import Padded + + frequency = kwargs.pop("frequency", self.frequency) + return ( + Padded(self, start, end, frequency, dict(start=start, end=end, frequency=frequency)) + ._subset(**kwargs) + .mutate() + ) from .subset import Subset @@ -724,6 +737,9 @@ def grids(self) -> TupleIndex: """Return the grid shape of the dataset.""" return (self.shape[-1],) + def empty_item(self) -> NDArray[Any]: + return np.zeros((*self.shape[1:-1], 0), dtype=self.dtype) + def _check(self) -> None: """Check for overridden private methods in the dataset.""" common = Dataset.__dict__.keys() & self.__class__.__dict__.keys() @@ -1075,3 +1091,16 @@ def get_dataset_names(self, names: Set[str]) -> None: The dataset names. """ pass + + def get_latitudes(self, i): + return self.get_aux(i)[0] + + def get_longitudes(self, i): + return self.get_aux(i)[1] + + def get_timedeltas(self, i): + return self.get_aux(i)[2] + + def get_aux(self, i): + # need to decide if Fields datasets need to implement this + raise NotImplementedError(f"get_aux is not implemented for this dataset, {type(self)}") diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 09d936efa..bb922e96d 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -330,8 +330,14 @@ def check_same_grid(self, d1: Dataset, d2: Dataset) -> None: ValueError If the grids are not the same. """ - if (d1.latitudes != d2.latitudes).any() or (d1.longitudes != d2.longitudes).any(): - raise ValueError(f"Incompatible grid ({d1} {d2})") + + # note: not a proper implementation, should be handled + # in a more consolidated way ... + rtol = 1.0e-7 + if not np.allclose(d1.latitudes, d2.latitudes, rtol=rtol) or not np.allclose( + d1.longitudes, d2.longitudes, rtol=rtol + ): + raise ValueError(f"Incompatible grid ({d1.longitudes} {d2.longitudes})") def check_same_shape(self, d1: Dataset, d2: Dataset) -> None: """Checks if the shapes of two datasets are the same. diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 2f9f77b25..b5523ef85 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -11,6 +11,7 @@ import calendar import datetime import logging +import os from pathlib import PurePath from typing import TYPE_CHECKING from typing import Any @@ -22,6 +23,7 @@ import numpy as np import zarr +from anemoi.utils.config import load_any_dict_format from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray @@ -108,7 +110,10 @@ def round_datetime(d: np.datetime64, dates: NDArray[np.datetime64], up: bool) -> def _as_date( - d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64], last: bool + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + last: bool, + frequency: Optional[datetime.timedelta] = None, ) -> np.datetime64: """Convert a date to a numpy datetime64 object, rounding to the nearest date in a list of dates. @@ -120,6 +125,8 @@ def _as_date( The list of dates. last : bool Whether to round to the last date. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- @@ -142,30 +149,49 @@ def _as_date( pass if isinstance(d, int): + delta = frequency + if delta is None: + delta = np.timedelta64(1, "s") + delta = np.timedelta64(delta, "s") + if len(str(d)) == 4: year = d if last: - return _as_date(np.datetime64(f"{year:04}-12-31T23:59:59"), dates, last) + year = year + 1 + npdate = np.datetime64(f"{year:04}-01-01T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last, frequency) if len(str(d)) == 6: year = d // 100 month = d % 100 if last: - _, last_day = calendar.monthrange(year, month) - return _as_date(np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59"), dates, last) + month = month + 1 + if month > 12: + month = 1 + year += 1 + npdate = np.datetime64(f"{year:04}-{month:02}-01T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last, frequency) if len(str(d)) == 8: year = d // 10000 month = (d % 10000) // 100 day = d % 100 if last: - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59"), dates, last) + day = day + 1 + if day > calendar.monthrange(year, month)[1]: + day = 1 + month += 1 + if month > 12: + month = 1 + year += 1 + npdate = np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00") + return _as_date(npdate - delta, dates, last, frequency) else: - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last, frequency) if isinstance(d, str): @@ -201,19 +227,20 @@ def isfloat(s: str) -> bool: np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}"), dates, last, + frequency, ) if "-" in d: assert ":" not in d bits = d.split("-") if len(bits) == 1: - return _as_date(int(bits[0]), dates, last) + return _as_date(int(bits[0]), dates, last, frequency) if len(bits) == 2: - return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last) + return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last, frequency) if len(bits) == 3: - return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last) + return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last, frequency) if ":" in d: assert len(d) == 5 @@ -225,12 +252,16 @@ def isfloat(s: str) -> bool: month = first.month day = first.day - return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last) + return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last, frequency) raise NotImplementedError(f"Unsupported date: {d} ({type(d)})") -def as_first_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64]) -> np.datetime64: +def as_first_date( + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + frequency: Optional[datetime.timedelta] = None, +) -> np.datetime64: """Convert a date to the first date in a list of dates. Parameters @@ -239,16 +270,22 @@ def as_first_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArr The date to convert. dates : NDArray[np.datetime64] The list of dates. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- np.datetime64 The first date. """ - return _as_date(d, dates, last=False) + return _as_date(d, dates, last=False, frequency=frequency) -def as_last_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArray[np.datetime64]) -> np.datetime64: +def as_last_date( + d: Union[int, str, np.datetime64, datetime.date], + dates: NDArray[np.datetime64], + frequency: Optional[datetime.timedelta] = None, +) -> np.datetime64: """Convert a date to the last date in a list of dates. Parameters @@ -257,13 +294,15 @@ def as_last_date(d: Union[int, str, np.datetime64, datetime.date], dates: NDArra The date to convert. dates : NDArray[np.datetime64] The list of dates. + frequency : Optional[datetime.timedelta] + The frequency of the dataset. Returns ------- np.datetime64 The last date. """ - return _as_date(d, dates, last=True) + return _as_date(d, dates, last=True, frequency=frequency) def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple["Dataset", Dict[str, Any]]: @@ -317,6 +356,18 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - from .stores import Zarr from .stores import zarr_lookup + if isinstance(a, str) and len(a.split(".")) in [2, 3]: + + metadata_path = os.path.join(a, "metadata.json") + if os.path.exists(metadata_path): + metadata = load_any_dict_format(metadata_path) + if "backend" not in metadata: + raise ValueError(f"Metadata for {a} does not contain 'backend' key") + + from anemoi.datasets.data.records import open_records_dataset + + return open_records_dataset(a, backend=metadata["backend"]) + if isinstance(a, Dataset): return a.mutate() @@ -454,6 +505,13 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": for a in args: sets.append(_open(a)) + if "observations" in kwargs: + from .observations import observations_factory + + assert not sets, sets + + return observations_factory(args, kwargs).mutate() + if "xy" in kwargs: # Experimental feature, may be removed from .xy import xy_factory diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py new file mode 100644 index 000000000..b5f8ec5e9 --- /dev/null +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -0,0 +1,316 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from functools import cached_property +from typing import Any +from typing import Dict +from typing import Tuple + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.data.dataset import Dataset + +from ..debug import Node + +LOG = logging.getLogger(__name__) + + +def round_datetime(dt, frequency, up=True): + dt = dt.replace(minute=0, second=0, microsecond=0) + hour = dt.hour + if hour % frequency != 0: + dt = dt.replace(hour=(hour // frequency) * frequency) + dt = dt + datetime.timedelta(hours=frequency) + return dt + + +def make_dates(start, end, frequency): + if isinstance(start, np.datetime64): + start = start.astype(datetime.datetime) + if isinstance(end, np.datetime64): + end = end.astype(datetime.datetime) + + dates = [] + current_date = start + while current_date <= end: + dates.append(current_date) + current_date += frequency + + dates = [np.datetime64(d, "s") for d in dates] + dates = np.array(dates, dtype="datetime64[s]") + return dates + + +class ObservationsBase(Dataset): + resolution = None + + @cached_property + def shape(self): + return (len(self.dates), len(self.variables), "dynamic") + + def empty_item(self): + return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32) + + def metadata(self): + return dict(observations_datasets="obs datasets currenty have no metadata") + + def _check(self): + pass + + def __len__(self): + return len(self.dates) + + def tree(self): + return Node(self) + + def __getitem__(self, i): + if isinstance(i, int): + return self.getitem(i) + + # The following may would work but is likely to change in the future + # if isinstance(i, slice): + # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))] + # if isinstance(i, list): + # return [self.getitem(j) for j in i] + + raise ValueError( + ( + f"Expected int, got {i} of type {type(i)}. Only int is supported to index " + "observations datasets. Please use a second [] to select part of the data [i][a,b,c]" + ) + ) + + @property + def variables(self): + raise NotImplementedError() + + def collect_input_sources(self): + LOG.warning("collect_input_sources method is not implemented") + return [] + + def constant_fields(self): + LOG.warning("constant_fields method is not implemented") + return [] + + @property + def dates(self): + return self._dates + + @property + def dtype(self): + return np.float32 + + @property + def field_shape(self): + return self.shape[1:] + + @property + def frequency(self): + assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}" + return self._frequency + + @property + def latitudes(self): + raise NotImplementedError("latitudes property is not implemented") + + @property + def longitudes(self): + raise NotImplementedError("longitudes property is not implemented") + + @property + def missing(self): + return [] + + def statistics_tendencies(self): + raise NotImplementedError("statistics_tendencies method is not implemented") + + def variables_metadata(self): + raise NotImplementedError("variables_metadata method is not implemented") + + +class ObservationsZarr(ObservationsBase): + def __init__(self, dataset, frequency=None, window=None): + import zarr + + if isinstance(dataset, zarr.hierarchy.Group): + dataset = dataset._store.path + + from ..stores import zarr_lookup + + dataset = zarr_lookup(dataset) + self.path = dataset + assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" + + if frequency is None: + frequency = self._probe_attributes.get("frequency") + # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}") + if frequency is None: + frequency = "6h" + # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}") + self._frequency = frequency_to_timedelta(frequency) + assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}" + if self.frequency.total_seconds() != 6 * 3600: + LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown") + + frequency_hours = int(self.frequency.total_seconds() // 3600) + assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}" + + if window is None: + window = (-frequency_hours, 0) + if window != (-frequency_hours, 0): + raise ValueError("For now, only window = (- frequency, 0) are supported") + + self.window = window + + start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"] + start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end) + start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours) + + self._dates = make_dates(start + self.frequency, end, self.frequency) + + first_window_begin = start.strftime("%Y%m%d%H%M%S") + first_window_begin = int(first_window_begin) + # last_window_end must be the end of the time window of the last item + last_window_end = int(end.strftime("%Y%m%d%H%M%S")) + + from .legacy_obs_dataset import ObsDataset + + args = [self.path, first_window_begin, last_window_end] + kwargs = dict( + len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item + step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items + ) + self.forward = ObsDataset(*args, **kwargs) + + assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}" + assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}" + + if len(self.forward) != len(self.dates): + raise ValueError( + ( + f"Dates are not consistent with the number of items in the dataset. " + f"The dataset contains {len(self.forward)} time windows. " + f"This is not compatible with the " + f"{len(self.dates)} requested dates with frequency={frequency_hours}" + f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} " + ) + ) + + @property + def source(self): + return self.path + + def get_dataset_names(self): + name = os.path.basename(self.path) + if name.endswith(".zarr"): + name = name[:-5] + return [name] + + @cached_property + def _probe_attributes(self): + import zarr + + z = zarr.open(self.path, mode="r") + return dict(z.data.attrs) + + def get_aux(self, i): + data = self.forward[i] + + latitudes = data[:, self.name_to_index["__latitudes"]].numpy() + longitudes = data[:, self.name_to_index["__longitudes"]].numpy() + + reference = self.dates[i] + times = self.forward.get_dates(i) + if str(times.dtype) != "datetime64[s]": + LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ") + times = times.astype("datetime64[s]") + assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}" + timedeltas = times - reference + + assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" + assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" + + return latitudes, longitudes, timedeltas + + def getitem(self, i): + data = self.forward[i] + + data = data.numpy().astype(np.float32) + assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}" + data = data.T + + if not data.size: + data = self.empty_item() + assert ( + data.shape[0] == self.shape[1] + ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}" + return data + + @cached_property + def variables(self): + colnames = self.forward.colnames + variables = [] + for n in colnames: + if n.startswith("obsvalue_"): + n = n.replace("obsvalue_", "") + if n == "latitude" or n == "lat": + assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}" + variables.append("__latitudes") + continue + if n == "longitude" or n == "lon": + assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}" + variables.append("__longitudes") + continue + assert not n.startswith("__"), f"Invalid name {n} found in {colnames}" + variables.append(n) + return variables + + @property + def name_to_index(self): + return {n: i for i, n in enumerate(self.variables)} + + @property + def statistics(self): + mean = self.forward.properties["means"] + mean = np.array(mean, dtype=np.float32) + + var = self.forward.properties["vars"] + var = np.array(var, dtype=np.float32) + stdev = np.sqrt(var) + + minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32) + maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32) + + assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}" + assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}" + assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}" + assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}" + return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum) + + def tree(self): + return Node( + self, + [], + path=self.path, + frequency=self.frequency, + ) + + def __repr__(self): + return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})" + + +def observations_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> ObservationsBase: + observations = kwargs.pop("observations") + + if not isinstance(observations, dict): + observations = dict(dataset=observations) + dataset = ObservationsZarr(**observations) + return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/data/observations/legacy_obs_dataset.py b/src/anemoi/datasets/data/observations/legacy_obs_dataset.py new file mode 100644 index 000000000..85ab51583 --- /dev/null +++ b/src/anemoi/datasets/data/observations/legacy_obs_dataset.py @@ -0,0 +1,200 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import numpy as np +import pandas as pd +import torch +import zarr +from torch.utils.data import Dataset + +LOG = logging.getLogger(__name__) + + +class ObsDataset(Dataset): + + def __init__( + self, + filename: str, + start: int, + end: int, + len_hrs: int, + step_hrs: int = None, + select: list[str] = None, + drop: list[str] = None, + ) -> None: + + self.filename = filename + self.z = zarr.open(filename, mode="r") + self.data = self.z["data"] + self.dt = self.z["dates"] # datetime only + self.hrly_index = self.z["idx_197001010000_1"] + self.colnames = self.data.attrs["colnames"] + self.selected_colnames = self.colnames + self.selected_cols_idx = np.arange(len(self.colnames)) + self.len_hrs = len_hrs + self.step_hrs = step_hrs if step_hrs else len_hrs + + # Create index for samples + self._setup_sample_index(start, end, self.len_hrs, self.step_hrs) + + self._load_properties() + + if select: + self.select(select) + + if drop: + self.drop(drop) + + def __getitem__( + self, + idx: int, + ) -> torch.tensor: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + data = self.data.oindex[start_row:end_row, self.selected_cols_idx] + + return torch.from_numpy(data) + + def __len__(self) -> int: + + return len(self.indices_start) + + def get_dates( + self, + idx: int, + ) -> np.ndarray: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + dates = self.dt.oindex[start_row:end_row] + + assert len(dates.shape) == 2, dates.shape + dates = dates[:, 0] + + if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"): + dates = dates.astype("datetime64[s]") + + return dates + + def get_df(self, idx: int) -> pd.DataFrame: + """Convenience function to return data for sample idx packaged in a pandas DataFrame""" + + d = self.__getitem__(idx) + + df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx]) + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + dts = self.dt[start_row:end_row, :] + df["datetime"] = dts + + return df + + def select(self, cols_list: list[str]) -> None: + """Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + """ + self.selected_colnames = cols_list + self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list]) + + def drop(self, cols_to_drop: list[str]) -> None: + """Allow user to drop specific columns from the dataset. + Get functions no longer return data for these columns after being set. + """ + mask = [name not in cols_to_drop for name in self.selected_colnames] + + self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep] + self.selected_cols_idx = self.selected_cols_idx[mask] + + def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: + """Returns a tuple of datetime objects describing the start and end times of the sample at position idx.""" + + if idx < 0: + idx = len(self) + idx + + time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1) + time_end = min( + self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)), + self.end_dt, + ) + + return (np.datetime64(time_start), np.datetime64(time_end)) + + def first_sample_with_data(self) -> int: + """Returns the position of the first sample which contains data.""" + return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None + + def last_sample_with_data(self) -> int: + """Returns the position of the last sample which contains data.""" + if self.indices_end.max() == 0: + last_sample = None + else: + last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1) + + return last_sample + + def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None: + """Dataset is divided into samples; + - each n_hours long + - sample 0 starts at start (yyyymmddhhmm) + - index array has one entry for each sample; contains the index of the first row + containing data for that sample + """ + + try: + from obsdata.config import config + + assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000" + except ImportError: + pass + base_yyyymmddhhmm = 197001010000 + + assert start > base_yyyymmddhhmm, ( + f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n" + f" Current value: {start}" + ) + + format_str = "%Y%m%d%H%M%S" + base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str) + self.start_dt = datetime.datetime.strptime(str(start), format_str) + self.end_dt = datetime.datetime.strptime(str(end), format_str) + + # Calculate hours since the base date for the requested dataset ranges + diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600) + diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600) + + # Find elements that need to be extracted from the hourly index + # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample + sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs) + sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end) + + # Initialize local index arrays + self.indices_start = np.zeros(sample_starts.shape, dtype=int) + self.indices_end = np.zeros(self.indices_start.shape, dtype=int) + + max_hrly_index = self.hrly_index.shape[0] - 1 + valid_start_mask = sample_starts <= max_hrly_index + valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index) + + # Copy elements from the hrly_index into the local index + self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]] + self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0) + + def _load_properties(self) -> None: + + self.properties = {} + + self.properties["means"] = self.data.attrs["means"] + self.properties["vars"] = self.data.attrs["vars"] + self.properties["data_idxs"] = self.data.attrs["data_idxs"] + self.properties["obs_id"] = self.data.attrs["obs_id"] diff --git a/src/anemoi/datasets/data/observations/multi.py b/src/anemoi/datasets/data/observations/multi.py new file mode 100644 index 000000000..af5c02e71 --- /dev/null +++ b/src/anemoi/datasets/data/observations/multi.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os + +from anemoi.datasets.data import open_dataset + +LOG = logging.getLogger(__name__) + + +class LegacyDatasets: + def __init__(self, paths, start=None, end=None, **kwargs): + self.paths = paths + + if not start or not end: + print( + "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets" + ) + lst = [self._open_dataset(p, **kwargs) for p in paths] + start = min([d.dates[0] for d in lst]) + end = max([d.dates[-1] for d in lst]) + + self._datasets = { + os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty") + for p in paths + } + + first = list(self._datasets.values())[0] + for name, dataset in self._datasets.items(): + if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]: + raise ValueError("Datasets have different start and end times") + if dataset.frequency != first.frequency: + raise ValueError("Datasets have different frequencies") + + self._keys = self._datasets.keys + + self._first = list(self._datasets.values())[0] + + def _open_dataset(self, p, **kwargs): + if p.startswith("observations-"): + return open_dataset(observations=p, **kwargs) + else: + print("❗ Opening non-observations dataset:", p) + return open_dataset(p, **kwargs) + + def items(self): + return self._datasets.items() + + @property + def dates(self): + return self._first.dates + + def __len__(self): + return len(self._first) + + def __getitem__(self, i): + return {k: d[i] for k, d in self._datasets.items()} diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py new file mode 100644 index 000000000..dcff11bae --- /dev/null +++ b/src/anemoi/datasets/data/padded.py @@ -0,0 +1,227 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +from functools import cached_property +from typing import Any +from typing import Dict +from typing import Set + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta +from numpy.typing import NDArray + +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.misc import as_first_date +from anemoi.datasets.data.misc import as_last_date + +LOG = logging.getLogger(__name__) + + +class Padded(Forwards): + _before: int = 0 + _after: int = 0 + _inside: int = 0 + + def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reason: Dict[str, Any]) -> None: + """Create a padded subset of a dataset. + + Attributes: + dataset (Dataset): The dataset to subset. + start (str): The start date of the subset. + end (str): The end date of the subset. + frequency (str): The frequency of the subset. + reason (Dict[str, Any]): The reason for the padding. + """ + + self.reason = {k: v for k, v in reason.items() if v is not None} + + if frequency is None: + frequency = dataset.frequency + + self._frequency = frequency_to_timedelta(frequency) + + if start is None: + # default is to start at the first date + start = dataset.dates[0] + else: + start = as_first_date(start, None, frequency=self._frequency) + + if end is None: + # default is to end at the last date + end = dataset.dates[-1] + else: + end = as_last_date(end, None, frequency=self._frequency) + + assert isinstance(dataset.dates[0], np.datetime64), (dataset.dates[0], type(dataset.dates[0])) + + # 'start' is the requested start date + # 'end' is the requested end date + # 'first' is the first date of the dataset + # 'last' is the last date of the dataset + first = dataset.dates[0] + last = dataset.dates[-1] + + timedelta = np.array([frequency], dtype="timedelta64[s]")[0] + + parts = [] + before_end = min(end + timedelta, first) + before_part = np.arange(start, before_end, timedelta) + if start < first: + # if the start date is before the first date of the dataset, there is a "before" part + assert len(before_part) > 0, (start, first, before_end) + parts.append(before_part) + self._before = len(before_part) + if start >= first: + # if the start date is the first date of the dataset, there is no "before" part + assert len(before_part) == 0, (start, first, before_end) + self._before = 0 + + # if the start date is before the last date of the dataset + # and the end date is after the first date of the dataset + # there is an "inside" part + if start < last and end > first: + inside_start = max(start, first) + inside_end = min(end, last) + self.dataset = dataset._subset(start=inside_start, end=inside_end) + inside_part = self.dataset.dates + parts.append(inside_part) + self._inside = len(inside_part) + else: + self.dataset = dataset # still needed to get the empty_item + self._inside = 0 + + after_start = max(start, last + timedelta) + after_part = np.arange(after_start, end + timedelta, timedelta) + if end > last: + # if the end date is after the last date of the dataset, there is an "after" part + assert len(after_part) > 0, (end, last, after_start) + parts.append(after_part) + self._after = len(after_part) + if end <= last: + assert len(after_part) == 0, (end, last, after_start) + self._after = 0 + + self._dates = np.hstack(parts) + + assert len(self._dates) == self._before + self._inside + self._after, ( + len(self._dates), + self._before, + self._inside, + self._after, + ) + + assert self._dates[0] == start, (self._dates[0], start) + assert self._dates[-1] == end, (self._dates[-1], end) + + # Forward other properties to the super dataset + super().__init__(dataset) + + @debug_indexing + def __getitem__(self, n: FullIndex) -> NDArray[Any]: + if isinstance(n, tuple): + return self._get_tuple(n) + + if isinstance(n, slice): + return self._get_slice(n) + + if self._i_out_of_range(n): + return self.empty_item() + + return self.dataset[n - self._before] + + def _i_out_of_range(self, n: FullIndex) -> bool: + """Check if the index is out of range.""" + if 0 <= n < self._before: + return True + + if (self._before + self._inside) <= n < (self._before + self._inside + self._after): + return True + return False + + @debug_indexing + def _get_slice(self, s: slice) -> NDArray[Any]: + LOG.warning("Padded subset does not support slice indexing, returning a list") + return [self[i] for i in range(*s.indices(self._len))] + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, n: TupleIndex) -> NDArray[Any]: + LOG.warning("Padded subset does not support tuple indexing, returning a list") + return [self[i] for i in n] + + def empty_item(self): + return self.dataset.empty_item() + + def get_aux(self, i: FullIndex) -> NDArray[np.timedelta64]: + if self._i_out_of_range(i): + arr = np.array([], dtype=np.float32) + aux = arr, arr, arr + else: + aux = self.dataset.get_aux(i - self._before) + + assert len(aux) == 3, (aux, i) + return aux + + def __len__(self) -> int: + return len(self._dates) + + @property + def frequency(self) -> datetime.timedelta: + """Get the frequency of the subset.""" + return self._frequency + + @property + def dates(self) -> NDArray[np.datetime64]: + return self._dates + + @property + def shape(self) -> Shape: + return (len(self.dates),) + self.dataset.shape[1:] + + @cached_property + def missing(self) -> Set[int]: + raise NotImplementedError("Need to decide whether to include the added dates as missing or not") + # return self.forward.missing + + def tree(self) -> Node: + """Get the tree representation of the subset. + + Returns: + Node: The tree representation of the subset. + """ + return Node(self, [self.dataset.tree()], **self.reason) + + def forwards_subclass_metadata_specific(self) -> Dict[str, Any]: + """Get the metadata specific to the forwards subclass. + + Returns: + Dict[str, Any]: The metadata specific to the forwards subclass. + """ + return { + # "indices": self.indices, + "reason": self.reason, + } + + def __repr__(self) -> str: + """Get the string representation of the subset. + + Returns: + str: The string representation of the subset. + """ + return f"Padded({self.forward}, {self.dates[0]}...{self.dates[-1]}, frequency={self.frequency})" diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py new file mode 100644 index 000000000..f569a4105 --- /dev/null +++ b/src/anemoi/datasets/data/records/__init__.py @@ -0,0 +1,442 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from collections import defaultdict +from functools import cached_property + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.data.records.backends import backend_factory + +LOG = logging.getLogger(__name__) + +if os.environ.get("ANEMOI_DATASET_COUNTER", "0") == "1": + + def counter(func): + def wrapper(*args, **kwargs): + count = 0 + for i in range(len(args[0])): + count += 1 + yield func(*args, **kwargs) + print(f"Counter: {count} calls to {func.__name__}") + + return wrapper + +else: + + def counter(func): + return func + + +def open_records_dataset(dataset, **kwargs): + if not dataset.endswith(".vz"): + raise ValueError("dataset must be a .vz file") + return RecordsDataset(dataset, **kwargs) + + +class BaseRecordsDataset: + + def __getitem__(self, i): + if isinstance(i, str): + return self._getgroup(i) + + if isinstance(i, int): + return self._getrecord(i) + + raise ValueError(f"Invalid index {i}, must be int or str") + + def _getgroup(self, i): + return Tabular(self, i) + + def _getrecord(self, i): + return Record(self, i) + + def _load_data(self, i): + raise NotImplementedError("Must be implemented in subclass") + + @property + def start_date(self): + return self.dates[0] + + @property + def end_date(self): + if len(self.dates) == 0: + return None + if len(self.dates) == 1: + return self.dates[0] + return self.dates[-1] + + @property + def groups(self): + return tuple(self.keys()) + + def _subset(self, **kwargs): + start = kwargs.pop("start", None) + end = kwargs.pop("end", None) + frequency = kwargs.pop("frequency", self.frequency) + + if frequency != self.frequency: + raise ValueError(f"Changing the frequency {frequency} (from {self.frequency}) is not implemented yet.") + + if start is not None or end is not None: + + def _dates_to_indices(start, end): + from anemoi.datasets.data.misc import as_first_date + from anemoi.datasets.data.misc import as_last_date + + start = self.dates[0] if start is None else as_first_date(start, self.dates) + end = self.dates[-1] if end is None else as_last_date(end, self.dates) + + return [i for i, date in enumerate(self.dates) if start <= date <= end] + + return RecordsSubset( + self, _dates_to_indices(start, end), {"start": start, "end": end, "frequency": frequency} + )._subset(**kwargs) + + select = kwargs.pop("select", None) + if select is not None: + return Select(self, select)._subset(**kwargs) + + return self + + def mutate(self): + return self + + def _check(self): + pass + + @property + def name_to_index(self): + raise NotImplementedError("Must be implemented in subclass") + + +class RecordsForward(BaseRecordsDataset): + def __init__(self, dataset): + self.forward = dataset + + @property + def statistics(self): + return self.forward.statistics + + @property + def variables(self): + return self.forward.variables + + def keys(self): + return self.forward.keys() + + @property + def dates(self): + return self.forward.dates + + @property + def name_to_index(self): + return self.forward.name_to_index + + @property + def frequency(self): + return self.forward.frequency + + @property + def shapes(self): + return self.forward.shapes + + def __len__(self): + return len(self.forward) + + +def match_variable(lst, group, name): + # lst must be a list of strings with dots (if there is no dot, it is automatically added at the end) + # - a dict with keys as group and values as list of strings + + if name == "__latitudes" or name == "__longitudes": + # This should disappear in the future, when we stop saving a duplicate of lat/lon in the data + return False + + lst = [k if "." in k else f"{k}.*" for k in lst] + + key = f"{group}.{name}" + if key in lst: + return True + if f"{group}.*" in lst: + return True + if f"*.{name}" in lst: + return True + if "*" in lst: + return True + return False + + +class Select(RecordsForward): + def __init__(self, dataset, select): + super().__init__(dataset) + + self.dataset = dataset + + if isinstance(select, dict): + # if a dict is provided, make it a list of strings with '.' + sel = [] + for group, d in select.items(): + for name in d: + sel.append(f"{group}.{name}") + select = sel + + self._select = select + + self.reason = {"select": select} + self._build_indices_and_name_to_index() + + def _build_indices_and_name_to_index(self): + indices = {} + name_to_index = {} + variables = {} + + # this should be revisited to take into account the order requested by the user + # see what is done in the fields datasets + for group, names in self.dataset.variables.items(): + ind = np.zeros(len(names), dtype=bool) + count = 0 + for j, name in enumerate(names): + if self.match_variable(group, name): + assert j == names.index(name), f"Invalid index {j} for {name} in {group}" + ind[j] = True + indices[group] = ind + if group not in name_to_index: + name_to_index[group] = {} + assert group not in variables, (group, j, name, variables, name_to_index) + variables[group] = [] + name_to_index[group][name] = count + variables[group].append(name) + count += 1 + assert np.sum(ind) == count, f"Mismatch in {group}: {names}, {ind}" + self._indices = indices + self._name_to_index = name_to_index + self._variables = variables + + def match_variable(self, *args, **kwargs): + return match_variable(self._select, *args, **kwargs) + + def keys(self): + return self._indices.keys() + + def _load_data(self, i): + forward = self.dataset._load_data(i) + data = {} + for k, v in self._indices.items(): + data[f"latitudes:{k}"] = forward[f"latitudes:{k}"] + data[f"longitudes:{k}"] = forward[f"longitudes:{k}"] + data[f"timedeltas:{k}"] = forward[f"timedeltas:{k}"] + data[f"metadata:{k}"] = forward[f"metadata:{k}"] + for k, v in self._indices.items(): + data[f"data:{k}"] = forward[f"data:{k}"][v] # notice the [v] here + return data + + @property + def name_to_index(self): + return self._name_to_index + + @property + def variables(self): + return self._variables + + @property + def statistics(self): + dic = {} + for group, v in self._indices.items(): + stats = self.dataset.statistics[group] + dic[group] = {key: stats[key][v] for key in stats.keys()} + assert "mean" in dic[group], f"Missing mean in {dic[group]}" + return dic + + +class RecordsSubset(RecordsForward): + def __init__(self, dataset, indices, reason): + super().__init__(dataset) + self.dataset = dataset + self.reason = reason + self._indices = indices + + @cached_property + def dates(self): + return self.dataset.dates[self._indices] + + def _load_data(self, i): + return self.dataset._load_data(self._indices[i]) + + def __len__(self): + return len(self._indices) + + +class RecordsDataset(BaseRecordsDataset): + + def __init__(self, path, backend="npz1", **kwargs): + if kwargs: + print("Warning: ignoring additional kwargs", kwargs) + self.path = path + self.backend = backend_factory(backend, path, **kwargs) + self.keys = self.metadata["sources"].keys + + @property + def frequency(self): + frequency = self.metadata["frequency"] + frequency = frequency_to_timedelta(frequency) + return frequency + + @property + def name_to_index(self): + return self.metadata["name_to_index"] + + @property + def variables(self): + return self.metadata["variables"] + + @cached_property + def metadata(self): + return self.backend.read_metadata() + + @property + def shapes(self): + return self.metadata["shapes"] + + def items(self, *args, **kwargs): + return {k: Tabular(self, k) for k in self.keys()}.items(*args, **kwargs) + + @cached_property + def statistics(self): + return self.backend.read_statistics() + + def __len__(self): + return len(self.dates) + + @property + def start_date(self): + date = self.metadata["start_date"] + return datetime.datetime.fromisoformat(date) + + @property + def end_date(self): + date = self.metadata["end_date"] + return datetime.datetime.fromisoformat(date) + + @cached_property + def dates(self): + result = [] + delta = self.frequency + d = self.start_date + while d <= self.end_date: + result.append(d) + d += delta + return np.array(result) + + @counter + def _load_data(self, i): + return self.backend.read(i) + + def check(self, i=None): + if i is not None: + dict_of_sets = defaultdict(set) + for key in self._load_data(i).keys(): + kind, group = key.split(":") + dict_of_sets[group].add(kind) + for group, s in dict_of_sets.items(): + assert s == {"latitudes", "longitudes", "timedeltas", "metadata", "data"}, f"Invalid keys {s}" + + +class Record(dict): + def __init__(self, dataset, n): + self.dataset = dataset + self.n = n + + def __repr__(self): + d = {group: "" for group in self.dataset.keys()} + return str(d) + + def items(self): + return self._payload.items() + + @property + def name_to_index(self): + return self.dataset.name_to_index + + @cached_property + def _payload(self): + payload = self.dataset._load_data(self.n) + for k in payload.keys(): + assert len(k.split(":")) == 2, f"Invalid key {k}" + return payload + + def keys(self): + return self.dataset.keys() + + def __getitem__(self, group): + return self._payload["data:" + group] + + def _get_aux(self, name): + try: + return {k: self._payload[name + ":" + k] for k in self.keys()} + except KeyError as e: + e.add_note(f"Available keys are {self._payload.keys()}") + raise + + @property + def latitudes(self): + return self._get_aux("latitudes") + + @property + def longitudes(self): + return self._get_aux("longitudes") + + @property + def timedeltas(self): + return self._get_aux("timedeltas") + + @property + def statistics(self): + return self.dataset.statistics + + @property + def groups(self): + return tuple(self.keys()) + + +class Tabular: + def __init__(self, dataset, name): + self.dataset = dataset + self.name = name + + @property + def group(self): + return self.name + + def __getitem__(self, i): + return self.__get(i, "data") + + def __get(self, i, k): + payload = self.dataset._load_data(i) + try: + return payload[k + ":" + self.name] + except KeyError: + print(f"KeyError to retrieve {self.name} available groups are", payload.keys()) + raise + + @property + def variables(self): + return self.dataset.variables[self.name] + + @property + def name_to_index(self): + return self.dataset.name_to_index[self.name] + + @property + def statistics(self): + return self.dataset.statistics[self.name] diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py new file mode 100644 index 000000000..6971cafd4 --- /dev/null +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -0,0 +1,157 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +import os + +import numpy as np + + +class Backend: + def __init__(self, path, **kwargs): + self.path = path + self.kwargs = kwargs + + def read(self, i, **kwargs): + raise NotImplementedError("Must be implemented in subclass") + + def read_metadata(self): + raise NotImplementedError("Must be implemented in subclass") + + def read_statistics(self): + raise NotImplementedError("Must be implemented in subclass") + + +class Npz1Backend(Backend): + def read(self, i, **kwargs): + path = os.path.join(self.path, "data", str(int(i / 10)), f"{i}.npz") + with open(path, "rb") as f: + return dict(np.load(f)) + + def read_metadata(self): + with open(os.path.join(self.path, "metadata.json"), "r") as f: + return json.load(f) + + def read_statistics(self): + path = os.path.join(self.path, "statistics.npz") + dic = {} + for k, v in dict(np.load(path)).items(): + key, group = k.split(":") + if group not in dic: + dic[group] = {} + dic[group][key] = v + return dic + + +class Npz2Backend(Backend): + def read(self, i, **kwargs): + path = os.path.join(self.path, "data_", str(int(i / 10)), f"{i}_.npz") + with open(path, "rb") as f: + return dict(np.load(f)) + + def read_metadata(self): + with open(os.path.join(self.path, "metadata.json"), "r") as f: + return json.load(f) + + def read_statistics(self): + path = os.path.join(self.path, "statistics_.npz") + dic = {} + for k, v in dict(np.load(path)).items(): + key, group = k.split(":") + if group not in dic: + dic[group] = {} + dic[group][key] = v + return dic + + +def backend_factory(backend, *args, **kwargs): + BACKENDS = dict( + npz1=Npz1Backend, + npz2=Npz2Backend, + ) + return BACKENDS[backend](*args, **kwargs) + + +class WriteBackend(Backend): + def __init__(self, path, **kwargs): + super().__init__(path, **kwargs) + + def write(self, i, data, **kwargs): + raise NotImplementedError("Must be implemented in subclass") + + def write_metadata(self, metadata): + raise NotImplementedError("Must be implemented in subclass") + + def write_statistics(self, statistics): + raise NotImplementedError("Must be implemented in subclass") + + +class Npz1WriteBackend(WriteBackend): + def write(self, i, data, **kwargs): + path = os.path.join(self.path, "data", str(int(i / 10))) + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{i}.npz") + np.savez(out_path, **data) + + def write_metadata(self, metadata): + from anemoi.datasets.create import json_tidy + + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2, default=json_tidy) + + def write_statistics(self, statistics): + flatten = {} + for name, d in statistics.items(): + assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + for k, v in d.items(): + assert isinstance( + v, (int, float, np.ndarray) + ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" + flatten[k + ":" + name] = v + + path = os.path.join(self.path, "statistics.npz") + np.savez(path, **flatten) + + +class Npz2WriteBackend(WriteBackend): + def write(self, i, data, **kwargs): + path = os.path.join(self.path, "data_", str(int(i / 10))) + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{i}_.npz") + np.savez(out_path, **data) + + def write_metadata(self, metadata): + from anemoi.datasets.create import json_tidy + + os.makedirs(self.path, exist_ok=True) + with open(os.path.join(self.path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2, default=json_tidy) + + def write_statistics(self, statistics): + flatten = {} + for name, d in statistics.items(): + assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" + for k, v in d.items(): + assert isinstance( + v, (int, float, np.ndarray) + ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" + flatten[k + ":" + name] = v + + os.makedirs(self.path, exist_ok=True) + path = os.path.join(self.path, "statistics_.npz") + np.savez(path, **flatten) + + +def writer_backend_factory(backend, *args, **kwargs): + WRITE_BACKENDS = dict( + npz1=Npz1WriteBackend, + npz2=Npz2WriteBackend, + ) + return WRITE_BACKENDS[backend](*args, **kwargs) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 4cc160547..3c7442487 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -107,51 +107,6 @@ def __getitem__(self, key: str) -> bytes: return response["Body"].read() -class PlanetaryComputerStore(ReadOnlyStore): - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - raise NotImplementedError() - - class DebugStore(ReadOnlyStore): """A store to debug the zarr loading.""" @@ -190,11 +145,11 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: if store.startswith("http://") or store.startswith("https://"): - parsed = urlparse(store) - if store.endswith(".zip"): import multiurl + parsed = urlparse(store) + # Zarr cannot handle zip files over HTTP tmpdir = tempfile.gettempdir() name = os.path.basename(parsed.path) @@ -210,15 +165,7 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: os.rename(path + ".tmp", path) return name_to_zarr_store(path) - bits = parsed.netloc.split(".") - if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): - s3_url = f"s3://{bits[0]}{parsed.path}" - store = S3Store(s3_url, region=bits[2]) - elif store.startswith("https://planetarycomputer.microsoft.com/"): - data_catalog_id = store.rsplit("/", 1)[-1] - store = PlanetaryComputerStore(data_catalog_id).store - else: - store = HTTPStore(store) + return HTTPStore(store) return store @@ -565,6 +512,10 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: config = load_config()["datasets"] use_search_path_not_found = config.get("use_search_path_not_found", False) + if name.endswith(".zarr/"): + LOG.warning("Removing trailing slash from path: %s", name) + name = name[:-1] + if name.endswith(".zarr") or name.endswith(".zip"): if os.path.exists(name): diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 3ea5265d9..2072137a0 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -185,6 +185,11 @@ def __getitem__(self, n: FullIndex) -> NDArray[Any]: n = self.indices[n] return self.dataset[n] + def get_aux(self, n: FullIndex) -> NDArray[Any]: + assert n >= 0, n + n = self.indices[n] + return self.dataset.get_aux(n) + @debug_indexing def _get_slice(self, s: slice) -> NDArray[Any]: """Get slice of data. diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 8a36aad6e..f663dade9 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -605,6 +605,7 @@ def nearest_grid_points( target_latitudes: NDArray[Any], target_longitudes: NDArray[Any], max_distance: float = None, + k: int = 1, ) -> NDArray[Any]: """Find the nearest grid points from source to target coordinates. @@ -621,6 +622,8 @@ def nearest_grid_points( max_distance: float, optional Maximum distance between nearest point and point to interpolate. Defaults to None. For example, 1e-3 is 1 km. + k : int, optional + The number of k closest neighbors to consider for interpolation Returns ------- @@ -637,10 +640,10 @@ def nearest_grid_points( target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) target_points = np.array(target_xyz).transpose() if max_distance is None: - _, indices = cKDTree(source_points).query(target_points, k=1) + distances, indices = cKDTree(source_points).query(target_points, k=k) else: - _, indices = cKDTree(source_points).query(target_points, k=1, distance_upper_bound=max_distance) - return indices + distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + return distances, indices if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..210ad7a14 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1 @@ +pytest_plugins = "anemoi.utils.testing" diff --git a/tests/create/__init__.py b/tests/create/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/create/test_create.py b/tests/create/test_create.py index beed90387..ec2d515ce 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -8,24 +8,17 @@ # nor does it submit to any jurisdiction. import glob -import hashlib -import json import logging import os import sys -from functools import wraps from unittest.mock import patch -import numpy as np import pytest -from anemoi.utils.testing import get_test_archive -from anemoi.utils.testing import get_test_data from anemoi.utils.testing import skip_if_offline -from earthkit.data import from_source as original_from_source -from anemoi.datasets import open_dataset -from anemoi.datasets.create.testing import create_dataset -from anemoi.datasets.data.stores import open_zarr +from .utils.compare import Comparer +from .utils.create import create_dataset +from .utils.mock_sources import LoadSource HERE = os.path.dirname(__file__) # find_yamls @@ -37,365 +30,41 @@ assert NAMES, "No yaml files found in " + HERE -def mockup_from_source(func: callable) -> callable: - """Decorator to mock the `from_source` function from the `earthkit.data` module. - - Parameters - ---------- - func : function - The function to be wrapped. - - Returns - ------- - function - The wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - with patch("earthkit.data.from_source", _from_source): - return func(*args, **kwargs) - - return wrapper - - -class LoadSource: - """Class to load data sources and handle mockup data.""" - - def filename(self, args: tuple, kwargs: dict) -> str: - """Generate a filename based on the arguments and keyword arguments. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - str - The generated filename. - """ - string = json.dumps([args, kwargs], sort_keys=True, default=str) - h = hashlib.md5(string.encode("utf8")).hexdigest() - return h + ".grib" - - def get_data(self, args: tuple, kwargs: dict, path: str) -> None: - """Retrieve data and save it to the specified path. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - path : str - The path to save the data. - - Raises - ------ - ValueError - If the test data is missing. - """ - upload_path = os.path.realpath(path + ".to_upload") - ds = original_from_source("mars", *args, **kwargs) - ds.save(upload_path) - print(f"Mockup: Saving to {upload_path} for {args}, {kwargs}") - print() - print("⚠️ To upload the test data, run this:") - path = os.path.relpath(upload_path, os.getcwd()) - name = os.path.basename(upload_path).replace(".to_upload", "") - print(f"scp {path} data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/{name}") - print() - exit(1) - raise ValueError("Test data is missing") - - def mars(self, args: tuple, kwargs: dict) -> object: - """Load data from the MARS archive. - - Parameters - ---------- - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - object - The loaded data source. - """ - - name = self.filename(args, kwargs) - - try: - return original_from_source("file", get_test_data(f"anemoi-datasets/create/mock-mars/{name}")) - except RuntimeError: - raise # If offline - except Exception: - self.get_data(args, kwargs, name) - - def __call__(self, name: str, *args: tuple, **kwargs: dict) -> object: - """Call the appropriate method based on the data source name. - - Parameters - ---------- - name : str - The name of the data source. - args : tuple - The positional arguments. - kwargs : dict - The keyword arguments. - - Returns - ------- - object - The loaded data source. - """ - if name == "mars": - return self.mars(args, kwargs) - - return original_from_source(name, *args, **kwargs) - - -_from_source = LoadSource() - - -def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: - """Compare the attributes of two Zarr datasets. - - Parameters - ---------- - a : dict - The attributes of the first dataset. - b : dict - The attributes of the second dataset. - path : str - The current path in the attribute hierarchy. - errors : list - The list to store error messages. - """ - if isinstance(a, dict): - a_keys = list(a.keys()) - b_keys = list(b.keys()) - for k in set(a_keys) | set(b_keys): - if k not in a_keys: - errors.append(f"❌ {path}.{k} : missing key (only in reference)") - continue - if k not in b_keys: - errors.append(f"❌ {path}.{k} : additional key (missing in reference)") - continue - if k in [ - "timestamp", - "uuid", - "latest_write_timestamp", - "history", - "provenance", - "provenance_load", - "description", - "config_path", - "total_size", - ]: - if type(a[k]) is not type(b[k]): - errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") - continue - - compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) - - return - - if isinstance(a, list): - if len(a) != len(b): - errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") - return - - for i, (v, w) in enumerate(zip(a, b)): - compare_dot_zattrs(v, w, f"{path}.{i}", errors) - - return - - if type(a) is not type(b): - msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" - errors.append(msg) - return - - if a != b: - msg = f"❌ {path} actual != expected : {a} != {b}" - errors.append(msg) - - -def compare_datasets(a: object, b: object) -> None: - """Compare two datasets. - - Parameters - ---------- - a : object - The first dataset. - b : object - The second dataset. - - Raises - ------ - AssertionError - If the datasets do not match. - """ - assert a.shape == b.shape, (a.shape, b.shape) - assert (a.dates == b.dates).all(), (a.dates, b.dates) - for a_, b_ in zip(a.variables, b.variables): - assert a_ == b_, (a, b) - assert a.missing == b.missing, "Missing are different" - - for i_date, date in zip(range(a.shape[0]), a.dates): - if i_date in a.missing: - continue - for i_param in range(a.shape[1]): - param = a.variables[i_param] - assert param == b.variables[i_param], ( - date, - param, - a.variables[i_param], - b.variables[i_param], - ) - a_ = a[i_date, i_param] - b_ = b[i_date, i_param] - assert a.shape == b.shape, (date, param, a.shape, b.shape) - - a_nans = np.isnan(a_) - b_nans = np.isnan(b_) - assert np.all(a_nans == b_nans), (date, param, "nans are different") - - a_ = np.where(a_nans, 0, a_) - b_ = np.where(b_nans, 0, b_) - - delta = a_ - b_ - max_delta = np.max(np.abs(delta)) - abs_error = np.abs(a_ - b_) - rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero - assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) - - -def compare_statistics(ds1: object, ds2: object) -> None: - """Compare the statistics of two datasets. - - Parameters - ---------- - ds1 : object - The first dataset. - ds2 : object - The second dataset. - - Raises - ------ - AssertionError - If the statistics do not match. - """ - vars1 = ds1.variables - vars2 = ds2.variables - assert len(vars1) == len(vars2) - for v1, v2 in zip(vars1, vars2): - idx1 = ds1.name_to_index[v1] - idx2 = ds2.name_to_index[v2] - assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() - assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() - assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() - assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() - - -class Comparer: - """Class to compare datasets and their metadata. - - Parameters - ---------- - name : str - The name of the dataset. - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - - def __init__(self, name: str, output_path: str = None, reference_path: str = None) -> None: - """Initialize the Comparer instance. - - Parameters - ---------- - name : str - The name of the dataset. - output_path : str, optional - The path to the output dataset. - reference_path : str, optional - The path to the reference dataset. - """ - self.name = name - self.output_path = output_path or os.path.join(name + ".zarr") - self.reference_path = reference_path - print(f"Comparing {self.output_path} and {self.reference_path}") - - self.z_output = open_zarr(self.output_path) - self.z_reference = open_zarr(self.reference_path) - - self.z_reference["data"] - self.ds_output = open_dataset(self.output_path) - self.ds_reference = open_dataset(self.reference_path) - - def compare(self) -> None: - """Compare the output dataset with the reference dataset. - - Raises - ------ - AssertionError - If the datasets or their metadata do not match. - """ - errors = [] - compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) - if errors: - print("Comparison failed") - print("\n".join(errors)) - - if errors: - print() - - print() - print("⚠️ To update the reference data, run this:") - print("cd " + os.path.dirname(self.output_path)) - base = os.path.basename(self.output_path) - print(f"tar zcf {base}.tgz {base}") - print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") - print() - raise AssertionError("Comparison failed") - - compare_datasets(self.ds_output, self.ds_reference) - compare_statistics(self.ds_output, self.ds_reference) - # do not compare tendencies statistics yet, as we don't know yet if they should stay +@pytest.fixture +def load_source(get_test_data: callable) -> LoadSource: + return LoadSource(get_test_data) @skip_if_offline @pytest.mark.parametrize("name", NAMES) -@mockup_from_source -def test_run(name: str) -> None: +def test_run(name: str, get_test_archive: callable, load_source: LoadSource) -> None: """Run the test for the specified dataset. Parameters ---------- name : str The name of the dataset. + get_test_archive : callable + Fixture to retrieve the test archive. + load_source : LoadSource + Fixture to mock data sources. Raises ------ AssertionError If the comparison fails. """ - config = os.path.join(HERE, name + ".yaml") - output = os.path.join(HERE, name + ".zarr") - is_test = False + with patch("earthkit.data.from_source", load_source): + config = os.path.join(HERE, name + ".yaml") + output = os.path.join(HERE, name + ".zarr") + is_test = False - create_dataset(config=config, output=output, delta=["12h"], is_test=is_test) + create_dataset(config=config, output=output, delta=["12h"], is_test=is_test) - directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") - reference = os.path.join(directory, name + ".zarr") + directory = get_test_archive(f"anemoi-datasets/create/mock-mars/{name}.zarr.tgz") + reference = os.path.join(directory, name + ".zarr") - Comparer(name, output_path=output, reference_path=reference).compare() + Comparer(output_path=output, reference_path=reference).compare() if __name__ == "__main__": diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index f74471911..e8e766b51 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -7,20 +7,22 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging import os +import sys -from anemoi.utils.testing import get_test_data +import numpy as np +import pytest from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages from anemoi.utils.testing import skip_slow_tests from anemoi.datasets import open_dataset -from anemoi.datasets.create.testing import create_dataset + +from .utils.create import create_dataset @skip_if_offline -def test_grib() -> None: +def test_grib(get_test_data: callable) -> None: """Test the creation of a dataset from GRIB files. This function tests the creation of a dataset using GRIB files from @@ -50,8 +52,123 @@ def test_grib() -> None: assert ds.shape == (8, 12, 1, 162) +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" +) +@skip_if_offline +def test_grib_gridfile(get_test_data) -> None: + """Test the creation of a dataset from GRIB files with an unstructured grid. + + This function tests the creation of a dataset using GRIB files from + specific dates and verifies the shape of the resulting dataset. + This GRIB data is defined on an unstructured grid and therefore requires + specifying a grid file. + """ + data1 = get_test_data("anemoi-datasets/create/grib-iconch1-20250101.grib") + data2 = get_test_data("anemoi-datasets/create/grib-iconch1-20250102.grib") + gridfile = get_test_data("anemoi-datasets/create/icon_grid_0001_R19B08_mch.nc") + assert os.path.dirname(data1) == os.path.dirname(data2) + + path = os.path.dirname(data1) + + config = { + "dates": { + "start": "2025-01-01T00:00:00", + "end": "2025-01-02T18:00:00", + "frequency": "6h", + }, + "input": { + "grib": { + "path": os.path.join(path, "grib-iconch1-{date:strftime(%Y%m%d)}.grib"), + "grid_definition": {"icon": {"path": gridfile}}, + "flavour": [[{"levtype": "sfc"}, {"levelist": None}]], + }, + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (8, 1, 1, 1147980) + assert ds.variables == ["2t"] + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" +) @skip_if_offline -def test_netcdf() -> None: +@pytest.mark.parametrize( + "refinement_level_c,shape", + ( + (2, (2, 13, 1, 2880)), + (7, (2, 13, 1, 2949120)), + ), +) +def test_grib_gridfile_with_refinement_level( + refinement_level_c: str, shape: tuple[int, int, int, int, int], get_test_data: callable +) -> None: + """Test the creation of a dataset from GRIB files with an unstructured grid. + + This function tests the creation of a dataset using GRIB files from + specific dates and verifies the shape of the resulting dataset. + This GRIB data is defined on an unstructured grid and therefore requires + specifying a grid file. The `refinement_level_c` selection key and + strftimedelta are used. + """ + + p = "anemoi-datasets/create/test_grib_gridfile_with_refinement_level/" + data1 = get_test_data(p + "2023010103+fc_R03B07_rea_ml.2023010100") + data2 = get_test_data(p + "2023010106+fc_R03B07_rea_ml.2023010103") + gridfile = get_test_data("dwd/2024-12-11_00/icon_grid_0026_R03B07_subsetAICON.nc") + assert os.path.dirname(data1) == os.path.dirname(data2) + + path = os.path.dirname(data1) + + param = ["pres", "t", "u", "v", "q"] + level = [101, 119] + forcings = ["cos_latitude", "sin_latitude", "cos_julian_day"] + assert len(param) * len(level) + len(forcings) == shape[1] + + grib = { + "path": os.path.join(path, "{date:strftimedelta(+3h;%Y%m%d%H)}+fc_R03B07_rea_ml.{date:strftime(%Y%m%d%H)}"), + "grid_definition": {"icon": {"path": gridfile, "refinement_level_c": refinement_level_c}}, + "param": param, + "level": level, + } + refinement_filter = {"icon_refinement_level": {"grid": gridfile, "refinement_level_c": refinement_level_c}} + + config = { + "dates": { + "start": "2023-01-01T00:00:00", + "end": "2023-01-01T03:00:00", + "frequency": "3h", + }, + "input": { + "pipe": [ + { + "join": [ + {"grib": grib}, + {"forcings": {"param": forcings, "template": "${input.pipe.0.join.0.grib}"}}, + ] + }, + refinement_filter, + ] + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == shape + assert np.all(ds.data[ds.to_index(date=0, variable="cos_julian_day", member=0)] == 1.0), "cos(julian_day = 0) == 1" + assert np.all(ds.data[ds.to_index(date=0, variable="u_101", member=0)] == 42.0), "artificially constant data day 0" + assert np.all(ds.data[ds.to_index(date=1, variable="v_119", member=0)] == 21.0), "artificially constant data day 1" + assert ds.data[ds.to_index(date=0, variable="cos_latitude", member=0)].max() > 0.9 + assert ds.data[ds.to_index(date=0, variable="cos_latitude", member=0)].min() >= 0 + assert ds.data[ds.to_index(date=0, variable="sin_latitude", member=0)].max() > 0.9 + assert ds.data[ds.to_index(date=0, variable="sin_latitude", member=0)].min() < -0.9 + + +@skip_if_offline +def test_netcdf(get_test_data: callable) -> None: """Test for NetCDF files. This function tests the creation of a dataset from a NetCDF file. @@ -74,7 +191,7 @@ def test_netcdf() -> None: @skip_missing_packages("fstd", "rpnpy.librmn") -def test_eccs_fstd() -> None: +def test_eccs_fstd(get_test_data: callable) -> None: """Test for 'fstd' files from ECCC.""" # See https://github.com/neishm/fstd2nc @@ -98,7 +215,7 @@ def test_eccs_fstd() -> None: @skip_slow_tests @skip_if_offline @skip_missing_packages("kerchunk", "s3fs") -def test_kerchunk() -> None: +def test_kerchunk(get_test_data: callable) -> None: """Test for Kerchunk JSON files. This function tests the creation of a dataset from a Kerchunk JSON file. @@ -128,12 +245,44 @@ def test_kerchunk() -> None: assert ds.shape == (4, 1, 1, 1038240) +@skip_if_offline +@skip_missing_packages("planetary_computer", "adlfs") +def test_planetary_computer_conus404() -> None: + """Test loading and validating the planetary_computer_conus404 dataset.""" + + config = { + "dates": { + "start": "2022-01-01", + "end": "2022-01-02", + "frequency": "1d", + }, + "input": { + "planetary_computer": { + "data_catalog_id": "conus404", + "param": ["Z"], + "level": [1], + "patch": { + "coordinates": ["bottom_top_stag"], + "rename": { + "bottom_top_stag": "level", + }, + "attributes": { + "lon": {"standard_name": "longitude", "long_name": "Longitude"}, + "lat": {"standard_name": "latitude", "long_name": "Latitude"}, + }, + }, + } + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (2, 1, 1, 1387505), ds.shape + + if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - test_kerchunk() - exit() - """Run all test functions that start with 'test_'.""" - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() + test_planetary_computer_conus404() + exit(0) + from anemoi.utils.testing import run_tests + + run_tests(globals()) diff --git a/tests/create/utils/__init__.py b/tests/create/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py new file mode 100644 index 000000000..aa6a59dd2 --- /dev/null +++ b/tests/create/utils/compare.py @@ -0,0 +1,218 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os + +import numpy as np + +from anemoi.datasets import open_dataset +from anemoi.datasets.data.stores import open_zarr + + +class Comparer: + """Class to compare datasets and their metadata. + + Parameters + ---------- + output_path : str, optional + The path to the output dataset. + reference_path : str, optional + The path to the reference dataset. + """ + + def __init__(self, output_path: str = None, reference_path: str = None) -> None: + """Initialize the Comparer instance. + + Parameters + ---------- + output_path : str, optional + The path to the output dataset. + reference_path : str, optional + The path to the reference dataset. + """ + self.output_path = output_path + self.reference_path = reference_path + print(f"Comparing {self.output_path} and {self.reference_path}") + + self.z_output = open_zarr(self.output_path) + self.z_reference = open_zarr(self.reference_path) + + self.z_reference["data"] + self.ds_output = open_dataset(self.output_path) + self.ds_reference = open_dataset(self.reference_path) + + @staticmethod + def compare_datasets(a: object, b: object) -> None: + """Compare two datasets. + + Parameters + ---------- + a : object + The first dataset. + b : object + The second dataset. + + Raises + ------ + AssertionError + If the datasets do not match. + """ + assert a.shape == b.shape, (a.shape, b.shape) + assert (a.dates == b.dates).all(), (a.dates, b.dates) + for a_, b_ in zip(a.variables, b.variables): + assert a_ == b_, (a, b) + assert a.missing == b.missing, "Missing are different" + + for i_date, date in zip(range(a.shape[0]), a.dates): + if i_date in a.missing: + continue + for i_param in range(a.shape[1]): + param = a.variables[i_param] + assert param == b.variables[i_param], ( + date, + param, + a.variables[i_param], + b.variables[i_param], + ) + a_ = a[i_date, i_param] + b_ = b[i_date, i_param] + assert a.shape == b.shape, (date, param, a.shape, b.shape) + + a_nans = np.isnan(a_) + b_nans = np.isnan(b_) + assert np.all(a_nans == b_nans), (date, param, "nans are different") + + a_ = np.where(a_nans, 0, a_) + b_ = np.where(b_nans, 0, b_) + + delta = a_ - b_ + max_delta = np.max(np.abs(delta)) + abs_error = np.abs(a_ - b_) + rel_error = np.abs(a_ - b_) / (np.abs(b_) + 1e-10) # Avoid division by zero + assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta, np.max(abs_error), np.max(rel_error)) + + @staticmethod + def compare_statistics(ds1: object, ds2: object) -> None: + """Compare the statistics of two datasets. + + Parameters + ---------- + ds1 : object + The first dataset. + ds2 : object + The second dataset. + + Raises + ------ + AssertionError + If the statistics do not match. + """ + vars1 = ds1.variables + vars2 = ds2.variables + assert len(vars1) == len(vars2) + for v1, v2 in zip(vars1, vars2): + idx1 = ds1.name_to_index[v1] + idx2 = ds2.name_to_index[v2] + assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() + assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() + assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() + assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() + + @staticmethod + def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: + """Compare the attributes of two Zarr datasets. + + Parameters + ---------- + a : dict + The attributes of the first dataset. + b : dict + The attributes of the second dataset. + path : str + The current path in the attribute hierarchy. + errors : list + The list to store error messages. + """ + if isinstance(a, dict): + a_keys = list(a.keys()) + b_keys = list(b.keys()) + for k in set(a_keys) | set(b_keys): + if k not in a_keys: + errors.append(f"❌ {path}.{k} : missing key (only in reference)") + continue + if k not in b_keys: + errors.append(f"❌ {path}.{k} : additional key (missing in reference)") + continue + if k in [ + "timestamp", + "uuid", + "latest_write_timestamp", + "history", + "provenance", + "provenance_load", + "description", + "config_path", + "total_size", + ]: + if type(a[k]) is not type(b[k]): + errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") + continue + + Comparer.compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) + + return + + if isinstance(a, list): + if len(a) != len(b): + errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") + return + + for i, (v, w) in enumerate(zip(a, b)): + Comparer.compare_dot_zattrs(v, w, f"{path}.{i}", errors) + + return + + if type(a) is not type(b): + msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" + errors.append(msg) + return + + if a != b: + msg = f"❌ {path} actual != expected : {a} != {b}" + errors.append(msg) + + def compare(self) -> None: + """Compare the output dataset with the reference dataset. + + Raises + ------ + AssertionError + If the datasets or their metadata do not match. + """ + errors = [] + self.compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) + if errors: + print("Comparison failed") + print("\n".join(errors)) + + if errors: + print() + + print() + print("⚠️ To update the reference data, run this:") + print("cd " + os.path.dirname(self.output_path)) + base = os.path.basename(self.output_path) + print(f"tar zcf {base}.tgz {base}") + print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") + print() + raise AssertionError("Comparison failed") + + self.compare_datasets(self.ds_output, self.ds_reference) + self.compare_statistics(self.ds_output, self.ds_reference) + # do not compare tendencies statistics yet, as we don't know yet if they should stay diff --git a/src/anemoi/datasets/create/testing.py b/tests/create/utils/create.py similarity index 100% rename from src/anemoi/datasets/create/testing.py rename to tests/create/utils/create.py diff --git a/tests/create/utils/mock_sources.py b/tests/create/utils/mock_sources.py new file mode 100644 index 000000000..6c8448372 --- /dev/null +++ b/tests/create/utils/mock_sources.py @@ -0,0 +1,117 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import hashlib +import json +import os + +from earthkit.data import from_source as original_from_source + + +class LoadSource: + """Class to load data sources and handle mockup data.""" + + def __init__(self, get_test_data_func) -> None: + self._get_test_data = get_test_data_func + + def filename(self, args: tuple, kwargs: dict) -> str: + """Generate a filename based on the arguments and keyword arguments. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + str + The generated filename. + """ + string = json.dumps([args, kwargs], sort_keys=True, default=str) + h = hashlib.md5(string.encode("utf8")).hexdigest() + return h + ".grib" + + def get_data(self, args: tuple, kwargs: dict, path: str) -> None: + """Retrieve data and save it to the specified path. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + path : str + The path to save the data. + + Raises + ------ + ValueError + If the test data is missing. + """ + upload_path = os.path.realpath(path + ".to_upload") + ds = original_from_source("mars", *args, **kwargs) + ds.save(upload_path) + print(f"Mockup: Saving to {upload_path} for {args}, {kwargs}") + print() + print("⚠️ To upload the test data, run this:") + path = os.path.relpath(upload_path, os.getcwd()) + name = os.path.basename(upload_path).replace(".to_upload", "") + print(f"scp {path} data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/{name}") + print() + exit(1) + raise ValueError("Test data is missing") + + def mars(self, args: tuple, kwargs: dict) -> object: + """Load data from the MARS archive. + + Parameters + ---------- + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + object + The loaded data source. + """ + + name = self.filename(args, kwargs) + + try: + return original_from_source("file", self._get_test_data(f"anemoi-datasets/create/mock-mars/{name}")) + except RuntimeError: + raise # If offline + except Exception: + self.get_data(args, kwargs, name) + + def __call__(self, name: str, *args: tuple, **kwargs: dict) -> object: + """Call the appropriate method based on the data source name. + + Parameters + ---------- + name : str + The name of the data source. + args : tuple + The positional arguments. + kwargs : dict + The keyword arguments. + + Returns + ------- + object + The loaded data source. + """ + if name == "mars": + return self.mars(args, kwargs) + + return original_from_source(name, *args, **kwargs) diff --git a/tests/test_data.py b/tests/test_data.py index 5c3e4c352..07b35887e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -31,6 +31,7 @@ from anemoi.datasets.data.join import Join from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.data.padded import Padded from anemoi.datasets.data.select import Rename from anemoi.datasets.data.select import Select from anemoi.datasets.data.statistics import Statistics @@ -389,6 +390,7 @@ def run( time_increment: datetime.timedelta, statistics_reference_dataset: Optional[Union[str, list]], statistics_reference_variables: Optional[Union[str, list]], + regular_shape: bool = True, ) -> None: """Run the dataset tests. @@ -414,6 +416,8 @@ def run( Reference dataset for statistics. statistics_reference_variables : Optional[Union[str, list]] Reference variables for statistics. + regular_shape : bool, optional + Whether the dataset has a regular shape, by default True. """ if isinstance(expected_variables, str): expected_variables = [v for v in expected_variables] @@ -452,7 +456,8 @@ def run( statistics_reference_variables, ) - self.indexing(self.ds) + if regular_shape: + self.indexing(self.ds) self.metadata(self.ds) self.ds.tree() @@ -669,6 +674,25 @@ def test_join_3() -> None: ) +@mockup_open_zarr +def test_padding_1() -> None: + """Test subsetting a dataset (case 2).""" + test = DatasetTester("test-2022-2022-1h-o96-abcd", start="2021-01-01", end="2023-12-31 23:00:00", padding="empty") + test.run( + expected_class=Padded, + expected_length=365 * 24 * 3, + expected_shape=(365 * 24 * 3, 4, 1, VALUES), + expected_variables="abcd", + expected_name_to_index="abcd", + date_to_row=lambda date: simple_row(date, "abcd") if date.year == 2022 else np.zeros((4, 1, 0)), + start_date=datetime.datetime(2021, 1, 1), + time_increment=datetime.timedelta(hours=1), + statistics_reference_dataset="test-2022-2022-1h-o96-abcd", + statistics_reference_variables="abcd", + regular_shape=False, + ) + + @mockup_open_zarr def test_subset_1() -> None: """Test subsetting a dataset (case 1).""" diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 000000000..896081f9a --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,160 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import os + +import numpy as np +import pytest + +from anemoi.datasets.data import open_dataset +from anemoi.datasets.data.records import Record +from anemoi.datasets.data.records import Tabular + + +def check_numpy(x, y): + assert x.shape == y.shape, f"Expected {x.shape} == {y.shape}" + assert type(x) == type(y), f"Expected {type(x)} == {type(y)}" # noqa: E721 + assert np.all(np.isnan(x) == np.isnan(y)) and np.all( + np.nan_to_num(x) == np.nan_to_num(y) + ), f"Expected {x} == {y} (ignoring NaNs)" + + +def _test(ds, nb_dates=None): + grp = "metop_a_ascat" + index_i = 0 + + if nb_dates is not None: + assert len(ds) == nb_dates, f"Expected {nb_dates} dates, got {len(ds)}" + + ################################# + # Order does not matter too much [i] and [grp] are exchangeable + + elt = ds[index_i] + assert isinstance(elt, Record), (type(ds), type(elt)) + assert ds[index_i].dataset == ds, (type(ds[index_i].dataset), type(ds)) + + group = ds[grp] + assert isinstance(group, Tabular), type(group) + + x = ds[grp][index_i] + y = ds[index_i][grp] + check_numpy(x, y) + + ############################################### + # lat and lon and timedelta are not the same for all elements + # but they have the same size + + lat = ds[index_i].latitudes[grp] + assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet + # lat = ds[grp].latitudes[index_i] + # assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet : do not need ? + # lat = ds.latitudes[grp][index_i] + # assert isinstance(lat, np.ndarray), type(lat) + + # Not implemented yet : do not need ? + # lat = ds.latitudes[index_i][grp] + # assert isinstance(lat, np.ndarray), type(lat) + + lon = ds[index_i].longitudes[grp] + assert isinstance(lon, np.ndarray), type(lon) + assert len(lat) == len(lon), f"Expected same size for lat and lon {len(lat)} == {len(lon)}" + + timedeltas = ds[index_i].timedeltas[grp] + assert isinstance(timedeltas, np.ndarray), type(timedeltas) + assert len(timedeltas) == len(lat), f"Expected same size for lat and timedeltas {len(lat)} == {len(timedeltas)}" + + ############################################# + # name_to_index is must be the same for all elements + # name_to_index is a dict of dict (key is the group name) + + name_to_index = ds.name_to_index + assert isinstance(name_to_index, dict), type(name_to_index) + assert len(name_to_index) > 0, "name_to_index is empty" + assert all(isinstance(k, str) for k in name_to_index.keys()), name_to_index + assert all(isinstance(v, dict) for v in name_to_index.values()), name_to_index + + _name_to_index = ds[index_i].name_to_index + assert list(name_to_index.keys()) == list(_name_to_index.keys()), ( + list(name_to_index.keys()), + list(_name_to_index.keys()), + ) + assert name_to_index == _name_to_index, "name_to_index is not the same for all elements" + + ############################################### + # statistics is not the same for all elements + # statistics is a dict of dict (first key is the group name) + + statistics = ds.statistics + assert isinstance(statistics, dict), type(statistics) + assert len(statistics) > index_i, "statistics is empty" + assert all(isinstance(k, str) for k in statistics.keys()), statistics + assert all(isinstance(v, dict) for v in statistics.values()), statistics + assert grp in statistics, f"statistics does not contain {grp}" + + statistics_ = ds[grp].statistics + assert isinstance(statistics_, dict), type(statistics_) + assert "mean" in statistics_, "statistics does not contain mean" + + # ! here, the meaning could be ambigous, this is the statistics of the whole dataset. + # Do not document this, and maybe remove it. + _statistics = ds[index_i].statistics + assert isinstance(_statistics, dict), type(_statistics) + assert grp in _statistics, f"statistics does not contain {grp}" + assert _statistics.keys() == ds.keys(), (_statistics.keys(), ds.keys()) + for group_name, stats in _statistics.items(): + assert "mean" in stats, f"statistics does not contain mean for {group_name}" + for key, v in stats.items(): + assert np.all(statistics[group_name][key] == v), (key, statistics[group_name][key], v) + + assert statistics[grp].keys() == statistics_.keys(), (statistics[grp].keys(), statistics_.keys()) + for key, v in statistics[grp].items(): + assert np.all(statistics[grp][key] == v), (key, statistics[grp][key], v) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open(): + ds = open_dataset("../../data/vz/obs-2018-11.vz") + _test(ds) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open_with_subset_dates(): + ds = open_dataset( + "../../data/vz/obs-2018-11.vz", + end="2018-11-30", + select=[ + "metop_a_ascat.*", + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", + ], + ) + _test(ds, nb_dates=8) + + +@pytest.mark.skipif(not os.path.exists("../../data/vz/obs-2018-11.vz"), reason="File not found") +def test_open_with_subset_select(): + ds = open_dataset( + "../../data/vz/obs-2018-11.vz", + select=[ + "amsr2_h180.rawbt_4", + "amsr2_h180.rawbt_3", + "metop_a_ascat.*", + ], + ) + _test(ds) + + +if __name__ == "__main__": + + test_open() + test_open_with_subset_select() + test_open_with_subset_dates() diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py new file mode 100644 index 000000000..ba185ca39 --- /dev/null +++ b/tests/xarray/test_flavour.py @@ -0,0 +1,104 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import numpy as np +import pytest +import xarray as xr + +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.flavour import DefaultCoordinateGuesser + + +def create_ds(var_name, standard_name, long_name, units, coord_length=5): + attrs = { + k: v for k, v in [("standard_name", standard_name), ("long_name", long_name), ("units", units)] if v is not None + } + + ds = xr.Dataset( + {"x_wind": ([var_name], np.random.rand(coord_length))}, + coords={ + var_name: xr.DataArray(np.arange(coord_length), dims=var_name, attrs=attrs), + }, + ) + return ds + + +@pytest.mark.parametrize( + "var_name, standard_name, long_name, units, result", + [ + # longitude + ("longitude", None, None, None, LongitudeCoordinate), + ("longitude", None, "longitude", "degrees_east", LongitudeCoordinate), + ("longitude", None, "longitude", "degrees", LongitudeCoordinate), + ("lons", "longitude", None, "degrees", LongitudeCoordinate), + ("lons", None, None, None, UnsupportedCoordinate), + # latitude + ("latitude", None, None, None, LatitudeCoordinate), + ("latitude", None, "latitude", "degrees_north", LatitudeCoordinate), + ("latitude", None, "latitude", "degrees", LatitudeCoordinate), + ("lats", "latitude", None, "degrees", LatitudeCoordinate), + ("lats", None, None, "'degrees", UnsupportedCoordinate), + # x + ("x", None, None, None, XCoordinate), + ("x_coord", "projection_x_coordinate", None, None, XCoordinate), + ("x_coord", "grid_longitude", None, None, XCoordinate), + # y + ("y", None, None, None, YCoordinate), + ("y_coord", "projection_y_coordinate", None, None, YCoordinate), + ("y_coord", "grid_latitude", None, None, YCoordinate), + # time + ("time", "time", None, None, TimeCoordinate), + ("time", None, None, None, TimeCoordinate), + # date + ("t", "forecast_reference_time", None, None, DateCoordinate), + ("forecast_reference_time", None, None, None, DateCoordinate), + ("forecast_reference_time", "forecast_reference_time", None, None, DateCoordinate), + # step + ("fp", "forecast_period", None, None, StepCoordinate), + ("forecast_period", None, "time elapsed since the start of the forecast", None, StepCoordinate), + ("prediction_timedelta", None, None, None, StepCoordinate), + # level + ("lev", "atmosphere_hybrid_sigma_pressure_coordinate", None, None, LevelCoordinate), + ("h", None, "height", "m", LevelCoordinate), + ("level", "air_pressure", None, "hPa", LevelCoordinate), + ("pressure_0", None, "pressure", "hPa", LevelCoordinate), + ("pressure_0", None, "pressure", "Pa", LevelCoordinate), + ("level", None, None, None, LevelCoordinate), + ("lev", None, "level", None, UnsupportedCoordinate), + ("vertical", "vertical", None, "hPa", LevelCoordinate), + ("depth", "depth", None, "m", LevelCoordinate), + ("depth", "depth", None, None, LevelCoordinate), + # number + ("realization", None, None, None, EnsembleCoordinate), + ("number", None, None, None, EnsembleCoordinate), + ], +) +def test_coordinate_guesser(var_name, standard_name, long_name, units, result): + ds = create_ds(var_name, standard_name, long_name, units) + guesser = DefaultCoordinateGuesser(ds) + guess = guesser.guess(ds[var_name], var_name) + assert isinstance(guess, result) + + +def test_coordinate_guesser_scalar(): + var_name = "height" + ds = create_ds(var_name, None, None, "m", coord_length=1) + guesser = DefaultCoordinateGuesser(ds) + guess = guesser.guess(ds[var_name], var_name) + assert isinstance(guess, ScalarCoordinate) diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index c0a8f8ed2..81d36b923 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -13,7 +13,6 @@ from anemoi.utils.testing import skip_missing_packages from anemoi.datasets.create.sources.xarray import XarrayFieldList -from anemoi.datasets.data.stores import name_to_zarr_store from anemoi.datasets.testing import assert_field_list @@ -133,33 +132,6 @@ def test_noaa_replay() -> None: ) -@skip_if_offline -@skip_missing_packages("planetary_computer", "adlfs") -def test_planetary_computer_conus404() -> None: - """Test loading and validating the planetary_computer_conus404 dataset.""" - url = "https://planetarycomputer.microsoft.com/api/stac/v1/collections/conus404" - ds = xr.open_zarr(**name_to_zarr_store(url)) - - flavour = { - "rules": { - "latitude": {"name": "lat"}, - "longitude": {"name": "lon"}, - "x": {"name": "west_east"}, - "y": {"name": "south_north"}, - "time": {"name": "time"}, - }, - } - - fs = XarrayFieldList.from_xarray(ds, flavour=flavour) - - assert_field_list( - fs, - 74634912, - "1979-10-01T00:00:00", - "2022-09-30T23:00:00", - ) - - if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): diff --git a/tools/build-obs.py b/tools/build-obs.py new file mode 100755 index 000000000..e3caff9f9 --- /dev/null +++ b/tools/build-obs.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os +import shutil + +import tqdm + +from anemoi.datasets import open_dataset + +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="open a dataset and build a new one") + parser.add_argument("input", help="input dataset") + parser.add_argument("output", help="output dataset") + parser.add_argument("--backend", help="backend to use", type=str, default="npz1") + parser.add_argument("--overwrite", help="overwrite output directory if it exists", action="store_true") + args = parser.parse_args() + build(**vars(args)) + + +def build(input, output, backend, overwrite=False): + ds = open_dataset(input, backend=backend) + print(f"Using dataset {ds} as input") + print(f"{input} backend is '{ds.metadata['backend']}'") + print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") + print(f"Converting dataset to {output} using new backend '{backend}'") + + from anemoi.datasets.data.records.backends import writer_backend_factory + + if os.path.exists(output): + if overwrite: + LOG.warning(f"Output directory {output} already exists, removing it") + shutil.rmtree(output) + else: + raise FileExistsError(f"Output directory {output} already exists, use --overwrite to remove it") + writer = writer_backend_factory(backend, output) + + for i in tqdm.tqdm(range(len(ds))): + writer.write(i, ds[i]) + + writer.write_statistics(ds.statistics) + + metadata = ds.metadata.copy() + metadata["backend"] = backend + writer.write_metadata(metadata) + + +if __name__ == "__main__": + main() diff --git a/tools/check-obs.py b/tools/check-obs.py new file mode 100755 index 000000000..283b76aac --- /dev/null +++ b/tools/check-obs.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +import argparse +import logging + +import numpy as np + +from anemoi.datasets import open_dataset + +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="open two datasets and compare them") + parser.add_argument("dataset", help="dataset to check") + parser.add_argument("reference", help="reference dataset") + args = parser.parse_args() + compare(args.dataset, args.reference) + + +def _compare_nested_dicts(a, b): + if isinstance(a, dict) and isinstance(b, dict): + if a.keys() != b.keys(): + return False + return all(_compare_nested_dicts(a[k], b[k]) for k in a) + elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + if a.shape != b.shape: + return False + return np.array_equal(a, b) + assert False, f"Unsupported types for comparison: {type(a)} and {type(b)}" + + +def compare(input, reference): + ds = open_dataset(input) + ref = open_dataset(reference) + + if len(ds) != len(ref): + raise ValueError(f"Datasets have different lengths: {len(ds)} != {len(ref)}") + + for i in range(len(ds)): + if ds[i] != ref[i]: + raise ValueError(f"Datasets differ at index {i}: {ds[i]} != {ref[i]}") + if ds.dates[i] != ref.dates[i]: + raise ValueError(f"Dates differ at index {i}: {ds.dates[i]} != {ref.dates[i]}") + print("✅ Data and dates are identical") + + ds_metadata = ds.metadata.copy() + ref_metadata = ref.metadata.copy() + ds_metadata.pop("backend", None) + ref_metadata.pop("backend", None) + if ds_metadata != ref_metadata: + raise ValueError("Metadata differs between datasets (excluding backend)") + print("✅ Metadata is identical") + + if not _compare_nested_dicts(ds.statistics, ref.statistics): + raise ValueError("Statistics differ between datasets") + print("✅ Statistics are identical") + + +if __name__ == "__main__": + main() From 83936f75ee3f65982760b0c3d5227df5cf70dd57 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:05:53 +0200 Subject: [PATCH 070/212] merge --- 03-constant-fields.rst | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 03-constant-fields.rst diff --git a/03-constant-fields.rst b/03-constant-fields.rst deleted file mode 100644 index 2b21505dc..000000000 --- a/03-constant-fields.rst +++ /dev/null @@ -1,3 +0,0 @@ -######################## - Adding constant fields -######################## From 5209f26c0aa229d076834ab895094852c4208100 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 11 Aug 2025 20:24:47 +0200 Subject: [PATCH 071/212] update --- src/anemoi/datasets/create/__init__.py | 14 +++- src/anemoi/datasets/create/python.py | 98 +++++++++++--------------- 2 files changed, 52 insertions(+), 60 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 69b5a0d42..508eb2952 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1667,11 +1667,19 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - from ..create.python import PythonCode + from ..create.python import PythonSource config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - code = PythonCode() - return input.python_code(code) + code = PythonSource() + code = input.python_code(code).source_code() + + try: + import black + + return black.format_str(code, mode=black.Mode()) + except ImportError: + LOG.warning("Black not installed, skipping formatting") + return code diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 271845c2d..8d3312583 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -8,11 +8,8 @@ # nor does it submit to any jurisdiction. # -import datetime -import json import re - -from anemoi.utils.dates import frequency_to_string +import textwrap # input.python_prelude(code) # code1 = "\n".join(prelude) @@ -61,21 +58,51 @@ class PythonCode: + def __init__(self, parent): + self.parent = parent + def call(self, name, argument): - return PythonCall(name, argument) + return PythonCall(self, name, argument) def sum(self, actions): - return PythonChain("+", actions) + return PythonChain(self, "+", actions) def pipe(self, actions): - return PythonChain("|", actions) + return PythonChain(self, "|", actions) def concat(self, argument): - return PythonConcat(argument) + return PythonConcat(self, argument) + + def source_code(self, top=None): + return self.parent.source_code(top=top or self) + + +class PythonSource(PythonCode): + + def __init__(self): + super().__init__(parent=None) + self.prelude = "" + + def source_code(self, top): + + return textwrap.dedent( + f""" + # Generated Python code for Anemoi dataset creation + + from anemoi.datasets.recipe import Recipe + r = Recipe() + {self.prelude} + r.input = {repr(top)} + + """ + ) + + return repr(top) class PythonConcat(PythonCode): - def __init__(self, argument): + def __init__(self, parent, argument): + super().__init__(parent=parent) self.argument = argument def __repr__(self): @@ -83,13 +110,14 @@ def __repr__(self): class PythonCall(PythonCode): - def __init__(self, name, argument): + def __init__(self, parent, name, argument): + super().__init__(parent=parent) self.name = name self.argument = argument def __repr__(self): name = self.name.replace("-", "_") - config = self.argument + config = dict(**self.argument) # def convert(obj): # if isinstance(obj, datetime.datetime): @@ -121,54 +149,10 @@ def __repr__(self): class PythonChain(PythonCode): - def __init__(self, op, actions): + def __init__(self, parent, op, actions): + super().__init__(parent=parent) self.op = op self.actions = actions def __repr__(self): return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - -def _python(name, config, **extra) -> str: - """Convert the action to Python code. - - Parameters - ---------- - name : str - The name of the action. - config : dict - The configuration for the action. - extra : Any - Additional keyword arguments. - - Returns - ------- - str - The Python code representation of the action. - """ - - name = name.replace("-", "_") - - def convert(obj): - if isinstance(obj, datetime.datetime): - return obj.isoformat() - if isinstance(obj, datetime.date): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return frequency_to_string(obj) - raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - config = json.loads(json.dumps(config, default=convert)) - - params = [] - for k, v in config.items(): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: - return f"r.{name}({config})" - params.append(f"{k}={repr(v)}") - - for k, v in extra.items(): - params.append(f"{k}={v}") - - params = ",".join(params) - return f"r.{name}({params})" - # return f"{name}({config})" From c7a0e5d7d50817ad02d1b879b07a6147c9993e35 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 12 Aug 2025 12:21:59 +0200 Subject: [PATCH 072/212] update --- src/anemoi/datasets/create/__init__.py | 6 +- src/anemoi/datasets/create/python.py | 255 ++++++++++++++++++------- 2 files changed, 194 insertions(+), 67 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 508eb2952..7a076aa48 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1674,12 +1674,14 @@ def config_to_python(config: Any) -> Any: input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) code = PythonSource() - code = input.python_code(code).source_code() + x = input.python_code(code) + code = code.source_code(x) try: import black return black.format_str(code, mode=black.Mode()) - except ImportError: + # except ImportError: + except Exception: LOG.warning("Black not installed, skipping formatting") return code diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 8d3312583..d54fca1f3 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -9,24 +9,9 @@ # import re -import textwrap +import sys +from collections import defaultdict -# input.python_prelude(code) -# code1 = "\n".join(prelude) -# rich.print(f"Input prelude:\n{code1}") -# code2 = input.to_python() - -# code = f"from anemoi.datasets.recipe import Recipe\nr = Recipe()\n{code1}\nr.input = {code2}\n\nr.dump()" - -# code = re.sub(r"[\"\']?\${data_sources\.(\w+)}[\"\']?", r"\1", code) - -# try: -# import black - -# return black.format_str(code, mode=black.Mode()) -# except ImportError: -# LOG.warning("Black not installed, skipping formatting") -# return code RESERVED_KEYWORDS = ( "and", "or", @@ -56,103 +41,243 @@ ) +def _sanitize_name(name): + name = name.replace("-", "_") + if name in RESERVED_KEYWORDS: + name = f"{name}_" + return name + + class PythonCode: - def __init__(self, parent): - self.parent = parent + def __init__(self, top): + print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) + self.top = top + self.top.register(self) + self.key = str(id(self)) def call(self, name, argument): - return PythonCall(self, name, argument) + return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self, "+", actions) + return PythonChain(self.top, "+", actions) def pipe(self, actions): - return PythonChain(self, "|", actions) + return PythonChain(self.top, "|", actions) def concat(self, argument): - return PythonConcat(self, argument) + return PythonConcat(self.top, argument) + + def source_code(self): + return self.top.source_code(self) + + def combine(self, nodes): + return None + + +class Argument: + + def __init__(self, name): + self.name = name - def source_code(self, top=None): - return self.parent.source_code(top=top or self) + def __repr__(self): + return f"{_sanitize_name(self.name)}" class PythonSource(PythonCode): def __init__(self): - super().__init__(parent=None) - self.prelude = "" + self._prelude = [] + self.nodes = [] + self._count = defaultdict(int) + super().__init__(top=self) + + def register(self, child): + if child is not self: + self.nodes.append(child) + + def prelude(self): + return "\n".join(self._prelude) + + def source_code(self, first): + + which = self.nodes.index(first) + + more = True + while more: + more = False + + by_class = defaultdict(list) + for node in self.nodes: + by_class[(node.__class__, node.key)].append(node) + + for (cls, key), nodes in by_class.items(): + if len(nodes) > 1: + print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) + print(f"Nodes: {len(nodes)}", file=sys.stderr) + changes = nodes[0].combine(nodes) + if changes: + self.replace_nodes(changes) + more = True + + first = self.nodes[which] + + return "\n\n".join( + [ + "# Generated Python code for Anemoi dataset creation", + "from anemoi.datasets.recipe import Recipe", + "r = Recipe()", + *self._prelude, + f"r.input = {repr(first)}", + "r.dump()", + ] + ) - def source_code(self, top): + def function(self, key, value, node): - return textwrap.dedent( - f""" - # Generated Python code for Anemoi dataset creation + n = self._count[node.name] + self._count[node.name] += 1 - from anemoi.datasets.recipe import Recipe - r = Recipe() - {self.prelude} - r.input = {repr(top)} + name = f"{node.name}_{n}" + name = _sanitize_name(name) + key = _sanitize_name(key) - """ - ) + class Function: + def __init__(self, name, key, value, node): + self.name = name + self.key = key + self.value = value + self.node = node + + def __repr__(self): + return f"{self.name}" + + self._prelude.append(f"def {name}({key}):") + self._prelude.append(f" return {node}") + return Function(name, key, value, node) - return repr(top) + def replace_nodes(self, changes): + + for old, new in changes: + assert old in self.nodes, f"Node {old} not found in {self.nodes}" + for i, node in enumerate(self.nodes): + + if node is old: + self.nodes[i] = new + else: + node.replace_node(old, new) class PythonConcat(PythonCode): - def __init__(self, parent, argument): - super().__init__(parent=parent) + def __init__(self, top, argument): + super().__init__(top=top) self.argument = argument + for k, v in self.argument.items(): + assert isinstance(v, PythonCode), f"Value must be a PythonCode instance {v}" def __repr__(self): - return str(self.argument) + return f"r.concat({self.argument})" + + def replace_node(self, old, new): + for k, v in list(self.argument.items()): + if v is old: + self.argument[k] = new + else: + v.replace_node(old, new) class PythonCall(PythonCode): - def __init__(self, parent, name, argument): - super().__init__(parent=parent) + def __init__(self, top, name, argument, parameters=None): + super().__init__(top=top) self.name = name self.argument = argument + self.key = name + self.parameters = parameters def __repr__(self): name = self.name.replace("-", "_") config = dict(**self.argument) - # def convert(obj): - # if isinstance(obj, datetime.datetime): - # return obj.isoformat() - # if isinstance(obj, datetime.date): - # return obj.isoformat() - # if isinstance(obj, datetime.timedelta): - # return frequency_to_string(obj) - # if isinstance(obj, PythonCode): - # return obj - # raise TypeError(f"Object of type {type(obj)} is not JSON serializable") - - # config = json.loads(json.dumps(config, default=convert)) - params = [] + for k, v in config.items(): if isinstance(k, str): - if k in RESERVED_KEYWORDS or re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k) is None: + + if k in RESERVED_KEYWORDS: + k = f"{k}_" + + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): return f"r.{name}({config})" params.append(f"{k}={repr(v)}") - # for k, v in extra.items(): - # params.append(f"{k}={v}") - params = ",".join(params) return f"r.{name}({params})" - # return f"{name}({config})" - return f"{self.name}({self.argument})" + + def replace_node(self, old, new): + pass + + def combine(self, nodes): + + x = defaultdict(list) + for node in nodes: + argument = node.argument + for k, v in argument.items(): + rest = {k2: v2 for k2, v2 in sorted(argument.items()) if k2 != k} + x[str(rest)].append((k, v, node)) + + for i in sorted(x.values(), key=len, reverse=True): + key, value, node = i[0] + if len(i) < 2: + return + + rest = {k: v for k, v in node.argument.items() if k != key} + rest[key] = Argument(key) + call = PythonCall(self.top, self.name, rest) + + func = self.top.function(key, value, node=call) + changes = [] + for key, value, node in i: + + new = PythonFunction( + top=self.top, + func=func, + argument={key: value}, + ) + + changes.append((node, new)) + + return changes class PythonChain(PythonCode): - def __init__(self, parent, op, actions): - super().__init__(parent=parent) + def __init__(self, top, op, actions): + super().__init__(top=top) self.op = op - self.actions = actions + self.actions = list(actions) + self.key = op def __repr__(self): return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + def replace_node(self, old, new): + + for i, node in enumerate(self.actions): + + if node is old: + self.actions[i] = new + else: + node.replace_node(old, new) + + +class PythonFunction(PythonCode): + def __init__(self, top, func, argument): + super().__init__(top=top) + self.func = func + self.argument = argument + self.key = func + + def __repr__(self): + return f"{self.func}({', '.join(f'{_sanitize_name(k)}={repr(v)}' for k, v in self.argument.items())})" + + def replace_node(self, old, new): + pass From b78a098829fe1dac60c28ea866919b29a81218f9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 12 Aug 2025 18:54:04 +0200 Subject: [PATCH 073/212] update --- src/anemoi/datasets/create/__init__.py | 4 +- src/anemoi/datasets/create/python.py | 217 ++++++++++++++++++------- 2 files changed, 161 insertions(+), 60 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 7a076aa48..fbe4ab024 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1667,13 +1667,13 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - from ..create.python import PythonSource + from ..create.python import PythonScript config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - code = PythonSource() + code = PythonScript() x = input.python_code(code) code = code.source_code(x) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d54fca1f3..dccf33642 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -11,6 +11,7 @@ import re import sys from collections import defaultdict +from functools import cached_property RESERVED_KEYWORDS = ( "and", @@ -74,22 +75,74 @@ def source_code(self): def combine(self, nodes): return None + def prelude(self): + return None + class Argument: def __init__(self, name): - self.name = name + self.name = _sanitize_name(name) def __repr__(self): - return f"{_sanitize_name(self.name)}" + return self.name -class PythonSource(PythonCode): +class Parameter: + + def __init__(self, name): + self.name = _sanitize_name(name) + + def __repr__(self): + return self.name + + +class Function: + def __init__(self, name, node, counter, *parameters): + self._name = name + self.node = node + self.used = False + self.counter = counter + # self.parameters = parameters + + def __repr__(self): + return self.name + + def prelude(self): + if self.used: + return None + + self.used = True + + node_prelude = self.node.prelude() + + arguments = self.node.free_arguments() + + return [ + *(node_prelude if node_prelude else []), + f"def {self.name}({','.join(repr(p) for p in arguments)}):", + f" return {self.node}", + ] + + def free_arguments(self): + return self.node.free_arguments() + + @cached_property + def name(self): + n = self.counter[self._name] + self.counter[self._name] += 1 + return _sanitize_name(f"{self._name}_{n}") + + def replace_node(self, old, new): + if self.node is old: + self.node = new + + +class PythonScript(PythonCode): def __init__(self): - self._prelude = [] self.nodes = [] - self._count = defaultdict(int) + self.counter = defaultdict(int) super().__init__(top=self) def register(self, child): @@ -97,7 +150,14 @@ def register(self, child): self.nodes.append(child) def prelude(self): - return "\n".join(self._prelude) + result = [] + for node in self.nodes: + prelude = node.prelude() + if prelude: + if not isinstance(prelude, (list, tuple)): + prelude = list(prelude) + result.extend(prelude) + return "\n".join(result) def source_code(self, first): @@ -127,34 +187,17 @@ def source_code(self, first): "# Generated Python code for Anemoi dataset creation", "from anemoi.datasets.recipe import Recipe", "r = Recipe()", - *self._prelude, + self.prelude(), f"r.input = {repr(first)}", "r.dump()", ] ) - def function(self, key, value, node): - - n = self._count[node.name] - self._count[node.name] += 1 + def function0(self, node): + return Function(node.name, node, self.counter) - name = f"{node.name}_{n}" - name = _sanitize_name(name) - key = _sanitize_name(key) - - class Function: - def __init__(self, name, key, value, node): - self.name = name - self.key = key - self.value = value - self.node = node - - def __repr__(self): - return f"{self.name}" - - self._prelude.append(f"def {name}({key}):") - self._prelude.append(f" return {node}") - return Function(name, key, value, node) + def function1(self, node, key): + return Function(node.name, node, self.counter, Parameter(key)) def replace_nodes(self, changes): @@ -172,8 +215,6 @@ class PythonConcat(PythonCode): def __init__(self, top, argument): super().__init__(top=top) self.argument = argument - for k, v in self.argument.items(): - assert isinstance(v, PythonCode), f"Value must be a PythonCode instance {v}" def __repr__(self): return f"r.concat({self.argument})" @@ -186,13 +227,39 @@ def replace_node(self, old, new): v.replace_node(old, new) +class PythonChain(PythonCode): + def __init__(self, top, op, actions): + super().__init__(top=top) + self.op = op + self.actions = list(actions) + self.key = op + + def __repr__(self): + return "(" + self.op.join(repr(x) for x in self.actions) + ")" + + def replace_node(self, old, new): + + for i, node in enumerate(self.actions): + + if node is old: + self.actions[i] = new + else: + node.replace_node(old, new) + + class PythonCall(PythonCode): - def __init__(self, top, name, argument, parameters=None): + def __init__(self, top, name, argument): super().__init__(top=top) self.name = name self.argument = argument self.key = name - self.parameters = parameters + + def free_arguments(self): + result = [] + for k, v in self.argument.items(): + if isinstance(v, Argument): + result.append(v) + return result def __repr__(self): name = self.name.replace("-", "_") @@ -210,6 +277,9 @@ def __repr__(self): return f"r.{name}({config})" params.append(f"{k}={repr(v)}") + if params: + params.append("") # For a trailing comma + params = ",".join(params) return f"r.{name}({params})" @@ -218,6 +288,43 @@ def replace_node(self, old, new): def combine(self, nodes): + # Exact similarity + + changes = self._combine0(nodes) + if changes: + return changes + + # On key difference + changes = self._combine1(nodes) + if changes: + return changes + + def _combine0(self, nodes): + + x = defaultdict(list) + for node in nodes: + key = {k2: v2 for k2, v2 in sorted(node.argument.items())} + x[str(key)].append(node) + + for i in sorted(x.values(), key=len, reverse=True): + node = i[0] + if len(i) < 2: + return + + call = PythonCall(self.top, self.name, node.argument) + + func = self.top.function0(node=call) + changes = [] + for node in i: + + new = PythonFunction(top=self.top, func=func) + + changes.append((node, new)) + + return changes + + def _combine1(self, nodes): + x = defaultdict(list) for node in nodes: argument = node.argument @@ -234,14 +341,14 @@ def combine(self, nodes): rest[key] = Argument(key) call = PythonCall(self.top, self.name, rest) - func = self.top.function(key, value, node=call) + func = self.top.function1(call, key) changes = [] for key, value, node in i: new = PythonFunction( top=self.top, func=func, - argument={key: value}, + **{key: value}, ) changes.append((node, new)) @@ -249,35 +356,29 @@ def combine(self, nodes): return changes -class PythonChain(PythonCode): - def __init__(self, top, op, actions): +class PythonFunction(PythonCode): + def __init__(self, top, func, **kwargs): super().__init__(top=top) - self.op = op - self.actions = list(actions) - self.key = op + self.func = func + self.kwargs = kwargs def __repr__(self): - return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - def replace_node(self, old, new): - - for i, node in enumerate(self.actions): - - if node is old: - self.actions[i] = new + params = [] + for a in self.func.free_arguments(): + name = _sanitize_name(a.name) + if a.name in self.kwargs: + v = self.kwargs[a.name] + params.append(f"{name}={repr(v)}") else: - node.replace_node(old, new) + params.append(f"{name}={name}") + return f"{self.func}({', '.join(params)})" -class PythonFunction(PythonCode): - def __init__(self, top, func, argument): - super().__init__(top=top) - self.func = func - self.argument = argument - self.key = func + def replace_node(self, old, new): + self.func.replace_node(old, new) - def __repr__(self): - return f"{self.func}({', '.join(f'{_sanitize_name(k)}={repr(v)}' for k, v in self.argument.items())})" + def prelude(self): + return self.func.prelude() - def replace_node(self, old, new): - pass + def free_arguments(self): + return [a for a in self.func.free_arguments() if a.name not in self.kwargs] From d641ea78309228bdd870155a1226ebe67095ebbb Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 13 Aug 2025 17:05:18 +0200 Subject: [PATCH 074/212] update --- src/anemoi/datasets/create/input/action.py | 9 +- src/anemoi/datasets/create/python.py | 34 +++-- src/anemoi/datasets/dates/__init__.py | 91 +++++++----- src/anemoi/datasets/dates/groups.py | 155 +++++++++++++++------ 4 files changed, 197 insertions(+), 92 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 8dadf14dc..d8120a289 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -56,7 +56,10 @@ def __call__(self, context, argument): def python_code(self, code): return code.concat( - {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} + { + filtering_dates.to_python(just_dates=True): action.python_code(code) + for filtering_dates, action in self.choices + } ) @@ -124,6 +127,10 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) def python_code(self, code) -> str: + # For now... + if "source" in self.config: + source = action_factory(self.config["source"]) + self.config["source"] = source.python_code(code) return code.call(self.name, self.config) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index dccf33642..9e66ccfdd 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -49,6 +49,16 @@ def _sanitize_name(name): return name +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + class PythonCode: def __init__(self, top): @@ -98,12 +108,11 @@ def __repr__(self): class Function: - def __init__(self, name, node, counter, *parameters): + def __init__(self, name, node, counter): self._name = name self.node = node self.used = False self.counter = counter - # self.parameters = parameters def __repr__(self): return self.name @@ -131,6 +140,8 @@ def free_arguments(self): def name(self): n = self.counter[self._name] self.counter[self._name] += 1 + if n == 0: + return _sanitize_name(self._name) return _sanitize_name(f"{self._name}_{n}") def replace_node(self, old, new): @@ -193,12 +204,9 @@ def source_code(self, first): ] ) - def function0(self, node): + def function(self, node): return Function(node.name, node, self.counter) - def function1(self, node, key): - return Function(node.name, node, self.counter, Parameter(key)) - def replace_nodes(self, changes): for old, new in changes: @@ -214,7 +222,7 @@ def replace_nodes(self, changes): class PythonConcat(PythonCode): def __init__(self, top, argument): super().__init__(top=top) - self.argument = argument + self.argument = _un_dotdict(argument) def __repr__(self): return f"r.concat({self.argument})" @@ -251,7 +259,7 @@ class PythonCall(PythonCode): def __init__(self, top, name, argument): super().__init__(top=top) self.name = name - self.argument = argument + self.argument = _un_dotdict(argument) self.key = name def free_arguments(self): @@ -313,7 +321,7 @@ def _combine0(self, nodes): call = PythonCall(self.top, self.name, node.argument) - func = self.top.function0(node=call) + func = self.top.function(call) changes = [] for node in i: @@ -341,7 +349,7 @@ def _combine1(self, nodes): rest[key] = Argument(key) call = PythonCall(self.top, self.name, rest) - func = self.top.function1(call, key) + func = self.top.function(call) changes = [] for key, value, node in i: @@ -363,6 +371,12 @@ def __init__(self, top, func, **kwargs): self.kwargs = kwargs def __repr__(self): + + # if len(self.func.free_arguments()) == 0: + # a = repr(self.func.node) + # if '=' not in a: + # return a + params = [] for a in self.func.free_arguments(): name = _sanitize_name(a.name) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 9570d381f..0ae8105e8 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -32,13 +32,15 @@ def extend(x: Union[str, List[Any], Tuple[Any, ...]]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Args: - x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. - - Returns - ------- - Iterator[datetime.datetime] - An iterator of datetime objects. + Parameters + ---------- + x : Union[str, List[Any], Tuple[Any, ...]] + A date range string or list/tuple of dates. + + Yields + ------ + datetime.datetime + Individual datetime objects. """ if isinstance(x, (list, tuple)): @@ -63,6 +65,8 @@ def extend(x: Union[str, List[Any], Tuple[Any, ...]]) -> Iterator[datetime.datet class DatesProvider: """Base class for date generation. + Examples + -------- >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)] @@ -163,9 +167,12 @@ def summary(self) -> str: class ValuesDates(DatesProvider): """Class for handling a list of date values. - Args: - values (List[Union[str, datetime.datetime]]): List of date values. - **kwargs (Any): Additional arguments. + Parameters + ---------- + values : List[Union[str, datetime.datetime]] + List of date values. + **kwargs : Any + Additional arguments. """ def __init__(self, values: List[Union[str, datetime.datetime]], **kwargs: Any) -> None: @@ -202,11 +209,16 @@ def as_dict(self) -> Dict[str, Any]: class StartEndDates(DatesProvider): """Class for generating dates between a start and end date with a specified frequency. - Args: - start (Union[str, datetime.datetime]): Start date. - end (Union[str, datetime.datetime]): End date. - frequency (Union[int, str]): Frequency of dates. - **kwargs (Any): Additional arguments. + Parameters + ---------- + start : Union[str, datetime.datetime] + Start date. + end : Union[str, datetime.datetime] + End date. + frequency : Union[int, str] + Frequency of dates. + **kwargs : Any + Additional arguments. """ def __repr__(self) -> str: @@ -273,26 +285,27 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self) -> str: - """Convert the StartEndDates instance to a Python string. - - Returns - ------- - str - Python string representation of the instance. - """ - # assert self.frequency == frequency_to_timedelta(1), self.frequency - return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) + def to_python(self, just_dates=False) -> str: + """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" + if just_dates: + return (self.start.isoformat(), self.end.isoformat()) + else: + return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) class Hindcast: """Class representing a single hindcast date. - Args: - date (datetime.datetime): The date of the hindcast. - refdate (datetime.datetime): The reference date. - hdate (datetime.datetime): The hindcast date. - step (int): The step value. + Parameters + ---------- + date : datetime.datetime + The date of the hindcast. + refdate : datetime.datetime + The reference date. + hdate : datetime.datetime + The hindcast date. + step : int + The step value. """ def __init__( @@ -315,12 +328,18 @@ def __init__( class HindcastsDates(DatesProvider): """Class for generating hindcast dates over a range of years. - Args: - start (Union[str, List[str]]): Start date(s). - end (Union[str, List[str]]): End date(s). - steps (List[int]): List of step values. - years (int): Number of years. - **kwargs (Any): Additional arguments. + Parameters + ---------- + start : Union[str, List[str]] + Start date(s). + end : Union[str, List[str]] + End date(s). + steps : List[int] + List of step values. + years : int + Number of years. + **kwargs : Any + Additional arguments. """ def __init__( diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index f70fe8a57..88165f609 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -27,11 +27,15 @@ def _shorten(dates: Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]) -> Union[str, List[str]]: """Shorten the list of dates for display. - Args: - dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. - - Returns: - Union[str, List[str]]: The shortened list of dates. + Parameters + ---------- + dates : Union[List[datetime.datetime], Tuple[datetime.datetime, ...]] + The list of dates. + + Returns + ------- + Union[str, List[str]] + The shortened list of dates. """ if isinstance(dates, (list, tuple)): dates = [d.isoformat() for d in dates] @@ -44,6 +48,17 @@ class GroupOfDates: """A class to represent a group of dates.""" def __init__(self, dates: List[datetime.datetime], provider: DatesProvider, partial_ok: bool = False) -> None: + """Initialise a GroupOfDates instance. + + Parameters + ---------- + dates : List[datetime.datetime] + List of dates. + provider : DatesProvider + The dates provider. + partial_ok : bool, optional + Whether partial groups are allowed (default is False). + """ assert isinstance(provider, DatesProvider), type(provider) assert isinstance(dates, list) @@ -54,35 +69,45 @@ def __init__(self, dates: List[datetime.datetime], provider: DatesProvider, part def __len__(self) -> int: """Return the number of dates in the group. - Returns: - int: The number of dates. + Returns + ------- + int + The number of dates. """ return len(self.dates) def __iter__(self) -> Iterator[datetime.datetime]: """Return an iterator over the dates in the group. - Returns: - Iterator[datetime.datetime]: The iterator over the dates. + Returns + ------- + Iterator[datetime.datetime] + The iterator over the dates. """ return iter(self.dates) def __repr__(self) -> str: """Return a string representation of the group of dates. - Returns: - str: The string representation. + Returns + ------- + str + The string representation. """ return f"GroupOfDates(dates={_shorten(self.dates)})" def __eq__(self, other: object) -> bool: """Check if two groups of dates are equal. - Args: - other (object): The other group of dates. + Parameters + ---------- + other : object + The other group of dates. - Returns: - bool: True if the groups are equal, False otherwise. + Returns + ------- + bool + True if the groups are equal, False otherwise. """ return isinstance(other, GroupOfDates) and self.dates == other.dates @@ -90,7 +115,8 @@ def __eq__(self, other: object) -> bool: class Groups: """A collection of groups of dates. - Examples: + Examples + -------- >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] @@ -122,7 +148,8 @@ def __init__(self, **kwargs: Any) -> None: Parameters ---------- - **kwargs : Any : Arbitrary keyword arguments. Expected keys include: + **kwargs : Any + Arbitrary keyword arguments. Expected keys include: - group_by: Configuration for the Grouper. - Other keys for DatesProvider configuration. """ @@ -140,8 +167,10 @@ def provider(self) -> DatesProvider: def __iter__(self) -> Iterator[GroupOfDates]: """Return an iterator over the groups of dates. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ for go in self._grouper(self._dates): dates = self._filter(go.dates) @@ -152,8 +181,10 @@ def __iter__(self) -> Iterator[GroupOfDates]: def __len__(self) -> int: """Return the number of groups of dates. - Returns: - int: The number of groups. + Returns + ------- + int + The number of groups. """ return self._len @@ -171,24 +202,30 @@ def _len(self) -> int: def __repr__(self) -> str: """Return a string representation of the groups of dates. - Returns: - str: The string representation. + Returns + ------- + str + The string representation. """ return f"{self.__class__.__name__}(dates={len(self)},{_shorten(self._dates)})" def describe(self) -> str: """Return a summary description of the dates. - Returns: - str: The summary description. + Returns + ------- + str + The summary description. """ return self._dates.summary def one_date(self) -> GroupOfDates: """Return a group containing only one date. - Returns: - GroupOfDates: The group containing only one date. + Returns + ------- + GroupOfDates + The group containing only one date. """ go = next(iter(self)) return GroupOfDates([go.dates[0]], go.provider) @@ -203,22 +240,24 @@ def __init__(self, missing: List[datetime.datetime]) -> None: def __call__(self, dates: List[datetime.datetime]) -> List[datetime.datetime]: """Filter out missing dates from the list of dates. - Args: - dates (List[datetime.datetime]): The list of dates. + Parameters + ---------- + dates : List[datetime.datetime] + The list of dates. - Returns: - List[datetime.datetime]: The filtered list of dates. + Returns + ------- + List[datetime.datetime] + The filtered list of dates. """ return [d for d in dates if d not in self.missing] class Grouper(ABC): - """Abstract base class for grouping dates.""" @classmethod def from_config(cls, group_by: Any) -> "Grouper": """Create a grouper based on the configuration.""" - if isinstance(group_by, int) and group_by > 0: return GrouperByFixedSize(group_by) @@ -278,11 +317,15 @@ class GrouperOneGroup(Grouper): def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group all dates into a single group. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ assert isinstance(dates, DatesProvider), type(dates) @@ -293,16 +336,27 @@ class GrouperByKey(Grouper): """Group dates by a key.""" def __init__(self, key: Callable[[datetime.datetime], Any]) -> None: + """Initialise GrouperByKey with a key function. + + Parameters + ---------- + key : Callable[[datetime.datetime], Any] + Function to extract grouping key from a datetime. + """ self.key = key def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates based on the provided key. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ for _, g in itertools.groupby(sorted(dates, key=self.key), key=self.key): yield GroupOfDates(list(g), dates) @@ -312,16 +366,27 @@ class GrouperByFixedSize(Grouper): """Group dates by a fixed size.""" def __init__(self, size: int) -> None: + """Initialise GrouperByFixedSize with batch size. + + Parameters + ---------- + size : int + Number of dates per group. + """ self.size = size def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates into fixed-size batches. - Args: - dates (DatesProvider): The dates provider. + Parameters + ---------- + dates : DatesProvider + The dates provider. - Returns: - Iterator[GroupOfDates]: The iterator over the groups of dates. + Returns + ------- + Iterator[GroupOfDates] + The iterator over the groups of dates. """ batch = [] From 3754eb267c86b8eace2ffe3b6c1b4170913296dc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 18:13:40 +0200 Subject: [PATCH 075/212] update --- src/anemoi/datasets/create/input/__init__.py | 22 ++- src/anemoi/datasets/create/input/action.py | 27 +++- src/anemoi/datasets/create/python.py | 161 ++++++++++++++++--- src/anemoi/datasets/recipe.py | 38 ++++- 4 files changed, 210 insertions(+), 38 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 63020324c..2e089a42f 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -18,36 +18,32 @@ class InputBuilder: """Builder class for creating input data from configuration and data sources.""" - def __init__(self, config: dict, data_sources: Union[dict, list], **kwargs: Any) -> None: + def __init__(self, config: dict, data_sources: dict, **kwargs: Any) -> None: """Initialize the InputBuilder. Parameters ---------- config : dict Configuration dictionary. - data_sources : Union[dict, list] + data_sources : dict Data sources. **kwargs : Any Additional keyword arguments. """ self.kwargs = kwargs - - config = deepcopy(config) - if data_sources: - config = dict( - data_sources=dict( - sources=data_sources, - input=config, - ) - ) - self.config = config + self.config = deepcopy(config) + self.data_sources = deepcopy(dict(data_sources=data_sources)) @cached_property def action(self) -> Any: """Returns the action object based on the configuration.""" + from .action import Recipe from .action import action_factory - return action_factory(self.config, "input") + sources = action_factory(self.data_sources, "data_sources") + input = action_factory(self.config, "input") + + return Recipe(input, sources) def select(self, argument) -> Any: """Select data based on the group of dates. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index d8120a289..fdb412458 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -198,7 +198,32 @@ def new_filter(name, mixin): ) -KLASS = {"concat": Concat, "join": Join, "pipe": Pipe} +class DataSources(Action): + def __init__(self, config, *path): + self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + + def python_code(self, code): + return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + + +class Recipe(Action): + def __init__(self, input, data_sources): + self.input = input + self.data_sources = data_sources + + def python_code(self, code): + return code.recipe( + self.input.python_code(code), + self.data_sources.python_code(code), + ) + + +KLASS = { + "concat": Concat, + "join": Join, + "pipe": Pipe, + "data-sources": DataSources, +} LEN_KLASS = len(KLASS) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 9e66ccfdd..6924af6cc 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -8,11 +8,14 @@ # nor does it submit to any jurisdiction. # +import logging import re import sys from collections import defaultdict from functools import cached_property +LOG = logging.getLogger(__name__) + RESERVED_KEYWORDS = ( "and", "or", @@ -62,7 +65,7 @@ def _un_dotdict(x): class PythonCode: def __init__(self, top): - print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) + # print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) self.top = top self.top.register(self) self.key = str(id(self)) @@ -71,10 +74,10 @@ def call(self, name, argument): return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self.top, "+", actions) + return PythonChain(self.top, "join", "+", actions) def pipe(self, actions): - return PythonChain(self.top, "|", actions) + return PythonChain(self.top, "pipe", "|", actions) def concat(self, argument): return PythonConcat(self.top, argument) @@ -85,27 +88,102 @@ def source_code(self): def combine(self, nodes): return None + def recipe(self, input, data_sources): + return PythonRecipe(self.top, input, data_sources) + def prelude(self): return None + def sources(self, sources): + return PythonSources(self.top, sources) -class Argument: - def __init__(self, name): - self.name = _sanitize_name(name) +class PythonRecipe(PythonCode): + def __init__(self, top, input, data_sources): + super().__init__(top) + self.input = input + self.data_sources = data_sources + + def apply_references(self, *path): + self.input.apply_references(*path, "input") + + def replace_node(self, old, new): + if self.input is old: + self.input = new + return + + if self.data_sources is old: + self.data_sources = new + return + + self.input.replace_node(old, new) + self.data_sources.replace_node(old, new) def __repr__(self): - return self.name + return repr(self.input) + + def prelude(self): + return self.data_sources.prelude() -class Parameter: +class Argument(PythonCode): - def __init__(self, name): + def __init__(self, top, name): + super().__init__(top=top) self.name = _sanitize_name(name) def __repr__(self): return self.name + def replace_node(self, old, new): + pass + + +class Anchor(PythonCode): + + def __init__(self, node): + super().__init__(top=node.top) + self.node = node + + @cached_property + def name(self): + n = self.top.counter["_anchor"] + self.top.counter["_anchor"] += 1 + return f"_a{n}" + + def __repr__(self): + return f"({self.name} := {repr(self.node)})" + + def replace_node(self, old, new): + pass + + +class Reference(PythonCode): + + def __init__(self, top, path): + super().__init__(top) + self.path = tuple(path) + self.anchor = None + + node = top.by_reference.get(self.path, None) + if node is None: + LOG.warning(f"Reference {self.path} not found") + for p in sorted(top.by_reference): + LOG.warning(f" - {p}") + else: + self.anchor = Anchor(node) + self.top.replace_nodes([(node, self.anchor)]) + + def __repr__(self): + if self.anchor is not None: + print("Reference:", self.path, "->", self.anchor.name, file=sys.stderr) + return self.anchor.name + + return f"'${{{'.'.join(self.path)}}}'" + + def replace_node(self, old, new): + pass + class Function: def __init__(self, name, node, counter): @@ -149,11 +227,38 @@ def replace_node(self, old, new): self.node = new +class PythonSources(PythonCode): + def __init__(self, top, sources): + super().__init__(top) + self.sources = sources + + def __repr__(self): + return "" + + def prelude(self): + result = [] + for k, v in self.sources.items(): + result.append(f"{k}={repr(v)}") + result.append("") + return result + + def replace_node(self, old, new): + for k, v in list(self.sources.items()): + if v is old: + self.sources[k] = new + else: + v.replace_node(old, new) + + def apply_references(self, *path): + self.top.by_reference[path + (self.name,)] = self + + class PythonScript(PythonCode): def __init__(self): self.nodes = [] self.counter = defaultdict(int) + self.by_reference = {} super().__init__(top=self) def register(self, child): @@ -174,6 +279,10 @@ def source_code(self, first): which = self.nodes.index(first) + first.apply_references() + # for k, v in self.by_reference.items(): + # print(f"Reference: {k} -> {v}", file=sys.stderr) + more = True while more: more = False @@ -184,8 +293,8 @@ def source_code(self, first): for (cls, key), nodes in by_class.items(): if len(nodes) > 1: - print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) - print(f"Nodes: {len(nodes)}", file=sys.stderr) + # print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) + # print(f"Nodes: {len(nodes)}", file=sys.stderr) changes = nodes[0].combine(nodes) if changes: self.replace_nodes(changes) @@ -234,11 +343,18 @@ def replace_node(self, old, new): else: v.replace_node(old, new) + def apply_references(self, *path): + assert "concat" not in path, path + self.top.by_reference[path + ("concat",)] = self + for i, node in enumerate(self.argument.values()): + node.apply_references(*path, "concat", str(i)) + class PythonChain(PythonCode): - def __init__(self, top, op, actions): + def __init__(self, top, kind, op, actions): super().__init__(top=top) self.op = op + self.kind = kind self.actions = list(actions) self.key = op @@ -254,6 +370,11 @@ def replace_node(self, old, new): else: node.replace_node(old, new) + def apply_references(self, *path): + self.top.by_reference[path + (self.kind,)] = self + for i, node in enumerate(self.actions): + node.apply_references(*path, self.kind, str(i)) + class PythonCall(PythonCode): def __init__(self, top, name, argument): @@ -283,6 +404,7 @@ def __repr__(self): if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): return f"r.{name}({config})" + params.append(f"{k}={repr(v)}") if params: @@ -346,7 +468,7 @@ def _combine1(self, nodes): return rest = {k: v for k, v in node.argument.items() if k != key} - rest[key] = Argument(key) + rest[key] = Argument(self.top, key) call = PythonCall(self.top, self.name, rest) func = self.top.function(call) @@ -363,6 +485,14 @@ def _combine1(self, nodes): return changes + def apply_references(self, *path): + self.top.by_reference[path + (self.name,)] = self + + for k, v in self.argument.items(): + if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): + path = m.group(1).split(".") + self.argument[k] = Reference(self.top, path) + class PythonFunction(PythonCode): def __init__(self, top, func, **kwargs): @@ -372,11 +502,6 @@ def __init__(self, top, func, **kwargs): def __repr__(self): - # if len(self.func.free_arguments()) == 0: - # a = repr(self.func.node) - # if '=' not in a: - # return a - params = [] for a in self.func.free_arguments(): name = _sanitize_name(a.name) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 69ce8b7df..24fc8caa2 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -12,6 +12,7 @@ import sys from tempfile import TemporaryDirectory +import rich import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict @@ -32,6 +33,11 @@ def __init__(self, index): def __repr__(self): return f"Index({self.name})" + def same(self, other): + if not isinstance(other, Index): + return False + return self.name == other.name + class Step: @@ -41,6 +47,9 @@ def __or__(self, other): def __add__(self, other): return Join(self, other) + def same(self, other): + return self is other + class Chain(Step): def __init__(self, *args): @@ -55,10 +64,18 @@ def as_dict(self, recipe): return self.steps[0].as_dict(recipe) return {self.name: [s.as_dict(recipe) for s in self.steps]} - def __repr__(self): - return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" + # def __repr__(self): + # return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" def path(self, target, result, *path): + + rich.print(f"path: {target=}, {self=}, {result=}, {[s.name for s in path]}") + rich.print("-------------") + + if target is self: + result.append([*path, self]) + return + for i, s in enumerate(self.steps): s.path(target, result, *path, self, self.index[i]) @@ -81,6 +98,8 @@ def __init__(self, args): assert isinstance(args, dict), f"Invalid argument {args}" self.params = args + rich.print(f"Concat: {self=}") + def __setitem__(self, key, value): self.params[key] = value @@ -94,10 +113,15 @@ def as_dict(self, recipe): return {"concat": result} def collocated(self, a, b): - return a[0] is b[0] + return a[0].same(b[0]) def path(self, target, result, *path): + rich.print(f"path: {target=}, {self=}, {result=}, {path=}") + rich.print("-------------") + if target is self: + result.append([*path, self]) + return for i, (k, v) in enumerate(sorted(self.params.items())): v.path(target, result, *path, self, Index(i)) @@ -131,8 +155,8 @@ def resolve(params, recipe): return {self.owner.name: resolve(self.params, recipe)} - def __repr__(self): - return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" + # def __repr__(self): + # return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" def path(self, target, result, *path): @@ -233,7 +257,7 @@ def concat(self, *args, **kwargs): return Concat(*args, **kwargs) def resolve(self, source, target): - assert isinstance(target, Source), f"Only sources can be used as template {target}" + # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -263,6 +287,8 @@ def resolve(self, source, target): assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" + rich.print(f"Common ancestor: {common_ancestor=} {a=} {b=}") + if not common_ancestor.collocated(a, b): source = ".".join(s.name for s in path_to_source) target = ".".join(s.name for s in path_to_target) From 39ebc13063d71c3e020e7ee061737262c1e56dca Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:04:12 +0200 Subject: [PATCH 076/212] update --- src/anemoi/datasets/create/python.py | 232 ++++++++++++++++----------- 1 file changed, 134 insertions(+), 98 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 6924af6cc..1681f885b 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -10,7 +10,6 @@ import logging import re -import sys from collections import defaultdict from functools import cached_property @@ -97,6 +96,44 @@ def prelude(self): def sources(self, sources): return PythonSources(self.top, sources) + def update_anchor(self): + pass + + +class Variable(PythonCode): + def __init__(self, name, node): + super().__init__(top=node.top) + self.name = name + self.node = node + + def __repr__(self): + + return "" + + def replace_node(self, old, new): + pass + + def prelude(self): + return [f"{self.name} = {repr(self.node)}", ""] + + +class InLine(PythonCode): + def __init__(self, node): + super().__init__(top=node.top) + self.node = node + + @cached_property + def name(self): + n = self.top.counter["_anchor"] + self.top.counter["_anchor"] += 1 + return f"_a{n}" + + def __repr__(self): + return f"({self.name} := {repr(self.node)})" + + def replace_node(self, old, new): + pass + class PythonRecipe(PythonCode): def __init__(self, top, input, data_sources): @@ -105,6 +142,7 @@ def __init__(self, top, input, data_sources): self.data_sources = data_sources def apply_references(self, *path): + self.data_sources.apply_references(*path, "data_sources") self.input.apply_references(*path, "input") def replace_node(self, old, new): @@ -141,18 +179,17 @@ def replace_node(self, old, new): class Anchor(PythonCode): - def __init__(self, node): - super().__init__(top=node.top) - self.node = node + def __init__(self, identifier): + super().__init__(top=identifier.node.top) + self.identifier = identifier - @cached_property + @property def name(self): - n = self.top.counter["_anchor"] - self.top.counter["_anchor"] += 1 - return f"_a{n}" + return self.identifier.name def __repr__(self): - return f"({self.name} := {repr(self.node)})" + # assert False + return repr(self.identifier) def replace_node(self, old, new): pass @@ -165,18 +202,19 @@ def __init__(self, top, path): self.path = tuple(path) self.anchor = None - node = top.by_reference.get(self.path, None) + def update_anchor(self): + + node = self.top.by_reference.get(self.path, None) if node is None: LOG.warning(f"Reference {self.path} not found") - for p in sorted(top.by_reference): + for p in sorted(self.top.by_reference): LOG.warning(f" - {p}") else: self.anchor = Anchor(node) - self.top.replace_nodes([(node, self.anchor)]) + self.top.replace_nodes([(node.node, self.anchor)]) def __repr__(self): if self.anchor is not None: - print("Reference:", self.path, "->", self.anchor.name, file=sys.stderr) return self.anchor.name return f"'${{{'.'.join(self.path)}}}'" @@ -185,8 +223,9 @@ def replace_node(self, old, new): pass -class Function: +class Function(PythonCode): def __init__(self, name, node, counter): + super().__init__(top=node.top) self._name = name self.node = node self.used = False @@ -236,11 +275,7 @@ def __repr__(self): return "" def prelude(self): - result = [] - for k, v in self.sources.items(): - result.append(f"{k}={repr(v)}") - result.append("") - return result + pass def replace_node(self, old, new): for k, v in list(self.sources.items()): @@ -250,82 +285,8 @@ def replace_node(self, old, new): v.replace_node(old, new) def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = self - - -class PythonScript(PythonCode): - - def __init__(self): - self.nodes = [] - self.counter = defaultdict(int) - self.by_reference = {} - super().__init__(top=self) - - def register(self, child): - if child is not self: - self.nodes.append(child) - - def prelude(self): - result = [] - for node in self.nodes: - prelude = node.prelude() - if prelude: - if not isinstance(prelude, (list, tuple)): - prelude = list(prelude) - result.extend(prelude) - return "\n".join(result) - - def source_code(self, first): - - which = self.nodes.index(first) - - first.apply_references() - # for k, v in self.by_reference.items(): - # print(f"Reference: {k} -> {v}", file=sys.stderr) - - more = True - while more: - more = False - - by_class = defaultdict(list) - for node in self.nodes: - by_class[(node.__class__, node.key)].append(node) - - for (cls, key), nodes in by_class.items(): - if len(nodes) > 1: - # print(f"Found multiple nodes of type {cls.__name__}/{key}, merging them", file=sys.stderr) - # print(f"Nodes: {len(nodes)}", file=sys.stderr) - changes = nodes[0].combine(nodes) - if changes: - self.replace_nodes(changes) - more = True - - first = self.nodes[which] - - return "\n\n".join( - [ - "# Generated Python code for Anemoi dataset creation", - "from anemoi.datasets.recipe import Recipe", - "r = Recipe()", - self.prelude(), - f"r.input = {repr(first)}", - "r.dump()", - ] - ) - - def function(self, node): - return Function(node.name, node, self.counter) - - def replace_nodes(self, changes): - - for old, new in changes: - assert old in self.nodes, f"Node {old} not found in {self.nodes}" - for i, node in enumerate(self.nodes): - - if node is old: - self.nodes[i] = new - else: - node.replace_node(old, new) + for k, v in self.sources.items(): + self.top.by_reference[path + (k,)] = Variable(k, v) class PythonConcat(PythonCode): @@ -345,7 +306,7 @@ def replace_node(self, old, new): def apply_references(self, *path): assert "concat" not in path, path - self.top.by_reference[path + ("concat",)] = self + self.top.by_reference[path + ("concat",)] = InLine(self) for i, node in enumerate(self.argument.values()): node.apply_references(*path, "concat", str(i)) @@ -371,7 +332,7 @@ def replace_node(self, old, new): node.replace_node(old, new) def apply_references(self, *path): - self.top.by_reference[path + (self.kind,)] = self + self.top.by_reference[path + (self.kind,)] = InLine(self) for i, node in enumerate(self.actions): node.apply_references(*path, self.kind, str(i)) @@ -486,7 +447,7 @@ def _combine1(self, nodes): return changes def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = self + self.top.by_reference[path + (self.name,)] = InLine(self) for k, v in self.argument.items(): if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): @@ -521,3 +482,78 @@ def prelude(self): def free_arguments(self): return [a for a in self.func.free_arguments() if a.name not in self.kwargs] + + +class PythonScript(PythonCode): + + def __init__(self): + self.nodes = [] + self.counter = defaultdict(int) + self.by_reference = {} + super().__init__(top=self) + + def register(self, child): + if child is not self: + self.nodes.append(child) + + def prelude(self): + result = [] + for node in self.nodes: + prelude = node.prelude() + if prelude: + if not isinstance(prelude, (list, tuple)): + prelude = list(prelude) + result.extend(prelude) + return "\n".join(result) + + def source_code(self, first): + + which = self.nodes.index(first) + + more = True + while more: + more = False + + by_class = defaultdict(list) + for node in self.nodes: + by_class[(node.__class__, node.key)].append(node) + + for nodes in by_class.values(): + if len(nodes) > 1: + changes = nodes[0].combine(nodes) + if changes: + self.replace_nodes(changes) + more = True + + first = self.nodes[which] + + first.apply_references() + for node in self.nodes: + node.update_anchor() + + first = self.nodes[which] + + return "\n\n".join( + [ + "# Generated Python code for Anemoi dataset creation", + "from anemoi.datasets.recipe import Recipe", + "r = Recipe()", + self.prelude(), + f"r.input = {repr(first)}", + "r.dump()", + ] + ) + + def function(self, node): + return Function(node.name, node, self.counter) + + def replace_nodes(self, changes): + + for old, new in changes: + assert old in self.nodes, f"Node {old} not found in {self.nodes}" + for i, node in enumerate(self.nodes): + + if node is old: + self.nodes[i] = new + else: + node.replace_node(old, new) From 5d327451ec751a59ef156b7443dc0cd500118505 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:08:18 +0200 Subject: [PATCH 077/212] update --- src/anemoi/datasets/create/python.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 1681f885b..d0ab72dd5 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -483,6 +483,9 @@ def prelude(self): def free_arguments(self): return [a for a in self.func.free_arguments() if a.name not in self.kwargs] + def apply_references(self, *path): + pass + class PythonScript(PythonCode): @@ -509,6 +512,9 @@ def prelude(self): def source_code(self, first): which = self.nodes.index(first) + first.apply_references() + for node in self.nodes: + node.update_anchor() more = True while more: @@ -527,12 +533,6 @@ def source_code(self, first): first = self.nodes[which] - first.apply_references() - for node in self.nodes: - node.update_anchor() - - first = self.nodes[which] - return "\n\n".join( [ "# Generated Python code for Anemoi dataset creation", From 6c1f14639e017d4c40ff0902e3ee84d87bc9d454 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 19:39:18 +0200 Subject: [PATCH 078/212] update --- src/anemoi/datasets/recipe.py | 56 ++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 24fc8caa2..e42263343 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -10,9 +10,9 @@ import logging import os import sys +from collections import defaultdict from tempfile import TemporaryDirectory -import rich import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict @@ -69,9 +69,6 @@ def as_dict(self, recipe): def path(self, target, result, *path): - rich.print(f"path: {target=}, {self=}, {result=}, {[s.name for s in path]}") - rich.print("-------------") - if target is self: result.append([*path, self]) return @@ -98,8 +95,6 @@ def __init__(self, args): assert isinstance(args, dict), f"Invalid argument {args}" self.params = args - rich.print(f"Concat: {self=}") - def __setitem__(self, key, value): self.params[key] = value @@ -116,9 +111,6 @@ def collocated(self, a, b): return a[0].same(b[0]) def path(self, target, result, *path): - rich.print(f"path: {target=}, {self=}, {result=}, {path=}") - rich.print("-------------") - if target is self: result.append([*path, self]) return @@ -138,9 +130,9 @@ def __init__(self, owner, *args, **kwargs): def as_dict(self, recipe): - def resolve(params, recipe): + def resolve(params, recipe, name=None): if isinstance(params, dict): - return {k: resolve(v, recipe) for k, v in params.items()} + return {k: resolve(v, recipe, name=k) for k, v in params.items()} if isinstance(params, (list, tuple)): return [resolve(v, recipe) for v in params] @@ -149,7 +141,7 @@ def resolve(params, recipe): return [resolve(v, recipe) for v in params] if isinstance(params, Step): - return recipe.resolve(self, params) + return recipe.resolve(self, params, name=name) return params @@ -209,6 +201,9 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self.statistics = DotDict() self.build = DotDict() + self._data_sources = {} + self._counter = defaultdict(int) + sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() @@ -247,6 +242,9 @@ def as_dict(self): "dates": self.dates, } + if self._data_sources: + result["data_sources"] = self._data_sources + for k, v in list(result.items()): if v is None: del result[k] @@ -256,7 +254,26 @@ def as_dict(self): def concat(self, *args, **kwargs): return Concat(*args, **kwargs) - def resolve(self, source, target): + # def assert False, (name, target.as_dict(self)) + + def make_data_source(self, name, target): + + target = target.as_dict(self) + + name = name or "source" + if name in self._data_sources: + if self._data_sources[name] == target: + return f"${{data_sources.{name}}}" + + n = self._counter[name] + self._counter[name] += 1 + + name = f"{name}_{n}" if n > 0 else name + + self._data_sources[name] = target.copy() + return f"${{data_sources.{name}}}" + + def resolve(self, source, target, name=None): # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -271,10 +288,13 @@ def resolve(self, source, target): path_to_target = [] self.input.path(target, path_to_target, top) - if len(path_to_target) == 0: - raise ValueError(f"Target {target} not found in recipe") if len(path_to_target) > 1: raise ValueError(f"Target {target} found in multiple locations {path_to_target}") + + if len(path_to_target) == 0: + # Add a `data_sources` entry + return self.make_data_source(name, target) + path_to_target = path_to_target[0] a = [s for s in path_to_target] @@ -287,8 +307,6 @@ def resolve(self, source, target): assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" - rich.print(f"Common ancestor: {common_ancestor=} {a=} {b=}") - if not common_ancestor.collocated(a, b): source = ".".join(s.name for s in path_to_source) target = ".".join(s.name for s in path_to_target) @@ -368,9 +386,11 @@ def dates(self, value): self._dates = self._parse_dates(value) def dump(self, file=sys.stdout): + input = self.input.as_dict(self) # First so we get the data_sources + result = self.as_dict() - result["input"] = self.input.as_dict(self) + result["input"] = input if self.output: result["output"] = self.output.as_dict() From 37de369a91b9fd4613f730c6697a73fc85c39f1d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 14 Aug 2025 20:17:15 +0200 Subject: [PATCH 079/212] update --- src/anemoi/datasets/create/python.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d0ab72dd5..ee33398e9 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -64,7 +64,6 @@ def _un_dotdict(x): class PythonCode: def __init__(self, top): - # print(f"Creating {self.__class__.__name__} from {top.__class__.__name__}", file=sys.stderr) self.top = top self.top.register(self) self.key = str(id(self)) @@ -107,7 +106,6 @@ def __init__(self, name, node): self.node = node def __repr__(self): - return "" def replace_node(self, old, new): From 24f2c2aba8152be268ee914803eccf87e31ce4ce Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 11:13:48 +0200 Subject: [PATCH 080/212] update --- src/anemoi/datasets/create/input/action.py | 5 +---- src/anemoi/datasets/create/python.py | 9 +++------ src/anemoi/datasets/dates/__init__.py | 4 ++-- src/anemoi/datasets/recipe.py | 8 +++++++- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index fdb412458..d173f0996 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -56,10 +56,7 @@ def __call__(self, context, argument): def python_code(self, code): return code.concat( - { - filtering_dates.to_python(just_dates=True): action.python_code(code) - for filtering_dates, action in self.choices - } + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} ) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index ee33398e9..2a47a5b67 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -356,13 +356,10 @@ def __repr__(self): params = [] for k, v in config.items(): - if isinstance(k, str): + k = _sanitize_name(k) - if k in RESERVED_KEYWORDS: - k = f"{k}_" - - if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", k): - return f"r.{name}({config})" + if not k.isidentifier(): + return f"r.{name}({config})" params.append(f"{k}={repr(v)}") diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 0ae8105e8..5e6c1863a 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -285,9 +285,9 @@ def as_dict(self) -> Dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self, just_dates=False) -> str: + def to_python(self) -> str: """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" - if just_dates: + if self.frequency == datetime.timedelta(hours=1): return (self.start.isoformat(), self.end.isoformat()) else: return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index e42263343..455cf689a 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -132,7 +132,13 @@ def as_dict(self, recipe): def resolve(params, recipe, name=None): if isinstance(params, dict): - return {k: resolve(v, recipe, name=k) for k, v in params.items()} + + def _(k): + if k.endswith("_"): + return k[:-1] + return k + + return {_(k): resolve(v, recipe, name=_(k)) for k, v in params.items()} if isinstance(params, (list, tuple)): return [resolve(v, recipe) for v in params] From db4d8956c121fdc47862d9f711aa57b212e568cc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 12:27:44 +0200 Subject: [PATCH 081/212] add dumper --- src/anemoi/datasets/dumper.py | 71 +++++++++++++++++++++++++++++++++++ src/anemoi/datasets/recipe.py | 12 ++++-- 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 src/anemoi/datasets/dumper.py diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py new file mode 100644 index 000000000..052750d56 --- /dev/null +++ b/src/anemoi/datasets/dumper.py @@ -0,0 +1,71 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import yaml + +LOG = logging.getLogger(__name__) + + +class MyDumper(yaml.SafeDumper): + pass + + +def represent_date(dumper, data): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# --- Represent multiline strings with | style --- +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + + if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + elems = [yaml.dump(i, explicit_start=False, explicit_end=False).replace("\n...\n", "") for i in data] + lines = [] + line = [] + for e in elems: + if sum(len(x) for x in line) + len(e) + 2 * (len(line) + 1) <= 80: + line.append(e) + else: + lines.append(line) + line = [e] + + if line: + lines.append(line) + + if len(lines) == 1: + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + + block_lines = ["- [" + ", ".join(line) + "]" for line in lines] + return dumper.represent_scalar("tag:yaml.org,2002:str", "\n".join(block_lines), style="|") + + +# Register representers +MyDumper.add_representer(datetime.date, represent_date) +MyDumper.add_representer(datetime.datetime, represent_date) +MyDumper.add_representer(str, represent_multiline_str) +MyDumper.add_representer(list, represent_inline_list) + + +def yaml_dump(obj, **kwargs): + return yaml.dump(obj, Dumper=MyDumper, **kwargs) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 455cf689a..51e83b693 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -13,7 +13,6 @@ from collections import defaultdict from tempfile import TemporaryDirectory -import yaml from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.utils.config import DotDict from anemoi.utils.dates import as_datetime @@ -103,7 +102,12 @@ def as_dict(self, recipe): result = [] for k, v in sorted(self.params.items()): - result.append({"dates": dict(start=k[0], end=k[1]), **v.as_dict(recipe)}) + + key = dict(start=as_datetime(k[0]), end=as_datetime(k[1])) + if len(k) == 3: + key["frequency"] = k[2] + + result.append({"dates": key, **v.as_dict(recipe)}) return {"concat": result} @@ -407,7 +411,9 @@ def dump(self, file=sys.stdout): if self.build: result["build"] = self.build.as_dict() - yaml.safe_dump(result, sort_keys=False, indent=2, width=120, stream=file) + from .dumper import yaml_dump + + yaml_dump(result, sort_keys=False, indent=2, width=120, stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From 55f740d34db13f16553c2e2ea67e8eae50f24753 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 14:04:42 +0200 Subject: [PATCH 082/212] update --- src/anemoi/datasets/commands/format.py | 87 ++++++++++++++++++++++++++ src/anemoi/datasets/dumper.py | 16 ++++- 2 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 src/anemoi/datasets/commands/format.py diff --git a/src/anemoi/datasets/commands/format.py b/src/anemoi/datasets/commands/format.py new file mode 100644 index 000000000..42d3aa1f9 --- /dev/null +++ b/src/anemoi/datasets/commands/format.py @@ -0,0 +1,87 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import sys +from typing import Any + +import yaml + +from ..dumper import yaml_dump +from . import Command + +LOG = logging.getLogger(__name__) + + +def make_dates(config): + if isinstance(config, dict): + return {k: make_dates(v) for k, v in config.items()} + if isinstance(config, list): + return [make_dates(v) for v in config] + if isinstance(config, str): + try: + return datetime.datetime.fromisoformat(config) + except ValueError: + return config + return config + + +ORDER = ( + "name", + "description", + "dataset_status", + "licence", + "attribution", + "env", + "dates", + "common", + "data_sources", + "input", + "output", + "statistics", + "build", + "platform", +) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + config = make_dates(config) + + text = yaml_dump(config, sort_keys=False, indent=2, width=120, order=ORDER) + # with open(args.path + ".tmp", "w") as f: + f = sys.stdout + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) + + # os.rename(args.path + ".tmp", args.path) + + +command = Recipe diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 052750d56..de41b0710 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -68,4 +68,18 @@ def represent_inline_list(dumper, data): def yaml_dump(obj, **kwargs): - return yaml.dump(obj, Dumper=MyDumper, **kwargs) + + kwargs.setdefault("Dumper", MyDumper) + kwargs.setdefault("sort_keys", False) + kwargs.setdefault("indent", 2) + kwargs.setdefault("width", 120) + + order = kwargs.pop("order", None) + if order: + + def _ordering(k): + return order.index(k) if k in order else len(order) + + obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} + + return yaml.dump(obj, **kwargs) From 014dbbc90948874f1ccf7871d0208441b0a06fa7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 16:23:48 +0200 Subject: [PATCH 083/212] update --- src/anemoi/datasets/create/input/action.py | 9 +++++++++ src/anemoi/datasets/dumper.py | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index d173f0996..3085bc60d 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -202,6 +202,10 @@ def __init__(self, config, *path): def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + def __call__(self, context, argument): + for source in self.sources.values(): + source(context, argument) + class Recipe(Action): def __init__(self, input, data_sources): @@ -214,6 +218,11 @@ def python_code(self, code): self.data_sources.python_code(code), ) + def __call__(self, context, argument): + # Load data_sources + self.data_sources(context, argument) + return self.input(context, argument) + KLASS = { "concat": Concat, diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index de41b0710..7e3867969 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -67,14 +67,13 @@ def represent_inline_list(dumper, data): MyDumper.add_representer(list, represent_inline_list) -def yaml_dump(obj, **kwargs): +def yaml_dump(obj, order=None, **kwargs): kwargs.setdefault("Dumper", MyDumper) kwargs.setdefault("sort_keys", False) kwargs.setdefault("indent", 2) kwargs.setdefault("width", 120) - order = kwargs.pop("order", None) if order: def _ordering(k): From 8ad939661a73ee81c9ac7c0a52d0c0711aacfe88 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 15:31:27 +0000 Subject: [PATCH 084/212] bug fix in path --- src/anemoi/datasets/create/input/action.py | 23 +++++++++++-------- .../datasets/create/input/context/__init__.py | 5 ++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 3085bc60d..2c157e3f9 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -20,23 +20,27 @@ class Action: def __init__(self, config, *path): self.config = config self.path = path + assert path[0] in ( + "input", + "data_sources", + ), f"{self.__class__.__name__}. Path must start with 'input' or 'data_sources': {path}" # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") class Concat(Action): def __init__(self, config, *path): - super().__init__(config, *path) + super().__init__(config, *path, "concat") assert isinstance(config, list), f"Value must be a dict {list}" self.choices = [] - for item in config: + for i, item in enumerate(config): assert "dates" in item, f"Value must contain the key 'dates' {item}" dates = item["dates"] filtering_dates = DatesProvider.from_config(**dates) - action = action_factory({k: v for k, v in item.items() if k != "dates"}) + action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) self.choices.append((filtering_dates, action)) def __repr__(self): @@ -62,11 +66,11 @@ def python_code(self, code): class Join(Action): def __init__(self, config, *path): - super().__init__(config, *path) + super().__init__(config, *path, "join") assert isinstance(config, list), f"Value must be a list {config}" - self.actions = [action_factory(item, *path, "join", str(i)) for i, item in enumerate(config)] + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Join({self.actions})" @@ -86,8 +90,8 @@ def python_code(self, code) -> None: class Pipe(Action): def __init__(self, config, *path): assert isinstance(config, list), f"Value must be a list {config}" - super().__init__(config, *path) - self.actions = [action_factory(item, *path, "pipe", str(i)) for i, item in enumerate(config)] + super().__init__(config, *path, "pipe") + self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] def __repr__(self): return f"Pipe({self.actions})" @@ -197,14 +201,15 @@ def new_filter(name, mixin): class DataSources(Action): def __init__(self, config, *path): + super().__init__(config, *path) self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) def __call__(self, context, argument): - for source in self.sources.values(): - source(context, argument) + for name, source in self.sources.items(): + context.register(source(context, argument), self.path + (name,)) class Recipe(Action): diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 26d449659..81ccbd593 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -34,6 +34,8 @@ def register(self, data: Any, path: list[str]) -> Any: if not path: return data + assert path[0] in ("input", "data_sources"), path + rich.print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -47,6 +49,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: + rich.print(f"Path not found {path}") + for p in sorted(self.results): + rich.print(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config From 97566189a1930011a365fb07166f189450e5e0f6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:16:31 +0200 Subject: [PATCH 085/212] join recipe command --- src/anemoi/datasets/commands/recipe.py | 50 ---------- .../datasets/commands/recipe/__init__.py | 93 +++++++++++++++++++ .../datasets/commands/{ => recipe}/format.py | 48 +++------- .../datasets/commands/{ => recipe}/migrate.py | 61 +++--------- src/anemoi/datasets/create/input/action.py | 13 ++- 5 files changed, 126 insertions(+), 139 deletions(-) delete mode 100644 src/anemoi/datasets/commands/recipe.py create mode 100644 src/anemoi/datasets/commands/recipe/__init__.py rename src/anemoi/datasets/commands/{ => recipe}/format.py (51%) rename src/anemoi/datasets/commands/{ => recipe}/migrate.py (90%) diff --git a/src/anemoi/datasets/commands/recipe.py b/src/anemoi/datasets/commands/recipe.py deleted file mode 100644 index f111aee1b..000000000 --- a/src/anemoi/datasets/commands/recipe.py +++ /dev/null @@ -1,50 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -import yaml - -from . import Command -from .migrate import migrate - -LOG = logging.getLogger(__name__) - - -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - - command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") - - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - from anemoi.datasets.create import config_to_python - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - if args.migrate: - config = migrate(config) - - print(config_to_python(config)) - - -command = Recipe diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py new file mode 100644 index 000000000..b45bc34d9 --- /dev/null +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -0,0 +1,93 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import argparse +import logging +import sys +from typing import Any + +import yaml + +from anemoi.datasets.create import config_to_python + +from .. import Command +from .format import format_recipe +from .migrate import migrate_recipe + +LOG = logging.getLogger(__name__) + + +class Recipe(Command): + def add_arguments(self, command_parser: Any) -> None: + """Add arguments to the command parser. + + Parameters + ---------- + command_parser : Any + Command parser object. + """ + + command_parser.add_argument("--format", action="store_true", help="Format the recipe.") + command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") + + group = command_parser.add_mutually_exclusive_group() + group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") + group.add_argument("--output", type=str, help="Output file path for the converted recipe.") + + command_parser.add_argument( + "path", + help="Path to recipe.", + ) + + def run(self, args: Any) -> None: + + if not args.format and not args.migrate and not args.python: + args.format = True + + with open(args.path, "r") as file: + config = yaml.safe_load(file) + + assert isinstance(config, dict) + + if args.migrate: + config = migrate_recipe(args, config) + if config is None: + LOG.info(f"{args.path}: No changes needed.") + return + + args.format = True + + if args.format: + formatted = format_recipe(args, config) + f = sys.stdout + if args.output: + f = open(args.output, "w") + + if args.inplace: + f = open(args.path, "w") + + print(formatted, file=f) + + if args.python: + if args.inplace: + argparse.ArgumentError(None, "Inplace conversion to Python is not supported.") + + if args.format: + raise argparse.ArgumentError(None, "Formatting is not supported when converting to Python.") + + if args.output: + with open(args.output, "w") as file: + file.write(config_to_python(config)) + else: + print(config_to_python(config)) + + +command = Recipe diff --git a/src/anemoi/datasets/commands/format.py b/src/anemoi/datasets/commands/recipe/format.py similarity index 51% rename from src/anemoi/datasets/commands/format.py rename to src/anemoi/datasets/commands/recipe/format.py index 42d3aa1f9..f221e0b8d 100644 --- a/src/anemoi/datasets/commands/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -9,14 +9,10 @@ import datetime +import io import logging -import sys -from typing import Any -import yaml - -from ..dumper import yaml_dump -from . import Command +from ...dumper import yaml_dump LOG = logging.getLogger(__name__) @@ -52,36 +48,16 @@ def make_dates(config): ) -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - config = make_dates(config) - - text = yaml_dump(config, sort_keys=False, indent=2, width=120, order=ORDER) - # with open(args.path + ".tmp", "w") as f: - f = sys.stdout - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) +def format_recipe(args, config: dict) -> str: - # os.rename(args.path + ".tmp", args.path) + config = make_dates(config) + assert config + text = yaml_dump(config, order=ORDER) + f = io.StringIO() + for i, line in enumerate(text.splitlines()): + if i and line and line[0] not in (" ", "-"): + line = "\n" + line + print(line, file=f) -command = Recipe + return f.getvalue() diff --git a/src/anemoi/datasets/commands/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py similarity index 90% rename from src/anemoi/datasets/commands/migrate.py rename to src/anemoi/datasets/commands/recipe/migrate.py index fa74886b8..daaa3bf16 100644 --- a/src/anemoi/datasets/commands/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -10,7 +10,6 @@ import datetime import logging -import os from collections.abc import Sequence from typing import Any @@ -22,8 +21,6 @@ from anemoi.datasets.create import validate_config -from . import Command - LOG = logging.getLogger(__name__) @@ -540,54 +537,22 @@ def check(config): raise -class Recipe(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - - rich.print(f"Migrating {args.path}") - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - try: - validate_config(config) - LOG.info(f"{args.path}: Validation successful.") - return - except Exception: - pass +def migrate_recipe(args: Any, config) -> None: - migrated = migrate(config) + rich.print(f"Migrating {args.path}") - migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} - - check(migrated) - if migrated == config: - LOG.info(f"{args.path}: No changes needed.") - return - - migrated = make_dates(migrated) - text = yaml.dump(migrated, default_flow_style=False, sort_keys=False, indent=2, width=120, Dumper=MyDumper) + try: + validate_config(config) + LOG.info(f"{args.path}: Validation successful.") + except Exception: + pass - LOG.info(f"{args.path}: updating.") - with open(args.path + ".tmp", "w") as f: - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) + migrated = migrate(config) - os.rename(args.path + ".tmp", args.path) + migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} + check(migrated) + if migrated == config: + return None -command = Recipe + return migrated diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 2c157e3f9..1e3c8175d 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -23,8 +23,7 @@ def __init__(self, config, *path): assert path[0] in ( "input", "data_sources", - ), f"{self.__class__.__name__}. Path must start with 'input' or 'data_sources': {path}" - # rich.print(f"Creating {self.__class__.__name__} {'.'.join(x for x in self.path)} from {config}") + ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" class Concat(Action): @@ -130,7 +129,7 @@ def __call__(self, context, argument): def python_code(self, code) -> str: # For now... if "source" in self.config: - source = action_factory(self.config["source"]) + source = action_factory(self.config["source"], *self.path, "source") self.config["source"] = source.python_code(code) return code.call(self.name, self.config) @@ -239,7 +238,7 @@ def __call__(self, context, argument): LEN_KLASS = len(KLASS) -def make(key, config, path): +def make(key, config, *path): if LEN_KLASS == len(KLASS): @@ -272,8 +271,12 @@ def make(key, config, path): def action_factory(data, *path): + + assert len(path) > 0, f"Path must contain at least one element {path}" + assert path[0] in ("input", "data_sources") + assert isinstance(data, dict), f"Input data must be a dictionary {data}" assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" key, value = next(iter(data.items())) - return make(key, value, path) + return make(key, value, *path) From a493a961c2a6ddd008688db59164ccc3cd6ab597 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:23:19 +0200 Subject: [PATCH 086/212] join recipe command --- .../datasets/commands/recipe/__init__.py | 17 ++++++- src/anemoi/datasets/commands/validate.py | 44 ------------------- 2 files changed, 15 insertions(+), 46 deletions(-) delete mode 100644 src/anemoi/datasets/commands/validate.py diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index b45bc34d9..546486f56 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -16,6 +16,7 @@ import yaml from anemoi.datasets.create import config_to_python +from anemoi.datasets.create import validate_config from .. import Command from .format import format_recipe @@ -34,6 +35,7 @@ def add_arguments(self, command_parser: Any) -> None: Command parser object. """ + command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") command_parser.add_argument("--format", action="store_true", help="Format the recipe.") command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") @@ -49,14 +51,25 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: - if not args.format and not args.migrate and not args.python: - args.format = True + if not args.validate and not args.format and not args.migrate and not args.python: + args.validate = True with open(args.path, "r") as file: config = yaml.safe_load(file) assert isinstance(config, dict) + if args.validate: + if args.inplace and (not args.format and not args.migrate and not args.python): + argparse.ArgumentError(None, "--inplace is not supported with --validate.") + + if args.output and (not args.format and not args.migrate and not args.python): + argparse.ArgumentError(None, "--output is not supported with --validate.") + + validate_config(config) + LOG.info(f"{args.path}: Recipe is valid.") + return + if args.migrate: config = migrate_recipe(args, config) if config is None: diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py deleted file mode 100644 index 84b25c6f8..000000000 --- a/src/anemoi/datasets/commands/validate.py +++ /dev/null @@ -1,44 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -import yaml - -from . import Command - -LOG = logging.getLogger(__name__) - - -class Validate(Command): - def add_arguments(self, command_parser: Any) -> None: - """Add arguments to the command parser. - - Parameters - ---------- - command_parser : Any - Command parser object. - """ - command_parser.add_argument( - "path", - help="Path to recipe.", - ) - - def run(self, args: Any) -> None: - from anemoi.datasets.create import validate_config - - with open(args.path, "r") as file: - config = yaml.safe_load(file) - - validate_config(config) - - -command = Validate From 1cde9f8253332b532c767ea129d51bee9b44bc3a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:33:46 +0200 Subject: [PATCH 087/212] use ampersand --- src/anemoi/datasets/create/python.py | 2 +- src/anemoi/datasets/recipe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 2a47a5b67..07c521824 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -72,7 +72,7 @@ def call(self, name, argument): return PythonCall(self.top, name, argument) def sum(self, actions): - return PythonChain(self.top, "join", "+", actions) + return PythonChain(self.top, "join", "&", actions) def pipe(self, actions): return PythonChain(self.top, "pipe", "|", actions) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 51e83b693..20199beae 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -43,7 +43,7 @@ class Step: def __or__(self, other): return Pipe(self, other) - def __add__(self, other): + def __and__(self, other): return Join(self, other) def same(self, other): From cdb1a9a08fff142fb6c37c39da6227e98eec2972 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 15 Aug 2025 18:45:43 +0200 Subject: [PATCH 088/212] use ampersand --- src/anemoi/datasets/recipe.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 20199beae..9b8516bf9 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -63,9 +63,6 @@ def as_dict(self, recipe): return self.steps[0].as_dict(recipe) return {self.name: [s.as_dict(recipe) for s in self.steps]} - # def __repr__(self): - # return f"{self.__class__.name}({','.join([str(s) for s in self.steps])})" - def path(self, target, result, *path): if target is self: @@ -157,11 +154,7 @@ def _(k): return {self.owner.name: resolve(self.params, recipe)} - # def __repr__(self): - # return f"{self.__class__.__name__}({self.owner.name}, {','.join([f'{k}={v}' for k, v in self.params.items()])})" - def path(self, target, result, *path): - if self is target: result.append([*path, self]) From e69eb1071e38b0ec3f6ccfd45abe146b31a025cd Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 08:54:47 +0200 Subject: [PATCH 089/212] add settings --- src/anemoi/datasets/create/__init__.py | 4 +++- src/anemoi/datasets/create/python.py | 26 +++++++++++++++++++++++--- src/anemoi/datasets/recipe.py | 3 +++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index fbe4ab024..4d270f72d 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1669,13 +1669,15 @@ def config_to_python(config: Any) -> Any: from ..create.python import PythonScript + raw_config = config + config = loader_config(config) input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) code = PythonScript() x = input.python_code(code) - code = code.source_code(x) + code = code.source_code(x, raw_config) try: import black diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 07c521824..9b5384b1c 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -494,8 +494,27 @@ def register(self, child): if child is not self: self.nodes.append(child) - def prelude(self): + def prelude(self, config): + + from anemoi.datasets.recipe import Recipe + + SKIP = ( + "input", + "data_sources", + ) + result = [] + + for k, v in config.items(): + + if k in SKIP: + continue + + if not hasattr(Recipe, k): + continue + + result.append(f"r.{k} = {repr(v)}") + for node in self.nodes: prelude = node.prelude() if prelude: @@ -504,7 +523,7 @@ def prelude(self): result.extend(prelude) return "\n".join(result) - def source_code(self, first): + def source_code(self, first, config): which = self.nodes.index(first) first.apply_references() @@ -531,9 +550,10 @@ def source_code(self, first): return "\n\n".join( [ "# Generated Python code for Anemoi dataset creation", + "import datetime", "from anemoi.datasets.recipe import Recipe", "r = Recipe()", - self.prelude(), + self.prelude(config), f"r.input = {repr(first)}", "r.dump()", ] diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 9b8516bf9..394304e9a 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -358,6 +358,9 @@ def dates(self): def _parse_dates(self, value): + if isinstance(value, dict): + return value + start = None end = None frequency = 1 From 92165b466df55baeb6b64e3c6ea2380ad22e9ac6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 08:56:51 +0200 Subject: [PATCH 090/212] add settings --- src/anemoi/datasets/create/python.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index 9b5384b1c..d5accc2f4 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -511,6 +511,7 @@ def prelude(self, config): continue if not hasattr(Recipe, k): + LOG.warning(f"Unknown key in recipe: {k}") continue result.append(f"r.{k} = {repr(v)}") From a044e141d5c40f3caf4211cc145188857a3fa83a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 16:43:30 +0000 Subject: [PATCH 091/212] udpate --- src/anemoi/datasets/recipe.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 394304e9a..08b51b981 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -198,6 +198,8 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._licence = licence self._name = name self._dates = None + self._statistics = None + self._build = None self.input = Join() self.output = DotDict() @@ -243,6 +245,8 @@ def as_dict(self): "attribution": self.attribution, "licence": self.licence, "dates": self.dates, + "statistics": self.statistics, + "build": self.build, } if self._data_sources: @@ -391,6 +395,22 @@ def _parse_dates(self, value): def dates(self, value): self._dates = self._parse_dates(value) + @property + def statistics(self): + return self._statistics + + @statistics.setter + def statistics(self, value): + self._statistics = value + + @property + def build(self): + return self._build + + @build.setter + def build(self, value): + self._build = value + def dump(self, file=sys.stdout): input = self.input.as_dict(self) # First so we get the data_sources @@ -402,10 +422,10 @@ def dump(self, file=sys.stdout): result["output"] = self.output.as_dict() if self.statistics: - result["statistics"] = self.statistics.as_dict() + result["statistics"] = self.statistics if self.build: - result["build"] = self.build.as_dict() + result["build"] = self.build from .dumper import yaml_dump From cb9c5761e92bd3b468aec3779b4d51ca434b2897 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 18:46:07 +0200 Subject: [PATCH 092/212] use ruamel --- pyproject.toml | 1 + src/anemoi/datasets/dumper.py | 58 +++++++++++++---------------------- src/anemoi/datasets/recipe.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 873415a24..c10a858b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "numcodecs<0.16", # Until we move to zarr3 "numpy", "pyyaml", + "ruamel-yaml", "semantic-version", "tqdm", "zarr<=2.18.4", diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 7e3867969..69d9f4140 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -10,15 +10,11 @@ import datetime import logging -import yaml +import ruamel.yaml LOG = logging.getLogger(__name__) -class MyDumper(yaml.SafeDumper): - pass - - def represent_date(dumper, data): if data.tzinfo is None: data = data.replace(tzinfo=datetime.timezone.utc) @@ -30,7 +26,7 @@ def represent_date(dumper, data): # --- Represent multiline strings with | style --- def represent_multiline_str(dumper, data): if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|") return dumper.represent_scalar("tag:yaml.org,2002:str", data) @@ -40,39 +36,15 @@ def represent_inline_list(dumper, data): if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): return dumper.represent_sequence("tag:yaml.org,2002:seq", data) - elems = [yaml.dump(i, explicit_start=False, explicit_end=False).replace("\n...\n", "") for i in data] - lines = [] - line = [] - for e in elems: - if sum(len(x) for x in line) + len(e) + 2 * (len(line) + 1) <= 80: - line.append(e) - else: - lines.append(line) - line = [e] - - if line: - lines.append(line) - - if len(lines) == 1: - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - - block_lines = ["- [" + ", ".join(line) + "]" for line in lines] - return dumper.represent_scalar("tag:yaml.org,2002:str", "\n".join(block_lines), style="|") - - -# Register representers -MyDumper.add_representer(datetime.date, represent_date) -MyDumper.add_representer(datetime.datetime, represent_date) -MyDumper.add_representer(str, represent_multiline_str) -MyDumper.add_representer(list, represent_inline_list) + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) def yaml_dump(obj, order=None, **kwargs): - kwargs.setdefault("Dumper", MyDumper) - kwargs.setdefault("sort_keys", False) - kwargs.setdefault("indent", 2) - kwargs.setdefault("width", 120) + # kwargs.setdefault("Dumper", MyDumper) + # kwargs.setdefault("sort_keys", False) + # kwargs.setdefault("indent", 2) + # kwargs.setdefault("width", 120) if order: @@ -81,4 +53,18 @@ def _ordering(k): obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} - return yaml.dump(obj, **kwargs) + # yaml = yaml.YAML(typ='unsafe', pure=True) + yaml = ruamel.yaml.YAML() + yaml.width = 120 # wrap long flow sequences + # yaml.default_flow_style = True + yaml.Representer.add_representer(datetime.date, represent_date) + yaml.Representer.add_representer(datetime.datetime, represent_date) + yaml.Representer.add_representer(str, represent_multiline_str) + yaml.Representer.add_representer(list, represent_inline_list) + + data = ruamel.yaml.comments.CommentedMap() + for k, v in obj.items(): + data[k] = v + data.yaml_set_comment_before_after_key(key=k, before="\n") + + return yaml.dump(data, **kwargs) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index 08b51b981..f1bf60ddc 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -429,7 +429,7 @@ def dump(self, file=sys.stdout): from .dumper import yaml_dump - yaml_dump(result, sort_keys=False, indent=2, width=120, stream=file) + yaml_dump(result, stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From 99a5fb781d5682edbacfbb5a268409fec42eb8a7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 16 Aug 2025 17:11:02 +0000 Subject: [PATCH 093/212] fix source as parameters --- src/anemoi/datasets/create/input/action.py | 2 +- src/anemoi/datasets/create/input/context/__init__.py | 9 +++++++-- src/anemoi/datasets/create/sources/repeated_dates.py | 3 ++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 1e3c8175d..db9d8dace 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -275,7 +275,7 @@ def action_factory(data, *path): assert len(path) > 0, f"Path must contain at least one element {path}" assert path[0] in ("input", "data_sources") - assert isinstance(data, dict), f"Input data must be a dictionary {data}" + assert isinstance(data, dict), f"Input data must be a dictionary, got {type(data)}" assert len(data) == 1, f"Input data must contain exactly one key-value pair {data} {'.'.join(x for x in path)}" key, value = next(iter(data.items())) diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 81ccbd593..738f6a85b 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -56,10 +56,15 @@ def resolve(self, config): return config - def create_source(self, config: Any) -> Any: + def create_source(self, config: Any, *path) -> Any: from anemoi.datasets.create.input.action import action_factory - return action_factory(config) + if not isinstance(config, dict): + # It is already a result (e.g. ekd.FieldList), loaded from ${a.b.c} + # TODO: something more elegant + return lambda *args, **kwargs: config + + return action_factory(config, *path) @abstractmethod def empty_result(self) -> Any: ... diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index d092f08ad..eb235cd99 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -302,11 +302,12 @@ def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: class RepeatedDatesSource(Source): def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: + self.owner = owner self.mapper = DateMapper.from_mode(mode, source, kwargs) self.source = source def execute(self, context, group_of_dates): - source = context.create_source(self.source) + source = context.create_source(self.source, *self.owner.path, "source") result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): From ce027f434b7ff598e2df7b49ea6a961460dc295b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 18 Aug 2025 18:50:34 +0200 Subject: [PATCH 094/212] update --- .../datasets/commands/recipe/__init__.py | 2 + src/anemoi/datasets/commands/recipe/format.py | 10 +- .../datasets/commands/recipe/migrate.py | 240 +++++++++--------- src/anemoi/datasets/create/config.py | 2 +- src/anemoi/datasets/create/python.py | 3 + src/anemoi/datasets/dumper.py | 36 +-- src/anemoi/datasets/recipe.py | 64 ++++- 7 files changed, 206 insertions(+), 151 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 546486f56..71b116213 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -80,6 +80,7 @@ def run(self, args: Any) -> None: if args.format: formatted = format_recipe(args, config) + assert "dates" in formatted f = sys.stdout if args.output: f = open(args.output, "w") @@ -88,6 +89,7 @@ def run(self, args: Any) -> None: f = open(args.path, "w") print(formatted, file=f) + f.close() if args.python: if args.inplace: diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index f221e0b8d..533a569c1 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -9,7 +9,6 @@ import datetime -import io import logging from ...dumper import yaml_dump @@ -53,11 +52,4 @@ def format_recipe(args, config: dict) -> str: config = make_dates(config) assert config - text = yaml_dump(config, order=ORDER) - f = io.StringIO() - for i, line in enumerate(text.splitlines()): - if i and line and line[0] not in (" ", "-"): - line = "\n" + line - print(line, file=f) - - return f.getvalue() + return yaml_dump(config, order=ORDER) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index daaa3bf16..31c67aa42 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -8,26 +8,22 @@ # nor does it submit to any jurisdiction. -import datetime import logging +import sys from collections.abc import Sequence from typing import Any import rich -import yaml from glom import assign from glom import delete from glom import glom from anemoi.datasets.create import validate_config +from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) -class MyDumper(yaml.SafeDumper): - pass - - def find_paths(data, target_key=None, target_value=None, *path): matches = [] @@ -58,85 +54,23 @@ def find_chevrons(data, *path): return matches -# Custom representer for datetime.date and datetime.datetime -def represent_date(dumper, data): - if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): - data = datetime.datetime(data.year, data.month, data.day, 0, 0, 0) - # Ensure it's UTC - if data.tzinfo is None: - data = data.replace(tzinfo=datetime.timezone.utc) - data = data.astimezone(datetime.timezone.utc) - # Format as ISO 8601 with 'Z' - iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" - return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) - - -# Custom representer for multiline strings using the '|' block style -def represent_multiline_str(dumper, data): - if "\n" in data: - text_list = [line.rstrip() for line in data.splitlines()] - fixed_data = "\n".join(text_list) - return dumper.represent_scalar("tag:yaml.org,2002:str", fixed_data, style="|") - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - -# --- Represent short lists inline (flow style) --- -def represent_inline_list(dumper, data): - # Flow style if list has <= 4 simple elements - if ( - all(isinstance(i, (str, int, float, bool, type(None))) for i in data) - and len(", ".join([str(x) for x in data])) + 2 <= 80 - ): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - return dumper.represent_sequence("tag:yaml.org,2002:seq", data) - - -# Register custom representers -MyDumper.add_representer(datetime.date, represent_date) -MyDumper.add_representer(datetime.datetime, represent_date) -MyDumper.add_representer(str, represent_multiline_str) -MyDumper.add_representer(list, represent_inline_list) - - -def make_dates(config): - if isinstance(config, dict): - return {k: make_dates(v) for k, v in config.items()} - if isinstance(config, list): - return [make_dates(v) for v in config] - if isinstance(config, str): - try: - return datetime.datetime.fromisoformat(config) - except ValueError: - return config - return config - - -ORDER = ( - "name", - "description", - "dataset_status", - "licence", - "attribution", - "env", - "dates", - "common", - "data_sources", - "input", - "output", - "statistics", - "build", - "platform", -) -ORDER = {k: i for i, k in enumerate(ORDER)} - - -def order(x: str) -> int: - +def find_paths_in_substrees(path, obj, cur_path=None): + if cur_path is None: + cur_path = [] + matches = [] try: - return ORDER[x[0]] - except KeyError: - rich.print(f"Unknown key: {x}") - raise + glom(obj, path) # just to check existence + matches.append(cur_path + path.split(".")) + except Exception: + pass + + if isinstance(obj, dict): + for k, v in obj.items(): + matches.extend(find_paths_in_substrees(path, v, cur_path + [k])) + elif isinstance(obj, list): + for i, v in enumerate(obj): + matches.extend(find_paths_in_substrees(path, v, cur_path + [str(i)])) + return matches MIGRATE = { @@ -153,24 +87,26 @@ def order(x: str) -> int: "loop.0.loop_a.dates": "dates", "dates.stop": "dates.end", "dates.group_by": "build.group_by", - "include": "data_sources", + "include.mars": "data_sources.mars.mars", "ensemble_dimension": "build.ensemble_dimension", "flatten_grid": "build.flatten_grid", } DELETE = [ "purpose", - "input.join.0.label", + # "input.join.0.label", "status", "common", "config_format_version", "aliases", - "platform", + # "platform", "loops.0.loop_a.applies_to", "loop.0.loop_a.applies_to", "dataset_status", "alias", "resources", + "input.dates.<<", + "input.dates.join.0.label.name", ] @@ -186,12 +122,12 @@ def order(x: str) -> int: MARKER = object() -def _delete(config, path, result): +def _delete(config, path): x = glom(config, path, default=MARKER) if x is MARKER: return rich.print(f"Deleting {path}={x}") - delete(result, path) + delete(config, path) def _move(config, path, new_path, result): @@ -203,12 +139,12 @@ def _move(config, path, new_path, result): assign(result, new_path, x, missing=dict) -def _fix_input_0(result, config): +def _fix_input_0(config): if isinstance(config["input"], dict): return input = config["input"] - new_input = result["input"] = [] + new_input = [] blocks = {} first = None @@ -227,26 +163,24 @@ def _fix_input_0(result, config): source_name = values.pop("name", None) if inherit is not None: + if inherit.startswith("$"): + inherit = inherit[1:] inherited = blocks[inherit].copy() inherited.update(values) values = inherited - if "source_or_dataset" in values: - values.pop("source_or_dataset", None) - values["template"] = "${input.join.0." + first + "}" - if first is None: first = source_name blocks[block_name] = values.copy() - new_input.append({block_name: {SOURCES.get(source_name, source_name): values.copy()}}) + new_input.append({SOURCES.get(source_name, source_name): values.copy()}) else: assert False, f"Block {block_name} does not have 'kwargs': {values}" blocks[block_name] = values.copy() - config["input"] = result["input"].copy() + config["input"] = dict(join=new_input) def _fix_input_1(result, config): @@ -410,7 +344,7 @@ def _fix_join(result: dict, config: dict) -> None: config["input"] = result["input"].copy() -def _fix_sources(result: dict, config: dict, what) -> None: +def _fix_sources(config: dict, what) -> None: input = config["input"] if what not in input: @@ -420,7 +354,7 @@ def _fix_sources(result: dict, config: dict, what) -> None: new_join = [] for j in join: assert isinstance(j, dict) - assert len(j) == 1 + assert len(j) == 1, j key, values = list(j.items())[0] @@ -432,10 +366,15 @@ def _fix_sources(result: dict, config: dict, what) -> None: } ) - result["input"][what] = new_join + config["input"][what] = new_join config["input"][what] = new_join.copy() +def _assign(config, path, value): + rich.print(f"Assign {path} {value}") + assign(config, path, value) + + def _fix_chevrons(result: dict, config: dict) -> None: rich.print("Fixing chevrons...") paths = find_chevrons(config) @@ -447,23 +386,91 @@ def _fix_chevrons(result: dict, config: dict) -> None: assign(result, ".".join(p[:-1]), a) +def _fix_some(config: dict) -> None: + + paths = find_paths_in_substrees("label.function", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + _assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("constants.source_or_dataset", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + node["template"] = node.pop("source_or_dataset") + if node["template"] == "$previous_data": + node["template"] = "${input.join.0.mars}" + paths = find_paths_in_substrees("constants.template", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + if node["template"] == "$pl_data": + node["template"] = "${input.join.0.mars}" + for d in ("date", "dates", "time"): + paths = find_paths_in_substrees(d, config) + for p in paths: + if len(p) > 1: + node = glom(config, ".".join(p[:-1])) + if isinstance(node, dict) and isinstance(node[d], str) and node[d].startswith("$"): + del node[d] + + paths = find_paths_in_substrees("source.<<", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + node.update(node.pop("<<")) + parent[node.pop("name")] = node + assert len(parent) == 2 + del parent["source"] + + paths = find_paths_in_substrees("label.mars", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("input.dates.join", config) + for p in paths: + node = glom(config, ".".join(p)) + config["input"]["join"] = node + del config["input"]["dates"] + + paths = find_paths_in_substrees("source.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assign(config, ".".join(p[:-2]), {name: node}) + + paths = find_paths_in_substrees("function.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assert node + assign(config, ".".join(p[:-2]), {name: node}) + + def _migrate(config: dict, n) -> dict: result = config.copy() - _fix_input_0(result, config) - _fix_loops(result, config) - _fix_input_1(result, config) - _fix_join(result, config) - _fix_sources(result, config, "join") - _fix_chevrons(result, config) - _fix_other(result, config) + _fix_input_0(result) + # _fix_loops(result, config) + # _fix_input_1(result, config) + # _fix_join(result, config) + # _fix_chevrons(result, config) + # _fix_other(result, config) for k, v in MIGRATE.items(): _move(config, k, v, result) + _fix_some(result) + _fix_sources(result, "join") + for k in DELETE: - _delete(config, k, result) + _delete(result, k) remove_empties(result) @@ -513,7 +520,6 @@ def has_value(config, value: str) -> bool: def check(config): - from anemoi.datasets.create import validate_config try: @@ -523,6 +529,7 @@ def check(config): assert not has_key(config, "label") assert not has_key(config, "kwargs") assert not has_value(config, "$previous_data") + assert not has_value(config, "$pl_data") assert not has_value(config, "$dates") assert not has_key(config, "inherit") assert not has_key(config, "source_or_dataset") @@ -532,25 +539,18 @@ def check(config): assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." except Exception as e: - rich.print(f"Validation failed: {e}") - rich.print(f"Config: {config}") - raise + rich.print("Validation failed:") + rich.print(e) + print(yaml_dump(config)) + sys.exit(1) def migrate_recipe(args: Any, config) -> None: rich.print(f"Migrating {args.path}") - try: - validate_config(config) - LOG.info(f"{args.path}: Validation successful.") - except Exception: - pass - migrated = migrate(config) - migrated = {k: v for k, v in sorted(migrated.items(), key=order) if v} - check(migrated) if migrated == config: return None diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 6de4a06cd..6b336654f 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -93,7 +93,7 @@ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: if dic[key] == value: return raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - LOG.info(f"Setting {key}={value} in config") + # LOG.info(f"Setting {key}={value} in config") dic[key] = value diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py index d5accc2f4..29b8c611d 100644 --- a/src/anemoi/datasets/create/python.py +++ b/src/anemoi/datasets/create/python.py @@ -501,6 +501,8 @@ def prelude(self, config): SKIP = ( "input", "data_sources", + "common", + "aliases", ) result = [] @@ -512,6 +514,7 @@ def prelude(self, config): if not hasattr(Recipe, k): LOG.warning(f"Unknown key in recipe: {k}") + assert False, f"Unknown key in recipe: {k}" continue result.append(f"r.{k} = {repr(v)}") diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py index 69d9f4140..18c8d34d4 100644 --- a/src/anemoi/datasets/dumper.py +++ b/src/anemoi/datasets/dumper.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import datetime +import io import logging import ruamel.yaml @@ -16,10 +17,15 @@ def represent_date(dumper, data): - if data.tzinfo is None: - data = data.replace(tzinfo=datetime.timezone.utc) - data = data.astimezone(datetime.timezone.utc) - iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + + if isinstance(data, datetime.datetime): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + else: + iso_str = data.isoformat() + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) @@ -39,12 +45,7 @@ def represent_inline_list(dumper, data): return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) -def yaml_dump(obj, order=None, **kwargs): - - # kwargs.setdefault("Dumper", MyDumper) - # kwargs.setdefault("sort_keys", False) - # kwargs.setdefault("indent", 2) - # kwargs.setdefault("width", 120) +def yaml_dump(obj, order=None, stream=None, **kwargs): if order: @@ -53,18 +54,23 @@ def _ordering(k): obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} - # yaml = yaml.YAML(typ='unsafe', pure=True) yaml = ruamel.yaml.YAML() yaml.width = 120 # wrap long flow sequences - # yaml.default_flow_style = True + yaml.Representer.add_representer(datetime.date, represent_date) yaml.Representer.add_representer(datetime.datetime, represent_date) yaml.Representer.add_representer(str, represent_multiline_str) yaml.Representer.add_representer(list, represent_inline_list) data = ruamel.yaml.comments.CommentedMap() - for k, v in obj.items(): + for i, (k, v) in enumerate(obj.items()): data[k] = v - data.yaml_set_comment_before_after_key(key=k, before="\n") + if i > 0: + data.yaml_set_comment_before_after_key(key=k, before="\n") + + if stream: + yaml.dump(data, stream=stream, **kwargs) - return yaml.dump(data, **kwargs) + stream = io.StringIO() + yaml.dump(data, stream=stream, **kwargs) + return stream.getvalue() diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index f1bf60ddc..c0dbc1bea 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -25,6 +25,16 @@ LOG = logging.getLogger(__name__) +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + class Index: def __init__(self, index): self.name = str(index) @@ -135,7 +145,7 @@ def resolve(params, recipe, name=None): if isinstance(params, dict): def _(k): - if k.endswith("_"): + if isinstance(k, str) and k.endswith("_"): return k[:-1] return k @@ -200,6 +210,10 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): self._dates = None self._statistics = None self._build = None + self._env = None + self._dataset_status = None + self._output = None + self._platform = None self.input = Join() self.output = DotDict() @@ -261,8 +275,6 @@ def as_dict(self): def concat(self, *args, **kwargs): return Concat(*args, **kwargs) - # def assert False, (name, target.as_dict(self)) - def make_data_source(self, name, target): target = target.as_dict(self) @@ -281,7 +293,6 @@ def make_data_source(self, name, target): return f"${{data_sources.{name}}}" def resolve(self, source, target, name=None): - # assert isinstance(target, Source), f"Only sources can be used as template {target}" top = Index("input") # So we have 'input' first in the path @@ -395,6 +406,14 @@ def _parse_dates(self, value): def dates(self, value): self._dates = self._parse_dates(value) + @property + def output(self): + return self._output + + @output.setter + def output(self, value): + self._output = value + @property def statistics(self): return self._statistics @@ -411,6 +430,30 @@ def build(self): def build(self, value): self._build = value + @property + def env(self): + return self._env + + @env.setter + def env(self, value): + self._env = value + + @property + def dataset_status(self): + return self._dataset_status + + @dataset_status.setter + def dataset_status(self, value): + self._dataset_status = value + + @property + def platform(self): + return self._platform + + @platform.setter + def platform(self, value): + self._platform = value + def dump(self, file=sys.stdout): input = self.input.as_dict(self) # First so we get the data_sources @@ -419,7 +462,7 @@ def dump(self, file=sys.stdout): result["input"] = input if self.output: - result["output"] = self.output.as_dict() + result["output"] = self.output if self.statistics: result["statistics"] = self.statistics @@ -427,9 +470,18 @@ def dump(self, file=sys.stdout): if self.build: result["build"] = self.build + if self.env: + result["env"] = self.env + + if self.dataset_status: + result["dataset_status"] = self.dataset_status + + if self.platform: + result["platform"] = self.platform + from .dumper import yaml_dump - yaml_dump(result, stream=file) + yaml_dump(_un_dotdict(result), stream=file) def test(self, output="recipe.zarr"): from argparse import ArgumentParser From f2615d6ee9050176d061e0e954331f20720e0b37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 10:39:24 +0000 Subject: [PATCH 095/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 410a136d7..625b1a562 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -136,7 +136,7 @@ def _subset(self, **kwargs: Any) -> "Dataset": if not kwargs: return self.mutate() - name = kwargs.pop("set_group", None) # TODO(Florian) + name = kwargs.pop("set_group", None) # TODO(Florian) name = kwargs.pop("name", name) result = self.__subset(**kwargs) result._name = name From 3d5f0ef62b29c5d103096e7b1d3026984a5a56e1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 08:38:51 +0100 Subject: [PATCH 096/212] tidy --- .../datasets/commands/recipe/__init__.py | 2 +- .../datasets/commands/recipe/migrate.py | 19 +++++------- src/anemoi/datasets/create/__init__.py | 3 -- src/anemoi/datasets/create/input/action.py | 4 --- .../datasets/create/input/context/__init__.py | 10 +++---- .../datasets/create/input/context/field.py | 3 +- .../datasets/create/sources/repeated_dates.py | 29 ++++++++----------- 7 files changed, 26 insertions(+), 44 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 71b116213..bf08d1ee7 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -54,7 +54,7 @@ def run(self, args: Any) -> None: if not args.validate and not args.format and not args.migrate and not args.python: args.validate = True - with open(args.path, "r") as file: + with open(args.path) as file: config = yaml.safe_load(file) assert isinstance(config, dict) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 31c67aa42..03da61fbc 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -13,7 +13,6 @@ from collections.abc import Sequence from typing import Any -import rich from glom import assign from glom import delete from glom import glom @@ -126,7 +125,6 @@ def _delete(config, path): x = glom(config, path, default=MARKER) if x is MARKER: return - rich.print(f"Deleting {path}={x}") delete(config, path) @@ -134,7 +132,6 @@ def _move(config, path, new_path, result): x = glom(config, path, default=MARKER) if x is MARKER: return - rich.print(f"Moving {path}={x} to {new_path}={x}") delete(result, path) assign(result, new_path, x, missing=dict) @@ -265,7 +262,7 @@ def _fix_loops(result: dict, config: dict) -> None: concat = [] result["input"] = {"concat": concat} - rich.print("Found loops:", entries) + print("Found loops:", entries) for block in input: assert isinstance(block, dict), block @@ -296,7 +293,7 @@ def _fix_loops(result: dict, config: dict) -> None: def _fix_other(result: dict, config: dict) -> None: paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") for p in paths: - rich.print(f"Fixing {'.'.join(p)}") + print(f"Fixing {'.'.join(p)}") assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) delete(result, ".".join(p)) @@ -306,7 +303,7 @@ def _fix_other(result: dict, config: dict) -> None: def _fix_join(result: dict, config: dict) -> None: - rich.print("Fixing join...") + print("Fixing join...") input = config["input"] if "dates" in input and "join" in input["dates"]: result["input"]["join"] = input["dates"]["join"] @@ -371,12 +368,12 @@ def _fix_sources(config: dict, what) -> None: def _assign(config, path, value): - rich.print(f"Assign {path} {value}") + print(f"Assign {path} {value}") assign(config, path, value) def _fix_chevrons(result: dict, config: dict) -> None: - rich.print("Fixing chevrons...") + print("Fixing chevrons...") paths = find_chevrons(config) for p in paths: a = glom(config, ".".join(p)) @@ -539,15 +536,15 @@ def check(config): assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." except Exception as e: - rich.print("Validation failed:") - rich.print(e) + print("Validation failed:") + print(e) print(yaml_dump(config)) sys.exit(1) def migrate_recipe(args: Any, config) -> None: - rich.print(f"Migrating {args.path}") + print(f"Migrating {args.path}") migrated = migrate(config) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 8eb1c87da..3c615e21f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -19,7 +19,6 @@ import cftime import numpy as np -import rich import tqdm import zarr from anemoi.utils.dates import as_datetime @@ -670,8 +669,6 @@ def _run(self) -> int: LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) - rich.print("Minimal input dates:", self.minimal_input) - variables = self.minimal_input.variables LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index db9d8dace..7afd81e35 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -9,8 +9,6 @@ import logging -import rich - from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) @@ -122,8 +120,6 @@ def __call__(self, context, argument): source = self.create_object(config) - rich.print(f"Executing source {self.name} from {config}") - return context.register(self.call_object(context, source, argument), self.path) def python_code(self, code) -> str: diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 738f6a85b..eef61504c 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -12,8 +12,6 @@ from abc import abstractmethod from typing import Any -import rich - LOG = logging.getLogger(__name__) @@ -27,7 +25,7 @@ def __init__(self, /, argument: Any) -> None: def trace(self, emoji, *message) -> None: - rich.print(f"{emoji}: {message}") + print(f"{emoji}: {message}") def register(self, data: Any, path: list[str]) -> Any: @@ -36,7 +34,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - rich.print(f"Registering data at path: {path}") + print(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -49,9 +47,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: - rich.print(f"Path not found {path}") + print(f"Path not found {path}") for p in sorted(self.results): - rich.print(f" Available paths: {p}") + print(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index c3456d89f..23282dce0 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -9,7 +9,6 @@ from typing import Any -from typing import Dict from earthkit.data.core.order import build_remapping @@ -25,7 +24,7 @@ def __init__( argument: Any, order_by: str, flatten_grid: bool, - remapping: Dict[str, Any], + remapping: dict[str, Any], use_grib_paramid: bool, ) -> None: super().__init__(argument) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index eb235cd99..77a06c76c 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -10,15 +10,10 @@ import logging from collections import defaultdict +from collections.abc import Generator from typing import Any -from typing import Dict -from typing import Generator -from typing import Optional -from typing import Set -from typing import Tuple import numpy as np -import rich from anemoi.transform.fields import new_field_with_valid_datetime from anemoi.transform.fields import new_fieldlist_from_list from anemoi.utils.dates import as_datetime @@ -52,7 +47,7 @@ class DateMapper: """A factory class to create DateMapper instances based on the given mode.""" @staticmethod - def from_mode(mode: str, source: Any, config: Dict[str, Any]) -> "DateMapper": + def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": """Create a DateMapper instance based on the given mode. Parameters @@ -102,10 +97,10 @@ def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", ski self.maximum: Any = frequency_to_timedelta(maximum) self.frequency: Any = frequency_to_timedelta(frequency) self.skip_all_nans: bool = skip_all_nans - self.tried: Set[Any] = set() - self.found: Set[Any] = set() + self.tried: set[Any] = set() + self.found: set[Any] = set() - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: """Transform the group of dates to the closest available dates. Parameters @@ -200,7 +195,7 @@ def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, Non class DateMapperClimatology(DateMapper): """A DateMapper implementation that maps dates to specified climatology dates.""" - def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) -> None: + def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: """Initialize DateMapperClimatology. Parameters @@ -216,9 +211,9 @@ def __init__(self, source: Any, year: int, day: int, hour: Optional[int] = None) """ self.year: int = year self.day: int = day - self.hour: Optional[int] = hour + self.hour: int | None = hour - def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, None]: + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: """Transform the group of dates to the specified climatology dates. Parameters @@ -254,7 +249,7 @@ def transform(self, group_of_dates: Any) -> Generator[Tuple[Any, Any], None, Non class DateMapperConstant(DateMapper): """A DateMapper implementation that maps dates to a constant date.""" - def __init__(self, source: Any, date: Optional[Any] = None) -> None: + def __init__(self, source: Any, date: Any | None = None) -> None: """Initialize DateMapperConstant. Parameters @@ -265,9 +260,9 @@ def __init__(self, source: Any, date: Optional[Any] = None) -> None: The constant date to map to. """ self.source: Any = source - self.date: Optional[Any] = date + self.date: Any | None = date - def transform(self, group_of_dates: Any) -> Tuple[Any, Any]: + def transform(self, group_of_dates: Any) -> tuple[Any, Any]: """Transform the group of dates to a constant date. Parameters @@ -311,7 +306,7 @@ def execute(self, context, group_of_dates): result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): - rich.print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") + print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") source_results = source(context, one_date_group) for field in source_results: for date in many_dates_group: From 70272f69a112c75060a51949122f0beb24c7046d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 09:24:21 +0100 Subject: [PATCH 097/212] update --- src/anemoi/datasets/create/input/action.py | 7 +------ src/anemoi/datasets/create/input/context/field.py | 1 + 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7afd81e35..7bf6bb017 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -242,7 +242,6 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.create.filters import filter_registry as dataset_filter_registry from anemoi.datasets.create.sources import source_registry as dataset_source_registry # Register sources, local first @@ -254,11 +253,7 @@ def make(key, config, *path): if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) - # Register filters, local first - for name in dataset_filter_registry.registered: - if name not in KLASS: - KLASS[name.replace("_", "-")] = new_filter(name, DatasetFilterMixin) - + # Register filters for name in transform_filter_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 23282dce0..1dd01340e 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -32,6 +32,7 @@ def __init__( self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) self.use_grib_paramid = use_grib_paramid + self.partial_ok = False def empty_result(self) -> Any: import earthkit.data as ekd From 38ced18e528a2b4875a5476c971b9d948142d396 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 11:43:25 +0100 Subject: [PATCH 098/212] update tests --- src/anemoi/datasets/create/input/action.py | 6 +++++- tests/create/test_create.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7bf6bb017..25d7690e9 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -197,7 +197,11 @@ def new_filter(name, mixin): class DataSources(Action): def __init__(self, config, *path): super().__init__(config, *path) - self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + assert isinstance(config, (dict, list)), f"Invalid config type: {type(config)}" + if isinstance(config, dict): + self.sources = {k: action_factory(v, *path, k) for k, v in config.items()} + else: + self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} def python_code(self, code): return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 193c9a26a..dd3f37864 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -14,6 +14,8 @@ from unittest.mock import patch import pytest +from anemoi.transform.filter import Filter +from anemoi.transform.filters import filter_registry from anemoi.utils.testing import GetTestArchive from anemoi.utils.testing import GetTestData from anemoi.utils.testing import skip_if_offline @@ -32,6 +34,18 @@ assert NAMES, "No yaml files found in " + HERE +# Used by pipe.yaml +@filter_registry.register("filter") +class TestFilter(Filter): + + def __init__(self, **kwargs): + + self.kwargs = kwargs + + def forward(self, data): + return data.sel(**self.kwargs) + + @pytest.fixture def load_source(get_test_data: GetTestData) -> LoadSource: return LoadSource(get_test_data) From cb3847e4c1b23262658feb7791b0f441b8c764e5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 22 Aug 2025 12:41:00 +0100 Subject: [PATCH 099/212] fix tests --- src/anemoi/datasets/create/input/action.py | 26 ++++++-------------- src/anemoi/datasets/create/sources/legacy.py | 4 +-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 25d7690e9..7e164c586 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -118,7 +118,7 @@ def __call__(self, context, argument): config["_type"] = self.name # Find a better way to do this - source = self.create_object(config) + source = self.create_object(context, config) return context.register(self.call_object(context, source, argument), self.path) @@ -131,37 +131,27 @@ def python_code(self, code) -> str: class DatasetSourceMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.datasets.create.sources import create_source as create_datasets_source - return create_datasets_source(self, config) + return create_datasets_source(context, config) def call_object(self, context, source, argument): - return source.execute(context, context.source_argument(argument)) - - -class DatasetFilterMixin: - def create_object(self, config): - from anemoi.datasets.create.filters import create_filter as create_datasets_filter - - return create_datasets_filter(self, config) - - def call_object(self, context, filter, argument): - return filter.execute(context.filter_argument(argument)) + return source.execute(context.source_argument(argument)) class TransformSourceMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.transform.sources import create_source as create_transform_source - return create_transform_source(self, config) + return create_transform_source(context, config) class TransformFilterMixin: - def create_object(self, config): + def create_object(self, context, config): from anemoi.transform.filters import create_filter as create_transform_filter - return create_transform_filter(self, config) + return create_transform_filter(context, config) def call_object(self, context, filter, argument): return filter.forward(context.filter_argument(argument)) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d7ab82dd1..d7a15bfe7 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -69,14 +69,14 @@ def __call__(self, execute: Callable) -> Callable: name = f"Legacy{self.name.title()}Source" source = ".".join([execute.__module__, execute.__name__]) - def execute_wrapper(self, context, dates) -> Any: + def execute_wrapper(self, dates) -> Any: """Wrapper method to call the execute function.""" # args, kwargs = resolve(context, (self.args, self.kwargs)) args, kwargs = self.args, self.kwargs try: - return execute(context, dates, *args, **kwargs) + return execute(self.context, dates, *args, **kwargs) except TypeError: LOG.error(f"Error executing source {this.name} from {source}") LOG.error(f"Function signature is: {inspect.signature(execute)}") From 00477c9f805e7e6552b068ffbf0da7a0c572415b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 25 Aug 2025 17:37:44 +0100 Subject: [PATCH 100/212] add missing package --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 30f82757e..a1f96a221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "anemoi-transform>=0.1.10", "anemoi-utils[provenance]>=0.4.32", "cfunits", + "glom", "numcodecs<0.16", # Until we move to zarr3 "numpy", "pyyaml", From b0508a94c7792e9c7d27c7b8a9524ba43d236e28 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 25 Aug 2025 22:00:42 +0100 Subject: [PATCH 101/212] update --- src/anemoi/datasets/create/input/action.py | 39 +++++++++++-- .../datasets/create/input/context/field.py | 18 ++++++ src/anemoi/datasets/create/input/origin.py | 58 +++++++++++++++++++ .../datasets/create/input/result/field.py | 10 ++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 src/anemoi/datasets/create/input/origin.py diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7e164c586..94269e209 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -8,13 +8,15 @@ # nor does it submit to any jurisdiction. import logging +from abc import ABC +from abc import abstractmethod from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) -class Action: +class Action(ABC): def __init__(self, config, *path): self.config = config self.path = path @@ -23,6 +25,14 @@ def __init__(self, config, *path): "data_sources", ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" + @abstractmethod + def __call__(self, context, argument): + pass + + @abstractmethod + def python_code(self, code): + pass + class Concat(Action): def __init__(self, config, *path): @@ -137,7 +147,12 @@ def create_object(self, context, config): return create_datasets_source(context, config) def call_object(self, context, source, argument): - return source.execute(context.source_argument(argument)) + return context.origin(source.execute(context.source_argument(argument)), self) + + def origin(self): + from .origin import Source + + return Source(self.path[-1], self.config) class TransformSourceMixin: @@ -146,6 +161,15 @@ def create_object(self, context, config): return create_transform_source(context, config) + def combine_origins(self, current, previous): + assert previous is None, f"Cannot combine origins, previous already exists: {previous}" + return current + + def origin(self): + from .origin import Source + + return Source(self.path[-1], self.config) + class TransformFilterMixin: def create_object(self, context, config): @@ -154,12 +178,15 @@ def create_object(self, context, config): return create_transform_filter(context, config) def call_object(self, context, filter, argument): - return filter.forward(context.filter_argument(argument)) + return context.origin(filter.forward(context.filter_argument(argument)), self) + def origin(self): + from .origin import Filter -class FilterFunction(Function): - def __call__(self, context, argument): - return self.call(context, argument, context.filter_argument) + return Filter(self.path[-1], self.config) + + def combine_origins(self, current, previous): + return {"_apply": current, **(previous or {})} def _make_name(name, what): diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 1dd01340e..f105ae973 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -10,6 +10,8 @@ from typing import Any +from anemoi.transform.fields import new_field_with_metadata +from anemoi.transform.fields import new_fieldlist_from_list from earthkit.data.core.order import build_remapping from ..result.field import FieldResult @@ -52,3 +54,19 @@ def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: from anemoi.datasets.dates.groups import GroupOfDates return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) + + def origin(self, data: Any, action: Any) -> Any: + # rich.print(f"origin: {data} from {action}") + origin = action.origin() + result = [] + for fs in data: + previous = fs.metadata("_origin", default=None) + origin = origin.combine(previous) + result.append(new_field_with_metadata(fs, _origin=origin)) + + result = new_fieldlist_from_list(result) + + for fs in result: + fs.metadata() + + return result diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py new file mode 100644 index 000000000..378d3f71c --- /dev/null +++ b/src/anemoi/datasets/create/input/origin.py @@ -0,0 +1,58 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC + +LOG = logging.getLogger(__name__) + + +class Origin(ABC): + pass + + +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + +class Source(Origin): + def __init__(self, name, config): + assert isinstance(config, dict), f"Config must be a dictionary {config}" + self.name = name + self.config = _un_dotdict(config) + + def combine(self, previous): + assert previous is None, f"Cannot combine origins, previous already exists: {previous}" + return self + + def __repr__(self): + return f"Source(name={self.name}, config={self.config})" + + +class Filter(Origin): + def __init__(self, name, config, previous=None): + assert isinstance(config, dict), f"Config must be a dictionary {config}" + self.name = name + self.config = _un_dotdict(config) + self.previous = previous + + def combine(self, previous): + if self.previous is previous: + # Avoid duplication of intermediate origins + return self + return Filter(self.name, self.config, previous) + + def __repr__(self): + return f"Filter(name={self.name}, config={self.config}, previous={self.previous})" diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 083d2ffd7..6ee63f3b0 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -17,6 +17,7 @@ from typing import DefaultDict import numpy as np +import rich from anemoi.utils.dates import as_timedelta from anemoi.utils.humanize import seconds_to_human from anemoi.utils.humanize import shorten_list @@ -293,6 +294,8 @@ def __init__(self, context: Any, datasource: Any) -> None: self.group_of_dates, GroupOfDates ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" + self._origin = defaultdict(set) + @property def data_request(self) -> dict[str, Any]: """Returns a dictionary with the parameters needed to retrieve the data.""" @@ -556,6 +559,13 @@ def build_coords(self) -> None: self._cube: Any = cube + p = None + for i, fs in enumerate(self.datasource): + o = fs.metadata("_origin") + if p != o: + rich.print(f"🔥🔥🔥🔥🔥🔥 {fs.metadata()}, {o}") + p = o + self._coords_already_built: bool = True @property From 7d494b9faed51659cb4c7f883727f4daeb5ec501 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 26 Aug 2025 16:11:48 +0000 Subject: [PATCH 102/212] update --- src/anemoi/datasets/create/__init__.py | 1 + .../datasets/create/input/context/field.py | 7 +-- src/anemoi/datasets/create/input/origin.py | 34 ++++++++++++-- .../datasets/create/input/result/field.py | 46 +++++++++++-------- 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 3c615e21f..2b649b251 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -743,6 +743,7 @@ def _run(self) -> int: metadata["end_date"] = dates[-1].isoformat() metadata["frequency"] = frequency metadata["missing_dates"] = [_.isoformat() for _ in missing] + metadata["origins"] = self.minimal_input.origins metadata["version"] = VERSION diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index f105ae973..fb7d6c254 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -56,13 +56,14 @@ def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) def origin(self, data: Any, action: Any) -> Any: - # rich.print(f"origin: {data} from {action}") + origin = action.origin() + result = [] for fs in data: - previous = fs.metadata("_origin", default=None) + previous = fs.metadata("anemoi_origin", default=None) origin = origin.combine(previous) - result.append(new_field_with_metadata(fs, _origin=origin)) + result.append(new_field_with_metadata(fs, anemoi_origin=origin)) result = new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py index 378d3f71c..673144248 100644 --- a/src/anemoi/datasets/create/input/origin.py +++ b/src/anemoi/datasets/create/input/origin.py @@ -14,7 +14,23 @@ class Origin(ABC): - pass + + def __init__(self): + self._variables = set() + + def __repr__(self): + return repr(self.as_dict()) + + def __eq__(self, other): + if not isinstance(other, Origin): + return False + return self is other # or self.as_dict() == other.as_dict() + + def __hash__(self): + return id(self) + + def add_variable(self, name): + self._variables.add(name) def _un_dotdict(x): @@ -29,6 +45,7 @@ def _un_dotdict(x): class Source(Origin): def __init__(self, name, config): + super().__init__() assert isinstance(config, dict), f"Config must be a dictionary {config}" self.name = name self.config = _un_dotdict(config) @@ -37,12 +54,13 @@ def combine(self, previous): assert previous is None, f"Cannot combine origins, previous already exists: {previous}" return self - def __repr__(self): - return f"Source(name={self.name}, config={self.config})" + def as_dict(self): + return {"type": "source", "name": self.name, "config": self.config, "variables": sorted(self._variables)} class Filter(Origin): def __init__(self, name, config, previous=None): + super().__init__() assert isinstance(config, dict), f"Config must be a dictionary {config}" self.name = name self.config = _un_dotdict(config) @@ -54,5 +72,11 @@ def combine(self, previous): return self return Filter(self.name, self.config, previous) - def __repr__(self): - return f"Filter(name={self.name}, config={self.config}, previous={self.previous})" + def as_dict(self): + return { + "type": "filter", + "name": self.name, + "config": self.config, + "apply_to": self.previous.as_dict(), + "variables": sorted(self._variables), + } diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 6ee63f3b0..2bb5e199e 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -294,13 +294,18 @@ def __init__(self, context: Any, datasource: Any) -> None: self.group_of_dates, GroupOfDates ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" - self._origin = defaultdict(set) + self._origins = [] @property def data_request(self) -> dict[str, Any]: """Returns a dictionary with the parameters needed to retrieve the data.""" return _data_request(self.datasource) + @property + def origins(self) -> dict[str, Any]: + """Returns a dictionary with the parameters needed to retrieve the data.""" + return [o.as_dict() for o in self._origins] + def get_cube(self) -> Any: """Retrieve the data cube for the result. @@ -312,26 +317,26 @@ def get_cube(self) -> Any: ds: Any = self.datasource - remapping: Any = self.context.remapping - order_by: Any = self.context.order_by - flatten_grid: Any = self.context.flatten_grid - start: float = time.time() - LOG.debug("Sorting dataset %s %s", dict(order_by), remapping) - assert order_by, order_by + self.remapping: Any = self.context.remapping + self.order_by: Any = self.context.order_by + self.flatten_grid: Any = self.context.flatten_grid + self.start: float = time.time() + LOG.debug("Sorting dataset %s %s", dict(self.order_by), self.remapping) + assert self.order_by, self.order_by - patches: dict[str, dict[Any | None, int]] = {"number": {None: 0}} + self.patches: dict[str, dict[Any | None, int]] = {"number": {None: 0}} try: cube: Any = ds.cube( - order_by, - remapping=remapping, - flatten_values=flatten_grid, - patches=patches, + self.order_by, + remapping=self.remapping, + flatten_values=self.flatten_grid, + patches=self.patches, ) cube = cube.squeeze() - LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.") + LOG.debug(f"Sorting done in {seconds_to_human(time.time()-self.start)}.") except ValueError: - self.explain(ds, order_by, remapping=remapping, patches=patches) + self.explain(ds, self.order_by, remapping=self.remapping, patches=self.patches) # raise ValueError(f"Error in {self}") exit(1) @@ -559,12 +564,17 @@ def build_coords(self) -> None: self._cube: Any = cube + name_key = list(self.order_by.keys())[1] + p = None - for i, fs in enumerate(self.datasource): - o = fs.metadata("_origin") - if p != o: - rich.print(f"🔥🔥🔥🔥🔥🔥 {fs.metadata()}, {o}") + self._origins = [] + for fs in self.datasource: + o, name = fs.metadata("anemoi_origin", name_key, remapping=self.remapping, patches=self.patches) + o.add_variable(name) + if p is not o: + rich.print(f"🔥🔥🔥🔥🔥🔥 {name}, {o}") p = o + self._origins.append(o) self._coords_already_built: bool = True From dd62e77dc0dcfb15e30c774fe0b4d2fd1f65be22 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 29 Aug 2025 07:03:42 +0000 Subject: [PATCH 103/212] fix icon grid test --- .../datasets/create/input/context/field.py | 7 +--- src/anemoi/datasets/create/sources/grib.py | 7 ++++ tests/create/test_sources.py | 35 +++++++++++-------- tests/create/utils/compare.py | 23 ++++++++---- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index fb7d6c254..35c63f92b 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -65,9 +65,4 @@ def origin(self, data: Any, action: Any) -> Any: origin = origin.combine(previous) result.append(new_field_with_metadata(fs, anemoi_origin=origin)) - result = new_fieldlist_from_list(result) - - for fs in result: - fs.metadata() - - return result + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 03bcda475..66134e86c 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -138,6 +138,13 @@ def execute( check(ds, given_paths, valid_datetime=dates, **kwargs) if grid is not None: + + lat, lon = grid.latlon() + + assert len(lat) == len(lon), (len(lat), len(lon)) + for f in ds: + assert len(f.to_numpy(flatten=True)) == len(lat), (len(f.to_numpy(flatten=True)), len(lat)) + ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) if len(ds) == 0: diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index dbf0d746a..e679fb6bf 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -8,7 +8,6 @@ # nor does it submit to any jurisdiction. import os -import sys import numpy as np import pytest @@ -51,9 +50,6 @@ def test_grib(get_test_data: callable) -> None: assert ds.shape == (8, 12, 1, 162) -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" -) @skip_if_offline def test_grib_gridfile(get_test_data) -> None: """Test the creation of a dataset from GRIB files with an unstructured grid. @@ -91,19 +87,19 @@ def test_grib_gridfile(get_test_data) -> None: assert ds.variables == ["2t"] -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="Type hints from anemoi-transform are not compatible with Python < 3.10" -) @skip_if_offline @pytest.mark.parametrize( - "refinement_level_c,shape", + "input_refinement_level_c,output_refinement_level_c,shape", ( - (2, (2, 13, 1, 2880)), - (7, (2, 13, 1, 2949120)), + (7, 2, (2, 13, 1, 2880)), + (7, 7, (2, 13, 1, 2949120)), ), ) def test_grib_gridfile_with_refinement_level( - refinement_level_c: str, shape: tuple[int, int, int, int, int], get_test_data: callable + input_refinement_level_c: str, + output_refinement_level_c: str, + shape: tuple[int, int, int, int, int], + get_test_data: callable, ) -> None: """Test the creation of a dataset from GRIB files with an unstructured grid. @@ -129,11 +125,21 @@ def test_grib_gridfile_with_refinement_level( grib = { "path": os.path.join(path, "{date:strftimedelta(+3h;%Y%m%d%H)}+fc_R03B07_rea_ml.{date:strftime(%Y%m%d%H)}"), - "grid_definition": {"icon": {"path": gridfile, "refinement_level_c": refinement_level_c}}, + "grid_definition": { + "icon": { + "path": gridfile, + "refinement_level_c": input_refinement_level_c, + } + }, "param": param, "level": level, } - refinement_filter = {"icon_refinement_level": {"grid": gridfile, "refinement_level_c": refinement_level_c}} + refinement_filter = { + "icon_refinement_level": { + "grid": gridfile, + "refinement_level_c": output_refinement_level_c, + } + } config = { "dates": { @@ -280,8 +286,7 @@ def test_planetary_computer_conus404() -> None: if __name__ == "__main__": - test_planetary_computer_conus404() - exit(0) + from anemoi.utils.testing import run_tests run_tests(globals()) diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py index aa6a59dd2..796181853 100644 --- a/tests/create/utils/compare.py +++ b/tests/create/utils/compare.py @@ -125,7 +125,7 @@ def compare_statistics(ds1: object, ds2: object) -> None: assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() @staticmethod - def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: + def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list, ignore: list | None = None) -> None: """Compare the attributes of two Zarr datasets. Parameters @@ -138,11 +138,16 @@ def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: The current path in the attribute hierarchy. errors : list The list to store error messages. + ignore : list|None + A list of keys to ignore during comparison. """ + if ignore is None: + ignore = [] + if isinstance(a, dict): - a_keys = list(a.keys()) - b_keys = list(b.keys()) - for k in set(a_keys) | set(b_keys): + a_keys = set(a.keys()) - set(ignore) + b_keys = set(b.keys()) - set(ignore) + for k in a_keys | b_keys: if k not in a_keys: errors.append(f"❌ {path}.{k} : missing key (only in reference)") continue @@ -196,7 +201,13 @@ def compare(self) -> None: If the datasets or their metadata do not match. """ errors = [] - self.compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) + self.compare_dot_zattrs( + dict(self.z_output.attrs), + dict(self.z_reference.attrs), + "metadata", + errors, + ignore=["origins"], + ) if errors: print("Comparison failed") print("\n".join(errors)) @@ -211,7 +222,7 @@ def compare(self) -> None: print(f"tar zcf {base}.tgz {base}") print(f"scp {base}.tgz data@anemoi.ecmwf.int:public/anemoi-datasets/create/mock-mars/") print() - raise AssertionError("Comparison failed") + raise AssertionError(f"Comparison failed {errors}") self.compare_datasets(self.ds_output, self.ds_reference) self.compare_statistics(self.ds_output, self.ds_reference) From 21208ad2f11ef33f91f263cdd9bfc4c58b707c72 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 31 Aug 2025 06:21:22 +0000 Subject: [PATCH 104/212] review origins --- .../datasets/create/input/context/field.py | 8 ++- src/anemoi/datasets/create/input/origin.py | 56 ++++++++++++------- .../datasets/create/input/result/field.py | 21 ++++--- 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 35c63f92b..b89d33dab 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -62,7 +62,11 @@ def origin(self, data: Any, action: Any) -> Any: result = [] for fs in data: previous = fs.metadata("anemoi_origin", default=None) - origin = origin.combine(previous) - result.append(new_field_with_metadata(fs, anemoi_origin=origin)) + fall_through = fs.metadata("anemoi_fall_through", default=False) + if fall_through: + # The field has pass unchanges in a filter + result.append(fs) + else: + result.append(new_field_with_metadata(fs, anemoi_origin=origin.combine(previous))) return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py index 673144248..78274b9a8 100644 --- a/src/anemoi/datasets/create/input/origin.py +++ b/src/anemoi/datasets/create/input/origin.py @@ -15,23 +15,14 @@ class Origin(ABC): - def __init__(self): - self._variables = set() - - def __repr__(self): - return repr(self.as_dict()) - def __eq__(self, other): if not isinstance(other, Origin): return False - return self is other # or self.as_dict() == other.as_dict() + return self is other def __hash__(self): return id(self) - def add_variable(self, name): - self._variables.add(name) - def _un_dotdict(x): if isinstance(x, dict): @@ -43,6 +34,28 @@ def _un_dotdict(x): return x +class Pipe(Origin): + def __init__(self, s1, s2): + super().__init__() + self.steps = [s1, s2] + + if isinstance(s1, Pipe): + assert not isinstance(s2, Pipe), (s1, s2) + self.steps = s1.steps + [s2] + + def combine(self, previous): + assert False, (self, previous) + + def as_dict(self): + return { + "type": "pipe", + "steps": [s.as_dict() for s in self.steps], + } + + def __repr__(self): + return " | ".join(repr(s) for s in self.steps) + + class Source(Origin): def __init__(self, name, config): super().__init__() @@ -55,28 +68,33 @@ def combine(self, previous): return self def as_dict(self): - return {"type": "source", "name": self.name, "config": self.config, "variables": sorted(self._variables)} + return {"type": "source", "name": self.name, "config": self.config} + + def __repr__(self): + return f"{self.name}({id(self)})" class Filter(Origin): - def __init__(self, name, config, previous=None): + def __init__(self, name, config): super().__init__() assert isinstance(config, dict), f"Config must be a dictionary {config}" self.name = name self.config = _un_dotdict(config) - self.previous = previous + self._cache = {} def combine(self, previous): - if self.previous is previous: - # Avoid duplication of intermediate origins - return self - return Filter(self.name, self.config, previous) + if previous in self._cache: + # We use a cache to avoid recomputing the same combination + return self._cache[previous] + self._cache[previous] = Pipe(previous, self) + return self._cache[previous] def as_dict(self): return { "type": "filter", "name": self.name, "config": self.config, - "apply_to": self.previous.as_dict(), - "variables": sorted(self._variables), } + + def __repr__(self): + return f"{self.name}({id(self)})" diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 2bb5e199e..a052fc628 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -17,7 +17,6 @@ from typing import DefaultDict import numpy as np -import rich from anemoi.utils.dates import as_timedelta from anemoi.utils.humanize import seconds_to_human from anemoi.utils.humanize import shorten_list @@ -304,7 +303,7 @@ def data_request(self) -> dict[str, Any]: @property def origins(self) -> dict[str, Any]: """Returns a dictionary with the parameters needed to retrieve the data.""" - return [o.as_dict() for o in self._origins] + return {"version": 1, "origins": self._origins} def get_cube(self) -> Any: """Retrieve the data cube for the result. @@ -567,14 +566,22 @@ def build_coords(self) -> None: name_key = list(self.order_by.keys())[1] p = None - self._origins = [] + origins = defaultdict(set) + for fs in self.datasource: - o, name = fs.metadata("anemoi_origin", name_key, remapping=self.remapping, patches=self.patches) - o.add_variable(name) + o = fs.metadata("anemoi_origin", remapping=self.remapping, patches=self.patches) + name = fs.metadata(name_key, remapping=self.remapping, patches=self.patches) + + assert name not in origins[o], (name,) + origins[o].add(name) + if p is not o: - rich.print(f"🔥🔥🔥🔥🔥🔥 {name}, {o}") + LOG.info(f"🔥🔥🔥🔥🔥🔥 Source: {name}, {o}") p = o - self._origins.append(o) + + self._origins = [] + for k, v in origins.items(): + self._origins.append({"origin": k.as_dict(), "variables": sorted(v)}) self._coords_already_built: bool = True From 53f915c0762f49a7391102df853172c005b3fc84 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 1 Sep 2025 18:24:07 +0000 Subject: [PATCH 105/212] work on components --- src/anemoi/datasets/data/components.py | 160 +++++++++++++++++++++++++ src/anemoi/datasets/data/dataset.py | 10 ++ src/anemoi/datasets/data/join.py | 18 +++ src/anemoi/datasets/data/select.py | 23 ++++ src/anemoi/datasets/data/stores.py | 20 ++++ src/anemoi/datasets/data/subset.py | 23 ++++ 6 files changed, 254 insertions(+) create mode 100644 src/anemoi/datasets/data/components.py diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py new file mode 100644 index 000000000..8686b2f39 --- /dev/null +++ b/src/anemoi/datasets/data/components.py @@ -0,0 +1,160 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from rich.tree import Tree + + +def indices_to_slices(indices: list[int]) -> list[slice]: + indices = sorted(indices) + + if len(indices) <= 1: + return [slice(indices[0], indices[0] + 1, 1)] + + diffs = [a - b for a, b in zip(indices[1:], indices[:-1])] + slices = [] + count = 0 + prev = None + i = 0 + for diff in diffs: + if diff != prev: + if count: + slices.append(slice(indices[i], indices[i + count] + 1, prev)) + i += count + count = 1 + prev = diff + continue + count += 1 + + if count: + slices.append(slice(indices[i], indices[i + count] + 1, prev)) + + check = set() + for s in slices: + check.update(range(s.start, s.stop, s.step)) + + assert check == set(indices), (check - set(indices), set(indices) - check, slices, indices) + + return slices + + +def combine_slices(length, *slices): + + start, step, current_length = 0, 1, length + + for s in slices: + new_start, new_stop, new_step = s.indices(current_length) + new_length = len(range(new_start, new_stop, new_step)) + start = start + new_start * step + step = step * new_step + current_length = new_length + + if current_length == 0: + return slice(0, 0, 1) # canonical empty slice + + if current_length == 0: + return slice(0, 0, 1) + + stop = start + current_length * step + + if step > 0 and stop > length: + stop = None + elif step < 0 and stop <= -1: + stop = None + + return slice(start, stop, step) + + +class Component: + + def reduce(self): + result = [] + + for slices, name, shape in self._reduce(): + combined = [] + for i in range(len(slices)): + combined.append(combine_slices(shape[i], *slices[i])) + + result.append((combined, name, shape)) + + return result + + +class ComponentList(Component): + def __init__(self, components: list[Component]) -> None: + self.components = components + + def __repr__(self): + return "ComponentList(" + ",".join(repr(c) for c in self.components) + ")" + + def tree(self, tree=None): + if tree is None: + tree = Tree("Components") + + t = tree.add("ComponentList") + for c in self.components: + c.tree(t) + return tree + + def _reduce(self): + return sum([c._reduce() for c in self.components], []) + + +class ZarrComponent(Component): + def __init__(self, name, shape) -> None: + self.name = name + self.shape = shape + + def __repr__(self): + return f"ZarrComponent({self.name})" + + def tree(self, tree=None): + if tree is None: + tree = Tree("Components") + + tree.add(f"ZarrComponent({self.name})") + return tree + + def _reduce(self): + slices = [[slice(0, s, 1)] for s in self.shape] + return [(slices, self.name, self.shape)] + + +class AxisComponent(Component): + def __init__(self, slice, component) -> None: + self.slice = slice + self.component = component + self.length = len(list(range(*self.slice.indices(self.slice.stop)))) + + def __repr__(self): + + return f"{self.__class__.__name__}({self.slice} ({self.length}),{self.component})" + + def tree(self, tree=None): + if tree is None: + tree = Tree("Components") + + self.component.tree(tree.add(f"{self.__class__.__name__}({self.slice} ({self.length}))")) + + return tree + + def _reduce(self): + result = [] + for slices, name, shape in self.component._reduce(): + slices = slices.copy() + slices[self.axis].append(self.slice) # Add this slice to list + result.append((slices, name, shape)) + return result + + +class DateSpan(AxisComponent): + axis = 0 + + +class VariableSpan(AxisComponent): + axis = 1 diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 4b76d24f5..1cdb5252d 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -1004,6 +1004,16 @@ def variables_metadata(self) -> dict[str, Any]: """Return the metadata of the variables in the dataset.""" pass + # @abstractmethod + def origin(self, index) -> Any: + """Return the origin of the variable at the specified index.""" + raise NotImplementedError(f"origin() is not implemented for `{self.__class__.__name__}`") + + # @abstractmethod + def components(self) -> Any: + """Return the components of the variable at the specified index.""" + raise NotImplementedError(f"components() is not implemented for `{self.__class__.__name__}`") + @abstractmethod @cached_property def missing(self) -> set[int]: diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 59aefd3a4..069729dbd 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -291,6 +291,24 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return {} + def origin(self, index): + assert ( + isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) + ), tuple + + i = index[1] + for dataset in self.datasets: + if i < dataset.shape[1]: + return dataset.origin((index[0], i, index[2], index[3])) + i -= dataset.shape[1] + + raise ValueError(f"Invalid index {index} {[d.shape for d in self.datasets]}") + + def components(self): + from .components import ComponentList + + return ComponentList([d.components() for d in self.datasets]) + def join_factory(args: tuple, kwargs: dict) -> Dataset: """Create a joined dataset. diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index 048802892..0cad60a50 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -224,6 +224,29 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: # return dict(indices=self.indices) return dict(reason=self.reason) + def origin(self, index): + assert ( + isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) + ), tuple + + return self.dataset.origin((index[0], self.indices[index[1]], index[2], index[3])) + + def components(self): + from .components import ComponentList + from .components import VariableSpan + from .components import indices_to_slices + + slices = indices_to_slices(self.indices) + + forward = self.dataset.components() + + slices = [VariableSpan(s, forward) for s in slices] + + if len(slices) == 1: + return slices[0] + + return ComponentList(slices) + class Rename(Forwards): """Class to rename variables in a dataset.""" diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 78470fec6..d41b59acf 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -424,6 +424,21 @@ def collect_input_sources(self, collected: set) -> None: """Collect input sources.""" pass + def origin(self, index): + # if self.z.attrs.get("origins") is None: + # raise ValueError(f"No origins found in {self}") + return [self.path, self.variables_metadata[self.variables[index[1]]]] + + def components(self): + from .components import ZarrComponent + + return ZarrComponent(self.dataset_name, self.shape) + + @property + def dataset_name(self) -> str: + """Return the name of the dataset.""" + return self.z.attrs.get("recipe", {}).get("name", self.path) + class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" @@ -497,6 +512,11 @@ def label(self) -> str: """Return the label of the dataset.""" return "zarr*" + def origin(self, index): + if index[0] in self.missing: + self._report_missing(index[0]) + return super().origin(index) + QUIET = set() diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 8954fa5bc..7df41a3bc 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -298,3 +298,26 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: # "indices": self.indices, "reason": self.reason, } + + def origin(self, index): + assert ( + isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) + ), tuple + return self.dataset.origin((self.indices[index[0]], index[1], index[2], index[3])) + + def components(self): + + from .components import ComponentList + from .components import DateSpan + from .components import indices_to_slices + + slices = indices_to_slices(self.indices) + + forward = self.dataset.components() + + slices = [DateSpan(s, forward) for s in slices] + + if len(slices) == 1: + return slices[0] + + return ComponentList(slices) From b0348bd4b416fbb765a15f7a57c9063e5c3ca178 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 1 Sep 2025 18:33:59 +0000 Subject: [PATCH 106/212] work on components --- src/anemoi/datasets/data/components.py | 47 ++++++++++++++------------ 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index 8686b2f39..c0d475003 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -12,33 +12,36 @@ def indices_to_slices(indices: list[int]) -> list[slice]: indices = sorted(indices) + assert len(indices) == len(set(indices)), "Duplicate indices are not allowed" - if len(indices) <= 1: - return [slice(indices[0], indices[0] + 1, 1)] + if not indices: + return [] - diffs = [a - b for a, b in zip(indices[1:], indices[:-1])] slices = [] - count = 0 - prev = None + n = len(indices) i = 0 - for diff in diffs: - if diff != prev: - if count: - slices.append(slice(indices[i], indices[i + count] + 1, prev)) - i += count - count = 1 - prev = diff - continue - count += 1 - - if count: - slices.append(slice(indices[i], indices[i + count] + 1, prev)) - - check = set() + + while i < n: + start = indices[i] + # default step = 1 + if i + 1 < n: + step = indices[i + 1] - indices[i] + else: + step = 1 + + j = i + 1 + while j < n and indices[j] - indices[j - 1] == step: + j += 1 + + stop = indices[j - 1] + step + slices.append(slice(start, stop, step)) + i = j + + check = list() for s in slices: - check.update(range(s.start, s.stop, s.step)) + check.extend(range(s.start, s.stop, s.step)) - assert check == set(indices), (check - set(indices), set(indices) - check, slices, indices) + assert check == list(indices), slices return slices @@ -80,7 +83,7 @@ def reduce(self): for i in range(len(slices)): combined.append(combine_slices(shape[i], *slices[i])) - result.append((combined, name, shape)) + result.append((combined, name)) return result From 28d6ffa3e1fc4aff15c331f13e7cd93f2e1a824f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 1 Sep 2025 18:54:37 +0000 Subject: [PATCH 107/212] work on components --- src/anemoi/datasets/data/components.py | 27 +++++++++++++------------- src/anemoi/datasets/data/stores.py | 2 +- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index c0d475003..96e5bbd6a 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -66,9 +66,9 @@ def combine_slices(length, *slices): stop = start + current_length * step if step > 0 and stop > length: - stop = None + stop = length elif step < 0 and stop <= -1: - stop = None + stop = 0 return slice(start, stop, step) @@ -78,12 +78,12 @@ class Component: def reduce(self): result = [] - for slices, name, shape in self._reduce(): + for slices, store in self._reduce(): combined = [] for i in range(len(slices)): - combined.append(combine_slices(shape[i], *slices[i])) + combined.append(combine_slices(store.shape[i], *slices[i])) - result.append((combined, name)) + result.append((combined, store)) return result @@ -109,23 +109,22 @@ def _reduce(self): class ZarrComponent(Component): - def __init__(self, name, shape) -> None: - self.name = name - self.shape = shape + def __init__(self, store) -> None: + self.store = store def __repr__(self): - return f"ZarrComponent({self.name})" + return f"ZarrComponent({self.store})" def tree(self, tree=None): if tree is None: tree = Tree("Components") - tree.add(f"ZarrComponent({self.name})") + tree.add(f"ZarrComponent({self.store})") return tree def _reduce(self): - slices = [[slice(0, s, 1)] for s in self.shape] - return [(slices, self.name, self.shape)] + slices = [[slice(0, s, 1)] for s in self.store.shape] + return [(slices, self.store)] class AxisComponent(Component): @@ -148,10 +147,10 @@ def tree(self, tree=None): def _reduce(self): result = [] - for slices, name, shape in self.component._reduce(): + for slices, store in self.component._reduce(): slices = slices.copy() slices[self.axis].append(self.slice) # Add this slice to list - result.append((slices, name, shape)) + result.append((slices, store)) return result diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index d41b59acf..da5146c8b 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -432,7 +432,7 @@ def origin(self, index): def components(self): from .components import ZarrComponent - return ZarrComponent(self.dataset_name, self.shape) + return ZarrComponent(self) @property def dataset_name(self) -> str: From 1a6a3e4299a84350bf5ef28bbc919a93f4698b7a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 2 Sep 2025 17:52:04 +0000 Subject: [PATCH 108/212] add projections --- src/anemoi/datasets/create/input/origin.py | 24 ++- src/anemoi/datasets/data/components.py | 174 ++++++++++++++++++++- src/anemoi/datasets/data/dataset.py | 7 +- src/anemoi/datasets/data/join.py | 14 +- src/anemoi/datasets/data/select.py | 17 +- src/anemoi/datasets/data/stores.py | 35 ++++- src/anemoi/datasets/data/subset.py | 14 +- 7 files changed, 253 insertions(+), 32 deletions(-) diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py index 78274b9a8..5b14cc77f 100644 --- a/src/anemoi/datasets/create/input/origin.py +++ b/src/anemoi/datasets/create/input/origin.py @@ -15,6 +15,9 @@ class Origin(ABC): + def __init__(self, when="dataset-create"): + self.when = when + def __eq__(self, other): if not isinstance(other, Origin): return False @@ -35,8 +38,8 @@ def _un_dotdict(x): class Pipe(Origin): - def __init__(self, s1, s2): - super().__init__() + def __init__(self, s1, s2, when="dataset-create"): + super().__init__(when) self.steps = [s1, s2] if isinstance(s1, Pipe): @@ -50,6 +53,7 @@ def as_dict(self): return { "type": "pipe", "steps": [s.as_dict() for s in self.steps], + "when": self.when, } def __repr__(self): @@ -57,8 +61,8 @@ def __repr__(self): class Source(Origin): - def __init__(self, name, config): - super().__init__() + def __init__(self, name, config, when="dataset-create"): + super().__init__(when) assert isinstance(config, dict), f"Config must be a dictionary {config}" self.name = name self.config = _un_dotdict(config) @@ -68,15 +72,20 @@ def combine(self, previous): return self def as_dict(self): - return {"type": "source", "name": self.name, "config": self.config} + return { + "type": "source", + "name": self.name, + "config": self.config, + "when": self.when, + } def __repr__(self): return f"{self.name}({id(self)})" class Filter(Origin): - def __init__(self, name, config): - super().__init__() + def __init__(self, name, config, when="dataset-create"): + super().__init__(when) assert isinstance(config, dict), f"Config must be a dictionary {config}" self.name = name self.config = _un_dotdict(config) @@ -94,6 +103,7 @@ def as_dict(self): "type": "filter", "name": self.name, "config": self.config, + "when": self.when, } def __repr__(self): diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index 96e5bbd6a..d59c057d5 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -7,6 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from functools import cached_property + from rich.tree import Tree @@ -58,9 +60,11 @@ def combine_slices(length, *slices): current_length = new_length if current_length == 0: + # assert False, (length, slices) return slice(0, 0, 1) # canonical empty slice if current_length == 0: + # assert False, (length, slices) return slice(0, 0, 1) stop = start + current_length * step @@ -73,6 +77,149 @@ def combine_slices(length, *slices): return slice(start, stop, step) +class _Base: + + def from_store(self, slices, store): + return ProjectionStore(slices, store) + + def make_new(self, slices): + return Projection(slices) + + def list_or_single(self, projections): + if len(projections) == 1: + return projections[0] + return ProjectionList(projections) + + def ensure_list(self): + return ProjectionList([self]) + + +class Projection(_Base): + + def __init__(self, slices): + assert isinstance(slices, (list, tuple)), slices + assert all(isinstance(s, slice) for s in slices), slices + assert len(slices) == 4, slices + self.slices = tuple(slices) + + def from_indices(self, *, axis, indices): + slices = indices_to_slices(indices) + this_slice = self.slices[axis] + combined = [] + for s in slices: + # combined.append(combine_slices(max(this_slice.stop,s.stop), this_slice, s)) + combined.append(combine_slices(max(this_slice.stop, s.stop), s, this_slice)) + + projections = [ + Projection([c if i == axis else self.slices[i] for i in range(len(self.slices))]) for c in combined + ] + + if len(projections) == 1: + return projections[0] + else: + return ProjectionList(projections) + + # def join(self, *, axis, shapes): + # assert isinstance(shapes, (list, tuple)), shapes + # assert all(isinstance(s, (list, tuple)) for s in shapes), shapes + + # i = 0 + # for s in shapes: + # i += s[axis] + + def advance(self, axis, amount): + this_slice = self.slices[axis] + new_start = this_slice.start + amount + new_stop = this_slice.stop + amount + slices = list(self.slices) + slices[axis] = slice(new_start, new_stop, this_slice.step) + return Projection(slices) + + def __repr__(self): + return f"Projection(slices={self.slices})" + + +class ProjectionList(_Base): + + def __init__(self, projections): + assert isinstance(projections, (list, tuple)), projections + assert all(isinstance(p, _Base) for p in projections), projections + self.projections = [] + for p in projections: + if isinstance(p, ProjectionList): + self.projections.extend(p.projections) + else: + self.projections.append(p) + + def from_indices(self, *, axis, indices): + return ProjectionList([p.from_indices(axis=axis, indices=indices) for p in self.projections]) + + # def join(self, *, axis, shapes): + # return ProjectionList([p.join(axis=axis, shapes=shapes) for p in self.projections]) + + # def combine(self, *, axis, projections): + # assert False, projections + + def __repr__(self): + return "ProjectionList(" + ",".join(repr(p) for p in self.projections) + ")" + + def ensure_list(self): + return self + + def __iter__(self): + return iter(self.projections) + + +class ProjectionStore(_Base): + def __init__(self, slices, store): + assert isinstance(slices, (list, tuple)), slices + assert all(isinstance(s, slice) for s in slices), slices + assert len(slices) == 4, slices + + self.slices = slices + self.store = store + + def __repr__(self): + return repr((self.slices, self.store.dataset_name)) + + def apply(self, projection): + + projections = projection.ensure_list() + + result = [] + + for projection in projections: + + # rich.print('apply', projection, 'on', self) + slices = [] + for a, b in zip(self.slices, projection.slices): + slices.append(combine_slices(a.stop, a, b)) + result.append(ProjectionStore(slices, self.store)) + + return self.list_or_single(result) + + +class Mapping: + + def __init__(self, slice: slice, length) -> None: + self.slice = slice + self.length = length + + def __repr__(self): + return f"Mapping(slice={self.slice}, length={self.length})" + + def indices(self): + return self.slice.indices(self.length) + + @property + def start(self): + return self.slice.start + + @cached_property + def mapping(self): + return {j: i for i, j in enumerate(range(*self.indices()))} + + class Component: def reduce(self): @@ -81,16 +228,19 @@ def reduce(self): for slices, store in self._reduce(): combined = [] for i in range(len(slices)): - combined.append(combine_slices(store.shape[i], *slices[i])) + s = combine_slices(store.shape[i], *slices[i]) + combined.append(Mapping(s, store.shape[i])) - result.append((combined, store)) + result.append((combined, store, slices)) return result -class ComponentList(Component): - def __init__(self, components: list[Component]) -> None: +class _ComponentList(Component): + def __init__(self, components: list[Component], what, reason) -> None: self.components = components + self.what = what + self.reason = reason def __repr__(self): return "ComponentList(" + ",".join(repr(c) for c in self.components) + ")" @@ -99,7 +249,7 @@ def tree(self, tree=None): if tree is None: tree = Tree("Components") - t = tree.add("ComponentList") + t = tree.add(f"{self.__class__.__name__}({self.what}, {self.reason})") for c in self.components: c.tree(t) return tree @@ -108,6 +258,20 @@ def _reduce(self): return sum([c._reduce() for c in self.components], []) +class Join(_ComponentList): + # def _reduce(self): + # assert False, self.components + pass + + +class Select(_ComponentList): + pass + + +class Concat(_ComponentList): + pass + + class ZarrComponent(Component): def __init__(self, store) -> None: self.store = store diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 1cdb5252d..904547173 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -1010,10 +1010,15 @@ def origin(self, index) -> Any: raise NotImplementedError(f"origin() is not implemented for `{self.__class__.__name__}`") # @abstractmethod - def components(self) -> Any: + def components(self, slices) -> Any: """Return the components of the variable at the specified index.""" raise NotImplementedError(f"components() is not implemented for `{self.__class__.__name__}`") + # @abstractmethod + def project(self, projection) -> Any: + """Return the project of the variable at the specified index.""" + raise NotImplementedError(f"project() is not implemented for `{self.__class__.__name__}`") + @abstractmethod @cached_property def missing(self) -> set[int]: diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 069729dbd..c1b3e5120 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -304,10 +304,18 @@ def origin(self, index): raise ValueError(f"Invalid index {index} {[d.shape for d in self.datasets]}") - def components(self): - from .components import ComponentList + def components(self, slices): + from .components import Join - return ComponentList([d.components() for d in self.datasets]) + return Join([d.components(slices) for d in self.datasets], "join", {}) + + def project(self, projection): + projections = [] + for dataset in self.datasets: + projections.append(dataset.project(projection)) + projection = projection.advance(axis=1, amount=dataset.shape[1]) + + return projection.list_or_single(projections) def join_factory(args: tuple, kwargs: dict) -> Dataset: diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index 0cad60a50..a4aa1ad42 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -231,21 +231,23 @@ def origin(self, index): return self.dataset.origin((index[0], self.indices[index[1]], index[2], index[3])) - def components(self): - from .components import ComponentList + def components(self, slices): + from .components import Select from .components import VariableSpan from .components import indices_to_slices slices = indices_to_slices(self.indices) - forward = self.dataset.components() - - slices = [VariableSpan(s, forward) for s in slices] + slices = [VariableSpan(s, self.dataset.components((None, s, None, None))) for s in slices] if len(slices) == 1: return slices[0] - return ComponentList(slices) + return Select(slices, "select", self.reason) + + def project(self, projection): + projection = projection.from_indices(axis=1, indices=self.indices) + return self.dataset.project(projection) class Rename(Forwards): @@ -300,3 +302,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: Dict[str, Any]: The metadata specific to the subclass. """ return dict(rename=self.rename) + + def components(self, slices): + return self.forward.components(slices) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index da5146c8b..c00538c77 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -424,16 +424,43 @@ def collect_input_sources(self, collected: set) -> None: """Collect input sources.""" pass + @cached_property + def _origins(self): + origins = self.z.attrs.get("origins") + if self.z.attrs.get("origins") is None: + from anemoi.registry import Dataset + + LOG.warning("No 'origins' in %r, trying to get it from the registry", self.dataset_name) + ds = Dataset(self.dataset_name) + origins = ds.record.get("metadata", {}).get("origins") + + if origins is None: + raise ValueError(f"No 'origins' in {self.dataset_name} or in the registry") + + # version = origins["version"] + origins = origins["origins"] + + result = {} + + for origin in origins: + for v in origin["variables"]: + result[v] = origin["origin"] + + return result + def origin(self, index): - # if self.z.attrs.get("origins") is None: - # raise ValueError(f"No origins found in {self}") - return [self.path, self.variables_metadata[self.variables[index[1]]]] + variable = self.variables[index[1]] + return [self.path, self._origins[variable]] - def components(self): + def components(self, slices): from .components import ZarrComponent return ZarrComponent(self) + def project(self, projection): + slices = tuple(slice(0, i, 1) for i in self.shape) + return projection.from_store(slices, self).apply(projection) + @property def dataset_name(self) -> str: """Return the name of the dataset.""" diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 7df41a3bc..cb007f681 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -305,19 +305,21 @@ def origin(self, index): ), tuple return self.dataset.origin((self.indices[index[0]], index[1], index[2], index[3])) - def components(self): + def components(self, slices): - from .components import ComponentList + from .components import Concat from .components import DateSpan from .components import indices_to_slices slices = indices_to_slices(self.indices) - forward = self.dataset.components() - - slices = [DateSpan(s, forward) for s in slices] + slices = [DateSpan(s, self.dataset.components((s, None, None, None))) for s in slices] if len(slices) == 1: return slices[0] - return ComponentList(slices) + return Concat(slices) + + def project(self, projection): + projection = projection.from_indices(axis=0, indices=self.indices) + return self.dataset.project(projection) From 7b332b51a05e1e07861d61959d20163b8cc2ed1f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 3 Sep 2025 08:48:59 +0000 Subject: [PATCH 109/212] add projection --- src/anemoi/datasets/data/components.py | 91 +++++++++++++++++--------- src/anemoi/datasets/data/join.py | 14 ++-- src/anemoi/datasets/data/select.py | 3 + 3 files changed, 73 insertions(+), 35 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index d59c057d5..34766d36c 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -9,6 +9,7 @@ from functools import cached_property +import rich from rich.tree import Tree @@ -53,6 +54,7 @@ def combine_slices(length, *slices): start, step, current_length = 0, 1, length for s in slices: + assert s.stop >= s.start and s.step > 0 new_start, new_stop, new_step = s.indices(current_length) new_length = len(range(new_start, new_stop, new_step)) start = start + new_start * step @@ -60,20 +62,13 @@ def combine_slices(length, *slices): current_length = new_length if current_length == 0: - # assert False, (length, slices) return slice(0, 0, 1) # canonical empty slice if current_length == 0: - # assert False, (length, slices) return slice(0, 0, 1) stop = start + current_length * step - if step > 0 and stop > length: - stop = length - elif step < 0 and stop <= -1: - stop = 0 - return slice(start, stop, step) @@ -103,44 +98,71 @@ def __init__(self, slices): self.slices = tuple(slices) def from_indices(self, *, axis, indices): + length = max(indices) + 1 slices = indices_to_slices(indices) this_slice = self.slices[axis] combined = [] + for s in slices: - # combined.append(combine_slices(max(this_slice.stop,s.stop), this_slice, s)) - combined.append(combine_slices(max(this_slice.stop, s.stop), s, this_slice)) + c = combine_slices(max(this_slice.stop, s.stop, length), s, this_slice) + + combined.append(c) projections = [ Projection([c if i == axis else self.slices[i] for i in range(len(self.slices))]) for c in combined ] - if len(projections) == 1: - return projections[0] - else: - return ProjectionList(projections) + return self.list_or_single(projections) - # def join(self, *, axis, shapes): - # assert isinstance(shapes, (list, tuple)), shapes - # assert all(isinstance(s, (list, tuple)) for s in shapes), shapes + def from_slices(self, slices): + return Projection(slices) - # i = 0 - # for s in shapes: - # i += s[axis] + def distribute(self, axis, shapes): - def advance(self, axis, amount): - this_slice = self.slices[axis] - new_start = this_slice.start + amount - new_stop = this_slice.stop + amount - slices = list(self.slices) - slices[axis] = slice(new_start, new_stop, this_slice.step) - return Projection(slices) + rich.print("Distributing", self.slices[axis], [s[axis] for s in shapes]) + result = [] + sizes = [s[axis] for s in shapes] + sizes = [sizes[0]] + [sizes[i] + sizes[i - 1] for i in range(1, len(sizes))] + i = 0 + indices = [] + rich.print("Sizes", sizes) + for indice in range(*self.slices[axis].indices(self.slices[axis].stop)): + if i == len(sizes): + break + if indice < sizes[i]: + indices.append(indice) + continue + + if indices: + for s in indices_to_slices(indices): + result.append(self.make_new([s if j == axis else self.slices[j] for j in range(len(self.slices))])) + indices = [] + indices.append(indice) + i += 1 + if indices: + for s in indices_to_slices(indices): + result.append(self.make_new([s if j == axis else self.slices[j] for j in range(len(self.slices))])) + rich.print("======") + for r in result: + rich.print("Distributing", r) + + # for n in [s[axis] for s in shapes]: + + return self.list_or_single(result) def __repr__(self): return f"Projection(slices={self.slices})" + def offset(self, axis, amount): + return Projection( + [slice(s.start + amount, s.stop + amount, s.step) if i == axis else s for i, s in enumerate(self.slices)] + ) -class ProjectionList(_Base): + def shape(self): + return tuple(len(range(*s.indices(s.stop))) for s in self.slices) + +class ProjectionList(_Base): def __init__(self, projections): assert isinstance(projections, (list, tuple)), projections assert all(isinstance(p, _Base) for p in projections), projections @@ -154,11 +176,18 @@ def __init__(self, projections): def from_indices(self, *, axis, indices): return ProjectionList([p.from_indices(axis=axis, indices=indices) for p in self.projections]) - # def join(self, *, axis, shapes): - # return ProjectionList([p.join(axis=axis, shapes=shapes) for p in self.projections]) + def distribute(self, axis, shapes): + + result = [] + offset = 0 + for p in self.projections: + n = p.slices[axis].stop + result.append(p.offset(axis, offset).distribute(axis=axis, shapes=shapes)) + offset += n + + return self.list_or_single(result) - # def combine(self, *, axis, projections): - # assert False, projections + return ProjectionList([p.distribute(axis=axis, shapes=shapes) for p in self.projections]) def __repr__(self): return "ProjectionList(" + ",".join(repr(p) for p in self.projections) + ")" diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index c1b3e5120..6fbcf8825 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -14,6 +14,7 @@ from typing import Any import numpy as np +import rich from numpy.typing import NDArray from .dataset import Dataset @@ -175,6 +176,8 @@ def _overlay(self) -> Dataset: from .select import Select + rich.print("Overlaying join with", variables, len(indices), [d.shape for d in self.datasets]) + return Select(self, indices, {"overlay": variables}) @cached_property @@ -310,12 +313,15 @@ def components(self, slices): return Join([d.components(slices) for d in self.datasets], "join", {}) def project(self, projection): - projections = [] + result = [] + offset = 0 + for dataset in self.datasets: - projections.append(dataset.project(projection)) - projection = projection.advance(axis=1, amount=dataset.shape[1]) + for p in projection.ensure_list(): + result.append(dataset.project(p.offset(axis=1, amount=-offset))) + offset += dataset.shape[1] - return projection.list_or_single(projections) + return projection.list_or_single(result) def join_factory(args: tuple, kwargs: dict) -> Dataset: diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index a4aa1ad42..eb9a7872b 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -305,3 +305,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: def components(self, slices): return self.forward.components(slices) + + def project(self, projection): + return self.forward.project(projection) From c40025a06159be9f32d415d90812fbeed0c49a3d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 3 Sep 2025 09:02:14 +0000 Subject: [PATCH 110/212] tidy code --- src/anemoi/datasets/data/components.py | 223 +++---------------------- src/anemoi/datasets/data/select.py | 4 +- src/anemoi/datasets/data/subset.py | 4 +- 3 files changed, 29 insertions(+), 202 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index 34766d36c..3c5578e57 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -7,13 +7,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from functools import cached_property -import rich -from rich.tree import Tree - - -def indices_to_slices(indices: list[int]) -> list[slice]: +def _indices_to_slices(indices: list[int]) -> list[slice]: indices = sorted(indices) assert len(indices) == len(set(indices)), "Duplicate indices are not allowed" @@ -72,15 +67,17 @@ def combine_slices(length, *slices): return slice(start, stop, step) -class _Base: +class ProjectionBase: def from_store(self, slices, store): return ProjectionStore(slices, store) - def make_new(self, slices): + @classmethod + def from_slices(cls, slices): return Projection(slices) - def list_or_single(self, projections): + @classmethod + def list_or_single(cls, projections): if len(projections) == 1: return projections[0] return ProjectionList(projections) @@ -89,7 +86,7 @@ def ensure_list(self): return ProjectionList([self]) -class Projection(_Base): +class Projection(ProjectionBase): def __init__(self, slices): assert isinstance(slices, (list, tuple)), slices @@ -99,14 +96,12 @@ def __init__(self, slices): def from_indices(self, *, axis, indices): length = max(indices) + 1 - slices = indices_to_slices(indices) + slices = _indices_to_slices(indices) this_slice = self.slices[axis] combined = [] for s in slices: - c = combine_slices(max(this_slice.stop, s.stop, length), s, this_slice) - - combined.append(c) + combined.append(combine_slices(max(this_slice.stop, s.stop, length), s, this_slice)) projections = [ Projection([c if i == axis else self.slices[i] for i in range(len(self.slices))]) for c in combined @@ -114,58 +109,31 @@ def from_indices(self, *, axis, indices): return self.list_or_single(projections) - def from_slices(self, slices): - return Projection(slices) - - def distribute(self, axis, shapes): - - rich.print("Distributing", self.slices[axis], [s[axis] for s in shapes]) - result = [] - sizes = [s[axis] for s in shapes] - sizes = [sizes[0]] + [sizes[i] + sizes[i - 1] for i in range(1, len(sizes))] - i = 0 - indices = [] - rich.print("Sizes", sizes) - for indice in range(*self.slices[axis].indices(self.slices[axis].stop)): - if i == len(sizes): - break - if indice < sizes[i]: - indices.append(indice) - continue - - if indices: - for s in indices_to_slices(indices): - result.append(self.make_new([s if j == axis else self.slices[j] for j in range(len(self.slices))])) - indices = [] - indices.append(indice) - i += 1 - if indices: - for s in indices_to_slices(indices): - result.append(self.make_new([s if j == axis else self.slices[j] for j in range(len(self.slices))])) - rich.print("======") - for r in result: - rich.print("Distributing", r) - - # for n in [s[axis] for s in shapes]: - - return self.list_or_single(result) - def __repr__(self): return f"Projection(slices={self.slices})" def offset(self, axis, amount): return Projection( - [slice(s.start + amount, s.stop + amount, s.step) if i == axis else s for i, s in enumerate(self.slices)] + [ + ( + slice( + s.start + amount, + s.stop + amount, + s.step, + ) + if i == axis + else s + ) + for i, s in enumerate(self.slices) + ] ) - def shape(self): - return tuple(len(range(*s.indices(s.stop))) for s in self.slices) - -class ProjectionList(_Base): +class ProjectionList(ProjectionBase): def __init__(self, projections): assert isinstance(projections, (list, tuple)), projections - assert all(isinstance(p, _Base) for p in projections), projections + assert all(isinstance(p, ProjectionBase) for p in projections), projections + self.projections = [] for p in projections: if isinstance(p, ProjectionList): @@ -176,19 +144,6 @@ def __init__(self, projections): def from_indices(self, *, axis, indices): return ProjectionList([p.from_indices(axis=axis, indices=indices) for p in self.projections]) - def distribute(self, axis, shapes): - - result = [] - offset = 0 - for p in self.projections: - n = p.slices[axis].stop - result.append(p.offset(axis, offset).distribute(axis=axis, shapes=shapes)) - offset += n - - return self.list_or_single(result) - - return ProjectionList([p.distribute(axis=axis, shapes=shapes) for p in self.projections]) - def __repr__(self): return "ProjectionList(" + ",".join(repr(p) for p in self.projections) + ")" @@ -199,7 +154,7 @@ def __iter__(self): return iter(self.projections) -class ProjectionStore(_Base): +class ProjectionStore(ProjectionBase): def __init__(self, slices, store): assert isinstance(slices, (list, tuple)), slices assert all(isinstance(s, slice) for s in slices), slices @@ -219,137 +174,9 @@ def apply(self, projection): for projection in projections: - # rich.print('apply', projection, 'on', self) slices = [] for a, b in zip(self.slices, projection.slices): slices.append(combine_slices(a.stop, a, b)) result.append(ProjectionStore(slices, self.store)) return self.list_or_single(result) - - -class Mapping: - - def __init__(self, slice: slice, length) -> None: - self.slice = slice - self.length = length - - def __repr__(self): - return f"Mapping(slice={self.slice}, length={self.length})" - - def indices(self): - return self.slice.indices(self.length) - - @property - def start(self): - return self.slice.start - - @cached_property - def mapping(self): - return {j: i for i, j in enumerate(range(*self.indices()))} - - -class Component: - - def reduce(self): - result = [] - - for slices, store in self._reduce(): - combined = [] - for i in range(len(slices)): - s = combine_slices(store.shape[i], *slices[i]) - combined.append(Mapping(s, store.shape[i])) - - result.append((combined, store, slices)) - - return result - - -class _ComponentList(Component): - def __init__(self, components: list[Component], what, reason) -> None: - self.components = components - self.what = what - self.reason = reason - - def __repr__(self): - return "ComponentList(" + ",".join(repr(c) for c in self.components) + ")" - - def tree(self, tree=None): - if tree is None: - tree = Tree("Components") - - t = tree.add(f"{self.__class__.__name__}({self.what}, {self.reason})") - for c in self.components: - c.tree(t) - return tree - - def _reduce(self): - return sum([c._reduce() for c in self.components], []) - - -class Join(_ComponentList): - # def _reduce(self): - # assert False, self.components - pass - - -class Select(_ComponentList): - pass - - -class Concat(_ComponentList): - pass - - -class ZarrComponent(Component): - def __init__(self, store) -> None: - self.store = store - - def __repr__(self): - return f"ZarrComponent({self.store})" - - def tree(self, tree=None): - if tree is None: - tree = Tree("Components") - - tree.add(f"ZarrComponent({self.store})") - return tree - - def _reduce(self): - slices = [[slice(0, s, 1)] for s in self.store.shape] - return [(slices, self.store)] - - -class AxisComponent(Component): - def __init__(self, slice, component) -> None: - self.slice = slice - self.component = component - self.length = len(list(range(*self.slice.indices(self.slice.stop)))) - - def __repr__(self): - - return f"{self.__class__.__name__}({self.slice} ({self.length}),{self.component})" - - def tree(self, tree=None): - if tree is None: - tree = Tree("Components") - - self.component.tree(tree.add(f"{self.__class__.__name__}({self.slice} ({self.length}))")) - - return tree - - def _reduce(self): - result = [] - for slices, store in self.component._reduce(): - slices = slices.copy() - slices[self.axis].append(self.slice) # Add this slice to list - result.append((slices, store)) - return result - - -class DateSpan(AxisComponent): - axis = 0 - - -class VariableSpan(AxisComponent): - axis = 1 diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index eb9a7872b..6b1fdf454 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -234,9 +234,9 @@ def origin(self, index): def components(self, slices): from .components import Select from .components import VariableSpan - from .components import indices_to_slices + from .components import _indices_to_slices - slices = indices_to_slices(self.indices) + slices = _indices_to_slices(self.indices) slices = [VariableSpan(s, self.dataset.components((None, s, None, None))) for s in slices] diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index cb007f681..518d36e5f 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -309,9 +309,9 @@ def components(self, slices): from .components import Concat from .components import DateSpan - from .components import indices_to_slices + from .components import _indices_to_slices - slices = indices_to_slices(self.indices) + slices = _indices_to_slices(self.indices) slices = [DateSpan(s, self.dataset.components((s, None, None, None))) for s in slices] From 81e355bda7fb871ec72be12f51e58665a8ed38eb Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 3 Sep 2025 11:01:59 +0000 Subject: [PATCH 111/212] tidy --- src/anemoi/datasets/data/components.py | 15 ++++++++++++--- src/anemoi/datasets/data/dataset.py | 9 +++++---- src/anemoi/datasets/data/forwards.py | 5 +++++ src/anemoi/datasets/data/join.py | 5 ----- src/anemoi/datasets/data/select.py | 19 +------------------ src/anemoi/datasets/data/stores.py | 11 +---------- src/anemoi/datasets/data/subset.py | 17 +---------------- 7 files changed, 25 insertions(+), 56 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index 3c5578e57..ce3e0793a 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -44,7 +44,7 @@ def _indices_to_slices(indices: list[int]) -> list[slice]: return slices -def combine_slices(length, *slices): +def _combine_slices(length, *slices): start, step, current_length = 0, 1, length @@ -101,7 +101,7 @@ def from_indices(self, *, axis, indices): combined = [] for s in slices: - combined.append(combine_slices(max(this_slice.stop, s.stop, length), s, this_slice)) + combined.append(_combine_slices(max(this_slice.stop, s.stop, length), s, this_slice)) projections = [ Projection([c if i == axis else self.slices[i] for i in range(len(self.slices))]) for c in combined @@ -176,7 +176,16 @@ def apply(self, projection): slices = [] for a, b in zip(self.slices, projection.slices): - slices.append(combine_slices(a.stop, a, b)) + slices.append(_combine_slices(a.stop, a, b)) result.append(ProjectionStore(slices, self.store)) return self.list_or_single(result) + + def variables(self): + return self.store.variables[self.slices[1]] + + def origins(self): + result = {} + for variable in self.variables(): + result[variable] = self.store.origins[variable] + return result diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 904547173..55203371b 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -1009,10 +1009,11 @@ def origin(self, index) -> Any: """Return the origin of the variable at the specified index.""" raise NotImplementedError(f"origin() is not implemented for `{self.__class__.__name__}`") - # @abstractmethod - def components(self, slices) -> Any: - """Return the components of the variable at the specified index.""" - raise NotImplementedError(f"components() is not implemented for `{self.__class__.__name__}`") + def components(self) -> Any: + from anemoi.datasets.data.components import Projection + + slices = tuple(slice(0, i, 1) for i in self.shape) + return self.project(Projection(slices)) # @abstractmethod def project(self, projection) -> Any: diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 4e2219b1c..78d632db2 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -240,6 +240,11 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields + def origin(self, index): + origin = self.forward_subclass_origin(index) + self.annotate_origin(origin) + return origin + class Combined(Forwards): """A class to combine multiple datasets into a single dataset.""" diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 6fbcf8825..5eaf9c022 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -307,11 +307,6 @@ def origin(self, index): raise ValueError(f"Invalid index {index} {[d.shape for d in self.datasets]}") - def components(self, slices): - from .components import Join - - return Join([d.components(slices) for d in self.datasets], "join", {}) - def project(self, projection): result = [] offset = 0 diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index 6b1fdf454..90ba7e344 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -224,27 +224,13 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: # return dict(indices=self.indices) return dict(reason=self.reason) - def origin(self, index): + def forward_subclass_origin(self, index): assert ( isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) ), tuple return self.dataset.origin((index[0], self.indices[index[1]], index[2], index[3])) - def components(self, slices): - from .components import Select - from .components import VariableSpan - from .components import _indices_to_slices - - slices = _indices_to_slices(self.indices) - - slices = [VariableSpan(s, self.dataset.components((None, s, None, None))) for s in slices] - - if len(slices) == 1: - return slices[0] - - return Select(slices, "select", self.reason) - def project(self, projection): projection = projection.from_indices(axis=1, indices=self.indices) return self.dataset.project(projection) @@ -303,8 +289,5 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(rename=self.rename) - def components(self, slices): - return self.forward.components(slices) - def project(self, projection): return self.forward.project(projection) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index c00538c77..79dfb6a55 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -425,7 +425,7 @@ def collect_input_sources(self, collected: set) -> None: pass @cached_property - def _origins(self): + def origins(self): origins = self.z.attrs.get("origins") if self.z.attrs.get("origins") is None: from anemoi.registry import Dataset @@ -448,15 +448,6 @@ def _origins(self): return result - def origin(self, index): - variable = self.variables[index[1]] - return [self.path, self._origins[variable]] - - def components(self, slices): - from .components import ZarrComponent - - return ZarrComponent(self) - def project(self, projection): slices = tuple(slice(0, i, 1) for i in self.shape) return projection.from_store(slices, self).apply(projection) diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 518d36e5f..6d68c61a8 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -299,27 +299,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: "reason": self.reason, } - def origin(self, index): + def forward_subclass_origin(self, index): assert ( isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) ), tuple return self.dataset.origin((self.indices[index[0]], index[1], index[2], index[3])) - def components(self, slices): - - from .components import Concat - from .components import DateSpan - from .components import _indices_to_slices - - slices = _indices_to_slices(self.indices) - - slices = [DateSpan(s, self.dataset.components((s, None, None, None))) for s in slices] - - if len(slices) == 1: - return slices[0] - - return Concat(slices) - def project(self, projection): projection = projection.from_indices(axis=0, indices=self.indices) return self.dataset.project(projection) From 06850d82d12c8d56eb10c22021fcf4bcba9610d5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 3 Sep 2025 16:00:17 +0000 Subject: [PATCH 112/212] add transformations --- src/anemoi/datasets/data/components.py | 28 ++++++++++++++++++++++++-- src/anemoi/datasets/data/forwards.py | 6 ++---- src/anemoi/datasets/data/select.py | 7 +++++-- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index ce3e0793a..eb315719e 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -155,13 +155,14 @@ def __iter__(self): class ProjectionStore(ProjectionBase): - def __init__(self, slices, store): + def __init__(self, slices, store, transformations=None): assert isinstance(slices, (list, tuple)), slices assert all(isinstance(s, slice) for s in slices), slices assert len(slices) == 4, slices self.slices = slices self.store = store + self.transformations = transformations or [] def __repr__(self): return repr((self.slices, self.store.dataset_name)) @@ -187,5 +188,28 @@ def variables(self): def origins(self): result = {} for variable in self.variables(): - result[variable] = self.store.origins[variable] + + origins = self.store.origins[variable] + + pipe = [] + for transformation in self.transformations: + + action = transformation.origin_transformation(variable, origins) + action = action.copy() + action.setdefault("when", "dataset-usage") + action.setdefault("type", "filter") + pipe.append(action) + + if pipe: + origins = { + "type": "pipe", + "when": "dataset-usage", + "steps": [origins] + pipe, + } + + result[variable] = origins + return result + + def add_transformation(self, transformation): + return ProjectionStore(self.slices, self.store, self.transformations + [transformation]) diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 78d632db2..cef5d9ec2 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -240,10 +240,8 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields - def origin(self, index): - origin = self.forward_subclass_origin(index) - self.annotate_origin(origin) - return origin + def project(self, projection): + return self.forward.project(projection).add_transformation(self) class Combined(Forwards): diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index 90ba7e344..f783276d3 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -289,5 +289,8 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(rename=self.rename) - def project(self, projection): - return self.forward.project(projection) + def origin_transformation(self, variable, origins): + return { + "name": "rename", + "config": {"rename": self.rename}, + } From 1d94d5a63fa3c23a895529ee1db943061a6b657e Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 5 Sep 2025 11:04:28 +0000 Subject: [PATCH 113/212] up --- src/anemoi/datasets/data/records/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index a0f13652f..b8c568296 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -229,6 +229,10 @@ def __init__(self, fields_dataset, name): self._groups = [name] self.reason = {"name": name} + @property + def metadata(self): + return self.forward.metadata + def _nest_in_dict(self, obj): """Helper to nest the object in a dict with the name as key.""" return {self._name: obj} @@ -616,6 +620,10 @@ def __init__(self, dataset, select): self.reason = {"select": select} self._build_indices_and_name_to_index() + @property + def metadata(self): + return dict(select=self._select, forward=self.dataset.metadata) + def _build_indices_and_name_to_index(self): indices = {} name_to_index = {} From 6941fdfa6866b7eba39cddf9fc02c056ce275c4b Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 5 Sep 2025 11:20:54 +0000 Subject: [PATCH 114/212] add type hint --- src/anemoi/datasets/data/stores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 17290862c..122da8ce7 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -14,7 +14,7 @@ import tempfile import warnings from functools import cached_property -from typing import Any +from typing import Any, Optional from urllib.parse import urlparse import numpy as np From 9ef1fc3533c4719fcb41f6ca0963f0d988c8d4b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:21:48 +0000 Subject: [PATCH 115/212] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/datasets/data/stores.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 122da8ce7..30a932ba9 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -14,7 +14,8 @@ import tempfile import warnings from functools import cached_property -from typing import Any, Optional +from typing import Any +from typing import Optional from urllib.parse import urlparse import numpy as np From e09ed7ed9b3920b809beea60f31b0b99144bae2f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 6 Sep 2025 05:19:14 +0000 Subject: [PATCH 116/212] rename variables --- src/anemoi/datasets/data/components.py | 8 +++++++- src/anemoi/datasets/data/missing.py | 4 ++-- src/anemoi/datasets/data/select.py | 2 +- src/anemoi/datasets/data/stores.py | 28 ++++++++++++++++---------- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index eb315719e..bfd9fea78 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -153,6 +153,9 @@ def ensure_list(self): def __iter__(self): return iter(self.projections) + def add_transformation(self, transformation): + return ProjectionList([p.add_transformation(transformation) for p in self.projections]) + class ProjectionStore(ProjectionBase): def __init__(self, slices, store, transformations=None): @@ -194,7 +197,7 @@ def origins(self): pipe = [] for transformation in self.transformations: - action = transformation.origin_transformation(variable, origins) + action, variable = transformation.origin_transformation(variable, origins) action = action.copy() action.setdefault("when", "dataset-usage") action.setdefault("type", "filter") @@ -213,3 +216,6 @@ def origins(self): def add_transformation(self, transformation): return ProjectionStore(self.slices, self.store, self.transformations + [transformation]) + + def __iter__(self): + return iter([self]) diff --git a/src/anemoi/datasets/data/missing.py b/src/anemoi/datasets/data/missing.py index 5e6530bda..8e0fb44ff 100644 --- a/src/anemoi/datasets/data/missing.py +++ b/src/anemoi/datasets/data/missing.py @@ -59,14 +59,14 @@ def __init__(self, dataset: Dataset, missing_dates: list[int | str]) -> None: self._missing = set() - other = [] + other = set() for date in missing_dates: if isinstance(date, int): self._missing.add(date) self.missing_dates.append(dataset.dates[date]) else: date = to_datetime(date) - other.append(date) + other.add(date) if other: for i, date in enumerate(dataset.dates): diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index f783276d3..b422899ee 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -293,4 +293,4 @@ def origin_transformation(self, variable, origins): return { "name": "rename", "config": {"rename": self.rename}, - } + }, self.rename.get(variable, variable) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 79dfb6a55..9c88a2b30 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -85,21 +85,24 @@ class S3Store(ReadOnlyStore): options using the anemoi configs. """ - def __init__(self, url: str, region: str | None = None) -> None: - """Initialize the S3Store with a URL and optional region.""" - from anemoi.utils.remote.s3 import s3_client + def __init__(self, url: str) -> None: + """Initialize the S3Store with a URL.""" + + LOG.warning("Accessing dataset using %s", url) + LOG.warning("Data access may be slow") - _, _, self.bucket, self.key = url.split("/", 3) - self.s3 = s3_client(self.bucket, region=region) + self.url = url def __getitem__(self, key: str) -> bytes: """Retrieve an item from the store.""" - try: - response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) - except self.s3.exceptions.NoSuchKey: - raise KeyError(key) + from anemoi.utils.remote.s3 import get_object + + target = self.url + "/" + key - return response["Body"].read() + try: + return get_object(target).bytes() + except FileNotFoundError: + raise KeyError(target) class DebugStore(ReadOnlyStore): @@ -586,6 +589,9 @@ def zarr_lookup(name: str, fail: bool = True) -> str | None: pass if fail: - raise ValueError(f"Cannot find a dataset that matched '{name}'. Tried: {tried}") + LOG.error(f"Failed to find dataset '{name}'. Tried:") + for path in tried: + LOG.error(f" - {path}") + raise ValueError(f"Cannot find a dataset that matched '{name}'") return None From cd06c9832019efef7f330fd4fe46633d7d3ec530 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 6 Sep 2025 08:20:43 +0000 Subject: [PATCH 117/212] add origins test --- src/anemoi/datasets/data/stores.py | 13 ++--- tests/test_data.py | 2 +- tests/test_data_gridded.py | 2 +- tests/test_origins.py | 93 ++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 10 deletions(-) create mode 100644 tests/test_origins.py diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 9c88a2b30..90e635cd6 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -188,7 +188,7 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hie if cache is not None: store = zarr.LRUStoreCache(store, max_size=cache) - return zarr.convenience.open(store, "r") + return zarr.open(store, "r") except zarr.errors.PathNotFoundError: if not dont_fail: raise zarr.errors.PathNotFoundError(path) @@ -430,15 +430,12 @@ def collect_input_sources(self, collected: set) -> None: @cached_property def origins(self): origins = self.z.attrs.get("origins") - if self.z.attrs.get("origins") is None: - from anemoi.registry import Dataset - - LOG.warning("No 'origins' in %r, trying to get it from the registry", self.dataset_name) - ds = Dataset(self.dataset_name) - origins = ds.record.get("metadata", {}).get("origins") if origins is None: - raise ValueError(f"No 'origins' in {self.dataset_name} or in the registry") + import rich + + rich.print(dict(self.z.attrs)) + raise ValueError(f"No 'origins' in {self.dataset_name}") # version = origins["version"] origins = origins["origins"] diff --git a/tests/test_data.py b/tests/test_data.py index f81b16350..9b33ca92e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -54,7 +54,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index d493a50e7..9e9c7d1ec 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -41,7 +41,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) diff --git a/tests/test_origins.py b/tests/test_origins.py new file mode 100644 index 000000000..ef5a134fd --- /dev/null +++ b/tests/test_origins.py @@ -0,0 +1,93 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from collections.abc import Callable +from functools import wraps +from unittest.mock import patch + +from anemoi.utils.testing import skip_if_offline + +from anemoi.datasets import open_dataset + + +def _tests_zarrs(name: str) -> str: + return f"https://anemoi-test.ecmwf.int/test-zarrs/{name}.zarr" + + +def zarr_tests(func: Callable) -> Callable: + """Decorator to mock the zarr_lookup function. + + Parameters + ---------- + func : Callable + Function to wrap. + + Returns + ------- + Callable + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + with patch("anemoi.datasets.data.stores.zarr_lookup", _tests_zarrs): + return func(*args, **kwargs) + + return wrapper + + +@skip_if_offline +@zarr_tests +def test_origins_rename() -> None: + ds = open_dataset( + [ + { + "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "select": ["cp", "tp"], + "end": 2023, + "frequency": "6h", + }, + { + "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v1-precipitations", + "end": 2023, + "frequency": "6h", + "rename": {"tp_0h_12h": "tp"}, + "select": ["tp_0h_12h"], + }, + ], + end=2022, + ) + + for p in ds.components(): + print(p) + print(p.origins()) + + +@skip_if_offline +@zarr_tests +def test_origins_cutout() -> None: + ds = open_dataset( + cutout=[ + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + ], + adjust="all", + ) + + for p in ds.components(): + print(p) + print(p.origins()) + + +if __name__ == "__main__": + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() From f9fd3a02c52a094dddf4ef95bb85522cc0a6ec88 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 8 Sep 2025 14:36:41 +0000 Subject: [PATCH 118/212] tidy --- src/anemoi/datasets/data/components.py | 10 + src/anemoi/datasets/data/dataset.py | 7 +- src/anemoi/datasets/data/forwards.py | 14 +- src/anemoi/datasets/data/masked.py | 15 + src/anemoi/datasets/data/padded.py | 6 +- src/anemoi/datasets/data/rescale.py | 12 + src/anemoi/datasets/data/select.py | 3 + src/anemoi/datasets/data/stores.py | 8 +- tests/test_classes.py | 408 +++++++++++++++++++++++++ tests/test_origins.py | 93 ------ 10 files changed, 469 insertions(+), 107 deletions(-) create mode 100644 tests/test_classes.py delete mode 100644 tests/test_origins.py diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index bfd9fea78..db5bca108 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -8,6 +8,9 @@ # nor does it submit to any jurisdiction. +from collections import defaultdict + + def _indices_to_slices(indices: list[int]) -> list[slice]: indices = sorted(indices) assert len(indices) == len(set(indices)), "Duplicate indices are not allowed" @@ -85,6 +88,13 @@ def list_or_single(cls, projections): def ensure_list(self): return ProjectionList([self]) + def compressed_origins(self): + result = defaultdict(list) + for p in self.ensure_list(): + for k, v in p.origins().items(): + result[k].append(v) + return result + class Projection(ProjectionBase): diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 55203371b..4338d4ab3 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -1004,10 +1004,9 @@ def variables_metadata(self) -> dict[str, Any]: """Return the metadata of the variables in the dataset.""" pass - # @abstractmethod - def origin(self, index) -> Any: - """Return the origin of the variable at the specified index.""" - raise NotImplementedError(f"origin() is not implemented for `{self.__class__.__name__}`") + def origins(self) -> Any: + for p in self.components().ensure_list(): + print(p.origins()) def components(self) -> Any: from anemoi.datasets.data.components import Projection diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index cef5d9ec2..294aca063 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -240,9 +240,6 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields - def project(self, projection): - return self.forward.project(projection).add_transformation(self) - class Combined(Forwards): """A class to combine multiple datasets into a single dataset.""" @@ -664,3 +661,14 @@ def missing(self) -> set[int]: if self.axis == 0: # Advance if axis is time offset += len(d) return result + + def project(self, projection): + result = [] + offset = 0 + + for dataset in self.datasets: + for p in projection.ensure_list(): + result.append(dataset.project(p.offset(axis=self.axis, amount=-offset))) + offset += dataset.shape[self.axis] + + return projection.list_or_single(result) diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index f7eeea03d..0d4435d57 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -125,6 +125,9 @@ def collect_supporting_arrays(self, collected: list[tuple], *path: Any) -> None: super().collect_supporting_arrays(collected, *path) collected.append((path, self.mask_name, self.mask)) + def project(self, projection): + return self.forward.project(projection).add_transformation(self) + class Thinning(Masked): """A class to represent a thinned dataset.""" @@ -200,6 +203,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(thinning=self.thinning, method=self.method) + def origin_transformation(self, variable, origins): + return { + "name": "thinning", + "config": dict(thinning=self.thinning, method=self.method), + }, variable + class Cropping(Masked): """A class to represent a cropped dataset.""" @@ -250,6 +259,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(area=self.area) + def origin_transformation(self, variable, origins): + return { + "name": "cropping", + "config": dict(area=self.area), + }, variable + class TrimEdge(Masked): """A class that removes the boundary of a dataset.""" diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index d0bebb6fc..54c2f8d7e 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -32,9 +32,6 @@ class Padded(Forwards): - _before: int = 0 - _after: int = 0 - _inside: int = 0 def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reason: dict[str, Any]) -> None: """Create a padded subset of a dataset. @@ -53,6 +50,9 @@ def __init__(self, dataset: Dataset, start: str, end: str, frequency: str, reaso frequency = dataset.frequency self._frequency = frequency_to_timedelta(frequency) + self._before: int = 0 + self._after: int = 0 + self._inside: int = 0 if start is None: # default is to start at the first date diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/data/rescale.py index 613bbe93e..df19e31e8 100644 --- a/src/anemoi/datasets/data/rescale.py +++ b/src/anemoi/datasets/data/rescale.py @@ -242,3 +242,15 @@ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict raise NotImplementedError("rescale tendencies statistics", k) return result + + def project(self, projection): + return self.forward.project(projection).add_transformation(self) + + def origin_transformation(self, variable, origins): + config = {} + for variable, (a, b) in self.rescale.items(): + config[variable] = {"scale": a, "offset": b} + return { + "name": "rescale", + "config": config, + }, variable diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index b422899ee..a93cf983e 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -289,6 +289,9 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(rename=self.rename) + def project(self, projection): + return self.forward.project(projection).add_transformation(self) + def origin_transformation(self, variable, origins): return { "name": "rename", diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 90e635cd6..6f13dbc33 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -530,10 +530,10 @@ def label(self) -> str: """Return the label of the dataset.""" return "zarr*" - def origin(self, index): - if index[0] in self.missing: - self._report_missing(index[0]) - return super().origin(index) + # def origin(self, index): + # if index[0] in self.missing: + # self._report_missing(index[0]) + # return super().origin(index) QUIET = set() diff --git a/tests/test_classes.py b/tests/test_classes.py new file mode 100644 index 000000000..e6cf36f50 --- /dev/null +++ b/tests/test_classes.py @@ -0,0 +1,408 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from collections.abc import Callable +from functools import wraps +from unittest.mock import patch + +import pytest +from anemoi.utils.testing import skip_if_offline + +from anemoi.datasets import open_dataset + + +def _tests_zarrs(name: str) -> str: + return f"https://anemoi-test.ecmwf.int/test-zarrs/{name}.zarr" + + +def zarr_tests(func: Callable) -> Callable: + + @wraps(func) + def wrapper(*args, **kwargs): + with patch("anemoi.datasets.data.stores.zarr_lookup", _tests_zarrs): + return func(*args, **kwargs) + + return wrapper + + +def _test_dataset(ds): + for p in ds.components(): + print(p) + print(p.origins()) + + +not_ready = pytest.mark.skip(reason="Not ready yet") + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_complement_none(): + pass + # ds = open_dataset( + # source="cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + # complement="aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + # # adjust="all", + # ) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_complement_nearest(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_concat(): + pass + + +@skip_if_offline +@zarr_tests +def test_class_number(): + ds = open_dataset( + "aifs-ea-an-enda-0001-mars-o96-1979-2022-6h-v6", + number=[1, 5, 6], + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_ensemble(): + ds = open_dataset( + ensemble=[ + "aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6", + "aifs-ea-em-enda-0001-mars-o96-1979-2022-6h-v6", + ] + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_dates_fill(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_dates_closest(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_dates_interpolate(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_grids(): + ds = open_dataset( + grids=[ + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + ], + adjust="all", + ) + + for p in ds.components(): + print(p.origins()["2t"]) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_cutout() -> None: + ds = open_dataset( + cutout=[ + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + ], + adjust="all", + ) + + for p in ds.components(): + print(p) + print(p.origins()) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_date_error(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_interpolate_frequency(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_interpolate_nearest(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_join(): + pass + + +@skip_if_offline +@zarr_tests +def test_class_thinning(): + ds = open_dataset( + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + thinning=100, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_cropping(): + ds = open_dataset( + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + area=[80, -10, 30, 40], + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_trim_edge(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_merge(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_dates(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_skip_missing_dates(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_missing_dataset(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_padded(): + pass + + +@skip_if_offline +@zarr_tests +def test_class_rescale_1(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + rescale={"2t": (1.0, -273.15)}, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_rescale_2(): + try: + import cfunits # noqa: F401 + except FileNotFoundError: + # cfunits requires the library udunits2 to be installed + raise pytest.skip("udunits2 library not installed") + + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + rescale={"2t": ("K", "degC")}, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_rescale_3(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + rescale={ + "2t": {"scale": 1.0, "offset": -273.15}, + }, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_select_select_1(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + select=["msl", "2t"], + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_select_select_2(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + select={"msl", "2t"}, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_select_drop(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + drop=["2t", "msl"], + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_rename() -> None: + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + rename={"2t": "temperature", "msl": "pressure"}, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_rename_with_overlap() -> None: + ds = open_dataset( + [ + { + "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "select": ["cp", "tp"], + "end": 2023, + "frequency": "6h", + }, + { + "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v1-precipitations", + "end": 2023, + "frequency": "6h", + "rename": {"tp_0h_12h": "tp"}, + "select": ["tp_0h_12h"], + }, + ], + end=2022, + ) + + for p in ds.components(): + print(p) + print(p.origins()) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_statistics(): + pass + + +@skip_if_offline +@zarr_tests +def test_class_zarr(): + ds = open_dataset("aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6") + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_zarr_with_missing_dates(): + ds = open_dataset("rodeo-opera-files-o96-2013-2023-6h-v5") + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +def test_class_subset(): + ds = open_dataset( + "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + frequency="12h", + start=2017, + end=2018, + ) + _test_dataset(ds) + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_chain(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_zipbase(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_zip(): + pass + + +@skip_if_offline +@zarr_tests +@not_ready +def test_class_xy(): + pass + + +if __name__ == "__main__": + test_class_rescale_2() + exit(0) + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/test_origins.py b/tests/test_origins.py deleted file mode 100644 index ef5a134fd..000000000 --- a/tests/test_origins.py +++ /dev/null @@ -1,93 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -from collections.abc import Callable -from functools import wraps -from unittest.mock import patch - -from anemoi.utils.testing import skip_if_offline - -from anemoi.datasets import open_dataset - - -def _tests_zarrs(name: str) -> str: - return f"https://anemoi-test.ecmwf.int/test-zarrs/{name}.zarr" - - -def zarr_tests(func: Callable) -> Callable: - """Decorator to mock the zarr_lookup function. - - Parameters - ---------- - func : Callable - Function to wrap. - - Returns - ------- - Callable - Wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - with patch("anemoi.datasets.data.stores.zarr_lookup", _tests_zarrs): - return func(*args, **kwargs) - - return wrapper - - -@skip_if_offline -@zarr_tests -def test_origins_rename() -> None: - ds = open_dataset( - [ - { - "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", - "select": ["cp", "tp"], - "end": 2023, - "frequency": "6h", - }, - { - "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v1-precipitations", - "end": 2023, - "frequency": "6h", - "rename": {"tp_0h_12h": "tp"}, - "select": ["tp_0h_12h"], - }, - ], - end=2022, - ) - - for p in ds.components(): - print(p) - print(p.origins()) - - -@skip_if_offline -@zarr_tests -def test_origins_cutout() -> None: - ds = open_dataset( - cutout=[ - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", - ], - adjust="all", - ) - - for p in ds.components(): - print(p) - print(p.origins()) - - -if __name__ == "__main__": - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() From 4f5acbbcf4d8c84964e6ca8cad5fb63fdd18f46c Mon Sep 17 00:00:00 2001 From: Ewan Pinnington Date: Tue, 9 Sep 2025 09:51:00 +0000 Subject: [PATCH 119/212] adding bufr examples to obs-data building --- tests/create/bufr2df.py | 122 +++++++ tests/create/bufr2df_parallel.py | 314 ++++++++++++++++++ tests/create/odb2df.py | 10 +- tests/create/test_observations.py | 2 +- tests/create/test_observations_mars.py | 8 +- tests/create/test_observations_mars_bufr.py | 127 +++++++ .../test_observations_mars_bufr_complex.py | 147 ++++++++ .../test_observations_mars_bufr_parallel.py | 128 +++++++ 8 files changed, 848 insertions(+), 10 deletions(-) create mode 100644 tests/create/bufr2df.py create mode 100644 tests/create/bufr2df_parallel.py create mode 100644 tests/create/test_observations_mars_bufr.py create mode 100644 tests/create/test_observations_mars_bufr_complex.py create mode 100644 tests/create/test_observations_mars_bufr_parallel.py diff --git a/tests/create/bufr2df.py b/tests/create/bufr2df.py new file mode 100644 index 000000000..892141041 --- /dev/null +++ b/tests/create/bufr2df.py @@ -0,0 +1,122 @@ +import eccodes +import numpy as np +import pandas as pd +import tqdm +from earthkit.data.readers.bufr.bufr import BUFRReader +from gribapi.errors import KeyValueNotFoundError + + +def filter_values(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Filter the DataFrame based on the specified conditions""" + for col, condition in filters.items(): + if isinstance(condition, str): + condition = eval(condition) + if callable(condition): + df = df[df[col].apply(condition)] + elif isinstance(condition, slice): + start, stop = condition.start, condition.stop + query_str = f"({start} <= {col}) & ({col} < {stop})" + df = df.query(query_str) + elif isinstance(condition, (list, set)): + df = df[df[col].isin(condition)] + else: + raise ValueError(f"Invalid condition for column '{col}': {condition}") + return df + + +def bufr_get_array(bid: int, element: str, typ: type, nsubsets: int, missing_val=np.nan) -> np.ndarray: + """Wrapper for codes_get_array to work around the inconsistent handling of arrays in eccodes when data is constant""" + try: + arr = eccodes.codes_get_array(bid, element, typ) + if len(arr) == 1: + arr = np.ones(nsubsets, dtype=typ) * arr + except KeyValueNotFoundError: + arr = np.ones(nsubsets, dtype=typ) * missing_val + return arr + + +def extract_datetimes(bid: int, nreports: int) -> pd.DataFrame: + """Extracts and parses the date/time info from a bufr message + and returns as an array of datetime objects + """ + df = pd.DataFrame( + dict( + years=bufr_get_array(bid, "year", int, nreports), + months=bufr_get_array(bid, "month", int, nreports), + days=bufr_get_array(bid, "day", int, nreports), + hours=bufr_get_array(bid, "hour", int, nreports), + minutes=bufr_get_array(bid, "minute", int, nreports), + seconds=bufr_get_array(bid, "second", int, nreports, missing_val=0), + ) + ) + # Create the datetime series using pandas + datetimes = pd.to_datetime(df) + return datetimes + + +def get_msg(f, i, per_report_dict, per_datum_dict=None, filters=None) -> pd.DataFrame: + bid = eccodes.codes_bufr_new_from_file(f) + eccodes.codes_set(bid, "unpack", 1) + nreports = eccodes.codes_get(bid, "numberOfSubsets") + + data_dict = { + item: bufr_get_array(bid, col, float, nreports).astype(np.float32) for col, item in per_report_dict.items() + } + data_dict["times"] = extract_datetimes(bid, nreports) + + if per_datum_dict: + for col, sub_dict in per_datum_dict.items(): + ndatum = eccodes.codes_get_size(bid, next(iter(per_datum_dict))) // nreports + vals = bufr_get_array(bid, col, float, nreports * ndatum).astype(np.float32) + try: + vals_2d = vals.reshape(ndatum, nreports).T + except ValueError as e: + if "cannot reshape array" in str(e): + import warnings + + warnings.warn( + f"Reshape error in file {f}, message {i}: Cannot reshape array of size {len(vals)} " + f"into shape ({ndatum}, {nreports}). Skipping this message.", + RuntimeWarning, + ) + eccodes.codes_release(bid) + return None + else: + raise # Re-raise if it's a different ValueError + + for col_rename, slice_str in sub_dict.items(): + vals_col = vals_2d[:, eval(slice_str)] + for i in range(vals_col.shape[1]): + data_dict[f"{col_rename}_{i+1}"] = vals_col[:, i] + + df = pd.DataFrame(data_dict) + + if filters: + df = filter_values(df, filters) + + eccodes.codes_release(bid) + return df + + +def bufr2df( + ekd_ds: BUFRReader, + per_report: dict, + per_datum: dict = None, + filter: dict = None, +) -> pd.DataFrame: + """Extracts data from a BUFR file into a pandas DataFrame + -info on what to extract (and how it should be named in the dataframe) are + provided by input dictionaries; one at the per-report level and another for the per-datum + """ + fname = ekd_ds.path + with open(fname, "rb") as f: + nmessages = eccodes.codes_count_in_file(f) + bar = tqdm.tqdm( + iterable=range(nmessages), + desc="Processing bufr messages...", + mininterval=20.0, + ) + df_lst = [get_msg(f, i, per_report, per_datum, filter) for i in bar] + df = pd.concat(df_lst) + df = df.sort_values(by=["times"]).reset_index(drop=True) + return df diff --git a/tests/create/bufr2df_parallel.py b/tests/create/bufr2df_parallel.py new file mode 100644 index 000000000..04dd20e7a --- /dev/null +++ b/tests/create/bufr2df_parallel.py @@ -0,0 +1,314 @@ +import logging +import mmap +import os +from multiprocessing import Pool + +import eccodes +import numpy as np +import pandas as pd +from earthkit.data.readers.bufr.bufr import BUFRReader +from gribapi.errors import KeyValueNotFoundError + + +def filter_values(df: pd.DataFrame, filters: dict) -> pd.DataFrame: + """Filter the DataFrame based on the specified conditions""" + for col, condition in filters.items(): + if isinstance(condition, str): + condition = eval(condition) + if callable(condition): + df = df[df[col].apply(condition)] + elif isinstance(condition, slice): + start, stop = condition.start, condition.stop + query_str = f"({start} <= {col}) & ({col} < {stop})" + df = df.query(query_str) + elif isinstance(condition, (list, set)): + df = df[df[col].isin(condition)] + else: + raise ValueError(f"Invalid condition for column '{col}': {condition}") + return df + + +log = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(processName)s - %(levelname)s - %(message)s", + force=True, +) + + +def filter_bufr_message(bid: int, filter_config: dict) -> bool: + """Check if BUFR message meets filtering conditions specified in filter_config + Returns True if message should be kept, False if it should be filtered out + """ + namespace = {"inf": float("inf")} + + for key, condition in filter_config.items(): + try: + # Get the value from BUFR + value = eccodes.codes_get(bid, key) + + if isinstance(condition, str) and condition.startswith("lambda"): + # Lambda expression case + filter_condition = eval(condition, namespace) + if not filter_condition(value): + return False + else: + + # Direct value comparison case + if value != condition: + return False + + except eccodes.KeyValueNotFoundError: + logging.warning(f"Key {key} not found in BUFR message") + return False + except Exception as e: + logging.error(f"Error evaluating condition for {key}: {e}") + return False + + return True + + +def bufr_get_array(bid: int, element: str, typ: type, nsubsets: int, missing_val=np.nan) -> np.ndarray: + """Wrapper for codes_get_array to work around the inconsistent handling of arrays in eccodes when data is constant""" + try: + arr = eccodes.codes_get_array(bid, element, typ) + if len(arr) == 1: + arr = np.ones(nsubsets, dtype=typ) * arr + except KeyValueNotFoundError: + arr = np.ones(nsubsets, dtype=typ) * missing_val + return arr + + +def extract_datetimes(bid: int, nreports: int, position_prefix: str = "") -> pd.DataFrame: + """Extracts and parses the date/time info from a bufr message + and returns as an array of datetime objects + """ + df = pd.DataFrame( + dict( + years=bufr_get_array(bid, position_prefix + "year", int, nreports), + months=bufr_get_array(bid, position_prefix + "month", int, nreports), + days=bufr_get_array(bid, position_prefix + "day", int, nreports), + hours=bufr_get_array(bid, position_prefix + "hour", int, nreports), + minutes=bufr_get_array(bid, position_prefix + "minute", int, nreports), + seconds=bufr_get_array(bid, position_prefix + "second", int, nreports, missing_val=0), + ) + ) + # Create the datetime series using pandas + datetimes = pd.to_datetime(df) + return datetimes + + +def get_msg( + bufr_msg, + per_report: dict, + prefilter_msg_header: dict = {}, + prefilter_msg_data: dict = {}, + datetime_position_prefix: str = "", + per_datum: dict = None, + filters: dict = None, +) -> pd.DataFrame: + try: + bid = eccodes.codes_new_from_message(bufr_msg) + nreports = eccodes.codes_get(bid, "numberOfSubsets") + eccodes.codes_set(bid, "skipExtraKeyAttributes", 1) + + # Optionally filter messages based on header section entries + if prefilter_msg_header and not filter_bufr_message(bid, prefilter_msg_header): + eccodes.codes_release(bid) + return pd.DataFrame() + + eccodes.codes_set(bid, "unpack", 1) + + # Optionally filter messages based on data section entries + if prefilter_msg_data and not filter_bufr_message(bid, prefilter_msg_data): + eccodes.codes_release(bid) + return pd.DataFrame() + + data_dict = { + item: bufr_get_array(bid, col, float, nreports).astype(np.float32) for col, item in per_report.items() + } + + data_dict["times"] = extract_datetimes(bid, nreports, datetime_position_prefix) + + if per_datum: + for col, sub_dict in per_datum.items(): + ndatum = eccodes.codes_get_size(bid, next(iter(per_datum))) // nreports + vals = bufr_get_array(bid, col, float, nreports * ndatum).astype(np.float32) + try: + vals_2d = vals.reshape(ndatum, nreports).T + except ValueError as e: + if "cannot reshape array" in str(e): + import warnings + + warnings.warn( + f"Reshape error in bufr message {bufr_msg}: Cannot reshape array of size {len(vals)} " + f"into shape ({ndatum}, {nreports}). Skipping this message.", + RuntimeWarning, + ) + eccodes.codes_release(bid) + return None + else: + raise # Re-raise if it's a different ValueError + + for col_rename, slice_str in sub_dict.items(): + vals_col = vals_2d[:, eval(slice_str)] + for k in range(vals_col.shape[1]): + data_dict[f"{col_rename}_{k+1}"] = vals_col[:, k] + + df = pd.DataFrame(data_dict) + + if filters: + df = filter_values(df, filters) + + eccodes.codes_release(bid) + return df + except Exception as e: + import warnings + + warnings.warn( + f"Unexpected error in message: {str(e)}. Skipping this message.", + RuntimeWarning, + ) + if "bid" in locals(): + eccodes.codes_release(bid) + return None + + +class BufrData(object): + def __init__(self, BufrFileName): + self._filename = BufrFileName + self._fobj = open(self._filename, "rb") + self._fileno = self._fobj.fileno() + self._nmsg = eccodes.codes_count_in_file(self._fobj) + self._dataBlock = self.get_datablock() + self._lstOffsets = self.get_list_offsets() + + @property + def dataBlock(self): + return self._dataBlock + + @property + def nmsg(self): + return self._nmsg + + @property + def lstOffsets(self): + return self._lstOffsets + + def get_datablock(self): + with mmap.mmap(self._fileno, length=0, access=mmap.ACCESS_READ) as mobj: + data = mobj.read() + return data + + def get_list_offsets(self): + lstOffsets = [] + for _ in range(0, self._nmsg): + bid = eccodes.codes_bufr_new_from_file(self._fobj) + offset = eccodes.codes_get_message_offset(bid) + size = eccodes.codes_get_message_size(bid) + lstOffsets.append((offset, size)) + eccodes.codes_release(bid) + return lstOffsets + + def __del__(self): + self._fobj.close() + + +def read_block( + sublist, + dataBlock, + per_report: dict, + prefilter_msg_header: dict = None, + prefilter_msg_data: dict = None, + datetime_position_prefix: str = "", + per_datum: dict = None, + filters: dict = None, +): + log.info(f"PID : {os.getpid()} in read block sublist has {len(sublist)} elements") + try: + df_lst = [ + get_msg( + dataBlock[offset : offset + ch_size], + per_report, + prefilter_msg_header, + prefilter_msg_data, + datetime_position_prefix, + per_datum, + filters, + ) + for offset, ch_size in sublist + ] + return pd.concat(df_lst) + except Exception as e: + log.error(f"Error in read_block: {str(e)}") + raise + + +def split_list(alist, nparts): + nelem = len(alist) + chunkSize = nelem // (nparts) + sublists = [] + for i in range(0, nelem, chunkSize): + slist = alist[i : i + chunkSize] + sublists.append(slist) + return sublists + + +def bufr2df_parallel( + ekd_ds: BUFRReader, + per_report: dict, + nproc: int = 1, + prefilter_msg_header: dict = None, + prefilter_msg_data: dict = None, + datetime_position_prefix: str = "", + per_datum: dict = None, + filters: dict = None, +) -> pd.DataFrame: + fname = ekd_ds.path + mbfo = BufrData(fname) + fullDataBlock = mbfo.dataBlock + log.info(f"number of messages {mbfo.nmsg}") + sublists = split_list(mbfo.lstOffsets, nproc) + + nSubLists = len(sublists) + + pool = Pool(processes=nproc) + try: + results = [ + pool.apply_async( + read_block, + args=( + sublists[i], + fullDataBlock, + per_report, + prefilter_msg_header, + prefilter_msg_data, + datetime_position_prefix, + per_datum, + filters, + ), + ) + for i in range(0, nSubLists) + ] + all_lst = [] + for r in results: + try: + df = r.get() + all_lst.append(df) + except Exception as e: + log.error(f"Error getting result from worker process: {str(e)}") + continue + if not all_lst: + raise ValueError("No valid results were returned from any worker process") + finally: + pool.close() # Stop accepting new tasks + pool.join() # Wait for workers to finish with timeout + pool.terminate() # Force terminate if still running + + df = pd.concat(all_lst) + if len(df) > 0: + df = df.sort_values(by=["times"]).reset_index(drop=True) + + log.info(f"Number of rows in the dataframe {len(df)}") + + return df diff --git a/tests/create/odb2df.py b/tests/create/odb2df.py index 5ada55252..9ad31f1df 100644 --- a/tests/create/odb2df.py +++ b/tests/create/odb2df.py @@ -16,7 +16,7 @@ def load_varno_dict(path: Optional[str] = None) -> Dict: try: with open(path or "varno.json") as f: return json.load(f) - except: + except (ValueError, Exception): return {"data": []} @@ -27,7 +27,7 @@ def get_varno_name(varno: Union[int, str], varno_dict: Dict) -> str: for entry in varno_dict.get("data", []): if v in entry: return str(entry[0]) - except: + except (ValueError, Exception): pass return str(varno) @@ -75,7 +75,7 @@ def process_odb( try: df = reader.to_pandas() - except Exception as e: + except (ValueError, Exception) as e: logging.error(f"ODB conversion failed: {e}") return pd.DataFrame() @@ -108,14 +108,14 @@ def process_odb( format="%Y%m%d%H%M%S", ) df = df.drop(columns=[date_col, time_col], level=0) - except: + except (ValueError, Exception): logging.warning("Could not create datetime column") # Rename columns df.columns = rename_cols(df.columns.tolist(), extra_obs, varno_path) # Rename lat/lon columns to match expected format - df = df.rename(columns={"lat": "latitudes", "lon": "longitudes"}) + df = df.rename(columns={"lat": "latitudes", "lon": "longitudes"}).sort_values(by="times") return df diff --git a/tests/create/test_observations.py b/tests/create/test_observations.py index 2166827b9..6bdf577ca 100644 --- a/tests/create/test_observations.py +++ b/tests/create/test_observations.py @@ -45,7 +45,7 @@ def __call__(self, df): """Filter the data based on the given window.""" self._check(df) # Here we can add any filtering logic if needed - df["a1"] = df["a1"] + 0.42 + df.loc[:, "a1"] = df["a1"] + 0.42 return self._check(df) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index c98340f48..c5823476b 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -44,7 +44,7 @@ def __call__(self, window): return self._check(df) -class MarsSource(ObservationsSource): +class MarsObsSource(ObservationsSource): def __init__(self, request_dict, pre_process_dict, process_func): assert isinstance(request_dict, dict), "request_dict must be a dictionary" self.request_dict = request_dict @@ -83,7 +83,7 @@ def __call__(self, window): return self._check(df) -class DummyFilter(ObservationsFilter): +class ColFilter(ObservationsFilter): def __init__(self, col_name): self.col_name = col_name @@ -97,7 +97,7 @@ def __call__(self, df): dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] -source = MarsSource( +source = MarsObsSource( request_dict={ "class": "ea", "expver": "0001", @@ -117,7 +117,7 @@ def __call__(self, df): }, process_func=process_odb, ) -filter = DummyFilter("obsvalue_v10m_0") +filter = ColFilter("obsvalue_v10m_0") for d in dates: window = window_from_str("(-5h, 1h]").to_absolute_window(d) diff --git a/tests/create/test_observations_mars_bufr.py b/tests/create/test_observations_mars_bufr.py new file mode 100644 index 000000000..a22c7c3b9 --- /dev/null +++ b/tests/create/test_observations_mars_bufr.py @@ -0,0 +1,127 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import pandas as pd +from bufr2df import bufr2df +from earthkit.data import from_source + +from anemoi.datasets.create.sources.observations import ObservationsFilter +from anemoi.datasets.create.sources.observations import ObservationsSource +from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import window_from_str + +log = logging.getLogger(__name__) + + +class DummpySource(ObservationsSource): + def __init__(self, data): + assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" + self.data = data + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + if window.include_start: + mask = self.data["times"] > window.start + else: + mask = self.data["times"] >= window.start + if window.include_end: + mask &= self.data["times"] <= window.end + else: + mask &= self.data["times"] < window.end + + df = self.data[mask] + + return self._check(df) + + +class MarsObsSource(ObservationsSource): + def __init__(self, request_dict, pre_process_dict, process_func): + assert isinstance(request_dict, dict), "request_dict must be a dictionary" + self.request_dict = request_dict + self.pre_process_dict = pre_process_dict + self.process_func = process_func + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + request_dict = self.request_dict + request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" + try: + ekd_ds = from_source("mars", request_dict) + except Exception as e: + if "File is empty" in str(e): + log.warning( + f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." + ) + return + else: + raise # Re-raise if it's a different error + + data = self.process_func(ekd_ds, **self.pre_process_dict) + + if window.include_start: + mask = data["times"] > window.start + else: + mask = data["times"] >= window.start + if window.include_end: + mask &= data["times"] <= window.end + else: + mask &= data["times"] < window.end + + df = data[mask] + + return self._check(df) + + +class ColFilter(ObservationsFilter): + def __init__(self, col_name): + self.col_name = col_name + + def __call__(self, df): + """Filter the data based on the given window.""" + self._check(df) + # Here we can add any filtering logic if needed + df.loc[:, self.col_name] = df[self.col_name] + 0.42 + return self._check(df) + + +dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] + +source = MarsObsSource( + request_dict={ + "class": "od", + "expver": "0001", + "stream": "LWDA", + "type": "ai", + "obstype": "nexrad_rr", + "times": "00/06/12/18", + }, + pre_process_dict={ + # "target": odb2df.process_odb, + "per_report": { + "latitude": "latitudes", + "longitude": "longitudes", + "radarRainfallIntensity": "obsvalue_precip1h_0", + }, + }, + process_func=bufr2df, +) +filter = ColFilter("obsvalue_precip1h_0") + +for d in dates: + window = window_from_str("(-5h, 1h]").to_absolute_window(d) + print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) + d = source(window) + d = filter(d) + print(window) + print(d) diff --git a/tests/create/test_observations_mars_bufr_complex.py b/tests/create/test_observations_mars_bufr_complex.py new file mode 100644 index 000000000..ddb8afbac --- /dev/null +++ b/tests/create/test_observations_mars_bufr_complex.py @@ -0,0 +1,147 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import pandas as pd +from bufr2df_parallel import bufr2df_parallel +from earthkit.data import from_source + +from anemoi.datasets.create.sources.observations import ObservationsFilter +from anemoi.datasets.create.sources.observations import ObservationsSource +from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import window_from_str + +log = logging.getLogger(__name__) + + +class DummpySource(ObservationsSource): + def __init__(self, data): + assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" + self.data = data + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + if window.include_start: + mask = self.data["times"] > window.start + else: + mask = self.data["times"] >= window.start + if window.include_end: + mask &= self.data["times"] <= window.end + else: + mask &= self.data["times"] < window.end + + df = self.data[mask] + + return self._check(df) + + +class MarsObsSource(ObservationsSource): + def __init__(self, request_dict, pre_process_dict, process_func): + assert isinstance(request_dict, dict), "request_dict must be a dictionary" + self.request_dict = request_dict + self.pre_process_dict = pre_process_dict + self.process_func = process_func + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + request_dict = self.request_dict + request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" + try: + ekd_ds = from_source("mars", request_dict) + except Exception as e: + if "File is empty" in str(e): + log.warning( + f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." + ) + return + else: + raise # Re-raise if it's a different error + + data = self.process_func(ekd_ds, **self.pre_process_dict) + + if window.include_start: + mask = data["times"] > window.start + else: + mask = data["times"] >= window.start + if window.include_end: + mask &= data["times"] <= window.end + else: + mask &= data["times"] < window.end + + df = data[mask] + + return self._check(df) + + +class ColFilter(ObservationsFilter): + def __init__(self, col_name): + self.col_name = col_name + + def __call__(self, df): + """Filter the data based on the given window.""" + self._check(df) + # Here we can add any filtering logic if needed + df.loc[:, self.col_name] = df[self.col_name] + 0.42 + return self._check(df) + + +dates = [datetime.datetime(2015, 10, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] + +source = MarsObsSource( + request_dict={ + "class": "od", + "expver": "0001", + "stream": "DCDA/LWDA", + "type": "ai", + "obstype": "ssmis", + "times": "00/06/12/18", + }, + pre_process_dict={ + # "target": odb2df.process_odb, + "nproc": 12, + "prefilter_msg_header": {"satelliteID": 286.0}, + "datetime_position_prefix": "#1#", + "per_report": { + "satelliteID": "satelliteID", + "#1#latitude": "latitudes", + "#1#longitude": "longitudes", + # bearingOrAzimuth: azimuth + "fieldOfViewNumber": "fov_num", + "#9#brightnessTemperature": "obsvalue_rawbt_9", + "#10#brightnessTemperature": "obsvalue_rawbt_10", + "#11#brightnessTemperature": "obsvalue_rawbt_11", + "#12#brightnessTemperature": "obsvalue_rawbt_12", + "#13#brightnessTemperature": "obsvalue_rawbt_13", + "#14#brightnessTemperature": "obsvalue_rawbt_14", + "#15#brightnessTemperature": "obsvalue_rawbt_15", + "#16#brightnessTemperature": "obsvalue_rawbt_16", + "#17#brightnessTemperature": "obsvalue_rawbt_17", + "#18#brightnessTemperature": "obsvalue_rawbt_18", + }, + "filters": { + "longitudes": "lambda x: np.isfinite(x)", + "latitudes": "lambda x: np.isfinite(x)", + }, + }, + process_func=bufr2df_parallel, +) +filter = ColFilter("obsvalue_rawbt_9") + +for d in dates: + window = window_from_str("(-5h, 1h]").to_absolute_window(d) + print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) + d = source(window) + d = filter(d) + print(window) + print(d) + print(d["satelliteID"].unique()) diff --git a/tests/create/test_observations_mars_bufr_parallel.py b/tests/create/test_observations_mars_bufr_parallel.py new file mode 100644 index 000000000..d743efa8e --- /dev/null +++ b/tests/create/test_observations_mars_bufr_parallel.py @@ -0,0 +1,128 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import pandas as pd +from bufr2df_parallel import bufr2df_parallel +from earthkit.data import from_source + +from anemoi.datasets.create.sources.observations import ObservationsFilter +from anemoi.datasets.create.sources.observations import ObservationsSource +from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import window_from_str + +log = logging.getLogger(__name__) + + +class DummpySource(ObservationsSource): + def __init__(self, data): + assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" + self.data = data + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + if window.include_start: + mask = self.data["times"] > window.start + else: + mask = self.data["times"] >= window.start + if window.include_end: + mask &= self.data["times"] <= window.end + else: + mask &= self.data["times"] < window.end + + df = self.data[mask] + + return self._check(df) + + +class MarsObsSource(ObservationsSource): + def __init__(self, request_dict, pre_process_dict, process_func): + assert isinstance(request_dict, dict), "request_dict must be a dictionary" + self.request_dict = request_dict + self.pre_process_dict = pre_process_dict + self.process_func = process_func + + def __call__(self, window): + assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + + request_dict = self.request_dict + request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" + try: + ekd_ds = from_source("mars", request_dict) + except Exception as e: + if "File is empty" in str(e): + log.warning( + f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." + ) + return + else: + raise # Re-raise if it's a different error + + data = self.process_func(ekd_ds, **self.pre_process_dict) + + if window.include_start: + mask = data["times"] > window.start + else: + mask = data["times"] >= window.start + if window.include_end: + mask &= data["times"] <= window.end + else: + mask &= data["times"] < window.end + + df = data[mask] + + return self._check(df) + + +class ColFilter(ObservationsFilter): + def __init__(self, col_name): + self.col_name = col_name + + def __call__(self, df): + """Filter the data based on the given window.""" + self._check(df) + # Here we can add any filtering logic if needed + df.loc[:, self.col_name] = df[self.col_name] + 0.42 + return self._check(df) + + +dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] + +source = MarsObsSource( + request_dict={ + "class": "od", + "expver": "0001", + "stream": "LWDA", + "type": "ai", + "obstype": "nexrad_rr", + "times": "00/06/12/18", + }, + pre_process_dict={ + # "target": odb2df.process_odb, + "nproc": 12, + "per_report": { + "latitude": "latitudes", + "longitude": "longitudes", + "radarRainfallIntensity": "obsvalue_precip1h_0", + }, + }, + process_func=bufr2df_parallel, +) +filter = ColFilter("obsvalue_precip1h_0") + +for d in dates: + window = window_from_str("(-5h, 1h]").to_absolute_window(d) + print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) + d = source(window) + d = filter(d) + print(window) + print(d) From c9259c62e6a290ff719a56de2fd8b41d83e002c5 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 9 Sep 2025 10:07:15 +0000 Subject: [PATCH 120/212] fix skipped origins --- src/anemoi/datasets/create/input/action.py | 9 +++- .../datasets/create/input/context/field.py | 8 ++- src/anemoi/datasets/create/input/origin.py | 53 +++++++++++++++++-- tests/test_classes.py | 20 +++---- 4 files changed, 71 insertions(+), 19 deletions(-) diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 94269e209..2d64c047c 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -33,6 +33,9 @@ def __call__(self, context, argument): def python_code(self, code): pass + def __repr__(self): + return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" + class Concat(Action): def __init__(self, config, *path): @@ -147,7 +150,8 @@ def create_object(self, context, config): return create_datasets_source(context, config) def call_object(self, context, source, argument): - return context.origin(source.execute(context.source_argument(argument)), self) + result = source.execute(context.source_argument(argument)) + return context.origin(result, self, argument) def origin(self): from .origin import Source @@ -178,7 +182,8 @@ def create_object(self, context, config): return create_transform_filter(context, config) def call_object(self, context, filter, argument): - return context.origin(filter.forward(context.filter_argument(argument)), self) + result = filter.forward(context.filter_argument(argument)) + return context.origin(result, self, argument) def origin(self): from .origin import Filter diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index b89d33dab..8503e618a 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import logging from typing import Any from anemoi.transform.fields import new_field_with_metadata @@ -17,6 +18,8 @@ from ..result.field import FieldResult from . import Context +LOG = logging.getLogger(__name__) + class FieldContext(Context): @@ -55,7 +58,7 @@ def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - def origin(self, data: Any, action: Any) -> Any: + def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: origin = action.origin() @@ -67,6 +70,7 @@ def origin(self, data: Any, action: Any) -> Any: # The field has pass unchanges in a filter result.append(fs) else: - result.append(new_field_with_metadata(fs, anemoi_origin=origin.combine(previous))) + anemoi_origin = origin.combine(previous, action, action_arguments) + result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py index 5b14cc77f..d60bc202d 100644 --- a/src/anemoi/datasets/create/input/origin.py +++ b/src/anemoi/datasets/create/input/origin.py @@ -42,11 +42,14 @@ def __init__(self, s1, s2, when="dataset-create"): super().__init__(when) self.steps = [s1, s2] + assert s1 is not None, (s1, s2) + assert s2 is not None, (s1, s2) + if isinstance(s1, Pipe): assert not isinstance(s2, Pipe), (s1, s2) self.steps = s1.steps + [s2] - def combine(self, previous): + def combine(self, previous, action, action_arguments): assert False, (self, previous) def as_dict(self): @@ -60,6 +63,28 @@ def __repr__(self): return " | ".join(repr(s) for s in self.steps) +class Join(Origin): + def __init__(self, origins, when="dataset-create"): + assert isinstance(origins, (list, tuple, set)), origins + super().__init__(when) + self.steps = list(origins) + + assert all(o is not None for o in origins), origins + + def combine(self, previous, action, action_arguments): + assert False, (self, previous) + + def as_dict(self): + return { + "type": "join", + "steps": [s.as_dict() for s in self.steps], + "when": self.when, + } + + def __repr__(self): + return " & ".join(repr(s) for s in self.steps) + + class Source(Origin): def __init__(self, name, config, when="dataset-create"): super().__init__(when) @@ -67,7 +92,7 @@ def __init__(self, name, config, when="dataset-create"): self.name = name self.config = _un_dotdict(config) - def combine(self, previous): + def combine(self, previous, action, action_arguments): assert previous is None, f"Cannot combine origins, previous already exists: {previous}" return self @@ -91,10 +116,32 @@ def __init__(self, name, config, when="dataset-create"): self.config = _un_dotdict(config) self._cache = {} - def combine(self, previous): + def combine(self, previous, action, action_arguments): + + if previous is None: + key = (id(action), id(action_arguments)) + if key not in self._cache: + + LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") + LOG.warning(f"Connecting to action argumentsm {action_arguments}") + origins = set() + for k in action_arguments: + o = k.metadata("anemoi_origin", default=None) + if o is None: + raise ValueError( + f"Cannot combine origins, previous is None and action_arguments {action_arguments} has no origin" + ) + origins.add(o) + if len(origins) == 1: + self._cache[key] = origins.pop() + else: + self._cache[key] = Join(origins) + previous = self._cache[key] + if previous in self._cache: # We use a cache to avoid recomputing the same combination return self._cache[previous] + self._cache[previous] = Pipe(previous, self) return self._cache[previous] diff --git a/tests/test_classes.py b/tests/test_classes.py index e6cf36f50..2067dff46 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -121,9 +121,7 @@ def test_class_grids(): ], adjust="all", ) - - for p in ds.components(): - print(p.origins()["2t"]) + _test_dataset(ds) @skip_if_offline @@ -137,10 +135,7 @@ def test_class_cutout() -> None: ], adjust="all", ) - - for p in ds.components(): - print(p) - print(p.origins()) + _test_dataset(ds) @skip_if_offline @@ -195,7 +190,11 @@ def test_class_cropping(): @zarr_tests @not_ready def test_class_trim_edge(): - pass + ds = open_dataset( + "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + trim_edge=(1, 2, 3, 4), + ) + _test_dataset(ds) @skip_if_offline @@ -332,10 +331,7 @@ def test_class_rename_with_overlap() -> None: ], end=2022, ) - - for p in ds.components(): - print(p) - print(p.origins()) + _test_dataset(ds) @skip_if_offline From 8df9cc5f1b595c542bcd1d8421da3ce1a4273716 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 10 Sep 2025 05:50:48 +0000 Subject: [PATCH 121/212] work on origin --- .../datasets/commands/recipe/migrate.py | 2 +- src/anemoi/datasets/create/input/origin.py | 4 +++- .../datasets/create/input/result/field.py | 21 ++++++++++++++++--- src/anemoi/datasets/data/components.py | 6 +++++- src/anemoi/datasets/data/ensemble.py | 6 ++++++ src/anemoi/datasets/data/forwards.py | 3 +++ src/anemoi/datasets/data/masked.py | 10 ++------- src/anemoi/datasets/data/rescale.py | 9 ++------ src/anemoi/datasets/data/select.py | 3 --- tests/test_classes.py | 4 ++-- 10 files changed, 42 insertions(+), 26 deletions(-) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 03da61fbc..6a3c6301d 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -417,7 +417,7 @@ def _fix_some(config: dict) -> None: node = glom(config, ".".join(p[:-1])) node.update(node.pop("<<")) parent[node.pop("name")] = node - assert len(parent) == 2 + assert len(parent) == 2, parent del parent["source"] paths = find_paths_in_substrees("label.mars", config) diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py index d60bc202d..9f5173afc 100644 --- a/src/anemoi/datasets/create/input/origin.py +++ b/src/anemoi/datasets/create/input/origin.py @@ -119,11 +119,13 @@ def __init__(self, name, config, when="dataset-create"): def combine(self, previous, action, action_arguments): if previous is None: + # This can happen if the filter does not tag its output with an origin + # (e.g. a user plugin). In that case we try to get the origin from the action arguments key = (id(action), id(action_arguments)) if key not in self._cache: LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") - LOG.warning(f"Connecting to action argumentsm {action_arguments}") + LOG.warning(f"Connecting to action arguments {action_arguments}") origins = set() for k in action_arguments: o = k.metadata("anemoi_origin", default=None) diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index a052fc628..7363ebf00 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -566,19 +566,34 @@ def build_coords(self) -> None: name_key = list(self.order_by.keys())[1] p = None - origins = defaultdict(set) + origins_per_number = defaultdict(lambda: defaultdict(set)) for fs in self.datasource: o = fs.metadata("anemoi_origin", remapping=self.remapping, patches=self.patches) name = fs.metadata(name_key, remapping=self.remapping, patches=self.patches) + number = fs.metadata("number", remapping=self.remapping, patches=self.patches) - assert name not in origins[o], (name,) - origins[o].add(name) + assert name not in origins_per_number[number][o], name + origins_per_number[number][o].add(name) if p is not o: LOG.info(f"🔥🔥🔥🔥🔥🔥 Source: {name}, {o}") p = o + origins_per_variables = defaultdict(lambda: defaultdict(set)) + for number, origins in origins_per_number.items(): + for origin, names in origins.items(): + for name in names: + origins_per_variables[name][origin].add(number) + + origins = defaultdict(set) + + # Check if all members of a variable have the same origins + for name, origin_number in origins_per_variables.items(): + # For now we do not support variables with members from different origins + assert len(origin_number) == 1, origin_number + origins[list(origin_number.keys())[0]].add(name) + self._origins = [] for k, v in origins.items(): self._origins.append({"origin": k.as_dict(), "variables": sorted(v)}) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index db5bca108..6fdb781c1 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -207,7 +207,11 @@ def origins(self): pipe = [] for transformation in self.transformations: - action, variable = transformation.origin_transformation(variable, origins) + action = transformation.origin_transformation(variable, origins) + if isinstance(action, tuple): + # Needed to support 'rename' + action, variable = action + action = action.copy() action.setdefault("when", "dataset-usage") action.setdefault("type", "filter") diff --git a/src/anemoi/datasets/data/ensemble.py b/src/anemoi/datasets/data/ensemble.py index 50725c2c1..4826fa81d 100644 --- a/src/anemoi/datasets/data/ensemble.py +++ b/src/anemoi/datasets/data/ensemble.py @@ -124,6 +124,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """Returns metadata specific to the Number object.""" return {} + def origin_transformation(self, variable, origins): + return { + "name": "number", + "config": {"members": self.members}, + } + class Ensemble(GivenAxis): """A class to represent an ensemble of datasets combined along a given axis.""" diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 294aca063..b9f887774 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -240,6 +240,9 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields + def project(self, projection): + return self.forward.project(projection).add_transformation(self) + class Combined(Forwards): """A class to combine multiple datasets into a single dataset.""" diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index 0d4435d57..32148d7b0 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -125,9 +125,6 @@ def collect_supporting_arrays(self, collected: list[tuple], *path: Any) -> None: super().collect_supporting_arrays(collected, *path) collected.append((path, self.mask_name, self.mask)) - def project(self, projection): - return self.forward.project(projection).add_transformation(self) - class Thinning(Masked): """A class to represent a thinned dataset.""" @@ -207,7 +204,7 @@ def origin_transformation(self, variable, origins): return { "name": "thinning", "config": dict(thinning=self.thinning, method=self.method), - }, variable + } class Cropping(Masked): @@ -260,10 +257,7 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: return dict(area=self.area) def origin_transformation(self, variable, origins): - return { - "name": "cropping", - "config": dict(area=self.area), - }, variable + return {"name": "cropping", "config": dict(area=self.area)} class TrimEdge(Masked): diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/data/rescale.py index df19e31e8..b6071f3c7 100644 --- a/src/anemoi/datasets/data/rescale.py +++ b/src/anemoi/datasets/data/rescale.py @@ -243,14 +243,9 @@ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict return result - def project(self, projection): - return self.forward.project(projection).add_transformation(self) - def origin_transformation(self, variable, origins): config = {} for variable, (a, b) in self.rescale.items(): config[variable] = {"scale": a, "offset": b} - return { - "name": "rescale", - "config": config, - }, variable + + return {"name": "rescale", "config": config} diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index a93cf983e..b422899ee 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -289,9 +289,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(rename=self.rename) - def project(self, projection): - return self.forward.project(projection).add_transformation(self) - def origin_transformation(self, variable, origins): return { "name": "rename", diff --git a/tests/test_classes.py b/tests/test_classes.py index 2067dff46..b98649bdc 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -71,7 +71,7 @@ def test_class_concat(): @zarr_tests def test_class_number(): ds = open_dataset( - "aifs-ea-an-enda-0001-mars-o96-1979-2022-6h-v6", + "aifs-ea-an-enda-0001-mars-20p0-2020-2020-24h-v6", number=[1, 5, 6], ) _test_dataset(ds) @@ -396,7 +396,7 @@ def test_class_xy(): if __name__ == "__main__": - test_class_rescale_2() + test_class_number() exit(0) for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): From 4c06588dc1c2fc0aca64ce38e11f986af5f222d6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 13 Sep 2025 16:43:58 +0000 Subject: [PATCH 122/212] python recipes --- src/anemoi/datasets/create/config.py | 7 + .../datasets/create/input/context/__init__.py | 2 +- src/anemoi/datasets/data/complement.py | 8 +- src/anemoi/datasets/data/concat.py | 9 + src/anemoi/datasets/data/forwards.py | 32 ++- src/anemoi/datasets/grids.py | 6 +- src/anemoi/datasets/recipe.py | 17 +- tests/test_classes.py | 209 ++++++++++++++---- 8 files changed, 225 insertions(+), 65 deletions(-) diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index ffee2662f..2e5f27de7 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -10,6 +10,8 @@ import datetime import logging import os +import subprocess +import sys from copy import deepcopy from typing import Any @@ -402,6 +404,11 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: LoadersConfig The validated configuration object. """ + + if isinstance(config, str) and config.endswith(".py"): + result = subprocess.run([sys.executable, config], capture_output=True, text=True, check=True) + config = yaml.safe_load(result.stdout) + config = Config(config) if is_test: set_to_test_mode(config) diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index eef61504c..578ddaf66 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -34,7 +34,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - print(f"Registering data at path: {path}") + LOG.info(f"Registering data at path: {'.'.join(str(x) for x in path)}") self.results[tuple(path)] = data return data diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py index be5f84409..7f6f4484e 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/data/complement.py @@ -249,7 +249,13 @@ def __init__(self, target: Any, source: Any, max_distance: float = None, k: int """ super().__init__(target, source) + if isinstance(k, str): + assert False + LOG.warning(f"ComplementNearest: Interpreting k={k} ({type(k)}) as integer") + k = int(k) + self.k = k + self._distances, self._nearest_grid_points = nearest_grid_points( self._source.latitudes, self._source.longitudes, @@ -353,7 +359,7 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset: }[interpolation] if interpolation == "nearest": - k = kwargs.pop("k", "1") + k = kwargs.pop("k", 1) complement = Class(target=target, source=source, k=k)._subset(**kwargs) else: diff --git a/src/anemoi/datasets/data/concat.py b/src/anemoi/datasets/data/concat.py index 234001c8c..4398c15eb 100644 --- a/src/anemoi/datasets/data/concat.py +++ b/src/anemoi/datasets/data/concat.py @@ -255,6 +255,15 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return {} + def project(self, projection): + result = [] + + for dataset in self.datasets: + for p in projection.ensure_list(): + result.append(dataset.project(p)) + + return projection.list_or_single(result) + def concat_factory(args: tuple[Any, ...], kwargs: dict) -> Concat: """Factory function to create a Concat object. diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index b9f887774..decadabdd 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -294,7 +294,9 @@ def check_same_resolution(self, d1: Dataset, d2: Dataset) -> None: return if d1.resolution != d2.resolution: - raise ValueError(f"Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible resolutions: {d1.resolution} and {d2.resolution} ({d1} {d2})" + ) def check_same_frequency(self, d1: Dataset, d2: Dataset) -> None: """Checks if the frequencies of two datasets are the same. @@ -312,7 +314,9 @@ def check_same_frequency(self, d1: Dataset, d2: Dataset) -> None: If the frequencies are not the same. """ if d1.frequency != d2.frequency: - raise ValueError(f"Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible frequencies: {d1.frequency} and {d2.frequency} ({d1} {d2})" + ) def check_same_grid(self, d1: Dataset, d2: Dataset) -> None: """Checks if the grids of two datasets are the same. @@ -336,7 +340,7 @@ def check_same_grid(self, d1: Dataset, d2: Dataset) -> None: if not np.allclose(d1.latitudes, d2.latitudes, rtol=rtol) or not np.allclose( d1.longitudes, d2.longitudes, rtol=rtol ): - raise ValueError(f"Incompatible grid ({d1.longitudes} {d2.longitudes})") + raise ValueError(f"{self.__class__.__name__}: Incompatible grid ({d1.longitudes} {d2.longitudes})") def check_same_shape(self, d1: Dataset, d2: Dataset) -> None: """Checks if the shapes of two datasets are the same. @@ -354,10 +358,12 @@ def check_same_shape(self, d1: Dataset, d2: Dataset) -> None: If the shapes are not the same. """ if d1.shape[1:] != d2.shape[1:]: - raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") + raise ValueError(f"{self.__class__.__name__}: Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") if d1.variables != d2.variables: - raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})" + ) def check_same_sub_shapes(self, d1: Any, d2: Any, drop_axis: int) -> None: """Checks if the sub-shapes of two datasets are the same along a given axis. @@ -380,7 +386,7 @@ def check_same_sub_shapes(self, d1: Any, d2: Any, drop_axis: int) -> None: shape2 = d2.sub_shape(drop_axis) if shape1 != shape2: - raise ValueError(f"Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") + raise ValueError(f"{self.__class__.__name__}: Incompatible shapes: {d1.shape} and {d2.shape} ({d1} {d2})") def check_same_variables(self, d1: Dataset, d2: Dataset) -> None: """Checks if the variables of two datasets are the same. @@ -398,7 +404,9 @@ def check_same_variables(self, d1: Dataset, d2: Dataset) -> None: If the variables are not the same. """ if d1.variables != d2.variables: - raise ValueError(f"Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible variables: {d1.variables} and {d2.variables} ({d1} {d2})" + ) def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None: """Checks if the lengths of two datasets are the same. @@ -416,7 +424,7 @@ def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None: If the lengths are not the same. """ if d1._len != d2._len: - raise ValueError(f"Incompatible lengths: {d1._len} and {d2._len}") + raise ValueError(f"{self.__class__.__name__}: Incompatible lengths: {d1._len} and {d2._len}") def check_same_dates(self, d1: Dataset, d2: Dataset) -> None: """Checks if the dates of two datasets are the same. @@ -436,10 +444,14 @@ def check_same_dates(self, d1: Dataset, d2: Dataset) -> None: self.check_same_frequency(d1, d2) if d1.dates[0] != d2.dates[0]: - raise ValueError(f"Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible start dates: {d1.dates[0]} and {d2.dates[0]} ({d1} {d2})" + ) if d1.dates[-1] != d2.dates[-1]: - raise ValueError(f"Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})") + raise ValueError( + f"{self.__class__.__name__}: Incompatible end dates: {d1.dates[-1]} and {d2.dates[-1]} ({d1} {d2})" + ) def check_compatibility(self, d1: Dataset, d2: Dataset) -> None: """Checks if two datasets are compatible. diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 26f675526..ffec5e351 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -628,7 +628,7 @@ def nearest_grid_points( """ # TODO: Use the one from anemoi.utils.grids instead # from anemoi.utils.grids import ... - from scipy.spatial import cKDTree + from scipy.spatial import KDTree source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) source_points = np.array(source_xyz).transpose() @@ -636,9 +636,9 @@ def nearest_grid_points( target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) target_points = np.array(target_xyz).transpose() if max_distance is None: - distances, indices = cKDTree(source_points).query(target_points, k=k) + distances, indices = KDTree(source_points).query(target_points, k=k) else: - distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + distances, indices = KDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) return distances, indices diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index c0dbc1bea..a7057c1c2 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -19,7 +19,7 @@ from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry +# from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry from anemoi.datasets.create.sources import source_registry LOG = logging.getLogger(__name__) @@ -226,13 +226,6 @@ def __init__(self, name=None, description=None, attribution=None, licence=None): sources = source_registry.factories.copy() filters = transform_filter_registry.factories.copy() - for key, factory in datasets_filter_registry.factories.items(): - if key in filters: - LOG.warning( - f"Filter `{key}` is registered in anemoi.datasets filter registry and in anemoi.transform filter registry" - ) - filters[key] = factory - for key, factory in sources.items(): if key in filters: LOG.warning( @@ -341,7 +334,7 @@ def description(self): @description.setter def description(self, value): - self._description = value + self._description = value.strip() @property def attribution(self): @@ -349,7 +342,7 @@ def attribution(self): @attribution.setter def attribution(self, value): - self._attribution = value + self._attribution = value.strip() @property def licence(self): @@ -357,7 +350,7 @@ def licence(self): @licence.setter def licence(self, value): - self._licence = value + self._licence = value.strip() @property def name(self): @@ -365,7 +358,7 @@ def name(self): @name.setter def name(self, value): - self._name = value + self._name = value.strip() @property def dates(self): diff --git a/tests/test_classes.py b/tests/test_classes.py index b98649bdc..7a82f334c 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -8,18 +8,20 @@ # nor does it submit to any jurisdiction. +import os from collections.abc import Callable from functools import wraps from unittest.mock import patch import pytest +from anemoi.utils.testing import TEST_DATA_URL from anemoi.utils.testing import skip_if_offline from anemoi.datasets import open_dataset def _tests_zarrs(name: str) -> str: - return f"https://anemoi-test.ecmwf.int/test-zarrs/{name}.zarr" + return os.path.join(TEST_DATA_URL, "anemoi-datasets", f"{name}.zarr") def zarr_tests(func: Callable) -> Callable: @@ -32,7 +34,15 @@ def wrapper(*args, **kwargs): return wrapper -def _test_dataset(ds): +def _test_dataset(ds, variables=None): + + if variables is not None: + assert ds.variables == variables, ( + set(ds.variables) - set(variables), + set(variables) - set(ds.variables), + ds.variables, + ) + for p in ds.components(): print(p) print(p.origins()) @@ -47,34 +57,122 @@ def _test_dataset(ds): def test_class_complement_none(): pass # ds = open_dataset( - # source="cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", - # complement="aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + # source="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", + # complement="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", # # adjust="all", # ) @skip_if_offline @zarr_tests -@not_ready -def test_class_complement_nearest(): - pass +def test_class_complement_nearest_1(): + ds = open_dataset( + complement="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", + source="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + interpolation="nearest", + ) + _test_dataset( + ds, + variables=[ + "2t", + "cos_latitude", + "cp", + "insolation", + "lsm", + "msl", + "orog", + "sf", + "t_500", + "t_850", + "tp", + "z", + "z_500", + "z_850", + ], + ) + + +@skip_if_offline +@zarr_tests +def test_class_complement_nearest_2(): + ds = open_dataset( + source="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", + complement="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + interpolation="nearest", + ) + _test_dataset( + ds, + variables=[ + "2t", + "cos_latitude", + "cp", + "insolation", + "lsm", + "msl", + "orog", + "sf", + "t_500", + "t_850", + "tp", + "z", + "z_500", + "z_850", + ], + ) @skip_if_offline @zarr_tests -@not_ready def test_class_concat(): - pass + ds = open_dataset( + [ + "aifs-ea-an-oper-0001-mars-20p0-2016-2016-6h-v1", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + ] + ) + _test_dataset( + ds, + variables=[ + "2t", + "cos_latitude", + "cp", + "insolation", + "lsm", + "msl", + "t_500", + "t_850", + "tp", + "z", + "z_500", + "z_850", + ], + ) @skip_if_offline @zarr_tests def test_class_number(): ds = open_dataset( - "aifs-ea-an-enda-0001-mars-20p0-2020-2020-24h-v6", - number=[1, 5, 6], + "aifs-ea-an-enda-0001-mars-20p0-2017-2017-6h-v1", + members=[0, 2], + ) + _test_dataset( + ds, + variables=[ + "2t", + "cos_latitude", + "cp", + "insolation", + "lsm", + "msl", + "t_500", + "t_850", + "tp", + "z", + "z_500", + "z_850", + ], ) - _test_dataset(ds) @skip_if_offline @@ -82,11 +180,27 @@ def test_class_number(): def test_class_ensemble(): ds = open_dataset( ensemble=[ - "aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6", - "aifs-ea-em-enda-0001-mars-o96-1979-2022-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + "aifs-ea-em-enda-0001-mars-20p0-2017-2017-6h-v1", ] ) - _test_dataset(ds) + _test_dataset( + ds, + variables=[ + "2t", + "cos_latitude", + "cp", + "insolation", + "lsm", + "msl", + "t_500", + "t_850", + "tp", + "z", + "z_500", + "z_850", + ], + ) @skip_if_offline @@ -112,12 +226,11 @@ def test_class_missing_dates_interpolate(): @skip_if_offline @zarr_tests -@not_ready def test_class_grids(): ds = open_dataset( grids=[ - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", ], adjust="all", ) @@ -130,8 +243,8 @@ def test_class_grids(): def test_class_cutout() -> None: ds = open_dataset( cutout=[ - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", + "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", ], adjust="all", ) @@ -161,17 +274,34 @@ def test_class_interpolate_nearest(): @skip_if_offline @zarr_tests -@not_ready -def test_class_join(): - pass +def test_class_join_1(): + ds = open_dataset( + [ + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-sfc", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-pl", + ], + ) + _test_dataset(ds, ["2t", "lsm", "msl", "z", "t_500", "t_850", "z_500", "z_850"]) + + +@skip_if_offline +@zarr_tests +def test_class_join_2(): + ds = open_dataset( + [ + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-pl", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-sfc", + ], + ) + _test_dataset(ds, ["t_500", "t_850", "z_500", "z_850", "2t", "lsm", "msl", "z"]) @skip_if_offline @zarr_tests def test_class_thinning(): ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", - thinning=100, + "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", + thinning=4, ) _test_dataset(ds) @@ -180,7 +310,7 @@ def test_class_thinning(): @zarr_tests def test_class_cropping(): ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", area=[80, -10, 30, 40], ) _test_dataset(ds) @@ -191,7 +321,7 @@ def test_class_cropping(): @not_ready def test_class_trim_edge(): ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p5km-1984-2020-6h-v2-hmsi", + "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", trim_edge=(1, 2, 3, 4), ) _test_dataset(ds) @@ -236,7 +366,7 @@ def test_class_padded(): @zarr_tests def test_class_rescale_1(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", rescale={"2t": (1.0, -273.15)}, ) _test_dataset(ds) @@ -252,7 +382,7 @@ def test_class_rescale_2(): raise pytest.skip("udunits2 library not installed") ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", rescale={"2t": ("K", "degC")}, ) _test_dataset(ds) @@ -262,7 +392,7 @@ def test_class_rescale_2(): @zarr_tests def test_class_rescale_3(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", rescale={ "2t": {"scale": 1.0, "offset": -273.15}, }, @@ -274,7 +404,7 @@ def test_class_rescale_3(): @zarr_tests def test_class_select_select_1(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", select=["msl", "2t"], ) _test_dataset(ds) @@ -284,7 +414,7 @@ def test_class_select_select_1(): @zarr_tests def test_class_select_select_2(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", select={"msl", "2t"}, ) _test_dataset(ds) @@ -294,7 +424,7 @@ def test_class_select_select_2(): @zarr_tests def test_class_select_drop(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", drop=["2t", "msl"], ) _test_dataset(ds) @@ -304,7 +434,7 @@ def test_class_select_drop(): @zarr_tests def test_class_rename() -> None: ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", rename={"2t": "temperature", "msl": "pressure"}, ) _test_dataset(ds) @@ -312,11 +442,12 @@ def test_class_rename() -> None: @skip_if_offline @zarr_tests +@not_ready def test_class_rename_with_overlap() -> None: ds = open_dataset( [ { - "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "dataset": "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", "select": ["cp", "tp"], "end": 2023, "frequency": "6h", @@ -344,12 +475,13 @@ def test_class_statistics(): @skip_if_offline @zarr_tests def test_class_zarr(): - ds = open_dataset("aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6") + ds = open_dataset("aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1") _test_dataset(ds) @skip_if_offline @zarr_tests +@not_ready def test_class_zarr_with_missing_dates(): ds = open_dataset("rodeo-opera-files-o96-2013-2023-6h-v5") _test_dataset(ds) @@ -359,7 +491,7 @@ def test_class_zarr_with_missing_dates(): @zarr_tests def test_class_subset(): ds = open_dataset( - "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v6", + "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", frequency="12h", start=2017, end=2018, @@ -396,7 +528,8 @@ def test_class_xy(): if __name__ == "__main__": - test_class_number() + test_class_complement_nearest_1() + test_class_complement_nearest_2() exit(0) for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): From a917c49dcc8a9ca98d0d6dd3eb65011ad7545bf1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 13 Sep 2025 17:46:18 +0000 Subject: [PATCH 123/212] add doc --- docs/datasets/building/code/using-python-1.py | 7 ++++++ docs/datasets/building/code/using-python-2.py | 0 docs/datasets/building/code/using-python-3.py | 0 docs/datasets/building/code/using-python-4.py | 0 docs/datasets/building/introduction.rst | 1 + docs/datasets/building/using-python.rst | 24 +++++++++++++++++++ 6 files changed, 32 insertions(+) create mode 100644 docs/datasets/building/code/using-python-1.py create mode 100644 docs/datasets/building/code/using-python-2.py create mode 100644 docs/datasets/building/code/using-python-3.py create mode 100644 docs/datasets/building/code/using-python-4.py create mode 100644 docs/datasets/building/using-python.rst diff --git a/docs/datasets/building/code/using-python-1.py b/docs/datasets/building/code/using-python-1.py new file mode 100644 index 000000000..939a771fa --- /dev/null +++ b/docs/datasets/building/code/using-python-1.py @@ -0,0 +1,7 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.input = r.grib("input_data.grib", param=["2t", "msl"]) + +r.dump() diff --git a/docs/datasets/building/code/using-python-2.py b/docs/datasets/building/code/using-python-2.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/datasets/building/code/using-python-3.py b/docs/datasets/building/code/using-python-3.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/datasets/building/code/using-python-4.py b/docs/datasets/building/code/using-python-4.py new file mode 100644 index 000000000..e69de29bb diff --git a/docs/datasets/building/introduction.rst b/docs/datasets/building/introduction.rst index 71107baeb..ff075c548 100644 --- a/docs/datasets/building/introduction.rst +++ b/docs/datasets/building/introduction.rst @@ -94,6 +94,7 @@ operations can be combined to build complex datasets. statistics incremental advanced-options + using-python ******************** Naming Conventions diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst new file mode 100644 index 000000000..40708e2b0 --- /dev/null +++ b/docs/datasets/building/using-python.rst @@ -0,0 +1,24 @@ +############################# + Using Python define recipes +############################# + +You can use Python to define recipes for building datasets. This allows +for more complex logic and flexibility compared to using static +configuration files. + +When executed, the Python code will generate a YAML configuration that +can be used by the dataset building tool. + +Here is an example of how to define a dataset recipe using Python: + +.. literalinclude:: code/using-python-1.py + :language: python + +.. literalinclude:: code/using-python-2.py + :language: python + +.. literalinclude:: code/using-python-3.py + :language: python + +.. literalinclude:: code/using-python-4.py + :language: python From 39bedac79052b807080e1ee93f2382c4a9b521da Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 14 Sep 2025 07:50:23 +0100 Subject: [PATCH 124/212] add origins to metadata --- src/anemoi/datasets/data/dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 4b76d24f5..8d6d693d6 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -565,6 +565,11 @@ def metadata(self) -> dict[str, Any]: supporting_arrays=source_to_arrays[id(self)], ) + try: + md["origins"] = self.origins() + except Exception: + LOG.exception("Failed to get origins") + try: return json.loads(json.dumps(_tidy(md))) except Exception: From 9309cfeef5ddd825db032e25ce31e18ebac1f4bc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 14 Sep 2025 08:03:39 +0100 Subject: [PATCH 125/212] update --- src/anemoi/datasets/data/dataset.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 145075784..4338d4ab3 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -565,11 +565,6 @@ def metadata(self) -> dict[str, Any]: supporting_arrays=source_to_arrays[id(self)], ) - try: - md["origins"] = self.origins() - except Exception: - LOG.exception("Failed to get origins") - try: return json.loads(json.dumps(_tidy(md))) except Exception: From 437e4aa177e0685eed15b5a7b9b0d84fcb546c22 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 14 Sep 2025 11:12:16 +0100 Subject: [PATCH 126/212] docs --- docs/Makefile | 1 - docs/cli/grib-index.rst | 2 +- docs/datasets/building/introduction.rst | 12 +- .../building/sources/repeated-dates.rst | 2 +- .../datasets/create/input/repeated_dates.py | 386 ------------------ 5 files changed, 13 insertions(+), 390 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/repeated_dates.py diff --git a/docs/Makefile b/docs/Makefile index 12f080733..6c0762a44 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -19,5 +19,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - bash $(SOURCEDIR)/scripts/api_build.sh @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/cli/grib-index.rst b/docs/cli/grib-index.rst index 2b97b0157..61684dbfb 100644 --- a/docs/cli/grib-index.rst +++ b/docs/cli/grib-index.rst @@ -1,7 +1,7 @@ .. _grib-index_command: Grib-index Command -============ +================== The `grib-index` command is used to create an index file for GRIB files. The index file is then used by the `grib-index` :ref:`source `. diff --git a/docs/datasets/building/introduction.rst b/docs/datasets/building/introduction.rst index ff075c548..2054c75c5 100644 --- a/docs/datasets/building/introduction.rst +++ b/docs/datasets/building/introduction.rst @@ -94,7 +94,6 @@ operations can be combined to build complex datasets. statistics incremental advanced-options - using-python ******************** Naming Conventions @@ -106,3 +105,14 @@ operations can be combined to build complex datasets. :caption: Naming Conventions naming-conventions + +**************** + Python recipes +**************** + +.. toctree:: + :maxdepth: 1 + :hidden: + :caption: Python recipes + + using-python diff --git a/docs/datasets/building/sources/repeated-dates.rst b/docs/datasets/building/sources/repeated-dates.rst index ba16e4707..53baf3283 100644 --- a/docs/datasets/building/sources/repeated-dates.rst +++ b/docs/datasets/building/sources/repeated-dates.rst @@ -10,7 +10,7 @@ dates of the dataset. The general format of the `repeated-dates` source is: -.. literalinclude:: yaml/repeated_dates1.yaml +.. literalinclude:: yaml/repeated-dates1.yaml :language: yaml where ``source`` is any of the :ref:`operations ` or diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py deleted file mode 100644 index ad46fe208..000000000 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ /dev/null @@ -1,386 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from collections import defaultdict -from collections.abc import Generator -from typing import Any - -import numpy as np -from anemoi.transform.fields import new_field_with_valid_datetime -from anemoi.transform.fields import new_fieldlist_from_list -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_timedelta - -from .action import Action -from .action import action_factory -from .join import JoinResult -from .result.field import Result -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -class DateMapper: - """A factory class to create DateMapper instances based on the given mode.""" - - @staticmethod - def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": - """Create a DateMapper instance based on the given mode. - - Parameters - ---------- - mode : str - The mode to use for the DateMapper. - source : Any - The data source. - config : dict - Configuration parameters. - - Returns - ------- - DateMapper - An instance of DateMapper. - """ - MODES: dict = dict( - closest=DateMapperClosest, - climatology=DateMapperClimatology, - constant=DateMapperConstant, - ) - - if mode not in MODES: - raise ValueError(f"Invalid mode for DateMapper: {mode}") - - return MODES[mode](source, **config) - - -class DateMapperClosest(DateMapper): - """A DateMapper implementation that maps dates to the closest available dates.""" - - def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: - """Initialize DateMapperClosest. - - Parameters - ---------- - source : Any - The data source. - frequency : str - Frequency of the dates. - maximum : str - Maximum time delta. - skip_all_nans : bool - Whether to skip all NaN values. - """ - self.source: Any = source - self.maximum: Any = frequency_to_timedelta(maximum) - self.frequency: Any = frequency_to_timedelta(frequency) - self.skip_all_nans: bool = skip_all_nans - self.tried: set[Any] = set() - self.found: set[Any] = set() - - def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: - """Transform the group of dates to the closest available dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Generator[Tuple[Any, Any], None, None] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - asked_dates = list(group_of_dates) - if not asked_dates: - return [] - - to_try = set() - for date in asked_dates: - start = date - while start >= date - self.maximum: - to_try.add(start) - start -= self.frequency - - end = date - while end <= date + self.maximum: - to_try.add(end) - end += self.frequency - - to_try = sorted(to_try - self.tried) - info = {k: "no-data" for k in to_try} - - if not to_try: - LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") - # return [] - - if to_try: - result = self.source.select( - GroupOfDates( - sorted(to_try), - group_of_dates.provider, - partial_ok=True, - ) - ) - - cnt = 0 - for f in result.datasource: - cnt += 1 - # We could keep the fields in a dictionary, but we don't want to keep the fields in memory - date = as_datetime(f.metadata("valid_datetime")) - - if self.skip_all_nans: - if np.isnan(f.to_numpy()).all(): - LOG.warning(f"Skipping {date} because all values are NaN") - info[date] = "all-nans" - continue - - info[date] = "ok" - self.found.add(date) - - if cnt == 0: - raise ValueError(f"No data found for {group_of_dates} in {self.source}") - - self.tried.update(to_try) - - if not self.found: - for k, v in info.items(): - LOG.warning(f"{k}: {v}") - - raise ValueError(f"No matching data found for {asked_dates} in {self.source}") - - new_dates = defaultdict(list) - - for date in asked_dates: - best = None - for found_date in sorted(self.found): - delta = abs(date - found_date) - # With < we prefer the first date - # With <= we prefer the last date - if best is None or delta <= best[0]: - best = delta, found_date - new_dates[best[1]].append(date) - - for date, dates in new_dates.items(): - yield ( - GroupOfDates([date], group_of_dates.provider), - GroupOfDates(dates, group_of_dates.provider), - ) - - -class DateMapperClimatology(DateMapper): - """A DateMapper implementation that maps dates to specified climatology dates.""" - - def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: - """Initialize DateMapperClimatology. - - Parameters - ---------- - source : Any - The data source. - year : int - The year to map to. - day : int - The day to map to. - hour : Optional[int] - The hour to map to. - """ - self.year: int = year - self.day: int = day - self.hour: int | None = hour - - def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: - """Transform the group of dates to the specified climatology dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Generator[Tuple[Any, Any], None, None] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - dates = list(group_of_dates) - if not dates: - return [] - - new_dates = defaultdict(list) - for date in dates: - new_date = date.replace(year=self.year, day=self.day) - if self.hour is not None: - new_date = new_date.replace(hour=self.hour, minute=0, second=0) - new_dates[new_date].append(date) - - for date, dates in new_dates.items(): - yield ( - GroupOfDates([date], group_of_dates.provider), - GroupOfDates(dates, group_of_dates.provider), - ) - - -class DateMapperConstant(DateMapper): - """A DateMapper implementation that maps dates to a constant date.""" - - def __init__(self, source: Any, date: Any | None = None) -> None: - """Initialize DateMapperConstant. - - Parameters - ---------- - source : Any - The data source. - date : Optional[Any] - The constant date to map to. - """ - self.source: Any = source - self.date: Any | None = date - - def transform(self, group_of_dates: Any) -> tuple[Any, Any]: - """Transform the group of dates to a constant date. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Tuple[Any, Any] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - if self.date is None: - return [ - ( - GroupOfDates([], group_of_dates.provider), - group_of_dates, - ) - ] - - return [ - ( - GroupOfDates([self.date], group_of_dates.provider), - group_of_dates, - ) - ] - - -class DateMapperResult(Result): - """A Result implementation that updates the valid datetime of the datasource.""" - - def __init__( - self, - context: Any, - action_path: list[str], - group_of_dates: Any, - source_result: Any, - mapper: DateMapper, - original_group_of_dates: Any, - ) -> None: - """Initialize DateMapperResult. - - Parameters - ---------- - context : Any - The context. - action_path : list of str - The action path. - group_of_dates : Any - The group of dates. - source_result : Any - The source result. - mapper : DateMapper - The date mapper. - original_group_of_dates : Any - The original group of dates. - """ - super().__init__(context, action_path, group_of_dates) - - self.source_results: Any = source_result - self.mapper: DateMapper = mapper - self.original_group_of_dates: Any = original_group_of_dates - - @property - def datasource(self) -> Any: - """Get the datasource with updated valid datetime.""" - result: list = [] - - for field in self.source_results.datasource: - for date in self.original_group_of_dates: - result.append(new_field_with_valid_datetime(field, date)) - - if not result: - raise ValueError("repeated_dates: no input data found") - - return new_fieldlist_from_list(result) - - -class RepeatedDatesAction(Action): - """An Action implementation that selects and transforms a group of dates.""" - - def __init__(self, context: Any, action_path: list[str], source: Any, mode: str, **kwargs: Any) -> None: - """Initialize RepeatedDatesAction. - - Args: - context (Any): The context. - action_path (List[str]): The action path. - source (Any): The data source. - mode (str): The mode for date mapping. - **kwargs (Any): Additional arguments. - """ - super().__init__(context, action_path, source, mode, **kwargs) - - self.source: Any = action_factory(source, context, action_path + ["source"]) - self.mapper: DateMapper = DateMapper.from_mode(mode, self.source, kwargs) - self.mode = mode - self.kwargs = kwargs - - @trace_select - def select(self, group_of_dates: Any) -> JoinResult: - """Select and transform the group of dates. - - Args: - group_of_dates (Any): The group of dates to select. - - Returns - ------- - JoinResult - The result of the join operation. - """ - results: list = [] - for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): - results.append( - DateMapperResult( - self.context, - self.action_path, - one_date_group, - self.source.select(one_date_group), - self.mapper, - many_dates_group, - ) - ) - - return JoinResult(self.context, self.action_path, group_of_dates, results) - - def __repr__(self) -> str: - """Get the string representation of the action. - - Returns - ------- - str - The string representation. - """ - return f"MultiDateMatchAction({self.source}, {self.mapper})" From 79f170623d43715028b889fe5c5525f8ea7d46d2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 16 Sep 2025 07:46:54 +0100 Subject: [PATCH 127/212] add filter.rst --- docs/datasets/building/filters.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 docs/datasets/building/filters.rst diff --git a/docs/datasets/building/filters.rst b/docs/datasets/building/filters.rst new file mode 100644 index 000000000..3b3bd5abf --- /dev/null +++ b/docs/datasets/building/filters.rst @@ -0,0 +1,14 @@ +.. _filters: + +######### + Filters +######### + +.. warning:: + + This is still a work-in-progress. Some of the filters may be renamed + later. + +Filters are used to modify the data or metadata in a dataset. + +See :ref:`install ` for more information. From ef431f8a3d5a9176d040e93800d1ab5e9cd2504f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 21 Sep 2025 18:19:19 +0100 Subject: [PATCH 128/212] add docs --- docs/datasets/building/code/using-python-1.py | 4 ---- docs/datasets/building/code/using-python-2.py | 8 ++++++++ docs/datasets/building/code/using-python-3.py | 12 ++++++++++++ docs/datasets/building/using-python.rst | 4 +++- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/docs/datasets/building/code/using-python-1.py b/docs/datasets/building/code/using-python-1.py index 939a771fa..196a25f42 100644 --- a/docs/datasets/building/code/using-python-1.py +++ b/docs/datasets/building/code/using-python-1.py @@ -1,7 +1,3 @@ from anemoi.datasets.recipe import Recipe r = Recipe() - -r.input = r.grib("input_data.grib", param=["2t", "msl"]) - -r.dump() diff --git a/docs/datasets/building/code/using-python-2.py b/docs/datasets/building/code/using-python-2.py index e69de29bb..717129592 100644 --- a/docs/datasets/building/code/using-python-2.py +++ b/docs/datasets/building/code/using-python-2.py @@ -0,0 +1,8 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe( + description="Example dataset recipe", + name="example-dataset", + licence="CC-BY-4.0", + attribution="my-organisation", +) diff --git a/docs/datasets/building/code/using-python-3.py b/docs/datasets/building/code/using-python-3.py index e69de29bb..f21dc3947 100644 --- a/docs/datasets/building/code/using-python-3.py +++ b/docs/datasets/building/code/using-python-3.py @@ -0,0 +1,12 @@ +from anemoi.datasets.recipe import Recipe + +r = Recipe() + +r.description = """ +Example dataset recipe using Python, with attributes set one by one +and a multiline description. +""" + +r.name = "example-dataset" +r.licence = "CC-BY-4.0" +r.attribution = "my-organisation" diff --git a/docs/datasets/building/using-python.rst b/docs/datasets/building/using-python.rst index 40708e2b0..fbf2892cf 100644 --- a/docs/datasets/building/using-python.rst +++ b/docs/datasets/building/using-python.rst @@ -9,7 +9,9 @@ configuration files. When executed, the Python code will generate a YAML configuration that can be used by the dataset building tool. -Here is an example of how to define a dataset recipe using Python: +Here is an example of how to define a dataset recipe using Python. + +First create a ``Recipe`` object, which will hold the configuration: .. literalinclude:: code/using-python-1.py :language: python From 48c9d075ea8aa95c58b2e0a01ce7db1dc28fb89d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 25 Sep 2025 18:43:57 +0000 Subject: [PATCH 129/212] compress origins --- src/anemoi/datasets/data/components.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/components.py b/src/anemoi/datasets/data/components.py index 6fdb781c1..c16c20c6f 100644 --- a/src/anemoi/datasets/data/components.py +++ b/src/anemoi/datasets/data/components.py @@ -198,7 +198,7 @@ def apply(self, projection): def variables(self): return self.store.variables[self.slices[1]] - def origins(self): + def origins(self, compressed=False): result = {} for variable in self.variables(): @@ -226,6 +226,25 @@ def origins(self): result[variable] = origins + if compressed: + + def _hashable(v): + if isinstance(v, dict): + return tuple((k, _hashable(vv)) for k, vv in sorted(v.items())) + if isinstance(v, list): + return tuple(_hashable(vv) for vv in v) + return v + + compressed_result = defaultdict(list) + for k, v in result.items(): + compressed_result[_hashable(v)].append((k, v)) + + result = {} + for v in compressed_result.values(): + key = tuple(sorted(k for k, _ in v)) + value = v[0][1] + result[key] = value + return result def add_transformation(self, transformation): @@ -233,3 +252,7 @@ def add_transformation(self, transformation): def __iter__(self): return iter([self]) + + @property + def dataset_name(self): + return self.store.dataset_name From 8b6b76585762444aefb76362eb02735b8c9fd789 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 26 Sep 2025 10:58:31 +0200 Subject: [PATCH 130/212] added some comments --- src/anemoi/datasets/data/dataset.py | 3 + src/anemoi/datasets/data/records/__init__.py | 95 ++++++++++++++++--- .../data/records/backends/__init__.py | 16 ++++ 3 files changed, 103 insertions(+), 11 deletions(-) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 625b1a562..fe9d71607 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -410,6 +410,9 @@ def _select_to_columns(self, vars: str | list[str] | tuple[str] | set) -> list[i if not isinstance(vars, (list, tuple)): vars = [vars] + for v in vars: + if v not in self.name_to_index: + raise ValueError(f"select: unknown variable: {v}, available: {list(self.name_to_index)}") return [self.name_to_index[v] for v in vars] def _drop_to_columns(self, vars: str | Sequence[str]) -> list[int]: diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index b8c568296..ea901c7c8 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -78,8 +78,20 @@ def _to_numpy_dates(d): class BaseRecordsDataset: + """This is the base class for all datasets based on records. + Records datasets are datasets that can be indexed by time (int) or by group (str). + A record dataset is designed for observations, where multiple array of difference shapes need to be stored for each date. + They have the same concept or start_date, end_date, frequency as fields datasets, but each date correspond to a window. + All windows have the same size (the window span can be different from the dataset frequency) - def __getitem__(self, i): + variables in a record datasets are identified by a group and a name. + """ + + # Depending on the context, a variable is identified by "group.name", + # or using a dict with keys as groups and values as list of names. + # most of the code should be agnostic and transform one format to the other when needed. + + def __getitem__(self, i: int | str): if isinstance(i, str): return self._getgroup(i) @@ -90,15 +102,31 @@ def __getitem__(self, i): @cached_property def window(self): + """Returns a string representation of the relative window of the dataset, such as '(-3h, 3h]'.""" return str(self._window) - def _getgroup(self, i): - return Tabular(self, i) + def _getgroup(self, group: str): + """Returns a Tabular object for the group. As a partial function when argument group is given but i is not.""" + return Tabular(self, group) - def _getrecord(self, i): + def _getrecord(self, i: int): + """Returns a Record object for the time step i. As a partial function when argument i is given but group is not.""" return Record(self, i) - def _load_data(self, i): + def _load_data(self, i: int) -> dict: + """ + Load the data for a specific time step or window (i). + It is expected to return a dict containing keys of the form: + + - "data:group1" : numpy array + - "latitudes:group1" : numpy array + - "longitudes:group1" : numpy array + - "metadata:group1" : + - ... + - "data:group2" : numpy array + - "latitudes:group2" : numpy array + - ... + """ raise NotImplementedError("Must be implemented in subclass") @property @@ -221,6 +249,13 @@ class FieldsRecords(RecordsForward): """A wrapper around a FieldsDataset to provide a consistent interface for records datasets.""" def __init__(self, fields_dataset, name): + """wrapper around a fields dataset to provide a consistent interface for records datasets. + A FieldsRecords appears as a RecordsDataset with a single group. + This allows merging fields datasets with other records datasets. + Parameters: + fields_dataset: must be a regular fields dataset + name: the name of the group + .""" self.forward = fields_dataset from anemoi.datasets.data.dataset import Dataset @@ -293,7 +328,9 @@ def __len__(self): return len(self.forward.dates) -class GenericRename(RecordsForward): +class BaseRename(RecordsForward): + """Renames variables in a records dataset.""" + def __init__(self, dataset, rename): self.forward = dataset assert isinstance(rename, dict) @@ -320,16 +357,16 @@ def groups(self): return [self.rename.get(k, k) for k in self.forward.groups] -class Rename(GenericRename): +class Rename(BaseRename): pass -class SetGroup(GenericRename): +class SetGroup(BaseRename): def __init__(self, dataset, set_group): if len(dataset.groups) != 1: raise ValueError(f"{self.__class__.__name__} can only be used with datasets containing a single group.") - super.__init__(dataset, {dataset.groups[0]: set_group}) + super().__init__(dataset, {dataset.groups[0]: set_group}) def _load_data(self, i): return self.dataset._load_data(i) @@ -411,6 +448,7 @@ def _to_timedelta(t): class AbsoluteWindow: + # not used but expected to be useful when building datasets. And used in tests def __init__(self, start, end, include_start=True, include_end=True): assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}" @@ -428,6 +466,14 @@ def __repr__(self): class WindowsSpec: + # A window specified by relative timedeltas, such as (-6h, 0h] + # + # the term "WindowSpec" is used here to avoid confusion between + # - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec) + # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (AbsoluteWindow) + # + # but is is more confusing, it should be renamed as Window. + def __init__(self, *, start, end, include_start=False, include_end=True): assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}" @@ -447,6 +493,7 @@ def __init__(self, *, start, end, include_start=False, include_end=True): def to_absolute_window(self, date): """Convert the window to an absolute window based on a date.""" + # not used but expected to be useful when building datasets. And used in tests assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" start = date + self.start end = date + self.end @@ -466,6 +513,8 @@ def _frequency_to_string(t): return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" def compute_mask(self, timedeltas): + """Returns a boolean numpy array of the same shape as timedeltas.""" + assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" if self.include_start: lower_mask = timedeltas >= self._start_np @@ -480,6 +529,9 @@ def compute_mask(self, timedeltas): return lower_mask & upper_mask def starts_before(self, my_dates, other_dates, other_window): + # apply this window to my_dates[0] and the other_window to other_dates[0] + # return True if this window starts before the other window + assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" @@ -492,6 +544,7 @@ def starts_before(self, my_dates, other_dates, other_window): return my_start <= other_start def ends_after(self, my_dates, other_dates, other_window): + # same as starts_before assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" @@ -507,13 +560,15 @@ def ends_after(self, my_dates, other_dates, other_window): class Rewindowed(RecordsForward): + # change the window of a records dataset + # similar to changing the frequency of a dataset + def __init__(self, dataset, window): super().__init__(dataset) self.dataset = dataset # in this class anything with 1 refers to the original window/dataset # and anything with 2 refers to the new window/dataset - # and we use _Δ for timedeltas self._window1 = self.forward._window self._window2 = window_from_str(window) @@ -602,6 +657,13 @@ def _load_data(self, i): class Select(RecordsForward): + # Select a subset of variables from a records dataset + # select can be a list of strings with dots (or a dict with keys as groups and values as list of strings) + # + # the selection is a filter, not a reordering, which is different from fields datasets and should be documented/fixed + # + # Drop should be implemented + def __init__(self, dataset, select): super().__init__(dataset) @@ -693,6 +755,8 @@ def statistics(self): class RecordsSubset(RecordsForward): + """Subset of a records dataset based on a list of integer indices.""" + def __init__(self, dataset, indices, reason): super().__init__(dataset) self.dataset = dataset @@ -711,6 +775,7 @@ def __len__(self): class RecordsDataset(BaseRecordsDataset): + """This is the base class for all datasets based on records stored on disk.""" def __init__(self, path, backend=None, **kwargs): if kwargs: @@ -806,7 +871,13 @@ def tree(self): class Record: - def __init__(self, dataset, n): + """A record corresponds to a single time step in a record dataset.""" + + def __init__(self, dataset: RecordsDataset, n: int): + """A record corresponds to a single time step in a record dataset. + n : int, the index of the time step in the dataset. + dataset : RecordsDataset, the dataset this record belongs to. + """ self.dataset = dataset self.n = n @@ -867,6 +938,8 @@ def as_dict(self): class Tabular: + """A RecordsDataset for a single group, similar to a fields dataset, but allowing different shapes for each date.""" + def __init__(self, dataset, name): self.dataset = dataset self.name = name diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index 705c6b107..d5342d8e9 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -27,15 +27,19 @@ def __init__(self, path, **kwargs): self.kwargs = kwargs def read(self, i, **kwargs): + """Read the i-th record and return a dictionary of numpy arrays.""" raise NotImplementedError("Must be implemented in subclass") def read_metadata(self): + """Read the metadata of a record dataset. The metadata does not depend on the record index.""" raise NotImplementedError("Must be implemented in subclass") def read_statistics(self): + """Read the statistics of a record dataset. The statistics does not depend on the record index.""" raise NotImplementedError("Must be implemented in subclass") def _check_data(self, data): + # Check that all keys are normalised for k in list(data.keys()): k = k.split(":")[-1] if k != normalise_key(k): @@ -139,16 +143,22 @@ def backend_factory(name, *args, **kwargs): class WriteBackend(Backend): + # Write backend base class, not used for reading + # provides implementation to write data def __init__(self, *, target, **kwargs): super().__init__(target, **kwargs) def write(self, i, data, **kwargs): + # expects data to be a dict of numpy arrays raise NotImplementedError("Must be implemented in subclass") def write_metadata(self, metadata): + # expects metadata to be a dict raise NotImplementedError("Must be implemented in subclass") def write_statistics(self, statistics): + # expects statistics to be a dict of dicts with the right keys: + # {group: {mean:..., std:..., min:..., max:...}} raise NotImplementedError("Must be implemented in subclass") def _check_data(self, data): @@ -158,6 +168,8 @@ def _check_data(self, data): raise ValueError(f"{k} must be alphanumerical and '_' only.") def _dataframes_to_record(self, i, data, variables, **kwargs): + # Convert data from pandas DataFrames to a record format + # will be used for writing, building obs datasets assert isinstance(data, (dict)), type(data) if not data: @@ -174,6 +186,8 @@ def _dataframes_to_record(self, i, data, variables, **kwargs): return data def _dataframe_to_dict(self, name, df, **kwargs): + # will be used for writing, building obs datasets + d = {} d["timedeltas:" + name] = df["timedeltas"] d["latitudes:" + name] = df["latitudes"] @@ -304,6 +318,8 @@ def write_statistics(self, statistics): def writer_backend_factory(name, **kwargs): + # choose the right backend for writing + # this is intended to make benchmarking easier WRITE_BACKENDS = dict( npz1=Npz1WriteBackend, npz2=Npz2WriteBackend, From c6da1cf2b81ad56d8c966ea9b5b32bc55a2be10c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 08:00:22 +0000 Subject: [PATCH 131/212] check that file exists --- src/anemoi/datasets/data/misc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 6e764b903..d6b88c04c 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -364,6 +364,10 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " return Zarr(path).mutate() if path and path.endswith(".vz"): + + if not os.path.exists(path): + raise FileNotFoundError(f"File not found: {path}") + metadata_path = os.path.join(path, "metadata.json") if os.path.exists(metadata_path): if "backend" not in load_any_dict_format(metadata_path): From 3830d1c64a6720741874935ad206fe5aa4492af8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 08:03:24 +0000 Subject: [PATCH 132/212] make Record a mapping --- src/anemoi/datasets/data/records/__init__.py | 32 ++++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index b8c568296..474b88e69 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -11,6 +11,7 @@ import logging import os from collections import defaultdict +from collections.abc import Mapping from functools import cached_property import numpy as np @@ -805,7 +806,17 @@ def tree(self): return Node(self, [], path=self.path) -class Record: +class Record(Mapping): + """A record representing data for each group in the dataset. + + Parameters + ---------- + dataset : BaseRecordsDataset + The dataset containing the record. + n : int + The index of the record. + """ + def __init__(self, dataset, n): self.dataset = dataset self.n = n @@ -817,6 +828,15 @@ def __repr__(self): def items(self): return self._payload.items() + def __iter__(self): + return iter(self.groups) + + def __len__(self): + return len(self.groups) + + def __contains__(self, group): + return group in self.groups + @property def name_to_index(self): return self.dataset.name_to_index @@ -861,8 +881,14 @@ def timedeltas(self): def statistics(self): return self.dataset.statistics - def as_dict(self): - """Returns the record as a dictionary with group names as keys.""" + def as_dict(self) -> dict: + """Returns the record as a dictionary with group names as keys. + + Returns + ------- + dict + Dictionary mapping group names to their data. + """ return {group: self[group] for group in self.groups} From 3ba6c1116220435751d2d48c664661a6d81be064 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 08:42:07 +0000 Subject: [PATCH 133/212] add windows.py --- src/anemoi/datasets/data/records/__init__.py | 179 +----------------- src/anemoi/datasets/data/records/windows.py | 185 +++++++++++++++++++ 2 files changed, 189 insertions(+), 175 deletions(-) create mode 100644 src/anemoi/datasets/data/records/windows.py diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 7b74516b1..0c55d988a 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -16,12 +16,13 @@ import numpy as np from anemoi.utils.config import load_any_dict_format -from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.data.debug import Node from anemoi.datasets.data.records.backends import backend_factory +from .windows import window_from_str + LOG = logging.getLogger(__name__) if os.environ.get("ANEMOI_DATASET_COUNTER", "0") == "1": @@ -59,13 +60,6 @@ def merge_data(list_of_dicts): return {k: np.hstack(v) for k, v in merged.items()} -def _to_numpy_timedelta(td): - if isinstance(td, np.timedelta64): - assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" - return td - return np.timedelta64(int(td.total_seconds()), "s") - - def _to_numpy_date(d): if isinstance(d, np.datetime64): assert d.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {d.dtype}" @@ -395,171 +389,6 @@ def match_variable(lst, group, name): return False -def window_from_str(txt): - """Parses a window string of the form '(-6h, 0h]' and returns a WindowsSpec object.""" - if txt.startswith("["): - include_start = True - elif txt.startswith("("): - include_start = False - else: - raise ValueError(f"Invalid window {txt}, must start with '(' or '['") - txt = txt[1:] - - if txt.endswith("]"): - include_end = True - elif txt.endswith(")"): - include_end = False - else: - raise ValueError(f"Invalid window {txt}, must end with ')' or ']'") - txt = txt[:-1] - - txt = txt.strip() - if ";" in txt: - txt = txt.replace(";", ",") - lst = txt.split(",") - if len(lst) != 2: - raise ValueError( - f"Invalid window {txt}, must be of the form '(start, end)' or '[start, end]' or '[start, end)' or '(start, end]'" - ) - start, end = lst - start = start.strip() - end = end.strip() - - def _to_timedelta(t): - # This part should go into utils - from anemoi.utils.dates import as_timedelta - - if t.startswith(" ") or t.endswith(" "): - t = t.strip() - if t.startswith("-"): - return -as_timedelta(t[1:]) - if t.startswith("+"): - return as_timedelta(t[1:]) - # end of : This part should go into utils - return as_timedelta(t) - - start = _to_timedelta(start) - end = _to_timedelta(end) - return WindowsSpec( - start=start, - end=end, - include_start=include_start, - include_end=include_end, - ) - - -class AbsoluteWindow: - # not used but expected to be useful when building datasets. And used in tests - def __init__(self, start, end, include_start=True, include_end=True): - assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" - assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}" - assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" - assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" - if start >= end: - raise ValueError(f"start {start} must be less than end {end}") - self.start = start - self.end = end - self.include_start = include_start - self.include_end = include_end - - def __repr__(self): - return f"{'[' if self.include_start else '('}{self.start.isoformat()},{self.end.isoformat()}{']' if self.include_end else ')'}" - - -class WindowsSpec: - # A window specified by relative timedeltas, such as (-6h, 0h] - # - # the term "WindowSpec" is used here to avoid confusion between - # - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec) - # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (AbsoluteWindow) - # - # but is is more confusing, it should be renamed as Window. - - def __init__(self, *, start, end, include_start=False, include_end=True): - assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" - assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}" - assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" - assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" - assert include_start in (True, False), f"Invalid include_start {include_start}" # None is not allowed - assert include_end in (True, False), f"Invalid include_end {include_end}" # None is not allowed - if start >= end: - raise ValueError(f"start {start} must be less than end {end}") - self.start = start - self.end = end - self.include_start = include_start - self.include_end = include_end - - self._start_np = _to_numpy_timedelta(start) - self._end_np = _to_numpy_timedelta(end) - - def to_absolute_window(self, date): - """Convert the window to an absolute window based on a date.""" - # not used but expected to be useful when building datasets. And used in tests - assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" - start = date + self.start - end = date + self.end - return AbsoluteWindow(start=start, end=end, include_start=self.include_start, include_end=self.include_end) - - def __repr__(self): - first = "[" if self.include_start else "(" - last = "]" if self.include_end else ")" - - def _frequency_to_string(t): - if t < datetime.timedelta(0): - return f"-{frequency_to_string(-t)}" - elif t == datetime.timedelta(0): - return "0" - return frequency_to_string(t) - - return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" - - def compute_mask(self, timedeltas): - """Returns a boolean numpy array of the same shape as timedeltas.""" - - assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" - if self.include_start: - lower_mask = timedeltas >= self._start_np - else: - lower_mask = timedeltas > self._start_np - - if self.include_end: - upper_mask = timedeltas <= self._end_np - else: - upper_mask = timedeltas < self._end_np - - return lower_mask & upper_mask - - def starts_before(self, my_dates, other_dates, other_window): - # apply this window to my_dates[0] and the other_window to other_dates[0] - # return True if this window starts before the other window - - assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" - assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" - assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" - - my_start = my_dates[0] + self._start_np - other_start = other_dates[0] + other_window._start_np - - if my_start == other_start: - return (not other_window.include_start) or self.include_start - return my_start <= other_start - - def ends_after(self, my_dates, other_dates, other_window): - # same as starts_before - assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" - assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" - assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" - - my_end = my_dates[-1] + self._end_np - other_end = other_dates[-1] + other_window._end_np - - if my_end == other_end: - print(".", (not other_window.include_end) or self.include_end) - return (not other_window.include_end) or self.include_end - print(my_end >= other_end) - return my_end >= other_end - - class Rewindowed(RecordsForward): # change the window of a records dataset # similar to changing the frequency of a dataset @@ -903,14 +732,14 @@ def name_to_index(self): return self.dataset.name_to_index @cached_property - def _payload(self): + def _payload(self) -> dict: payload = self.dataset._load_data(self.n) for k in payload.keys(): assert len(k.split(":")) == 2, f"Invalid key {k}" return payload @cached_property - def groups(self): + def groups(self) -> list[str]: return self.dataset.groups def __getitem__(self, group): diff --git a/src/anemoi/datasets/data/records/windows.py b/src/anemoi/datasets/data/records/windows.py new file mode 100644 index 000000000..bfa3950a7 --- /dev/null +++ b/src/anemoi/datasets/data/records/windows.py @@ -0,0 +1,185 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime + +import numpy as np +from anemoi.utils.dates import frequency_to_string + + +def _to_numpy_timedelta(td): + if isinstance(td, np.timedelta64): + assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" + return td + return np.timedelta64(int(td.total_seconds()), "s") + + +def window_from_str(txt): + """Parses a window string of the form '(-6h, 0h]' and returns a WindowsSpec object.""" + if txt.startswith("["): + include_start = True + elif txt.startswith("("): + include_start = False + else: + raise ValueError(f"Invalid window {txt}, must start with '(' or '['") + txt = txt[1:] + + if txt.endswith("]"): + include_end = True + elif txt.endswith(")"): + include_end = False + else: + raise ValueError(f"Invalid window {txt}, must end with ')' or ']'") + txt = txt[:-1] + + txt = txt.strip() + if ";" in txt: + txt = txt.replace(";", ",") + lst = txt.split(",") + if len(lst) != 2: + raise ValueError( + f"Invalid window {txt}, must be of the form '(start, end)' or '[start, end]' or '[start, end)' or '(start, end]'" + ) + start, end = lst + start = start.strip() + end = end.strip() + + def _to_timedelta(t): + # This part should go into utils + from anemoi.utils.dates import as_timedelta + + if t.startswith(" ") or t.endswith(" "): + t = t.strip() + if t.startswith("-"): + return -as_timedelta(t[1:]) + if t.startswith("+"): + return as_timedelta(t[1:]) + # end of : This part should go into utils + return as_timedelta(t) + + start = _to_timedelta(start) + end = _to_timedelta(end) + return WindowsSpec( + start=start, + end=end, + include_start=include_start, + include_end=include_end, + ) + + +class AbsoluteWindow: + # not used but expected to be useful when building datasets. And used in tests + def __init__(self, start, end, include_start=True, include_end=True): + assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" + assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}" + assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" + assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" + if start >= end: + raise ValueError(f"start {start} must be less than end {end}") + self.start = start + self.end = end + self.include_start = include_start + self.include_end = include_end + + def __repr__(self): + return f"{'[' if self.include_start else '('}{self.start.isoformat()},{self.end.isoformat()}{']' if self.include_end else ')'}" + + +class WindowsSpec: + # A window specified by relative timedeltas, such as (-6h, 0h] + # + # the term "WindowSpec" is used here to avoid confusion between + # - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec) + # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (AbsoluteWindow) + # + # but is is more confusing, it should be renamed as Window. + + def __init__(self, *, start, end, include_start=False, include_end=True): + assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" + assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}" + assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" + assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" + assert include_start in (True, False), f"Invalid include_start {include_start}" # None is not allowed + assert include_end in (True, False), f"Invalid include_end {include_end}" # None is not allowed + if start >= end: + raise ValueError(f"start {start} must be less than end {end}") + self.start = start + self.end = end + self.include_start = include_start + self.include_end = include_end + + self._start_np = _to_numpy_timedelta(start) + self._end_np = _to_numpy_timedelta(end) + + def to_absolute_window(self, date): + """Convert the window to an absolute window based on a date.""" + # not used but expected to be useful when building datasets. And used in tests + assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" + start = date + self.start + end = date + self.end + return AbsoluteWindow(start=start, end=end, include_start=self.include_start, include_end=self.include_end) + + def __repr__(self): + first = "[" if self.include_start else "(" + last = "]" if self.include_end else ")" + + def _frequency_to_string(t): + if t < datetime.timedelta(0): + return f"-{frequency_to_string(-t)}" + elif t == datetime.timedelta(0): + return "0" + return frequency_to_string(t) + + return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" + + def compute_mask(self, timedeltas): + """Returns a boolean numpy array of the same shape as timedeltas.""" + + assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" + if self.include_start: + lower_mask = timedeltas >= self._start_np + else: + lower_mask = timedeltas > self._start_np + + if self.include_end: + upper_mask = timedeltas <= self._end_np + else: + upper_mask = timedeltas < self._end_np + + return lower_mask & upper_mask + + def starts_before(self, my_dates, other_dates, other_window): + # apply this window to my_dates[0] and the other_window to other_dates[0] + # return True if this window starts before the other window + + assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" + assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" + assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" + + my_start = my_dates[0] + self._start_np + other_start = other_dates[0] + other_window._start_np + + if my_start == other_start: + return (not other_window.include_start) or self.include_start + return my_start <= other_start + + def ends_after(self, my_dates, other_dates, other_window): + # same as starts_before + assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" + assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" + assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" + + my_end = my_dates[-1] + self._end_np + other_end = other_dates[-1] + other_window._end_np + + if my_end == other_end: + print(".", (not other_window.include_end) or self.include_end) + return (not other_window.include_end) or self.include_end + print(my_end >= other_end) + return my_end >= other_end From 61c8dc0cc06156029cfac0c090b74e5243305ada Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 17:57:11 +0000 Subject: [PATCH 134/212] move code around --- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/recipe/__init__.py | 4 +- .../datasets/commands/recipe/migrate.py | 2 +- .../datasets/create/contexts/__init__.py | 0 .../datasets/create/{ => fields}/__init__.py | 47 +-- src/anemoi/datasets/create/input/__init__.py | 8 +- src/anemoi/datasets/create/input/action.py | 14 +- .../create/input/context/observations.py | 68 +++++ .../create/input/result/observations.py | 275 ++++++++++++++++++ .../datasets/create/observations/__init__.py | 0 src/anemoi/datasets/create/source.py | 43 +++ src/anemoi/datasets/create/sources/csv.py | 42 +++ .../data/records/backends/__init__.py | 6 +- src/anemoi/datasets/data/records/windows.py | 63 +++- tests/create/test_observations.py | 6 +- tests/create/test_observations_mars.py | 8 +- tests/create/test_observations_mars_bufr.py | 8 +- .../test_observations_mars_bufr_complex.py | 8 +- .../test_observations_mars_bufr_parallel.py | 8 +- tests/create/utils/create.py | 2 +- 20 files changed, 558 insertions(+), 56 deletions(-) create mode 100644 src/anemoi/datasets/create/contexts/__init__.py rename src/anemoi/datasets/create/{ => fields}/__init__.py (98%) create mode 100644 src/anemoi/datasets/create/input/context/observations.py create mode 100644 src/anemoi/datasets/create/input/result/observations.py create mode 100644 src/anemoi/datasets/create/observations/__init__.py create mode 100644 src/anemoi/datasets/create/sources/csv.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 86332cfcc..0fc7d04f1 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create import creator_factory + from anemoi.datasets.create.fields import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index bf08d1ee7..9fe7ec3ff 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,8 +15,8 @@ import yaml -from anemoi.datasets.create import config_to_python -from anemoi.datasets.create import validate_config +from anemoi.datasets.create.fields import config_to_python +from anemoi.datasets.create.fields import validate_config from .. import Command from .format import format_recipe diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 6a3c6301d..2a6112410 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.create import validate_config +from anemoi.datasets.create.fields import validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/contexts/__init__.py b/src/anemoi/datasets/create/contexts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/fields/__init__.py similarity index 98% rename from src/anemoi/datasets/create/__init__.py rename to src/anemoi/datasets/create/fields/__init__.py index 2b649b251..9c83320b7 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/fields/__init__.py @@ -37,20 +37,20 @@ from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups -from .check import DatasetName -from .check import check_data_values -from .chunks import ChunkFilter -from .config import build_output -from .config import loader_config -from .input import InputBuilder -from .statistics import Summary -from .statistics import TmpStatistics -from .statistics import check_variance -from .statistics import compute_statistics -from .statistics import default_statistics_dates -from .statistics import fix_variance -from .utils import normalize_and_check_dates -from .writer import ViewCacheArray +from ..check import DatasetName +from ..check import check_data_values +from ..chunks import ChunkFilter +from ..config import build_output +from ..config import loader_config +from ..input import InputBuilder +from ..statistics import Summary +from ..statistics import TmpStatistics +from ..statistics import check_variance +from ..statistics import compute_statistics +from ..statistics import default_statistics_dates +from ..statistics import fix_variance +from ..utils import normalize_and_check_dates +from ..writer import ViewCacheArray LOG = logging.getLogger(__name__) @@ -193,7 +193,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from .zarr import add_zarr_dataset + from ..zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -397,7 +397,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from .utils import cache_context + from ..utils import cache_context return cache_context(self.cache) @@ -473,7 +473,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from .patch import apply_patch + from ..patch import apply_patch apply_patch(self.path, **self.options) @@ -493,7 +493,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from .size import compute_directory_sizes + from ..size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -515,7 +515,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from .zarr import ZarrBuiltRegistry + from ..zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) @@ -625,9 +625,12 @@ def __init__( LOG.info(f"Groups: {self.groups}") + window = self.main_config.dates.get("window") + one_date = self.groups.one_date() - # assert False, (type(one_date), type(self.groups)) - self.minimal_input = self.input.select(one_date) + + self.minimal_input = self.input.select(one_date, window) + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") LOG.info(self.minimal_input) @@ -866,7 +869,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(argument=group) + result = self.input.select(argument=group, window=None) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 22b98d07e..fc2083b6f 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -12,6 +12,7 @@ from typing import Any from anemoi.datasets.create.input.context.field import FieldContext +from anemoi.datasets.create.input.context.observations import ObservationContext class InputBuilder: @@ -44,20 +45,23 @@ def action(self) -> Any: return Recipe(input, sources) - def select(self, argument) -> Any: + def select(self, argument, window) -> Any: """Select data based on the group of dates. Parameters ---------- argument : GroupOfDates Group of dates to select data for. + window : str | None + Window specification. Returns ------- Any Selected data. """ - context = FieldContext(argument, **self.kwargs) + # TODO: move me elsewhere + context = ObservationContext(argument, **self.kwargs) if window else FieldContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) def python_code(self, code): diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 2d64c047c..5928f5301 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -258,6 +258,7 @@ def __call__(self, context, argument): } LEN_KLASS = len(KLASS) +TYPES = {} def make(key, config, *path): @@ -274,17 +275,28 @@ def make(key, config, *path): for name in dataset_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + TYPES[name.replace("_", "-")] = "source" for name in transform_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + TYPES[name.replace("_", "-")] = "source" # Register filters for name in transform_filter_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) + TYPES[name.replace("_", "-")] = "filter" - return KLASS[key.replace("_", "-")](config, *path) + key = key.replace("_", "-") + + if key not in KLASS: + LOG.error(f"Unknown action '{key}' in {'.'.join(x for x in path)}") + for available in sorted(KLASS): + LOG.error(f" Available: {available} (type={TYPES.get(available, 'built-in')})") + raise ValueError(f"Unknown action '{key}' in {'.'.join(x for x in path)}") + + return KLASS[key](config, *path) def action_factory(data, *path): diff --git a/src/anemoi/datasets/create/input/context/observations.py b/src/anemoi/datasets/create/input/context/observations.py new file mode 100644 index 000000000..9963652df --- /dev/null +++ b/src/anemoi/datasets/create/input/context/observations.py @@ -0,0 +1,68 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import warnings +from typing import Any + +from anemoi.transform.fields import new_field_with_metadata +from anemoi.transform.fields import new_fieldlist_from_list + +from ..result.observations import ObservationsResult +from . import Context + +LOG = logging.getLogger(__name__) + + +class ObservationContext(Context): + + def __init__( + self, + /, + argument: Any, + **kwargs: Any, + ) -> None: + super().__init__(argument) + + def empty_result(self) -> Any: + return [] + + def source_argument(self, argument: Any) -> Any: + return argument # .dates + + def filter_argument(self, argument: Any) -> Any: + return argument + + def create_result(self, data): + return ObservationsResult(self, data) + + def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: + from anemoi.datasets.dates.groups import GroupOfDates + + return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) + + def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: + warnings.warn("ObservationContext.origin is not implemented", UserWarning) + return data + + origin = action.origin() + + result = [] + for fs in data: + previous = fs.metadata("anemoi_origin", default=None) + fall_through = fs.metadata("anemoi_fall_through", default=False) + if fall_through: + # The field has pass unchanges in a filter + result.append(fs) + else: + anemoi_origin = origin.combine(previous, action, action_arguments) + result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) + + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/input/result/observations.py b/src/anemoi/datasets/create/input/result/observations.py new file mode 100644 index 000000000..fea3d81d0 --- /dev/null +++ b/src/anemoi/datasets/create/input/result/observations.py @@ -0,0 +1,275 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from collections import defaultdict +from typing import Any +from typing import DefaultDict + +from anemoi.utils.dates import as_timedelta + +from . import Result + +LOG = logging.getLogger(__name__) + + +def _fields_metatata(variables: tuple[str, ...], cube: Any) -> dict[str, Any]: + """Retrieve metadata for the given variables and cube. + + Parameters + ---------- + variables : tuple of str + The variables to retrieve metadata for. + cube : Any + The data cube. + + Returns + ------- + dict + The metadata dictionary. + """ + assert isinstance(variables, tuple), variables + + KNOWN: dict[str, dict[str, bool]] = { + "cos_julian_day": dict(computed_forcing=True, constant_in_time=False), + "cos_latitude": dict(computed_forcing=True, constant_in_time=True), + "cos_local_time": dict(computed_forcing=True, constant_in_time=False), + "cos_longitude": dict(computed_forcing=True, constant_in_time=True), + "cos_solar_zenith_angle": dict(computed_forcing=True, constant_in_time=False), + "insolation": dict(computed_forcing=True, constant_in_time=False), + "latitude": dict(computed_forcing=True, constant_in_time=True), + "longitude": dict(computed_forcing=True, constant_in_time=True), + "sin_julian_day": dict(computed_forcing=True, constant_in_time=False), + "sin_latitude": dict(computed_forcing=True, constant_in_time=True), + "sin_local_time": dict(computed_forcing=True, constant_in_time=False), + "sin_longitude": dict(computed_forcing=True, constant_in_time=True), + } + + def _merge(md1: dict[str, Any], md2: dict[str, Any]) -> dict[str, Any]: + assert set(md1.keys()) == set(md2.keys()), (set(md1.keys()), set(md2.keys())) + result: dict[str, Any] = {} + for k in md1.keys(): + v1 = md1[k] + v2 = md2[k] + + if v1 == v2: + result[k] = v1 + continue + + if isinstance(v1, list): + assert v2 not in v1, (v1, v2) + result[k] = sorted(v1 + [v2]) + continue + + if isinstance(v2, list): + assert v1 not in v2, (v1, v2) + result[k] = sorted(v2 + [v1]) + continue + + result[k] = sorted([v1, v2]) + + return result + + mars: dict[str, Any] = {} + other: DefaultDict[str, dict[str, Any]] = defaultdict(dict) + i: int = -1 + date: str | None = None + for c in cube.iterate_cubelets(): + + if date is None: + date = c._coords_names[0] + + if date != c._coords_names[0]: + continue + + if i == -1 or c._coords_names[1] != variables[i]: + i += 1 + + f = cube[c.coords] + md = f.metadata(namespace="mars") + if not md: + md = f.metadata(namespace="default") + + if md.get("param") == "~": + md["param"] = f.metadata("param") + assert md["param"] not in ("~", "unknown"), (md, f.metadata("param")) + + if md.get("param") == "unknown": + md["param"] = str(f.metadata("paramId", default="unknown")) + # assert md['param'] != 'unknown', (md, f.metadata('param')) + + startStep = f.metadata("startStep", default=None) + if startStep is not None: + startStep = as_timedelta(startStep) + + endStep = f.metadata("endStep", default=None) + if endStep is not None: + endStep = as_timedelta(endStep) + + stepTypeForConversion = f.metadata("stepTypeForConversion", default=None) + typeOfStatisticalProcessing = f.metadata("typeOfStatisticalProcessing", default=None) + timeRangeIndicator = f.metadata("timeRangeIndicator", default=None) + + # GRIB1 precipitation accumulations are not correctly encoded + if startStep == endStep and stepTypeForConversion == "accum": + endStep = f.metadata("P1") + startStep = f.metadata("P2") + + if startStep != endStep: + # https://codes.ecmwf.int/grib/format/grib2/ctables/4/10/ + TYPE_OF_STATISTICAL_PROCESSING: dict[int | None, str | None] = { + None: None, + 0: "average", + 1: "accumulation", + 2: "maximum", + 3: "minimum", + 4: "difference(end-start)", + 5: "root_mean_square", + 6: "standard_deviation", + 7: "covariance", + 8: "difference(start-end)", + 9: "ratio", + 10: "standardized_anomaly", + 11: "summation", + 100: "severity", + 101: "mode", + } + + # https://codes.ecmwf.int/grib/format/grib1/ctable/5/ + + TIME_RANGE_INDICATOR: dict[int, str] = { + 4: "accumulation", + 3: "average", + } + + STEP_TYPE_FOR_CONVERSION: dict[str, str] = { + "min": "minimum", + "max": "maximum", + "accum": "accumulation", + } + + # + # A few patches + # + + PATCHES: dict[str, str] = { + "10fg6": "maximum", + "mntpr3": "minimum", # Not in param db + "mntpr6": "minimum", # Not in param db + "mxtpr3": "maximum", # Not in param db + "mxtpr6": "maximum", # Not in param db + } + + process = TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing) + if process is None: + process = TIME_RANGE_INDICATOR.get(timeRangeIndicator) + if process is None: + process = STEP_TYPE_FOR_CONVERSION.get(stepTypeForConversion) + if process is None: + process = PATCHES.get(md["param"]) + if process is not None: + LOG.error(f"Unknown process {stepTypeForConversion} for {md['param']}, using {process} instead") + + if process is None: + raise ValueError( + f"Unknown for {md['param']}:" + f" {stepTypeForConversion=} ({STEP_TYPE_FOR_CONVERSION.get('stepTypeForConversion')})," + f" {typeOfStatisticalProcessing=} ({TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing)})," + f" {timeRangeIndicator=} ({TIME_RANGE_INDICATOR.get(timeRangeIndicator)})" + ) + + # print(md["param"], "startStep", startStep, "endStep", endStep, "process", process, "typeOfStatisticalProcessing", typeOfStatisticalProcessing) + other[variables[i]]["process"] = process + other[variables[i]]["period"] = (startStep, endStep) + + for k in md.copy().keys(): + if k.startswith("_"): + md.pop(k) + + if variables[i] in mars: + mars[variables[i]] = _merge(md, mars[variables[i]]) + else: + mars[variables[i]] = md + + result: dict[str, dict[str, Any]] = {} + for k, v in mars.items(): + result[k] = dict(mars=v) if v else {} + result[k].update(other[k]) + result[k].update(KNOWN.get(k, {})) + # assert result[k], k + + assert i + 1 == len(variables), (i + 1, len(variables)) + return result + + +def _data_request(data: Any) -> dict[str, Any]: + """Build a data request dictionary from the given data. + + Parameters + ---------- + data : Any + The data to build the request from. + + Returns + ------- + dict + The data request dictionary. + """ + date: Any | None = None + params_levels: DefaultDict[str, set] = defaultdict(set) + params_steps: DefaultDict[str, set] = defaultdict(set) + + area: Any | None = None + grid: Any | None = None + + for field in data: + try: + if date is None: + date = field.metadata("valid_datetime") + + if field.metadata("valid_datetime") != date: + continue + + as_mars = field.metadata(namespace="mars") + if not as_mars: + continue + step = as_mars.get("step") + levtype = as_mars.get("levtype", "sfc") + param = as_mars["param"] + levelist = as_mars.get("levelist", None) + area = field.mars_area + grid = field.mars_grid + + if levelist is None: + params_levels[levtype].add(param) + else: + params_levels[levtype].add((param, levelist)) + + if step: + params_steps[levtype].add((param, step)) + except Exception: + LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) + + def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: + new_dic: dict[str, list[Any]] = {} + for k, v in old_dic.items(): + new_dic[k] = sorted(list(v)) + return new_dic + + params_steps = sort(params_steps) + params_levels = sort(params_levels) + + return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) + + +class ObservationsResult(Result): + + def __init__(self, context: Any, datasource: Any) -> None: + + pass diff --git a/src/anemoi/datasets/create/observations/__init__.py b/src/anemoi/datasets/create/observations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/create/source.py b/src/anemoi/datasets/create/source.py index f79b0e9dd..8c9c3044d 100644 --- a/src/anemoi/datasets/create/source.py +++ b/src/anemoi/datasets/create/source.py @@ -49,3 +49,46 @@ def execute(self, dates: DateList) -> ekd.FieldList: """ pass + + +class FieldSource(Source): + """A source that returns a predefined FieldList.""" + + def __init__(self, context: any, data: ekd.FieldList, *args: tuple, **kwargs: dict): + """Initialise the FieldSource. + + Parameters + ---------- + context : Any + The context for the data source. + data : ekd.FieldList + The predefined data to return. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + """ + super().__init__(context, *args, **kwargs) + self.data = data + + def execute(self, dates: DateList) -> ekd.FieldList: + """Return the predefined FieldList. + + Parameters + ---------- + dates : DateList + The input dates (not used in this implementation). + + Returns + ------- + ekd.FieldList + The predefined data. + """ + self.context.trace(self.emoji, f"FieldSource returning {len(self.data)} fields") + return self.data + + +class ObservationsSource(Source): + """A source that retrieves observational data.""" + + pass diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py new file mode 100644 index 000000000..7cc38b56e --- /dev/null +++ b/src/anemoi/datasets/create/sources/csv.py @@ -0,0 +1,42 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from ..source import ObservationsSource +from . import source_registry + + +@source_registry.register("csv") +class CSVSource(ObservationsSource): + """A source that reads data from a CSV file.""" + + emoji = "📄" # For tracing + + def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): + """Initialise the CSVSource. + + Parameters + ---------- + context : Any + The context for the data source. + filepath : str + The path to the CSV file. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + """ + super().__init__(context, *args, **kwargs) + self.path = path + + def execute(self, dates): + import pandas as pd + + frame = pd.read_csv(self.path) + print(frame) diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index d5342d8e9..e5aab1bd6 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -213,7 +213,7 @@ def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): os.rename(tmp_path, out_path) def write_metadata(self, metadata): - from anemoi.datasets.create import json_tidy + from anemoi.datasets.create.fields import json_tidy os.makedirs(self.path, exist_ok=True) @@ -257,7 +257,7 @@ def write(self, i, data, **kwargs): ds.to_netcdf(out_path) def write_metadata(self, metadata): - from anemoi.datasets.create import json_tidy + from anemoi.datasets.create.fields import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -295,7 +295,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create import json_tidy + from anemoi.datasets.create.fields import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/src/anemoi/datasets/data/records/windows.py b/src/anemoi/datasets/data/records/windows.py index bfa3950a7..43ee671d7 100644 --- a/src/anemoi/datasets/data/records/windows.py +++ b/src/anemoi/datasets/data/records/windows.py @@ -73,7 +73,7 @@ def _to_timedelta(t): ) -class AbsoluteWindow: +class Interval: # not used but expected to be useful when building datasets. And used in tests def __init__(self, start, end, include_start=True, include_end=True): assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" @@ -90,13 +90,66 @@ def __init__(self, start, end, include_start=True, include_end=True): def __repr__(self): return f"{'[' if self.include_start else '('}{self.start.isoformat()},{self.end.isoformat()}{']' if self.include_end else ')'}" + def intersection(self, other): + assert isinstance(other, Interval), f"`other` must be a Interval, got {type(other)}" + + if self._start_np > other._end_np or other._start_np > self._end_np: + return None # no intersection + + if self._start_np < other._start_np: + start = other._start_np + include_start = other.include_start + elif self._start_np > other._start_np: + start = self._start_np + include_start = self.include_start + else: # equal + start = self._start_np + include_start = self.include_start and other.include_start + + if self._end_np < other._end_np: + end = self._end_np + include_end = self.include_end + elif self._end_np > other._end_np: + end = other._end_np + include_end = other.include_end + else: # equal + end = self._end_np + include_end = self.include_end and other.include_end + + return Interval(start=start, end=end, include_start=include_start, include_end=include_end) + + def union(self, other): + assert isinstance(other, Interval), f"`other` must be a Interval, got {type(other)}" + + if self._start_np < other._start_np: + start = self._start_np + include_start = self.include_start + elif self._start_np > other._start_np: + start = other._start_np + include_start = other.include_start + else: # equal + start = self._start_np + include_start = self.include_start or other.include_start + + if self._end_np > other._end_np: + end = self._end_np + include_end = self.include_end + elif self._end_np < other._end_np: + end = other._end_np + include_end = other.include_end + else: # equal + end = self._end_np + include_end = self.include_end or other.include_end + + return Interval(start=start, end=end, include_start=include_start, include_end=include_end) + class WindowsSpec: # A window specified by relative timedeltas, such as (-6h, 0h] # # the term "WindowSpec" is used here to avoid confusion between # - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec) - # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (AbsoluteWindow) + # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (Interval) # # but is is more confusing, it should be renamed as Window. @@ -107,8 +160,10 @@ def __init__(self, *, start, end, include_start=False, include_end=True): assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" assert include_start in (True, False), f"Invalid include_start {include_start}" # None is not allowed assert include_end in (True, False), f"Invalid include_end {include_end}" # None is not allowed + if start >= end: raise ValueError(f"start {start} must be less than end {end}") + self.start = start self.end = end self.include_start = include_start @@ -117,13 +172,13 @@ def __init__(self, *, start, end, include_start=False, include_end=True): self._start_np = _to_numpy_timedelta(start) self._end_np = _to_numpy_timedelta(end) - def to_absolute_window(self, date): + def to_interval(self, date): """Convert the window to an absolute window based on a date.""" # not used but expected to be useful when building datasets. And used in tests assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" start = date + self.start end = date + self.end - return AbsoluteWindow(start=start, end=end, include_start=self.include_start, include_end=self.include_end) + return Interval(start=start, end=end, include_start=self.include_start, include_end=self.include_end) def __repr__(self): first = "[" if self.include_start else "(" diff --git a/tests/create/test_observations.py b/tests/create/test_observations.py index 6bdf577ca..01410cf2f 100644 --- a/tests/create/test_observations.py +++ b/tests/create/test_observations.py @@ -14,7 +14,7 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import Interval from anemoi.datasets.data.records import window_from_str @@ -24,7 +24,7 @@ def __init__(self, data): self.data = data def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" if window.include_start: mask = self.data["times"] > window.start @@ -66,7 +66,7 @@ def __call__(self, df): filter = DummyFilter() for d in dates: - window = window_from_str("(-5h, 1h]").to_absolute_window(d) + window = window_from_str("(-5h, 1h]").to_interval(d) d = source(window) d = filter(d) print(window) diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index c5823476b..b28e00708 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -16,7 +16,7 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import Interval from anemoi.datasets.data.records import window_from_str log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def __init__(self, data): self.data = data def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" if window.include_start: mask = self.data["times"] > window.start @@ -52,7 +52,7 @@ def __init__(self, request_dict, pre_process_dict, process_func): self.process_func = process_func def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" request_dict = self.request_dict request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" @@ -120,7 +120,7 @@ def __call__(self, df): filter = ColFilter("obsvalue_v10m_0") for d in dates: - window = window_from_str("(-5h, 1h]").to_absolute_window(d) + window = window_from_str("(-5h, 1h]").to_interval(d) print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) d = source(window) d = filter(d) diff --git a/tests/create/test_observations_mars_bufr.py b/tests/create/test_observations_mars_bufr.py index a22c7c3b9..747274af5 100644 --- a/tests/create/test_observations_mars_bufr.py +++ b/tests/create/test_observations_mars_bufr.py @@ -16,7 +16,7 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import Interval from anemoi.datasets.data.records import window_from_str log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def __init__(self, data): self.data = data def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" if window.include_start: mask = self.data["times"] > window.start @@ -52,7 +52,7 @@ def __init__(self, request_dict, pre_process_dict, process_func): self.process_func = process_func def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" request_dict = self.request_dict request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" @@ -119,7 +119,7 @@ def __call__(self, df): filter = ColFilter("obsvalue_precip1h_0") for d in dates: - window = window_from_str("(-5h, 1h]").to_absolute_window(d) + window = window_from_str("(-5h, 1h]").to_interval(d) print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) d = source(window) d = filter(d) diff --git a/tests/create/test_observations_mars_bufr_complex.py b/tests/create/test_observations_mars_bufr_complex.py index ddb8afbac..2901e9cf6 100644 --- a/tests/create/test_observations_mars_bufr_complex.py +++ b/tests/create/test_observations_mars_bufr_complex.py @@ -16,7 +16,7 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import Interval from anemoi.datasets.data.records import window_from_str log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def __init__(self, data): self.data = data def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" if window.include_start: mask = self.data["times"] > window.start @@ -52,7 +52,7 @@ def __init__(self, request_dict, pre_process_dict, process_func): self.process_func = process_func def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" request_dict = self.request_dict request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" @@ -138,7 +138,7 @@ def __call__(self, df): filter = ColFilter("obsvalue_rawbt_9") for d in dates: - window = window_from_str("(-5h, 1h]").to_absolute_window(d) + window = window_from_str("(-5h, 1h]").to_interval(d) print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) d = source(window) d = filter(d) diff --git a/tests/create/test_observations_mars_bufr_parallel.py b/tests/create/test_observations_mars_bufr_parallel.py index d743efa8e..05c9397b3 100644 --- a/tests/create/test_observations_mars_bufr_parallel.py +++ b/tests/create/test_observations_mars_bufr_parallel.py @@ -16,7 +16,7 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import AbsoluteWindow +from anemoi.datasets.data.records import Interval from anemoi.datasets.data.records import window_from_str log = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def __init__(self, data): self.data = data def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" if window.include_start: mask = self.data["times"] > window.start @@ -52,7 +52,7 @@ def __init__(self, request_dict, pre_process_dict, process_func): self.process_func = process_func def __call__(self, window): - assert isinstance(window, AbsoluteWindow), "window must be an AbsoluteWindow" + assert isinstance(window, Interval), "window must be an Interval" request_dict = self.request_dict request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" @@ -120,7 +120,7 @@ def __call__(self, df): filter = ColFilter("obsvalue_precip1h_0") for d in dates: - window = window_from_str("(-5h, 1h]").to_absolute_window(d) + window = window_from_str("(-5h, 1h]").to_interval(d) print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) d = source(window) d = filter(d) diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index 78dc09133..d5cc3585c 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create import creator_factory +from anemoi.datasets.create.fields import creator_factory class TestingContext: From 373792b6f4263832662c0170e926053b9b3043f3 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:27:33 +0000 Subject: [PATCH 135/212] update --- src/anemoi/datasets/create/fields/__init__.py | 16 ++++++++++------ .../context/field.py => fields/context.py} | 14 ++++++++------ .../result/field.py => fields/result.py} | 6 +++--- src/anemoi/datasets/create/input/__init__.py | 19 +++++++++---------- .../datasets/create/input/context/__init__.py | 3 +-- 5 files changed, 31 insertions(+), 27 deletions(-) rename src/anemoi/datasets/create/{input/context/field.py => fields/context.py} (91%) rename src/anemoi/datasets/create/{input/result/field.py => fields/result.py} (99%) diff --git a/src/anemoi/datasets/create/fields/__init__.py b/src/anemoi/datasets/create/fields/__init__.py index 9c83320b7..7b595406c 100644 --- a/src/anemoi/datasets/create/fields/__init__.py +++ b/src/anemoi/datasets/create/fields/__init__.py @@ -51,6 +51,7 @@ from ..statistics import fix_variance from ..utils import normalize_and_check_dates from ..writer import ViewCacheArray +from .context import FieldContext LOG = logging.getLogger(__name__) @@ -551,14 +552,17 @@ def create_elements(self, config: Any) -> None: self.output = build_output(config.output, parent=self) - self.input = InputBuilder( - config.input, - data_sources=config.get("data_sources", {}), + self.context = FieldContext( order_by=self.output.order_by, flatten_grid=self.output.flatten_grid, remapping=build_remapping(self.output.remapping), use_grib_paramid=config.build.use_grib_paramid, ) + + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + ) LOG.debug("✅ INPUT_BUILDER") LOG.debug(self.input) @@ -625,11 +629,11 @@ def __init__( LOG.info(f"Groups: {self.groups}") - window = self.main_config.dates.get("window") + # window = self.main_config.dates.get("window") one_date = self.groups.one_date() - self.minimal_input = self.input.select(one_date, window) + self.minimal_input = self.input.select(self.context, one_date) LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") LOG.info(self.minimal_input) @@ -869,7 +873,7 @@ def _run(self) -> None: # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(argument=group, window=None) + result = self.input.select(self.context, argument=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/fields/context.py similarity index 91% rename from src/anemoi/datasets/create/input/context/field.py rename to src/anemoi/datasets/create/fields/context.py index 8503e618a..f4face597 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/fields/context.py @@ -15,8 +15,7 @@ from anemoi.transform.fields import new_fieldlist_from_list from earthkit.data.core.order import build_remapping -from ..result.field import FieldResult -from . import Context +from anemoi.datasets.create.input.context import Context LOG = logging.getLogger(__name__) @@ -26,13 +25,14 @@ class FieldContext(Context): def __init__( self, /, - argument: Any, order_by: str, flatten_grid: bool, remapping: dict[str, Any], use_grib_paramid: bool, ) -> None: - super().__init__(argument) + + super().__init__() + self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) @@ -50,8 +50,10 @@ def source_argument(self, argument: Any) -> Any: def filter_argument(self, argument: Any) -> Any: return argument - def create_result(self, data): - return FieldResult(self, data) + def create_result(self, argument, data): + from .result import FieldResult + + return FieldResult(self, argument, data) def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: from anemoi.datasets.dates.groups import GroupOfDates diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/fields/result.py similarity index 99% rename from src/anemoi/datasets/create/input/result/field.py rename to src/anemoi/datasets/create/fields/result.py index 7363ebf00..d4bcf58ea 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/fields/result.py @@ -22,7 +22,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from . import Result +from anemoi.datasets.create.input.result import Result LOG = logging.getLogger(__name__) @@ -282,13 +282,13 @@ class FieldResult(Result): empty: bool = False _coords_already_built: bool = False - def __init__(self, context: Any, datasource: Any) -> None: + def __init__(self, context: Any, argument: Any, datasource: Any) -> None: from anemoi.datasets.dates.groups import GroupOfDates self.context: Any = context self.datasource = datasource - self.group_of_dates = context.argument + self.group_of_dates = argument assert isinstance( self.group_of_dates, GroupOfDates ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index fc2083b6f..d29cbc2a1 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -11,14 +11,11 @@ from functools import cached_property from typing import Any -from anemoi.datasets.create.input.context.field import FieldContext -from anemoi.datasets.create.input.context.observations import ObservationContext - class InputBuilder: """Builder class for creating input data from configuration and data sources.""" - def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> None: + def __init__(self, config: dict, data_sources: dict | list) -> None: """Initialize the InputBuilder. Parameters @@ -30,7 +27,6 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No **kwargs : Any Additional keyword arguments. """ - self.kwargs = kwargs self.config = deepcopy(config) self.data_sources = deepcopy(dict(data_sources=data_sources)) @@ -45,15 +41,15 @@ def action(self) -> Any: return Recipe(input, sources) - def select(self, argument, window) -> Any: + def select(self, context, argument) -> Any: """Select data based on the group of dates. Parameters ---------- + context : Any + The context for the data selection. argument : GroupOfDates Group of dates to select data for. - window : str | None - Window specification. Returns ------- @@ -61,8 +57,11 @@ def select(self, argument, window) -> Any: Selected data. """ # TODO: move me elsewhere - context = ObservationContext(argument, **self.kwargs) if window else FieldContext(argument, **self.kwargs) - return context.create_result(self.action(context, argument)) + + return context.create_result( + argument, + self.action(context, argument), + ) def python_code(self, code): return self.action.python_code(code) diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context/__init__.py index 578ddaf66..28c797dd5 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/create/input/context/__init__.py @@ -18,10 +18,9 @@ class Context(ABC): """Context for building input data.""" - def __init__(self, /, argument: Any) -> None: + def __init__(self) -> None: self.results = {} self.cache = {} - self.argument = argument def trace(self, emoji, *message) -> None: From d58ae51f5014f0a69684d31a4c248bea9d8133e9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:29:12 +0000 Subject: [PATCH 136/212] refactor --- .../input/{context/__init__.py => context.py} | 0 .../create/input/context/observations.py | 68 ----- .../input/{result/__init__.py => result.py} | 0 .../create/input/result/observations.py | 275 ------------------ 4 files changed, 343 deletions(-) rename src/anemoi/datasets/create/input/{context/__init__.py => context.py} (100%) delete mode 100644 src/anemoi/datasets/create/input/context/observations.py rename src/anemoi/datasets/create/input/{result/__init__.py => result.py} (100%) delete mode 100644 src/anemoi/datasets/create/input/result/observations.py diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/create/input/context.py similarity index 100% rename from src/anemoi/datasets/create/input/context/__init__.py rename to src/anemoi/datasets/create/input/context.py diff --git a/src/anemoi/datasets/create/input/context/observations.py b/src/anemoi/datasets/create/input/context/observations.py deleted file mode 100644 index 9963652df..000000000 --- a/src/anemoi/datasets/create/input/context/observations.py +++ /dev/null @@ -1,68 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import warnings -from typing import Any - -from anemoi.transform.fields import new_field_with_metadata -from anemoi.transform.fields import new_fieldlist_from_list - -from ..result.observations import ObservationsResult -from . import Context - -LOG = logging.getLogger(__name__) - - -class ObservationContext(Context): - - def __init__( - self, - /, - argument: Any, - **kwargs: Any, - ) -> None: - super().__init__(argument) - - def empty_result(self) -> Any: - return [] - - def source_argument(self, argument: Any) -> Any: - return argument # .dates - - def filter_argument(self, argument: Any) -> Any: - return argument - - def create_result(self, data): - return ObservationsResult(self, data) - - def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: - from anemoi.datasets.dates.groups import GroupOfDates - - return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - - def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: - warnings.warn("ObservationContext.origin is not implemented", UserWarning) - return data - - origin = action.origin() - - result = [] - for fs in data: - previous = fs.metadata("anemoi_origin", default=None) - fall_through = fs.metadata("anemoi_fall_through", default=False) - if fall_through: - # The field has pass unchanges in a filter - result.append(fs) - else: - anemoi_origin = origin.combine(previous, action, action_arguments) - result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) - - return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/create/input/result.py similarity index 100% rename from src/anemoi/datasets/create/input/result/__init__.py rename to src/anemoi/datasets/create/input/result.py diff --git a/src/anemoi/datasets/create/input/result/observations.py b/src/anemoi/datasets/create/input/result/observations.py deleted file mode 100644 index fea3d81d0..000000000 --- a/src/anemoi/datasets/create/input/result/observations.py +++ /dev/null @@ -1,275 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from collections import defaultdict -from typing import Any -from typing import DefaultDict - -from anemoi.utils.dates import as_timedelta - -from . import Result - -LOG = logging.getLogger(__name__) - - -def _fields_metatata(variables: tuple[str, ...], cube: Any) -> dict[str, Any]: - """Retrieve metadata for the given variables and cube. - - Parameters - ---------- - variables : tuple of str - The variables to retrieve metadata for. - cube : Any - The data cube. - - Returns - ------- - dict - The metadata dictionary. - """ - assert isinstance(variables, tuple), variables - - KNOWN: dict[str, dict[str, bool]] = { - "cos_julian_day": dict(computed_forcing=True, constant_in_time=False), - "cos_latitude": dict(computed_forcing=True, constant_in_time=True), - "cos_local_time": dict(computed_forcing=True, constant_in_time=False), - "cos_longitude": dict(computed_forcing=True, constant_in_time=True), - "cos_solar_zenith_angle": dict(computed_forcing=True, constant_in_time=False), - "insolation": dict(computed_forcing=True, constant_in_time=False), - "latitude": dict(computed_forcing=True, constant_in_time=True), - "longitude": dict(computed_forcing=True, constant_in_time=True), - "sin_julian_day": dict(computed_forcing=True, constant_in_time=False), - "sin_latitude": dict(computed_forcing=True, constant_in_time=True), - "sin_local_time": dict(computed_forcing=True, constant_in_time=False), - "sin_longitude": dict(computed_forcing=True, constant_in_time=True), - } - - def _merge(md1: dict[str, Any], md2: dict[str, Any]) -> dict[str, Any]: - assert set(md1.keys()) == set(md2.keys()), (set(md1.keys()), set(md2.keys())) - result: dict[str, Any] = {} - for k in md1.keys(): - v1 = md1[k] - v2 = md2[k] - - if v1 == v2: - result[k] = v1 - continue - - if isinstance(v1, list): - assert v2 not in v1, (v1, v2) - result[k] = sorted(v1 + [v2]) - continue - - if isinstance(v2, list): - assert v1 not in v2, (v1, v2) - result[k] = sorted(v2 + [v1]) - continue - - result[k] = sorted([v1, v2]) - - return result - - mars: dict[str, Any] = {} - other: DefaultDict[str, dict[str, Any]] = defaultdict(dict) - i: int = -1 - date: str | None = None - for c in cube.iterate_cubelets(): - - if date is None: - date = c._coords_names[0] - - if date != c._coords_names[0]: - continue - - if i == -1 or c._coords_names[1] != variables[i]: - i += 1 - - f = cube[c.coords] - md = f.metadata(namespace="mars") - if not md: - md = f.metadata(namespace="default") - - if md.get("param") == "~": - md["param"] = f.metadata("param") - assert md["param"] not in ("~", "unknown"), (md, f.metadata("param")) - - if md.get("param") == "unknown": - md["param"] = str(f.metadata("paramId", default="unknown")) - # assert md['param'] != 'unknown', (md, f.metadata('param')) - - startStep = f.metadata("startStep", default=None) - if startStep is not None: - startStep = as_timedelta(startStep) - - endStep = f.metadata("endStep", default=None) - if endStep is not None: - endStep = as_timedelta(endStep) - - stepTypeForConversion = f.metadata("stepTypeForConversion", default=None) - typeOfStatisticalProcessing = f.metadata("typeOfStatisticalProcessing", default=None) - timeRangeIndicator = f.metadata("timeRangeIndicator", default=None) - - # GRIB1 precipitation accumulations are not correctly encoded - if startStep == endStep and stepTypeForConversion == "accum": - endStep = f.metadata("P1") - startStep = f.metadata("P2") - - if startStep != endStep: - # https://codes.ecmwf.int/grib/format/grib2/ctables/4/10/ - TYPE_OF_STATISTICAL_PROCESSING: dict[int | None, str | None] = { - None: None, - 0: "average", - 1: "accumulation", - 2: "maximum", - 3: "minimum", - 4: "difference(end-start)", - 5: "root_mean_square", - 6: "standard_deviation", - 7: "covariance", - 8: "difference(start-end)", - 9: "ratio", - 10: "standardized_anomaly", - 11: "summation", - 100: "severity", - 101: "mode", - } - - # https://codes.ecmwf.int/grib/format/grib1/ctable/5/ - - TIME_RANGE_INDICATOR: dict[int, str] = { - 4: "accumulation", - 3: "average", - } - - STEP_TYPE_FOR_CONVERSION: dict[str, str] = { - "min": "minimum", - "max": "maximum", - "accum": "accumulation", - } - - # - # A few patches - # - - PATCHES: dict[str, str] = { - "10fg6": "maximum", - "mntpr3": "minimum", # Not in param db - "mntpr6": "minimum", # Not in param db - "mxtpr3": "maximum", # Not in param db - "mxtpr6": "maximum", # Not in param db - } - - process = TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing) - if process is None: - process = TIME_RANGE_INDICATOR.get(timeRangeIndicator) - if process is None: - process = STEP_TYPE_FOR_CONVERSION.get(stepTypeForConversion) - if process is None: - process = PATCHES.get(md["param"]) - if process is not None: - LOG.error(f"Unknown process {stepTypeForConversion} for {md['param']}, using {process} instead") - - if process is None: - raise ValueError( - f"Unknown for {md['param']}:" - f" {stepTypeForConversion=} ({STEP_TYPE_FOR_CONVERSION.get('stepTypeForConversion')})," - f" {typeOfStatisticalProcessing=} ({TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing)})," - f" {timeRangeIndicator=} ({TIME_RANGE_INDICATOR.get(timeRangeIndicator)})" - ) - - # print(md["param"], "startStep", startStep, "endStep", endStep, "process", process, "typeOfStatisticalProcessing", typeOfStatisticalProcessing) - other[variables[i]]["process"] = process - other[variables[i]]["period"] = (startStep, endStep) - - for k in md.copy().keys(): - if k.startswith("_"): - md.pop(k) - - if variables[i] in mars: - mars[variables[i]] = _merge(md, mars[variables[i]]) - else: - mars[variables[i]] = md - - result: dict[str, dict[str, Any]] = {} - for k, v in mars.items(): - result[k] = dict(mars=v) if v else {} - result[k].update(other[k]) - result[k].update(KNOWN.get(k, {})) - # assert result[k], k - - assert i + 1 == len(variables), (i + 1, len(variables)) - return result - - -def _data_request(data: Any) -> dict[str, Any]: - """Build a data request dictionary from the given data. - - Parameters - ---------- - data : Any - The data to build the request from. - - Returns - ------- - dict - The data request dictionary. - """ - date: Any | None = None - params_levels: DefaultDict[str, set] = defaultdict(set) - params_steps: DefaultDict[str, set] = defaultdict(set) - - area: Any | None = None - grid: Any | None = None - - for field in data: - try: - if date is None: - date = field.metadata("valid_datetime") - - if field.metadata("valid_datetime") != date: - continue - - as_mars = field.metadata(namespace="mars") - if not as_mars: - continue - step = as_mars.get("step") - levtype = as_mars.get("levtype", "sfc") - param = as_mars["param"] - levelist = as_mars.get("levelist", None) - area = field.mars_area - grid = field.mars_grid - - if levelist is None: - params_levels[levtype].add(param) - else: - params_levels[levtype].add((param, levelist)) - - if step: - params_steps[levtype].add((param, step)) - except Exception: - LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) - - def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: - new_dic: dict[str, list[Any]] = {} - for k, v in old_dic.items(): - new_dic[k] = sorted(list(v)) - return new_dic - - params_steps = sort(params_steps) - params_levels = sort(params_levels) - - return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) - - -class ObservationsResult(Result): - - def __init__(self, context: Any, datasource: Any) -> None: - - pass From b193993eb641cb18b4bc90068cf48c47f7843549 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:35:30 +0000 Subject: [PATCH 137/212] absolute imports --- src/anemoi/datasets/__init__.py | 12 ++--- src/anemoi/datasets/__main__.py | 4 +- src/anemoi/datasets/commands/check.py | 3 +- src/anemoi/datasets/commands/cleanup.py | 3 +- src/anemoi/datasets/commands/compare-lam.py | 3 +- src/anemoi/datasets/commands/compare.py | 3 +- src/anemoi/datasets/commands/copy.py | 3 +- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/finalise-additions.py | 3 +- src/anemoi/datasets/commands/finalise.py | 3 +- src/anemoi/datasets/commands/grib-index.py | 2 +- .../datasets/commands/init-additions.py | 3 +- src/anemoi/datasets/commands/init.py | 3 +- src/anemoi/datasets/commands/inspect.py | 3 +- .../datasets/commands/load-additions.py | 3 +- src/anemoi/datasets/commands/load.py | 3 +- src/anemoi/datasets/commands/patch.py | 3 +- src/anemoi/datasets/commands/publish.py | 2 +- .../datasets/commands/recipe/__init__.py | 7 ++- src/anemoi/datasets/commands/recipe/format.py | 2 +- src/anemoi/datasets/commands/scan.py | 2 +- src/anemoi/datasets/commands/validate.py | 3 +- .../datasets/create/contexts/__init__.py | 0 src/anemoi/datasets/create/fields/__init__.py | 43 ++++++++--------- src/anemoi/datasets/create/fields/context.py | 2 +- src/anemoi/datasets/create/input/__init__.py | 4 +- src/anemoi/datasets/create/input/action.py | 6 +-- .../datasets/create/input/data_sources.py | 10 ++-- .../datasets/create/sources/accumulations.py | 7 ++- .../datasets/create/sources/accumulations2.py | 3 +- .../datasets/create/sources/anemoi_dataset.py | 2 +- .../datasets/create/sources/constants.py | 2 +- src/anemoi/datasets/create/sources/csv.py | 4 +- .../datasets/create/sources/eccc_fstd.py | 4 +- src/anemoi/datasets/create/sources/empty.py | 2 +- src/anemoi/datasets/create/sources/fdb.py | 5 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/grib.py | 2 +- .../datasets/create/sources/grib_index.py | 2 +- .../datasets/create/sources/hindcasts.py | 3 +- src/anemoi/datasets/create/sources/legacy.py | 4 +- src/anemoi/datasets/create/sources/mars.py | 3 +- src/anemoi/datasets/create/sources/netcdf.py | 4 +- src/anemoi/datasets/create/sources/opendap.py | 4 +- .../create/sources/planetary_computer.py | 4 +- .../datasets/create/sources/recentre.py | 5 +- src/anemoi/datasets/create/sources/source.py | 3 +- .../datasets/create/sources/tendencies.py | 3 +- src/anemoi/datasets/create/sources/xarray.py | 9 ++-- .../create/sources/xarray_kerchunk.py | 4 +- .../create/sources/xarray_support/__init__.py | 5 +- .../create/sources/xarray_support/field.py | 6 +-- .../sources/xarray_support/fieldlist.py | 12 ++--- .../create/sources/xarray_support/flavour.py | 38 +++++++-------- .../create/sources/xarray_support/metadata.py | 2 +- .../create/sources/xarray_support/time.py | 4 +- .../create/sources/xarray_support/variable.py | 2 +- .../datasets/create/sources/xarray_zarr.py | 4 +- src/anemoi/datasets/create/sources/zenodo.py | 6 +-- .../datasets/create/statistics/__init__.py | 4 +- .../datasets/create/statistics/summary.py | 6 +-- src/anemoi/datasets/data/__init__.py | 10 ++-- src/anemoi/datasets/data/complement.py | 24 +++++----- src/anemoi/datasets/data/concat.py | 30 ++++++------ src/anemoi/datasets/data/dataset.py | 48 +++++++++---------- src/anemoi/datasets/data/debug.py | 2 +- src/anemoi/datasets/data/ensemble.py | 22 ++++----- src/anemoi/datasets/data/fill_missing.py | 21 ++++---- src/anemoi/datasets/data/forwards.py | 20 ++++---- src/anemoi/datasets/data/grids.py | 30 ++++++------ src/anemoi/datasets/data/indexing.py | 6 +-- src/anemoi/datasets/data/interpolate.py | 24 +++++----- src/anemoi/datasets/data/join.py | 30 ++++++------ src/anemoi/datasets/data/masked.py | 26 +++++----- src/anemoi/datasets/data/merge.py | 26 +++++----- src/anemoi/datasets/data/misc.py | 36 +++++++------- src/anemoi/datasets/data/missing.py | 17 ++++--- .../datasets/data/observations/__init__.py | 7 ++- src/anemoi/datasets/data/records/__init__.py | 3 +- src/anemoi/datasets/data/rescale.py | 20 ++++---- src/anemoi/datasets/data/select.py | 24 +++++----- src/anemoi/datasets/data/statistics.py | 8 ++-- src/anemoi/datasets/data/stores.py | 22 ++++----- src/anemoi/datasets/data/subset.py | 30 ++++++------ src/anemoi/datasets/data/unchecked.py | 16 +++---- src/anemoi/datasets/data/xy.py | 12 ++--- src/anemoi/datasets/recipe.py | 2 +- 87 files changed, 398 insertions(+), 428 deletions(-) delete mode 100644 src/anemoi/datasets/create/contexts/__init__.py diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index fe6ca61f1..620f5e80f 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -8,16 +8,16 @@ # nor does it submit to any jurisdiction. -from .data import MissingDateError -from .data import add_dataset_path -from .data import add_named_dataset -from .data import list_dataset_names -from .data import open_dataset +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data import add_dataset_path +from anemoi.datasets.data import add_named_dataset +from anemoi.datasets.data import list_dataset_names +from anemoi.datasets.data import open_dataset try: # NOTE: the `_version.py` file must not be present in the git repository # as it is generated by setuptools at install time - from ._version import __version__ # type: ignore + from anemoi.datasets._version import __version__ # type: ignore except ImportError: # pragma: no cover # Local copy or not installed with setuptools __version__ = "999" diff --git a/src/anemoi/datasets/__main__.py b/src/anemoi/datasets/__main__.py index 62b7d7c73..f47c46050 100644 --- a/src/anemoi/datasets/__main__.py +++ b/src/anemoi/datasets/__main__.py @@ -12,8 +12,8 @@ from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser -from . import __version__ -from .commands import COMMANDS +from anemoi.datasets import __version__ +from anemoi.datasets.commands import COMMANDS # For read-the-docs diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 61b29bf23..4202ed09f 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,10 +13,9 @@ import yaml +from anemoi.datasets.commands import Command from anemoi.datasets.create.check import DatasetName -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/cleanup.py b/src/anemoi/datasets/commands/cleanup.py index 0b3a393bd..25b5b9ca0 100644 --- a/src/anemoi/datasets/commands/cleanup.py +++ b/src/anemoi/datasets/commands/cleanup.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/compare-lam.py b/src/anemoi/datasets/commands/compare-lam.py index 74d97bb48..92ea9a6af 100644 --- a/src/anemoi/datasets/commands/compare-lam.py +++ b/src/anemoi/datasets/commands/compare-lam.py @@ -12,8 +12,7 @@ import os from anemoi.datasets import open_dataset - -from . import Command +from anemoi.datasets.commands import Command RADIUS_EARTH_KM = 6371.0 # Earth's radius in kilometers diff --git a/src/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py index ffe1ec02e..bbd121bd1 100644 --- a/src/anemoi/datasets/commands/compare.py +++ b/src/anemoi/datasets/commands/compare.py @@ -15,8 +15,7 @@ import zarr from anemoi.datasets import open_dataset - -from . import Command +from anemoi.datasets.commands import Command class Compare(Command): diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 406c13de7..9628bae8e 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,8 +20,7 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr - -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 0fc7d04f1..45af78a44 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -18,7 +18,7 @@ import tqdm from anemoi.utils.humanize import seconds_to_human -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index 811760ae9..25380fbba 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index 53999ad50..5197fb73c 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index cfd7a08e8..b5cc910d2 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -13,7 +13,7 @@ import tqdm -from . import Command +from anemoi.datasets.commands import Command class GribIndexCmd(Command): diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index 09788f0e4..c49bbf76c 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index 0ca540b86..c5aa515fd 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index ad26c454e..71376ea56 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,11 +27,10 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset +from anemoi.datasets.commands import Command from anemoi.datasets.data.stores import open_zarr from anemoi.datasets.data.stores import zarr_lookup -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index a8cd5d5c9..82dec8b0a 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 3d969f5c3..7b1c1f684 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/patch.py b/src/anemoi/datasets/commands/patch.py index dc8356126..1920fc420 100644 --- a/src/anemoi/datasets/commands/patch.py +++ b/src/anemoi/datasets/commands/patch.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/publish.py b/src/anemoi/datasets/commands/publish.py index 7f719543e..47282e65b 100644 --- a/src/anemoi/datasets/commands/publish.py +++ b/src/anemoi/datasets/commands/publish.py @@ -10,7 +10,7 @@ import logging from typing import Any -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 9fe7ec3ff..26f0d486f 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,13 +15,12 @@ import yaml +from anemoi.datasets.commands import Command +from anemoi.datasets.commands.recipe.format import format_recipe +from anemoi.datasets.commands.recipe.migrate import migrate_recipe from anemoi.datasets.create.fields import config_to_python from anemoi.datasets.create.fields import validate_config -from .. import Command -from .format import format_recipe -from .migrate import migrate_recipe - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index 533a569c1..a291d9573 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -11,7 +11,7 @@ import datetime import logging -from ...dumper import yaml_dump +from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/scan.py b/src/anemoi/datasets/commands/scan.py index 8a048125e..37c8d0cfd 100644 --- a/src/anemoi/datasets/commands/scan.py +++ b/src/anemoi/datasets/commands/scan.py @@ -17,7 +17,7 @@ import tqdm import yaml -from . import Command +from anemoi.datasets.commands import Command KEYS = ("class", "type", "stream", "expver", "levtype", "domain") diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 1382814a7..03691541c 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,10 +10,9 @@ import logging from typing import Any +from anemoi.datasets.commands import Command from anemoi.datasets.validate import validate_dataset -from . import Command - LOG = logging.getLogger(__name__) DEFAULT_DATASET = "aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8" diff --git a/src/anemoi/datasets/create/contexts/__init__.py b/src/anemoi/datasets/create/contexts/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/anemoi/datasets/create/fields/__init__.py b/src/anemoi/datasets/create/fields/__init__.py index 7b595406c..6301d45ee 100644 --- a/src/anemoi/datasets/create/fields/__init__.py +++ b/src/anemoi/datasets/create/fields/__init__.py @@ -31,28 +31,27 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset +from anemoi.datasets.create.check import DatasetName +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.config import build_output +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.fields.context import FieldContext +from anemoi.datasets.create.input import InputBuilder from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage +from anemoi.datasets.create.statistics import Summary +from anemoi.datasets.create.statistics import TmpStatistics +from anemoi.datasets.create.statistics import check_variance +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.statistics import default_statistics_dates +from anemoi.datasets.create.statistics import fix_variance +from anemoi.datasets.create.utils import normalize_and_check_dates +from anemoi.datasets.create.writer import ViewCacheArray from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups -from ..check import DatasetName -from ..check import check_data_values -from ..chunks import ChunkFilter -from ..config import build_output -from ..config import loader_config -from ..input import InputBuilder -from ..statistics import Summary -from ..statistics import TmpStatistics -from ..statistics import check_variance -from ..statistics import compute_statistics -from ..statistics import default_statistics_dates -from ..statistics import fix_variance -from ..utils import normalize_and_check_dates -from ..writer import ViewCacheArray -from .context import FieldContext - LOG = logging.getLogger(__name__) VERSION = "0.30" @@ -194,7 +193,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from ..zarr import add_zarr_dataset + from anemoi.datasets.create.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -398,7 +397,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from ..utils import cache_context + from anemoi.datasets.create.utils import cache_context return cache_context(self.cache) @@ -474,7 +473,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from ..patch import apply_patch + from anemoi.datasets.create.patch import apply_patch apply_patch(self.path, **self.options) @@ -494,7 +493,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from ..size import compute_directory_sizes + from anemoi.datasets.create.size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -516,7 +515,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from ..zarr import ZarrBuiltRegistry + from anemoi.datasets.create.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) @@ -1670,7 +1669,7 @@ def _tidy(d): def config_to_python(config: Any) -> Any: - from ..create.python import PythonScript + from anemoi.datasets.create.create.python import PythonScript raw_config = config diff --git a/src/anemoi/datasets/create/fields/context.py b/src/anemoi/datasets/create/fields/context.py index f4face597..ef3ebeca5 100644 --- a/src/anemoi/datasets/create/fields/context.py +++ b/src/anemoi/datasets/create/fields/context.py @@ -51,7 +51,7 @@ def filter_argument(self, argument: Any) -> Any: return argument def create_result(self, argument, data): - from .result import FieldResult + from anemoi.datasets.create.fields.result import FieldResult return FieldResult(self, argument, data) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index d29cbc2a1..f56bbd067 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -33,8 +33,8 @@ def __init__(self, config: dict, data_sources: dict | list) -> None: @cached_property def action(self) -> Any: """Returns the action object based on the configuration.""" - from .action import Recipe - from .action import action_factory + from anemoi.datasets.create.input.action import Recipe + from anemoi.datasets.create.input.action import action_factory sources = action_factory(self.data_sources, "data_sources") input = action_factory(self.config, "input") diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 5928f5301..831456435 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -154,7 +154,7 @@ def call_object(self, context, source, argument): return context.origin(result, self, argument) def origin(self): - from .origin import Source + from anemoi.datasets.create.input.origin import Source return Source(self.path[-1], self.config) @@ -170,7 +170,7 @@ def combine_origins(self, current, previous): return current def origin(self): - from .origin import Source + from anemoi.datasets.create.input.origin import Source return Source(self.path[-1], self.config) @@ -186,7 +186,7 @@ def call_object(self, context, filter, argument): return context.origin(result, self, argument) def origin(self): - from .origin import Filter + from anemoi.datasets.create.input.origin import Filter return Filter(self.path[-1], self.config) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 9aa2429dd..7a706c8ef 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -13,11 +13,11 @@ from earthkit.data import FieldList -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .misc import _tidy -from .result.field import Result +from anemoi.datasets.create.input.action import Action +from anemoi.datasets.create.input.action import action_factory +from anemoi.datasets.create.input.misc import _tidy +from anemoi.datasets.create.input.result.field import Result +from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index 6acecbf98..40b8749f6 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -20,11 +20,10 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source -from .mars import mars - LOG = logging.getLogger(__name__) @@ -994,7 +993,7 @@ def accumulations( and request.get("stream", "oper") == "oper" and request.get("accumulation_period") == 24 ): - from .accumulations2 import accumulations as accumulations2 + from anemoi.datasets.create.sources.accumulations2 import accumulations as accumulations2 LOG.warning( "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" diff --git a/src/anemoi/datasets/create/sources/accumulations2.py b/src/anemoi/datasets/create/sources/accumulations2.py index f9ddf3b3a..3c34d392e 100644 --- a/src/anemoi/datasets/create/sources/accumulations2.py +++ b/src/anemoi/datasets/create/sources/accumulations2.py @@ -18,11 +18,10 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.mars import mars from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/anemoi_dataset.py b/src/anemoi/datasets/create/sources/anemoi_dataset.py index 12d41db23..a05e7df51 100644 --- a/src/anemoi/datasets/create/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/create/sources/anemoi_dataset.py @@ -9,7 +9,7 @@ import numpy as np -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 104f24863..accde7936 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 7cc38b56e..0b293845e 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from ..source import ObservationsSource -from . import source_registry +from anemoi.datasets.create.source import ObservationsSource +from anemoi.datasets.create.sources import source_registry @source_registry.register("csv") diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py index 41734e9b6..fdd79af8d 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/create/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/empty.py b/src/anemoi/datasets/create/sources/empty.py index fb7fcd906..f948810f5 100644 --- a/src/anemoi/datasets/create/sources/empty.py +++ b/src/anemoi/datasets/create/sources/empty.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py index bb33f7d50..81cdb7e13 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/create/sources/fdb.py @@ -16,11 +16,10 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry from anemoi.datasets.create.typing import DateList -from ..source import Source -from . import source_registry - @source_registry.register("fdb") class FdbSource(Source): diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index bbafaa465..88eca92e4 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 66134e86c..550709f98 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -20,7 +20,7 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index ea6878929..160ff3f3a 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -19,7 +19,7 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/hindcasts.py b/src/anemoi/datasets/create/sources/hindcasts.py index 9c470218c..d796a74af 100644 --- a/src/anemoi/datasets/create/sources/hindcasts.py +++ b/src/anemoi/datasets/create/sources/hindcasts.py @@ -12,10 +12,9 @@ from earthkit.data.core.fieldlist import MultiFieldList +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.mars import mars -from .legacy import legacy_source - LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index d7a15bfe7..0de230d29 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,8 @@ from collections.abc import Callable from typing import Any -from ..source import Source -from . import source_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/mars.py b/src/anemoi/datasets/create/sources/mars.py index 1a419f691..d59f6034d 100644 --- a/src/anemoi/datasets/create/sources/mars.py +++ b/src/anemoi/datasets/create/sources/mars.py @@ -16,10 +16,9 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - DEBUG = False diff --git a/src/anemoi/datasets/create/sources/netcdf.py b/src/anemoi/datasets/create/sources/netcdf.py index a73c095d3..606a8dd53 100644 --- a/src/anemoi/datasets/create/sources/netcdf.py +++ b/src/anemoi/datasets/create/sources/netcdf.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/opendap.py b/src/anemoi/datasets/create/sources/opendap.py index 483295a8b..34e3fe94d 100644 --- a/src/anemoi/datasets/create/sources/opendap.py +++ b/src/anemoi/datasets/create/sources/opendap.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py index b710bcbbe..07e8f0203 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/recentre.py b/src/anemoi/datasets/create/sources/recentre.py index 53ace8152..d0959f664 100644 --- a/src/anemoi/datasets/create/sources/recentre.py +++ b/src/anemoi/datasets/create/sources/recentre.py @@ -11,9 +11,8 @@ from typing import Any from anemoi.datasets.compute.recentre import recentre as _recentre - -from .legacy import legacy_source -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars def to_list(x: list | tuple | str) -> list: diff --git a/src/anemoi/datasets/create/sources/source.py b/src/anemoi/datasets/create/sources/source.py index 0db02e6db..1bac545d8 100644 --- a/src/anemoi/datasets/create/sources/source.py +++ b/src/anemoi/datasets/create/sources/source.py @@ -12,10 +12,9 @@ from earthkit.data import from_source +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - @legacy_source(__file__) def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: diff --git a/src/anemoi/datasets/create/sources/tendencies.py b/src/anemoi/datasets/create/sources/tendencies.py index 01c4d1bda..222dca9a4 100644 --- a/src/anemoi/datasets/create/sources/tendencies.py +++ b/src/anemoi/datasets/create/sources/tendencies.py @@ -14,10 +14,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - def _date_to_datetime(d: Any) -> Any: """Converts a date string or a list/tuple of date strings to datetime objects. diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py index d63b708d6..5e3cc4c10 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/create/sources/xarray.py @@ -11,13 +11,12 @@ import earthkit.data as ekd +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources.xarray_support import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support import load_many +from anemoi.datasets.create.sources.xarray_support import load_one from anemoi.datasets.create.typing import DateList -from ..source import Source -from .xarray_support import XarrayFieldList -from .xarray_support import load_many -from .xarray_support import load_one - __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py index 056d756ca..632a7cae2 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/create/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 33a057520..c33ce7bfc 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -15,10 +15,9 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.patterns import iterate_patterns - -from ..legacy import legacy_source -from .fieldlist import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support.fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 78f7de041..85f9970f8 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from .coordinates import extract_single_value -from .coordinates import is_scalar -from .metadata import XArrayMetadata +from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py index 48f9cf0e1..174cb2716 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from .field import EmptyFieldList -from .flavour import CoordinateGuesser -from .patch import patch_dataset -from .time import Time -from .variable import FilteredVariable -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList +from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser +from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset +from anemoi.datasets.create.sources.xarray_support.time import Time +from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 80f0b6a62..74fcdbd03 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from .coordinates import Coordinate -from .coordinates import DateCoordinate -from .coordinates import EnsembleCoordinate -from .coordinates import LatitudeCoordinate -from .coordinates import LevelCoordinate -from .coordinates import LongitudeCoordinate -from .coordinates import PointCoordinate -from .coordinates import ScalarCoordinate -from .coordinates import StepCoordinate -from .coordinates import TimeCoordinate -from .coordinates import UnsupportedCoordinate -from .coordinates import XCoordinate -from .coordinates import YCoordinate -from .coordinates import is_scalar -from .grid import Grid -from .grid import MeshedGrid -from .grid import MeshProjectionGrid -from .grid import UnstructuredGrid -from .grid import UnstructuredProjectionGrid +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.grid import Grid +from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid +from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 23713ae74..2230db3ef 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from .field import XArrayField + from anemoi.datasets.create.sources.xarray_support.field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py index 847b21598..7b1f60e58 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/create/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from .coordinates import Coordinate -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 5d2c1c5b1..13d6fa4e2 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from .field import XArrayField +from anemoi.datasets.create.sources.xarray_support.field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_zarr.py b/src/anemoi/datasets/create/sources/xarray_zarr.py index e91de781e..2f96ab207 100644 --- a/src/anemoi/datasets/create/sources/xarray_zarr.py +++ b/src/anemoi/datasets/create/sources/xarray_zarr.py @@ -11,8 +11,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/zenodo.py b/src/anemoi/datasets/create/sources/zenodo.py index 1b746bb42..e23b8fa47 100644 --- a/src/anemoi/datasets/create/sources/zenodo.py +++ b/src/anemoi/datasets/create/sources/zenodo.py @@ -14,9 +14,9 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from .legacy import legacy_source -from .patterns import iterate_patterns -from .xarray import load_one +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.patterns import iterate_patterns +from anemoi.datasets.create.sources.xarray import load_one @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index f74cbf364..e8e71c45a 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -23,8 +23,8 @@ from anemoi.utils.provenance import gather_provenance_info from numpy.typing import NDArray -from ..check import check_data_values -from .summary import Summary +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.statistics.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/statistics/summary.py b/src/anemoi/datasets/create/statistics/summary.py index 6c7bbb433..8b6c29eb0 100644 --- a/src/anemoi/datasets/create/statistics/summary.py +++ b/src/anemoi/datasets/create/statistics/summary.py @@ -13,9 +13,9 @@ import numpy as np -from ..check import StatisticsValueError -from ..check import check_data_values -from ..check import check_stats +from anemoi.datasets.create.check import StatisticsValueError +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.check import check_stats class Summary(dict): diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py index f32d83bb2..fc2b0839b 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/data/__init__.py @@ -15,13 +15,13 @@ # from .dataset import FullIndex # from .dataset import Shape # from .dataset import TupleIndex -from .misc import _open_dataset -from .misc import _save_dataset -from .misc import add_dataset_path -from .misc import add_named_dataset +from anemoi.datasets.data.misc import _open_dataset +from anemoi.datasets.data.misc import _save_dataset +from anemoi.datasets.data.misc import add_dataset_path +from anemoi.datasets.data.misc import add_named_dataset if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py index 7f6f4484e..e8698b833 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/data/complement.py @@ -16,18 +16,18 @@ import numpy as np from numpy.typing import NDArray -from ..grids import nearest_grid_points -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open_dataset +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open_dataset +from anemoi.datasets.grids import nearest_grid_points LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/concat.py b/src/anemoi/datasets/data/concat.py index 4398c15eb..4afffc04f 100644 --- a/src/anemoi/datasets/data/concat.py +++ b/src/anemoi/datasets/data/concat.py @@ -16,20 +16,20 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def check_dataset_compatibility(cls, datasets: list[Any], fill_missing_gaps: boo s = ranges[i + 1] if r[1] + frequency != s[0]: if fill_missing_gaps: - from .missing import MissingDataset + from anemoi.datasets.data.missing import MissingDataset result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency)) else: diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 20ac70cd8..3b0cbb3f5 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -34,8 +34,8 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .debug import Node -from .debug import Source +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source if TYPE_CHECKING: import matplotlib @@ -166,7 +166,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # This one must be first if "fill_missing_dates" in kwargs: - from .fill_missing import fill_missing_dates_factory + from anemoi.datasets.data.fill_missing import fill_missing_dates_factory fill_missing_dates = kwargs.pop("fill_missing_dates") ds = fill_missing_dates_factory(self, fill_missing_dates, kwargs) @@ -178,7 +178,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": padding = kwargs.pop("padding", None) if padding: - from .padded import Padded + from anemoi.datasets.data.padded import Padded frequency = kwargs.pop("frequency", self.frequency) return ( @@ -194,14 +194,14 @@ def __subset(self, **kwargs: Any) -> "Dataset": .mutate() ) - from .subset import Subset + from anemoi.datasets.data.subset import Subset return ( Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate() ) if "frequency" in kwargs: - from .subset import Subset + from anemoi.datasets.data.subset import Subset if "interpolate_frequency" in kwargs: raise ValueError("Cannot use both `frequency` and `interpolate_frequency`") @@ -214,38 +214,38 @@ def __subset(self, **kwargs: Any) -> "Dataset": ) if "select" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select select = kwargs.pop("select") return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate() if "drop" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select drop = kwargs.pop("drop") return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate() if "reorder" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select reorder = kwargs.pop("reorder") return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() if "rename" in kwargs: - from .select import Rename + from anemoi.datasets.data.select import Rename rename = kwargs.pop("rename") return Rename(self, rename)._subset(**kwargs).mutate() if "rescale" in kwargs: - from .rescale import Rescale + from anemoi.datasets.data.rescale import Rescale rescale = kwargs.pop("rescale") return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from ..data import open_dataset - from .statistics import Statistics + from anemoi.datasets.data import open_dataset + from anemoi.datasets.data.statistics import Statistics statistics = kwargs.pop("statistics") @@ -253,26 +253,26 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Note: trim_edge should go before thinning if "trim_edge" in kwargs: - from .masked import TrimEdge + from anemoi.datasets.data.masked import TrimEdge edge = kwargs.pop("trim_edge") return TrimEdge(self, edge)._subset(**kwargs).mutate() if "thinning" in kwargs: - from .masked import Thinning + from anemoi.datasets.data.masked import Thinning thinning = kwargs.pop("thinning") method = kwargs.pop("method", "every-nth") return Thinning(self, thinning, method)._subset(**kwargs).mutate() if "area" in kwargs: - from .masked import Cropping + from anemoi.datasets.data.masked import Cropping bbox = kwargs.pop("area") return Cropping(self, bbox)._subset(**kwargs).mutate() if "number" in kwargs or "numbers" in kwargs or "member" in kwargs or "members" in kwargs: - from .ensemble import Number + from anemoi.datasets.data.ensemble import Number members = {} for key in ["number", "numbers", "member", "members"]: @@ -282,13 +282,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Number(self, **members)._subset(**kwargs).mutate() if "set_missing_dates" in kwargs: - from .missing import MissingDates + from anemoi.datasets.data.missing import MissingDates set_missing_dates = kwargs.pop("set_missing_dates") return MissingDates(self, set_missing_dates)._subset(**kwargs).mutate() if "skip_missing_dates" in kwargs: - from .missing import SkipMissingDates + from anemoi.datasets.data.missing import SkipMissingDates if "expected_access" not in kwargs: raise ValueError("`expected_access` is required with `skip_missing_dates`") @@ -300,13 +300,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() if "interpolate_frequency" in kwargs: - from .interpolate import InterpolateFrequency + from anemoi.datasets.data.interpolate import InterpolateFrequency interpolate_frequency = kwargs.pop("interpolate_frequency") return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate() if "interpolate_variables" in kwargs: - from .interpolate import InterpolateNearest + from anemoi.datasets.data.interpolate import InterpolateNearest interpolate_variables = kwargs.pop("interpolate_variables") max_distance = kwargs.pop("max_distance", None) @@ -314,7 +314,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Keep last if "shuffle" in kwargs: - from .subset import Subset + from anemoi.datasets.data.subset import Subset shuffle = kwargs.pop("shuffle") @@ -378,8 +378,8 @@ def _dates_to_indices( list of int The list of indices. """ - from .misc import as_first_date - from .misc import as_last_date + from anemoi.datasets.data.misc import as_first_date + from anemoi.datasets.data.misc import as_last_date # TODO: optimize diff --git a/src/anemoi/datasets/data/debug.py b/src/anemoi/datasets/data/debug.py index 8af296b3d..8623c1307 100644 --- a/src/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/data/debug.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/ensemble.py b/src/anemoi/datasets/data/ensemble.py index 4826fa81d..b94c20f54 100644 --- a/src/anemoi/datasets/data/ensemble.py +++ b/src/anemoi/datasets/data/ensemble.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .debug import Node -from .forwards import Forwards -from .forwards import GivenAxis -from .indexing import apply_index_to_slices_changes -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.forwards import GivenAxis +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/fill_missing.py b/src/anemoi/datasets/data/fill_missing.py index d705b1d75..0cc1b0ee2 100644 --- a/src/anemoi/datasets/data/fill_missing.py +++ b/src/anemoi/datasets/data/fill_missing.py @@ -15,17 +15,16 @@ from numpy.typing import NDArray from anemoi.datasets.data import MissingDateError - -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index decadabdd..9b40859bb 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -18,16 +18,16 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import debug_indexing -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index 1e3a40cf7..fee2c792e 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -16,21 +16,21 @@ from numpy.typing import NDArray from scipy.spatial import cKDTree -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .forwards import GivenAxis -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.forwards import GivenAxis +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/indexing.py b/src/anemoi/datasets/data/indexing.py index 106023ccb..7c4bb4be3 100644 --- a/src/anemoi/datasets/data/indexing.py +++ b/src/anemoi/datasets/data/indexing.py @@ -15,9 +15,9 @@ import numpy as np from numpy.typing import NDArray -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex def _tuple_with_slices(t: TupleIndex, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]: diff --git a/src/anemoi/datasets/data/interpolate.py b/src/anemoi/datasets/data/interpolate.py index b03404645..1f64d21a9 100644 --- a/src/anemoi/datasets/data/interpolate.py +++ b/src/anemoi/datasets/data/interpolate.py @@ -17,17 +17,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -227,7 +227,7 @@ def __init__(self, dataset: Dataset, interpolate_variables: list[str], max_dista max_distance : Optional[float], optional The maximum distance for nearest neighbor search, by default None. """ - from ..grids import nearest_grid_points + from anemoi.datasets.grids import nearest_grid_points super().__init__(dataset) self.vars = interpolate_variables diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 5eaf9c022..bc1de23a3 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -17,20 +17,20 @@ import rich from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) @@ -174,7 +174,7 @@ def _overlay(self) -> Dataset: if not ok: LOG.warning("Dataset %r completely overridden.", d) - from .select import Select + from anemoi.datasets.data.select import Select rich.print("Overlaying join with", variables, len(indices), [d.shape for d in self.datasets]) diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index 32148d7b0..39b3d9dc9 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -15,18 +15,18 @@ import numpy as np from numpy.typing import NDArray -from ..grids import cropping_mask -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.grids import cropping_mask LOG = logging.getLogger(__name__) @@ -220,7 +220,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from ..data import open_dataset + from anemoi.datasets.data import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/data/merge.py index ca2697dda..b974a6afb 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/data/merge.py @@ -16,19 +16,19 @@ import numpy as np from numpy.typing import NDArray -from . import MissingDateError -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index d6b88c04c..416458d82 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -23,7 +23,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) @@ -323,11 +323,11 @@ def _concat_or_join(datasets: list["Dataset"], kwargs: dict[str, Any]) -> tuple[ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets] if len(set(ranges)) == 1: - from .join import Join + from anemoi.datasets.data.join import Join return Join(datasets)._overlay(), kwargs - from .concat import Concat + from anemoi.datasets.data.concat import Concat Concat.check_dataset_compatibility(datasets) @@ -347,9 +347,9 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " Dataset The opened dataset. """ - from .dataset import Dataset - from .stores import Zarr - from .stores import dataset_lookup + from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.data.stores import Zarr + from anemoi.datasets.data.stores import dataset_lookup if isinstance(a, Dataset): return a.mutate() @@ -508,7 +508,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from .observations import observations_factory + from anemoi.datasets.data.observations import observations_factory assert not sets, sets @@ -516,70 +516,70 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "xy" in kwargs: # Experimental feature, may be removed - from .xy import xy_factory + from anemoi.datasets.data.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "x" in kwargs and "y" in kwargs: # Experimental feature, may be removed - from .xy import xy_factory + from anemoi.datasets.data.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "zip" in kwargs: # Experimental feature, may be removed - from .xy import zip_factory + from anemoi.datasets.data.xy import zip_factory assert not sets, sets return zip_factory(args, kwargs).mutate() if "chain" in kwargs: # Experimental feature, may be removed - from .unchecked import chain_factory + from anemoi.datasets.data.unchecked import chain_factory assert not sets, sets return chain_factory(args, kwargs).mutate() if "join" in kwargs: - from .join import join_factory + from anemoi.datasets.data.join import join_factory assert not sets, sets return join_factory(args, kwargs).mutate() if "concat" in kwargs: - from .concat import concat_factory + from anemoi.datasets.data.concat import concat_factory assert not sets, sets return concat_factory(args, kwargs).mutate() if "merge" in kwargs: - from .merge import merge_factory + from anemoi.datasets.data.merge import merge_factory assert not sets, sets return merge_factory(args, kwargs).mutate() if "ensemble" in kwargs: - from .ensemble import ensemble_factory + from anemoi.datasets.data.ensemble import ensemble_factory assert not sets, sets return ensemble_factory(args, kwargs).mutate() if "grids" in kwargs: - from .grids import grids_factory + from anemoi.datasets.data.grids import grids_factory assert not sets, sets return grids_factory(args, kwargs).mutate() if "cutout" in kwargs: - from .grids import cutout_factory + from anemoi.datasets.data.grids import cutout_factory assert not sets, sets return cutout_factory(args, kwargs).mutate() if "complement" in kwargs: - from .complement import complement_factory + from anemoi.datasets.data.complement import complement_factory assert not sets, sets return complement_factory(args, kwargs).mutate() diff --git a/src/anemoi/datasets/data/missing.py b/src/anemoi/datasets/data/missing.py index 8e0fb44ff..f34904d23 100644 --- a/src/anemoi/datasets/data/missing.py +++ b/src/anemoi/datasets/data/missing.py @@ -18,15 +18,14 @@ from anemoi.datasets.create.utils import to_datetime from anemoi.datasets.data import MissingDateError - -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import expand_list_indexing -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py index 019359846..c5cead4bc 100644 --- a/src/anemoi/datasets/data/observations/__init__.py +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -15,8 +15,7 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.data.dataset import Dataset - -from ..debug import Node +from anemoi.datasets.data.debug import Node LOG = logging.getLogger(__name__) @@ -142,7 +141,7 @@ def __init__(self, dataset, frequency=None, window=None): if isinstance(dataset, zarr.hierarchy.Group): dataset = dataset._store.path - from ..stores import zarr_lookup + from anemoi.datasets.data.stores import zarr_lookup dataset = zarr_lookup(dataset) self.path = dataset @@ -180,7 +179,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from .legacy_obs_dataset import ObsDataset + from anemoi.datasets.data.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 0c55d988a..01298561e 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -20,8 +20,7 @@ from anemoi.datasets.data.debug import Node from anemoi.datasets.data.records.backends import backend_factory - -from .windows import window_from_str +from anemoi.datasets.data.records.windows import window_from_str LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/data/rescale.py index b6071f3c7..44cbba349 100644 --- a/src/anemoi/datasets/data/rescale.py +++ b/src/anemoi/datasets/data/rescale.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index b422899ee..8f091de3e 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -15,18 +15,18 @@ from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/statistics.py b/src/anemoi/datasets/data/statistics.py index af0d4bc6e..2bb26b3d6 100644 --- a/src/anemoi/datasets/data/statistics.py +++ b/src/anemoi/datasets/data/statistics.py @@ -15,10 +15,10 @@ from numpy.typing import NDArray -from . import open_dataset -from .dataset import Dataset -from .debug import Node -from .forwards import Forwards +from anemoi.datasets.data import open_dataset +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Forwards LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index c5b3b4bc2..bf1e74cad 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -23,17 +23,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from . import MissingDateError -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import DEBUG_ZARR_LOADING -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .indexing import expand_list_indexing -from .misc import load_config +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import DEBUG_ZARR_LOADING +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.misc import load_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 6d68c61a8..26d82ecf1 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -19,19 +19,19 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import make_slice_or_index_from_list_or_tuple -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import make_slice_or_index_from_list_or_tuple +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def _start(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the start date. """ - from .misc import as_first_date + from anemoi.datasets.data.misc import as_first_date c = as_first_date(a, dates) d = as_first_date(b, dates) @@ -82,7 +82,7 @@ def _end(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the end date. """ - from .misc import as_last_date + from anemoi.datasets.data.misc import as_last_date c = as_last_date(a, dates) d = as_last_date(b, dates) diff --git a/src/anemoi/datasets/data/unchecked.py b/src/anemoi/datasets/data/unchecked.py index cb4a1304c..478c8c1eb 100644 --- a/src/anemoi/datasets/data/unchecked.py +++ b/src/anemoi/datasets/data/unchecked.py @@ -18,14 +18,14 @@ import numpy as np from numpy.typing import NDArray -from .concat import ConcatMixin -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .debug import Node -from .forwards import Combined -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.concat import ConcatMixin +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/xy.py b/src/anemoi/datasets/data/xy.py index d3ae622bb..e181dc9aa 100644 --- a/src/anemoi/datasets/data/xy.py +++ b/src/anemoi/datasets/data/xy.py @@ -12,12 +12,12 @@ from functools import cached_property from typing import Any -from .dataset import Dataset -from .dataset import FullIndex -from .debug import Node -from .forwards import Combined -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py index a7057c1c2..134f1cc27 100644 --- a/src/anemoi/datasets/recipe.py +++ b/src/anemoi/datasets/recipe.py @@ -472,7 +472,7 @@ def dump(self, file=sys.stdout): if self.platform: result["platform"] = self.platform - from .dumper import yaml_dump + from anemoi.datasets.dumper import yaml_dump yaml_dump(_un_dotdict(result), stream=file) From 2b9b425f161dfd28eabf9e64be43592d50bd0b46 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:38:02 +0000 Subject: [PATCH 138/212] absolute imports --- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/recipe/__init__.py | 4 +- .../datasets/commands/recipe/migrate.py | 2 +- src/anemoi/datasets/create/fields/__init__.py | 1691 ----------------- .../data/records/backends/__init__.py | 6 +- tests/create/utils/create.py | 2 +- 6 files changed, 8 insertions(+), 1699 deletions(-) delete mode 100644 src/anemoi/datasets/create/fields/__init__.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 45af78a44..787f0fc89 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create.fields import creator_factory + from anemoi.datasets.create.fields.actors import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 26f0d486f..5c7b6f176 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -18,8 +18,8 @@ from anemoi.datasets.commands import Command from anemoi.datasets.commands.recipe.format import format_recipe from anemoi.datasets.commands.recipe.migrate import migrate_recipe -from anemoi.datasets.create.fields import config_to_python -from anemoi.datasets.create.fields import validate_config +from anemoi.datasets.create.fields.actors import config_to_python +from anemoi.datasets.create.fields.actors import validate_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 2a6112410..dbfde4143 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.create.fields import validate_config +from anemoi.datasets.create.fields.actors import validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/fields/__init__.py b/src/anemoi/datasets/create/fields/__init__.py deleted file mode 100644 index 6301d45ee..000000000 --- a/src/anemoi/datasets/create/fields/__init__.py +++ /dev/null @@ -1,1691 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import json -import logging -import os -import time -import uuid -import warnings -from functools import cached_property -from typing import Any - -import cftime -import numpy as np -import tqdm -import zarr -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta -from anemoi.utils.humanize import compress_dates -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.sanitise import sanitise -from earthkit.data.core.order import build_remapping - -from anemoi.datasets import MissingDateError -from anemoi.datasets import open_dataset -from anemoi.datasets.create.check import DatasetName -from anemoi.datasets.create.check import check_data_values -from anemoi.datasets.create.chunks import ChunkFilter -from anemoi.datasets.create.config import build_output -from anemoi.datasets.create.config import loader_config -from anemoi.datasets.create.fields.context import FieldContext -from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.input.trace import enable_trace -from anemoi.datasets.create.persistent import build_storage -from anemoi.datasets.create.statistics import Summary -from anemoi.datasets.create.statistics import TmpStatistics -from anemoi.datasets.create.statistics import check_variance -from anemoi.datasets.create.statistics import compute_statistics -from anemoi.datasets.create.statistics import default_statistics_dates -from anemoi.datasets.create.statistics import fix_variance -from anemoi.datasets.create.utils import normalize_and_check_dates -from anemoi.datasets.create.writer import ViewCacheArray -from anemoi.datasets.data.misc import as_first_date -from anemoi.datasets.data.misc import as_last_date -from anemoi.datasets.dates.groups import Groups - -LOG = logging.getLogger(__name__) - -VERSION = "0.30" - - -def json_tidy(o: Any) -> Any: - """Convert various types to JSON serializable format. - - Parameters - ---------- - o : Any - The object to convert. - - Returns - ------- - Any - The JSON serializable object. - """ - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.timedelta): - return frequency_to_string(o) - - if isinstance(o, cftime.DatetimeJulian): - import pandas as pd - - o = pd.Timestamp( - o.year, - o.month, - o.day, - o.hour, - o.minute, - o.second, - ) - return o.isoformat() - - if isinstance(o, (np.float32, np.float64)): - return float(o) - - raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") - - -def build_statistics_dates( - dates: list[datetime.datetime], - start: datetime.datetime | None, - end: datetime.datetime | None, -) -> tuple[str, str]: - """Compute the start and end dates for the statistics. - - Parameters - ---------- - dates : list of datetime.datetime - The list of dates. - start : Optional[datetime.datetime] - The start date. - end : Optional[datetime.datetime] - The end date. - - Returns - ------- - tuple of str - The start and end dates in ISO format. - """ - # if not specified, use the default statistics dates - default_start, default_end = default_statistics_dates(dates) - if start is None: - start = default_start - if end is None: - end = default_end - - # in any case, adapt to the actual dates in the dataset - start = as_first_date(start, dates) - end = as_last_date(end, dates) - - # and convert to datetime to isoformat - start = start.astype(datetime.datetime) - end = end.astype(datetime.datetime) - return (start.isoformat(), end.isoformat()) - - -def _path_readable(path: str) -> bool: - """Check if the path is readable. - - Parameters - ---------- - path : str - The path to check. - - Returns - ------- - bool - True if the path is readable, False otherwise. - """ - import zarr - - try: - zarr.open(path, "r") - return True - except zarr.errors.PathNotFoundError: - return False - - -class Dataset: - """A class to represent a dataset.""" - - def __init__(self, path: str): - """Initialize a Dataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - self.path = path - - _, ext = os.path.splitext(self.path) - if ext != ".zarr": - raise ValueError(f"Unsupported extension={ext} for path={self.path}") - - def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: - """Add a dataset to the Zarr store. - - Parameters - ---------- - mode : str, optional - The mode to open the Zarr store. - **kwargs - Additional arguments for the dataset. - - Returns - ------- - zarr.Array - The added dataset. - """ - import zarr - - z = zarr.open(self.path, mode=mode) - from anemoi.datasets.create.zarr import add_zarr_dataset - - return add_zarr_dataset(zarr_root=z, **kwargs) - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - import zarr - - LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") - for k, v in kwargs.items(): - if isinstance(v, np.datetime64): - v = v.astype(datetime.datetime) - if isinstance(v, datetime.date): - v = v.isoformat() - z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) - - @cached_property - def anemoi_dataset(self) -> Any: - """Get the Anemoi dataset.""" - return open_dataset(self.path) - - @cached_property - def zarr_metadata(self) -> dict: - """Get the Zarr metadata.""" - import zarr - - return dict(zarr.open(self.path, mode="r").attrs) - - def print_info(self) -> None: - """Print information about the dataset.""" - import zarr - - z = zarr.open(self.path, mode="r") - try: - LOG.info(z["data"].info) - except Exception as e: - LOG.info(e) - - def get_zarr_chunks(self) -> tuple: - """Get the chunks of the Zarr dataset. - - Returns - ------- - tuple - The chunks of the Zarr dataset. - """ - import zarr - - z = zarr.open(self.path, mode="r") - return z["data"].chunks - - def check_name( - self, - resolution: str, - dates: list[datetime.datetime], - frequency: datetime.timedelta, - raise_exception: bool = True, - is_test: bool = False, - ) -> None: - """Check the name of the dataset. - - Parameters - ---------- - resolution : str - The resolution of the dataset. - dates : list of datetime.datetime - The dates of the dataset. - frequency : datetime.timedelta - The frequency of the dataset. - raise_exception : bool, optional - Whether to raise an exception if the name is invalid. - is_test : bool, optional - Whether this is a test. - """ - basename, _ = os.path.splitext(os.path.basename(self.path)) - try: - DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() - except Exception as e: - if raise_exception and not is_test: - raise e - else: - LOG.warning(f"Dataset name error: {e}") - - def get_main_config(self) -> Any: - """Get the main configuration of the dataset. - - Returns - ------- - Any - The main configuration. - """ - import zarr - - z = zarr.open(self.path, mode="r") - config = loader_config(z.attrs.get("_create_yaml_config")) - - if "env" in config: - for k, v in config["env"].items(): - LOG.info(f"Setting env variable {k}={v}") - os.environ[k] = str(v) - - return config - - -class WritableDataset(Dataset): - """A class to represent a writable dataset.""" - - def __init__(self, path: str): - """Initialize a WritableDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="r+") - - @cached_property - def data_array(self) -> Any: - """Get the data array of the dataset.""" - import zarr - - return zarr.open(self.path, mode="r+")["data"] - - -class NewDataset(Dataset): - """A class to represent a new dataset.""" - - def __init__(self, path: str, overwrite: bool = False): - """Initialize a NewDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - overwrite : bool, optional - Whether to overwrite the existing dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="w") - self.z.create_group("_build") - - -class Actor: # TODO: rename to Creator - """A base class for dataset creation actors.""" - - dataset_class = WritableDataset - - def __init__(self, path: str, cache: str | None = None): - """Initialize an Actor instance. - - Parameters - ---------- - path : str - The path to the dataset. - cache : Optional[str], optional - The cache directory. - """ - # Catch all floating point errors, including overflow, sqrt(<0), etc - np.seterr(all="raise", under="warn") - - self.path = path - self.cache = cache - self.dataset = self.dataset_class(self.path) - - def run(self) -> None: - """Run the actor.""" - # to be implemented in the sub-classes - raise NotImplementedError() - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - self.dataset.update_metadata(**kwargs) - - def _cache_context(self) -> Any: - """Get the cache context. - - Returns - ------- - Any - The cache context. - """ - from anemoi.datasets.create.utils import cache_context - - return cache_context(self.cache) - - def check_unkown_kwargs(self, kwargs: dict) -> None: - """Check for unknown keyword arguments. - - Parameters - ---------- - kwargs : dict - The keyword arguments. - """ - # remove this latter - LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") - - def read_dataset_metadata(self, path: str) -> None: - """Read the metadata of the dataset. - - Parameters - ---------- - path : str - The path to the dataset. - """ - ds = open_dataset(path) - self.dataset_shape = ds.shape - self.variables_names = ds.variables - assert len(self.variables_names) == ds.shape[1], self.dataset_shape - self.dates = ds.dates - - self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) - - def check_missing_dates(expected: list[np.datetime64]) -> None: - """Check if the missing dates in the dataset match the expected dates. - - Parameters - ---------- - expected : list of np.datetime64 - The expected missing dates. - - Raises - ------ - ValueError - If the missing dates in the dataset do not match the expected dates. - """ - import zarr - - z = zarr.open(path, "r") - missing_dates = z.attrs.get("missing_dates", []) - missing_dates = sorted([np.datetime64(d) for d in missing_dates]) - if missing_dates != expected: - LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") - LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") - LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") - raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") - - check_missing_dates(self.missing_dates) - - -class Patch(Actor): - """A class to apply patches to a dataset.""" - - def __init__(self, path: str, options: dict = None, **kwargs: Any): - """Initialize a Patch instance. - - Parameters - ---------- - path : str - The path to the dataset. - options : dict, optional - The patch options. - """ - self.path = path - self.options = options or {} - - def run(self) -> None: - """Run the patch.""" - from anemoi.datasets.create.patch import apply_patch - - apply_patch(self.path, **self.options) - - -class Size(Actor): - """A class to compute the size of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Size instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the size computation.""" - from anemoi.datasets.create.size import compute_directory_sizes - - metadata = compute_directory_sizes(self.path) - self.update_metadata(**metadata) - - # Look for constant fields - ds = open_dataset(self.path) - constants = ds.computed_constant_fields() - - variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() - for k in constants: - variables_metadata[k]["constant_in_time"] = True - - self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) - - -class HasRegistryMixin: - """A mixin class to provide registry functionality.""" - - @cached_property - def registry(self) -> Any: - """Get the registry.""" - from anemoi.datasets.create.zarr import ZarrBuiltRegistry - - return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) - - -class HasStatisticTempMixin: - """A mixin class to provide temporary statistics functionality.""" - - @cached_property - def tmp_statistics(self) -> TmpStatistics: - """Get the temporary statistics.""" - directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") - return TmpStatistics(directory) - - -class HasElementForDataMixin: - """A mixin class to provide element creation functionality for data.""" - - def create_elements(self, config: Any) -> None: - """Create elements for the dataset. - - Parameters - ---------- - config : Any - The configuration. - """ - assert self.registry - assert self.tmp_statistics - - LOG.info(dict(config.dates)) - - self.groups = Groups(**config.dates) - LOG.info(self.groups) - - self.output = build_output(config.output, parent=self) - - self.context = FieldContext( - order_by=self.output.order_by, - flatten_grid=self.output.flatten_grid, - remapping=build_remapping(self.output.remapping), - use_grib_paramid=config.build.use_grib_paramid, - ) - - self.input = InputBuilder( - config.input, - data_sources=config.get("data_sources", {}), - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(self.input) - - -class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to initialize a new dataset.""" - - dataset_class = NewDataset - - def __init__( - self, - path: str, - config: dict, - check_name: bool = False, - overwrite: bool = False, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - test: bool = False, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize an Init instance. - - Parameters - ---------- - path : str - The path to the dataset. - config : dict - The configuration. - check_name : bool, optional - Whether to check the dataset name. - overwrite : bool, optional - Whether to overwrite the existing dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - test : bool, optional - Whether this is a test. - cache : Optional[str], optional - The cache directory. - """ - if _path_readable(path) and not overwrite: - raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - - super().__init__(path, cache=cache) - self.config = config - self.check_name = check_name - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.test = test - - self.main_config = loader_config(config, is_test=test) - - # self.registry.delete() ?? - self.tmp_statistics.delete() - - assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by - self.create_elements(self.main_config) - - LOG.info(f"Groups: {self.groups}") - - # window = self.main_config.dates.get("window") - - one_date = self.groups.one_date() - - self.minimal_input = self.input.select(self.context, one_date) - - LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") - LOG.info(self.minimal_input) - - def run(self) -> int: - """Run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - with self._cache_context(): - return self._run() - - def _run(self) -> int: - """Internal method to run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - """Create an empty dataset of the right final shape. - - Read a small part of the data to get the shape of the data and the resolution and more metadata. - """ - - LOG.info("Config loaded ok:") - # LOG.info(self.main_config) - - dates = self.groups.provider.values - frequency = self.groups.provider.frequency - missing = self.groups.provider.missing - - assert isinstance(frequency, datetime.timedelta), frequency - - LOG.info(f"Found {len(dates)} datetimes.") - LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(missing)}") - lengths = tuple(len(g) for g in self.groups) - - variables = self.minimal_input.variables - LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") - - variables_with_nans = self.main_config.statistics.get("allow_nans", []) - - ensembles = self.minimal_input.ensembles - LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") - - grid_points = self.minimal_input.grid_points - LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") - - resolution = self.minimal_input.resolution - LOG.info(f"{resolution=}") - - coords = self.minimal_input.coords - coords["dates"] = dates - total_shape = self.minimal_input.shape - total_shape[0] = len(dates) - LOG.info(f"total_shape = {total_shape}") - - chunks = self.output.get_chunking(coords) - LOG.info(f"{chunks=}") - dtype = self.output.dtype - - LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") - - metadata = {} - metadata["uuid"] = str(uuid.uuid4()) - - metadata.update(self.main_config.get("add_metadata", {})) - - metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() - - recipe = sanitise(self.main_config.get_serialisable_dict()) - - # Remove stuff added by prepml - for k in [ - "build_dataset", - "config_format_version", - "config_path", - "dataset_status", - "ecflow", - "metadata", - "platform", - "reading_chunks", - "upload", - ]: - recipe.pop(k, None) - - metadata["recipe"] = recipe - - metadata["description"] = self.main_config.description - metadata["licence"] = self.main_config["licence"] - metadata["attribution"] = self.main_config["attribution"] - - metadata["remapping"] = self.output.remapping - metadata["order_by"] = self.output.order_by_as_list - metadata["flatten_grid"] = self.output.flatten_grid - - metadata["ensemble_dimension"] = len(ensembles) - metadata["variables"] = variables - metadata["variables_with_nans"] = variables_with_nans - metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) - metadata["resolution"] = resolution - - metadata["data_request"] = self.minimal_input.data_request - metadata["field_shape"] = self.minimal_input.field_shape - metadata["proj_string"] = self.minimal_input.proj_string - metadata["variables_metadata"] = self.minimal_input.variables_metadata - - metadata["start_date"] = dates[0].isoformat() - metadata["end_date"] = dates[-1].isoformat() - metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in missing] - metadata["origins"] = self.minimal_input.origins - - metadata["version"] = VERSION - - self.dataset.check_name( - raise_exception=self.check_name, - is_test=self.test, - resolution=resolution, - dates=dates, - frequency=frequency, - ) - - if len(dates) != total_shape[0]: - raise ValueError( - f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " - f"does not match data shape {total_shape[0]}. {total_shape=}" - ) - - dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) - - metadata.update(self.main_config.get("force_metadata", {})) - - ############################################################### - # write metadata - ############################################################### - - self.update_metadata(**metadata) - - self.dataset.add_dataset( - name="data", - chunks=chunks, - dtype=dtype, - shape=total_shape, - dimensions=("time", "variable", "ensemble", "cell"), - ) - self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) - self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) - self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) - - self.registry.create(lengths=lengths) - self.tmp_statistics.create(exist_ok=False) - self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - - statistics_start, statistics_end = build_statistics_dates( - dates, - self.main_config.statistics.get("start"), - self.main_config.statistics.get("end"), - ) - self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) - LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") - - self.registry.add_to_history("init finished") - - assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - - # Return the number of groups to process, so we can show a nice progress bar - return len(lengths) - - -class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to load data into a dataset.""" - - def __init__( - self, - path: str, - parts: str | None = None, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize a Load instance. - - Parameters - ---------- - path : str - The path to the dataset. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - cache : Optional[str], optional - The cache directory. - """ - super().__init__(path, cache=cache) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.parts = parts - self.dataset = WritableDataset(self.path) - - self.main_config = self.dataset.get_main_config() - self.create_elements(self.main_config) - self.read_dataset_metadata(self.dataset.path) - - total = len(self.registry.get_flags()) - self.chunk_filter = ChunkFilter(parts=self.parts, total=total) - - self.data_array = self.dataset.data_array - self.n_groups = len(self.groups) - - def run(self) -> None: - """Run the data loading.""" - with self._cache_context(): - self._run() - - def _run(self) -> None: - """Internal method to run the data loading.""" - for igroup, group in enumerate(self.groups): - if not self.chunk_filter(igroup): - continue - if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") - continue - - # assert isinstance(group[0], datetime.datetime), type(group[0]) - LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - - result = self.input.select(self.context, argument=group) - assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) - - # There are several groups. - # There is one result to load for each group. - self.load_result(result) - self.registry.set_flag(igroup) - - self.registry.add_provenance(name="provenance_load") - self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - - self.dataset.print_info() - - def load_result(self, result: Any) -> None: - """Load the result into the dataset. - - Parameters - ---------- - result : Any - The result to load. - """ - # There is one cube to load for each result. - dates = list(result.group_of_dates) - - LOG.debug(f"Loading cube for {len(dates)} dates") - - cube = result.get_cube() - shape = cube.extended_user_shape - dates_in_data = cube.user_coords["valid_datetime"] - - LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") - - def check_shape(cube, dates, dates_in_data): - if cube.extended_user_shape[0] != len(dates): - print( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - print("Requested dates", compress_dates(dates)) - print("Cube dates", compress_dates(dates_in_data)) - - a = {as_datetime(_) for _ in dates} - b = {as_datetime(_) for _ in dates_in_data} - - print("Missing dates", compress_dates(a - b)) - print("Extra dates", compress_dates(b - a)) - - raise ValueError( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - - check_shape(cube, dates, dates_in_data) - - def check_dates_in_data(dates_in_data, requested_dates): - _requested_dates = [np.datetime64(_) for _ in requested_dates] - _dates_in_data = [np.datetime64(_) for _ in dates_in_data] - if _dates_in_data != _requested_dates: - LOG.error("Dates in data are not the requested ones:") - - dates_in_data = set(dates_in_data) - requested_dates = set(requested_dates) - - missing = sorted(requested_dates - dates_in_data) - extra = sorted(dates_in_data - requested_dates) - - if missing: - LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") - if extra: - LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") - - raise ValueError("Dates in data are not the requested ones") - - check_dates_in_data(dates_in_data, dates) - - def dates_to_indexes(dates, all_dates): - x = np.array(dates, dtype=np.datetime64) - y = np.array(all_dates, dtype=np.datetime64) - bitmap = np.isin(x, y) - return np.where(bitmap)[0] - - indexes = dates_to_indexes(self.dates, dates_in_data) - - array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) - LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") - self.load_cube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) - self.tmp_statistics.write(indexes, stats, dates=dates_in_data) - LOG.info("Flush data array") - array.flush() - LOG.info("Flushed data array") - - def _get_allow_nans(self) -> bool | list: - """Get the allow_nans configuration. - - Returns - ------- - bool | list - The allow_nans configuration. - """ - config = self.main_config - if "allow_nans" in config.build: - return config.build.allow_nans - - return config.statistics.get("allow_nans", []) - - def load_cube(self, cube: Any, array: ViewCacheArray) -> None: - """Load the cube into the array. - - Parameters - ---------- - cube : Any - The cube to load. - array : ViewCacheArray - The array to load into. - """ - # There are several cubelets for each cube - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - LOG.debug(f"Loading datacube: {cube}") - - def position(x: Any) -> int | None: - if isinstance(x, str) and "/" in x: - x = x.split("/") - return int(x[0]) - return None - - bar = tqdm.tqdm( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - position=position(self.parts), - ) - for i, cubelet in enumerate(bar): - bar.set_description(f"Loading {i}/{total}") - - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values( - data[:], - name=name, - log=[i, data.shape, local_indexes], - allow_nans=self._get_allow_nans(), - ) - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.debug( - f"Elapsed: {seconds_to_human(time.time() - start)}, " - f"load time: {seconds_to_human(load)}, " - f"write time: {seconds_to_human(save)}." - ) - - -class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): - """A class to clean up temporary data and registry entries.""" - - def __init__( - self, - path: str, - statistics_temp_dir: str | None = None, - delta: list = [], - use_threads: bool = False, - **kwargs: Any, - ): - """Initialize a Cleanup instance. - - Parameters - ---------- - path : str - The path to the dataset. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - delta : list, optional - The delta values. - use_threads : bool, optional - Whether to use threads. - """ - super().__init__(path) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.additinon_temp_dir = statistics_temp_dir - self.actors = [ - _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) - for d in delta - ] - - def run(self) -> None: - """Run the cleanup.""" - - self.tmp_statistics.delete() - self.registry.clean() - for actor in self.actors: - actor.cleanup() - - -class Verify(Actor): - """A class to verify the integrity of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Verify instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the verification.""" - LOG.info(f"Verifying dataset at {self.path}") - LOG.info(str(self.dataset.anemoi_dataset)) - - -class AdditionsMixin: - """A mixin class to handle dataset additions.""" - - def skip(self) -> bool: - """Check if the additions should be skipped. - - Returns - ------- - bool - Whether to skip the additions. - """ - frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - if not self.delta.total_seconds() % frequency.total_seconds() == 0: - LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") - return True - - if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: - LOG.warning(f"Additions are disabled for {self.path} in the recipe.") - return True - - return False - - @cached_property - def tmp_storage_path(self) -> str: - """Get the path to the temporary storage.""" - name = "storage_for_additions" - if self.delta: - name += frequency_to_string(self.delta) - return os.path.join(f"{self.path}.{name}.tmp") - - def read_from_dataset(self) -> None: - """Read data from the dataset.""" - self.variables = self.dataset.anemoi_dataset.variables - self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - start = self.dataset.zarr_metadata["statistics_start_date"] - end = self.dataset.zarr_metadata["statistics_end_date"] - self.start = datetime.datetime.fromisoformat(start) - self.end = datetime.datetime.fromisoformat(end) - - ds = open_dataset(self.path, start=self.start, end=self.end) - self.dates = ds.dates - self.total = len(self.dates) - - idelta = self.delta.total_seconds() // self.frequency.total_seconds() - assert int(idelta) == idelta, idelta - idelta = int(idelta) - self.ds = DeltaDataset(ds, idelta) - - -class DeltaDataset: - """A class to represent a dataset with delta values.""" - - def __init__(self, ds: Any, idelta: int): - """Initialize a DeltaDataset instance. - - Parameters - ---------- - ds : Any - The dataset. - idelta : int - The delta value. - """ - self.ds = ds - self.idelta = idelta - - def __getitem__(self, i: int) -> Any: - """Get an item from the dataset. - - Parameters - ---------- - i : int - The index. - - Returns - ------- - Any - The item. - """ - j = i - self.idelta - if j < 0: - raise MissingDateError(f"Missing date {j}") - return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] - - -class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to initialize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize an _InitAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - def run(self) -> None: - """Run the additions initialization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) - self.tmp_storage.delete() - self.tmp_storage.create() - LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") - - def cleanup(self) -> None: - """Clean up the temporary storage.""" - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - self.tmp_storage.delete() - LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") - - -class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to run dataset additions.""" - - def __init__( - self, - path: str, - delta: str, - parts: str | None = None, - use_threads: bool = False, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a _RunAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - self.parts = parts - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Writing in {self.tmp_storage_path}") - - def run(self) -> None: - """Run the additions.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.read_from_dataset() - - chunk_filter = ChunkFilter(parts=self.parts, total=self.total) - for i in range(0, self.total): - if not chunk_filter(i): - continue - date = self.dates[i] - try: - arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) - self.tmp_storage.add([date, i, stats], key=date) - except MissingDateError: - self.tmp_storage.add([date, i, "missing"], key=date) - self.tmp_storage.flush() - LOG.debug(f"Dataset {self.path} additions run.") - - def allow_nans(self) -> bool: - """Check if NaNs are allowed. - - Returns - ------- - bool - Whether NaNs are allowed. - """ - if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): - return True - - variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) - if variables_with_nans is not None: - return variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") - return True - - -class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to finalize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize a _FinaliseAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Reading from {self.tmp_storage_path}.") - - def run(self) -> None: - """Run the additions finalization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}.") - return - - self.read_from_dataset() - - shape = (len(self.dates), len(self.variables)) - agg = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - has_nans=np.full(shape, False, dtype=np.bool_), - ) - LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") - - found = set() - ifound = set() - missing = set() - for _date, (date, i, stats) in self.tmp_storage.items(): - assert _date == date - if stats == "missing": - missing.add(date) - continue - - assert date not in found, f"Duplicates found {date}" - found.add(date) - ifound.add(i) - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k][i, ...] = stats[k] - - assert len(found) + len(missing) == len(self.dates), ( - len(found), - len(missing), - len(self.dates), - ) - assert found.union(missing) == set(self.dates), ( - found, - missing, - set(self.dates), - ) - - if len(ifound) < 2: - LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") - self.tmp_storage.delete() - return - - mask = sorted(list(ifound)) - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k] = agg[k][mask, ...] - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - assert agg[k].shape == agg["count"].shape, ( - agg[k].shape, - agg["count"].shape, - ) - - minimum = np.nanmin(agg["minimum"], axis=0) - maximum = np.nanmax(agg["maximum"], axis=0) - sums = np.nansum(agg["sums"], axis=0) - squares = np.nansum(agg["squares"], axis=0) - count = np.nansum(agg["count"], axis=0) - has_nans = np.any(agg["has_nans"], axis=0) - - assert sums.shape == count.shape - assert sums.shape == squares.shape - assert sums.shape == minimum.shape - assert sums.shape == maximum.shape - assert sums.shape == has_nans.shape - - mean = sums / count - assert sums.shape == mean.shape - - x = squares / count - mean * mean - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - # remove negative variance due to numerical errors - for i, name in enumerate(self.variables): - x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) - check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) - - stdev = np.sqrt(x) - assert sums.shape == stdev.shape - - self.summary = Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables, - has_nans=has_nans, - ) - LOG.info(f"Dataset {self.path} additions finalised.") - # self.check_statistics() - self._write(self.summary) - self.tmp_storage.delete() - - def _write(self, summary: Summary) -> None: - """Write the summary to the dataset. - - Parameters - ---------- - summary : Summary - The summary to write. - """ - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" - self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) - self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.debug(f"Wrote additions in {self.path}") - - -def multi_addition(cls: type) -> type: - """Create a class to handle multiple additions. - - Parameters - ---------- - cls : type - The class to handle additions. - - Returns - ------- - type - The class to handle multiple additions. - """ - - class MultiAdditions: - def __init__(self, *args, **kwargs: Any): - self.actors = [] - - for k in kwargs.pop("delta", []): - self.actors.append(cls(*args, delta=k, **kwargs)) - - if not self.actors: - LOG.warning("No delta found in kwargs, no additions will be computed.") - - def run(self) -> None: - """Run the additions.""" - for actor in self.actors: - actor.run() - - return MultiAdditions - - -InitAdditions = multi_addition(_InitAdditions) -RunAdditions = multi_addition(_RunAdditions) -FinaliseAdditions = multi_addition(_FinaliseAdditions) - - -class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): - """A class to compute statistics for a dataset.""" - - def __init__( - self, - path: str, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a Statistics instance. - - Parameters - ---------- - path : str - The path to the dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.use_threads = use_threads - self.progress = progress - self.statistics_temp_dir = statistics_temp_dir - - def run(self) -> None: - """Run the statistics computation.""" - start, end = ( - self.dataset.zarr_metadata["statistics_start_date"], - self.dataset.zarr_metadata["statistics_end_date"], - ) - start, end = np.datetime64(start), np.datetime64(end) - dates = self.dataset.anemoi_dataset.dates - - assert type(dates[0]) is type(start), (type(dates[0]), type(start)) - - dates = [d for d in dates if d >= start and d <= end] - dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] - variables = self.dataset.anemoi_dataset.variables - stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) - - LOG.info(stats) - - if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - "has_nans", - ]: - self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) - - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote statistics in {self.path}") - - @cached_property - def allow_nans(self) -> bool | list: - """Check if NaNs are allowed.""" - import zarr - - z = zarr.open(self.path, mode="r") - if "allow_nans" in z.attrs: - return z.attrs["allow_nans"] - - if "variables_with_nans" in z.attrs: - return z.attrs["variables_with_nans"] - - warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") - return True - - -def chain(tasks: list) -> type: - """Create a class to chain multiple tasks. - - Parameters - ---------- - tasks : list - The list of tasks to chain. - - Returns - ------- - type - The class to chain multiple tasks. - """ - - class Chain(Actor): - def __init__(self, **kwargs: Any): - self.kwargs = kwargs - - def run(self) -> None: - """Run the chained tasks.""" - for cls in tasks: - t = cls(**self.kwargs) - t.run() - - return Chain - - -def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: - """Create a dataset creator. - - Parameters - ---------- - name : str - The name of the creator. - trace : Optional[str], optional - The trace file. - **kwargs - Additional arguments for the creator. - - Returns - ------- - Any - The dataset creator. - """ - if trace: - - enable_trace(trace) - - cls = dict( - init=Init, - load=Load, - size=Size, - patch=Patch, - statistics=Statistics, - finalise=chain([Statistics, Size, Cleanup]), - cleanup=Cleanup, - verify=Verify, - init_additions=InitAdditions, - load_additions=RunAdditions, - run_additions=RunAdditions, - finalise_additions=chain([FinaliseAdditions, Size]), - additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]), - )[name] - LOG.debug(f"Creating {cls.__name__} with {kwargs}") - return cls(**kwargs) - - -def validate_config(config: Any) -> None: - - import json - - import jsonschema - - def _tidy(d): - if isinstance(d, dict): - return {k: _tidy(v) for k, v in d.items()} - - if isinstance(d, list): - return [_tidy(v) for v in d if v is not None] - - # jsonschema does not support datetime.date - if isinstance(d, datetime.datetime): - return d.isoformat() - - if isinstance(d, datetime.date): - return d.isoformat() - - return d - - # https://json-schema.org - - with open( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "schemas", - "recipe.json", - ) - ) as f: - schema = json.load(f) - - try: - jsonschema.validate(instance=_tidy(config), schema=schema) - except jsonschema.exceptions.ValidationError as e: - LOG.error("❌ Config validation failed (jsonschema):") - LOG.error(e.message) - raise - - -def config_to_python(config: Any) -> Any: - - from anemoi.datasets.create.create.python import PythonScript - - raw_config = config - - config = loader_config(config) - - input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - - code = PythonScript() - x = input.python_code(code) - code = code.source_code(x, raw_config) - - try: - import black - - return black.format_str(code, mode=black.Mode()) - # except ImportError: - except Exception: - LOG.warning("Black not installed, skipping formatting") - return code diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index e5aab1bd6..62785c9cd 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -213,7 +213,7 @@ def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): os.rename(tmp_path, out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields import json_tidy + from anemoi.datasets.create.fields.actors import json_tidy os.makedirs(self.path, exist_ok=True) @@ -257,7 +257,7 @@ def write(self, i, data, **kwargs): ds.to_netcdf(out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields import json_tidy + from anemoi.datasets.create.fields.actors import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -295,7 +295,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create.fields import json_tidy + from anemoi.datasets.create.fields.actors import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index d5cc3585c..4ca48f98d 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create.fields import creator_factory +from anemoi.datasets.create.fields.actors import creator_factory class TestingContext: From 68fdf5b1e6630039fcab5776ef7ed9a8259860a4 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:38:40 +0000 Subject: [PATCH 139/212] refactor --- src/anemoi/datasets/create/fields/__init__.py | 0 src/anemoi/datasets/create/fields/actors.py | 1691 +++++++++++++++++ 2 files changed, 1691 insertions(+) create mode 100644 src/anemoi/datasets/create/fields/__init__.py create mode 100644 src/anemoi/datasets/create/fields/actors.py diff --git a/src/anemoi/datasets/create/fields/__init__.py b/src/anemoi/datasets/create/fields/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/create/fields/actors.py b/src/anemoi/datasets/create/fields/actors.py new file mode 100644 index 000000000..6301d45ee --- /dev/null +++ b/src/anemoi/datasets/create/fields/actors.py @@ -0,0 +1,1691 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import json +import logging +import os +import time +import uuid +import warnings +from functools import cached_property +from typing import Any + +import cftime +import numpy as np +import tqdm +import zarr +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.humanize import compress_dates +from anemoi.utils.humanize import seconds_to_human +from anemoi.utils.sanitise import sanitise +from earthkit.data.core.order import build_remapping + +from anemoi.datasets import MissingDateError +from anemoi.datasets import open_dataset +from anemoi.datasets.create.check import DatasetName +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.config import build_output +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.fields.context import FieldContext +from anemoi.datasets.create.input import InputBuilder +from anemoi.datasets.create.input.trace import enable_trace +from anemoi.datasets.create.persistent import build_storage +from anemoi.datasets.create.statistics import Summary +from anemoi.datasets.create.statistics import TmpStatistics +from anemoi.datasets.create.statistics import check_variance +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.statistics import default_statistics_dates +from anemoi.datasets.create.statistics import fix_variance +from anemoi.datasets.create.utils import normalize_and_check_dates +from anemoi.datasets.create.writer import ViewCacheArray +from anemoi.datasets.data.misc import as_first_date +from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.dates.groups import Groups + +LOG = logging.getLogger(__name__) + +VERSION = "0.30" + + +def json_tidy(o: Any) -> Any: + """Convert various types to JSON serializable format. + + Parameters + ---------- + o : Any + The object to convert. + + Returns + ------- + Any + The JSON serializable object. + """ + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.timedelta): + return frequency_to_string(o) + + if isinstance(o, cftime.DatetimeJulian): + import pandas as pd + + o = pd.Timestamp( + o.year, + o.month, + o.day, + o.hour, + o.minute, + o.second, + ) + return o.isoformat() + + if isinstance(o, (np.float32, np.float64)): + return float(o) + + raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") + + +def build_statistics_dates( + dates: list[datetime.datetime], + start: datetime.datetime | None, + end: datetime.datetime | None, +) -> tuple[str, str]: + """Compute the start and end dates for the statistics. + + Parameters + ---------- + dates : list of datetime.datetime + The list of dates. + start : Optional[datetime.datetime] + The start date. + end : Optional[datetime.datetime] + The end date. + + Returns + ------- + tuple of str + The start and end dates in ISO format. + """ + # if not specified, use the default statistics dates + default_start, default_end = default_statistics_dates(dates) + if start is None: + start = default_start + if end is None: + end = default_end + + # in any case, adapt to the actual dates in the dataset + start = as_first_date(start, dates) + end = as_last_date(end, dates) + + # and convert to datetime to isoformat + start = start.astype(datetime.datetime) + end = end.astype(datetime.datetime) + return (start.isoformat(), end.isoformat()) + + +def _path_readable(path: str) -> bool: + """Check if the path is readable. + + Parameters + ---------- + path : str + The path to check. + + Returns + ------- + bool + True if the path is readable, False otherwise. + """ + import zarr + + try: + zarr.open(path, "r") + return True + except zarr.errors.PathNotFoundError: + return False + + +class Dataset: + """A class to represent a dataset.""" + + def __init__(self, path: str): + """Initialize a Dataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + self.path = path + + _, ext = os.path.splitext(self.path) + if ext != ".zarr": + raise ValueError(f"Unsupported extension={ext} for path={self.path}") + + def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: + """Add a dataset to the Zarr store. + + Parameters + ---------- + mode : str, optional + The mode to open the Zarr store. + **kwargs + Additional arguments for the dataset. + + Returns + ------- + zarr.Array + The added dataset. + """ + import zarr + + z = zarr.open(self.path, mode=mode) + from anemoi.datasets.create.zarr import add_zarr_dataset + + return add_zarr_dataset(zarr_root=z, **kwargs) + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + import zarr + + LOG.debug(f"Updating metadata {kwargs}") + z = zarr.open(self.path, mode="w+") + for k, v in kwargs.items(): + if isinstance(v, np.datetime64): + v = v.astype(datetime.datetime) + if isinstance(v, datetime.date): + v = v.isoformat() + z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) + + @cached_property + def anemoi_dataset(self) -> Any: + """Get the Anemoi dataset.""" + return open_dataset(self.path) + + @cached_property + def zarr_metadata(self) -> dict: + """Get the Zarr metadata.""" + import zarr + + return dict(zarr.open(self.path, mode="r").attrs) + + def print_info(self) -> None: + """Print information about the dataset.""" + import zarr + + z = zarr.open(self.path, mode="r") + try: + LOG.info(z["data"].info) + except Exception as e: + LOG.info(e) + + def get_zarr_chunks(self) -> tuple: + """Get the chunks of the Zarr dataset. + + Returns + ------- + tuple + The chunks of the Zarr dataset. + """ + import zarr + + z = zarr.open(self.path, mode="r") + return z["data"].chunks + + def check_name( + self, + resolution: str, + dates: list[datetime.datetime], + frequency: datetime.timedelta, + raise_exception: bool = True, + is_test: bool = False, + ) -> None: + """Check the name of the dataset. + + Parameters + ---------- + resolution : str + The resolution of the dataset. + dates : list of datetime.datetime + The dates of the dataset. + frequency : datetime.timedelta + The frequency of the dataset. + raise_exception : bool, optional + Whether to raise an exception if the name is invalid. + is_test : bool, optional + Whether this is a test. + """ + basename, _ = os.path.splitext(os.path.basename(self.path)) + try: + DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() + except Exception as e: + if raise_exception and not is_test: + raise e + else: + LOG.warning(f"Dataset name error: {e}") + + def get_main_config(self) -> Any: + """Get the main configuration of the dataset. + + Returns + ------- + Any + The main configuration. + """ + import zarr + + z = zarr.open(self.path, mode="r") + config = loader_config(z.attrs.get("_create_yaml_config")) + + if "env" in config: + for k, v in config["env"].items(): + LOG.info(f"Setting env variable {k}={v}") + os.environ[k] = str(v) + + return config + + +class WritableDataset(Dataset): + """A class to represent a writable dataset.""" + + def __init__(self, path: str): + """Initialize a WritableDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="r+") + + @cached_property + def data_array(self) -> Any: + """Get the data array of the dataset.""" + import zarr + + return zarr.open(self.path, mode="r+")["data"] + + +class NewDataset(Dataset): + """A class to represent a new dataset.""" + + def __init__(self, path: str, overwrite: bool = False): + """Initialize a NewDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + overwrite : bool, optional + Whether to overwrite the existing dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="w") + self.z.create_group("_build") + + +class Actor: # TODO: rename to Creator + """A base class for dataset creation actors.""" + + dataset_class = WritableDataset + + def __init__(self, path: str, cache: str | None = None): + """Initialize an Actor instance. + + Parameters + ---------- + path : str + The path to the dataset. + cache : Optional[str], optional + The cache directory. + """ + # Catch all floating point errors, including overflow, sqrt(<0), etc + np.seterr(all="raise", under="warn") + + self.path = path + self.cache = cache + self.dataset = self.dataset_class(self.path) + + def run(self) -> None: + """Run the actor.""" + # to be implemented in the sub-classes + raise NotImplementedError() + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + self.dataset.update_metadata(**kwargs) + + def _cache_context(self) -> Any: + """Get the cache context. + + Returns + ------- + Any + The cache context. + """ + from anemoi.datasets.create.utils import cache_context + + return cache_context(self.cache) + + def check_unkown_kwargs(self, kwargs: dict) -> None: + """Check for unknown keyword arguments. + + Parameters + ---------- + kwargs : dict + The keyword arguments. + """ + # remove this latter + LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") + + def read_dataset_metadata(self, path: str) -> None: + """Read the metadata of the dataset. + + Parameters + ---------- + path : str + The path to the dataset. + """ + ds = open_dataset(path) + self.dataset_shape = ds.shape + self.variables_names = ds.variables + assert len(self.variables_names) == ds.shape[1], self.dataset_shape + self.dates = ds.dates + + self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) + + def check_missing_dates(expected: list[np.datetime64]) -> None: + """Check if the missing dates in the dataset match the expected dates. + + Parameters + ---------- + expected : list of np.datetime64 + The expected missing dates. + + Raises + ------ + ValueError + If the missing dates in the dataset do not match the expected dates. + """ + import zarr + + z = zarr.open(path, "r") + missing_dates = z.attrs.get("missing_dates", []) + missing_dates = sorted([np.datetime64(d) for d in missing_dates]) + if missing_dates != expected: + LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") + LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") + LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") + raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") + + check_missing_dates(self.missing_dates) + + +class Patch(Actor): + """A class to apply patches to a dataset.""" + + def __init__(self, path: str, options: dict = None, **kwargs: Any): + """Initialize a Patch instance. + + Parameters + ---------- + path : str + The path to the dataset. + options : dict, optional + The patch options. + """ + self.path = path + self.options = options or {} + + def run(self) -> None: + """Run the patch.""" + from anemoi.datasets.create.patch import apply_patch + + apply_patch(self.path, **self.options) + + +class Size(Actor): + """A class to compute the size of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Size instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the size computation.""" + from anemoi.datasets.create.size import compute_directory_sizes + + metadata = compute_directory_sizes(self.path) + self.update_metadata(**metadata) + + # Look for constant fields + ds = open_dataset(self.path) + constants = ds.computed_constant_fields() + + variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() + for k in constants: + variables_metadata[k]["constant_in_time"] = True + + self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) + + +class HasRegistryMixin: + """A mixin class to provide registry functionality.""" + + @cached_property + def registry(self) -> Any: + """Get the registry.""" + from anemoi.datasets.create.zarr import ZarrBuiltRegistry + + return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) + + +class HasStatisticTempMixin: + """A mixin class to provide temporary statistics functionality.""" + + @cached_property + def tmp_statistics(self) -> TmpStatistics: + """Get the temporary statistics.""" + directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") + return TmpStatistics(directory) + + +class HasElementForDataMixin: + """A mixin class to provide element creation functionality for data.""" + + def create_elements(self, config: Any) -> None: + """Create elements for the dataset. + + Parameters + ---------- + config : Any + The configuration. + """ + assert self.registry + assert self.tmp_statistics + + LOG.info(dict(config.dates)) + + self.groups = Groups(**config.dates) + LOG.info(self.groups) + + self.output = build_output(config.output, parent=self) + + self.context = FieldContext( + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=config.build.use_grib_paramid, + ) + + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + ) + LOG.debug("✅ INPUT_BUILDER") + LOG.debug(self.input) + + +class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to initialize a new dataset.""" + + dataset_class = NewDataset + + def __init__( + self, + path: str, + config: dict, + check_name: bool = False, + overwrite: bool = False, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + test: bool = False, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize an Init instance. + + Parameters + ---------- + path : str + The path to the dataset. + config : dict + The configuration. + check_name : bool, optional + Whether to check the dataset name. + overwrite : bool, optional + Whether to overwrite the existing dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + test : bool, optional + Whether this is a test. + cache : Optional[str], optional + The cache directory. + """ + if _path_readable(path) and not overwrite: + raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + + super().__init__(path, cache=cache) + self.config = config + self.check_name = check_name + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.test = test + + self.main_config = loader_config(config, is_test=test) + + # self.registry.delete() ?? + self.tmp_statistics.delete() + + assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by + self.create_elements(self.main_config) + + LOG.info(f"Groups: {self.groups}") + + # window = self.main_config.dates.get("window") + + one_date = self.groups.one_date() + + self.minimal_input = self.input.select(self.context, one_date) + + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") + LOG.info(self.minimal_input) + + def run(self) -> int: + """Run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + with self._cache_context(): + return self._run() + + def _run(self) -> int: + """Internal method to run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + """Create an empty dataset of the right final shape. + + Read a small part of the data to get the shape of the data and the resolution and more metadata. + """ + + LOG.info("Config loaded ok:") + # LOG.info(self.main_config) + + dates = self.groups.provider.values + frequency = self.groups.provider.frequency + missing = self.groups.provider.missing + + assert isinstance(frequency, datetime.timedelta), frequency + + LOG.info(f"Found {len(dates)} datetimes.") + LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") + LOG.info(f"Missing dates: {len(missing)}") + lengths = tuple(len(g) for g in self.groups) + + variables = self.minimal_input.variables + LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") + + variables_with_nans = self.main_config.statistics.get("allow_nans", []) + + ensembles = self.minimal_input.ensembles + LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") + + grid_points = self.minimal_input.grid_points + LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") + + resolution = self.minimal_input.resolution + LOG.info(f"{resolution=}") + + coords = self.minimal_input.coords + coords["dates"] = dates + total_shape = self.minimal_input.shape + total_shape[0] = len(dates) + LOG.info(f"total_shape = {total_shape}") + + chunks = self.output.get_chunking(coords) + LOG.info(f"{chunks=}") + dtype = self.output.dtype + + LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") + + metadata = {} + metadata["uuid"] = str(uuid.uuid4()) + + metadata.update(self.main_config.get("add_metadata", {})) + + metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() + + recipe = sanitise(self.main_config.get_serialisable_dict()) + + # Remove stuff added by prepml + for k in [ + "build_dataset", + "config_format_version", + "config_path", + "dataset_status", + "ecflow", + "metadata", + "platform", + "reading_chunks", + "upload", + ]: + recipe.pop(k, None) + + metadata["recipe"] = recipe + + metadata["description"] = self.main_config.description + metadata["licence"] = self.main_config["licence"] + metadata["attribution"] = self.main_config["attribution"] + + metadata["remapping"] = self.output.remapping + metadata["order_by"] = self.output.order_by_as_list + metadata["flatten_grid"] = self.output.flatten_grid + + metadata["ensemble_dimension"] = len(ensembles) + metadata["variables"] = variables + metadata["variables_with_nans"] = variables_with_nans + metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) + metadata["resolution"] = resolution + + metadata["data_request"] = self.minimal_input.data_request + metadata["field_shape"] = self.minimal_input.field_shape + metadata["proj_string"] = self.minimal_input.proj_string + metadata["variables_metadata"] = self.minimal_input.variables_metadata + + metadata["start_date"] = dates[0].isoformat() + metadata["end_date"] = dates[-1].isoformat() + metadata["frequency"] = frequency + metadata["missing_dates"] = [_.isoformat() for _ in missing] + metadata["origins"] = self.minimal_input.origins + + metadata["version"] = VERSION + + self.dataset.check_name( + raise_exception=self.check_name, + is_test=self.test, + resolution=resolution, + dates=dates, + frequency=frequency, + ) + + if len(dates) != total_shape[0]: + raise ValueError( + f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " + f"does not match data shape {total_shape[0]}. {total_shape=}" + ) + + dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) + + metadata.update(self.main_config.get("force_metadata", {})) + + ############################################################### + # write metadata + ############################################################### + + self.update_metadata(**metadata) + + self.dataset.add_dataset( + name="data", + chunks=chunks, + dtype=dtype, + shape=total_shape, + dimensions=("time", "variable", "ensemble", "cell"), + ) + self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) + self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) + self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) + + self.registry.create(lengths=lengths) + self.tmp_statistics.create(exist_ok=False) + self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) + + statistics_start, statistics_end = build_statistics_dates( + dates, + self.main_config.statistics.get("start"), + self.main_config.statistics.get("end"), + ) + self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) + LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") + + self.registry.add_to_history("init finished") + + assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) + + # Return the number of groups to process, so we can show a nice progress bar + return len(lengths) + + +class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to load data into a dataset.""" + + def __init__( + self, + path: str, + parts: str | None = None, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize a Load instance. + + Parameters + ---------- + path : str + The path to the dataset. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + cache : Optional[str], optional + The cache directory. + """ + super().__init__(path, cache=cache) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.parts = parts + self.dataset = WritableDataset(self.path) + + self.main_config = self.dataset.get_main_config() + self.create_elements(self.main_config) + self.read_dataset_metadata(self.dataset.path) + + total = len(self.registry.get_flags()) + self.chunk_filter = ChunkFilter(parts=self.parts, total=total) + + self.data_array = self.dataset.data_array + self.n_groups = len(self.groups) + + def run(self) -> None: + """Run the data loading.""" + with self._cache_context(): + self._run() + + def _run(self) -> None: + """Internal method to run the data loading.""" + for igroup, group in enumerate(self.groups): + if not self.chunk_filter(igroup): + continue + if self.registry.get_flag(igroup): + LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") + continue + + # assert isinstance(group[0], datetime.datetime), type(group[0]) + LOG.debug(f"Building data for group {igroup}/{self.n_groups}") + + result = self.input.select(self.context, argument=group) + assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) + + # There are several groups. + # There is one result to load for each group. + self.load_result(result) + self.registry.set_flag(igroup) + + self.registry.add_provenance(name="provenance_load") + self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) + + self.dataset.print_info() + + def load_result(self, result: Any) -> None: + """Load the result into the dataset. + + Parameters + ---------- + result : Any + The result to load. + """ + # There is one cube to load for each result. + dates = list(result.group_of_dates) + + LOG.debug(f"Loading cube for {len(dates)} dates") + + cube = result.get_cube() + shape = cube.extended_user_shape + dates_in_data = cube.user_coords["valid_datetime"] + + LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") + + def check_shape(cube, dates, dates_in_data): + if cube.extended_user_shape[0] != len(dates): + print( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + print("Requested dates", compress_dates(dates)) + print("Cube dates", compress_dates(dates_in_data)) + + a = {as_datetime(_) for _ in dates} + b = {as_datetime(_) for _ in dates_in_data} + + print("Missing dates", compress_dates(a - b)) + print("Extra dates", compress_dates(b - a)) + + raise ValueError( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + + check_shape(cube, dates, dates_in_data) + + def check_dates_in_data(dates_in_data, requested_dates): + _requested_dates = [np.datetime64(_) for _ in requested_dates] + _dates_in_data = [np.datetime64(_) for _ in dates_in_data] + if _dates_in_data != _requested_dates: + LOG.error("Dates in data are not the requested ones:") + + dates_in_data = set(dates_in_data) + requested_dates = set(requested_dates) + + missing = sorted(requested_dates - dates_in_data) + extra = sorted(dates_in_data - requested_dates) + + if missing: + LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") + if extra: + LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") + + raise ValueError("Dates in data are not the requested ones") + + check_dates_in_data(dates_in_data, dates) + + def dates_to_indexes(dates, all_dates): + x = np.array(dates, dtype=np.datetime64) + y = np.array(all_dates, dtype=np.datetime64) + bitmap = np.isin(x, y) + return np.where(bitmap)[0] + + indexes = dates_to_indexes(self.dates, dates_in_data) + + array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) + LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") + self.load_cube(cube, array) + + stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) + self.tmp_statistics.write(indexes, stats, dates=dates_in_data) + LOG.info("Flush data array") + array.flush() + LOG.info("Flushed data array") + + def _get_allow_nans(self) -> bool | list: + """Get the allow_nans configuration. + + Returns + ------- + bool | list + The allow_nans configuration. + """ + config = self.main_config + if "allow_nans" in config.build: + return config.build.allow_nans + + return config.statistics.get("allow_nans", []) + + def load_cube(self, cube: Any, array: ViewCacheArray) -> None: + """Load the cube into the array. + + Parameters + ---------- + cube : Any + The cube to load. + array : ViewCacheArray + The array to load into. + """ + # There are several cubelets for each cube + start = time.time() + load = 0 + save = 0 + + reading_chunks = None + total = cube.count(reading_chunks) + LOG.debug(f"Loading datacube: {cube}") + + def position(x: Any) -> int | None: + if isinstance(x, str) and "/" in x: + x = x.split("/") + return int(x[0]) + return None + + bar = tqdm.tqdm( + iterable=cube.iterate_cubelets(reading_chunks), + total=total, + desc=f"Loading datacube {cube}", + position=position(self.parts), + ) + for i, cubelet in enumerate(bar): + bar.set_description(f"Loading {i}/{total}") + + now = time.time() + data = cubelet.to_numpy() + local_indexes = cubelet.coords + load += time.time() - now + + name = self.variables_names[local_indexes[1]] + check_data_values( + data[:], + name=name, + log=[i, data.shape, local_indexes], + allow_nans=self._get_allow_nans(), + ) + + now = time.time() + array[local_indexes] = data + save += time.time() - now + + now = time.time() + save += time.time() - now + LOG.debug( + f"Elapsed: {seconds_to_human(time.time() - start)}, " + f"load time: {seconds_to_human(load)}, " + f"write time: {seconds_to_human(save)}." + ) + + +class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): + """A class to clean up temporary data and registry entries.""" + + def __init__( + self, + path: str, + statistics_temp_dir: str | None = None, + delta: list = [], + use_threads: bool = False, + **kwargs: Any, + ): + """Initialize a Cleanup instance. + + Parameters + ---------- + path : str + The path to the dataset. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + delta : list, optional + The delta values. + use_threads : bool, optional + Whether to use threads. + """ + super().__init__(path) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.additinon_temp_dir = statistics_temp_dir + self.actors = [ + _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) + for d in delta + ] + + def run(self) -> None: + """Run the cleanup.""" + + self.tmp_statistics.delete() + self.registry.clean() + for actor in self.actors: + actor.cleanup() + + +class Verify(Actor): + """A class to verify the integrity of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Verify instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the verification.""" + LOG.info(f"Verifying dataset at {self.path}") + LOG.info(str(self.dataset.anemoi_dataset)) + + +class AdditionsMixin: + """A mixin class to handle dataset additions.""" + + def skip(self) -> bool: + """Check if the additions should be skipped. + + Returns + ------- + bool + Whether to skip the additions. + """ + frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + if not self.delta.total_seconds() % frequency.total_seconds() == 0: + LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") + return True + + if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: + LOG.warning(f"Additions are disabled for {self.path} in the recipe.") + return True + + return False + + @cached_property + def tmp_storage_path(self) -> str: + """Get the path to the temporary storage.""" + name = "storage_for_additions" + if self.delta: + name += frequency_to_string(self.delta) + return os.path.join(f"{self.path}.{name}.tmp") + + def read_from_dataset(self) -> None: + """Read data from the dataset.""" + self.variables = self.dataset.anemoi_dataset.variables + self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + start = self.dataset.zarr_metadata["statistics_start_date"] + end = self.dataset.zarr_metadata["statistics_end_date"] + self.start = datetime.datetime.fromisoformat(start) + self.end = datetime.datetime.fromisoformat(end) + + ds = open_dataset(self.path, start=self.start, end=self.end) + self.dates = ds.dates + self.total = len(self.dates) + + idelta = self.delta.total_seconds() // self.frequency.total_seconds() + assert int(idelta) == idelta, idelta + idelta = int(idelta) + self.ds = DeltaDataset(ds, idelta) + + +class DeltaDataset: + """A class to represent a dataset with delta values.""" + + def __init__(self, ds: Any, idelta: int): + """Initialize a DeltaDataset instance. + + Parameters + ---------- + ds : Any + The dataset. + idelta : int + The delta value. + """ + self.ds = ds + self.idelta = idelta + + def __getitem__(self, i: int) -> Any: + """Get an item from the dataset. + + Parameters + ---------- + i : int + The index. + + Returns + ------- + Any + The item. + """ + j = i - self.idelta + if j < 0: + raise MissingDateError(f"Missing date {j}") + return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] + + +class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to initialize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize an _InitAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + def run(self) -> None: + """Run the additions initialization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) + self.tmp_storage.delete() + self.tmp_storage.create() + LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") + + def cleanup(self) -> None: + """Clean up the temporary storage.""" + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + self.tmp_storage.delete() + LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") + + +class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to run dataset additions.""" + + def __init__( + self, + path: str, + delta: str, + parts: str | None = None, + use_threads: bool = False, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a _RunAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + self.parts = parts + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Writing in {self.tmp_storage_path}") + + def run(self) -> None: + """Run the additions.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.read_from_dataset() + + chunk_filter = ChunkFilter(parts=self.parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i] + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.debug(f"Dataset {self.path} additions run.") + + def allow_nans(self) -> bool: + """Check if NaNs are allowed. + + Returns + ------- + bool + Whether NaNs are allowed. + """ + if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): + return True + + variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) + if variables_with_nans is not None: + return variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") + return True + + +class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to finalize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize a _FinaliseAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Reading from {self.tmp_storage_path}.") + + def run(self) -> None: + """Run the additions finalization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}.") + return + + self.read_from_dataset() + + shape = (len(self.dates), len(self.variables)) + agg = dict( + minimum=np.full(shape, np.nan, dtype=np.float64), + maximum=np.full(shape, np.nan, dtype=np.float64), + sums=np.full(shape, np.nan, dtype=np.float64), + squares=np.full(shape, np.nan, dtype=np.float64), + count=np.full(shape, -1, dtype=np.int64), + has_nans=np.full(shape, False, dtype=np.bool_), + ) + LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") + + found = set() + ifound = set() + missing = set() + for _date, (date, i, stats) in self.tmp_storage.items(): + assert _date == date + if stats == "missing": + missing.add(date) + continue + + assert date not in found, f"Duplicates found {date}" + found.add(date) + ifound.add(i) + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k][i, ...] = stats[k] + + assert len(found) + len(missing) == len(self.dates), ( + len(found), + len(missing), + len(self.dates), + ) + assert found.union(missing) == set(self.dates), ( + found, + missing, + set(self.dates), + ) + + if len(ifound) < 2: + LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") + self.tmp_storage.delete() + return + + mask = sorted(list(ifound)) + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k] = agg[k][mask, ...] + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + assert agg[k].shape == agg["count"].shape, ( + agg[k].shape, + agg["count"].shape, + ) + + minimum = np.nanmin(agg["minimum"], axis=0) + maximum = np.nanmax(agg["maximum"], axis=0) + sums = np.nansum(agg["sums"], axis=0) + squares = np.nansum(agg["squares"], axis=0) + count = np.nansum(agg["count"], axis=0) + has_nans = np.any(agg["has_nans"], axis=0) + + assert sums.shape == count.shape + assert sums.shape == squares.shape + assert sums.shape == minimum.shape + assert sums.shape == maximum.shape + assert sums.shape == has_nans.shape + + mean = sums / count + assert sums.shape == mean.shape + + x = squares / count - mean * mean + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + # remove negative variance due to numerical errors + for i, name in enumerate(self.variables): + x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) + check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) + + stdev = np.sqrt(x) + assert sums.shape == stdev.shape + + self.summary = Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables, + has_nans=has_nans, + ) + LOG.info(f"Dataset {self.path} additions finalised.") + # self.check_statistics() + self._write(self.summary) + self.tmp_storage.delete() + + def _write(self, summary: Summary) -> None: + """Write the summary to the dataset. + + Parameters + ---------- + summary : Summary + The summary to write. + """ + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" + self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) + self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") + LOG.debug(f"Wrote additions in {self.path}") + + +def multi_addition(cls: type) -> type: + """Create a class to handle multiple additions. + + Parameters + ---------- + cls : type + The class to handle additions. + + Returns + ------- + type + The class to handle multiple additions. + """ + + class MultiAdditions: + def __init__(self, *args, **kwargs: Any): + self.actors = [] + + for k in kwargs.pop("delta", []): + self.actors.append(cls(*args, delta=k, **kwargs)) + + if not self.actors: + LOG.warning("No delta found in kwargs, no additions will be computed.") + + def run(self) -> None: + """Run the additions.""" + for actor in self.actors: + actor.run() + + return MultiAdditions + + +InitAdditions = multi_addition(_InitAdditions) +RunAdditions = multi_addition(_RunAdditions) +FinaliseAdditions = multi_addition(_FinaliseAdditions) + + +class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): + """A class to compute statistics for a dataset.""" + + def __init__( + self, + path: str, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a Statistics instance. + + Parameters + ---------- + path : str + The path to the dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.use_threads = use_threads + self.progress = progress + self.statistics_temp_dir = statistics_temp_dir + + def run(self) -> None: + """Run the statistics computation.""" + start, end = ( + self.dataset.zarr_metadata["statistics_start_date"], + self.dataset.zarr_metadata["statistics_end_date"], + ) + start, end = np.datetime64(start), np.datetime64(end) + dates = self.dataset.anemoi_dataset.dates + + assert type(dates[0]) is type(start), (type(dates[0]), type(start)) + + dates = [d for d in dates if d >= start and d <= end] + dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] + variables = self.dataset.anemoi_dataset.variables + stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) + + LOG.info(stats) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") + + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: + self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) + + self.registry.add_to_history("compute_statistics_end") + LOG.info(f"Wrote statistics in {self.path}") + + @cached_property + def allow_nans(self) -> bool | list: + """Check if NaNs are allowed.""" + import zarr + + z = zarr.open(self.path, mode="r") + if "allow_nans" in z.attrs: + return z.attrs["allow_nans"] + + if "variables_with_nans" in z.attrs: + return z.attrs["variables_with_nans"] + + warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") + return True + + +def chain(tasks: list) -> type: + """Create a class to chain multiple tasks. + + Parameters + ---------- + tasks : list + The list of tasks to chain. + + Returns + ------- + type + The class to chain multiple tasks. + """ + + class Chain(Actor): + def __init__(self, **kwargs: Any): + self.kwargs = kwargs + + def run(self) -> None: + """Run the chained tasks.""" + for cls in tasks: + t = cls(**self.kwargs) + t.run() + + return Chain + + +def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: + """Create a dataset creator. + + Parameters + ---------- + name : str + The name of the creator. + trace : Optional[str], optional + The trace file. + **kwargs + Additional arguments for the creator. + + Returns + ------- + Any + The dataset creator. + """ + if trace: + + enable_trace(trace) + + cls = dict( + init=Init, + load=Load, + size=Size, + patch=Patch, + statistics=Statistics, + finalise=chain([Statistics, Size, Cleanup]), + cleanup=Cleanup, + verify=Verify, + init_additions=InitAdditions, + load_additions=RunAdditions, + run_additions=RunAdditions, + finalise_additions=chain([FinaliseAdditions, Size]), + additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]), + )[name] + LOG.debug(f"Creating {cls.__name__} with {kwargs}") + return cls(**kwargs) + + +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise + + +def config_to_python(config: Any) -> Any: + + from anemoi.datasets.create.create.python import PythonScript + + raw_config = config + + config = loader_config(config) + + input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) + + code = PythonScript() + x = input.python_code(code) + code = code.source_code(x, raw_config) + + try: + import black + + return black.format_str(code, mode=black.Mode()) + # except ImportError: + except Exception: + LOG.warning("Black not installed, skipping formatting") + return code From 64ba82af6544f41d41af71fef02716e2e3d9f569 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:40:45 +0000 Subject: [PATCH 140/212] refactor --- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/recipe/__init__.py | 4 ++-- .../datasets/commands/recipe/migrate.py | 2 +- .../create/fields/{actors.py => tasks.py} | 24 +++++++++---------- .../data/records/backends/__init__.py | 6 ++--- tests/create/utils/create.py | 2 +- 6 files changed, 20 insertions(+), 20 deletions(-) rename src/anemoi/datasets/create/fields/{actors.py => tasks.py} (98%) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 787f0fc89..215e5ca0e 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create.fields.actors import creator_factory + from anemoi.datasets.create.fields.tasks import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 5c7b6f176..69831e6da 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -18,8 +18,8 @@ from anemoi.datasets.commands import Command from anemoi.datasets.commands.recipe.format import format_recipe from anemoi.datasets.commands.recipe.migrate import migrate_recipe -from anemoi.datasets.create.fields.actors import config_to_python -from anemoi.datasets.create.fields.actors import validate_config +from anemoi.datasets.create.fields.tasks import config_to_python +from anemoi.datasets.create.fields.tasks import validate_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index dbfde4143..071dbab89 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.create.fields.actors import validate_config +from anemoi.datasets.create.fields.tasks import validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/fields/actors.py b/src/anemoi/datasets/create/fields/tasks.py similarity index 98% rename from src/anemoi/datasets/create/fields/actors.py rename to src/anemoi/datasets/create/fields/tasks.py index 6301d45ee..9412f7186 100644 --- a/src/anemoi/datasets/create/fields/actors.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -352,7 +352,7 @@ def __init__(self, path: str, overwrite: bool = False): self.z.create_group("_build") -class Actor: # TODO: rename to Creator +class Task: # TODO: rename to Creator """A base class for dataset creation actors.""" dataset_class = WritableDataset @@ -455,7 +455,7 @@ def check_missing_dates(expected: list[np.datetime64]) -> None: check_missing_dates(self.missing_dates) -class Patch(Actor): +class Patch(Task): """A class to apply patches to a dataset.""" def __init__(self, path: str, options: dict = None, **kwargs: Any): @@ -478,7 +478,7 @@ def run(self) -> None: apply_patch(self.path, **self.options) -class Size(Actor): +class Size(Task): """A class to compute the size of a dataset.""" def __init__(self, path: str, **kwargs: Any): @@ -566,7 +566,7 @@ def create_elements(self, config: Any) -> None: LOG.debug(self.input) -class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): +class Init(Task, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): """A class to initialize a new dataset.""" dataset_class = NewDataset @@ -808,7 +808,7 @@ def _run(self) -> int: return len(lengths) -class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): +class Load(Task, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): """A class to load data into a dataset.""" def __init__( @@ -1037,7 +1037,7 @@ def position(x: Any) -> int | None: ) -class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): +class Cleanup(Task, HasRegistryMixin, HasStatisticTempMixin): """A class to clean up temporary data and registry entries.""" def __init__( @@ -1079,7 +1079,7 @@ def run(self) -> None: actor.cleanup() -class Verify(Actor): +class Verify(Task): """A class to verify the integrity of a dataset.""" def __init__(self, path: str, **kwargs: Any): @@ -1182,7 +1182,7 @@ def __getitem__(self, i: int) -> Any: return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] -class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): +class _InitAdditions(Task, HasRegistryMixin, AdditionsMixin): """A class to initialize dataset additions.""" def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): @@ -1222,7 +1222,7 @@ def cleanup(self) -> None: LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") -class _RunAdditions(Actor, HasRegistryMixin, AdditionsMixin): +class _RunAdditions(Task, HasRegistryMixin, AdditionsMixin): """A class to run dataset additions.""" def __init__( @@ -1298,7 +1298,7 @@ def allow_nans(self) -> bool: return True -class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): +class _FinaliseAdditions(Task, HasRegistryMixin, AdditionsMixin): """A class to finalize dataset additions.""" def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): @@ -1478,7 +1478,7 @@ def run(self) -> None: FinaliseAdditions = multi_addition(_FinaliseAdditions) -class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): +class Statistics(Task, HasStatisticTempMixin, HasRegistryMixin): """A class to compute statistics for a dataset.""" def __init__( @@ -1573,7 +1573,7 @@ def chain(tasks: list) -> type: The class to chain multiple tasks. """ - class Chain(Actor): + class Chain(Task): def __init__(self, **kwargs: Any): self.kwargs = kwargs diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index 62785c9cd..1c9e0e96c 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -213,7 +213,7 @@ def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): os.rename(tmp_path, out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.actors import json_tidy + from anemoi.datasets.create.fields.tasks import json_tidy os.makedirs(self.path, exist_ok=True) @@ -257,7 +257,7 @@ def write(self, i, data, **kwargs): ds.to_netcdf(out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.actors import json_tidy + from anemoi.datasets.create.fields.tasks import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -295,7 +295,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.actors import json_tidy + from anemoi.datasets.create.fields.tasks import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index 4ca48f98d..b1c9ee07e 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create.fields.actors import creator_factory +from anemoi.datasets.create.fields.tasks import creator_factory class TestingContext: From 6519c5ab6f8a374465d2b68f11d7ce9fd5614857 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Sep 2025 18:44:47 +0000 Subject: [PATCH 141/212] refactor --- src/anemoi/datasets/commands/create.py | 4 ++-- src/anemoi/datasets/create/fields/tasks.py | 2 +- tests/create/utils/create.py | 22 +++++++++++----------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 215e5ca0e..8530cf127 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,11 +45,11 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create.fields.tasks import creator_factory + from anemoi.datasets.create.fields.tasks import task_factory options = {k: v for k, v in options.items() if v is not None} - c = creator_factory(what.replace("-", "_"), **options) + c = task_factory(what.replace("-", "_"), **options) result = c.run() LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 9412f7186..4f6a66818 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -1586,7 +1586,7 @@ def run(self) -> None: return Chain -def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: +def task_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: """Create a dataset creator. Parameters diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index b1c9ee07e..0addb122b 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create.fields.tasks import creator_factory +from anemoi.datasets.create.fields.tasks import task_factory class TestingContext: @@ -52,21 +52,21 @@ def create_dataset( if output is None: output = tempfile.mkdtemp(suffix=".zarr") - creator_factory("init", config=config, path=output, overwrite=True, test=is_test).run() - creator_factory("load", path=output).run() - creator_factory("finalise", path=output).run() - creator_factory("patch", path=output).run() + task_factory("init", config=config, path=output, overwrite=True, test=is_test).run() + task_factory("load", path=output).run() + task_factory("finalise", path=output).run() + task_factory("patch", path=output).run() if delta is not None: - creator_factory("init_additions", path=output, delta=delta).run() - creator_factory("run_additions", path=output, delta=delta).run() - creator_factory("finalise_additions", path=output, delta=delta).run() + task_factory("init_additions", path=output, delta=delta).run() + task_factory("run_additions", path=output, delta=delta).run() + task_factory("finalise_additions", path=output, delta=delta).run() - creator_factory("cleanup", path=output).run() + task_factory("cleanup", path=output).run() if delta is not None: - creator_factory("cleanup", path=output, delta=delta).run() + task_factory("cleanup", path=output, delta=delta).run() - creator_factory("verify", path=output).run() + task_factory("verify", path=output).run() return output From 38475ad510f0d3abae30bd816843e91a783183fa Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 30 Sep 2025 08:24:09 +0000 Subject: [PATCH 142/212] refactor --- src/anemoi/datasets/commands/create.py | 2 +- src/anemoi/datasets/create/__init__.py | 0 src/anemoi/datasets/create/fields/tasks.py | 16 ++++++++-------- src/anemoi/datasets/create/observations/tasks.py | 0 src/anemoi/datasets/create/tasks.py | 9 +++++++++ 5 files changed, 18 insertions(+), 9 deletions(-) create mode 100644 src/anemoi/datasets/create/__init__.py create mode 100644 src/anemoi/datasets/create/observations/tasks.py create mode 100644 src/anemoi/datasets/create/tasks.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 8530cf127..b7e93186a 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create.fields.tasks import task_factory + from anemoi.datasets.create.tasks import task_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 4f6a66818..0fdb746d4 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -352,8 +352,8 @@ def __init__(self, path: str, overwrite: bool = False): self.z.create_group("_build") -class Task: # TODO: rename to Creator - """A base class for dataset creation actors.""" +class Task: + """A base class for dataset creation tasks.""" dataset_class = WritableDataset @@ -1065,7 +1065,7 @@ def __init__( self.use_threads = use_threads self.statistics_temp_dir = statistics_temp_dir self.additinon_temp_dir = statistics_temp_dir - self.actors = [ + self.tasks = [ _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) for d in delta ] @@ -1075,7 +1075,7 @@ def run(self) -> None: self.tmp_statistics.delete() self.registry.clean() - for actor in self.actors: + for actor in self.tasks: actor.cleanup() @@ -1457,17 +1457,17 @@ def multi_addition(cls: type) -> type: class MultiAdditions: def __init__(self, *args, **kwargs: Any): - self.actors = [] + self.tasks = [] for k in kwargs.pop("delta", []): - self.actors.append(cls(*args, delta=k, **kwargs)) + self.tasks.append(cls(*args, delta=k, **kwargs)) - if not self.actors: + if not self.tasks: LOG.warning("No delta found in kwargs, no additions will be computed.") def run(self) -> None: """Run the additions.""" - for actor in self.actors: + for actor in self.tasks: actor.run() return MultiAdditions diff --git a/src/anemoi/datasets/create/observations/tasks.py b/src/anemoi/datasets/create/observations/tasks.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py new file mode 100644 index 000000000..cff391eab --- /dev/null +++ b/src/anemoi/datasets/create/tasks.py @@ -0,0 +1,9 @@ +def task_factory(name: str, fields: bool, trace: str | None = None, **kwargs): + if fields: + from anemoi.datasets.create.fields.tasks import task_factory as fields_task_factory + + return fields_task_factory(name, trace=trace, **kwargs) + else: + from anemoi.datasets.create.observations.tasks import task_factory as observations_task_factory + + return observations_task_factory(name, trace=trace, **kwargs) From 734a36fb7afb294ddd7bbc4a6e4c9cc5893d6d49 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 30 Sep 2025 09:59:29 +0100 Subject: [PATCH 143/212] refactor --- src/anemoi/datasets/commands/create.py | 58 +- src/anemoi/datasets/create/fields/cleanup.py | 60 ++ src/anemoi/datasets/create/fields/init.py | 293 ++++++++ src/anemoi/datasets/create/fields/load.py | 260 +++++++ src/anemoi/datasets/create/fields/tasks.py | 683 ++---------------- .../datasets/create/observations/tasks.py | 131 ++++ src/anemoi/datasets/create/tasks.py | 59 +- 7 files changed, 881 insertions(+), 663 deletions(-) create mode 100644 src/anemoi/datasets/create/fields/cleanup.py create mode 100644 src/anemoi/datasets/create/fields/init.py create mode 100644 src/anemoi/datasets/create/fields/load.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index b7e93186a..1ca332f80 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -23,25 +23,8 @@ LOG = logging.getLogger(__name__) -def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: - """Make sure `import Creator` is done in the sub-processes, and not in the main one. - - Parameters - ---------- - what : str - The task to be executed. - options : dict - Options for the task. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The result of the task. - """ +def task(what: str, fields: bool, options: dict, *args: Any, **kwargs: Any) -> Any: + """Make sure `import Creator` is done in the sub-processes, and not in the main one.""" now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") @@ -49,7 +32,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: options = {k: v for k, v in options.items() if v is not None} - c = task_factory(what.replace("-", "_"), **options) + c = task_factory(what.replace("-", "_"), fields, **options) result = c.run() LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") @@ -115,18 +98,20 @@ def serial_create(self, args: Any) -> None: options.pop("threads") options.pop("processes") - task("init", options) - task("load", options) - task("finalise", options) + fields = args.path.endswith(".zarr") or args.path.endswith(".zarr/") - task("init_additions", options) - task("run_additions", options) - task("finalise_additions", options) + task("init", fields, options) + task("load", fields, options) + task("finalise", fields, options) - task("patch", options) + task("init_additions", fields, options) + task("run_additions", fields, options) + task("finalise_additions", fields, options) - task("cleanup", options) - task("verify", options) + task("patch", fields, options) + + task("cleanup", fields, options) + task("verify", fields, options) def parallel_create(self, args: Any) -> None: """Create the dataset in parallel mode. @@ -147,6 +132,7 @@ def parallel_create(self, args: Any) -> None: threads = options.pop("threads") processes = options.pop("processes") + fields = args.path.endswith(".zarr") or args.path.endswith(".zarr/") use_threads = threads > 0 options["use_threads"] = use_threads @@ -157,7 +143,7 @@ def parallel_create(self, args: Any) -> None: ExecutorClass = ProcessPoolExecutor with ExecutorClass(max_workers=1) as executor: - total = executor.submit(task, "init", options).result() + total = executor.submit(task, "init", fields, options).result() futures = [] @@ -166,7 +152,7 @@ def parallel_create(self, args: Any) -> None: for n in range(total): opt = options.copy() opt["parts"] = f"{n+1}/{total}" - futures.append(executor.submit(task, "load", opt)) + futures.append(executor.submit(task, "load", fields, opt)) for future in tqdm.tqdm( as_completed(futures), desc="Loading", total=len(futures), colour="green", position=parallel + 1 @@ -177,7 +163,7 @@ def parallel_create(self, args: Any) -> None: executor.submit(task, "finalise", options).result() with ExecutorClass(max_workers=1) as executor: - executor.submit(task, "init-additions", options).result() + executor.submit(task, "init-additions", fields, options).result() with ExecutorClass(max_workers=parallel) as executor: for n in range(total): @@ -195,10 +181,10 @@ def parallel_create(self, args: Any) -> None: future.result() with ExecutorClass(max_workers=1) as executor: - executor.submit(task, "finalise-additions", options).result() - executor.submit(task, "patch", options).result() - executor.submit(task, "cleanup", options).result() - executor.submit(task, "verify", options).result() + executor.submit(task, "finalise-additions", fields, options).result() + executor.submit(task, "patch", fields, options).result() + executor.submit(task, "cleanup", fields, options).result() + executor.submit(task, "verify", fields, options).result() command = Create diff --git a/src/anemoi/datasets/create/fields/cleanup.py b/src/anemoi/datasets/create/fields/cleanup.py new file mode 100644 index 000000000..77b601e58 --- /dev/null +++ b/src/anemoi/datasets/create/fields/cleanup.py @@ -0,0 +1,60 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from .tasks import FieldTask +from .tasks import HasRegistryMixin +from .tasks import HasStatisticTempMixin +from .tasks import _InitAdditions + +LOG = logging.getLogger(__name__) + + +class Cleanup(FieldTask, HasRegistryMixin, HasStatisticTempMixin): + """A class to clean up temporary data and registry entries.""" + + def __init__( + self, + path: str, + statistics_temp_dir: str | None = None, + delta: list = [], + use_threads: bool = False, + **kwargs: Any, + ): + """Initialize a Cleanup instance. + + Parameters + ---------- + path : str + The path to the dataset. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + delta : list, optional + The delta values. + use_threads : bool, optional + Whether to use threads. + """ + super().__init__(path) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.additinon_temp_dir = statistics_temp_dir + self.tasks = [ + _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) + for d in delta + ] + + def run(self) -> None: + """Run the cleanup.""" + + self.tmp_statistics.delete() + self.registry.clean() + for actor in self.tasks: + actor.cleanup() diff --git a/src/anemoi/datasets/create/fields/init.py b/src/anemoi/datasets/create/fields/init.py new file mode 100644 index 000000000..094f60922 --- /dev/null +++ b/src/anemoi/datasets/create/fields/init.py @@ -0,0 +1,293 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import uuid +from typing import Any + +import zarr +from anemoi.utils.sanitise import sanitise + +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.utils import normalize_and_check_dates + +from .tasks import FieldTask +from .tasks import HasElementForDataMixin +from .tasks import HasRegistryMixin +from .tasks import HasStatisticTempMixin +from .tasks import NewDataset +from .tasks import build_statistics_dates + +LOG = logging.getLogger(__name__) + +VERSION = "0.30" + + +def _path_readable(path: str) -> bool: + """Check if the path is readable. + + Parameters + ---------- + path : str + The path to check. + + Returns + ------- + bool + True if the path is readable, False otherwise. + """ + + try: + zarr.open(path, "r") + return True + except zarr.errors.PathNotFoundError: + return False + + +class Init(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to initialize a new dataset.""" + + dataset_class = NewDataset + + def __init__( + self, + path: str, + config: dict, + check_name: bool = False, + overwrite: bool = False, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + test: bool = False, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize an Init instance. + + Parameters + ---------- + path : str + The path to the dataset. + config : dict + The configuration. + check_name : bool, optional + Whether to check the dataset name. + overwrite : bool, optional + Whether to overwrite the existing dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + test : bool, optional + Whether this is a test. + cache : Optional[str], optional + The cache directory. + """ + if _path_readable(path) and not overwrite: + raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + + super().__init__(path, cache=cache) + self.config = config + self.check_name = check_name + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.test = test + + self.main_config = loader_config(config, is_test=test) + + # self.registry.delete() ?? + self.tmp_statistics.delete() + + assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by + self.create_elements(self.main_config) + + LOG.info(f"Groups: {self.groups}") + + # window = self.main_config.dates.get("window") + + one_date = self.groups.one_date() + + self.minimal_input = self.input.select(self.context, one_date) + + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") + LOG.info(self.minimal_input) + + def run(self) -> int: + """Run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + with self._cache_context(): + return self._run() + + def _run(self) -> int: + """Internal method to run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + """Create an empty dataset of the right final shape. + + Read a small part of the data to get the shape of the data and the resolution and more metadata. + """ + + LOG.info("Config loaded ok:") + # LOG.info(self.main_config) + + dates = self.groups.provider.values + frequency = self.groups.provider.frequency + missing = self.groups.provider.missing + + assert isinstance(frequency, datetime.timedelta), frequency + + LOG.info(f"Found {len(dates)} datetimes.") + LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") + LOG.info(f"Missing dates: {len(missing)}") + lengths = tuple(len(g) for g in self.groups) + + variables = self.minimal_input.variables + LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") + + variables_with_nans = self.main_config.statistics.get("allow_nans", []) + + ensembles = self.minimal_input.ensembles + LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") + + grid_points = self.minimal_input.grid_points + LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") + + resolution = self.minimal_input.resolution + LOG.info(f"{resolution=}") + + coords = self.minimal_input.coords + coords["dates"] = dates + total_shape = self.minimal_input.shape + total_shape[0] = len(dates) + LOG.info(f"total_shape = {total_shape}") + + chunks = self.output.get_chunking(coords) + LOG.info(f"{chunks=}") + dtype = self.output.dtype + + LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") + + metadata = {} + metadata["uuid"] = str(uuid.uuid4()) + + metadata.update(self.main_config.get("add_metadata", {})) + + metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() + + recipe = sanitise(self.main_config.get_serialisable_dict()) + + # Remove stuff added by prepml + for k in [ + "build_dataset", + "config_format_version", + "config_path", + "dataset_status", + "ecflow", + "metadata", + "platform", + "reading_chunks", + "upload", + ]: + recipe.pop(k, None) + + metadata["recipe"] = recipe + + metadata["description"] = self.main_config.description + metadata["licence"] = self.main_config["licence"] + metadata["attribution"] = self.main_config["attribution"] + + metadata["remapping"] = self.output.remapping + metadata["order_by"] = self.output.order_by_as_list + metadata["flatten_grid"] = self.output.flatten_grid + + metadata["ensemble_dimension"] = len(ensembles) + metadata["variables"] = variables + metadata["variables_with_nans"] = variables_with_nans + metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) + metadata["resolution"] = resolution + + metadata["data_request"] = self.minimal_input.data_request + metadata["field_shape"] = self.minimal_input.field_shape + metadata["proj_string"] = self.minimal_input.proj_string + metadata["variables_metadata"] = self.minimal_input.variables_metadata + + metadata["start_date"] = dates[0].isoformat() + metadata["end_date"] = dates[-1].isoformat() + metadata["frequency"] = frequency + metadata["missing_dates"] = [_.isoformat() for _ in missing] + metadata["origins"] = self.minimal_input.origins + + metadata["version"] = VERSION + + self.dataset.check_name( + raise_exception=self.check_name, + is_test=self.test, + resolution=resolution, + dates=dates, + frequency=frequency, + ) + + if len(dates) != total_shape[0]: + raise ValueError( + f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " + f"does not match data shape {total_shape[0]}. {total_shape=}" + ) + + dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) + + metadata.update(self.main_config.get("force_metadata", {})) + + ############################################################### + # write metadata + ############################################################### + + self.update_metadata(**metadata) + + self.dataset.add_dataset( + name="data", + chunks=chunks, + dtype=dtype, + shape=total_shape, + dimensions=("time", "variable", "ensemble", "cell"), + ) + self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) + self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) + self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) + + self.registry.create(lengths=lengths) + self.tmp_statistics.create(exist_ok=False) + self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) + + statistics_start, statistics_end = build_statistics_dates( + dates, + self.main_config.statistics.get("start"), + self.main_config.statistics.get("end"), + ) + self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) + LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") + + self.registry.add_to_history("init finished") + + assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) + + # Return the number of groups to process, so we can show a nice progress bar + return len(lengths) diff --git a/src/anemoi/datasets/create/fields/load.py b/src/anemoi/datasets/create/fields/load.py new file mode 100644 index 000000000..bab731cb2 --- /dev/null +++ b/src/anemoi/datasets/create/fields/load.py @@ -0,0 +1,260 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import time +from typing import Any + +import numpy as np +import tqdm +from anemoi.utils.dates import as_datetime +from anemoi.utils.humanize import compress_dates +from anemoi.utils.humanize import seconds_to_human + +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.writer import ViewCacheArray + +from .tasks import FieldTask +from .tasks import HasElementForDataMixin +from .tasks import HasRegistryMixin +from .tasks import HasStatisticTempMixin +from .tasks import WritableDataset + +LOG = logging.getLogger(__name__) + + +class Load(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to load data into a dataset.""" + + def __init__( + self, + path: str, + parts: str | None = None, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize a Load instance. + + Parameters + ---------- + path : str + The path to the dataset. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + cache : Optional[str], optional + The cache directory. + """ + super().__init__(path, cache=cache) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.parts = parts + self.dataset = WritableDataset(self.path) + + self.main_config = self.dataset.get_main_config() + self.create_elements(self.main_config) + self.read_dataset_metadata(self.dataset.path) + + total = len(self.registry.get_flags()) + self.chunk_filter = ChunkFilter(parts=self.parts, total=total) + + self.data_array = self.dataset.data_array + self.n_groups = len(self.groups) + + def run(self) -> None: + """Run the data loading.""" + with self._cache_context(): + self._run() + + def _run(self) -> None: + """Internal method to run the data loading.""" + for igroup, group in enumerate(self.groups): + if not self.chunk_filter(igroup): + continue + if self.registry.get_flag(igroup): + LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") + continue + + # assert isinstance(group[0], datetime.datetime), type(group[0]) + LOG.debug(f"Building data for group {igroup}/{self.n_groups}") + + result = self.input.select(self.context, argument=group) + assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) + + # There are several groups. + # There is one result to load for each group. + self.load_result(result) + self.registry.set_flag(igroup) + + self.registry.add_provenance(name="provenance_load") + self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) + + self.dataset.print_info() + + def load_result(self, result: Any) -> None: + """Load the result into the dataset. + + Parameters + ---------- + result : Any + The result to load. + """ + # There is one cube to load for each result. + dates = list(result.group_of_dates) + + LOG.debug(f"Loading cube for {len(dates)} dates") + + cube = result.get_cube() + shape = cube.extended_user_shape + dates_in_data = cube.user_coords["valid_datetime"] + + LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") + + def check_shape(cube, dates, dates_in_data): + if cube.extended_user_shape[0] != len(dates): + print( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + print("Requested dates", compress_dates(dates)) + print("Cube dates", compress_dates(dates_in_data)) + + a = {as_datetime(_) for _ in dates} + b = {as_datetime(_) for _ in dates_in_data} + + print("Missing dates", compress_dates(a - b)) + print("Extra dates", compress_dates(b - a)) + + raise ValueError( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + + check_shape(cube, dates, dates_in_data) + + def check_dates_in_data(dates_in_data, requested_dates): + _requested_dates = [np.datetime64(_) for _ in requested_dates] + _dates_in_data = [np.datetime64(_) for _ in dates_in_data] + if _dates_in_data != _requested_dates: + LOG.error("Dates in data are not the requested ones:") + + dates_in_data = set(dates_in_data) + requested_dates = set(requested_dates) + + missing = sorted(requested_dates - dates_in_data) + extra = sorted(dates_in_data - requested_dates) + + if missing: + LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") + if extra: + LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") + + raise ValueError("Dates in data are not the requested ones") + + check_dates_in_data(dates_in_data, dates) + + def dates_to_indexes(dates, all_dates): + x = np.array(dates, dtype=np.datetime64) + y = np.array(all_dates, dtype=np.datetime64) + bitmap = np.isin(x, y) + return np.where(bitmap)[0] + + indexes = dates_to_indexes(self.dates, dates_in_data) + + array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) + LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") + self.load_cube(cube, array) + + stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) + self.tmp_statistics.write(indexes, stats, dates=dates_in_data) + LOG.info("Flush data array") + array.flush() + LOG.info("Flushed data array") + + def _get_allow_nans(self) -> bool | list: + """Get the allow_nans configuration. + + Returns + ------- + bool | list + The allow_nans configuration. + """ + config = self.main_config + if "allow_nans" in config.build: + return config.build.allow_nans + + return config.statistics.get("allow_nans", []) + + def load_cube(self, cube: Any, array: ViewCacheArray) -> None: + """Load the cube into the array. + + Parameters + ---------- + cube : Any + The cube to load. + array : ViewCacheArray + The array to load into. + """ + # There are several cubelets for each cube + start = time.time() + load = 0 + save = 0 + + reading_chunks = None + total = cube.count(reading_chunks) + LOG.debug(f"Loading datacube: {cube}") + + def position(x: Any) -> int | None: + if isinstance(x, str) and "/" in x: + x = x.split("/") + return int(x[0]) + return None + + bar = tqdm.tqdm( + iterable=cube.iterate_cubelets(reading_chunks), + total=total, + desc=f"Loading datacube {cube}", + position=position(self.parts), + ) + for i, cubelet in enumerate(bar): + bar.set_description(f"Loading {i}/{total}") + + now = time.time() + data = cubelet.to_numpy() + local_indexes = cubelet.coords + load += time.time() - now + + name = self.variables_names[local_indexes[1]] + check_data_values( + data[:], + name=name, + log=[i, data.shape, local_indexes], + allow_nans=self._get_allow_nans(), + ) + + now = time.time() + array[local_indexes] = data + save += time.time() - now + + now = time.time() + save += time.time() - now + LOG.debug( + f"Elapsed: {seconds_to_human(time.time() - start)}, " + f"load time: {seconds_to_human(load)}, " + f"write time: {seconds_to_human(save)}." + ) diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 0fdb746d4..14725e56a 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -11,34 +11,25 @@ import json import logging import os -import time -import uuid import warnings from functools import cached_property from typing import Any import cftime import numpy as np -import tqdm import zarr -from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta -from anemoi.utils.humanize import compress_dates -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.sanitise import sanitise from earthkit.data.core.order import build_remapping from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset from anemoi.datasets.create.check import DatasetName -from anemoi.datasets.create.check import check_data_values from anemoi.datasets.create.chunks import ChunkFilter from anemoi.datasets.create.config import build_output from anemoi.datasets.create.config import loader_config from anemoi.datasets.create.fields.context import FieldContext from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage from anemoi.datasets.create.statistics import Summary from anemoi.datasets.create.statistics import TmpStatistics @@ -46,15 +37,13 @@ from anemoi.datasets.create.statistics import compute_statistics from anemoi.datasets.create.statistics import default_statistics_dates from anemoi.datasets.create.statistics import fix_variance -from anemoi.datasets.create.utils import normalize_and_check_dates -from anemoi.datasets.create.writer import ViewCacheArray from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups -LOG = logging.getLogger(__name__) +from ..tasks import chain -VERSION = "0.30" +LOG = logging.getLogger(__name__) def json_tidy(o: Any) -> Any: @@ -136,28 +125,6 @@ def build_statistics_dates( return (start.isoformat(), end.isoformat()) -def _path_readable(path: str) -> bool: - """Check if the path is readable. - - Parameters - ---------- - path : str - The path to check. - - Returns - ------- - bool - True if the path is readable, False otherwise. - """ - import zarr - - try: - zarr.open(path, "r") - return True - except zarr.errors.PathNotFoundError: - return False - - class Dataset: """A class to represent a dataset.""" @@ -352,7 +319,7 @@ def __init__(self, path: str, overwrite: bool = False): self.z.create_group("_build") -class Task: +class FieldTask: """A base class for dataset creation tasks.""" dataset_class = WritableDataset @@ -455,7 +422,7 @@ def check_missing_dates(expected: list[np.datetime64]) -> None: check_missing_dates(self.missing_dates) -class Patch(Task): +class Patch(FieldTask): """A class to apply patches to a dataset.""" def __init__(self, path: str, options: dict = None, **kwargs: Any): @@ -478,7 +445,7 @@ def run(self) -> None: apply_patch(self.path, **self.options) -class Size(Task): +class Size(FieldTask): """A class to compute the size of a dataset.""" def __init__(self, path: str, **kwargs: Any): @@ -566,520 +533,7 @@ def create_elements(self, config: Any) -> None: LOG.debug(self.input) -class Init(Task, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to initialize a new dataset.""" - - dataset_class = NewDataset - - def __init__( - self, - path: str, - config: dict, - check_name: bool = False, - overwrite: bool = False, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - test: bool = False, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize an Init instance. - - Parameters - ---------- - path : str - The path to the dataset. - config : dict - The configuration. - check_name : bool, optional - Whether to check the dataset name. - overwrite : bool, optional - Whether to overwrite the existing dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - test : bool, optional - Whether this is a test. - cache : Optional[str], optional - The cache directory. - """ - if _path_readable(path) and not overwrite: - raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - - super().__init__(path, cache=cache) - self.config = config - self.check_name = check_name - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.test = test - - self.main_config = loader_config(config, is_test=test) - - # self.registry.delete() ?? - self.tmp_statistics.delete() - - assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by - self.create_elements(self.main_config) - - LOG.info(f"Groups: {self.groups}") - - # window = self.main_config.dates.get("window") - - one_date = self.groups.one_date() - - self.minimal_input = self.input.select(self.context, one_date) - - LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") - LOG.info(self.minimal_input) - - def run(self) -> int: - """Run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - with self._cache_context(): - return self._run() - - def _run(self) -> int: - """Internal method to run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - """Create an empty dataset of the right final shape. - - Read a small part of the data to get the shape of the data and the resolution and more metadata. - """ - - LOG.info("Config loaded ok:") - # LOG.info(self.main_config) - - dates = self.groups.provider.values - frequency = self.groups.provider.frequency - missing = self.groups.provider.missing - - assert isinstance(frequency, datetime.timedelta), frequency - - LOG.info(f"Found {len(dates)} datetimes.") - LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(missing)}") - lengths = tuple(len(g) for g in self.groups) - - variables = self.minimal_input.variables - LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") - - variables_with_nans = self.main_config.statistics.get("allow_nans", []) - - ensembles = self.minimal_input.ensembles - LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") - - grid_points = self.minimal_input.grid_points - LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") - - resolution = self.minimal_input.resolution - LOG.info(f"{resolution=}") - - coords = self.minimal_input.coords - coords["dates"] = dates - total_shape = self.minimal_input.shape - total_shape[0] = len(dates) - LOG.info(f"total_shape = {total_shape}") - - chunks = self.output.get_chunking(coords) - LOG.info(f"{chunks=}") - dtype = self.output.dtype - - LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") - - metadata = {} - metadata["uuid"] = str(uuid.uuid4()) - - metadata.update(self.main_config.get("add_metadata", {})) - - metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() - - recipe = sanitise(self.main_config.get_serialisable_dict()) - - # Remove stuff added by prepml - for k in [ - "build_dataset", - "config_format_version", - "config_path", - "dataset_status", - "ecflow", - "metadata", - "platform", - "reading_chunks", - "upload", - ]: - recipe.pop(k, None) - - metadata["recipe"] = recipe - - metadata["description"] = self.main_config.description - metadata["licence"] = self.main_config["licence"] - metadata["attribution"] = self.main_config["attribution"] - - metadata["remapping"] = self.output.remapping - metadata["order_by"] = self.output.order_by_as_list - metadata["flatten_grid"] = self.output.flatten_grid - - metadata["ensemble_dimension"] = len(ensembles) - metadata["variables"] = variables - metadata["variables_with_nans"] = variables_with_nans - metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) - metadata["resolution"] = resolution - - metadata["data_request"] = self.minimal_input.data_request - metadata["field_shape"] = self.minimal_input.field_shape - metadata["proj_string"] = self.minimal_input.proj_string - metadata["variables_metadata"] = self.minimal_input.variables_metadata - - metadata["start_date"] = dates[0].isoformat() - metadata["end_date"] = dates[-1].isoformat() - metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in missing] - metadata["origins"] = self.minimal_input.origins - - metadata["version"] = VERSION - - self.dataset.check_name( - raise_exception=self.check_name, - is_test=self.test, - resolution=resolution, - dates=dates, - frequency=frequency, - ) - - if len(dates) != total_shape[0]: - raise ValueError( - f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " - f"does not match data shape {total_shape[0]}. {total_shape=}" - ) - - dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) - - metadata.update(self.main_config.get("force_metadata", {})) - - ############################################################### - # write metadata - ############################################################### - - self.update_metadata(**metadata) - - self.dataset.add_dataset( - name="data", - chunks=chunks, - dtype=dtype, - shape=total_shape, - dimensions=("time", "variable", "ensemble", "cell"), - ) - self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) - self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) - self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) - - self.registry.create(lengths=lengths) - self.tmp_statistics.create(exist_ok=False) - self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - - statistics_start, statistics_end = build_statistics_dates( - dates, - self.main_config.statistics.get("start"), - self.main_config.statistics.get("end"), - ) - self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) - LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") - - self.registry.add_to_history("init finished") - - assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - - # Return the number of groups to process, so we can show a nice progress bar - return len(lengths) - - -class Load(Task, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to load data into a dataset.""" - - def __init__( - self, - path: str, - parts: str | None = None, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize a Load instance. - - Parameters - ---------- - path : str - The path to the dataset. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - cache : Optional[str], optional - The cache directory. - """ - super().__init__(path, cache=cache) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.parts = parts - self.dataset = WritableDataset(self.path) - - self.main_config = self.dataset.get_main_config() - self.create_elements(self.main_config) - self.read_dataset_metadata(self.dataset.path) - - total = len(self.registry.get_flags()) - self.chunk_filter = ChunkFilter(parts=self.parts, total=total) - - self.data_array = self.dataset.data_array - self.n_groups = len(self.groups) - - def run(self) -> None: - """Run the data loading.""" - with self._cache_context(): - self._run() - - def _run(self) -> None: - """Internal method to run the data loading.""" - for igroup, group in enumerate(self.groups): - if not self.chunk_filter(igroup): - continue - if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") - continue - - # assert isinstance(group[0], datetime.datetime), type(group[0]) - LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - - result = self.input.select(self.context, argument=group) - assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) - - # There are several groups. - # There is one result to load for each group. - self.load_result(result) - self.registry.set_flag(igroup) - - self.registry.add_provenance(name="provenance_load") - self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - - self.dataset.print_info() - - def load_result(self, result: Any) -> None: - """Load the result into the dataset. - - Parameters - ---------- - result : Any - The result to load. - """ - # There is one cube to load for each result. - dates = list(result.group_of_dates) - - LOG.debug(f"Loading cube for {len(dates)} dates") - - cube = result.get_cube() - shape = cube.extended_user_shape - dates_in_data = cube.user_coords["valid_datetime"] - - LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") - - def check_shape(cube, dates, dates_in_data): - if cube.extended_user_shape[0] != len(dates): - print( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - print("Requested dates", compress_dates(dates)) - print("Cube dates", compress_dates(dates_in_data)) - - a = {as_datetime(_) for _ in dates} - b = {as_datetime(_) for _ in dates_in_data} - - print("Missing dates", compress_dates(a - b)) - print("Extra dates", compress_dates(b - a)) - - raise ValueError( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - - check_shape(cube, dates, dates_in_data) - - def check_dates_in_data(dates_in_data, requested_dates): - _requested_dates = [np.datetime64(_) for _ in requested_dates] - _dates_in_data = [np.datetime64(_) for _ in dates_in_data] - if _dates_in_data != _requested_dates: - LOG.error("Dates in data are not the requested ones:") - - dates_in_data = set(dates_in_data) - requested_dates = set(requested_dates) - - missing = sorted(requested_dates - dates_in_data) - extra = sorted(dates_in_data - requested_dates) - - if missing: - LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") - if extra: - LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") - - raise ValueError("Dates in data are not the requested ones") - - check_dates_in_data(dates_in_data, dates) - - def dates_to_indexes(dates, all_dates): - x = np.array(dates, dtype=np.datetime64) - y = np.array(all_dates, dtype=np.datetime64) - bitmap = np.isin(x, y) - return np.where(bitmap)[0] - - indexes = dates_to_indexes(self.dates, dates_in_data) - - array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) - LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") - self.load_cube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) - self.tmp_statistics.write(indexes, stats, dates=dates_in_data) - LOG.info("Flush data array") - array.flush() - LOG.info("Flushed data array") - - def _get_allow_nans(self) -> bool | list: - """Get the allow_nans configuration. - - Returns - ------- - bool | list - The allow_nans configuration. - """ - config = self.main_config - if "allow_nans" in config.build: - return config.build.allow_nans - - return config.statistics.get("allow_nans", []) - - def load_cube(self, cube: Any, array: ViewCacheArray) -> None: - """Load the cube into the array. - - Parameters - ---------- - cube : Any - The cube to load. - array : ViewCacheArray - The array to load into. - """ - # There are several cubelets for each cube - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - LOG.debug(f"Loading datacube: {cube}") - - def position(x: Any) -> int | None: - if isinstance(x, str) and "/" in x: - x = x.split("/") - return int(x[0]) - return None - - bar = tqdm.tqdm( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - position=position(self.parts), - ) - for i, cubelet in enumerate(bar): - bar.set_description(f"Loading {i}/{total}") - - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values( - data[:], - name=name, - log=[i, data.shape, local_indexes], - allow_nans=self._get_allow_nans(), - ) - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.debug( - f"Elapsed: {seconds_to_human(time.time() - start)}, " - f"load time: {seconds_to_human(load)}, " - f"write time: {seconds_to_human(save)}." - ) - - -class Cleanup(Task, HasRegistryMixin, HasStatisticTempMixin): - """A class to clean up temporary data and registry entries.""" - - def __init__( - self, - path: str, - statistics_temp_dir: str | None = None, - delta: list = [], - use_threads: bool = False, - **kwargs: Any, - ): - """Initialize a Cleanup instance. - - Parameters - ---------- - path : str - The path to the dataset. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - delta : list, optional - The delta values. - use_threads : bool, optional - Whether to use threads. - """ - super().__init__(path) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.additinon_temp_dir = statistics_temp_dir - self.tasks = [ - _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) - for d in delta - ] - - def run(self) -> None: - """Run the cleanup.""" - - self.tmp_statistics.delete() - self.registry.clean() - for actor in self.tasks: - actor.cleanup() - - -class Verify(Task): +class Verify(FieldTask): """A class to verify the integrity of a dataset.""" def __init__(self, path: str, **kwargs: Any): @@ -1182,7 +636,7 @@ def __getitem__(self, i: int) -> Any: return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] -class _InitAdditions(Task, HasRegistryMixin, AdditionsMixin): +class _InitAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): """A class to initialize dataset additions.""" def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): @@ -1222,7 +676,7 @@ def cleanup(self) -> None: LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") -class _RunAdditions(Task, HasRegistryMixin, AdditionsMixin): +class _RunAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): """A class to run dataset additions.""" def __init__( @@ -1298,7 +752,7 @@ def allow_nans(self) -> bool: return True -class _FinaliseAdditions(Task, HasRegistryMixin, AdditionsMixin): +class _FinaliseAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): """A class to finalize dataset additions.""" def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): @@ -1478,7 +932,7 @@ def run(self) -> None: FinaliseAdditions = multi_addition(_FinaliseAdditions) -class Statistics(Task, HasStatisticTempMixin, HasRegistryMixin): +class Statistics(FieldTask, HasStatisticTempMixin, HasRegistryMixin): """A class to compute statistics for a dataset.""" def __init__( @@ -1559,73 +1013,6 @@ def allow_nans(self) -> bool | list: return True -def chain(tasks: list) -> type: - """Create a class to chain multiple tasks. - - Parameters - ---------- - tasks : list - The list of tasks to chain. - - Returns - ------- - type - The class to chain multiple tasks. - """ - - class Chain(Task): - def __init__(self, **kwargs: Any): - self.kwargs = kwargs - - def run(self) -> None: - """Run the chained tasks.""" - for cls in tasks: - t = cls(**self.kwargs) - t.run() - - return Chain - - -def task_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: - """Create a dataset creator. - - Parameters - ---------- - name : str - The name of the creator. - trace : Optional[str], optional - The trace file. - **kwargs - Additional arguments for the creator. - - Returns - ------- - Any - The dataset creator. - """ - if trace: - - enable_trace(trace) - - cls = dict( - init=Init, - load=Load, - size=Size, - patch=Patch, - statistics=Statistics, - finalise=chain([Statistics, Size, Cleanup]), - cleanup=Cleanup, - verify=Verify, - init_additions=InitAdditions, - load_additions=RunAdditions, - run_additions=RunAdditions, - finalise_additions=chain([FinaliseAdditions, Size]), - additions=chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup]), - )[name] - LOG.debug(f"Creating {cls.__name__} with {kwargs}") - return cls(**kwargs) - - def validate_config(config: Any) -> None: import json @@ -1689,3 +1076,53 @@ def config_to_python(config: Any) -> Any: except Exception: LOG.warning("Black not installed, skipping formatting") return code + + +class TaskCreator: + """A class to create and run dataset creation tasks.""" + + def init(self, *args: Any, **kwargs: Any): + from .init import Init + + return Init(*args, **kwargs) + + def load(self, *args: Any, **kwargs: Any): + from .load import Load + + return Load(*args, **kwargs) + + def size(self, *args: Any, **kwargs: Any): + return Size(*args, **kwargs) + + def patch(self, *args: Any, **kwargs: Any): + return Patch(*args, **kwargs) + + def statistics(self, *args: Any, **kwargs: Any): + return Statistics(*args, **kwargs) + + def finalise(self, *args: Any, **kwargs: Any): + from .cleanup import Cleanup + + return chain([Statistics, Size, Cleanup])(*args, **kwargs) + + def cleanup(self, *args: Any, **kwargs: Any): + from .cleanup import Cleanup + + return Cleanup(*args, **kwargs) + + def verify(self, *args: Any, **kwargs: Any): + return Verify(*args, **kwargs) + + def init_additions(self, *args: Any, **kwargs: Any): + return InitAdditions(*args, **kwargs) + + def run_additions(self, *args: Any, **kwargs: Any): + return RunAdditions(*args, **kwargs) + + def finalise_additions(self, *args: Any, **kwargs: Any): + return chain([FinaliseAdditions, Size])(*args, **kwargs) + + def additions(self, *args: Any, **kwargs: Any): + from .cleanup import Cleanup + + return chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) diff --git a/src/anemoi/datasets/create/observations/tasks.py b/src/anemoi/datasets/create/observations/tasks.py index e69de29bb..d8432b1ad 100644 --- a/src/anemoi/datasets/create/observations/tasks.py +++ b/src/anemoi/datasets/create/observations/tasks.py @@ -0,0 +1,131 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +from anemoi.datasets.create.tasks import Task + + +class Init(Task): + def __init__(self, *, config: str, path: str, overwrite: bool = False, test: bool = False, **kwargs: Any): + self.config = config + self.path = path + self.overwrite = overwrite + self.test = test + + def run(self) -> None: + print(f"Init dataset at {self.path} with config {self.config}, overwrite={self.overwrite}, test={self.test}") + # Here would be the logic to initialize the dataset + + +class Load(Task): + def __init__(self, *, path: str, parts: str | None = None, use_threads: bool = False, **kwargs: Any): + self.path = path + self.parts = parts + self.use_threads = use_threads + + def run(self) -> None: + print(f"Load data into dataset at {self.path}, parts={self.parts}, use_threads={self.use_threads}") + # Here would be the logic to load data into the dataset + + +class Finalise(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Finalise dataset at {self.path}") + # Here would be the logic to finalise the dataset + + +class InitAdditions(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Init additions for dataset at {self.path}") + # Here would be the logic to initialize additions + + +class RunAdditions(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Run additions for dataset at {self.path}") + # Here would be the logic to run additions + + +class FinaliseAdditions(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Finalise additions for dataset at {self.path}") + # Here would be the logic to finalise additions + + +class Patch(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Patch dataset at {self.path}") + # Here would be the logic to patch the dataset + + +class Cleanup(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Cleanup dataset at {self.path}") + # Here would be the logic to cleanup the dataset + + +class Verify(Task): + def __init__(self, *, path: str, **kwargs: Any): + self.path = path + + def run(self) -> None: + print(f"Verify dataset at {self.path}") + # Here would be the logic to verify the dataset + + +class TaskCreator: + """A class to create and run dataset creation tasks.""" + + def init(self, *args: Any, **kwargs: Any): + return Init(*args, **kwargs) + + def load(self, *args: Any, **kwargs: Any): + + return Load(*args, **kwargs) + + def finalise(self, *args: Any, **kwargs: Any): + return Finalise(*args, **kwargs) + + def init_additions(self, *args: Any, **kwargs: Any): + return InitAdditions(*args, **kwargs) + + def run_additions(self, *args: Any, **kwargs: Any): + return RunAdditions(*args, **kwargs) + + def finalise_additions(self, *args: Any, **kwargs: Any): + return FinaliseAdditions(*args, **kwargs) + + def patch(self, *args: Any, **kwargs: Any): + return Patch(*args, **kwargs) + + def cleanup(self, *args: Any, **kwargs: Any): + return Cleanup(*args, **kwargs) + + def verify(self, *args: Any, **kwargs: Any): + + return Verify(*args, **kwargs) diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py index cff391eab..f06f8fe5d 100644 --- a/src/anemoi/datasets/create/tasks.py +++ b/src/anemoi/datasets/create/tasks.py @@ -1,9 +1,60 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from abc import ABC +from abc import abstractmethod +from typing import Any + + +class Task(ABC): + @abstractmethod + def run(self) -> None: + """Run the task.""" + pass + + +def chain(tasks: list) -> type: + """Create a class to chain multiple tasks. + + Parameters + ---------- + tasks : list + The list of tasks to chain. + + Returns + ------- + type + The class to chain multiple tasks. + """ + + class Chain(Task): + def __init__(self, **kwargs: Any): + self.kwargs = kwargs + + def run(self) -> None: + """Run the chained tasks.""" + for cls in tasks: + t = cls(**self.kwargs) + t.run() + + return Chain + + def task_factory(name: str, fields: bool, trace: str | None = None, **kwargs): + if fields: - from anemoi.datasets.create.fields.tasks import task_factory as fields_task_factory + from anemoi.datasets.create.fields.tasks import TaskCreator - return fields_task_factory(name, trace=trace, **kwargs) + creator = TaskCreator() else: - from anemoi.datasets.create.observations.tasks import task_factory as observations_task_factory + from anemoi.datasets.create.observations.tasks import TaskCreator + + creator = TaskCreator() - return observations_task_factory(name, trace=trace, **kwargs) + return getattr(creator, name)(trace=trace, **kwargs) From bd09239766859bff2a91d50a1a5794f6c63c1274 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 30 Sep 2025 10:22:01 +0100 Subject: [PATCH 144/212] refactor --- .../datasets/commands/recipe/__init__.py | 10 +- .../datasets/commands/recipe/migrate.py | 4 +- .../datasets/create/fields/additions.py | 413 +++++++++++++ src/anemoi/datasets/create/fields/cleanup.py | 2 +- src/anemoi/datasets/create/fields/init.py | 4 +- src/anemoi/datasets/create/fields/patch.py | 38 ++ src/anemoi/datasets/create/fields/size.py | 48 ++ .../datasets/create/fields/statistics.py | 102 ++++ src/anemoi/datasets/create/fields/tasks.py | 574 +----------------- src/anemoi/datasets/create/fields/verify.py | 34 ++ src/anemoi/datasets/data/padded.py | 1 + .../data/records/backends/__init__.py | 12 +- 12 files changed, 678 insertions(+), 564 deletions(-) create mode 100644 src/anemoi/datasets/create/fields/additions.py create mode 100644 src/anemoi/datasets/create/fields/patch.py create mode 100644 src/anemoi/datasets/create/fields/size.py create mode 100644 src/anemoi/datasets/create/fields/statistics.py create mode 100644 src/anemoi/datasets/create/fields/verify.py diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 69831e6da..07952ecbb 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -18,8 +18,8 @@ from anemoi.datasets.commands import Command from anemoi.datasets.commands.recipe.format import format_recipe from anemoi.datasets.commands.recipe.migrate import migrate_recipe -from anemoi.datasets.create.fields.tasks import config_to_python -from anemoi.datasets.create.fields.tasks import validate_config +from anemoi.datasets.create.fields.tasks import _config_to_python +from anemoi.datasets.create.fields.tasks import _validate_config LOG = logging.getLogger(__name__) @@ -65,7 +65,7 @@ def run(self, args: Any) -> None: if args.output and (not args.format and not args.migrate and not args.python): argparse.ArgumentError(None, "--output is not supported with --validate.") - validate_config(config) + _validate_config(config) LOG.info(f"{args.path}: Recipe is valid.") return @@ -99,9 +99,9 @@ def run(self, args: Any) -> None: if args.output: with open(args.output, "w") as file: - file.write(config_to_python(config)) + file.write(_config_to_python(config)) else: - print(config_to_python(config)) + print(_config_to_python(config)) command = Recipe diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 071dbab89..65a03ec76 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.create.fields.tasks import validate_config +from anemoi.datasets.create.fields.tasks import _validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) @@ -520,7 +520,7 @@ def check(config): try: - validate_config(config) + _validate_config(config) assert config.get("input", {}) assert config.get("dates", {}) assert not has_key(config, "label") diff --git a/src/anemoi/datasets/create/fields/additions.py b/src/anemoi/datasets/create/fields/additions.py new file mode 100644 index 000000000..9068f8181 --- /dev/null +++ b/src/anemoi/datasets/create/fields/additions.py @@ -0,0 +1,413 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +import warnings +from functools import cached_property +from typing import Any + +import numpy as np +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets import MissingDateError +from anemoi.datasets import open_dataset +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.persistent import build_storage +from anemoi.datasets.create.statistics import Summary +from anemoi.datasets.create.statistics import check_variance +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.statistics import fix_variance + +from .tasks import FieldTask +from .tasks import HasRegistryMixin + +LOG = logging.getLogger(__name__) + + +class AdditionsMixin: + """A mixin class to handle dataset additions.""" + + def skip(self) -> bool: + """Check if the additions should be skipped. + + Returns + ------- + bool + Whether to skip the additions. + """ + frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + if not self.delta.total_seconds() % frequency.total_seconds() == 0: + LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") + return True + + if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: + LOG.warning(f"Additions are disabled for {self.path} in the recipe.") + return True + + return False + + @cached_property + def tmp_storage_path(self) -> str: + """Get the path to the temporary storage.""" + name = "storage_for_additions" + if self.delta: + name += frequency_to_string(self.delta) + return os.path.join(f"{self.path}.{name}.tmp") + + def read_from_dataset(self) -> None: + """Read data from the dataset.""" + self.variables = self.dataset.anemoi_dataset.variables + self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + start = self.dataset.zarr_metadata["statistics_start_date"] + end = self.dataset.zarr_metadata["statistics_end_date"] + self.start = datetime.datetime.fromisoformat(start) + self.end = datetime.datetime.fromisoformat(end) + + ds = open_dataset(self.path, start=self.start, end=self.end) + self.dates = ds.dates + self.total = len(self.dates) + + idelta = self.delta.total_seconds() // self.frequency.total_seconds() + assert int(idelta) == idelta, idelta + idelta = int(idelta) + self.ds = DeltaDataset(ds, idelta) + + +class DeltaDataset: + """A class to represent a dataset with delta values.""" + + def __init__(self, ds: Any, idelta: int): + """Initialize a DeltaDataset instance. + + Parameters + ---------- + ds : Any + The dataset. + idelta : int + The delta value. + """ + self.ds = ds + self.idelta = idelta + + def __getitem__(self, i: int) -> Any: + """Get an item from the dataset. + + Parameters + ---------- + i : int + The index. + + Returns + ------- + Any + The item. + """ + j = i - self.idelta + if j < 0: + raise MissingDateError(f"Missing date {j}") + return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] + + +class _InitAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to initialize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize an _InitAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + def run(self) -> None: + """Run the additions initialization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) + self.tmp_storage.delete() + self.tmp_storage.create() + LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") + + def cleanup(self) -> None: + """Clean up the temporary storage.""" + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + self.tmp_storage.delete() + LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") + + +class _RunAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to run dataset additions.""" + + def __init__( + self, + path: str, + delta: str, + parts: str | None = None, + use_threads: bool = False, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a _RunAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + self.parts = parts + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Writing in {self.tmp_storage_path}") + + def run(self) -> None: + """Run the additions.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.read_from_dataset() + + chunk_filter = ChunkFilter(parts=self.parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i] + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.debug(f"Dataset {self.path} additions run.") + + def allow_nans(self) -> bool: + """Check if NaNs are allowed. + + Returns + ------- + bool + Whether NaNs are allowed. + """ + if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): + return True + + variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) + if variables_with_nans is not None: + return variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") + return True + + +class _FinaliseAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to finalize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize a _FinaliseAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Reading from {self.tmp_storage_path}.") + + def run(self) -> None: + """Run the additions finalization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}.") + return + + self.read_from_dataset() + + shape = (len(self.dates), len(self.variables)) + agg = dict( + minimum=np.full(shape, np.nan, dtype=np.float64), + maximum=np.full(shape, np.nan, dtype=np.float64), + sums=np.full(shape, np.nan, dtype=np.float64), + squares=np.full(shape, np.nan, dtype=np.float64), + count=np.full(shape, -1, dtype=np.int64), + has_nans=np.full(shape, False, dtype=np.bool_), + ) + LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") + + found = set() + ifound = set() + missing = set() + for _date, (date, i, stats) in self.tmp_storage.items(): + assert _date == date + if stats == "missing": + missing.add(date) + continue + + assert date not in found, f"Duplicates found {date}" + found.add(date) + ifound.add(i) + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k][i, ...] = stats[k] + + assert len(found) + len(missing) == len(self.dates), ( + len(found), + len(missing), + len(self.dates), + ) + assert found.union(missing) == set(self.dates), ( + found, + missing, + set(self.dates), + ) + + if len(ifound) < 2: + LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") + self.tmp_storage.delete() + return + + mask = sorted(list(ifound)) + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k] = agg[k][mask, ...] + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + assert agg[k].shape == agg["count"].shape, ( + agg[k].shape, + agg["count"].shape, + ) + + minimum = np.nanmin(agg["minimum"], axis=0) + maximum = np.nanmax(agg["maximum"], axis=0) + sums = np.nansum(agg["sums"], axis=0) + squares = np.nansum(agg["squares"], axis=0) + count = np.nansum(agg["count"], axis=0) + has_nans = np.any(agg["has_nans"], axis=0) + + assert sums.shape == count.shape + assert sums.shape == squares.shape + assert sums.shape == minimum.shape + assert sums.shape == maximum.shape + assert sums.shape == has_nans.shape + + mean = sums / count + assert sums.shape == mean.shape + + x = squares / count - mean * mean + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + # remove negative variance due to numerical errors + for i, name in enumerate(self.variables): + x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) + check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) + + stdev = np.sqrt(x) + assert sums.shape == stdev.shape + + self.summary = Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables, + has_nans=has_nans, + ) + LOG.info(f"Dataset {self.path} additions finalised.") + # self.check_statistics() + self._write(self.summary) + self.tmp_storage.delete() + + def _write(self, summary: Summary) -> None: + """Write the summary to the dataset. + + Parameters + ---------- + summary : Summary + The summary to write. + """ + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" + self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) + self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") + LOG.debug(f"Wrote additions in {self.path}") + + +def multi_addition(cls: type) -> type: + """Create a class to handle multiple additions. + + Parameters + ---------- + cls : type + The class to handle additions. + + Returns + ------- + type + The class to handle multiple additions. + """ + + class MultiAdditions: + def __init__(self, *args, **kwargs: Any): + self.tasks = [] + + for k in kwargs.pop("delta", []): + self.tasks.append(cls(*args, delta=k, **kwargs)) + + if not self.tasks: + LOG.warning("No delta found in kwargs, no additions will be computed.") + + def run(self) -> None: + """Run the additions.""" + for actor in self.tasks: + actor.run() + + return MultiAdditions + + +InitAdditions = multi_addition(_InitAdditions) +RunAdditions = multi_addition(_RunAdditions) +FinaliseAdditions = multi_addition(_FinaliseAdditions) diff --git a/src/anemoi/datasets/create/fields/cleanup.py b/src/anemoi/datasets/create/fields/cleanup.py index 77b601e58..8b87ba3cc 100644 --- a/src/anemoi/datasets/create/fields/cleanup.py +++ b/src/anemoi/datasets/create/fields/cleanup.py @@ -10,10 +10,10 @@ import logging from typing import Any +from .additions import _InitAdditions from .tasks import FieldTask from .tasks import HasRegistryMixin from .tasks import HasStatisticTempMixin -from .tasks import _InitAdditions LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/fields/init.py b/src/anemoi/datasets/create/fields/init.py index 094f60922..77e1f36e1 100644 --- a/src/anemoi/datasets/create/fields/init.py +++ b/src/anemoi/datasets/create/fields/init.py @@ -23,7 +23,7 @@ from .tasks import HasRegistryMixin from .tasks import HasStatisticTempMixin from .tasks import NewDataset -from .tasks import build_statistics_dates +from .tasks import _build_statistics_dates LOG = logging.getLogger(__name__) @@ -277,7 +277,7 @@ def _run(self) -> int: self.tmp_statistics.create(exist_ok=False) self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - statistics_start, statistics_end = build_statistics_dates( + statistics_start, statistics_end = _build_statistics_dates( dates, self.main_config.statistics.get("start"), self.main_config.statistics.get("end"), diff --git a/src/anemoi/datasets/create/fields/patch.py b/src/anemoi/datasets/create/fields/patch.py new file mode 100644 index 000000000..546d53a13 --- /dev/null +++ b/src/anemoi/datasets/create/fields/patch.py @@ -0,0 +1,38 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from .tasks import FieldTask + +LOG = logging.getLogger(__name__) + + +class Patch(FieldTask): + """A class to apply patches to a dataset.""" + + def __init__(self, path: str, options: dict = None, **kwargs: Any): + """Initialize a Patch instance. + + Parameters + ---------- + path : str + The path to the dataset. + options : dict, optional + The patch options. + """ + self.path = path + self.options = options or {} + + def run(self) -> None: + """Run the patch.""" + from anemoi.datasets.create.patch import apply_patch + + apply_patch(self.path, **self.options) diff --git a/src/anemoi/datasets/create/fields/size.py b/src/anemoi/datasets/create/fields/size.py new file mode 100644 index 000000000..10c64d4d7 --- /dev/null +++ b/src/anemoi/datasets/create/fields/size.py @@ -0,0 +1,48 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from anemoi.datasets import open_dataset + +from .tasks import FieldTask + +LOG = logging.getLogger(__name__) + + +class Size(FieldTask): + """A class to compute the size of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Size instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the size computation.""" + from anemoi.datasets.create.size import compute_directory_sizes + + metadata = compute_directory_sizes(self.path) + self.update_metadata(**metadata) + + # Look for constant fields + ds = open_dataset(self.path) + constants = ds.computed_constant_fields() + + variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() + for k in constants: + variables_metadata[k]["constant_in_time"] = True + + self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) diff --git a/src/anemoi/datasets/create/fields/statistics.py b/src/anemoi/datasets/create/fields/statistics.py new file mode 100644 index 000000000..b199fd052 --- /dev/null +++ b/src/anemoi/datasets/create/fields/statistics.py @@ -0,0 +1,102 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import warnings +from functools import cached_property +from typing import Any + +import numpy as np +import zarr + +from .tasks import FieldTask +from .tasks import HasRegistryMixin +from .tasks import HasStatisticTempMixin + +LOG = logging.getLogger(__name__) + + +class Statistics(FieldTask, HasStatisticTempMixin, HasRegistryMixin): + """A class to compute statistics for a dataset.""" + + def __init__( + self, + path: str, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a Statistics instance. + + Parameters + ---------- + path : str + The path to the dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.use_threads = use_threads + self.progress = progress + self.statistics_temp_dir = statistics_temp_dir + + def run(self) -> None: + """Run the statistics computation.""" + start, end = ( + self.dataset.zarr_metadata["statistics_start_date"], + self.dataset.zarr_metadata["statistics_end_date"], + ) + start, end = np.datetime64(start), np.datetime64(end) + dates = self.dataset.anemoi_dataset.dates + + assert type(dates[0]) is type(start), (type(dates[0]), type(start)) + + dates = [d for d in dates if d >= start and d <= end] + dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] + variables = self.dataset.anemoi_dataset.variables + stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) + + LOG.info(stats) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") + + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: + self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) + + self.registry.add_to_history("compute_statistics_end") + LOG.info(f"Wrote statistics in {self.path}") + + @cached_property + def allow_nans(self) -> bool | list: + """Check if NaNs are allowed.""" + + z = zarr.open(self.path, mode="r") + if "allow_nans" in z.attrs: + return z.attrs["allow_nans"] + + if "variables_with_nans" in z.attrs: + return z.attrs["variables_with_nans"] + + warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") + return True diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 14725e56a..22a6dfcb4 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -11,7 +11,6 @@ import json import logging import os -import warnings from functools import cached_property from typing import Any @@ -19,24 +18,16 @@ import numpy as np import zarr from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta from earthkit.data.core.order import build_remapping -from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset from anemoi.datasets.create.check import DatasetName -from anemoi.datasets.create.chunks import ChunkFilter from anemoi.datasets.create.config import build_output from anemoi.datasets.create.config import loader_config from anemoi.datasets.create.fields.context import FieldContext from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.persistent import build_storage -from anemoi.datasets.create.statistics import Summary from anemoi.datasets.create.statistics import TmpStatistics -from anemoi.datasets.create.statistics import check_variance -from anemoi.datasets.create.statistics import compute_statistics from anemoi.datasets.create.statistics import default_statistics_dates -from anemoi.datasets.create.statistics import fix_variance from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups @@ -46,7 +37,7 @@ LOG = logging.getLogger(__name__) -def json_tidy(o: Any) -> Any: +def _json_tidy(o: Any) -> Any: """Convert various types to JSON serializable format. Parameters @@ -87,7 +78,7 @@ def json_tidy(o: Any) -> Any: raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") -def build_statistics_dates( +def _build_statistics_dates( dates: list[datetime.datetime], start: datetime.datetime | None, end: datetime.datetime | None, @@ -181,7 +172,7 @@ def update_metadata(self, **kwargs: Any) -> None: v = v.astype(datetime.datetime) if isinstance(v, datetime.date): v = v.isoformat() - z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) + z.attrs[k] = json.loads(json.dumps(v, default=_json_tidy)) @cached_property def anemoi_dataset(self) -> Any: @@ -422,60 +413,6 @@ def check_missing_dates(expected: list[np.datetime64]) -> None: check_missing_dates(self.missing_dates) -class Patch(FieldTask): - """A class to apply patches to a dataset.""" - - def __init__(self, path: str, options: dict = None, **kwargs: Any): - """Initialize a Patch instance. - - Parameters - ---------- - path : str - The path to the dataset. - options : dict, optional - The patch options. - """ - self.path = path - self.options = options or {} - - def run(self) -> None: - """Run the patch.""" - from anemoi.datasets.create.patch import apply_patch - - apply_patch(self.path, **self.options) - - -class Size(FieldTask): - """A class to compute the size of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Size instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the size computation.""" - from anemoi.datasets.create.size import compute_directory_sizes - - metadata = compute_directory_sizes(self.path) - self.update_metadata(**metadata) - - # Look for constant fields - ds = open_dataset(self.path) - constants = ds.computed_constant_fields() - - variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() - for k in constants: - variables_metadata[k]["constant_in_time"] = True - - self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) - - class HasRegistryMixin: """A mixin class to provide registry functionality.""" @@ -533,487 +470,7 @@ def create_elements(self, config: Any) -> None: LOG.debug(self.input) -class Verify(FieldTask): - """A class to verify the integrity of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Verify instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the verification.""" - LOG.info(f"Verifying dataset at {self.path}") - LOG.info(str(self.dataset.anemoi_dataset)) - - -class AdditionsMixin: - """A mixin class to handle dataset additions.""" - - def skip(self) -> bool: - """Check if the additions should be skipped. - - Returns - ------- - bool - Whether to skip the additions. - """ - frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - if not self.delta.total_seconds() % frequency.total_seconds() == 0: - LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") - return True - - if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: - LOG.warning(f"Additions are disabled for {self.path} in the recipe.") - return True - - return False - - @cached_property - def tmp_storage_path(self) -> str: - """Get the path to the temporary storage.""" - name = "storage_for_additions" - if self.delta: - name += frequency_to_string(self.delta) - return os.path.join(f"{self.path}.{name}.tmp") - - def read_from_dataset(self) -> None: - """Read data from the dataset.""" - self.variables = self.dataset.anemoi_dataset.variables - self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - start = self.dataset.zarr_metadata["statistics_start_date"] - end = self.dataset.zarr_metadata["statistics_end_date"] - self.start = datetime.datetime.fromisoformat(start) - self.end = datetime.datetime.fromisoformat(end) - - ds = open_dataset(self.path, start=self.start, end=self.end) - self.dates = ds.dates - self.total = len(self.dates) - - idelta = self.delta.total_seconds() // self.frequency.total_seconds() - assert int(idelta) == idelta, idelta - idelta = int(idelta) - self.ds = DeltaDataset(ds, idelta) - - -class DeltaDataset: - """A class to represent a dataset with delta values.""" - - def __init__(self, ds: Any, idelta: int): - """Initialize a DeltaDataset instance. - - Parameters - ---------- - ds : Any - The dataset. - idelta : int - The delta value. - """ - self.ds = ds - self.idelta = idelta - - def __getitem__(self, i: int) -> Any: - """Get an item from the dataset. - - Parameters - ---------- - i : int - The index. - - Returns - ------- - Any - The item. - """ - j = i - self.idelta - if j < 0: - raise MissingDateError(f"Missing date {j}") - return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] - - -class _InitAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to initialize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize an _InitAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - def run(self) -> None: - """Run the additions initialization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) - self.tmp_storage.delete() - self.tmp_storage.create() - LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") - - def cleanup(self) -> None: - """Clean up the temporary storage.""" - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - self.tmp_storage.delete() - LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") - - -class _RunAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to run dataset additions.""" - - def __init__( - self, - path: str, - delta: str, - parts: str | None = None, - use_threads: bool = False, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a _RunAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - self.parts = parts - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Writing in {self.tmp_storage_path}") - - def run(self) -> None: - """Run the additions.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.read_from_dataset() - - chunk_filter = ChunkFilter(parts=self.parts, total=self.total) - for i in range(0, self.total): - if not chunk_filter(i): - continue - date = self.dates[i] - try: - arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) - self.tmp_storage.add([date, i, stats], key=date) - except MissingDateError: - self.tmp_storage.add([date, i, "missing"], key=date) - self.tmp_storage.flush() - LOG.debug(f"Dataset {self.path} additions run.") - - def allow_nans(self) -> bool: - """Check if NaNs are allowed. - - Returns - ------- - bool - Whether NaNs are allowed. - """ - if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): - return True - - variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) - if variables_with_nans is not None: - return variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") - return True - - -class _FinaliseAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to finalize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize a _FinaliseAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Reading from {self.tmp_storage_path}.") - - def run(self) -> None: - """Run the additions finalization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}.") - return - - self.read_from_dataset() - - shape = (len(self.dates), len(self.variables)) - agg = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - has_nans=np.full(shape, False, dtype=np.bool_), - ) - LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") - - found = set() - ifound = set() - missing = set() - for _date, (date, i, stats) in self.tmp_storage.items(): - assert _date == date - if stats == "missing": - missing.add(date) - continue - - assert date not in found, f"Duplicates found {date}" - found.add(date) - ifound.add(i) - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k][i, ...] = stats[k] - - assert len(found) + len(missing) == len(self.dates), ( - len(found), - len(missing), - len(self.dates), - ) - assert found.union(missing) == set(self.dates), ( - found, - missing, - set(self.dates), - ) - - if len(ifound) < 2: - LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") - self.tmp_storage.delete() - return - - mask = sorted(list(ifound)) - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k] = agg[k][mask, ...] - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - assert agg[k].shape == agg["count"].shape, ( - agg[k].shape, - agg["count"].shape, - ) - - minimum = np.nanmin(agg["minimum"], axis=0) - maximum = np.nanmax(agg["maximum"], axis=0) - sums = np.nansum(agg["sums"], axis=0) - squares = np.nansum(agg["squares"], axis=0) - count = np.nansum(agg["count"], axis=0) - has_nans = np.any(agg["has_nans"], axis=0) - - assert sums.shape == count.shape - assert sums.shape == squares.shape - assert sums.shape == minimum.shape - assert sums.shape == maximum.shape - assert sums.shape == has_nans.shape - - mean = sums / count - assert sums.shape == mean.shape - - x = squares / count - mean * mean - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - # remove negative variance due to numerical errors - for i, name in enumerate(self.variables): - x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) - check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) - - stdev = np.sqrt(x) - assert sums.shape == stdev.shape - - self.summary = Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables, - has_nans=has_nans, - ) - LOG.info(f"Dataset {self.path} additions finalised.") - # self.check_statistics() - self._write(self.summary) - self.tmp_storage.delete() - - def _write(self, summary: Summary) -> None: - """Write the summary to the dataset. - - Parameters - ---------- - summary : Summary - The summary to write. - """ - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" - self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) - self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.debug(f"Wrote additions in {self.path}") - - -def multi_addition(cls: type) -> type: - """Create a class to handle multiple additions. - - Parameters - ---------- - cls : type - The class to handle additions. - - Returns - ------- - type - The class to handle multiple additions. - """ - - class MultiAdditions: - def __init__(self, *args, **kwargs: Any): - self.tasks = [] - - for k in kwargs.pop("delta", []): - self.tasks.append(cls(*args, delta=k, **kwargs)) - - if not self.tasks: - LOG.warning("No delta found in kwargs, no additions will be computed.") - - def run(self) -> None: - """Run the additions.""" - for actor in self.tasks: - actor.run() - - return MultiAdditions - - -InitAdditions = multi_addition(_InitAdditions) -RunAdditions = multi_addition(_RunAdditions) -FinaliseAdditions = multi_addition(_FinaliseAdditions) - - -class Statistics(FieldTask, HasStatisticTempMixin, HasRegistryMixin): - """A class to compute statistics for a dataset.""" - - def __init__( - self, - path: str, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a Statistics instance. - - Parameters - ---------- - path : str - The path to the dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.use_threads = use_threads - self.progress = progress - self.statistics_temp_dir = statistics_temp_dir - - def run(self) -> None: - """Run the statistics computation.""" - start, end = ( - self.dataset.zarr_metadata["statistics_start_date"], - self.dataset.zarr_metadata["statistics_end_date"], - ) - start, end = np.datetime64(start), np.datetime64(end) - dates = self.dataset.anemoi_dataset.dates - - assert type(dates[0]) is type(start), (type(dates[0]), type(start)) - - dates = [d for d in dates if d >= start and d <= end] - dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] - variables = self.dataset.anemoi_dataset.variables - stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) - - LOG.info(stats) - - if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - "has_nans", - ]: - self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) - - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote statistics in {self.path}") - - @cached_property - def allow_nans(self) -> bool | list: - """Check if NaNs are allowed.""" - import zarr - - z = zarr.open(self.path, mode="r") - if "allow_nans" in z.attrs: - return z.attrs["allow_nans"] - - if "variables_with_nans" in z.attrs: - return z.attrs["variables_with_nans"] - - warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") - return True - - -def validate_config(config: Any) -> None: +def _validate_config(config: Any) -> None: import json @@ -1054,7 +511,7 @@ def _tidy(d): raise -def config_to_python(config: Any) -> Any: +def _config_to_python(config: Any) -> Any: from anemoi.datasets.create.create.python import PythonScript @@ -1092,16 +549,24 @@ def load(self, *args: Any, **kwargs: Any): return Load(*args, **kwargs) def size(self, *args: Any, **kwargs: Any): + from .size import Size + return Size(*args, **kwargs) def patch(self, *args: Any, **kwargs: Any): + from .patch import Patch + return Patch(*args, **kwargs) def statistics(self, *args: Any, **kwargs: Any): + from .statistics import Statistics + return Statistics(*args, **kwargs) def finalise(self, *args: Any, **kwargs: Any): from .cleanup import Cleanup + from .size import Size + from .statistics import Statistics return chain([Statistics, Size, Cleanup])(*args, **kwargs) @@ -1111,18 +576,31 @@ def cleanup(self, *args: Any, **kwargs: Any): return Cleanup(*args, **kwargs) def verify(self, *args: Any, **kwargs: Any): + from .verify import Verify + return Verify(*args, **kwargs) def init_additions(self, *args: Any, **kwargs: Any): + from .additions import InitAdditions + return InitAdditions(*args, **kwargs) def run_additions(self, *args: Any, **kwargs: Any): + from .additions import RunAdditions + return RunAdditions(*args, **kwargs) def finalise_additions(self, *args: Any, **kwargs: Any): + from .additions import FinaliseAdditions + from .size import Size + return chain([FinaliseAdditions, Size])(*args, **kwargs) def additions(self, *args: Any, **kwargs: Any): + from .additions import FinaliseAdditions + from .additions import InitAdditions + from .additions import RunAdditions from .cleanup import Cleanup + from .size import Size return chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) diff --git a/src/anemoi/datasets/create/fields/verify.py b/src/anemoi/datasets/create/fields/verify.py new file mode 100644 index 000000000..27b3e5f24 --- /dev/null +++ b/src/anemoi/datasets/create/fields/verify.py @@ -0,0 +1,34 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from .tasks import FieldTask + +LOG = logging.getLogger(__name__) + + +class Verify(FieldTask): + """A class to verify the integrity of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Verify instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the verification.""" + LOG.info(f"Verifying dataset at {self.path}") + LOG.info(str(self.dataset.anemoi_dataset)) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/data/padded.py index 1784604c3..be91c15cf 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/data/padded.py @@ -12,6 +12,7 @@ import logging from functools import cached_property from typing import Any +from typing import Dict import numpy as np from anemoi.utils.dates import frequency_to_timedelta diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/data/records/backends/__init__.py index 1c9e0e96c..699b58c1a 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/data/records/backends/__init__.py @@ -213,14 +213,14 @@ def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): os.rename(tmp_path, out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.tasks import json_tidy + from anemoi.datasets.create.fields.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) path = os.path.join(self.path, "metadata.json") tmp_path = path + ".tmp" with open(tmp_path, "w") as f: - json.dump(metadata, f, indent=2, default=json_tidy) + json.dump(metadata, f, indent=2, default=_json_tidy) os.rename(tmp_path, path) def write_statistics(self, statistics): @@ -257,11 +257,11 @@ def write(self, i, data, **kwargs): ds.to_netcdf(out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.tasks import json_tidy + from anemoi.datasets.create.fields.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: - json.dump(metadata, f, indent=2, default=json_tidy) + json.dump(metadata, f, indent=2, default=_json_tidy) def write_statistics(self, statistics): os.makedirs(self.path, exist_ok=True) @@ -295,11 +295,11 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.tasks import json_tidy + from anemoi.datasets.create.fields.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: - json.dump(metadata, f, indent=2, default=json_tidy) + json.dump(metadata, f, indent=2, default=_json_tidy) def write_statistics(self, statistics): flatten = {} From 8be0cb515c173daa1fc2be5214af65c063b6e7ff Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Tue, 30 Sep 2025 10:29:48 +0000 Subject: [PATCH 145/212] same as #423 --- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/create/fields/additions.py | 6 ++--- src/anemoi/datasets/create/fields/tasks.py | 10 ++++----- .../datasets/create/observations/tasks.py | 6 ++--- tests/create/run.sh | 22 +++++++++---------- tests/create/utils/create.py | 2 +- 6 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 1ca332f80..151b175d9 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -105,7 +105,7 @@ def serial_create(self, args: Any) -> None: task("finalise", fields, options) task("init_additions", fields, options) - task("run_additions", fields, options) + task("load_additions", fields, options) task("finalise_additions", fields, options) task("patch", fields, options) diff --git a/src/anemoi/datasets/create/fields/additions.py b/src/anemoi/datasets/create/fields/additions.py index 9068f8181..94972e1c4 100644 --- a/src/anemoi/datasets/create/fields/additions.py +++ b/src/anemoi/datasets/create/fields/additions.py @@ -157,7 +157,7 @@ def cleanup(self) -> None: LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") -class _RunAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): +class _LoadAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): """A class to run dataset additions.""" def __init__( @@ -169,7 +169,7 @@ def __init__( progress: Any = None, **kwargs: Any, ): - """Initialize a _RunAdditions instance. + """Initialize a _LoadAdditions instance. Parameters ---------- @@ -409,5 +409,5 @@ def run(self) -> None: InitAdditions = multi_addition(_InitAdditions) -RunAdditions = multi_addition(_RunAdditions) +LoadAdditions = multi_addition(_LoadAdditions) FinaliseAdditions = multi_addition(_FinaliseAdditions) diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 22a6dfcb4..97beef80f 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -585,10 +585,10 @@ def init_additions(self, *args: Any, **kwargs: Any): return InitAdditions(*args, **kwargs) - def run_additions(self, *args: Any, **kwargs: Any): - from .additions import RunAdditions + def load_additions(self, *args: Any, **kwargs: Any): + from .additions import LoadAdditions - return RunAdditions(*args, **kwargs) + return LoadAdditions(*args, **kwargs) def finalise_additions(self, *args: Any, **kwargs: Any): from .additions import FinaliseAdditions @@ -599,8 +599,8 @@ def finalise_additions(self, *args: Any, **kwargs: Any): def additions(self, *args: Any, **kwargs: Any): from .additions import FinaliseAdditions from .additions import InitAdditions - from .additions import RunAdditions + from .additions import LoadAdditions from .cleanup import Cleanup from .size import Size - return chain([InitAdditions, RunAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) + return chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) diff --git a/src/anemoi/datasets/create/observations/tasks.py b/src/anemoi/datasets/create/observations/tasks.py index d8432b1ad..62cd4c052 100644 --- a/src/anemoi/datasets/create/observations/tasks.py +++ b/src/anemoi/datasets/create/observations/tasks.py @@ -53,7 +53,7 @@ def run(self) -> None: # Here would be the logic to initialize additions -class RunAdditions(Task): +class LoadAdditions(Task): def __init__(self, *, path: str, **kwargs: Any): self.path = path @@ -114,8 +114,8 @@ def finalise(self, *args: Any, **kwargs: Any): def init_additions(self, *args: Any, **kwargs: Any): return InitAdditions(*args, **kwargs) - def run_additions(self, *args: Any, **kwargs: Any): - return RunAdditions(*args, **kwargs) + def load_additions(self, *args: Any, **kwargs: Any): + return LoadAdditions(*args, **kwargs) def finalise_additions(self, *args: Any, **kwargs: Any): return FinaliseAdditions(*args, **kwargs) diff --git a/tests/create/run.sh b/tests/create/run.sh index 432acec2e..61d6f6380 100755 --- a/tests/create/run.sh +++ b/tests/create/run.sh @@ -2,19 +2,17 @@ set -eux NAME=${1:-join} -anemoi-datasets create-step init $NAME.yaml $NAME.zarr --overwrite -anemoi-datasets create-step load $NAME.zarr --part 1/2 -anemoi-datasets create-step load $NAME.zarr --part 2/2 +anemoi-datasets init $NAME.yaml $NAME.zarr --overwrite +anemoi-datasets load $NAME.zarr --part 1/2 +anemoi-datasets load $NAME.zarr --part 2/2 -anemoi-datasets create-step statistics $NAME.zarr -anemoi-datasets create-step size $NAME.zarr -# anemoi-datasets create-step finalise $NAME.zarr +anemoi-datasets finalise $NAME.zarr -anemoi-datasets create-step patch $NAME.zarr +anemoi-datasets patch $NAME.zarr -anemoi-datasets create-step init-additions $NAME.zarr --delta 12h -anemoi-datasets create-step run-additions $NAME.zarr --part 1/2 --delta 12h -anemoi-datasets create-step run-additions $NAME.zarr --part 2/2 --delta 12h -anemoi-datasets create-step finalise-additions $NAME.zarr --delta 12h +anemoi-datasets init-additions $NAME.zarr --delta 12h +anemoi-datasets load-additions $NAME.zarr --part 1/2 --delta 12h +anemoi-datasets load-additions $NAME.zarr --part 2/2 --delta 12h +anemoi-datasets finalise-additions $NAME.zarr --delta 12h -anemoi-datasets create-step cleanup $NAME.zarr +anemoi-datasets cleanup $NAME.zarr diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index 0addb122b..245968ada 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -59,7 +59,7 @@ def create_dataset( if delta is not None: task_factory("init_additions", path=output, delta=delta).run() - task_factory("run_additions", path=output, delta=delta).run() + task_factory("load_additions", path=output, delta=delta).run() task_factory("finalise_additions", path=output, delta=delta).run() task_factory("cleanup", path=output).run() From 475eae7752a63f3075c36593341abd422900bc15 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 14:53:16 +0000 Subject: [PATCH 146/212] make all import absolute --- src/anemoi/datasets/__init__.py | 12 ++--- src/anemoi/datasets/__main__.py | 4 +- src/anemoi/datasets/commands/check.py | 3 +- src/anemoi/datasets/commands/cleanup.py | 3 +- src/anemoi/datasets/commands/compare-lam.py | 3 +- src/anemoi/datasets/commands/compare.py | 3 +- src/anemoi/datasets/commands/copy.py | 3 +- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/finalise-additions.py | 3 +- src/anemoi/datasets/commands/finalise.py | 3 +- src/anemoi/datasets/commands/grib-index.py | 2 +- .../datasets/commands/init-additions.py | 3 +- src/anemoi/datasets/commands/init.py | 3 +- src/anemoi/datasets/commands/inspect.py | 3 +- .../datasets/commands/load-additions.py | 3 +- src/anemoi/datasets/commands/load.py | 3 +- src/anemoi/datasets/commands/patch.py | 3 +- src/anemoi/datasets/commands/publish.py | 2 +- .../datasets/commands/recipe/__init__.py | 7 ++- src/anemoi/datasets/commands/recipe/format.py | 2 +- src/anemoi/datasets/commands/scan.py | 2 +- src/anemoi/datasets/commands/validate.py | 3 +- src/anemoi/datasets/create/__init__.py | 39 ++++++++------- src/anemoi/datasets/create/input/__init__.py | 4 +- .../datasets/create/input/context/field.py | 4 +- .../datasets/create/input/data_sources.py | 10 ++-- .../datasets/create/input/repeated_dates.py | 10 ++-- .../datasets/create/input/result/field.py | 2 +- .../datasets/create/sources/accumulations.py | 7 ++- .../datasets/create/sources/accumulations2.py | 3 +- .../datasets/create/sources/anemoi_dataset.py | 2 +- .../datasets/create/sources/constants.py | 2 +- .../datasets/create/sources/eccc_fstd.py | 4 +- src/anemoi/datasets/create/sources/empty.py | 2 +- src/anemoi/datasets/create/sources/fdb.py | 5 +- .../datasets/create/sources/forcings.py | 2 +- src/anemoi/datasets/create/sources/grib.py | 2 +- .../datasets/create/sources/grib_index.py | 2 +- .../datasets/create/sources/hindcasts.py | 3 +- src/anemoi/datasets/create/sources/legacy.py | 4 +- src/anemoi/datasets/create/sources/mars.py | 3 +- src/anemoi/datasets/create/sources/netcdf.py | 4 +- src/anemoi/datasets/create/sources/opendap.py | 4 +- .../create/sources/planetary_computer.py | 4 +- .../datasets/create/sources/recentre.py | 5 +- src/anemoi/datasets/create/sources/source.py | 3 +- .../datasets/create/sources/tendencies.py | 3 +- src/anemoi/datasets/create/sources/xarray.py | 9 ++-- .../create/sources/xarray_kerchunk.py | 4 +- .../create/sources/xarray_support/__init__.py | 5 +- .../create/sources/xarray_support/field.py | 6 +-- .../sources/xarray_support/fieldlist.py | 12 ++--- .../create/sources/xarray_support/flavour.py | 38 +++++++-------- .../create/sources/xarray_support/metadata.py | 2 +- .../create/sources/xarray_support/time.py | 4 +- .../create/sources/xarray_support/variable.py | 2 +- .../datasets/create/sources/xarray_zarr.py | 4 +- src/anemoi/datasets/create/sources/zenodo.py | 6 +-- .../datasets/create/statistics/__init__.py | 4 +- .../datasets/create/statistics/summary.py | 6 +-- src/anemoi/datasets/data/__init__.py | 10 ++-- src/anemoi/datasets/data/complement.py | 24 +++++----- src/anemoi/datasets/data/concat.py | 30 ++++++------ src/anemoi/datasets/data/dataset.py | 48 +++++++++---------- src/anemoi/datasets/data/debug.py | 2 +- src/anemoi/datasets/data/ensemble.py | 22 ++++----- src/anemoi/datasets/data/fill_missing.py | 21 ++++---- src/anemoi/datasets/data/forwards.py | 20 ++++---- src/anemoi/datasets/data/grids.py | 30 ++++++------ src/anemoi/datasets/data/indexing.py | 6 +-- src/anemoi/datasets/data/interpolate.py | 24 +++++----- src/anemoi/datasets/data/join.py | 30 ++++++------ src/anemoi/datasets/data/masked.py | 26 +++++----- src/anemoi/datasets/data/merge.py | 26 +++++----- src/anemoi/datasets/data/misc.py | 36 +++++++------- src/anemoi/datasets/data/missing.py | 17 ++++--- .../datasets/data/observations/__init__.py | 7 ++- src/anemoi/datasets/data/rescale.py | 20 ++++---- src/anemoi/datasets/data/select.py | 24 +++++----- src/anemoi/datasets/data/statistics.py | 8 ++-- src/anemoi/datasets/data/stores.py | 22 ++++----- src/anemoi/datasets/data/subset.py | 30 ++++++------ src/anemoi/datasets/data/unchecked.py | 16 +++---- src/anemoi/datasets/data/xy.py | 12 ++--- 84 files changed, 396 insertions(+), 425 deletions(-) diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index fe6ca61f1..620f5e80f 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -8,16 +8,16 @@ # nor does it submit to any jurisdiction. -from .data import MissingDateError -from .data import add_dataset_path -from .data import add_named_dataset -from .data import list_dataset_names -from .data import open_dataset +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data import add_dataset_path +from anemoi.datasets.data import add_named_dataset +from anemoi.datasets.data import list_dataset_names +from anemoi.datasets.data import open_dataset try: # NOTE: the `_version.py` file must not be present in the git repository # as it is generated by setuptools at install time - from ._version import __version__ # type: ignore + from anemoi.datasets._version import __version__ # type: ignore except ImportError: # pragma: no cover # Local copy or not installed with setuptools __version__ = "999" diff --git a/src/anemoi/datasets/__main__.py b/src/anemoi/datasets/__main__.py index 62b7d7c73..f47c46050 100644 --- a/src/anemoi/datasets/__main__.py +++ b/src/anemoi/datasets/__main__.py @@ -12,8 +12,8 @@ from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser -from . import __version__ -from .commands import COMMANDS +from anemoi.datasets import __version__ +from anemoi.datasets.commands import COMMANDS # For read-the-docs diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 61b29bf23..4202ed09f 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,10 +13,9 @@ import yaml +from anemoi.datasets.commands import Command from anemoi.datasets.create.check import DatasetName -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/cleanup.py b/src/anemoi/datasets/commands/cleanup.py index 0b3a393bd..25b5b9ca0 100644 --- a/src/anemoi/datasets/commands/cleanup.py +++ b/src/anemoi/datasets/commands/cleanup.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/compare-lam.py b/src/anemoi/datasets/commands/compare-lam.py index 74d97bb48..92ea9a6af 100644 --- a/src/anemoi/datasets/commands/compare-lam.py +++ b/src/anemoi/datasets/commands/compare-lam.py @@ -12,8 +12,7 @@ import os from anemoi.datasets import open_dataset - -from . import Command +from anemoi.datasets.commands import Command RADIUS_EARTH_KM = 6371.0 # Earth's radius in kilometers diff --git a/src/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py index ffe1ec02e..bbd121bd1 100644 --- a/src/anemoi/datasets/commands/compare.py +++ b/src/anemoi/datasets/commands/compare.py @@ -15,8 +15,7 @@ import zarr from anemoi.datasets import open_dataset - -from . import Command +from anemoi.datasets.commands import Command class Compare(Command): diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 5020a208d..5c5768714 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,8 +20,7 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr - -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 3f6bbe7dd..6601d0ee4 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -18,7 +18,7 @@ import tqdm from anemoi.utils.humanize import seconds_to_human -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index 811760ae9..25380fbba 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index 53999ad50..5197fb73c 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index cfd7a08e8..b5cc910d2 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -13,7 +13,7 @@ import tqdm -from . import Command +from anemoi.datasets.commands import Command class GribIndexCmd(Command): diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index 09788f0e4..c49bbf76c 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index 0ca540b86..c5aa515fd 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 384ee7d34..52b7e689d 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,11 +27,10 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset +from anemoi.datasets.commands import Command from anemoi.datasets.data.stores import open_zarr from anemoi.datasets.data.stores import zarr_lookup -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index a8cd5d5c9..82dec8b0a 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 3d969f5c3..7b1c1f684 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/patch.py b/src/anemoi/datasets/commands/patch.py index dc8356126..1920fc420 100644 --- a/src/anemoi/datasets/commands/patch.py +++ b/src/anemoi/datasets/commands/patch.py @@ -13,10 +13,9 @@ from anemoi.utils.humanize import seconds_to_human +from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task -from . import Command - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/publish.py b/src/anemoi/datasets/commands/publish.py index 7f719543e..47282e65b 100644 --- a/src/anemoi/datasets/commands/publish.py +++ b/src/anemoi/datasets/commands/publish.py @@ -10,7 +10,7 @@ import logging from typing import Any -from . import Command +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 45400806c..e708d8b50 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,12 +15,11 @@ import yaml +from anemoi.datasets.commands import Command +from anemoi.datasets.commands.recipe.format import format_recipe +from anemoi.datasets.commands.recipe.migrate import migrate_recipe from anemoi.datasets.create import validate_config -from .. import Command -from .format import format_recipe -from .migrate import migrate_recipe - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index 872060981..328e6d756 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -11,7 +11,7 @@ import datetime import logging -from ...dumper import yaml_dump +from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/scan.py b/src/anemoi/datasets/commands/scan.py index 8a048125e..37c8d0cfd 100644 --- a/src/anemoi/datasets/commands/scan.py +++ b/src/anemoi/datasets/commands/scan.py @@ -17,7 +17,7 @@ import tqdm import yaml -from . import Command +from anemoi.datasets.commands import Command KEYS = ("class", "type", "stream", "expver", "levtype", "domain") diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 1382814a7..03691541c 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,10 +10,9 @@ import logging from typing import Any +from anemoi.datasets.commands import Command from anemoi.datasets.validate import validate_dataset -from . import Command - LOG = logging.getLogger(__name__) DEFAULT_DATASET = "aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8" diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 5600cb254..cec8b95da 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -31,27 +31,26 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset +from anemoi.datasets.create.check import DatasetName +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.config import build_output +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.input import InputBuilder from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage +from anemoi.datasets.create.statistics import Summary +from anemoi.datasets.create.statistics import TmpStatistics +from anemoi.datasets.create.statistics import check_variance +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.statistics import default_statistics_dates +from anemoi.datasets.create.statistics import fix_variance +from anemoi.datasets.create.utils import normalize_and_check_dates +from anemoi.datasets.create.writer import ViewCacheArray from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups -from .check import DatasetName -from .check import check_data_values -from .chunks import ChunkFilter -from .config import build_output -from .config import loader_config -from .input import InputBuilder -from .statistics import Summary -from .statistics import TmpStatistics -from .statistics import check_variance -from .statistics import compute_statistics -from .statistics import default_statistics_dates -from .statistics import fix_variance -from .utils import normalize_and_check_dates -from .writer import ViewCacheArray - LOG = logging.getLogger(__name__) VERSION = "0.30" @@ -193,7 +192,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from .zarr import add_zarr_dataset + from anemoi.datasets.create.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -397,7 +396,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from .utils import cache_context + from anemoi.datasets.create.utils import cache_context return cache_context(self.cache) @@ -473,7 +472,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from .patch import apply_patch + from anemoi.datasets.create.patch import apply_patch apply_patch(self.path, **self.options) @@ -493,7 +492,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from .size import compute_directory_sizes + from anemoi.datasets.create.size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -515,7 +514,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from .zarr import ZarrBuiltRegistry + from anemoi.datasets.create.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index e30ecefb5..2fe695781 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -40,8 +40,8 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No @cached_property def action(self) -> "Recipe": """Returns the action object based on the configuration.""" - from .action import Recipe - from .action import action_factory + from anemoi.datasets.create.input.action import Recipe + from anemoi.datasets.create.input.action import action_factory sources = action_factory(self.data_sources, "data_sources") input = action_factory(self.config, "input") diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/create/input/context/field.py index 1dd01340e..e92a1ebbd 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/create/input/context/field.py @@ -12,8 +12,8 @@ from earthkit.data.core.order import build_remapping -from ..result.field import FieldResult -from . import Context +from anemoi.datasets.create.input.context import Context +from anemoi.datasets.create.input.result.field import FieldResult class FieldContext(Context): diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 31bf3d8cc..31956d602 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -13,11 +13,11 @@ from earthkit.data import FieldList -from ...dates.groups import GroupOfDates -from .action import Action -from .action import action_factory -from .misc import _tidy -from .result.field import Result +from anemoi.datasets.create.input.action import Action +from anemoi.datasets.create.input.action import action_factory +from anemoi.datasets.create.input.misc import _tidy +from anemoi.datasets.create.input.result.field import Result +from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py index ad46fe208..962b82717 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -19,11 +19,11 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from .action import Action -from .action import action_factory -from .join import JoinResult -from .result.field import Result -from .trace import trace_select +from anemoi.datasets.create.input.action import Action +from anemoi.datasets.create.input.action import action_factory +from anemoi.datasets.create.input.join import JoinResult +from anemoi.datasets.create.input.result.field import Result +from anemoi.datasets.create.input.trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/create/input/result/field.py index 083d2ffd7..dbcf8fbd4 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/create/input/result/field.py @@ -22,7 +22,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from . import Result +from anemoi.datasets.create.input.result import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index 6acecbf98..40b8749f6 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -20,11 +20,10 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source -from .mars import mars - LOG = logging.getLogger(__name__) @@ -994,7 +993,7 @@ def accumulations( and request.get("stream", "oper") == "oper" and request.get("accumulation_period") == 24 ): - from .accumulations2 import accumulations as accumulations2 + from anemoi.datasets.create.sources.accumulations2 import accumulations as accumulations2 LOG.warning( "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" diff --git a/src/anemoi/datasets/create/sources/accumulations2.py b/src/anemoi/datasets/create/sources/accumulations2.py index f9ddf3b3a..3c34d392e 100644 --- a/src/anemoi/datasets/create/sources/accumulations2.py +++ b/src/anemoi/datasets/create/sources/accumulations2.py @@ -18,11 +18,10 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.mars import mars from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/anemoi_dataset.py b/src/anemoi/datasets/create/sources/anemoi_dataset.py index 12d41db23..a05e7df51 100644 --- a/src/anemoi/datasets/create/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/create/sources/anemoi_dataset.py @@ -9,7 +9,7 @@ import numpy as np -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index 104f24863..accde7936 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py index 41734e9b6..fdd79af8d 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/create/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/empty.py b/src/anemoi/datasets/create/sources/empty.py index fb7fcd906..f948810f5 100644 --- a/src/anemoi/datasets/create/sources/empty.py +++ b/src/anemoi/datasets/create/sources/empty.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py index bb33f7d50..81cdb7e13 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/create/sources/fdb.py @@ -16,11 +16,10 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry from anemoi.datasets.create.typing import DateList -from ..source import Source -from . import source_registry - @source_registry.register("fdb") class FdbSource(Source): diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index bbafaa465..88eca92e4 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 03bcda475..e1eaed2da 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -20,7 +20,7 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index ea6878929..160ff3f3a 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -19,7 +19,7 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from .legacy import legacy_source +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/hindcasts.py b/src/anemoi/datasets/create/sources/hindcasts.py index 9c470218c..d796a74af 100644 --- a/src/anemoi/datasets/create/sources/hindcasts.py +++ b/src/anemoi/datasets/create/sources/hindcasts.py @@ -12,10 +12,9 @@ from earthkit.data.core.fieldlist import MultiFieldList +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.mars import mars -from .legacy import legacy_source - LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index 4dbd481cd..352ae207e 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -14,8 +14,8 @@ from collections.abc import Callable from typing import Any -from ..source import Source -from . import source_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/mars.py b/src/anemoi/datasets/create/sources/mars.py index 1a419f691..d59f6034d 100644 --- a/src/anemoi/datasets/create/sources/mars.py +++ b/src/anemoi/datasets/create/sources/mars.py @@ -16,10 +16,9 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - DEBUG = False diff --git a/src/anemoi/datasets/create/sources/netcdf.py b/src/anemoi/datasets/create/sources/netcdf.py index a73c095d3..606a8dd53 100644 --- a/src/anemoi/datasets/create/sources/netcdf.py +++ b/src/anemoi/datasets/create/sources/netcdf.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/opendap.py b/src/anemoi/datasets/create/sources/opendap.py index 483295a8b..34e3fe94d 100644 --- a/src/anemoi/datasets/create/sources/opendap.py +++ b/src/anemoi/datasets/create/sources/opendap.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py index b710bcbbe..07e8f0203 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/recentre.py b/src/anemoi/datasets/create/sources/recentre.py index 53ace8152..d0959f664 100644 --- a/src/anemoi/datasets/create/sources/recentre.py +++ b/src/anemoi/datasets/create/sources/recentre.py @@ -11,9 +11,8 @@ from typing import Any from anemoi.datasets.compute.recentre import recentre as _recentre - -from .legacy import legacy_source -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars def to_list(x: list | tuple | str) -> list: diff --git a/src/anemoi/datasets/create/sources/source.py b/src/anemoi/datasets/create/sources/source.py index 0db02e6db..1bac545d8 100644 --- a/src/anemoi/datasets/create/sources/source.py +++ b/src/anemoi/datasets/create/sources/source.py @@ -12,10 +12,9 @@ from earthkit.data import from_source +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - @legacy_source(__file__) def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: diff --git a/src/anemoi/datasets/create/sources/tendencies.py b/src/anemoi/datasets/create/sources/tendencies.py index 01c4d1bda..222dca9a4 100644 --- a/src/anemoi/datasets/create/sources/tendencies.py +++ b/src/anemoi/datasets/create/sources/tendencies.py @@ -14,10 +14,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.utils import to_datetime_list -from .legacy import legacy_source - def _date_to_datetime(d: Any) -> Any: """Converts a date string or a list/tuple of date strings to datetime objects. diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py index d63b708d6..5e3cc4c10 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/create/sources/xarray.py @@ -11,13 +11,12 @@ import earthkit.data as ekd +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources.xarray_support import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support import load_many +from anemoi.datasets.create.sources.xarray_support import load_one from anemoi.datasets.create.typing import DateList -from ..source import Source -from .xarray_support import XarrayFieldList -from .xarray_support import load_many -from .xarray_support import load_one - __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py index 056d756ca..632a7cae2 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/create/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 33a057520..c33ce7bfc 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -15,10 +15,9 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.patterns import iterate_patterns - -from ..legacy import legacy_source -from .fieldlist import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support.fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 78f7de041..85f9970f8 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from .coordinates import extract_single_value -from .coordinates import is_scalar -from .metadata import XArrayMetadata +from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py index 48f9cf0e1..174cb2716 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from .field import EmptyFieldList -from .flavour import CoordinateGuesser -from .patch import patch_dataset -from .time import Time -from .variable import FilteredVariable -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList +from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser +from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset +from anemoi.datasets.create.sources.xarray_support.time import Time +from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 80f0b6a62..74fcdbd03 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from .coordinates import Coordinate -from .coordinates import DateCoordinate -from .coordinates import EnsembleCoordinate -from .coordinates import LatitudeCoordinate -from .coordinates import LevelCoordinate -from .coordinates import LongitudeCoordinate -from .coordinates import PointCoordinate -from .coordinates import ScalarCoordinate -from .coordinates import StepCoordinate -from .coordinates import TimeCoordinate -from .coordinates import UnsupportedCoordinate -from .coordinates import XCoordinate -from .coordinates import YCoordinate -from .coordinates import is_scalar -from .grid import Grid -from .grid import MeshedGrid -from .grid import MeshProjectionGrid -from .grid import UnstructuredGrid -from .grid import UnstructuredProjectionGrid +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.grid import Grid +from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid +from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 23713ae74..2230db3ef 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from .field import XArrayField + from anemoi.datasets.create.sources.xarray_support.field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py index 847b21598..7b1f60e58 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/create/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from .coordinates import Coordinate -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 5d2c1c5b1..13d6fa4e2 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from .field import XArrayField +from anemoi.datasets.create.sources.xarray_support.field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_zarr.py b/src/anemoi/datasets/create/sources/xarray_zarr.py index e91de781e..2f96ab207 100644 --- a/src/anemoi/datasets/create/sources/xarray_zarr.py +++ b/src/anemoi/datasets/create/sources/xarray_zarr.py @@ -11,8 +11,8 @@ import earthkit.data as ekd -from .legacy import legacy_source -from .xarray import load_many +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/zenodo.py b/src/anemoi/datasets/create/sources/zenodo.py index 1b746bb42..e23b8fa47 100644 --- a/src/anemoi/datasets/create/sources/zenodo.py +++ b/src/anemoi/datasets/create/sources/zenodo.py @@ -14,9 +14,9 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from .legacy import legacy_source -from .patterns import iterate_patterns -from .xarray import load_one +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.patterns import iterate_patterns +from anemoi.datasets.create.sources.xarray import load_one @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index f74cbf364..e8e71c45a 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -23,8 +23,8 @@ from anemoi.utils.provenance import gather_provenance_info from numpy.typing import NDArray -from ..check import check_data_values -from .summary import Summary +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.statistics.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/statistics/summary.py b/src/anemoi/datasets/create/statistics/summary.py index 6c7bbb433..8b6c29eb0 100644 --- a/src/anemoi/datasets/create/statistics/summary.py +++ b/src/anemoi/datasets/create/statistics/summary.py @@ -13,9 +13,9 @@ import numpy as np -from ..check import StatisticsValueError -from ..check import check_data_values -from ..check import check_stats +from anemoi.datasets.create.check import StatisticsValueError +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.check import check_stats class Summary(dict): diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py index f32d83bb2..fc2b0839b 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/data/__init__.py @@ -15,13 +15,13 @@ # from .dataset import FullIndex # from .dataset import Shape # from .dataset import TupleIndex -from .misc import _open_dataset -from .misc import _save_dataset -from .misc import add_dataset_path -from .misc import add_named_dataset +from anemoi.datasets.data.misc import _open_dataset +from anemoi.datasets.data.misc import _save_dataset +from anemoi.datasets.data.misc import add_dataset_path +from anemoi.datasets.data.misc import add_named_dataset if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py index be5f84409..87c65a5b4 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/data/complement.py @@ -16,18 +16,18 @@ import numpy as np from numpy.typing import NDArray -from ..grids import nearest_grid_points -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open_dataset +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open_dataset +from anemoi.datasets.grids import nearest_grid_points LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/concat.py b/src/anemoi/datasets/data/concat.py index 234001c8c..fcdc768fc 100644 --- a/src/anemoi/datasets/data/concat.py +++ b/src/anemoi/datasets/data/concat.py @@ -16,20 +16,20 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def check_dataset_compatibility(cls, datasets: list[Any], fill_missing_gaps: boo s = ranges[i + 1] if r[1] + frequency != s[0]: if fill_missing_gaps: - from .missing import MissingDataset + from anemoi.datasets.data.missing import MissingDataset result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency)) else: diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 4b76d24f5..021e385a2 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -34,8 +34,8 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .debug import Node -from .debug import Source +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source if TYPE_CHECKING: import matplotlib @@ -165,7 +165,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # This one must be first if "fill_missing_dates" in kwargs: - from .fill_missing import fill_missing_dates_factory + from anemoi.datasets.data.fill_missing import fill_missing_dates_factory fill_missing_dates = kwargs.pop("fill_missing_dates") ds = fill_missing_dates_factory(self, fill_missing_dates, kwargs) @@ -179,7 +179,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": if padding: if padding != "empty": raise ValueError(f"Only 'empty' padding is supported, got {padding=}") - from .padded import Padded + from anemoi.datasets.data.padded import Padded frequency = kwargs.pop("frequency", self.frequency) return ( @@ -188,14 +188,14 @@ def __subset(self, **kwargs: Any) -> "Dataset": .mutate() ) - from .subset import Subset + from anemoi.datasets.data.subset import Subset return ( Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate() ) if "frequency" in kwargs: - from .subset import Subset + from anemoi.datasets.data.subset import Subset if "interpolate_frequency" in kwargs: raise ValueError("Cannot use both `frequency` and `interpolate_frequency`") @@ -208,38 +208,38 @@ def __subset(self, **kwargs: Any) -> "Dataset": ) if "select" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select select = kwargs.pop("select") return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate() if "drop" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select drop = kwargs.pop("drop") return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate() if "reorder" in kwargs: - from .select import Select + from anemoi.datasets.data.select import Select reorder = kwargs.pop("reorder") return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() if "rename" in kwargs: - from .select import Rename + from anemoi.datasets.data.select import Rename rename = kwargs.pop("rename") return Rename(self, rename)._subset(**kwargs).mutate() if "rescale" in kwargs: - from .rescale import Rescale + from anemoi.datasets.data.rescale import Rescale rescale = kwargs.pop("rescale") return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from ..data import open_dataset - from .statistics import Statistics + from anemoi.datasets.data import open_dataset + from anemoi.datasets.data.statistics import Statistics statistics = kwargs.pop("statistics") @@ -247,26 +247,26 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Note: trim_edge should go before thinning if "trim_edge" in kwargs: - from .masked import TrimEdge + from anemoi.datasets.data.masked import TrimEdge edge = kwargs.pop("trim_edge") return TrimEdge(self, edge)._subset(**kwargs).mutate() if "thinning" in kwargs: - from .masked import Thinning + from anemoi.datasets.data.masked import Thinning thinning = kwargs.pop("thinning") method = kwargs.pop("method", "every-nth") return Thinning(self, thinning, method)._subset(**kwargs).mutate() if "area" in kwargs: - from .masked import Cropping + from anemoi.datasets.data.masked import Cropping bbox = kwargs.pop("area") return Cropping(self, bbox)._subset(**kwargs).mutate() if "number" in kwargs or "numbers" in kwargs or "member" in kwargs or "members" in kwargs: - from .ensemble import Number + from anemoi.datasets.data.ensemble import Number members = {} for key in ["number", "numbers", "member", "members"]: @@ -276,13 +276,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Number(self, **members)._subset(**kwargs).mutate() if "set_missing_dates" in kwargs: - from .missing import MissingDates + from anemoi.datasets.data.missing import MissingDates set_missing_dates = kwargs.pop("set_missing_dates") return MissingDates(self, set_missing_dates)._subset(**kwargs).mutate() if "skip_missing_dates" in kwargs: - from .missing import SkipMissingDates + from anemoi.datasets.data.missing import SkipMissingDates if "expected_access" not in kwargs: raise ValueError("`expected_access` is required with `skip_missing_dates`") @@ -294,13 +294,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() if "interpolate_frequency" in kwargs: - from .interpolate import InterpolateFrequency + from anemoi.datasets.data.interpolate import InterpolateFrequency interpolate_frequency = kwargs.pop("interpolate_frequency") return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate() if "interpolate_variables" in kwargs: - from .interpolate import InterpolateNearest + from anemoi.datasets.data.interpolate import InterpolateNearest interpolate_variables = kwargs.pop("interpolate_variables") max_distance = kwargs.pop("max_distance", None) @@ -308,7 +308,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Keep last if "shuffle" in kwargs: - from .subset import Subset + from anemoi.datasets.data.subset import Subset shuffle = kwargs.pop("shuffle") @@ -372,8 +372,8 @@ def _dates_to_indices( list of int The list of indices. """ - from .misc import as_first_date - from .misc import as_last_date + from anemoi.datasets.data.misc import as_first_date + from anemoi.datasets.data.misc import as_last_date # TODO: optimize diff --git a/src/anemoi/datasets/data/debug.py b/src/anemoi/datasets/data/debug.py index 0c58dafa1..ca80ace41 100644 --- a/src/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/data/debug.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/ensemble.py b/src/anemoi/datasets/data/ensemble.py index 50725c2c1..a6c59c812 100644 --- a/src/anemoi/datasets/data/ensemble.py +++ b/src/anemoi/datasets/data/ensemble.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .debug import Node -from .forwards import Forwards -from .forwards import GivenAxis -from .indexing import apply_index_to_slices_changes -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.forwards import GivenAxis +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/fill_missing.py b/src/anemoi/datasets/data/fill_missing.py index d705b1d75..0cc1b0ee2 100644 --- a/src/anemoi/datasets/data/fill_missing.py +++ b/src/anemoi/datasets/data/fill_missing.py @@ -15,17 +15,16 @@ from numpy.typing import NDArray from anemoi.datasets.data import MissingDateError - -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 4e2219b1c..463aebf3e 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -18,16 +18,16 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import debug_indexing -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index 1e3a40cf7..fee2c792e 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -16,21 +16,21 @@ from numpy.typing import NDArray from scipy.spatial import cKDTree -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .forwards import GivenAxis -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import length_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.forwards import GivenAxis +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/indexing.py b/src/anemoi/datasets/data/indexing.py index 106023ccb..7c4bb4be3 100644 --- a/src/anemoi/datasets/data/indexing.py +++ b/src/anemoi/datasets/data/indexing.py @@ -15,9 +15,9 @@ import numpy as np from numpy.typing import NDArray -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex def _tuple_with_slices(t: TupleIndex, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]: diff --git a/src/anemoi/datasets/data/interpolate.py b/src/anemoi/datasets/data/interpolate.py index b03404645..1f64d21a9 100644 --- a/src/anemoi/datasets/data/interpolate.py +++ b/src/anemoi/datasets/data/interpolate.py @@ -17,17 +17,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -227,7 +227,7 @@ def __init__(self, dataset: Dataset, interpolate_variables: list[str], max_dista max_distance : Optional[float], optional The maximum distance for nearest neighbor search, by default None. """ - from ..grids import nearest_grid_points + from anemoi.datasets.grids import nearest_grid_points super().__init__(dataset) self.vars = interpolate_variables diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 59aefd3a4..f7f1caa03 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -16,20 +16,20 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) @@ -173,7 +173,7 @@ def _overlay(self) -> Dataset: if not ok: LOG.warning("Dataset %r completely overridden.", d) - from .select import Select + from anemoi.datasets.data.select import Select return Select(self, indices, {"overlay": variables}) diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index f7eeea03d..b4982d696 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -15,18 +15,18 @@ import numpy as np from numpy.typing import NDArray -from ..grids import cropping_mask -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.grids import cropping_mask LOG = logging.getLogger(__name__) @@ -214,7 +214,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from ..data import open_dataset + from anemoi.datasets.data import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/data/merge.py index ca2697dda..b974a6afb 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/data/merge.py @@ -16,19 +16,19 @@ import numpy as np from numpy.typing import NDArray -from . import MissingDateError -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Combined -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 3252e345a..53ae8bb07 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -23,7 +23,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from .dataset import Dataset + from anemoi.datasets.data.dataset import Dataset LOG = logging.getLogger(__name__) @@ -323,11 +323,11 @@ def _concat_or_join(datasets: list["Dataset"], kwargs: dict[str, Any]) -> tuple[ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets] if len(set(ranges)) == 1: - from .join import Join + from anemoi.datasets.data.join import Join return Join(datasets)._overlay(), kwargs - from .concat import Concat + from anemoi.datasets.data.concat import Concat Concat.check_dataset_compatibility(datasets) @@ -347,9 +347,9 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " Dataset The opened dataset. """ - from .dataset import Dataset - from .stores import Zarr - from .stores import zarr_lookup + from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.data.stores import Zarr + from anemoi.datasets.data.stores import zarr_lookup if isinstance(a, str) and len(a.split(".")) in [2, 3]: @@ -501,7 +501,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from .observations import observations_factory + from anemoi.datasets.data.observations import observations_factory assert not sets, sets @@ -509,70 +509,70 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "xy" in kwargs: # Experimental feature, may be removed - from .xy import xy_factory + from anemoi.datasets.data.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "x" in kwargs and "y" in kwargs: # Experimental feature, may be removed - from .xy import xy_factory + from anemoi.datasets.data.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "zip" in kwargs: # Experimental feature, may be removed - from .xy import zip_factory + from anemoi.datasets.data.xy import zip_factory assert not sets, sets return zip_factory(args, kwargs).mutate() if "chain" in kwargs: # Experimental feature, may be removed - from .unchecked import chain_factory + from anemoi.datasets.data.unchecked import chain_factory assert not sets, sets return chain_factory(args, kwargs).mutate() if "join" in kwargs: - from .join import join_factory + from anemoi.datasets.data.join import join_factory assert not sets, sets return join_factory(args, kwargs).mutate() if "concat" in kwargs: - from .concat import concat_factory + from anemoi.datasets.data.concat import concat_factory assert not sets, sets return concat_factory(args, kwargs).mutate() if "merge" in kwargs: - from .merge import merge_factory + from anemoi.datasets.data.merge import merge_factory assert not sets, sets return merge_factory(args, kwargs).mutate() if "ensemble" in kwargs: - from .ensemble import ensemble_factory + from anemoi.datasets.data.ensemble import ensemble_factory assert not sets, sets return ensemble_factory(args, kwargs).mutate() if "grids" in kwargs: - from .grids import grids_factory + from anemoi.datasets.data.grids import grids_factory assert not sets, sets return grids_factory(args, kwargs).mutate() if "cutout" in kwargs: - from .grids import cutout_factory + from anemoi.datasets.data.grids import cutout_factory assert not sets, sets return cutout_factory(args, kwargs).mutate() if "complement" in kwargs: - from .complement import complement_factory + from anemoi.datasets.data.complement import complement_factory assert not sets, sets return complement_factory(args, kwargs).mutate() diff --git a/src/anemoi/datasets/data/missing.py b/src/anemoi/datasets/data/missing.py index 5e6530bda..5a6e8a5f8 100644 --- a/src/anemoi/datasets/data/missing.py +++ b/src/anemoi/datasets/data/missing.py @@ -18,15 +18,14 @@ from anemoi.datasets.create.utils import to_datetime from anemoi.datasets.data import MissingDateError - -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import expand_list_indexing -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/data/observations/__init__.py index bb9595da9..23413e05d 100644 --- a/src/anemoi/datasets/data/observations/__init__.py +++ b/src/anemoi/datasets/data/observations/__init__.py @@ -15,8 +15,7 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets.data.dataset import Dataset - -from ..debug import Node +from anemoi.datasets.data.debug import Node LOG = logging.getLogger(__name__) @@ -139,7 +138,7 @@ def __init__(self, dataset, frequency=None, window=None): if isinstance(dataset, zarr.hierarchy.Group): dataset = dataset._store.path - from ..stores import zarr_lookup + from anemoi.datasets.data.stores import zarr_lookup dataset = zarr_lookup(dataset) self.path = dataset @@ -177,7 +176,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from .legacy_obs_dataset import ObsDataset + from anemoi.datasets.data.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/data/rescale.py index 613bbe93e..f5d8734fe 100644 --- a/src/anemoi/datasets/data/rescale.py +++ b/src/anemoi/datasets/data/rescale.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import TupleIndex -from .debug import Node -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/data/select.py index 048802892..e27b94f76 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/data/select.py @@ -15,18 +15,18 @@ from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/statistics.py b/src/anemoi/datasets/data/statistics.py index af0d4bc6e..2bb26b3d6 100644 --- a/src/anemoi/datasets/data/statistics.py +++ b/src/anemoi/datasets/data/statistics.py @@ -15,10 +15,10 @@ from numpy.typing import NDArray -from . import open_dataset -from .dataset import Dataset -from .debug import Node -from .forwards import Forwards +from anemoi.datasets.data import open_dataset +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Forwards LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 78470fec6..9224c22d3 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -22,17 +22,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from . import MissingDateError -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import DEBUG_ZARR_LOADING -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .indexing import expand_list_indexing -from .misc import load_config +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import DEBUG_ZARR_LOADING +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.misc import load_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 8954fa5bc..22eef70da 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -19,19 +19,19 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .dataset import TupleIndex -from .debug import Node -from .debug import Source -from .debug import debug_indexing -from .forwards import Forwards -from .indexing import apply_index_to_slices_changes -from .indexing import expand_list_indexing -from .indexing import index_to_slices -from .indexing import make_slice_or_index_from_list_or_tuple -from .indexing import update_tuple +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.debug import Source +from anemoi.datasets.data.debug import debug_indexing +from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.data.indexing import apply_index_to_slices_changes +from anemoi.datasets.data.indexing import expand_list_indexing +from anemoi.datasets.data.indexing import index_to_slices +from anemoi.datasets.data.indexing import make_slice_or_index_from_list_or_tuple +from anemoi.datasets.data.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def _start(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the start date. """ - from .misc import as_first_date + from anemoi.datasets.data.misc import as_first_date c = as_first_date(a, dates) d = as_first_date(b, dates) @@ -82,7 +82,7 @@ def _end(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the end date. """ - from .misc import as_last_date + from anemoi.datasets.data.misc import as_last_date c = as_last_date(a, dates) d = as_last_date(b, dates) diff --git a/src/anemoi/datasets/data/unchecked.py b/src/anemoi/datasets/data/unchecked.py index cb4a1304c..478c8c1eb 100644 --- a/src/anemoi/datasets/data/unchecked.py +++ b/src/anemoi/datasets/data/unchecked.py @@ -18,14 +18,14 @@ import numpy as np from numpy.typing import NDArray -from .concat import ConcatMixin -from .dataset import Dataset -from .dataset import FullIndex -from .dataset import Shape -from .debug import Node -from .forwards import Combined -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.concat import ConcatMixin +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.dataset import Shape +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/xy.py b/src/anemoi/datasets/data/xy.py index d3ae622bb..e181dc9aa 100644 --- a/src/anemoi/datasets/data/xy.py +++ b/src/anemoi/datasets/data/xy.py @@ -12,12 +12,12 @@ from functools import cached_property from typing import Any -from .dataset import Dataset -from .dataset import FullIndex -from .debug import Node -from .forwards import Combined -from .misc import _auto_adjust -from .misc import _open +from anemoi.datasets.data.dataset import Dataset +from anemoi.datasets.data.dataset import FullIndex +from anemoi.datasets.data.debug import Node +from anemoi.datasets.data.forwards import Combined +from anemoi.datasets.data.misc import _auto_adjust +from anemoi.datasets.data.misc import _open LOG = logging.getLogger(__name__) From 19b9171f3bf4430680365af756a75ed723e96115 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:11:44 +0000 Subject: [PATCH 147/212] rename files --- src/anemoi/datasets/.gitignore | 1 + src/anemoi/datasets/__init__.py | 10 ++-- .../datasets/{create => build}/__init__.py | 46 +++++++++--------- .../datasets/{create => build}/check.py | 0 .../datasets/{create => build}/chunks.py | 0 .../datasets/{create => build}/config.py | 0 .../datasets/{create => build}/filter.py | 0 .../gridded}/sources/__init__.py | 0 .../gridded}/sources/accumulations.py | 8 ++-- .../gridded}/sources/accumulations2.py | 6 +-- .../gridded}/sources/anemoi_dataset.py | 2 +- .../gridded}/sources/constants.py | 2 +- .../gridded}/sources/eccc_fstd.py | 4 +- .../gridded}/sources/empty.py | 2 +- .../{create => build/gridded}/sources/fdb.py | 8 ++-- .../gridded}/sources/forcings.py | 2 +- .../{create => build/gridded}/sources/grib.py | 2 +- .../gridded}/sources/grib_index.py | 2 +- .../gridded}/sources/hindcasts.py | 4 +- .../gridded}/sources/legacy.py | 4 +- .../{create => build/gridded}/sources/mars.py | 4 +- .../gridded}/sources/netcdf.py | 4 +- .../gridded}/sources/opendap.py | 4 +- .../gridded}/sources/patterns.py | 0 .../gridded}/sources/planetary_computer.py | 4 +- .../gridded}/sources/recentre.py | 4 +- .../gridded}/sources/repeated_dates.py | 4 +- .../gridded}/sources/source.py | 4 +- .../gridded}/sources/tendencies.py | 6 +-- .../gridded}/sources/xarray.py | 10 ++-- .../gridded}/sources/xarray_kerchunk.py | 4 +- .../gridded}/sources/xarray_support/README.md | 0 .../sources/xarray_support/__init__.py | 6 +-- .../sources/xarray_support/coordinates.py | 0 .../gridded}/sources/xarray_support/field.py | 6 +-- .../sources/xarray_support/fieldlist.py | 12 ++--- .../sources/xarray_support/flavour.py | 38 +++++++-------- .../gridded}/sources/xarray_support/grid.py | 0 .../sources/xarray_support/metadata.py | 2 +- .../gridded}/sources/xarray_support/patch.py | 0 .../gridded}/sources/xarray_support/time.py | 4 +- .../sources/xarray_support/variable.py | 2 +- .../gridded}/sources/xarray_zarr.py | 4 +- .../gridded}/sources/zenodo.py | 6 +-- .../{create => build}/input/__init__.py | 8 ++-- .../{create => build}/input/action.py | 4 +- .../input/context/__init__.py | 2 +- .../{create => build}/input/context/field.py | 4 +- .../{create => build}/input/data_sources.py | 8 ++-- .../datasets/{create => build}/input/misc.py | 0 .../{create => build}/input/repeated_dates.py | 10 ++-- .../input/result/__init__.py | 0 .../{create => build}/input/result/field.py | 2 +- .../datasets/{create => build}/input/trace.py | 0 .../datasets/{create => build}/patch.py | 0 .../datasets/{create => build}/persistent.py | 0 src/anemoi/datasets/{create => build}/size.py | 0 .../datasets/{create => build}/source.py | 2 +- .../{create => build}/statistics/__init__.py | 4 +- .../{create => build}/statistics/summary.py | 6 +-- .../datasets/{create => build}/testing.py | 0 .../datasets/{create => build}/typing.py | 0 .../datasets/{create => build}/utils.py | 0 .../datasets/{create => build}/writer.py | 0 src/anemoi/datasets/{create => build}/zarr.py | 0 src/anemoi/datasets/commands/check.py | 2 +- src/anemoi/datasets/commands/create.py | 2 +- src/anemoi/datasets/commands/grib-index.py | 2 +- src/anemoi/datasets/commands/inspect.py | 4 +- .../datasets/commands/recipe/__init__.py | 2 +- .../datasets/commands/recipe/migrate.py | 2 +- .../{data => use/gridded}/__init__.py | 10 ++-- .../{data => use/gridded}/complement.py | 22 ++++----- .../datasets/{data => use/gridded}/concat.py | 30 ++++++------ .../datasets/{data => use/gridded}/dataset.py | 48 +++++++++---------- .../datasets/{data => use/gridded}/debug.css | 0 .../datasets/{data => use/gridded}/debug.py | 2 +- .../{data => use/gridded}/ensemble.py | 22 ++++----- .../{data => use/gridded}/fill_missing.py | 22 ++++----- .../{data => use/gridded}/forwards.py | 20 ++++---- .../datasets/{data => use/gridded}/grids.py | 30 ++++++------ .../{data => use/gridded}/indexing.py | 6 +-- .../{data => use/gridded}/interpolate.py | 22 ++++----- .../datasets/{data => use/gridded}/join.py | 30 ++++++------ .../datasets/{data => use/gridded}/masked.py | 24 +++++----- .../datasets/{data => use/gridded}/merge.py | 26 +++++----- .../datasets/{data => use/gridded}/misc.py | 38 +++++++-------- .../datasets/{data => use/gridded}/missing.py | 20 ++++---- .../datasets/{data => use/gridded}/padded.py | 20 ++++---- .../datasets/{data => use/gridded}/rescale.py | 20 ++++---- .../datasets/{data => use/gridded}/select.py | 24 +++++----- .../{data => use/gridded}/statistics.py | 8 ++-- .../datasets/{data => use/gridded}/stores.py | 22 ++++----- .../datasets/{data => use/gridded}/subset.py | 30 ++++++------ .../{data => use/gridded}/unchecked.py | 16 +++---- .../datasets/{data => use/gridded}/xy.py | 12 ++--- .../{data => use}/observations/__init__.py | 8 ++-- .../observations/legacy_obs_dataset.py | 0 .../{data => use}/observations/multi.py | 2 +- .../{data => use}/records/__init__.py | 6 +-- .../records/backends/__init__.py | 4 +- src/anemoi/datasets/validate.py | 2 +- tests/create/utils/compare.py | 2 +- tests/create/utils/create.py | 2 +- tests/test_chunks.py | 2 +- tests/test_data.py | 26 +++++----- tests/test_dates.py | 2 +- tests/test_indexing.py | 2 +- tests/test_records.py | 6 +-- tests/xarray/test_flavour.py | 24 +++++----- tests/xarray/test_netcdf.py | 2 +- tests/xarray/test_opendap.py | 2 +- tests/xarray/test_variable.py | 16 +++---- tests/xarray/test_zarr.py | 2 +- tools/build-obs.py | 2 +- 115 files changed, 466 insertions(+), 465 deletions(-) create mode 100644 src/anemoi/datasets/.gitignore rename src/anemoi/datasets/{create => build}/__init__.py (97%) rename src/anemoi/datasets/{create => build}/check.py (100%) rename src/anemoi/datasets/{create => build}/chunks.py (100%) rename src/anemoi/datasets/{create => build}/config.py (100%) rename src/anemoi/datasets/{create => build}/filter.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/__init__.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/accumulations.py (99%) rename src/anemoi/datasets/{create => build/gridded}/sources/accumulations2.py (99%) rename src/anemoi/datasets/{create => build/gridded}/sources/anemoi_dataset.py (96%) rename src/anemoi/datasets/{create => build/gridded}/sources/constants.py (95%) rename src/anemoi/datasets/{create => build/gridded}/sources/eccc_fstd.py (82%) rename src/anemoi/datasets/{create => build/gridded}/sources/empty.py (93%) rename src/anemoi/datasets/{create => build/gridded}/sources/fdb.py (94%) rename src/anemoi/datasets/{create => build/gridded}/sources/forcings.py (94%) rename src/anemoi/datasets/{create => build/gridded}/sources/grib.py (98%) rename src/anemoi/datasets/{create => build/gridded}/sources/grib_index.py (99%) rename src/anemoi/datasets/{create => build/gridded}/sources/hindcasts.py (95%) rename src/anemoi/datasets/{create => build/gridded}/sources/legacy.py (95%) rename src/anemoi/datasets/{create => build/gridded}/sources/mars.py (99%) rename src/anemoi/datasets/{create => build/gridded}/sources/netcdf.py (90%) rename src/anemoi/datasets/{create => build/gridded}/sources/opendap.py (90%) rename src/anemoi/datasets/{create => build/gridded}/sources/patterns.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/planetary_computer.py (92%) rename src/anemoi/datasets/{create => build/gridded}/sources/recentre.py (97%) rename src/anemoi/datasets/{create => build/gridded}/sources/repeated_dates.py (98%) rename src/anemoi/datasets/{create => build/gridded}/sources/source.py (94%) rename src/anemoi/datasets/{create => build/gridded}/sources/tendencies.py (96%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray.py (88%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_kerchunk.py (89%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/README.md (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/__init__.py (95%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/coordinates.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/field.py (96%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/fieldlist.py (94%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/flavour.py (95%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/grid.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/metadata.py (98%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/patch.py (100%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/time.py (98%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_support/variable.py (99%) rename src/anemoi/datasets/{create => build/gridded}/sources/xarray_zarr.py (89%) rename src/anemoi/datasets/{create => build/gridded}/sources/zenodo.py (90%) rename src/anemoi/datasets/{create => build}/input/__init__.py (89%) rename src/anemoi/datasets/{create => build}/input/action.py (98%) rename src/anemoi/datasets/{create => build}/input/context/__init__.py (96%) rename src/anemoi/datasets/{create => build}/input/context/field.py (92%) rename src/anemoi/datasets/{create => build}/input/data_sources.py (94%) rename src/anemoi/datasets/{create => build}/input/misc.py (100%) rename src/anemoi/datasets/{create => build}/input/repeated_dates.py (97%) rename src/anemoi/datasets/{create => build}/input/result/__init__.py (100%) rename src/anemoi/datasets/{create => build}/input/result/field.py (99%) rename src/anemoi/datasets/{create => build}/input/trace.py (100%) rename src/anemoi/datasets/{create => build}/patch.py (100%) rename src/anemoi/datasets/{create => build}/persistent.py (100%) rename src/anemoi/datasets/{create => build}/size.py (100%) rename src/anemoi/datasets/{create => build}/source.py (96%) rename src/anemoi/datasets/{create => build}/statistics/__init__.py (99%) rename src/anemoi/datasets/{create => build}/statistics/summary.py (96%) rename src/anemoi/datasets/{create => build}/testing.py (100%) rename src/anemoi/datasets/{create => build}/typing.py (100%) rename src/anemoi/datasets/{create => build}/utils.py (100%) rename src/anemoi/datasets/{create => build}/writer.py (100%) rename src/anemoi/datasets/{create => build}/zarr.py (100%) rename src/anemoi/datasets/{data => use/gridded}/__init__.py (91%) rename src/anemoi/datasets/{data => use/gridded}/complement.py (95%) rename src/anemoi/datasets/{data => use/gridded}/concat.py (91%) rename src/anemoi/datasets/{data => use/gridded}/dataset.py (95%) rename src/anemoi/datasets/{data => use/gridded}/debug.css (100%) rename src/anemoi/datasets/{data => use/gridded}/debug.py (99%) rename src/anemoi/datasets/{data => use/gridded}/ensemble.py (89%) rename src/anemoi/datasets/{data => use/gridded}/fill_missing.py (93%) rename src/anemoi/datasets/{data => use/gridded}/forwards.py (97%) rename src/anemoi/datasets/{data => use/gridded}/grids.py (96%) rename src/anemoi/datasets/{data => use/gridded}/indexing.py (98%) rename src/anemoi/datasets/{data => use/gridded}/interpolate.py (93%) rename src/anemoi/datasets/{data => use/gridded}/join.py (91%) rename src/anemoi/datasets/{data => use/gridded}/masked.py (93%) rename src/anemoi/datasets/{data => use/gridded}/merge.py (92%) rename src/anemoi/datasets/{data => use/gridded}/misc.py (95%) rename src/anemoi/datasets/{data => use/gridded}/missing.py (95%) rename src/anemoi/datasets/{data => use/gridded}/padded.py (93%) rename src/anemoi/datasets/{data => use/gridded}/rescale.py (92%) rename src/anemoi/datasets/{data => use/gridded}/select.py (92%) rename src/anemoi/datasets/{data => use/gridded}/statistics.py (94%) rename src/anemoi/datasets/{data => use/gridded}/stores.py (96%) rename src/anemoi/datasets/{data => use/gridded}/subset.py (90%) rename src/anemoi/datasets/{data => use/gridded}/unchecked.py (94%) rename src/anemoi/datasets/{data => use/gridded}/xy.py (96%) rename src/anemoi/datasets/{data => use}/observations/__init__.py (97%) rename src/anemoi/datasets/{data => use}/observations/legacy_obs_dataset.py (100%) rename src/anemoi/datasets/{data => use}/observations/multi.py (97%) rename src/anemoi/datasets/{data => use}/records/__init__.py (98%) rename src/anemoi/datasets/{data => use}/records/backends/__init__.py (97%) diff --git a/src/anemoi/datasets/.gitignore b/src/anemoi/datasets/.gitignore new file mode 100644 index 000000000..0aba28e9b --- /dev/null +++ b/src/anemoi/datasets/.gitignore @@ -0,0 +1 @@ +!build/ diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index 620f5e80f..c38a5de68 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -8,11 +8,11 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data import add_dataset_path -from anemoi.datasets.data import add_named_dataset -from anemoi.datasets.data import list_dataset_names -from anemoi.datasets.data import open_dataset +from anemoi.datasets.use import MissingDateError +from anemoi.datasets.use import add_dataset_path +from anemoi.datasets.use import add_named_dataset +from anemoi.datasets.use import list_dataset_names +from anemoi.datasets.use import open_dataset try: # NOTE: the `_version.py` file must not be present in the git repository diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/build/__init__.py similarity index 97% rename from src/anemoi/datasets/create/__init__.py rename to src/anemoi/datasets/build/__init__.py index cec8b95da..f28955dd8 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/build/__init__.py @@ -31,25 +31,25 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset -from anemoi.datasets.create.check import DatasetName -from anemoi.datasets.create.check import check_data_values -from anemoi.datasets.create.chunks import ChunkFilter -from anemoi.datasets.create.config import build_output -from anemoi.datasets.create.config import loader_config -from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.input.trace import enable_trace -from anemoi.datasets.create.persistent import build_storage -from anemoi.datasets.create.statistics import Summary -from anemoi.datasets.create.statistics import TmpStatistics -from anemoi.datasets.create.statistics import check_variance -from anemoi.datasets.create.statistics import compute_statistics -from anemoi.datasets.create.statistics import default_statistics_dates -from anemoi.datasets.create.statistics import fix_variance -from anemoi.datasets.create.utils import normalize_and_check_dates -from anemoi.datasets.create.writer import ViewCacheArray -from anemoi.datasets.data.misc import as_first_date -from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.build.check import DatasetName +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.chunks import ChunkFilter +from anemoi.datasets.build.config import build_output +from anemoi.datasets.build.config import loader_config +from anemoi.datasets.build.input import InputBuilder +from anemoi.datasets.build.input.trace import enable_trace +from anemoi.datasets.build.persistent import build_storage +from anemoi.datasets.build.statistics import Summary +from anemoi.datasets.build.statistics import TmpStatistics +from anemoi.datasets.build.statistics import check_variance +from anemoi.datasets.build.statistics import compute_statistics +from anemoi.datasets.build.statistics import default_statistics_dates +from anemoi.datasets.build.statistics import fix_variance +from anemoi.datasets.build.utils import normalize_and_check_dates +from anemoi.datasets.build.writer import ViewCacheArray from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.use.misc import as_first_date +from anemoi.datasets.use.misc import as_last_date LOG = logging.getLogger(__name__) @@ -192,7 +192,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from anemoi.datasets.create.zarr import add_zarr_dataset + from anemoi.datasets.build.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -396,7 +396,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from anemoi.datasets.create.utils import cache_context + from anemoi.datasets.build.utils import cache_context return cache_context(self.cache) @@ -472,7 +472,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from anemoi.datasets.create.patch import apply_patch + from anemoi.datasets.build.patch import apply_patch apply_patch(self.path, **self.options) @@ -492,7 +492,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from anemoi.datasets.create.size import compute_directory_sizes + from anemoi.datasets.build.size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -514,7 +514,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from anemoi.datasets.create.zarr import ZarrBuiltRegistry + from anemoi.datasets.build.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/build/check.py similarity index 100% rename from src/anemoi/datasets/create/check.py rename to src/anemoi/datasets/build/check.py diff --git a/src/anemoi/datasets/create/chunks.py b/src/anemoi/datasets/build/chunks.py similarity index 100% rename from src/anemoi/datasets/create/chunks.py rename to src/anemoi/datasets/build/chunks.py diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/build/config.py similarity index 100% rename from src/anemoi/datasets/create/config.py rename to src/anemoi/datasets/build/config.py diff --git a/src/anemoi/datasets/create/filter.py b/src/anemoi/datasets/build/filter.py similarity index 100% rename from src/anemoi/datasets/create/filter.py rename to src/anemoi/datasets/build/filter.py diff --git a/src/anemoi/datasets/create/sources/__init__.py b/src/anemoi/datasets/build/gridded/sources/__init__.py similarity index 100% rename from src/anemoi/datasets/create/sources/__init__.py rename to src/anemoi/datasets/build/gridded/sources/__init__.py diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py similarity index 99% rename from src/anemoi/datasets/create/sources/accumulations.py rename to src/anemoi/datasets/build/gridded/sources/accumulations.py index 40b8749f6..2d45b164a 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -20,9 +20,9 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.mars import mars +from anemoi.datasets.build.utils import to_datetime_list LOG = logging.getLogger(__name__) @@ -993,7 +993,7 @@ def accumulations( and request.get("stream", "oper") == "oper" and request.get("accumulation_period") == 24 ): - from anemoi.datasets.create.sources.accumulations2 import accumulations as accumulations2 + from anemoi.datasets.build.sources.accumulations2 import accumulations as accumulations2 LOG.warning( "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" diff --git a/src/anemoi/datasets/create/sources/accumulations2.py b/src/anemoi/datasets/build/gridded/sources/accumulations2.py similarity index 99% rename from src/anemoi/datasets/create/sources/accumulations2.py rename to src/anemoi/datasets/build/gridded/sources/accumulations2.py index 3c34d392e..eb560b4b2 100644 --- a/src/anemoi/datasets/create/sources/accumulations2.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations2.py @@ -18,9 +18,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.mars import mars +from anemoi.datasets.build.utils import to_datetime_list LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/anemoi_dataset.py b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py similarity index 96% rename from src/anemoi/datasets/create/sources/anemoi_dataset.py rename to src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py index a05e7df51..e890f8130 100644 --- a/src/anemoi/datasets/create/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py @@ -9,7 +9,7 @@ import numpy as np -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/build/gridded/sources/constants.py similarity index 95% rename from src/anemoi/datasets/create/sources/constants.py rename to src/anemoi/datasets/build/gridded/sources/constants.py index accde7936..b0c15ce94 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/build/gridded/sources/constants.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py similarity index 82% rename from src/anemoi/datasets/create/sources/eccc_fstd.py rename to src/anemoi/datasets/build/gridded/sources/eccc_fstd.py index fdd79af8d..59be1ea81 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from anemoi.datasets.build.sources import source_registry +from anemoi.datasets.build.sources.xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/empty.py b/src/anemoi/datasets/build/gridded/sources/empty.py similarity index 93% rename from src/anemoi/datasets/create/sources/empty.py rename to src/anemoi/datasets/build/gridded/sources/empty.py index f948810f5..fbcfdecf1 100644 --- a/src/anemoi/datasets/create/sources/empty.py +++ b/src/anemoi/datasets/build/gridded/sources/empty.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/build/gridded/sources/fdb.py similarity index 94% rename from src/anemoi/datasets/create/sources/fdb.py rename to src/anemoi/datasets/build/gridded/sources/fdb.py index 81cdb7e13..bdadb9d83 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/build/gridded/sources/fdb.py @@ -16,9 +16,9 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.typing import DateList +from anemoi.datasets.build.source import Source +from anemoi.datasets.build.sources import source_registry +from anemoi.datasets.build.typing import DateList @source_registry.register("fdb") @@ -124,7 +124,7 @@ def _time_request_keys(dt: datetime, offset_from_date: bool | None = None) -> st def _shortname_to_paramid(shortname: list[str], param_id_map: dict[str, int] | None = None) -> list[int]: - from anemoi.datasets.create.sources.mars import use_grib_paramid + from anemoi.datasets.build.sources.mars import use_grib_paramid """Convert a shortname to a parameter ID.""" if param_id_map is None: diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/build/gridded/sources/forcings.py similarity index 94% rename from src/anemoi/datasets/create/sources/forcings.py rename to src/anemoi/datasets/build/gridded/sources/forcings.py index 88eca92e4..ae3545b3f 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/build/gridded/sources/forcings.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/build/gridded/sources/grib.py similarity index 98% rename from src/anemoi/datasets/create/sources/grib.py rename to src/anemoi/datasets/build/gridded/sources/grib.py index e1eaed2da..2d5932347 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/build/gridded/sources/grib.py @@ -20,7 +20,7 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/build/gridded/sources/grib_index.py similarity index 99% rename from src/anemoi/datasets/create/sources/grib_index.py rename to src/anemoi/datasets/build/gridded/sources/grib_index.py index 160ff3f3a..9c52c462f 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/build/gridded/sources/grib_index.py @@ -19,7 +19,7 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.build.sources.legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/hindcasts.py b/src/anemoi/datasets/build/gridded/sources/hindcasts.py similarity index 95% rename from src/anemoi/datasets/create/sources/hindcasts.py rename to src/anemoi/datasets/build/gridded/sources/hindcasts.py index d796a74af..b633b320c 100644 --- a/src/anemoi/datasets/create/sources/hindcasts.py +++ b/src/anemoi/datasets/build/gridded/sources/hindcasts.py @@ -12,8 +12,8 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.mars import mars LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/build/gridded/sources/legacy.py similarity index 95% rename from src/anemoi/datasets/create/sources/legacy.py rename to src/anemoi/datasets/build/gridded/sources/legacy.py index 352ae207e..058443293 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/build/gridded/sources/legacy.py @@ -14,8 +14,8 @@ from collections.abc import Callable from typing import Any -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.build.source import Source +from anemoi.datasets.build.sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/mars.py b/src/anemoi/datasets/build/gridded/sources/mars.py similarity index 99% rename from src/anemoi/datasets/create/sources/mars.py rename to src/anemoi/datasets/build/gridded/sources/mars.py index d59f6034d..5ba70950e 100644 --- a/src/anemoi/datasets/create/sources/mars.py +++ b/src/anemoi/datasets/build/gridded/sources/mars.py @@ -16,8 +16,8 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.utils import to_datetime_list DEBUG = False diff --git a/src/anemoi/datasets/create/sources/netcdf.py b/src/anemoi/datasets/build/gridded/sources/netcdf.py similarity index 90% rename from src/anemoi/datasets/create/sources/netcdf.py rename to src/anemoi/datasets/build/gridded/sources/netcdf.py index 606a8dd53..175b97a65 100644 --- a/src/anemoi/datasets/create/sources/netcdf.py +++ b/src/anemoi/datasets/build/gridded/sources/netcdf.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/opendap.py b/src/anemoi/datasets/build/gridded/sources/opendap.py similarity index 90% rename from src/anemoi/datasets/create/sources/opendap.py rename to src/anemoi/datasets/build/gridded/sources/opendap.py index 34e3fe94d..09c4a0986 100644 --- a/src/anemoi/datasets/create/sources/opendap.py +++ b/src/anemoi/datasets/build/gridded/sources/opendap.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/patterns.py b/src/anemoi/datasets/build/gridded/sources/patterns.py similarity index 100% rename from src/anemoi/datasets/create/sources/patterns.py rename to src/anemoi/datasets/build/gridded/sources/patterns.py diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/build/gridded/sources/planetary_computer.py similarity index 92% rename from src/anemoi/datasets/create/sources/planetary_computer.py rename to src/anemoi/datasets/build/gridded/sources/planetary_computer.py index 07e8f0203..538857a32 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/build/gridded/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from anemoi.datasets.build.sources import source_registry +from anemoi.datasets.build.sources.xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/recentre.py b/src/anemoi/datasets/build/gridded/sources/recentre.py similarity index 97% rename from src/anemoi/datasets/create/sources/recentre.py rename to src/anemoi/datasets/build/gridded/sources/recentre.py index d0959f664..c989dadb6 100644 --- a/src/anemoi/datasets/create/sources/recentre.py +++ b/src/anemoi/datasets/build/gridded/sources/recentre.py @@ -10,9 +10,9 @@ from copy import deepcopy from typing import Any +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.mars import mars from anemoi.datasets.compute.recentre import recentre as _recentre -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars def to_list(x: list | tuple | str) -> list: diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py similarity index 98% rename from src/anemoi/datasets/create/sources/repeated_dates.py rename to src/anemoi/datasets/build/gridded/sources/repeated_dates.py index b56537979..cdc4b5926 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py @@ -19,8 +19,8 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.build.source import Source +from anemoi.datasets.build.sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/source.py b/src/anemoi/datasets/build/gridded/sources/source.py similarity index 94% rename from src/anemoi/datasets/create/sources/source.py rename to src/anemoi/datasets/build/gridded/sources/source.py index 1bac545d8..5d724f4fd 100644 --- a/src/anemoi/datasets/create/sources/source.py +++ b/src/anemoi/datasets/build/gridded/sources/source.py @@ -12,8 +12,8 @@ from earthkit.data import from_source -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.utils import to_datetime_list @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/tendencies.py b/src/anemoi/datasets/build/gridded/sources/tendencies.py similarity index 96% rename from src/anemoi/datasets/create/sources/tendencies.py rename to src/anemoi/datasets/build/gridded/sources/tendencies.py index 222dca9a4..0f716f803 100644 --- a/src/anemoi/datasets/create/sources/tendencies.py +++ b/src/anemoi/datasets/build/gridded/sources/tendencies.py @@ -14,8 +14,8 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.utils import to_datetime_list def _date_to_datetime(d: Any) -> Any: @@ -107,7 +107,7 @@ def tendencies(dates: list[datetime.datetime], time_increment: Any, **kwargs: An all_dates = sorted(list(set(dates + shifted_dates))) # from .mars import execute as mars - from anemoi.datasets.create.mars import execute as mars + from anemoi.datasets.build.mars import execute as mars ds = mars(dates=all_dates, **kwargs) diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/build/gridded/sources/xarray.py similarity index 88% rename from src/anemoi/datasets/create/sources/xarray.py rename to src/anemoi/datasets/build/gridded/sources/xarray.py index 5e3cc4c10..077bcd63a 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray.py @@ -11,11 +11,11 @@ import earthkit.data as ekd -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources.xarray_support import XarrayFieldList -from anemoi.datasets.create.sources.xarray_support import load_many -from anemoi.datasets.create.sources.xarray_support import load_one -from anemoi.datasets.create.typing import DateList +from anemoi.datasets.build.source import Source +from anemoi.datasets.build.sources.xarray_support import XarrayFieldList +from anemoi.datasets.build.sources.xarray_support import load_many +from anemoi.datasets.build.sources.xarray_support import load_one +from anemoi.datasets.build.typing import DateList __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py similarity index 89% rename from src/anemoi/datasets/create/sources/xarray_kerchunk.py rename to src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py index 632a7cae2..caeb5e01a 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from anemoi.datasets.build.sources import source_registry +from anemoi.datasets.build.sources.xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_support/README.md b/src/anemoi/datasets/build/gridded/sources/xarray_support/README.md similarity index 100% rename from src/anemoi/datasets/create/sources/xarray_support/README.md rename to src/anemoi/datasets/build/gridded/sources/xarray_support/README.md diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py similarity index 95% rename from src/anemoi/datasets/create/sources/xarray_support/__init__.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py index c33ce7bfc..c40bd5fcd 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py @@ -15,9 +15,9 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.create.sources.xarray_support.fieldlist import XarrayFieldList +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.patterns import iterate_patterns +from anemoi.datasets.build.sources.xarray_support.fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/coordinates.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/coordinates.py similarity index 100% rename from src/anemoi/datasets/create/sources/xarray_support/coordinates.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/coordinates.py diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py similarity index 96% rename from src/anemoi/datasets/create/sources/xarray_support/field.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/field.py index 85f9970f8..7de7e6046 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata +from anemoi.datasets.build.sources.xarray_support.coordinates import extract_single_value +from anemoi.datasets.build.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.build.sources.xarray_support.metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py similarity index 94% rename from src/anemoi/datasets/create/sources/xarray_support/fieldlist.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py index 174cb2716..1798a1d4d 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList -from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser -from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset -from anemoi.datasets.create.sources.xarray_support.time import Time -from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from anemoi.datasets.build.sources.xarray_support.field import EmptyFieldList +from anemoi.datasets.build.sources.xarray_support.flavour import CoordinateGuesser +from anemoi.datasets.build.sources.xarray_support.patch import patch_dataset +from anemoi.datasets.build.sources.xarray_support.time import Time +from anemoi.datasets.build.sources.xarray_support.variable import FilteredVariable +from anemoi.datasets.build.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py similarity index 95% rename from src/anemoi/datasets/create/sources/xarray_support/flavour.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py index 74fcdbd03..94d1424ef 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.grid import Grid -from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid -from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid +from anemoi.datasets.build.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import PointCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.build.sources.xarray_support.grid import Grid +from anemoi.datasets.build.sources.xarray_support.grid import MeshedGrid +from anemoi.datasets.build.sources.xarray_support.grid import MeshProjectionGrid +from anemoi.datasets.build.sources.xarray_support.grid import UnstructuredGrid +from anemoi.datasets.build.sources.xarray_support.grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/grid.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/grid.py similarity index 100% rename from src/anemoi/datasets/create/sources/xarray_support/grid.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/grid.py diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py similarity index 98% rename from src/anemoi/datasets/create/sources/xarray_support/metadata.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py index 2230db3ef..104d1fb62 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from anemoi.datasets.create.sources.xarray_support.field import XArrayField + from anemoi.datasets.build.sources.xarray_support.field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/patch.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/patch.py similarity index 100% rename from src/anemoi/datasets/create/sources/xarray_support/patch.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/patch.py diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py similarity index 98% rename from src/anemoi/datasets/create/sources/xarray_support/time.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/time.py index 7b1f60e58..1a875473f 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from anemoi.datasets.build.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.build.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py similarity index 99% rename from src/anemoi/datasets/create/sources/xarray_support/variable.py rename to src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py index 13d6fa4e2..541e60d32 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from anemoi.datasets.create.sources.xarray_support.field import XArrayField +from anemoi.datasets.build.sources.xarray_support.field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_zarr.py b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py similarity index 89% rename from src/anemoi/datasets/create/sources/xarray_zarr.py rename to src/anemoi/datasets/build/gridded/sources/xarray_zarr.py index 2f96ab207..5e9da7f44 100644 --- a/src/anemoi/datasets/create/sources/xarray_zarr.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py @@ -11,8 +11,8 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/sources/zenodo.py b/src/anemoi/datasets/build/gridded/sources/zenodo.py similarity index 90% rename from src/anemoi/datasets/create/sources/zenodo.py rename to src/anemoi/datasets/build/gridded/sources/zenodo.py index e23b8fa47..774afd277 100644 --- a/src/anemoi/datasets/create/sources/zenodo.py +++ b/src/anemoi/datasets/build/gridded/sources/zenodo.py @@ -14,9 +14,9 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.create.sources.xarray import load_one +from anemoi.datasets.build.sources.legacy import legacy_source +from anemoi.datasets.build.sources.patterns import iterate_patterns +from anemoi.datasets.build.sources.xarray import load_one @legacy_source(__file__) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/build/input/__init__.py similarity index 89% rename from src/anemoi/datasets/create/input/__init__.py rename to src/anemoi/datasets/build/input/__init__.py index 2fe695781..4fd558242 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/build/input/__init__.py @@ -12,10 +12,10 @@ from typing import TYPE_CHECKING from typing import Any -from anemoi.datasets.create.input.context.field import FieldContext +from anemoi.datasets.build.input.context.field import FieldContext if TYPE_CHECKING: - from anemoi.datasets.create.input.action import Recipe + from anemoi.datasets.build.input.action import Recipe class InputBuilder: @@ -40,8 +40,8 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No @cached_property def action(self) -> "Recipe": """Returns the action object based on the configuration.""" - from anemoi.datasets.create.input.action import Recipe - from anemoi.datasets.create.input.action import action_factory + from anemoi.datasets.build.input.action import Recipe + from anemoi.datasets.build.input.action import action_factory sources = action_factory(self.data_sources, "data_sources") input = action_factory(self.config, "input") diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/build/input/action.py similarity index 98% rename from src/anemoi/datasets/create/input/action.py rename to src/anemoi/datasets/build/input/action.py index 7808ae717..8a5cab48c 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/build/input/action.py @@ -181,7 +181,7 @@ class DatasetSourceMixin: """Mixin class for sources defined in anemoi-datasets""" def create_object(self, context, config): - from anemoi.datasets.create.sources import create_source as create_datasets_source + from anemoi.datasets.build.sources import create_source as create_datasets_source return create_datasets_source(context, config) @@ -286,7 +286,7 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.create.sources import source_registry as dataset_source_registry + from anemoi.datasets.build.sources import source_registry as dataset_source_registry # Register sources, local first for name in dataset_source_registry.registered: diff --git a/src/anemoi/datasets/create/input/context/__init__.py b/src/anemoi/datasets/build/input/context/__init__.py similarity index 96% rename from src/anemoi/datasets/create/input/context/__init__.py rename to src/anemoi/datasets/build/input/context/__init__.py index 89df7a727..e8572ba78 100644 --- a/src/anemoi/datasets/create/input/context/__init__.py +++ b/src/anemoi/datasets/build/input/context/__init__.py @@ -55,7 +55,7 @@ def resolve(self, config): return config def create_source(self, config: Any, *path) -> Any: - from anemoi.datasets.create.input.action import action_factory + from anemoi.datasets.build.input.action import action_factory if not isinstance(config, dict): # It is already a result (e.g. ekd.FieldList), loaded from ${a.b.c} diff --git a/src/anemoi/datasets/create/input/context/field.py b/src/anemoi/datasets/build/input/context/field.py similarity index 92% rename from src/anemoi/datasets/create/input/context/field.py rename to src/anemoi/datasets/build/input/context/field.py index e92a1ebbd..1a03a603a 100644 --- a/src/anemoi/datasets/create/input/context/field.py +++ b/src/anemoi/datasets/build/input/context/field.py @@ -12,8 +12,8 @@ from earthkit.data.core.order import build_remapping -from anemoi.datasets.create.input.context import Context -from anemoi.datasets.create.input.result.field import FieldResult +from anemoi.datasets.build.input.context import Context +from anemoi.datasets.build.input.result.field import FieldResult class FieldContext(Context): diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/build/input/data_sources.py similarity index 94% rename from src/anemoi/datasets/create/input/data_sources.py rename to src/anemoi/datasets/build/input/data_sources.py index 31956d602..ab5cd5d50 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/build/input/data_sources.py @@ -13,10 +13,10 @@ from earthkit.data import FieldList -from anemoi.datasets.create.input.action import Action -from anemoi.datasets.create.input.action import action_factory -from anemoi.datasets.create.input.misc import _tidy -from anemoi.datasets.create.input.result.field import Result +from anemoi.datasets.build.input.action import Action +from anemoi.datasets.build.input.action import action_factory +from anemoi.datasets.build.input.misc import _tidy +from anemoi.datasets.build.input.result.field import Result from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/misc.py b/src/anemoi/datasets/build/input/misc.py similarity index 100% rename from src/anemoi/datasets/create/input/misc.py rename to src/anemoi/datasets/build/input/misc.py diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/build/input/repeated_dates.py similarity index 97% rename from src/anemoi/datasets/create/input/repeated_dates.py rename to src/anemoi/datasets/build/input/repeated_dates.py index 962b82717..925886c00 100644 --- a/src/anemoi/datasets/create/input/repeated_dates.py +++ b/src/anemoi/datasets/build/input/repeated_dates.py @@ -19,11 +19,11 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.create.input.action import Action -from anemoi.datasets.create.input.action import action_factory -from anemoi.datasets.create.input.join import JoinResult -from anemoi.datasets.create.input.result.field import Result -from anemoi.datasets.create.input.trace import trace_select +from anemoi.datasets.build.input.action import Action +from anemoi.datasets.build.input.action import action_factory +from anemoi.datasets.build.input.join import JoinResult +from anemoi.datasets.build.input.result.field import Result +from anemoi.datasets.build.input.trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/result/__init__.py b/src/anemoi/datasets/build/input/result/__init__.py similarity index 100% rename from src/anemoi/datasets/create/input/result/__init__.py rename to src/anemoi/datasets/build/input/result/__init__.py diff --git a/src/anemoi/datasets/create/input/result/field.py b/src/anemoi/datasets/build/input/result/field.py similarity index 99% rename from src/anemoi/datasets/create/input/result/field.py rename to src/anemoi/datasets/build/input/result/field.py index dbcf8fbd4..a80fdb3e6 100644 --- a/src/anemoi/datasets/create/input/result/field.py +++ b/src/anemoi/datasets/build/input/result/field.py @@ -22,7 +22,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from anemoi.datasets.create.input.result import Result +from anemoi.datasets.build.input.result import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/trace.py b/src/anemoi/datasets/build/input/trace.py similarity index 100% rename from src/anemoi/datasets/create/input/trace.py rename to src/anemoi/datasets/build/input/trace.py diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/build/patch.py similarity index 100% rename from src/anemoi/datasets/create/patch.py rename to src/anemoi/datasets/build/patch.py diff --git a/src/anemoi/datasets/create/persistent.py b/src/anemoi/datasets/build/persistent.py similarity index 100% rename from src/anemoi/datasets/create/persistent.py rename to src/anemoi/datasets/build/persistent.py diff --git a/src/anemoi/datasets/create/size.py b/src/anemoi/datasets/build/size.py similarity index 100% rename from src/anemoi/datasets/create/size.py rename to src/anemoi/datasets/build/size.py diff --git a/src/anemoi/datasets/create/source.py b/src/anemoi/datasets/build/source.py similarity index 96% rename from src/anemoi/datasets/create/source.py rename to src/anemoi/datasets/build/source.py index f79b0e9dd..df4911690 100644 --- a/src/anemoi/datasets/create/source.py +++ b/src/anemoi/datasets/build/source.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.create.typing import DateList +from anemoi.datasets.build.typing import DateList class Source(ABC): diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/build/statistics/__init__.py similarity index 99% rename from src/anemoi/datasets/create/statistics/__init__.py rename to src/anemoi/datasets/build/statistics/__init__.py index e8e71c45a..f7ece19bb 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/build/statistics/__init__.py @@ -23,8 +23,8 @@ from anemoi.utils.provenance import gather_provenance_info from numpy.typing import NDArray -from anemoi.datasets.create.check import check_data_values -from anemoi.datasets.create.statistics.summary import Summary +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.statistics.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/statistics/summary.py b/src/anemoi/datasets/build/statistics/summary.py similarity index 96% rename from src/anemoi/datasets/create/statistics/summary.py rename to src/anemoi/datasets/build/statistics/summary.py index 8b6c29eb0..59f3998b4 100644 --- a/src/anemoi/datasets/create/statistics/summary.py +++ b/src/anemoi/datasets/build/statistics/summary.py @@ -13,9 +13,9 @@ import numpy as np -from anemoi.datasets.create.check import StatisticsValueError -from anemoi.datasets.create.check import check_data_values -from anemoi.datasets.create.check import check_stats +from anemoi.datasets.build.check import StatisticsValueError +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.check import check_stats class Summary(dict): diff --git a/src/anemoi/datasets/create/testing.py b/src/anemoi/datasets/build/testing.py similarity index 100% rename from src/anemoi/datasets/create/testing.py rename to src/anemoi/datasets/build/testing.py diff --git a/src/anemoi/datasets/create/typing.py b/src/anemoi/datasets/build/typing.py similarity index 100% rename from src/anemoi/datasets/create/typing.py rename to src/anemoi/datasets/build/typing.py diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/build/utils.py similarity index 100% rename from src/anemoi/datasets/create/utils.py rename to src/anemoi/datasets/build/utils.py diff --git a/src/anemoi/datasets/create/writer.py b/src/anemoi/datasets/build/writer.py similarity index 100% rename from src/anemoi/datasets/create/writer.py rename to src/anemoi/datasets/build/writer.py diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/build/zarr.py similarity index 100% rename from src/anemoi/datasets/create/zarr.py rename to src/anemoi/datasets/build/zarr.py diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 4202ed09f..4ac355515 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,8 +13,8 @@ import yaml +from anemoi.datasets.build.check import DatasetName from anemoi.datasets.commands import Command -from anemoi.datasets.create.check import DatasetName LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 6601d0ee4..30df82783 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.create import creator_factory + from anemoi.datasets.build import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index b5cc910d2..072099bdd 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -83,7 +83,7 @@ def match(path: str) -> bool: """ return fnmatch.fnmatch(os.path.basename(path), args.match) - from anemoi.datasets.create.sources.grib_index import GribIndex + from anemoi.datasets.build.sources.grib_index import GribIndex index = GribIndex( args.index, diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 52b7e689d..59490bd33 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -28,8 +28,8 @@ from anemoi.datasets import open_dataset from anemoi.datasets.commands import Command -from anemoi.datasets.data.stores import open_zarr -from anemoi.datasets.data.stores import zarr_lookup +from anemoi.datasets.use.stores import open_zarr +from anemoi.datasets.use.stores import zarr_lookup LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index e708d8b50..813ca47b8 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,10 +15,10 @@ import yaml +from anemoi.datasets.build import validate_config from anemoi.datasets.commands import Command from anemoi.datasets.commands.recipe.format import format_recipe from anemoi.datasets.commands.recipe.migrate import migrate_recipe -from anemoi.datasets.create import validate_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 03da61fbc..ffaa3ddd1 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.create import validate_config +from anemoi.datasets.build import validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/use/gridded/__init__.py similarity index 91% rename from src/anemoi/datasets/data/__init__.py rename to src/anemoi/datasets/use/gridded/__init__.py index fc2b0839b..f6f8f5a3d 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/use/gridded/__init__.py @@ -15,13 +15,13 @@ # from .dataset import FullIndex # from .dataset import Shape # from .dataset import TupleIndex -from anemoi.datasets.data.misc import _open_dataset -from anemoi.datasets.data.misc import _save_dataset -from anemoi.datasets.data.misc import add_dataset_path -from anemoi.datasets.data.misc import add_named_dataset +from anemoi.datasets.use.misc import _open_dataset +from anemoi.datasets.use.misc import _save_dataset +from anemoi.datasets.use.misc import add_dataset_path +from anemoi.datasets.use.misc import add_named_dataset if TYPE_CHECKING: - from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.use.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/use/gridded/complement.py similarity index 95% rename from src/anemoi/datasets/data/complement.py rename to src/anemoi/datasets/use/gridded/complement.py index 87c65a5b4..df9b5cc86 100644 --- a/src/anemoi/datasets/data/complement.py +++ b/src/anemoi/datasets/use/gridded/complement.py @@ -16,18 +16,18 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open_dataset from anemoi.datasets.grids import nearest_grid_points +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open_dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/concat.py b/src/anemoi/datasets/use/gridded/concat.py similarity index 91% rename from src/anemoi/datasets/data/concat.py rename to src/anemoi/datasets/use/gridded/concat.py index fcdc768fc..9b9968468 100644 --- a/src/anemoi/datasets/data/concat.py +++ b/src/anemoi/datasets/use/gridded/concat.py @@ -16,20 +16,20 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import length_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import length_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def check_dataset_compatibility(cls, datasets: list[Any], fill_missing_gaps: boo s = ranges[i + 1] if r[1] + frequency != s[0]: if fill_missing_gaps: - from anemoi.datasets.data.missing import MissingDataset + from anemoi.datasets.use.missing import MissingDataset result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency)) else: diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py similarity index 95% rename from src/anemoi/datasets/data/dataset.py rename to src/anemoi/datasets/use/gridded/dataset.py index 021e385a2..cbfdfd2b6 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -34,8 +34,8 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import Source +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import Source if TYPE_CHECKING: import matplotlib @@ -165,7 +165,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # This one must be first if "fill_missing_dates" in kwargs: - from anemoi.datasets.data.fill_missing import fill_missing_dates_factory + from anemoi.datasets.use.fill_missing import fill_missing_dates_factory fill_missing_dates = kwargs.pop("fill_missing_dates") ds = fill_missing_dates_factory(self, fill_missing_dates, kwargs) @@ -179,7 +179,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": if padding: if padding != "empty": raise ValueError(f"Only 'empty' padding is supported, got {padding=}") - from anemoi.datasets.data.padded import Padded + from anemoi.datasets.use.padded import Padded frequency = kwargs.pop("frequency", self.frequency) return ( @@ -188,14 +188,14 @@ def __subset(self, **kwargs: Any) -> "Dataset": .mutate() ) - from anemoi.datasets.data.subset import Subset + from anemoi.datasets.use.subset import Subset return ( Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate() ) if "frequency" in kwargs: - from anemoi.datasets.data.subset import Subset + from anemoi.datasets.use.subset import Subset if "interpolate_frequency" in kwargs: raise ValueError("Cannot use both `frequency` and `interpolate_frequency`") @@ -208,38 +208,38 @@ def __subset(self, **kwargs: Any) -> "Dataset": ) if "select" in kwargs: - from anemoi.datasets.data.select import Select + from anemoi.datasets.use.select import Select select = kwargs.pop("select") return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate() if "drop" in kwargs: - from anemoi.datasets.data.select import Select + from anemoi.datasets.use.select import Select drop = kwargs.pop("drop") return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate() if "reorder" in kwargs: - from anemoi.datasets.data.select import Select + from anemoi.datasets.use.select import Select reorder = kwargs.pop("reorder") return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() if "rename" in kwargs: - from anemoi.datasets.data.select import Rename + from anemoi.datasets.use.select import Rename rename = kwargs.pop("rename") return Rename(self, rename)._subset(**kwargs).mutate() if "rescale" in kwargs: - from anemoi.datasets.data.rescale import Rescale + from anemoi.datasets.use.rescale import Rescale rescale = kwargs.pop("rescale") return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from anemoi.datasets.data import open_dataset - from anemoi.datasets.data.statistics import Statistics + from anemoi.datasets.use import open_dataset + from anemoi.datasets.use.statistics import Statistics statistics = kwargs.pop("statistics") @@ -247,26 +247,26 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Note: trim_edge should go before thinning if "trim_edge" in kwargs: - from anemoi.datasets.data.masked import TrimEdge + from anemoi.datasets.use.masked import TrimEdge edge = kwargs.pop("trim_edge") return TrimEdge(self, edge)._subset(**kwargs).mutate() if "thinning" in kwargs: - from anemoi.datasets.data.masked import Thinning + from anemoi.datasets.use.masked import Thinning thinning = kwargs.pop("thinning") method = kwargs.pop("method", "every-nth") return Thinning(self, thinning, method)._subset(**kwargs).mutate() if "area" in kwargs: - from anemoi.datasets.data.masked import Cropping + from anemoi.datasets.use.masked import Cropping bbox = kwargs.pop("area") return Cropping(self, bbox)._subset(**kwargs).mutate() if "number" in kwargs or "numbers" in kwargs or "member" in kwargs or "members" in kwargs: - from anemoi.datasets.data.ensemble import Number + from anemoi.datasets.use.ensemble import Number members = {} for key in ["number", "numbers", "member", "members"]: @@ -276,13 +276,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Number(self, **members)._subset(**kwargs).mutate() if "set_missing_dates" in kwargs: - from anemoi.datasets.data.missing import MissingDates + from anemoi.datasets.use.missing import MissingDates set_missing_dates = kwargs.pop("set_missing_dates") return MissingDates(self, set_missing_dates)._subset(**kwargs).mutate() if "skip_missing_dates" in kwargs: - from anemoi.datasets.data.missing import SkipMissingDates + from anemoi.datasets.use.missing import SkipMissingDates if "expected_access" not in kwargs: raise ValueError("`expected_access` is required with `skip_missing_dates`") @@ -294,13 +294,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() if "interpolate_frequency" in kwargs: - from anemoi.datasets.data.interpolate import InterpolateFrequency + from anemoi.datasets.use.interpolate import InterpolateFrequency interpolate_frequency = kwargs.pop("interpolate_frequency") return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate() if "interpolate_variables" in kwargs: - from anemoi.datasets.data.interpolate import InterpolateNearest + from anemoi.datasets.use.interpolate import InterpolateNearest interpolate_variables = kwargs.pop("interpolate_variables") max_distance = kwargs.pop("max_distance", None) @@ -308,7 +308,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Keep last if "shuffle" in kwargs: - from anemoi.datasets.data.subset import Subset + from anemoi.datasets.use.subset import Subset shuffle = kwargs.pop("shuffle") @@ -372,8 +372,8 @@ def _dates_to_indices( list of int The list of indices. """ - from anemoi.datasets.data.misc import as_first_date - from anemoi.datasets.data.misc import as_last_date + from anemoi.datasets.use.misc import as_first_date + from anemoi.datasets.use.misc import as_last_date # TODO: optimize diff --git a/src/anemoi/datasets/data/debug.css b/src/anemoi/datasets/use/gridded/debug.css similarity index 100% rename from src/anemoi/datasets/data/debug.css rename to src/anemoi/datasets/use/gridded/debug.css diff --git a/src/anemoi/datasets/data/debug.py b/src/anemoi/datasets/use/gridded/debug.py similarity index 99% rename from src/anemoi/datasets/data/debug.py rename to src/anemoi/datasets/use/gridded/debug.py index ca80ace41..84c6f0b64 100644 --- a/src/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/use/gridded/debug.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.use.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/ensemble.py b/src/anemoi/datasets/use/gridded/ensemble.py similarity index 89% rename from src/anemoi/datasets/data/ensemble.py rename to src/anemoi/datasets/use/gridded/ensemble.py index a6c59c812..1cf4d885b 100644 --- a/src/anemoi/datasets/data/ensemble.py +++ b/src/anemoi/datasets/use/gridded/ensemble.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.forwards import GivenAxis -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.forwards import GivenAxis +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/fill_missing.py b/src/anemoi/datasets/use/gridded/fill_missing.py similarity index 93% rename from src/anemoi/datasets/data/fill_missing.py rename to src/anemoi/datasets/use/gridded/fill_missing.py index 0cc1b0ee2..649a7e08b 100644 --- a/src/anemoi/datasets/data/fill_missing.py +++ b/src/anemoi/datasets/use/gridded/fill_missing.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use import MissingDateError +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/use/gridded/forwards.py similarity index 97% rename from src/anemoi/datasets/data/forwards.py rename to src/anemoi/datasets/use/gridded/forwards.py index 463aebf3e..058c66e9c 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/use/gridded/forwards.py @@ -18,16 +18,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import length_to_slices -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import length_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/use/gridded/grids.py similarity index 96% rename from src/anemoi/datasets/data/grids.py rename to src/anemoi/datasets/use/gridded/grids.py index fee2c792e..423a57deb 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/use/gridded/grids.py @@ -16,21 +16,21 @@ from numpy.typing import NDArray from scipy.spatial import cKDTree -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.forwards import GivenAxis -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import length_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.forwards import GivenAxis +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import length_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/indexing.py b/src/anemoi/datasets/use/gridded/indexing.py similarity index 98% rename from src/anemoi/datasets/data/indexing.py rename to src/anemoi/datasets/use/gridded/indexing.py index 7c4bb4be3..f152e907f 100644 --- a/src/anemoi/datasets/data/indexing.py +++ b/src/anemoi/datasets/use/gridded/indexing.py @@ -15,9 +15,9 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex def _tuple_with_slices(t: TupleIndex, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]: diff --git a/src/anemoi/datasets/data/interpolate.py b/src/anemoi/datasets/use/gridded/interpolate.py similarity index 93% rename from src/anemoi/datasets/data/interpolate.py rename to src/anemoi/datasets/use/gridded/interpolate.py index 1f64d21a9..5d8f70bf3 100644 --- a/src/anemoi/datasets/data/interpolate.py +++ b/src/anemoi/datasets/use/gridded/interpolate.py @@ -17,17 +17,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/use/gridded/join.py similarity index 91% rename from src/anemoi/datasets/data/join.py rename to src/anemoi/datasets/use/gridded/join.py index f7f1caa03..b852ab19f 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/use/gridded/join.py @@ -16,20 +16,20 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import Source -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import Source +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) @@ -173,7 +173,7 @@ def _overlay(self) -> Dataset: if not ok: LOG.warning("Dataset %r completely overridden.", d) - from anemoi.datasets.data.select import Select + from anemoi.datasets.use.select import Select return Select(self, indices, {"overlay": variables}) diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/use/gridded/masked.py similarity index 93% rename from src/anemoi/datasets/data/masked.py rename to src/anemoi/datasets/use/gridded/masked.py index b4982d696..f64bb2f59 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/use/gridded/masked.py @@ -15,18 +15,18 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple from anemoi.datasets.grids import cropping_mask +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -214,7 +214,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from anemoi.datasets.data import open_dataset + from anemoi.datasets.use import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/use/gridded/merge.py similarity index 92% rename from src/anemoi/datasets/data/merge.py rename to src/anemoi/datasets/use/gridded/merge.py index b974a6afb..a2f3a83bd 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/use/gridded/merge.py @@ -16,19 +16,19 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use import MissingDateError +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/use/gridded/misc.py similarity index 95% rename from src/anemoi/datasets/data/misc.py rename to src/anemoi/datasets/use/gridded/misc.py index 53ae8bb07..4709265be 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -23,7 +23,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.use.dataset import Dataset LOG = logging.getLogger(__name__) @@ -323,11 +323,11 @@ def _concat_or_join(datasets: list["Dataset"], kwargs: dict[str, Any]) -> tuple[ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets] if len(set(ranges)) == 1: - from anemoi.datasets.data.join import Join + from anemoi.datasets.use.join import Join return Join(datasets)._overlay(), kwargs - from anemoi.datasets.data.concat import Concat + from anemoi.datasets.use.concat import Concat Concat.check_dataset_compatibility(datasets) @@ -347,9 +347,9 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " Dataset The opened dataset. """ - from anemoi.datasets.data.dataset import Dataset - from anemoi.datasets.data.stores import Zarr - from anemoi.datasets.data.stores import zarr_lookup + from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.stores import Zarr + from anemoi.datasets.use.stores import zarr_lookup if isinstance(a, str) and len(a.split(".")) in [2, 3]: @@ -359,7 +359,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in metadata: raise ValueError(f"Metadata for {a} does not contain 'backend' key") - from anemoi.datasets.data.records import open_records_dataset + from anemoi.datasets.use.records import open_records_dataset return open_records_dataset(a, backend=metadata["backend"]) @@ -501,7 +501,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from anemoi.datasets.data.observations import observations_factory + from anemoi.datasets.use.observations import observations_factory assert not sets, sets @@ -509,70 +509,70 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "xy" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.data.xy import xy_factory + from anemoi.datasets.use.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "x" in kwargs and "y" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.data.xy import xy_factory + from anemoi.datasets.use.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "zip" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.data.xy import zip_factory + from anemoi.datasets.use.gridded.xy import zip_factory assert not sets, sets return zip_factory(args, kwargs).mutate() if "chain" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.data.unchecked import chain_factory + from anemoi.datasets.use.unchecked import chain_factory assert not sets, sets return chain_factory(args, kwargs).mutate() if "join" in kwargs: - from anemoi.datasets.data.join import join_factory + from anemoi.datasets.use.join import join_factory assert not sets, sets return join_factory(args, kwargs).mutate() if "concat" in kwargs: - from anemoi.datasets.data.concat import concat_factory + from anemoi.datasets.use.concat import concat_factory assert not sets, sets return concat_factory(args, kwargs).mutate() if "merge" in kwargs: - from anemoi.datasets.data.merge import merge_factory + from anemoi.datasets.use.merge import merge_factory assert not sets, sets return merge_factory(args, kwargs).mutate() if "ensemble" in kwargs: - from anemoi.datasets.data.ensemble import ensemble_factory + from anemoi.datasets.use.ensemble import ensemble_factory assert not sets, sets return ensemble_factory(args, kwargs).mutate() if "grids" in kwargs: - from anemoi.datasets.data.grids import grids_factory + from anemoi.datasets.use.grids import grids_factory assert not sets, sets return grids_factory(args, kwargs).mutate() if "cutout" in kwargs: - from anemoi.datasets.data.grids import cutout_factory + from anemoi.datasets.use.grids import cutout_factory assert not sets, sets return cutout_factory(args, kwargs).mutate() if "complement" in kwargs: - from anemoi.datasets.data.complement import complement_factory + from anemoi.datasets.use.complement import complement_factory assert not sets, sets return complement_factory(args, kwargs).mutate() diff --git a/src/anemoi/datasets/data/missing.py b/src/anemoi/datasets/use/gridded/missing.py similarity index 95% rename from src/anemoi/datasets/data/missing.py rename to src/anemoi/datasets/use/gridded/missing.py index 5a6e8a5f8..32a0d5c69 100644 --- a/src/anemoi/datasets/data/missing.py +++ b/src/anemoi/datasets/use/gridded/missing.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.create.utils import to_datetime -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.build.utils import to_datetime +from anemoi.datasets.use import MissingDateError +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/padded.py b/src/anemoi/datasets/use/gridded/padded.py similarity index 93% rename from src/anemoi/datasets/data/padded.py rename to src/anemoi/datasets/use/gridded/padded.py index d0bebb6fc..53c45071d 100644 --- a/src/anemoi/datasets/data/padded.py +++ b/src/anemoi/datasets/use/gridded/padded.py @@ -17,16 +17,16 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.misc import as_first_date -from anemoi.datasets.data.misc import as_last_date +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.misc import as_first_date +from anemoi.datasets.use.misc import as_last_date LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/use/gridded/rescale.py similarity index 92% rename from src/anemoi/datasets/data/rescale.py rename to src/anemoi/datasets/use/gridded/rescale.py index f5d8734fe..630199efc 100644 --- a/src/anemoi/datasets/data/rescale.py +++ b/src/anemoi/datasets/use/gridded/rescale.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/select.py b/src/anemoi/datasets/use/gridded/select.py similarity index 92% rename from src/anemoi/datasets/data/select.py rename to src/anemoi/datasets/use/gridded/select.py index e27b94f76..7a57639ce 100644 --- a/src/anemoi/datasets/data/select.py +++ b/src/anemoi/datasets/use/gridded/select.py @@ -15,18 +15,18 @@ from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import Source -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import Source +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/statistics.py b/src/anemoi/datasets/use/gridded/statistics.py similarity index 94% rename from src/anemoi/datasets/data/statistics.py rename to src/anemoi/datasets/use/gridded/statistics.py index 2bb26b3d6..e6439ecec 100644 --- a/src/anemoi/datasets/data/statistics.py +++ b/src/anemoi/datasets/use/gridded/statistics.py @@ -15,10 +15,10 @@ from numpy.typing import NDArray -from anemoi.datasets.data import open_dataset -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.forwards import Forwards +from anemoi.datasets.use import open_dataset +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.forwards import Forwards LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/use/gridded/stores.py similarity index 96% rename from src/anemoi/datasets/data/stores.py rename to src/anemoi/datasets/use/gridded/stores.py index 9224c22d3..4514f06c6 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -22,17 +22,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import DEBUG_ZARR_LOADING -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import Source -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.misc import load_config +from anemoi.datasets.use import MissingDateError +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import DEBUG_ZARR_LOADING +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import Source +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.misc import load_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/use/gridded/subset.py similarity index 90% rename from src/anemoi/datasets/data/subset.py rename to src/anemoi/datasets/use/gridded/subset.py index 22eef70da..bf65bb4a2 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/use/gridded/subset.py @@ -19,19 +19,19 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.debug import Source -from anemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.data.forwards import Forwards -from anemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.data.indexing import make_slice_or_index_from_list_or_tuple -from anemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.debug import Source +from anemoi.datasets.use.debug import debug_indexing +from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.indexing import expand_list_indexing +from anemoi.datasets.use.indexing import index_to_slices +from anemoi.datasets.use.indexing import make_slice_or_index_from_list_or_tuple +from anemoi.datasets.use.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def _start(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the start date. """ - from anemoi.datasets.data.misc import as_first_date + from anemoi.datasets.use.misc import as_first_date c = as_first_date(a, dates) d = as_first_date(b, dates) @@ -82,7 +82,7 @@ def _end(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the end date. """ - from anemoi.datasets.data.misc import as_last_date + from anemoi.datasets.use.misc import as_last_date c = as_last_date(a, dates) d = as_last_date(b, dates) diff --git a/src/anemoi/datasets/data/unchecked.py b/src/anemoi/datasets/use/gridded/unchecked.py similarity index 94% rename from src/anemoi/datasets/data/unchecked.py rename to src/anemoi/datasets/use/gridded/unchecked.py index 478c8c1eb..5f3a9fb4a 100644 --- a/src/anemoi/datasets/data/unchecked.py +++ b/src/anemoi/datasets/use/gridded/unchecked.py @@ -18,14 +18,14 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.data.concat import ConcatMixin -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.dataset import Shape -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.concat import ConcatMixin +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.dataset import Shape +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/xy.py b/src/anemoi/datasets/use/gridded/xy.py similarity index 96% rename from src/anemoi/datasets/data/xy.py rename to src/anemoi/datasets/use/gridded/xy.py index e181dc9aa..7d65201b3 100644 --- a/src/anemoi/datasets/data/xy.py +++ b/src/anemoi/datasets/use/gridded/xy.py @@ -12,12 +12,12 @@ from functools import cached_property from typing import Any -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.data.debug import Node -from anemoi.datasets.data.forwards import Combined -from anemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.data.misc import _open +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.dataset import FullIndex +from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.forwards import Combined +from anemoi.datasets.use.misc import _auto_adjust +from anemoi.datasets.use.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/observations/__init__.py b/src/anemoi/datasets/use/observations/__init__.py similarity index 97% rename from src/anemoi/datasets/data/observations/__init__.py rename to src/anemoi/datasets/use/observations/__init__.py index 23413e05d..58e7fa822 100644 --- a/src/anemoi/datasets/data/observations/__init__.py +++ b/src/anemoi/datasets/use/observations/__init__.py @@ -14,8 +14,8 @@ import numpy as np from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.data.dataset import Dataset -from anemoi.datasets.data.debug import Node +from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.use.debug import Node LOG = logging.getLogger(__name__) @@ -138,7 +138,7 @@ def __init__(self, dataset, frequency=None, window=None): if isinstance(dataset, zarr.hierarchy.Group): dataset = dataset._store.path - from anemoi.datasets.data.stores import zarr_lookup + from anemoi.datasets.use.stores import zarr_lookup dataset = zarr_lookup(dataset) self.path = dataset @@ -176,7 +176,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from anemoi.datasets.data.observations.legacy_obs_dataset import ObsDataset + from anemoi.datasets.use.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/src/anemoi/datasets/data/observations/legacy_obs_dataset.py b/src/anemoi/datasets/use/observations/legacy_obs_dataset.py similarity index 100% rename from src/anemoi/datasets/data/observations/legacy_obs_dataset.py rename to src/anemoi/datasets/use/observations/legacy_obs_dataset.py diff --git a/src/anemoi/datasets/data/observations/multi.py b/src/anemoi/datasets/use/observations/multi.py similarity index 97% rename from src/anemoi/datasets/data/observations/multi.py rename to src/anemoi/datasets/use/observations/multi.py index af5c02e71..a6b6be176 100644 --- a/src/anemoi/datasets/data/observations/multi.py +++ b/src/anemoi/datasets/use/observations/multi.py @@ -10,7 +10,7 @@ import logging import os -from anemoi.datasets.data import open_dataset +from anemoi.datasets.use import open_dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/use/records/__init__.py similarity index 98% rename from src/anemoi/datasets/data/records/__init__.py rename to src/anemoi/datasets/use/records/__init__.py index f569a4105..efd368606 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/use/records/__init__.py @@ -16,7 +16,7 @@ import numpy as np from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.data.records.backends import backend_factory +from anemoi.datasets.use.records.backends import backend_factory LOG = logging.getLogger(__name__) @@ -91,8 +91,8 @@ def _subset(self, **kwargs): if start is not None or end is not None: def _dates_to_indices(start, end): - from anemoi.datasets.data.misc import as_first_date - from anemoi.datasets.data.misc import as_last_date + from anemoi.datasets.use.misc import as_first_date + from anemoi.datasets.use.misc import as_last_date start = self.dates[0] if start is None else as_first_date(start, self.dates) end = self.dates[-1] if end is None else as_last_date(end, self.dates) diff --git a/src/anemoi/datasets/data/records/backends/__init__.py b/src/anemoi/datasets/use/records/backends/__init__.py similarity index 97% rename from src/anemoi/datasets/data/records/backends/__init__.py rename to src/anemoi/datasets/use/records/backends/__init__.py index 817d3fc88..f09c32e4d 100644 --- a/src/anemoi/datasets/data/records/backends/__init__.py +++ b/src/anemoi/datasets/use/records/backends/__init__.py @@ -100,7 +100,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create import json_tidy + from anemoi.datasets.build import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -128,7 +128,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create import json_tidy + from anemoi.datasets.build import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/src/anemoi/datasets/validate.py b/src/anemoi/datasets/validate.py index 9e1c6c891..a1e168116 100644 --- a/src/anemoi/datasets/validate.py +++ b/src/anemoi/datasets/validate.py @@ -14,8 +14,8 @@ import numpy as np -from anemoi.datasets.data.dataset import Dataset from anemoi.datasets.testing import default_test_indexing +from anemoi.datasets.use.dataset import Dataset LOG = logging.getLogger(__name__) # List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py index b0648d252..6da96ae95 100644 --- a/tests/create/utils/compare.py +++ b/tests/create/utils/compare.py @@ -12,7 +12,7 @@ import numpy as np from anemoi.datasets import open_dataset -from anemoi.datasets.data.stores import open_zarr +from anemoi.datasets.use.stores import open_zarr class Comparer: diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index a57022bab..a9fd74575 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create import creator_factory +from anemoi.datasets.build import creator_factory class TestingContext: diff --git a/tests/test_chunks.py b/tests/test_chunks.py index 18337b689..b2aa3c789 100644 --- a/tests/test_chunks.py +++ b/tests/test_chunks.py @@ -11,7 +11,7 @@ import pytest -from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.build.chunks import ChunkFilter def test_chunk_filter(): diff --git a/tests/test_data.py b/tests/test_data.py index d29e14059..c19f54b6c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -25,20 +25,20 @@ from anemoi.datasets import open_dataset from anemoi.datasets.commands.inspect import InspectZarr from anemoi.datasets.commands.inspect import NoVersion -from anemoi.datasets.data import save_dataset -from anemoi.datasets.data.concat import Concat -from anemoi.datasets.data.ensemble import Ensemble -from anemoi.datasets.data.grids import GridsBase -from anemoi.datasets.data.join import Join -from anemoi.datasets.data.misc import as_first_date -from anemoi.datasets.data.misc import as_last_date -from anemoi.datasets.data.padded import Padded -from anemoi.datasets.data.select import Rename -from anemoi.datasets.data.select import Select -from anemoi.datasets.data.statistics import Statistics -from anemoi.datasets.data.stores import Zarr -from anemoi.datasets.data.subset import Subset from anemoi.datasets.testing import default_test_indexing +from anemoi.datasets.use import save_dataset +from anemoi.datasets.use.concat import Concat +from anemoi.datasets.use.ensemble import Ensemble +from anemoi.datasets.use.grids import GridsBase +from anemoi.datasets.use.join import Join +from anemoi.datasets.use.misc import as_first_date +from anemoi.datasets.use.misc import as_last_date +from anemoi.datasets.use.padded import Padded +from anemoi.datasets.use.select import Rename +from anemoi.datasets.use.select import Select +from anemoi.datasets.use.statistics import Statistics +from anemoi.datasets.use.stores import Zarr +from anemoi.datasets.use.subset import Subset VALUES = 10 diff --git a/tests/test_dates.py b/tests/test_dates.py index 7d7613506..d169498bb 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from anemoi.datasets.create.statistics import default_statistics_dates +from anemoi.datasets.build.statistics import default_statistics_dates _ = datetime.datetime diff --git a/tests/test_indexing.py b/tests/test_indexing.py index bc53462ac..cd5c6f25d 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -10,7 +10,7 @@ import numpy as np -from anemoi.datasets.data.indexing import length_to_slices +from anemoi.datasets.use.indexing import length_to_slices def test_length_to_slices() -> None: diff --git a/tests/test_records.py b/tests/test_records.py index 896081f9a..c96041cb7 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -11,9 +11,9 @@ import numpy as np import pytest -from anemoi.datasets.data import open_dataset -from anemoi.datasets.data.records import Record -from anemoi.datasets.data.records import Tabular +from anemoi.datasets.use import open_dataset +from anemoi.datasets.use.records import Record +from anemoi.datasets.use.records import Tabular def check_numpy(x, y): diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py index 7b2bb33e5..cdf093e5f 100644 --- a/tests/xarray/test_flavour.py +++ b/tests/xarray/test_flavour.py @@ -11,18 +11,18 @@ import pytest import xarray as xr -from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.create.sources.xarray_support.flavour import DefaultCoordinateGuesser +from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.build.sources.xarray_support.flavour import DefaultCoordinateGuesser def create_ds(var_name, standard_name, long_name, units, coord_length=5): diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py index f25d8c4d7..1619a47ac 100644 --- a/tests/xarray/test_netcdf.py +++ b/tests/xarray/test_netcdf.py @@ -12,7 +12,7 @@ import xarray as xr from multiurl import download -from anemoi.datasets.create.sources.xarray import XarrayFieldList +from anemoi.datasets.build.sources.xarray import XarrayFieldList URLS = { "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/efas.nc": dict(length=3), diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index fb855ca94..b8f4eac9e 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -12,7 +12,7 @@ import xarray as xr from anemoi.utils.testing import skip_if_offline -from anemoi.datasets.create.sources.xarray import XarrayFieldList +from anemoi.datasets.build.sources.xarray import XarrayFieldList from anemoi.datasets.testing import assert_field_list diff --git a/tests/xarray/test_variable.py b/tests/xarray/test_variable.py index ff43da389..afb82ecbf 100644 --- a/tests/xarray/test_variable.py +++ b/tests/xarray/test_variable.py @@ -13,14 +13,14 @@ import pytest import xarray as xr -from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.sources.xarray_support.time import ForecastFromValidTimeAndStep -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.build.sources.xarray_support.time import ForecastFromValidTimeAndStep +from anemoi.datasets.build.sources.xarray_support.variable import Variable @pytest.fixture diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 742bfae80..5f166be22 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -12,7 +12,7 @@ from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages -from anemoi.datasets.create.sources.xarray import XarrayFieldList +from anemoi.datasets.build.sources.xarray import XarrayFieldList from anemoi.datasets.testing import assert_field_list diff --git a/tools/build-obs.py b/tools/build-obs.py index e3caff9f9..2ccd1c1a2 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,7 +28,7 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") - from anemoi.datasets.data.records.backends import writer_backend_factory + from anemoi.datasets.use.records.backends import writer_backend_factory if os.path.exists(output): if overwrite: From 6d84c5391ba0057e346d0fa7f37e1c3487a4696a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:19:30 +0000 Subject: [PATCH 148/212] rename files --- src/anemoi/datasets/use/__init__.py | 0 src/anemoi/datasets/use/{ => tabular}/observations/__init__.py | 0 .../datasets/use/{ => tabular}/observations/legacy_obs_dataset.py | 0 src/anemoi/datasets/use/{ => tabular}/observations/multi.py | 0 src/anemoi/datasets/use/{ => tabular}/records/__init__.py | 0 .../datasets/use/{ => tabular}/records/backends/__init__.py | 0 6 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/anemoi/datasets/use/__init__.py rename src/anemoi/datasets/use/{ => tabular}/observations/__init__.py (100%) rename src/anemoi/datasets/use/{ => tabular}/observations/legacy_obs_dataset.py (100%) rename src/anemoi/datasets/use/{ => tabular}/observations/multi.py (100%) rename src/anemoi/datasets/use/{ => tabular}/records/__init__.py (100%) rename src/anemoi/datasets/use/{ => tabular}/records/backends/__init__.py (100%) diff --git a/src/anemoi/datasets/use/__init__.py b/src/anemoi/datasets/use/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anemoi/datasets/use/observations/__init__.py b/src/anemoi/datasets/use/tabular/observations/__init__.py similarity index 100% rename from src/anemoi/datasets/use/observations/__init__.py rename to src/anemoi/datasets/use/tabular/observations/__init__.py diff --git a/src/anemoi/datasets/use/observations/legacy_obs_dataset.py b/src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py similarity index 100% rename from src/anemoi/datasets/use/observations/legacy_obs_dataset.py rename to src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py diff --git a/src/anemoi/datasets/use/observations/multi.py b/src/anemoi/datasets/use/tabular/observations/multi.py similarity index 100% rename from src/anemoi/datasets/use/observations/multi.py rename to src/anemoi/datasets/use/tabular/observations/multi.py diff --git a/src/anemoi/datasets/use/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py similarity index 100% rename from src/anemoi/datasets/use/records/__init__.py rename to src/anemoi/datasets/use/tabular/records/__init__.py diff --git a/src/anemoi/datasets/use/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py similarity index 100% rename from src/anemoi/datasets/use/records/backends/__init__.py rename to src/anemoi/datasets/use/tabular/records/backends/__init__.py From a0b8e3488ecdc01487fc92490fdc86b665fca869 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:24:30 +0000 Subject: [PATCH 149/212] rename files --- src/anemoi/datasets/build/__init__.py | 1658 ----------------- src/anemoi/datasets/build/check.py | 328 ---- src/anemoi/datasets/build/chunks.py | 138 -- src/anemoi/datasets/build/config.py | 445 ----- src/anemoi/datasets/build/filter.py | 47 - src/anemoi/datasets/build/patch.py | 188 -- src/anemoi/datasets/build/persistent.py | 269 --- src/anemoi/datasets/build/size.py | 47 - src/anemoi/datasets/build/source.py | 51 - .../datasets/build/statistics/__init__.py | 561 ------ .../datasets/build/statistics/summary.py | 152 -- src/anemoi/datasets/build/testing.py | 4 - src/anemoi/datasets/build/typing.py | 14 - src/anemoi/datasets/build/utils.py | 198 -- src/anemoi/datasets/build/writer.py | 64 - src/anemoi/datasets/build/zarr.py | 331 ---- src/anemoi/datasets/check.py | 93 - src/anemoi/datasets/commands/create.py | 2 +- .../datasets/commands/recipe/__init__.py | 2 +- .../datasets/commands/recipe/migrate.py | 2 +- src/anemoi/datasets/dumper.py | 76 - src/anemoi/datasets/grids.py | 668 ------- src/anemoi/datasets/schemas/recipe.json | 131 -- src/anemoi/datasets/testing.py | 173 -- .../use/tabular/records/backends/__init__.py | 4 +- src/anemoi/datasets/validate.py | 598 ------ tests/create/utils/create.py | 2 +- 27 files changed, 6 insertions(+), 6240 deletions(-) delete mode 100644 src/anemoi/datasets/build/__init__.py delete mode 100644 src/anemoi/datasets/build/check.py delete mode 100644 src/anemoi/datasets/build/chunks.py delete mode 100644 src/anemoi/datasets/build/config.py delete mode 100644 src/anemoi/datasets/build/filter.py delete mode 100755 src/anemoi/datasets/build/patch.py delete mode 100644 src/anemoi/datasets/build/persistent.py delete mode 100644 src/anemoi/datasets/build/size.py delete mode 100644 src/anemoi/datasets/build/source.py delete mode 100644 src/anemoi/datasets/build/statistics/__init__.py delete mode 100644 src/anemoi/datasets/build/statistics/summary.py delete mode 100644 src/anemoi/datasets/build/testing.py delete mode 100644 src/anemoi/datasets/build/typing.py delete mode 100644 src/anemoi/datasets/build/utils.py delete mode 100644 src/anemoi/datasets/build/writer.py delete mode 100644 src/anemoi/datasets/build/zarr.py delete mode 100644 src/anemoi/datasets/check.py delete mode 100644 src/anemoi/datasets/dumper.py delete mode 100644 src/anemoi/datasets/grids.py delete mode 100644 src/anemoi/datasets/schemas/recipe.json delete mode 100644 src/anemoi/datasets/testing.py delete mode 100644 src/anemoi/datasets/validate.py diff --git a/src/anemoi/datasets/build/__init__.py b/src/anemoi/datasets/build/__init__.py deleted file mode 100644 index f28955dd8..000000000 --- a/src/anemoi/datasets/build/__init__.py +++ /dev/null @@ -1,1658 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import json -import logging -import os -import time -import uuid -import warnings -from functools import cached_property -from typing import Any - -import cftime -import numpy as np -import tqdm -import zarr -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta -from anemoi.utils.humanize import compress_dates -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.sanitise import sanitise -from earthkit.data.core.order import build_remapping - -from anemoi.datasets import MissingDateError -from anemoi.datasets import open_dataset -from anemoi.datasets.build.check import DatasetName -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.chunks import ChunkFilter -from anemoi.datasets.build.config import build_output -from anemoi.datasets.build.config import loader_config -from anemoi.datasets.build.input import InputBuilder -from anemoi.datasets.build.input.trace import enable_trace -from anemoi.datasets.build.persistent import build_storage -from anemoi.datasets.build.statistics import Summary -from anemoi.datasets.build.statistics import TmpStatistics -from anemoi.datasets.build.statistics import check_variance -from anemoi.datasets.build.statistics import compute_statistics -from anemoi.datasets.build.statistics import default_statistics_dates -from anemoi.datasets.build.statistics import fix_variance -from anemoi.datasets.build.utils import normalize_and_check_dates -from anemoi.datasets.build.writer import ViewCacheArray -from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.use.misc import as_first_date -from anemoi.datasets.use.misc import as_last_date - -LOG = logging.getLogger(__name__) - -VERSION = "0.30" - - -def json_tidy(o: Any) -> Any: - """Convert various types to JSON serializable format. - - Parameters - ---------- - o : Any - The object to convert. - - Returns - ------- - Any - The JSON serializable object. - """ - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.timedelta): - return frequency_to_string(o) - - if isinstance(o, cftime.DatetimeJulian): - import pandas as pd - - o = pd.Timestamp( - o.year, - o.month, - o.day, - o.hour, - o.minute, - o.second, - ) - return o.isoformat() - - if isinstance(o, (np.float32, np.float64)): - return float(o) - - raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") - - -def build_statistics_dates( - dates: list[datetime.datetime], - start: datetime.datetime | None, - end: datetime.datetime | None, -) -> tuple[str, str]: - """Compute the start and end dates for the statistics. - - Parameters - ---------- - dates : list of datetime.datetime - The list of dates. - start : Optional[datetime.datetime] - The start date. - end : Optional[datetime.datetime] - The end date. - - Returns - ------- - tuple of str - The start and end dates in ISO format. - """ - # if not specified, use the default statistics dates - default_start, default_end = default_statistics_dates(dates) - if start is None: - start = default_start - if end is None: - end = default_end - - # in any case, adapt to the actual dates in the dataset - start = as_first_date(start, dates) - end = as_last_date(end, dates) - - # and convert to datetime to isoformat - start = start.astype(datetime.datetime) - end = end.astype(datetime.datetime) - return (start.isoformat(), end.isoformat()) - - -def _path_readable(path: str) -> bool: - """Check if the path is readable. - - Parameters - ---------- - path : str - The path to check. - - Returns - ------- - bool - True if the path is readable, False otherwise. - """ - import zarr - - try: - zarr.open(path, "r") - return True - except zarr.errors.PathNotFoundError: - return False - - -class Dataset: - """A class to represent a dataset.""" - - def __init__(self, path: str): - """Initialize a Dataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - self.path = path - - _, ext = os.path.splitext(self.path) - if ext != ".zarr": - raise ValueError(f"Unsupported extension={ext} for path={self.path}") - - def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: - """Add a dataset to the Zarr store. - - Parameters - ---------- - mode : str, optional - The mode to open the Zarr store. - **kwargs - Additional arguments for the dataset. - - Returns - ------- - zarr.Array - The added dataset. - """ - import zarr - - z = zarr.open(self.path, mode=mode) - from anemoi.datasets.build.zarr import add_zarr_dataset - - return add_zarr_dataset(zarr_root=z, **kwargs) - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - import zarr - - LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") - for k, v in kwargs.items(): - if isinstance(v, np.datetime64): - v = v.astype(datetime.datetime) - if isinstance(v, datetime.date): - v = v.isoformat() - z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) - - @cached_property - def anemoi_dataset(self) -> Any: - """Get the Anemoi dataset.""" - return open_dataset(self.path) - - @cached_property - def zarr_metadata(self) -> dict: - """Get the Zarr metadata.""" - import zarr - - return dict(zarr.open(self.path, mode="r").attrs) - - def print_info(self) -> None: - """Print information about the dataset.""" - import zarr - - z = zarr.open(self.path, mode="r") - try: - LOG.info(z["data"].info) - except Exception as e: - LOG.info(e) - - def get_zarr_chunks(self) -> tuple: - """Get the chunks of the Zarr dataset. - - Returns - ------- - tuple - The chunks of the Zarr dataset. - """ - import zarr - - z = zarr.open(self.path, mode="r") - return z["data"].chunks - - def check_name( - self, - resolution: str, - dates: list[datetime.datetime], - frequency: datetime.timedelta, - raise_exception: bool = True, - is_test: bool = False, - ) -> None: - """Check the name of the dataset. - - Parameters - ---------- - resolution : str - The resolution of the dataset. - dates : list of datetime.datetime - The dates of the dataset. - frequency : datetime.timedelta - The frequency of the dataset. - raise_exception : bool, optional - Whether to raise an exception if the name is invalid. - is_test : bool, optional - Whether this is a test. - """ - basename, _ = os.path.splitext(os.path.basename(self.path)) - try: - DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() - except Exception as e: - if raise_exception and not is_test: - raise e - else: - LOG.warning(f"Dataset name error: {e}") - - def get_main_config(self) -> Any: - """Get the main configuration of the dataset. - - Returns - ------- - Any - The main configuration. - """ - import zarr - - z = zarr.open(self.path, mode="r") - config = loader_config(z.attrs.get("_create_yaml_config")) - - if "env" in config: - for k, v in config["env"].items(): - LOG.info(f"Setting env variable {k}={v}") - os.environ[k] = str(v) - - return config - - -class WritableDataset(Dataset): - """A class to represent a writable dataset.""" - - def __init__(self, path: str): - """Initialize a WritableDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="r+") - - @cached_property - def data_array(self) -> Any: - """Get the data array of the dataset.""" - import zarr - - return zarr.open(self.path, mode="r+")["data"] - - -class NewDataset(Dataset): - """A class to represent a new dataset.""" - - def __init__(self, path: str, overwrite: bool = False): - """Initialize a NewDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - overwrite : bool, optional - Whether to overwrite the existing dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="w") - self.z.create_group("_build") - - -class Actor: # TODO: rename to Creator - """A base class for dataset creation actors.""" - - dataset_class = WritableDataset - - def __init__(self, path: str, cache: str | None = None): - """Initialize an Actor instance. - - Parameters - ---------- - path : str - The path to the dataset. - cache : Optional[str], optional - The cache directory. - """ - # Catch all floating point errors, including overflow, sqrt(<0), etc - np.seterr(all="raise", under="warn") - - self.path = path - self.cache = cache - self.dataset = self.dataset_class(self.path) - - def run(self) -> None: - """Run the actor.""" - # to be implemented in the sub-classes - raise NotImplementedError() - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - self.dataset.update_metadata(**kwargs) - - def _cache_context(self) -> Any: - """Get the cache context. - - Returns - ------- - Any - The cache context. - """ - from anemoi.datasets.build.utils import cache_context - - return cache_context(self.cache) - - def check_unkown_kwargs(self, kwargs: dict) -> None: - """Check for unknown keyword arguments. - - Parameters - ---------- - kwargs : dict - The keyword arguments. - """ - # remove this latter - LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") - - def read_dataset_metadata(self, path: str) -> None: - """Read the metadata of the dataset. - - Parameters - ---------- - path : str - The path to the dataset. - """ - ds = open_dataset(path) - self.dataset_shape = ds.shape - self.variables_names = ds.variables - assert len(self.variables_names) == ds.shape[1], self.dataset_shape - self.dates = ds.dates - - self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) - - def check_missing_dates(expected: list[np.datetime64]) -> None: - """Check if the missing dates in the dataset match the expected dates. - - Parameters - ---------- - expected : list of np.datetime64 - The expected missing dates. - - Raises - ------ - ValueError - If the missing dates in the dataset do not match the expected dates. - """ - import zarr - - z = zarr.open(path, "r") - missing_dates = z.attrs.get("missing_dates", []) - missing_dates = sorted([np.datetime64(d) for d in missing_dates]) - if missing_dates != expected: - LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") - LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") - LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") - raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") - - check_missing_dates(self.missing_dates) - - -class Patch(Actor): - """A class to apply patches to a dataset.""" - - def __init__(self, path: str, options: dict = None, **kwargs: Any): - """Initialize a Patch instance. - - Parameters - ---------- - path : str - The path to the dataset. - options : dict, optional - The patch options. - """ - self.path = path - self.options = options or {} - - def run(self) -> None: - """Run the patch.""" - from anemoi.datasets.build.patch import apply_patch - - apply_patch(self.path, **self.options) - - -class Size(Actor): - """A class to compute the size of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Size instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the size computation.""" - from anemoi.datasets.build.size import compute_directory_sizes - - metadata = compute_directory_sizes(self.path) - self.update_metadata(**metadata) - - # Look for constant fields - ds = open_dataset(self.path) - constants = ds.computed_constant_fields() - - variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() - for k in constants: - variables_metadata[k]["constant_in_time"] = True - - self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) - - -class HasRegistryMixin: - """A mixin class to provide registry functionality.""" - - @cached_property - def registry(self) -> Any: - """Get the registry.""" - from anemoi.datasets.build.zarr import ZarrBuiltRegistry - - return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) - - -class HasStatisticTempMixin: - """A mixin class to provide temporary statistics functionality.""" - - @cached_property - def tmp_statistics(self) -> TmpStatistics: - """Get the temporary statistics.""" - directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") - return TmpStatistics(directory) - - -class HasElementForDataMixin: - """A mixin class to provide element creation functionality for data.""" - - def create_elements(self, config: Any) -> None: - """Create elements for the dataset. - - Parameters - ---------- - config : Any - The configuration. - """ - assert self.registry - assert self.tmp_statistics - - LOG.info(dict(config.dates)) - - self.groups = Groups(**config.dates) - LOG.info(self.groups) - - self.output = build_output(config.output, parent=self) - - self.input = InputBuilder( - config.input, - data_sources=config.get("data_sources", {}), - order_by=self.output.order_by, - flatten_grid=self.output.flatten_grid, - remapping=build_remapping(self.output.remapping), - use_grib_paramid=config.build.use_grib_paramid, - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(self.input) - - -class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to initialize a new dataset.""" - - dataset_class = NewDataset - - def __init__( - self, - path: str, - config: dict, - check_name: bool = False, - overwrite: bool = False, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - test: bool = False, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize an Init instance. - - Parameters - ---------- - path : str - The path to the dataset. - config : dict - The configuration. - check_name : bool, optional - Whether to check the dataset name. - overwrite : bool, optional - Whether to overwrite the existing dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - test : bool, optional - Whether this is a test. - cache : Optional[str], optional - The cache directory. - """ - if _path_readable(path) and not overwrite: - raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - - super().__init__(path, cache=cache) - self.config = config - self.check_name = check_name - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.test = test - - self.main_config = loader_config(config, is_test=test) - - # self.registry.delete() ?? - self.tmp_statistics.delete() - - assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by - self.create_elements(self.main_config) - - LOG.info(f"Groups: {self.groups}") - - one_date = self.groups.one_date() - # assert False, (type(one_date), type(self.groups)) - self.minimal_input = self.input.select(one_date) - LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") - LOG.info(self.minimal_input) - - def run(self) -> int: - """Run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - with self._cache_context(): - return self._run() - - def _run(self) -> int: - """Internal method to run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - """Create an empty dataset of the right final shape. - - Read a small part of the data to get the shape of the data and the resolution and more metadata. - """ - - LOG.info("Config loaded ok:") - # LOG.info(self.main_config) - - dates = self.groups.provider.values - frequency = self.groups.provider.frequency - missing = self.groups.provider.missing - - assert isinstance(frequency, datetime.timedelta), frequency - - LOG.info(f"Found {len(dates)} datetimes.") - LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(missing)}") - lengths = tuple(len(g) for g in self.groups) - - variables = self.minimal_input.variables - LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") - - variables_with_nans = self.main_config.statistics.get("allow_nans", []) - - ensembles = self.minimal_input.ensembles - LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") - - grid_points = self.minimal_input.grid_points - LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") - - resolution = self.minimal_input.resolution - LOG.info(f"{resolution=}") - - coords = self.minimal_input.coords - coords["dates"] = dates - total_shape = self.minimal_input.shape - total_shape[0] = len(dates) - LOG.info(f"total_shape = {total_shape}") - - chunks = self.output.get_chunking(coords) - LOG.info(f"{chunks=}") - dtype = self.output.dtype - - LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") - - metadata = {} - metadata["uuid"] = str(uuid.uuid4()) - - metadata.update(self.main_config.get("add_metadata", {})) - - metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() - - recipe = sanitise(self.main_config.get_serialisable_dict()) - - # Remove stuff added by prepml - for k in [ - "build_dataset", - "config_format_version", - "config_path", - "dataset_status", - "ecflow", - "metadata", - "platform", - "reading_chunks", - "upload", - ]: - recipe.pop(k, None) - - metadata["recipe"] = recipe - - metadata["description"] = self.main_config.description - metadata["licence"] = self.main_config["licence"] - metadata["attribution"] = self.main_config["attribution"] - - metadata["remapping"] = self.output.remapping - metadata["order_by"] = self.output.order_by_as_list - metadata["flatten_grid"] = self.output.flatten_grid - - metadata["ensemble_dimension"] = len(ensembles) - metadata["variables"] = variables - metadata["variables_with_nans"] = variables_with_nans - metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) - metadata["resolution"] = resolution - - metadata["data_request"] = self.minimal_input.data_request - metadata["field_shape"] = self.minimal_input.field_shape - metadata["proj_string"] = self.minimal_input.proj_string - metadata["variables_metadata"] = self.minimal_input.variables_metadata - - metadata["start_date"] = dates[0].isoformat() - metadata["end_date"] = dates[-1].isoformat() - metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in missing] - - metadata["version"] = VERSION - - self.dataset.check_name( - raise_exception=self.check_name, - is_test=self.test, - resolution=resolution, - dates=dates, - frequency=frequency, - ) - - if len(dates) != total_shape[0]: - raise ValueError( - f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " - f"does not match data shape {total_shape[0]}. {total_shape=}" - ) - - dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) - - metadata.update(self.main_config.get("force_metadata", {})) - - ############################################################### - # write metadata - ############################################################### - - self.update_metadata(**metadata) - - self.dataset.add_dataset( - name="data", - chunks=chunks, - dtype=dtype, - shape=total_shape, - dimensions=("time", "variable", "ensemble", "cell"), - ) - self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) - self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) - self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) - - self.registry.create(lengths=lengths) - self.tmp_statistics.create(exist_ok=False) - self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - - statistics_start, statistics_end = build_statistics_dates( - dates, - self.main_config.statistics.get("start"), - self.main_config.statistics.get("end"), - ) - self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) - LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") - - self.registry.add_to_history("init finished") - - assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - - # Return the number of groups to process, so we can show a nice progress bar - return len(lengths) - - -class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to load data into a dataset.""" - - def __init__( - self, - path: str, - parts: str | None = None, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize a Load instance. - - Parameters - ---------- - path : str - The path to the dataset. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - cache : Optional[str], optional - The cache directory. - """ - super().__init__(path, cache=cache) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.parts = parts - self.dataset = WritableDataset(self.path) - - self.main_config = self.dataset.get_main_config() - self.create_elements(self.main_config) - self.read_dataset_metadata(self.dataset.path) - - total = len(self.registry.get_flags()) - self.chunk_filter = ChunkFilter(parts=self.parts, total=total) - - self.data_array = self.dataset.data_array - self.n_groups = len(self.groups) - - def run(self) -> None: - """Run the data loading.""" - with self._cache_context(): - self._run() - - def _run(self) -> None: - """Internal method to run the data loading.""" - for igroup, group in enumerate(self.groups): - if not self.chunk_filter(igroup): - continue - if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") - continue - - # assert isinstance(group[0], datetime.datetime), type(group[0]) - LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - - result = self.input.select(argument=group) - assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) - - # There are several groups. - # There is one result to load for each group. - self.load_result(result) - self.registry.set_flag(igroup) - - self.registry.add_provenance(name="provenance_load") - self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - - self.dataset.print_info() - - def load_result(self, result: Any) -> None: - """Load the result into the dataset. - - Parameters - ---------- - result : Any - The result to load. - """ - # There is one cube to load for each result. - dates = list(result.group_of_dates) - - LOG.debug(f"Loading cube for {len(dates)} dates") - - cube = result.get_cube() - shape = cube.extended_user_shape - dates_in_data = cube.user_coords["valid_datetime"] - - LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") - - def check_shape(cube, dates, dates_in_data): - if cube.extended_user_shape[0] != len(dates): - print( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - print("Requested dates", compress_dates(dates)) - print("Cube dates", compress_dates(dates_in_data)) - - a = {as_datetime(_) for _ in dates} - b = {as_datetime(_) for _ in dates_in_data} - - print("Missing dates", compress_dates(a - b)) - print("Extra dates", compress_dates(b - a)) - - raise ValueError( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - - check_shape(cube, dates, dates_in_data) - - def check_dates_in_data(dates_in_data, requested_dates): - _requested_dates = [np.datetime64(_) for _ in requested_dates] - _dates_in_data = [np.datetime64(_) for _ in dates_in_data] - if _dates_in_data != _requested_dates: - LOG.error("Dates in data are not the requested ones:") - - dates_in_data = set(dates_in_data) - requested_dates = set(requested_dates) - - missing = sorted(requested_dates - dates_in_data) - extra = sorted(dates_in_data - requested_dates) - - if missing: - LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") - if extra: - LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") - - raise ValueError("Dates in data are not the requested ones") - - check_dates_in_data(dates_in_data, dates) - - def dates_to_indexes(dates, all_dates): - x = np.array(dates, dtype=np.datetime64) - y = np.array(all_dates, dtype=np.datetime64) - bitmap = np.isin(x, y) - return np.where(bitmap)[0] - - indexes = dates_to_indexes(self.dates, dates_in_data) - - array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) - LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") - self.load_cube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) - self.tmp_statistics.write(indexes, stats, dates=dates_in_data) - LOG.info("Flush data array") - array.flush() - LOG.info("Flushed data array") - - def _get_allow_nans(self) -> bool | list: - """Get the allow_nans configuration. - - Returns - ------- - bool | list - The allow_nans configuration. - """ - config = self.main_config - if "allow_nans" in config.build: - return config.build.allow_nans - - return config.statistics.get("allow_nans", []) - - def load_cube(self, cube: Any, array: ViewCacheArray) -> None: - """Load the cube into the array. - - Parameters - ---------- - cube : Any - The cube to load. - array : ViewCacheArray - The array to load into. - """ - # There are several cubelets for each cube - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - LOG.debug(f"Loading datacube: {cube}") - - def position(x: Any) -> int | None: - if isinstance(x, str) and "/" in x: - x = x.split("/") - return int(x[0]) - return None - - bar = tqdm.tqdm( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - position=position(self.parts), - ) - for i, cubelet in enumerate(bar): - bar.set_description(f"Loading {i}/{total}") - - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values( - data[:], - name=name, - log=[i, data.shape, local_indexes], - allow_nans=self._get_allow_nans(), - ) - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.debug( - f"Elapsed: {seconds_to_human(time.time() - start)}, " - f"load time: {seconds_to_human(load)}, " - f"write time: {seconds_to_human(save)}." - ) - - -class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): - """A class to clean up temporary data and registry entries.""" - - def __init__( - self, - path: str, - statistics_temp_dir: str | None = None, - delta: list = [], - use_threads: bool = False, - **kwargs: Any, - ): - """Initialize a Cleanup instance. - - Parameters - ---------- - path : str - The path to the dataset. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - delta : list, optional - The delta values. - use_threads : bool, optional - Whether to use threads. - """ - super().__init__(path) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.additinon_temp_dir = statistics_temp_dir - self.actors = [ - _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) - for d in delta - ] - - def run(self) -> None: - """Run the cleanup.""" - - self.tmp_statistics.delete() - self.registry.clean() - for actor in self.actors: - actor.cleanup() - - -class Verify(Actor): - """A class to verify the integrity of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Verify instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the verification.""" - LOG.info(f"Verifying dataset at {self.path}") - LOG.info(str(self.dataset.anemoi_dataset)) - - -class AdditionsMixin: - """A mixin class to handle dataset additions.""" - - def skip(self) -> bool: - """Check if the additions should be skipped. - - Returns - ------- - bool - Whether to skip the additions. - """ - frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - if not self.delta.total_seconds() % frequency.total_seconds() == 0: - LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") - return True - - if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: - LOG.warning(f"Additions are disabled for {self.path} in the recipe.") - return True - - return False - - @cached_property - def tmp_storage_path(self) -> str: - """Get the path to the temporary storage.""" - name = "storage_for_additions" - if self.delta: - name += frequency_to_string(self.delta) - return os.path.join(f"{self.path}.{name}.tmp") - - def read_from_dataset(self) -> None: - """Read data from the dataset.""" - self.variables = self.dataset.anemoi_dataset.variables - self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - start = self.dataset.zarr_metadata["statistics_start_date"] - end = self.dataset.zarr_metadata["statistics_end_date"] - self.start = datetime.datetime.fromisoformat(start) - self.end = datetime.datetime.fromisoformat(end) - - ds = open_dataset(self.path, start=self.start, end=self.end) - self.dates = ds.dates - self.total = len(self.dates) - - idelta = self.delta.total_seconds() // self.frequency.total_seconds() - assert int(idelta) == idelta, idelta - idelta = int(idelta) - self.ds = DeltaDataset(ds, idelta) - - -class DeltaDataset: - """A class to represent a dataset with delta values.""" - - def __init__(self, ds: Any, idelta: int): - """Initialize a DeltaDataset instance. - - Parameters - ---------- - ds : Any - The dataset. - idelta : int - The delta value. - """ - self.ds = ds - self.idelta = idelta - - def __getitem__(self, i: int) -> Any: - """Get an item from the dataset. - - Parameters - ---------- - i : int - The index. - - Returns - ------- - Any - The item. - """ - j = i - self.idelta - if j < 0: - raise MissingDateError(f"Missing date {j}") - return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] - - -class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to initialize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize an _InitAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - def run(self) -> None: - """Run the additions initialization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) - self.tmp_storage.delete() - self.tmp_storage.create() - LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") - - def cleanup(self) -> None: - """Clean up the temporary storage.""" - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - self.tmp_storage.delete() - LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") - - -class _LoadAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to run dataset additions.""" - - def __init__( - self, - path: str, - delta: str, - parts: str | None = None, - use_threads: bool = False, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a _LoadAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - self.parts = parts - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Writing in {self.tmp_storage_path}") - - def run(self) -> None: - """Run the additions.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.read_from_dataset() - - chunk_filter = ChunkFilter(parts=self.parts, total=self.total) - for i in range(0, self.total): - if not chunk_filter(i): - continue - date = self.dates[i] - try: - arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) - self.tmp_storage.add([date, i, stats], key=date) - except MissingDateError: - self.tmp_storage.add([date, i, "missing"], key=date) - self.tmp_storage.flush() - LOG.debug(f"Dataset {self.path} additions run.") - - def allow_nans(self) -> bool: - """Check if NaNs are allowed. - - Returns - ------- - bool - Whether NaNs are allowed. - """ - if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): - return True - - variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) - if variables_with_nans is not None: - return variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") - return True - - -class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to finalize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize a _FinaliseAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Reading from {self.tmp_storage_path}.") - - def run(self) -> None: - """Run the additions finalization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}.") - return - - self.read_from_dataset() - - shape = (len(self.dates), len(self.variables)) - agg = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - has_nans=np.full(shape, False, dtype=np.bool_), - ) - LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") - - found = set() - ifound = set() - missing = set() - for _date, (date, i, stats) in self.tmp_storage.items(): - assert _date == date - if stats == "missing": - missing.add(date) - continue - - assert date not in found, f"Duplicates found {date}" - found.add(date) - ifound.add(i) - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k][i, ...] = stats[k] - - assert len(found) + len(missing) == len(self.dates), ( - len(found), - len(missing), - len(self.dates), - ) - assert found.union(missing) == set(self.dates), ( - found, - missing, - set(self.dates), - ) - - if len(ifound) < 2: - LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") - self.tmp_storage.delete() - return - - mask = sorted(list(ifound)) - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k] = agg[k][mask, ...] - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - assert agg[k].shape == agg["count"].shape, ( - agg[k].shape, - agg["count"].shape, - ) - - minimum = np.nanmin(agg["minimum"], axis=0) - maximum = np.nanmax(agg["maximum"], axis=0) - sums = np.nansum(agg["sums"], axis=0) - squares = np.nansum(agg["squares"], axis=0) - count = np.nansum(agg["count"], axis=0) - has_nans = np.any(agg["has_nans"], axis=0) - - assert sums.shape == count.shape - assert sums.shape == squares.shape - assert sums.shape == minimum.shape - assert sums.shape == maximum.shape - assert sums.shape == has_nans.shape - - mean = sums / count - assert sums.shape == mean.shape - - x = squares / count - mean * mean - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - # remove negative variance due to numerical errors - for i, name in enumerate(self.variables): - x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) - check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) - - stdev = np.sqrt(x) - assert sums.shape == stdev.shape - - self.summary = Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables, - has_nans=has_nans, - ) - LOG.info(f"Dataset {self.path} additions finalised.") - # self.check_statistics() - self._write(self.summary) - self.tmp_storage.delete() - - def _write(self, summary: Summary) -> None: - """Write the summary to the dataset. - - Parameters - ---------- - summary : Summary - The summary to write. - """ - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" - self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) - self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.debug(f"Wrote additions in {self.path}") - - -def multi_addition(cls: type) -> type: - """Create a class to handle multiple additions. - - Parameters - ---------- - cls : type - The class to handle additions. - - Returns - ------- - type - The class to handle multiple additions. - """ - - class MultiAdditions: - def __init__(self, *args, **kwargs: Any): - self.actors = [] - - for k in kwargs.pop("delta", []): - self.actors.append(cls(*args, delta=k, **kwargs)) - - if not self.actors: - LOG.warning("No delta found in kwargs, no additions will be computed.") - - def run(self) -> None: - """Run the additions.""" - for actor in self.actors: - actor.run() - - return MultiAdditions - - -InitAdditions = multi_addition(_InitAdditions) -LoadAdditions = multi_addition(_LoadAdditions) -FinaliseAdditions = multi_addition(_FinaliseAdditions) - - -class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): - """A class to compute statistics for a dataset.""" - - def __init__( - self, - path: str, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a Statistics instance. - - Parameters - ---------- - path : str - The path to the dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.use_threads = use_threads - self.progress = progress - self.statistics_temp_dir = statistics_temp_dir - - def run(self) -> None: - """Run the statistics computation.""" - start, end = ( - self.dataset.zarr_metadata["statistics_start_date"], - self.dataset.zarr_metadata["statistics_end_date"], - ) - start, end = np.datetime64(start), np.datetime64(end) - dates = self.dataset.anemoi_dataset.dates - - assert type(dates[0]) is type(start), (type(dates[0]), type(start)) - - dates = [d for d in dates if d >= start and d <= end] - dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] - variables = self.dataset.anemoi_dataset.variables - stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) - - LOG.info(stats) - - if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - "has_nans", - ]: - self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) - - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote statistics in {self.path}") - - @cached_property - def allow_nans(self) -> bool | list: - """Check if NaNs are allowed.""" - import zarr - - z = zarr.open(self.path, mode="r") - if "allow_nans" in z.attrs: - return z.attrs["allow_nans"] - - if "variables_with_nans" in z.attrs: - return z.attrs["variables_with_nans"] - - warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") - return True - - -def chain(tasks: list) -> type: - """Create a class to chain multiple tasks. - - Parameters - ---------- - tasks : list - The list of tasks to chain. - - Returns - ------- - type - The class to chain multiple tasks. - """ - - class Chain(Actor): - def __init__(self, **kwargs: Any): - self.kwargs = kwargs - - def run(self) -> None: - """Run the chained tasks.""" - for cls in tasks: - t = cls(**self.kwargs) - t.run() - - return Chain - - -def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: - """Create a dataset creator. - - Parameters - ---------- - name : str - The name of the creator. - trace : Optional[str], optional - The trace file. - **kwargs - Additional arguments for the creator. - - Returns - ------- - Any - The dataset creator. - """ - if trace: - - enable_trace(trace) - - cls = dict( - init=Init, - load=Load, - size=Size, - patch=Patch, - statistics=Statistics, - finalise=chain([Statistics, Size, Cleanup]), - cleanup=Cleanup, - verify=Verify, - init_additions=InitAdditions, - load_additions=LoadAdditions, - finalise_additions=chain([FinaliseAdditions, Size]), - additions=chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup]), - )[name] - LOG.debug(f"Creating {cls.__name__} with {kwargs}") - return cls(**kwargs) - - -def validate_config(config: Any) -> None: - - import json - - import jsonschema - - def _tidy(d): - if isinstance(d, dict): - return {k: _tidy(v) for k, v in d.items()} - - if isinstance(d, list): - return [_tidy(v) for v in d if v is not None] - - # jsonschema does not support datetime.date - if isinstance(d, datetime.datetime): - return d.isoformat() - - if isinstance(d, datetime.date): - return d.isoformat() - - return d - - # https://json-schema.org - - with open( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "schemas", - "recipe.json", - ) - ) as f: - schema = json.load(f) - - try: - jsonschema.validate(instance=_tidy(config), schema=schema) - except jsonschema.exceptions.ValidationError as e: - LOG.error("❌ Config validation failed (jsonschema):") - LOG.error(e.message) - raise diff --git a/src/anemoi/datasets/build/check.py b/src/anemoi/datasets/build/check.py deleted file mode 100644 index 3c09cc80b..000000000 --- a/src/anemoi/datasets/build/check.py +++ /dev/null @@ -1,328 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import datetime -import logging -import re -import warnings -from collections.abc import Callable -from typing import Any - -import numpy as np -from anemoi.utils.config import load_config -from anemoi.utils.dates import frequency_to_string -from numpy.typing import NDArray - -LOG = logging.getLogger(__name__) - - -class DatasetName: - """Validate and parse dataset names according to naming conventions.""" - - def __init__( - self, - name: str, - resolution: str | None = None, - start_date: datetime.date | None = None, - end_date: datetime.date | None = None, - frequency: datetime.timedelta | None = None, - ): - """Initialize a DatasetName instance. - - Parameters - ---------- - name : str - The name of the dataset. - resolution : Optional[str], optional - The resolution of the dataset. - start_date : Optional[datetime.date], optional - The start date of the dataset. - end_date : Optional[datetime.date], optional - The end date of the dataset. - frequency : Optional[datetime.timedelta], optional - The frequency of the dataset. - """ - self.name = name - self.parsed = self._parse(name) - print("---------------") - print(self.parsed) - print("---------------") - - self.messages = [] - - config = load_config().get("datasets", {}) - - if config.get("ignore_naming_conventions", False): - # setting the env variable ANEMOI_CONFIG_DATASETS_IGNORE_NAMING_CONVENTIONS=1 - # will ignore the naming conventions - return - - self.check_characters() - self.check_parsed() - self.check_resolution(resolution) - self.check_frequency(frequency) - self.check_start_date(start_date) - self.check_end_date(end_date) - - if self.messages: - self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items())) - - @property - def error_message(self) -> str: - """Generate an error message based on the collected messages.""" - out = " And ".join(self.messages) - if out: - out[0].upper() + out[1:] - return out - - def raise_if_not_valid(self, print: Callable = print) -> None: - """Raise a ValueError if the dataset name is not valid. - - Parameters - ---------- - print : Callable - The function to use for printing messages. - """ - if self.messages: - for m in self.messages: - print(m) - raise ValueError(self.error_message) - - def _parse(self, name: str) -> dict: - """Parse the dataset name into its components. - - Parameters - ---------- - name : str - The name of the dataset. - - Returns - ------- - dict - The parsed components of the dataset name. - """ - pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$" - match = re.match(pattern, name) - - if not match: - raise ValueError(f"the dataset name '{name}' does not follow naming convention. Does not match {pattern}") - - parsed = {} - if match: - keys = [ - "purpose", - "labelling", - "source", - "resolution", - "start_date", - "end_date", - "frequency", - "version", - "additional", - ] - parsed = {k: v for k, v in zip(keys, match.groups())} - - return parsed - - def __str__(self) -> str: - """Return the string representation of the dataset name.""" - return self.name - - def check_parsed(self) -> None: - """Check if the dataset name was parsed correctly.""" - if not self.parsed: - self.messages.append( - f"the dataset name {self} does not follow naming convention. " - "See here for details: " - "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html" - ) - - def check_resolution(self, resolution: str | None) -> None: - """Check if the resolution matches the expected format. - - Parameters - ---------- - resolution : str or None - The expected resolution. - """ - if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on": - self.messages.append( - f"the resolution {self.parsed['resolution'] } should start " - f"with a number or 'o' or 'n' in the dataset name {self}." - ) - - if resolution is None: - return - resolution_str = str(resolution).replace(".", "p").lower() - self._check_missing("resolution", resolution_str) - self._check_mismatch("resolution", resolution_str) - - def check_characters(self) -> None: - if not self.name.islower(): - self.messages.append(f"the {self.name} should be in lower case.") - if "_" in self.name: - self.messages.append(f"the {self.name} should use '-' instead of '_'.") - for c in self.name: - if not c.isalnum() and c not in "-": - self.messages.append(f"the {self.name} should only contain alphanumeric characters and '-'.") - - def check_frequency(self, frequency: datetime.timedelta | None) -> None: - """Check if the frequency matches the expected format. - - Parameters - ---------- - frequency : datetime.timedelta or None - The expected frequency. - """ - if frequency is None: - return - frequency_str = frequency_to_string(frequency) - self._check_missing("frequency", frequency_str) - self._check_mismatch("frequency", frequency_str) - - def check_start_date(self, start_date: datetime.date | None) -> None: - """Check if the start date matches the expected format. - - Parameters - ---------- - start_date : datetime.date or None - The expected start date. - """ - if start_date is None: - return - start_date_str = str(start_date.year) - self._check_missing("start_date", start_date_str) - self._check_mismatch("start_date", start_date_str) - - def check_end_date(self, end_date: datetime.date | None) -> None: - """Check if the end date matches the expected format. - - Parameters - ---------- - end_date : datetime.date or None - The expected end date. - """ - if end_date is None: - return - end_date_str = str(end_date.year) - self._check_missing("end_date", end_date_str) - self._check_mismatch("end_date", end_date_str) - - def _check_missing(self, key: str, value: str) -> None: - """Check if a component is missing from the dataset name. - - Parameters - ---------- - key : str - The component key. - value : str - The expected value. - """ - if value not in self.name: - self.messages.append(f"the {key} is {value}, but is missing in {self.name}.") - - def _check_mismatch(self, key: str, value: str) -> None: - """Check if a component value mismatches the expected value. - - Parameters - ---------- - key : str - The component key. - value : str - The expected value. - """ - if self.parsed.get(key) and self.parsed[key] != value: - self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.") - - -class StatisticsValueError(ValueError): - """Custom error for statistics value issues.""" - - pass - - -def check_data_values( - arr: NDArray[Any], *, name: str, log: list = [], allow_nans: bool | list | set | tuple | dict = False -) -> None: - """Check the values in the data array for validity. - - Parameters - ---------- - arr : NDArray[Any] - The data array to check. - name : str - The name of the data array. - log : list, optional - A list to log messages. - allow_nans : bool or list or set or tuple or dict, optional - Whether to allow NaNs in the data array. - """ - shape = arr.shape - - if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans: - arr = arr[~np.isnan(arr)] - - if arr.size == 0: - warnings.warn(f"Empty array for {name} ({shape})") - return - - assert arr.size > 0, (name, *log) - - min, max = arr.min(), arr.max() - assert not (np.isnan(arr).any()), (name, min, max, *log) - - if min == 9999.0: - warnings.warn(f"Min value 9999 for {name}") - - if max == 9999.0: - warnings.warn(f"Max value 9999 for {name}") - - in_minus_1_plus_1 = dict(minimum=-1, maximum=1) - limits = { - "cos_latitude": in_minus_1_plus_1, - "sin_latitude": in_minus_1_plus_1, - "cos_longitude": in_minus_1_plus_1, - "sin_longitude": in_minus_1_plus_1, - } - - if name in limits: - if min < limits[name]["minimum"]: - warnings.warn( - f"For {name}: minimum value in the data is {min}. " - "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" - ) - if max > limits[name]["maximum"]: - warnings.warn( - f"For {name}: maximum value in the data is {max}. " - "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" - ) - - -def check_stats(minimum: float, maximum: float, mean: float, msg: str, **kwargs: Any) -> None: - """Check if the mean value is within the min/max interval. - - Parameters - ---------- - minimum : float - The minimum value. - maximum : float - The maximum value. - mean : float - The mean value. - msg : str - The message to include in the error. - **kwargs : Any - Additional keyword arguments. - """ - tolerance = (abs(minimum) + abs(maximum)) * 0.01 - if (mean - minimum < -tolerance) or (mean - minimum < -tolerance): - raise StatisticsValueError( - f"Mean is not in min/max interval{msg} : we should have {minimum} <= {mean} <= {maximum}" - ) diff --git a/src/anemoi/datasets/build/chunks.py b/src/anemoi/datasets/build/chunks.py deleted file mode 100644 index 08cc1edfd..000000000 --- a/src/anemoi/datasets/build/chunks.py +++ /dev/null @@ -1,138 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings - -LOG = logging.getLogger(__name__) - -ALL = object() - - -class ChunkFilter: - """A filter to determine which chunks to process based on the specified parts. - - Attributes - ---------- - total : int - The total number of chunks. - allowed : object or list - The chunks that are allowed to be processed. - """ - - def __init__(self, *, parts: str | list, total: int): - """Initializes the ChunkFilter with the given parts and total number of chunks. - - Parameters - ---------- - parts : str or list - The parts to process, specified as 'i/n' or a list of such strings. - total : int - The total number of chunks. - - Raises - ------ - ValueError - If the parts format is invalid. - AssertionError - If the chunk number is invalid. - Warning - If the number of chunks is larger than the total number of chunks. - """ - self.total = total - - if isinstance(parts, list): - if len(parts) == 1: - parts = parts[0] - elif len(parts) == 0: - parts = None - else: - raise ValueError(f"Invalid parts format: {parts}. Must be in the form 'i/n'.") - - if not parts: - parts = "all" - - assert isinstance(parts, str), f"Argument parts must be a string, got {parts}." - - if parts.lower() == "all" or parts == "*": - self.allowed = ALL - return - - assert "/" in parts, f"Invalid parts format: {parts}. Must be in the form 'i/n'." - - i, n = parts.split("/") - i, n = int(i), int(n) - - assert i > 0, f"Chunk number {i} must be positive." - assert i <= n, f"Chunk number {i} must be less than total chunks {n}." - if n > total: - warnings.warn( - f"Number of chunks {n} is larger than the total number of chunks: {total}. " - "Some chunks will be empty." - ) - - chunk_size = total / n - parts = [x for x in range(total) if x >= (i - 1) * chunk_size and x < i * chunk_size] - - for i in parts: - if i < 0 or i >= total: - raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {total - 1}.") - if not parts: - warnings.warn(f"Nothing to do for chunk {i}/{n}.") - - LOG.debug(f"Running parts: {parts}") - - self.allowed = parts - - def __call__(self, i: int) -> bool: - """Checks if the given chunk number is allowed to be processed. - - Parameters - ---------- - i : int - The chunk number to check. - - Returns - ------- - bool - True if the chunk is allowed, False otherwise. - - Raises - ------ - AssertionError - If the chunk number is invalid. - """ - if i < 0 or i >= self.total: - raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {self.total - 1}.") - - if self.allowed == ALL: - return True - return i in self.allowed - - def __iter__(self) -> iter: - """Iterates over the allowed chunks. - - Yields - ------ - int - The next allowed chunk number. - """ - for i in range(self.total): - if self(i): - yield i - - def __len__(self) -> int: - """Returns the number of allowed chunks. - - Returns - ------- - int - The number of allowed chunks. - """ - return len([_ for _ in self]) diff --git a/src/anemoi/datasets/build/config.py b/src/anemoi/datasets/build/config.py deleted file mode 100644 index 4720ebb6b..000000000 --- a/src/anemoi/datasets/build/config.py +++ /dev/null @@ -1,445 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import os -from copy import deepcopy -from typing import Any - -import yaml -from anemoi.utils.config import DotDict -from anemoi.utils.config import load_any_dict_format -from earthkit.data.core.order import normalize_order_by - -from anemoi.datasets.dates.groups import Groups - -LOG = logging.getLogger(__name__) - - -def _get_first_key_if_dict(x: str | dict) -> str: - """Returns the first key if the input is a dictionary, otherwise returns the input string. - - Parameters - ---------- - x : str or dict - Input string or dictionary. - - Returns - ------- - str - The first key if input is a dictionary, otherwise the input string. - """ - if isinstance(x, str): - return x - return list(x.keys())[0] - - -def ensure_element_in_list(lst: list, elt: str, index: int) -> list: - """Ensures that a specified element is present at a given index in a list. - - Parameters - ---------- - lst : list - The list to check. - elt : str - The element to ensure is in the list. - index : int - The index at which the element should be present. - - Returns - ------- - list - The modified list with the element at the specified index. - """ - if elt in lst: - assert lst[index] == elt - return lst - - _lst = [_get_first_key_if_dict(d) for d in lst] - if elt in _lst: - assert _lst[index] == elt - return lst - - return lst[:index] + [elt] + lst[index:] - - -def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: - """Checks if a dictionary contains a specific key-value pair and sets it if not present. - - Parameters - ---------- - dic : dict - The dictionary to check. - key : str - The key to check in the dictionary. - value : Any - The value to set if the key is not present. - - Raises - ------ - ValueError - If the key is present but with a different value. - """ - if key in dic: - if dic[key] == value: - return - raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - LOG.info(f"Setting {key}={value} in config") - dic[key] = value - - -def resolve_includes(config: dict | list) -> dict | list: - """Resolves '<<' includes in a configuration dictionary or list. - - Parameters - ---------- - config : dict or list - The configuration to resolve includes for. - - Returns - ------- - dict or list - The configuration with includes resolved. - """ - if isinstance(config, list): - return [resolve_includes(c) for c in config] - if isinstance(config, dict): - include = config.pop("<<", {}) - new = deepcopy(include) - new.update(config) - return {k: resolve_includes(v) for k, v in new.items()} - return config - - -class Config(DotDict): - """Configuration class that extends DotDict to handle configuration loading and processing.""" - - def __init__(self, config: str | dict | None = None, **kwargs): - """Initializes the Config object. - - Parameters - ---------- - config : str or dict, optional - Path to the configuration file or a dictionary. Defaults to None. - **kwargs - Additional keyword arguments to update the configuration. - """ - if isinstance(config, str): - config = load_any_dict_format(config) - else: - config = deepcopy(config if config is not None else {}) - config = resolve_includes(config) - config.update(kwargs) - super().__init__(config) - - -class OutputSpecs: - """Class to handle output specifications for datasets.""" - - def __init__(self, config: Config, parent: Any): - """Initializes the OutputSpecs object. - - Parameters - ---------- - config : Config - The configuration object. - parent : Any - The parent object. - """ - self.config = config - if "order_by" in config: - assert isinstance(config.order_by, dict), config.order_by - - self.parent = parent - - @property - def dtype(self) -> str: - """Returns the data type for the output.""" - return self.config.dtype - - @property - def order_by_as_list(self) -> list[dict]: - """Returns the order_by configuration as a list of dictionaries.""" - return [{k: v} for k, v in self.config.order_by.items()] - - def get_chunking(self, coords: dict) -> tuple: - """Returns the chunking configuration based on coordinates. - - Parameters - ---------- - coords : dict - The coordinates dictionary. - - Returns - ------- - tuple - The chunking configuration. - """ - user = deepcopy(self.config.chunking) - chunks = [] - for k, v in coords.items(): - if k in user: - chunks.append(user.pop(k)) - else: - chunks.append(len(v)) - if user: - raise ValueError( - f"Unused chunking keys from config: {list(user.keys())}, not in known keys : {list(coords.keys())}" - ) - return tuple(chunks) - - @property - def order_by(self) -> dict: - """Returns the order_by configuration.""" - return self.config.order_by - - @property - def remapping(self) -> dict: - """Returns the remapping configuration.""" - return self.config.remapping - - @property - def flatten_grid(self) -> bool: - """Returns whether the grid should be flattened.""" - return self.config.flatten_grid - - @property - def statistics(self) -> str: - """Returns the statistics configuration.""" - return self.config.statistics - - -class LoadersConfig(Config): - """Configuration class for dataset loaders.""" - - def __init__(self, config: dict, *args, **kwargs): - """Initializes the LoadersConfig object. - - Parameters - ---------- - config : dict - The configuration dictionary. - *args - Additional positional arguments. - **kwargs - Additional keyword arguments. - """ - super().__init__(config, *args, **kwargs) - - # TODO: should use a json schema to validate the config - - self.setdefault("dataset_status", "experimental") - self.setdefault("description", "No description provided.") - self.setdefault("licence", "unknown") - self.setdefault("attribution", "unknown") - - self.setdefault("build", Config()) - self.build.setdefault("group_by", "monthly") - self.build.setdefault("use_grib_paramid", False) - self.build.setdefault("variable_naming", "default") - variable_naming = dict( - param="{param}", - param_levelist="{param}_{levelist}", - default="{param}_{levelist}", - ).get(self.build.variable_naming, self.build.variable_naming) - - self.setdefault("output", Config()) - self.output.setdefault("order_by", ["valid_datetime", "param_level", "number"]) - self.output.setdefault("remapping", Config(param_level=variable_naming)) - self.output.setdefault("statistics", "param_level") - self.output.setdefault("chunking", Config(dates=1, ensembles=1)) - self.output.setdefault("dtype", "float32") - - if "statistics_start" in self.output: - raise ValueError("statistics_start is not supported anymore. Use 'statistics:start:' instead") - if "statistics_end" in self.output: - raise ValueError("statistics_end is not supported anymore. Use 'statistics:end:' instead") - - self.setdefault("statistics", Config()) - if "allow_nans" not in self.statistics: - self.statistics.allow_nans = [] - - check_dict_value_and_set(self.output, "flatten_grid", True) - check_dict_value_and_set(self.output, "ensemble_dimension", 2) - - assert isinstance(self.output.order_by, (list, tuple)), self.output.order_by - self.output.order_by = ensure_element_in_list(self.output.order_by, "number", self.output.ensemble_dimension) - - order_by = self.output.order_by - assert len(order_by) == 3, order_by - assert _get_first_key_if_dict(order_by[0]) == "valid_datetime", order_by - assert _get_first_key_if_dict(order_by[2]) == "number", order_by - - self.output.order_by = normalize_order_by(self.output.order_by) - - self.setdefault("dates", Config()) - - self.dates["group_by"] = self.build.group_by - - ########### - - self.reading_chunks = self.get("reading_chunks") - - def get_serialisable_dict(self) -> dict: - """Returns a serializable dictionary representation of the configuration. - - Returns - ------- - dict - The serializable dictionary. - """ - return _prepare_serialisation(self) - - -def _prepare_serialisation(o: Any) -> Any: - """Prepares an object for serialization. - - Parameters - ---------- - o : Any - The object to prepare. - - Returns - ------- - Any - The prepared object. - """ - if isinstance(o, dict): - dic = {} - for k, v in o.items(): - v = _prepare_serialisation(v) - if k == "order_by" and isinstance(v, dict): - # zarr attributes are saved with sort_keys=True - # and ordered dict are reordered. - # This is a problem for "order_by" - # We ensure here that the order_by key contains - # a list of dict - v = [{kk: vv} for kk, vv in v.items()] - dic[k] = v - return dic - - if isinstance(o, (list, tuple)): - return [_prepare_serialisation(v) for v in o] - - if o in (None, True, False): - return o - - if isinstance(o, (str, int, float)): - return o - - if isinstance(o, (datetime.date, datetime.datetime)): - return o.isoformat() - - return str(o) - - -def set_to_test_mode(cfg: dict) -> None: - """Modifies the configuration to run in test mode. - - Parameters - ---------- - cfg : dict - The configuration dictionary. - """ - NUMBER_OF_DATES = 4 - - LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") - groups = Groups(**LoadersConfig(cfg).dates) - - dates = groups.provider.values - cfg["dates"] = dict( - start=dates[0], - end=dates[NUMBER_OF_DATES - 1], - frequency=groups.provider.frequency, - group_by=NUMBER_OF_DATES, - ) - - def set_element_to_test(obj): - if isinstance(obj, (list, tuple)): - for v in obj: - set_element_to_test(v) - return - if isinstance(obj, (dict, DotDict)): - if "grid" in obj: - previous = obj["grid"] - obj["grid"] = "20./20." - LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") - if "number" in obj: - if isinstance(obj["number"], (list, tuple)): - previous = obj["number"] - obj["number"] = previous[0:3] - LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") - for k, v in obj.items(): - set_element_to_test(v) - if "constants" in obj: - constants = obj["constants"] - if "param" in constants and isinstance(constants["param"], list): - constants["param"] = ["cos_latitude"] - - set_element_to_test(cfg) - - -def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: - """Loads and validates the configuration for dataset loaders. - - Parameters - ---------- - config : dict - The configuration dictionary. - is_test : bool, optional - Whether to run in test mode. Defaults to False. - - Returns - ------- - LoadersConfig - The validated configuration object. - """ - config = Config(config) - if is_test: - set_to_test_mode(config) - obj = LoadersConfig(config) - - # yaml round trip to check that serialisation works as expected - copy = obj.get_serialisable_dict() - copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader) - copy = Config(copy) - copy = LoadersConfig(config) - - a = yaml.dump(obj) - b = yaml.dump(copy) - if a != b: - print(a) - print(b) - raise ValueError("Serialisation failed") - - if "env" in copy: - for k, v in copy["env"].items(): - LOG.info(f"Setting env variable {k}={v}") - os.environ[k] = str(v) - - return copy - - -def build_output(*args, **kwargs) -> OutputSpecs: - """Builds the output specifications. - - Parameters - ---------- - *args - Additional positional arguments. - **kwargs - Additional keyword arguments. - - Returns - ------- - OutputSpecs - The output specifications object. - """ - return OutputSpecs(*args, **kwargs) diff --git a/src/anemoi/datasets/build/filter.py b/src/anemoi/datasets/build/filter.py deleted file mode 100644 index 4544db8f2..000000000 --- a/src/anemoi/datasets/build/filter.py +++ /dev/null @@ -1,47 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from typing import Any - -import earthkit.data as ekd - - -class TransformFilter: - """Calls filters from anemoi.transform.filters - - Parameters - ---------- - context : Any - The context in which the filter is created. - name : str - The name of the filter. - config : Dict[str, Any] - The configuration for the filter. - """ - - def __init__(self, context: Any, name: str, config: dict[str, Any]) -> None: - from anemoi.transform.filters import create_filter - - self.name = name - self.transform_filter = create_filter(context, config) - - def execute(self, input: ekd.FieldList) -> ekd.FieldList: - """Execute the transformation filter. - - Parameters - ---------- - input : ekd.FieldList - The input data to be transformed. - - Returns - ------- - ekd.FieldList - The transformed data. - """ - return self.transform_filter.forward(input) diff --git a/src/anemoi/datasets/build/patch.py b/src/anemoi/datasets/build/patch.py deleted file mode 100755 index 5cb08ec82..000000000 --- a/src/anemoi/datasets/build/patch.py +++ /dev/null @@ -1,188 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -import logging -import os - -import zarr - -LOG = logging.getLogger(__name__) - - -def fix_order_by(order_by: dict | list) -> list[dict]: - """Fix the order_by attribute to ensure it is a list of dictionaries. - - Parameters - ---------- - order_by : dict or list - The order_by attribute to fix. - - Returns - ------- - list[dict] - The fixed order_by attribute. - """ - if isinstance(order_by, list): - return order_by - - assert isinstance(order_by, dict), order_by - assert len(order_by) <= 3, order_by - lst = [] - lst.append({"valid_datetime": order_by["valid_datetime"]}) - lst.append({"param_level": order_by["param_level"]}) - lst.append({"number": order_by["number"]}) - return lst - - -def fix_history(history: list[dict]) -> list[dict]: - """Fix the history attribute by removing specific actions. - - Parameters - ---------- - history : list[dict] - The history attribute to fix. - - Returns - ------- - list[dict] - The fixed history attribute. - """ - new = history - new = [d for d in new if d.get("action") != "loading_data_start"] - new = [d for d in new if d.get("action") != "loading_data_end"] - return new - - -def fix_provenance(provenance: dict) -> dict: - """Fix the provenance attribute by adding missing fields and removing unnecessary ones. - - Parameters - ---------- - provenance : dict - The provenance attribute to fix. - - Returns - ------- - dict - The fixed provenance attribute. - """ - if "python" not in provenance: - provenance["python"] = provenance["platform"]["python_version"] - - for q in ( - "args", - "config_paths", - "executable", - "gpus", - "platform", - "python_path", - "assets", - ): - if q in provenance: - del provenance[q] - - for k, v in list(provenance["module_versions"].items()): - if v.startswith("<"): - del provenance["module_versions"][k] - if v.startswith("/"): - provenance["module_versions"][k] = os.path.join("...", os.path.basename(v)) - - for k, v in list(provenance["git_versions"].items()): - LOG.debug(k, v) - modified_files = v["git"].get("modified_files", []) - untracked_files = v["git"].get("untracked_files", []) - if not isinstance(modified_files, int): - modified_files = len(modified_files) - if not isinstance(untracked_files, int): - untracked_files = len(untracked_files) - provenance["git_versions"][k] = dict( - git={ - "sha1": v["git"]["sha1"], - "modified_files": modified_files, - "untracked_files": untracked_files, - } - ) - - LOG.debug(json.dumps(provenance, indent=2)) - # assert False - return provenance - - -def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: - """Apply a patch to the dataset at the given path. - - Parameters - ---------- - path : str - The path to the dataset. - verbose : bool, optional - Whether to log detailed information. Defaults to True. - dry_run : bool, optional - If True, do not actually apply the patch. Defaults to False. - """ - LOG.debug("====================") - LOG.debug(f"Patching {path}") - LOG.debug("====================") - - try: - attrs = zarr.open(path, mode="r").attrs.asdict() - except zarr.errors.PathNotFoundError as e: - LOG.error(f"Failed to open {path}") - LOG.error(e) - exit(0) - - FIXES = { - "history": fix_history, - "provenance_load": fix_provenance, - "provenance_statistics": fix_provenance, - "order_by": fix_order_by, - } - REMOVE = ["_create_yaml_config"] - - before = json.dumps(attrs, sort_keys=True) - - fixed_attrs = {} - for k, v in attrs.items(): - v = attrs[k] - if k in REMOVE: - LOG.info(f"✅ Remove {k}") - continue - - if k not in FIXES: - assert not k.startswith("provenance"), f"[{k}]" - LOG.debug(f"✅ Don't fix {k}") - fixed_attrs[k] = v - continue - - new_v = FIXES[k](v) - if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True): - LOG.info(f"✅ Fix {k}") - if verbose: - LOG.info(f" Before : {k}= {v}") - LOG.info(f" After : {k}= {new_v}") - else: - LOG.debug(f"✅ Unchanged {k}") - fixed_attrs[k] = new_v - - if dry_run: - return - z = zarr.open(path, mode="r+") - - for k in list(z.attrs.keys()): - if k not in fixed_attrs: - del z.attrs[k] - for k, v in fixed_attrs.items(): - z.attrs[k] = v - - after = json.dumps(z.attrs.asdict(), sort_keys=True) - if before != after: - LOG.info("Dataset changed by patch") - - assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True) diff --git a/src/anemoi/datasets/build/persistent.py b/src/anemoi/datasets/build/persistent.py deleted file mode 100644 index e52938507..000000000 --- a/src/anemoi/datasets/build/persistent.py +++ /dev/null @@ -1,269 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import glob -import hashlib -import json -import logging -import os -import pickle -import shutil -import socket -from collections.abc import Iterator -from typing import Any - -import numpy as np -from anemoi.utils.provenance import gather_provenance_info - -LOG = logging.getLogger(__name__) - - -class PersistentDict: - """A dictionary-like object that persists its contents to disk using pickle files. - - Attributes - ---------- - version : int - The version of the PersistentDict. - dirname : str - The directory where the data is stored. - name : str - The name of the directory. - ext : str - The extension of the directory. - """ - - version = 3 - - # Used in parrallel, during data loading, - # to write data in pickle files. - def __init__(self, directory: str, create: bool = True): - """Initialize the PersistentDict. - - Parameters - ---------- - directory : str - The directory where the data will be stored. - create : bool, optional - Whether to create the directory if it doesn't exist. - """ - self.dirname = directory - self.name, self.ext = os.path.splitext(os.path.basename(self.dirname)) - if create: - self.create() - - def create(self) -> None: - """Create the directory if it doesn't exist.""" - os.makedirs(self.dirname, exist_ok=True) - - def delete(self) -> None: - """Delete the directory and its contents.""" - try: - shutil.rmtree(self.dirname) - except FileNotFoundError: - pass - - def __str__(self) -> str: - """Return a string representation of the PersistentDict.""" - return f"{self.__class__.__name__}({self.dirname})" - - def items(self) -> Iterator[Any]: - """Yield items stored in the directory. - - Yields - ------ - Iterator[Any] - An iterator over the items. - """ - # use glob to read all pickles - files = glob.glob(self.dirname + "/*.pickle") - LOG.debug(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") - assert len(files) > 0, f"No files found in {self.dirname}" - for f in files: - with open(f, "rb") as f: - yield pickle.load(f) - - def add_provenance(self, **kwargs: Any) -> None: - """Add provenance information to the directory. - - Parameters - ---------- - **kwargs : Any - Additional provenance information. - """ - path = os.path.join(self.dirname, "provenance.json") - if os.path.exists(path): - return - out = dict(provenance=gather_provenance_info(), **kwargs) - with open(path, "w") as f: - json.dump(out, f) - - def add(self, elt: Any, *, key: Any) -> None: - """Add an element to the PersistentDict. - - Parameters - ---------- - elt : Any - The element to add. - key : Any - The key associated with the element. - """ - self[key] = elt - - def __setitem__(self, key: Any, elt: Any) -> None: - """Set an item in the PersistentDict. - - Parameters - ---------- - key : Any - The key associated with the element. - elt : Any - The element to set. - """ - h = hashlib.sha256(str(key).encode("utf-8")).hexdigest() - path = os.path.join(self.dirname, f"{h}.pickle") - - if os.path.exists(path): - LOG.warning(f"{path} already exists") - - tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" - with open(tmp_path, "wb") as f: - pickle.dump((key, elt), f) - shutil.move(tmp_path, path) - - LOG.debug(f"Written {self.name} data for len {key} in {path}") - - def flush(self) -> None: - """Flush the PersistentDict (no-op).""" - pass - - -class BufferedPersistentDict(PersistentDict): - """A buffered version of PersistentDict that stores elements in memory before persisting them to disk. - - Attributes - ---------- - buffer_size : int - The size of the buffer. - elements : list - The list of elements in the buffer. - keys : list - The list of keys in the buffer. - storage : PersistentDict - The underlying PersistentDict used for storage. - """ - - def __init__(self, buffer_size: int = 1000, **kwargs: Any): - """Initialize the BufferedPersistentDict. - - Parameters - ---------- - buffer_size : int, optional - The size of the buffer. - **kwargs : Any - Additional arguments for PersistentDict. - """ - self.buffer_size = buffer_size - self.elements = [] - self.keys = [] - self.storage = PersistentDict(**kwargs) - - def add(self, elt: Any, *, key: Any) -> None: - """Add an element to the BufferedPersistentDict. - - Parameters - ---------- - elt : Any - The element to add. - key : Any - The key associated with the element. - """ - self.elements.append(elt) - self.keys.append(key) - if len(self.keys) > self.buffer_size: - self.flush() - - def flush(self) -> None: - """Flush the buffer and store the elements in PersistentDict.""" - k = sorted(self.keys) - self.storage.add(self.elements, key=k) - self.elements = [] - self.keys = [] - - def items(self) -> Iterator[tuple[Any, Any]]: - """Yield items stored in the BufferedPersistentDict. - - Yields - ------ - Iterator[Tuple[Any, Any]] - An iterator over the items. - """ - for keys, elements in self.storage.items(): - yield from zip(keys, elements) - - def delete(self) -> None: - """Delete the storage directory and its contents.""" - self.storage.delete() - - def create(self) -> None: - """Create the storage directory if it doesn't exist.""" - self.storage.create() - - -def build_storage(directory: str, create: bool = True) -> BufferedPersistentDict: - """Build a BufferedPersistentDict storage. - - Parameters - ---------- - directory : str - The directory where the data will be stored. - create : bool, optional - Whether to create the directory if it doesn't exist. - - Returns - ------- - BufferedPersistentDict - The created BufferedPersistentDict. - """ - return BufferedPersistentDict(directory=directory, create=create) - - -if __name__ == "__main__": - N = 3 - P = 2 - directory = "h" - p = PersistentDict(directory=directory) - print(p) - assert os.path.exists(directory) - import numpy as np - - arrs = [np.random.randint(1, 101, size=(P,)) for _ in range(N)] - dates = [np.array([np.datetime64(f"2021-01-0{_+1}") + np.timedelta64(i, "h") for i in range(P)]) for _ in range(N)] - - print() - print("Writing the data") - for i in range(N): - _arr = arrs[i] - _dates = dates[i] - print(f"Writing : {i=}, {_arr=} {_dates=}") - p[_dates] = (i, _arr) - - print() - print("Reading the data back") - - p = PersistentDict(directory="h") - for _dates, (i, _arr) in p.items(): - print(f"{i=}, {_arr=}, {_dates=}") - - assert np.allclose(_arr, arrs[i]) - - assert len(_dates) == len(dates[i]) - for a, b in zip(_dates, dates[i]): - assert a == b diff --git a/src/anemoi/datasets/build/size.py b/src/anemoi/datasets/build/size.py deleted file mode 100644 index 4cffd66d7..000000000 --- a/src/anemoi/datasets/build/size.py +++ /dev/null @@ -1,47 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import os - -import tqdm -from anemoi.utils.humanize import bytes_to_human - -LOG = logging.getLogger(__name__) - - -def compute_directory_sizes(path: str) -> dict[str, int] | None: - """Computes the total size and number of files in a directory. - - Parameters - ---------- - path : str - The path to the directory. - - Returns - ------- - dict of str to int or None - A dictionary with the total size and number of files, or None if the path is not a directory. - """ - if not os.path.isdir(path): - return None - - size, n = 0, 0 - bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}") - for dirpath, _, filenames in bar: - for filename in filenames: - file_path = os.path.join(dirpath, filename) - size += os.path.getsize(file_path) - n += 1 - - LOG.info(f"Total size: {bytes_to_human(size)}") - LOG.info(f"Total number of files: {n}") - - return dict(total_size=size, total_number_of_files=n) diff --git a/src/anemoi/datasets/build/source.py b/src/anemoi/datasets/build/source.py deleted file mode 100644 index df4911690..000000000 --- a/src/anemoi/datasets/build/source.py +++ /dev/null @@ -1,51 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from abc import ABC -from abc import abstractmethod - -import earthkit.data as ekd - -from anemoi.datasets.build.typing import DateList - - -class Source(ABC): - """Represents a data source with a given context.""" - - emoji = "📦" # For tracing - - def __init__(self, context: any, *args: tuple, **kwargs: dict): - """Initialise the source. - Parameters - ---------- - context : Any - The context for the data source. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - """ - self.context = context - - @abstractmethod - def execute(self, dates: DateList) -> ekd.FieldList: - """Execute the filter. - - Parameters - ---------- - dates : DateList - The input dates. - - Returns - ------- - ekd.FieldList - The output data. - """ - - pass diff --git a/src/anemoi/datasets/build/statistics/__init__.py b/src/anemoi/datasets/build/statistics/__init__.py deleted file mode 100644 index f7ece19bb..000000000 --- a/src/anemoi/datasets/build/statistics/__init__.py +++ /dev/null @@ -1,561 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import glob -import hashlib -import json -import logging -import os -import pickle -import shutil -import socket -from typing import Any - -import numpy as np -import tqdm -from anemoi.utils.provenance import gather_provenance_info -from numpy.typing import NDArray - -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.statistics.summary import Summary - -LOG = logging.getLogger(__name__) - - -def default_statistics_dates(dates: list[datetime.datetime]) -> tuple[datetime.datetime, datetime.datetime]: - """Calculate default statistics dates based on the given list of dates. - - Parameters - ---------- - dates : list of datetime.datetime - List of datetime objects representing dates. - - Returns - ------- - tuple of datetime.datetime - A tuple containing the default start and end dates. - """ - - def to_datetime(d): - if isinstance(d, np.datetime64): - return d.tolist() - assert isinstance(d, datetime.datetime), d - return d - - first = dates[0] - last = dates[-1] - - first = to_datetime(first) - last = to_datetime(last) - - n_years = round((last - first).total_seconds() / (365.25 * 24 * 60 * 60)) - - if n_years < 10: - # leave out 20% of the data - k = int(len(dates) * 0.8) - end = dates[k - 1] - LOG.info(f"Number of years {n_years} < 10, leaving out 20%. {end=}") - return dates[0], end - - delta = 1 - if n_years >= 20: - delta = 3 - LOG.info(f"Number of years {n_years}, leaving out {delta} years.") - end_year = last.year - delta - - end = max(d for d in dates if to_datetime(d).year == end_year) - return dates[0], end - - -def to_datetime(date: str | datetime.datetime) -> np.datetime64: - """Convert a date to numpy datetime64 format. - - Parameters - ---------- - date : str or datetime.datetime - The date to convert. - - Returns - ------- - numpy.datetime64 - The converted date. - """ - if isinstance(date, str): - return np.datetime64(date) - if isinstance(date, datetime.datetime): - return np.datetime64(date, "s") - return date - - -def to_datetimes(dates: list[str | datetime.datetime]) -> list[np.datetime64]: - """Convert a list of dates to numpy datetime64 format. - - Parameters - ---------- - dates : list of str or datetime.datetime - List of dates to convert. - - Returns - ------- - list of numpy.datetime64 - List of converted dates. - """ - return [to_datetime(d) for d in dates] - - -def fix_variance(x: float, name: str, count: NDArray[Any], sums: NDArray[Any], squares: NDArray[Any]) -> float: - """Fix negative variance values due to numerical errors. - - Parameters - ---------- - x : float - The variance value. - name : str - The variable name. - count : numpy.ndarray - The count array. - sums : numpy.ndarray - The sums array. - squares : numpy.ndarray - The squares array. - - Returns - ------- - float - The fixed variance value. - """ - assert count.shape == sums.shape == squares.shape - assert isinstance(x, float) - - mean = sums / count - assert mean.shape == count.shape - - if x >= 0: - return x - - LOG.warning(f"Negative variance for {name=}, variance={x}") - magnitude = np.sqrt((squares / count + mean * mean) / 2) - LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}") - LOG.warning(f"Variable span order of magnitude is {magnitude}.") - LOG.warning(f"Count is {count}.") - - variances = squares / count - mean * mean - assert variances.shape == squares.shape == mean.shape - if np.all(variances >= 0): - LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.") - return 0 - - # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6: - # LOG.warning("Variance is negative but very small.") - # variances = squares / count - mean * mean - # return 0 - - LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).") - return 0 - - -def check_variance( - x: NDArray[Any], - variables_names: list[str], - minimum: NDArray[Any], - maximum: NDArray[Any], - mean: NDArray[Any], - count: NDArray[Any], - sums: NDArray[Any], - squares: NDArray[Any], -) -> None: - """Check for negative variance values and raise an error if found. - - Parameters - ---------- - x : numpy.ndarray - The variance array. - variables_names : list of str - List of variable names. - minimum : numpy.ndarray - The minimum values array. - maximum : numpy.ndarray - The maximum values array. - mean : numpy.ndarray - The mean values array. - count : numpy.ndarray - The count array. - sums : numpy.ndarray - The sums array. - squares : numpy.ndarray - The squares array. - - Raises - ------ - ValueError - If negative variance is found. - """ - if (x >= 0).all(): - return - print(x) - print(variables_names) - for i, (name, y) in enumerate(zip(variables_names, x)): - if y >= 0: - continue - print("---") - print(f"❗ Negative variance for {name=}, variance={y}") - print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}") - print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}") - print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}") - print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}") - print( - f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}" - ) - - raise ValueError("Negative variance") - - -def compute_statistics( - array: NDArray[Any], check_variables_names: list[str] | None = None, allow_nans: bool = False -) -> dict[str, np.ndarray]: - """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary. - - Parameters - ---------- - array : numpy.ndarray - The array to compute statistics for. - check_variables_names : list of str, optional - List of variable names to check. Defaults to None. - allow_nans : bool, optional - Whether to allow NaN values. Defaults to False. - - Returns - ------- - dict of str to numpy.ndarray - A dictionary containing the computed statistics. - """ - LOG.info(f"Computing statistics for {array.shape} array") - nvars = array.shape[1] - - LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}") - if check_variables_names: - assert nvars == len(check_variables_names), (nvars, check_variables_names) - stats_shape = (array.shape[0], nvars) - - count = np.zeros(stats_shape, dtype=np.int64) - sums = np.zeros(stats_shape, dtype=np.float64) - squares = np.zeros(stats_shape, dtype=np.float64) - minimum = np.zeros(stats_shape, dtype=np.float64) - maximum = np.zeros(stats_shape, dtype=np.float64) - has_nans = np.zeros(stats_shape, dtype=np.bool_) - - for i, chunk in tqdm.tqdm(enumerate(array), delay=1, total=array.shape[0], desc="Computing statistics"): - values = chunk.reshape((nvars, -1)) - - for j, name in enumerate(check_variables_names): - check_data_values(values[j, :], name=name, allow_nans=allow_nans) - if np.isnan(values[j, :]).all(): - # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") - LOG.warning(f"All NaN values for {name} ({j}) for date {i}") - - # Ignore NaN values - minimum[i] = np.nanmin(values, axis=1) - maximum[i] = np.nanmax(values, axis=1) - sums[i] = np.nansum(values, axis=1) - squares[i] = np.nansum(values * values, axis=1) - count[i] = np.sum(~np.isnan(values), axis=1) - has_nans[i] = np.isnan(values).any() - - LOG.info(f"Statistics computed for {nvars} variables.") - - return { - "minimum": minimum, - "maximum": maximum, - "sums": sums, - "squares": squares, - "count": count, - "has_nans": has_nans, - } - - -class TmpStatistics: - """Temporary statistics storage class.""" - - version = 3 - # Used in parrallel, during data loading, - # to write statistics in pickled npz files. - # can provide statistics for a subset of dates. - - def __init__(self, dirname: str, overwrite: bool = False) -> None: - """Initialize TmpStatistics. - - Parameters - ---------- - dirname : str - Directory name for storing statistics. - overwrite : bool, optional - Whether to overwrite existing files. Defaults to False. - """ - self.dirname = dirname - self.overwrite = overwrite - - def add_provenance(self, **kwargs: dict) -> None: - """Add provenance information. - - Parameters - ---------- - **kwargs : dict - Additional provenance information. - """ - self.create(exist_ok=True) - path = os.path.join(self.dirname, "provenance.json") - if os.path.exists(path): - return - out = dict(provenance=gather_provenance_info(), **kwargs) - with open(path, "w") as f: - json.dump(out, f) - - def create(self, exist_ok: bool) -> None: - """Create the directory for storing statistics. - - Parameters - ---------- - exist_ok : bool - Whether to ignore if the directory already exists. - """ - os.makedirs(self.dirname, exist_ok=exist_ok) - - def delete(self) -> None: - """Delete the directory for storing statistics.""" - try: - shutil.rmtree(self.dirname) - except FileNotFoundError: - pass - - def write(self, key: str, data: any, dates: list[datetime.datetime]) -> None: - """Write statistics data to a file. - - Parameters - ---------- - key : str - The key for the data. - data : any - The data to write. - dates : list of datetime.datetime - List of dates associated with the data. - """ - self.create(exist_ok=True) - h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest() - path = os.path.join(self.dirname, f"{h}.npz") - - if not self.overwrite: - assert not os.path.exists(path), f"{path} already exists" - - tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" - with open(tmp_path, "wb") as f: - pickle.dump((key, dates, data), f) - shutil.move(tmp_path, path) - - LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})") - - def _gather_data(self) -> tuple[str, list[datetime.datetime], dict]: - """Gather data from stored files. - - Yields - ------ - tuple of str, list of datetime.datetime, dict - A tuple containing key, dates, and data. - """ - # use glob to read all pickles - files = glob.glob(self.dirname + "/*.npz") - LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}") - assert len(files) > 0, f"No files found in {self.dirname}" - for f in files: - with open(f, "rb") as f: - yield pickle.load(f) - - def get_aggregated(self, *args: Any, **kwargs: Any) -> Summary: - """Get aggregated statistics. - - Parameters - ---------- - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Summary - The aggregated statistics summary. - """ - aggregator = StatAggregator(self, *args, **kwargs) - return aggregator.aggregate() - - def __str__(self) -> str: - """String representation of TmpStatistics. - - Returns - ------- - str - The string representation. - """ - return f"TmpStatistics({self.dirname})" - - -class StatAggregator: - """Statistics aggregator class.""" - - NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"] - - def __init__( - self, owner: TmpStatistics, dates: list[datetime.datetime], variables_names: list[str], allow_nans: bool - ) -> None: - """Initialize StatAggregator. - - Parameters - ---------- - owner : TmpStatistics - The owner TmpStatistics instance. - dates : list of datetime.datetime - List of dates to aggregate. - variables_names : list of str - List of variable names. - allow_nans : bool - Whether to allow NaN values. - """ - dates = sorted(dates) - dates = to_datetimes(dates) - assert dates, "No dates selected" - self.owner = owner - self.dates = dates - self._number_of_dates = len(dates) - self._set_of_dates = set(dates) - self.variables_names = variables_names - self.allow_nans = allow_nans - - self.shape = (self._number_of_dates, len(self.variables_names)) - LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") - - self.minimum = np.full(self.shape, np.nan, dtype=np.float64) - self.maximum = np.full(self.shape, np.nan, dtype=np.float64) - self.sums = np.full(self.shape, np.nan, dtype=np.float64) - self.squares = np.full(self.shape, np.nan, dtype=np.float64) - self.count = np.full(self.shape, -1, dtype=np.int64) - self.has_nans = np.full(self.shape, False, dtype=np.bool_) - - self._read() - - def _read(self) -> None: - """Read and aggregate statistics data from files.""" - - def check_type(a, b): - if not isinstance(a, set): - a = set(list(a)) - if not isinstance(b, set): - b = set(list(b)) - a = next(iter(a)) if a else None - b = next(iter(b)) if b else None - assert type(a) is type(b), (type(a), type(b)) - - found = set() - offset = 0 - - for _, _dates, stats in self.owner._gather_data(): - assert isinstance(stats, dict), stats - assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates)) - assert stats["minimum"].shape[1] == len(self.variables_names), ( - stats["minimum"].shape, - len(self.variables_names), - ) - for n in self.NAMES: - assert n in stats, (n, list(stats.keys())) - _dates = to_datetimes(_dates) - check_type(_dates, self._set_of_dates) - if found: - check_type(found, self._set_of_dates) - assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics" - - # filter dates - dates = set(_dates) & self._set_of_dates - - if not dates: - # dates have been completely filtered for this chunk - continue - - # filter data - bitmap = np.array([d in self._set_of_dates for d in _dates]) - for k in self.NAMES: - stats[k] = stats[k][bitmap] - - assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates)) - - # store data in self - found |= set(dates) - for name in self.NAMES: - array = getattr(self, name) - assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates)) - array[offset : offset + len(dates)] = stats[name] - offset += len(dates) - - for d in self.dates: - assert d in found, f"Statistics for date {d} not precomputed." - assert self._number_of_dates == len(found), "Not all dates found in precomputed statistics" - assert self._number_of_dates == offset, "Not all dates found in precomputed statistics." - LOG.debug(f"Statistics for {len(found)} dates found.") - - def aggregate(self) -> Summary: - """Aggregate the statistics data. - - Returns - ------- - Summary - The aggregated statistics summary. - """ - minimum = np.nanmin(self.minimum, axis=0) - maximum = np.nanmax(self.maximum, axis=0) - - sums = np.nansum(self.sums, axis=0) - squares = np.nansum(self.squares, axis=0) - count = np.nansum(self.count, axis=0) - has_nans = np.any(self.has_nans, axis=0) - assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape - - mean = sums / count - assert mean.shape == minimum.shape - - x = squares / count - mean * mean - assert x.shape == minimum.shape - - for i, name in enumerate(self.variables_names): - # remove negative variance due to numerical errors - x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1]) - - for i, name in enumerate(self.variables_names): - check_variance( - x[i : i + 1], - [name], - minimum[i : i + 1], - maximum[i : i + 1], - mean[i : i + 1], - count[i : i + 1], - sums[i : i + 1], - squares[i : i + 1], - ) - check_data_values(np.array([mean[i]]), name=name, allow_nans=False) - - stdev = np.sqrt(x) - - return Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables_names, - has_nans=has_nans, - ) diff --git a/src/anemoi/datasets/build/statistics/summary.py b/src/anemoi/datasets/build/statistics/summary.py deleted file mode 100644 index 59f3998b4..000000000 --- a/src/anemoi/datasets/build/statistics/summary.py +++ /dev/null @@ -1,152 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import json -from collections import defaultdict -from typing import Any - -import numpy as np - -from anemoi.datasets.build.check import StatisticsValueError -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.check import check_stats - - -class Summary(dict): - """This class is used to store the summary statistics of a dataset. It can be saved and loaded from a json file. And does some basic checks on the data.""" - - STATS_NAMES = [ - "minimum", - "maximum", - "mean", - "stdev", - "has_nans", - ] # order matter for __str__. - - def __init__(self, **kwargs: Any) -> None: - """Initialize the Summary object with given keyword arguments. - - Parameters - ---------- - **kwargs : Any - Arbitrary keyword arguments representing summary statistics. - """ - super().__init__(**kwargs) - self.check() - - @property - def size(self) -> int: - """Get the size of the summary, which is the number of variables.""" - return len(self["variables_names"]) - - def check(self) -> None: - """Perform checks on the summary statistics to ensure they are valid. - - Raises - ------ - AssertionError - If any of the checks fail. - StatisticsValueError - If any of the statistical checks fail. - """ - for k, v in self.items(): - if k == "variables_names": - assert len(v) == self.size - continue - assert v.shape == (self.size,) - if k == "count": - assert (v >= 0).all(), (k, v) - assert v.dtype == np.int64, (k, v) - continue - if k == "has_nans": - assert v.dtype == np.bool_, (k, v) - continue - if k == "stdev": - assert (v >= 0).all(), (k, v) - assert v.dtype == np.float64, (k, v) - - for i, name in enumerate(self["variables_names"]): - try: - check_stats(**{k: v[i] for k, v in self.items()}, msg=f"{i} {name}") - check_data_values(self["minimum"][i], name=name) - check_data_values(self["maximum"][i], name=name) - check_data_values(self["mean"][i], name=name) - except StatisticsValueError as e: - e.args += (i, name) - raise - - def __str__(self) -> str: - """Return a string representation of the summary statistics. - - Returns - ------- - str - A formatted string of the summary statistics. - """ - header = ["Variables"] + self.STATS_NAMES - out = [" ".join(header)] - - out += [ - " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) - for i, v in enumerate(self["variables_names"]) - ] - return "\n".join(out) - - def save(self, filename: str, **metadata: Any) -> None: - """Save the summary statistics to a JSON file. - - Parameters - ---------- - filename : str - The name of the file to save the summary statistics. - **metadata : Any - Additional metadata to include in the JSON file. - """ - assert filename.endswith(".json"), filename - dic = {} - for k in self.STATS_NAMES: - dic[k] = list(self[k]) - - out = dict(data=defaultdict(dict)) - for i, name in enumerate(self["variables_names"]): - for k in self.STATS_NAMES: - out["data"][name][k] = dic[k][i] - - out["metadata"] = metadata - - with open(filename, "w") as f: - json.dump(out, f, indent=2) - - def load(self, filename: str) -> "Summary": - """Load the summary statistics from a JSON file. - - Parameters - ---------- - filename : str - The name of the file to load the summary statistics from. - - Returns - ------- - Summary - The loaded Summary object. - """ - assert filename.endswith(".json"), filename - with open(filename) as f: - dic = json.load(f) - - dic_ = {} - for k, v in dic.items(): - if k == "count": - dic_[k] = np.array(v, dtype=np.int64) - continue - if k == "variables": - dic_[k] = v - continue - dic_[k] = np.array(v, dtype=np.float64) - return Summary(dic_) diff --git a/src/anemoi/datasets/build/testing.py b/src/anemoi/datasets/build/testing.py deleted file mode 100644 index 5363cd9f7..000000000 --- a/src/anemoi/datasets/build/testing.py +++ /dev/null @@ -1,4 +0,0 @@ -class TestingContext: - """A context for testing plugins.""" - - pass diff --git a/src/anemoi/datasets/build/typing.py b/src/anemoi/datasets/build/typing.py deleted file mode 100644 index 0eafdb193..000000000 --- a/src/anemoi/datasets/build/typing.py +++ /dev/null @@ -1,14 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime - -Date = datetime.datetime - -DateList = list[Date] diff --git a/src/anemoi/datasets/build/utils.py b/src/anemoi/datasets/build/utils.py deleted file mode 100644 index 00ea89e7b..000000000 --- a/src/anemoi/datasets/build/utils.py +++ /dev/null @@ -1,198 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import datetime -import os -import warnings -from contextlib import contextmanager -from typing import Any - -import numpy as np -from earthkit.data import settings -from numpy.typing import NDArray - - -def cache_context(dirname: str) -> contextmanager: - """Context manager for setting a temporary cache directory. - - Parameters - ---------- - dirname : str - The directory name for the cache. - - Returns - ------- - contextmanager - A context manager that sets the cache directory. - """ - - @contextmanager - def no_cache_context(): - yield - - if dirname is None: - return no_cache_context() - - os.makedirs(dirname, exist_ok=True) - # return settings.temporary("cache-directory", dirname) - return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname}) - - -def to_datetime_list(*args: Any, **kwargs: Any) -> list[datetime.datetime]: - """Convert various date formats to a list of datetime objects. - - Parameters - ---------- - *args : Any - Positional arguments for date conversion. - **kwargs : Any - Keyword arguments for date conversion. - - Returns - ------- - list[datetime.datetime] - A list of datetime objects. - """ - from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_ - - warnings.warn( - "to_datetime_list() is deprecated. Call earthkit.data.utils.dates.to_datetime_list() instead.", - DeprecationWarning, - stacklevel=2, - ) - return to_datetime_list_(*args, **kwargs) - - -def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime: - """Convert various date formats to a single datetime object. - - Parameters - ---------- - *args : Any - Positional arguments for date conversion. - **kwargs : Any - Keyword arguments for date conversion. - - Returns - ------- - datetime.datetime - A datetime object. - """ - from earthkit.data.utils.dates import to_datetime as to_datetime_ - - warnings.warn( - "to_datetime() is deprecated. Call earthkit.data.utils.dates.to_datetime() instead.", - DeprecationWarning, - stacklevel=2, - ) - - return to_datetime_(*args, **kwargs) - - -def make_list_int(value: str | list | tuple | int) -> list[int]: - """Convert a string, list, tuple, or integer to a list of integers. - - Parameters - ---------- - value : str or list or tuple or int - The value to convert. - - Returns - ------- - list[int] - A list of integers. - - Raises - ------ - ValueError - If the value cannot be converted to a list of integers. - """ - # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers. - # Moved to anemoi.utils.humanize - # replace with from anemoi.utils.humanize import make_list_int - # when anemoi-utils is released and pyproject.toml is updated - if isinstance(value, str): - if "/" not in value: - return [value] - bits = value.split("/") - if len(bits) == 3 and bits[1].lower() == "to": - value = list(range(int(bits[0]), int(bits[2]) + 1, 1)) - - elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by": - value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4]))) - - if isinstance(value, list): - return value - if isinstance(value, tuple): - return value - if isinstance(value, int): - return [value] - - raise ValueError(f"Cannot make list from {value}") - - -def normalize_and_check_dates( - dates: list[datetime.datetime], - start: datetime.datetime, - end: datetime.datetime, - frequency: datetime.timedelta, - dtype: str = "datetime64[s]", -) -> NDArray[Any]: - """Normalize and check a list of dates against a specified frequency. - - Parameters - ---------- - dates : list[datetime.datetime] - The list of dates to check. - start : datetime.datetime - The start date. - end : datetime.datetime - The end date. - frequency : datetime.timedelta - The frequency of the dates. - dtype : str, optional - The data type of the dates, by default "datetime64[s]". - - Returns - ------- - NDArray[Any] - An array of normalized dates. - - Raises - ------ - ValueError - If the final date size does not match the data shape. - """ - dates = [d.hdate if hasattr(d, "hdate") else d for d in dates] - - assert isinstance(frequency, datetime.timedelta), frequency - start = np.datetime64(start) - end = np.datetime64(end) - delta = np.timedelta64(frequency) - - res = [] - while start <= end: - res.append(start) - start += delta - dates_ = np.array(res).astype(dtype) - - if len(dates_) != len(dates): - raise ValueError( - f"Final date size {len(dates_)} (from {dates_[0]} to {dates_[-1]}, " - f"{frequency=}) does not match data shape {len(dates)} (from {dates[0]} to " - f"{dates[-1]})." - ) - - for i, (d1, d2) in enumerate(zip(dates, dates_)): - d1 = np.datetime64(d1) - d2 = np.datetime64(d2) - assert d1 == d2, (i, d1, d2) - - return dates_ diff --git a/src/anemoi/datasets/build/writer.py b/src/anemoi/datasets/build/writer.py deleted file mode 100644 index d573c1ca5..000000000 --- a/src/anemoi/datasets/build/writer.py +++ /dev/null @@ -1,64 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -import numpy as np -from numpy.typing import NDArray - -LOG = logging.getLogger(__name__) - - -class ViewCacheArray: - """A class that provides a caching mechanism for writing to a NumPy-like array. - - The is initialised with a NumPy-like array, a shape and a list to reindex the first - dimension. The array is used to store the final data, while the cache is used to - temporarily store the data before flushing it to the array. - - The `flush` method copies the contents of the cache to the final array. - """ - - def __init__(self, array: NDArray[Any], *, shape: tuple[int, ...], indexes: list[int]): - """Initialize the ViewCacheArray. - - Parameters - ---------- - array : NDArray[Any] - The NumPy-like array to store the final data. - shape : tuple[int, ...] - The shape of the cache array. - indexes : list[int] - List to reindex the first dimension. - """ - assert len(indexes) == shape[0], (len(indexes), shape[0]) - self.array = array - self.dtype = array.dtype - self.cache = np.full(shape, np.nan, dtype=self.dtype) - self.indexes = indexes - - def __setitem__(self, key: tuple[int, ...], value: NDArray[Any]) -> None: - """Set the value in the cache array at the specified key. - - Parameters - ---------- - key : tuple[int, ...] - The index key to set the value. - value : NDArray[Any] - The value to set in the cache array. - """ - self.cache[key] = value - - def flush(self) -> None: - """Copy the contents of the cache to the final array.""" - for i in range(self.cache.shape[0]): - global_i = self.indexes[i] - self.array[global_i] = self.cache[i] diff --git a/src/anemoi/datasets/build/zarr.py b/src/anemoi/datasets/build/zarr.py deleted file mode 100644 index 32b493dd3..000000000 --- a/src/anemoi/datasets/build/zarr.py +++ /dev/null @@ -1,331 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import shutil -from typing import Any - -import numpy as np -import zarr -from numpy.typing import NDArray - -LOG = logging.getLogger(__name__) - - -def add_zarr_dataset( - *, - name: str, - dtype: np.dtype = None, - fill_value: np.generic = None, - zarr_root: zarr.Group, - shape: tuple[int, ...] = None, - array: NDArray[Any] = None, - overwrite: bool = True, - dimensions: tuple[str, ...] = None, - **kwargs, -) -> zarr.Array: - """Add a dataset to a Zarr group. - - Parameters - ---------- - name : str - Name of the dataset. - dtype : np.dtype, optional - Data type of the dataset. - fill_value : np.generic, optional - Fill value for the dataset. - zarr_root : zarr.Group - Root Zarr group. - shape : tuple[int, ...], optional - Shape of the dataset. - array : NDArray[Any], optional - Array to initialize the dataset with. - overwrite : bool - Whether to overwrite existing dataset. - dimensions : tuple[str, ...] - Dimensions of the dataset. - **kwargs - Additional arguments for Zarr dataset creation. - - Returns - ------- - zarr.Array - The created Zarr array. - """ - assert dimensions is not None, "Please pass dimensions to add_zarr_dataset." - assert isinstance(dimensions, (tuple, list)) - - if dtype is None: - assert array is not None, (name, shape, array, dtype, zarr_root) - dtype = array.dtype - - if shape is None: - assert array is not None, (name, shape, array, dtype, zarr_root) - shape = array.shape - - if array is not None: - assert array.shape == shape, (array.shape, shape) - a = zarr_root.create_dataset( - name, - shape=shape, - dtype=dtype, - overwrite=overwrite, - **kwargs, - ) - a[...] = array - a.attrs["_ARRAY_DIMENSIONS"] = dimensions - return a - - if "fill_value" not in kwargs: - if str(dtype).startswith("float") or str(dtype).startswith("numpy.float"): - kwargs["fill_value"] = np.nan - elif str(dtype).startswith("datetime64") or str(dtype).startswith("numpy.datetime64"): - kwargs["fill_value"] = np.datetime64("NaT") - # elif str(dtype).startswith("timedelta64") or str(dtype).startswith( - # "numpy.timedelta64" - # ): - # kwargs["fill_value"] = np.timedelta64("NaT") - elif str(dtype).startswith("int") or str(dtype).startswith("numpy.int"): - kwargs["fill_value"] = 0 - elif str(dtype).startswith("bool") or str(dtype).startswith("numpy.bool"): - kwargs["fill_value"] = False - else: - raise ValueError(f"No fill_value for dtype={dtype}") - - a = zarr_root.create_dataset( - name, - shape=shape, - dtype=dtype, - overwrite=overwrite, - **kwargs, - ) - a.attrs["_ARRAY_DIMENSIONS"] = dimensions - return a - - -class ZarrBuiltRegistry: - """A class to manage the creation and access of Zarr datasets.""" - - name_lengths = "lengths" - name_flags = "flags" - lengths = None - flags = None - z = None - - def __init__(self, path: str, synchronizer_path: str | None = None, use_threads: bool = False): - """Initialize the ZarrBuiltRegistry. - - Parameters - ---------- - path : str - Path to the Zarr store. - synchronizer_path : Optional[str], optional - Path to the synchronizer. - use_threads : bool - Whether to use thread-based synchronization. - """ - import zarr - - assert isinstance(path, str), path - self.zarr_path = path - - if use_threads: - self.synchronizer = zarr.ThreadSynchronizer() - self.synchronizer_path = None - else: - if synchronizer_path is None: - synchronizer_path = self.zarr_path + ".sync" - self.synchronizer_path = synchronizer_path - self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) - - def clean(self) -> None: - """Clean up the synchronizer path.""" - if self.synchronizer_path is not None: - try: - shutil.rmtree(self.synchronizer_path) - except FileNotFoundError: - pass - - _build = self.zarr_path + "/_build" - try: - shutil.rmtree(_build) - except FileNotFoundError: - pass - - def _open_write(self) -> zarr.Group: - """Open the Zarr store in write mode.""" - import zarr - - return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer) - - def _open_read(self, sync: bool = True) -> zarr.Group: - """Open the Zarr store in read mode. - - Parameters - ---------- - sync : bool - Whether to use synchronization. - - Returns - ------- - zarr.Group - The opened Zarr group. - """ - import zarr - - if sync: - return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer) - else: - return zarr.open(self.zarr_path, mode="r") - - def new_dataset(self, *args, **kwargs) -> None: - """Create a new dataset in the Zarr store. - - Parameters - ---------- - *args - Positional arguments for dataset creation. - **kwargs - Keyword arguments for dataset creation. - """ - z = self._open_write() - zarr_root = z["_build"] - add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) - - def add_to_history(self, action: str, **kwargs) -> None: - """Add an action to the history attribute of the Zarr store. - - Parameters - ---------- - action : str - The action to record. - **kwargs - Additional information about the action. - """ - new = dict( - action=action, - timestamp=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat(), - ) - new.update(kwargs) - - z = self._open_write() - history = z.attrs.get("history", []) - history.append(new) - z.attrs["history"] = history - - def get_lengths(self) -> list[int]: - """Get the lengths dataset. - - Returns - ------- - list[int] - The lengths dataset. - """ - z = self._open_read() - return list(z["_build"][self.name_lengths][:]) - - def get_flags(self, **kwargs) -> list[bool]: - """Get the flags dataset. - - Parameters - ---------- - **kwargs - Additional arguments for reading the dataset. - - Returns - ------- - list[bool] - The flags dataset. - """ - z = self._open_read(**kwargs) - return list(z["_build"][self.name_flags][:]) - - def get_flag(self, i: int) -> bool: - """Get a specific flag. - - Parameters - ---------- - i : int - Index of the flag. - - Returns - ------- - bool - The flag value. - """ - z = self._open_read() - return z["_build"][self.name_flags][i] - - def set_flag(self, i: int, value: bool = True) -> None: - """Set a specific flag. - - Parameters - ---------- - i : int - Index of the flag. - value : bool - Value to set the flag to. - """ - z = self._open_write() - z.attrs["latest_write_timestamp"] = ( - datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() - ) - z["_build"][self.name_flags][i] = value - - def ready(self) -> bool: - """Check if all flags are set. - - Returns - ------- - bool - True if all flags are set, False otherwise. - """ - return all(self.get_flags()) - - def create(self, lengths: list[int], overwrite: bool = False) -> None: - """Create the lengths and flags datasets. - - Parameters - ---------- - lengths : list[int] - Lengths to initialize the dataset with. - overwrite : bool - Whether to overwrite existing datasets. - """ - self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4")) - self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool)) - self.add_to_history("initialised") - - def reset(self, lengths: list[int]) -> None: - """Reset the lengths and flags datasets. - - Parameters - ---------- - lengths : list[int] - Lengths to initialize the dataset with. - """ - return self.create(lengths, overwrite=True) - - def add_provenance(self, name: str) -> None: - """Add provenance information to the Zarr store. - - Parameters - ---------- - name : str - Name of the provenance attribute. - """ - z = self._open_write() - - if name in z.attrs: - return - - from anemoi.utils.provenance import gather_provenance_info - - z.attrs[name] = gather_provenance_info() diff --git a/src/anemoi/datasets/check.py b/src/anemoi/datasets/check.py deleted file mode 100644 index d795d13f9..000000000 --- a/src/anemoi/datasets/check.py +++ /dev/null @@ -1,93 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -# A collection of functions to support pytest testing - -import logging -import math -import os -import re - -LOG = logging.getLogger(__name__) - - -def _check_group(group, verbosity: int, *path) -> None: - import zarr - - group_keys = sorted(group.keys()) - if not group_keys: - raise ValueError(f"Check group: {group} is empty.") - - for name in sorted(group_keys): - if name.startswith("."): - if verbosity > 1: - LOG.info(f"Check group: skipping {name}") - continue - - if isinstance(group[name], zarr.hierarchy.Group): - _check_group(group[name], verbosity, *path, name) - else: - _check_array(group[name], verbosity, *path, name) - - -def _check_array(array, verbosity: int, *path) -> None: - assert len(array.chunks) == len(array.shape) - assert math.prod(array.shape) % math.prod(array.chunks) == 0 - - file_count = math.prod(array.shape) // math.prod(array.chunks) - - full = os.path.join(*path) - - chunks = array.chunks - - count = 0 - for f in os.listdir(full): - if verbosity > 1: - LOG.info(f"Check array: checking {f}") - - if f.startswith("."): - if verbosity > 1: - LOG.info(f"Check array: skipping {f}") - continue - - bits = f.split(".") - - if len(bits) != len(chunks): - raise ValueError(f"File {f} is not a valid chunk file.") - - if not all(re.match(r"^\d+$", bit) for bit in bits): - raise ValueError(f"File {f} is not a valid chunk file.") - - count += 1 - - if count != file_count: - raise ValueError(f"File count {count} does not match expected {file_count} for {array.name}.") - - -def check_zarr(path: str, verbosity: int = 0) -> None: - """Check if a Zarr archive is valid, that no files are missing, and that the chunking is correct. - - Parameters - ---------- - path : str - Path to the Zarr archive. - verbosity : int, optional - Verbosity level for logging. Default is 0 (no logging). - """ - import zarr - - if verbosity > 0: - LOG.info(f"Checking Zarr archive {path}") - - if not os.path.exists(path) and not os.path.isdir(path): - # This does not work with non-directory Zarr archives - raise ValueError(f"Path {path} does not exist.") - - _check_group(zarr.open(path, mode="r"), verbosity, path) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 30df82783..2b92718ae 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.build import creator_factory + from anemoi.datasets.build.gridded import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 813ca47b8..6a93af8e6 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,7 +15,7 @@ import yaml -from anemoi.datasets.build import validate_config +from anemoi.datasets.build.gridded import validate_config from anemoi.datasets.commands import Command from anemoi.datasets.commands.recipe.format import format_recipe from anemoi.datasets.commands.recipe.migrate import migrate_recipe diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index ffaa3ddd1..dc337d0ff 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.build import validate_config +from anemoi.datasets.build.gridded import validate_config from anemoi.datasets.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/dumper.py b/src/anemoi/datasets/dumper.py deleted file mode 100644 index 18c8d34d4..000000000 --- a/src/anemoi/datasets/dumper.py +++ /dev/null @@ -1,76 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import io -import logging - -import ruamel.yaml - -LOG = logging.getLogger(__name__) - - -def represent_date(dumper, data): - - if isinstance(data, datetime.datetime): - if data.tzinfo is None: - data = data.replace(tzinfo=datetime.timezone.utc) - data = data.astimezone(datetime.timezone.utc) - iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" - else: - iso_str = data.isoformat() - - return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) - - -# --- Represent multiline strings with | style --- -def represent_multiline_str(dumper, data): - if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|") - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - -# --- Represent short lists inline (flow style) --- -def represent_inline_list(dumper, data): - - if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): - return dumper.represent_sequence("tag:yaml.org,2002:seq", data) - - return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) - - -def yaml_dump(obj, order=None, stream=None, **kwargs): - - if order: - - def _ordering(k): - return order.index(k) if k in order else len(order) - - obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} - - yaml = ruamel.yaml.YAML() - yaml.width = 120 # wrap long flow sequences - - yaml.Representer.add_representer(datetime.date, represent_date) - yaml.Representer.add_representer(datetime.datetime, represent_date) - yaml.Representer.add_representer(str, represent_multiline_str) - yaml.Representer.add_representer(list, represent_inline_list) - - data = ruamel.yaml.comments.CommentedMap() - for i, (k, v) in enumerate(obj.items()): - data[k] = v - if i > 0: - data.yaml_set_comment_before_after_key(key=k, before="\n") - - if stream: - yaml.dump(data, stream=stream, **kwargs) - - stream = io.StringIO() - yaml.dump(data, stream=stream, **kwargs) - return stream.getvalue() diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py deleted file mode 100644 index 26f675526..000000000 --- a/src/anemoi/datasets/grids.py +++ /dev/null @@ -1,668 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import base64 -import logging -from typing import Any - -import numpy as np -from numpy.typing import NDArray - -LOG = logging.getLogger(__name__) - - -def plot_mask( - path: str, - mask: NDArray[Any], - lats: NDArray[Any], - lons: NDArray[Any], - global_lats: NDArray[Any], - global_lons: NDArray[Any], -) -> None: - """Plot and save various visualizations of the mask and coordinates. - - Parameters - ---------- - path : str - The base path for saving the plots. - mask : NDArray[Any] - The mask array. - lats : NDArray[Any] - Latitude coordinates. - lons : NDArray[Any] - Longitude coordinates. - global_lats : NDArray[Any] - Global latitude coordinates. - global_lons : NDArray[Any] - Global longitude coordinates. - """ - import matplotlib.pyplot as plt - - s = 1 - - global_lons[global_lons >= 180] -= 360 - - plt.figure(figsize=(10, 5)) - plt.scatter(global_lons, global_lats, s=s, marker="o", c="r") - if isinstance(path, str): - plt.savefig(path + "-global.png") - - plt.figure(figsize=(10, 5)) - plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k") - if isinstance(path, str): - plt.savefig(path + "-cutout.png") - - plt.figure(figsize=(10, 5)) - plt.scatter(lons, lats, s=s) - if isinstance(path, str): - plt.savefig(path + "-lam.png") - # plt.scatter(lons, lats, s=0.01) - - plt.figure(figsize=(10, 5)) - plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") - plt.scatter(lons, lats, s=s) - if isinstance(path, str): - plt.savefig(path + "-both.png") - # plt.scatter(lons, lats, s=0.01) - - plt.figure(figsize=(10, 5)) - plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") - plt.scatter(lons, lats, s=s) - plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) - plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) - if isinstance(path, str): - plt.savefig(path + "-both-zoomed.png") - - plt.figure(figsize=(10, 5)) - plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") - plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) - plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) - if isinstance(path, str): - plt.savefig(path + "-global-zoomed.png") - - -# TODO: Use the one from anemoi.utils.grids instead -# from anemoi.utils.grids import ... -def xyz_to_latlon(x: NDArray[Any], y: NDArray[Any], z: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any]]: - """Convert Cartesian coordinates to latitude and longitude. - - Parameters - ---------- - x : NDArray[Any] - X coordinates. - y : NDArray[Any] - Y coordinates. - z : NDArray[Any] - Z coordinates. - - Returns - ------- - Tuple[NDArray[Any], NDArray[Any]] - Latitude and longitude coordinates. - """ - return ( - np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))), - np.rad2deg(np.arctan2(y, x)), - ) - - -# TODO: Use the one from anemoi.utils.grids instead -# from anemoi.utils.grids import ... -def latlon_to_xyz( - lat: NDArray[Any], lon: NDArray[Any], radius: float = 1.0 -) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: - """Convert latitude and longitude to Cartesian coordinates. - - Parameters - ---------- - lat : NDArray[Any] - Latitude coordinates. - lon : NDArray[Any] - Longitude coordinates. - radius : float, optional - Radius of the sphere. Defaults to 1.0. - - Returns - ------- - Tuple[NDArray[Any], NDArray[Any], NDArray[Any]] - X, Y, and Z coordinates. - """ - # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates - # We assume that the Earth is a sphere of radius 1 so N(phi) = 1 - # We assume h = 0 - # - phi = np.deg2rad(lat) - lda = np.deg2rad(lon) - - cos_phi = np.cos(phi) - cos_lda = np.cos(lda) - sin_phi = np.sin(phi) - sin_lda = np.sin(lda) - - x = cos_phi * cos_lda * radius - y = cos_phi * sin_lda * radius - z = sin_phi * radius - - return x, y, z - - -class Triangle3D: - """A class to represent a 3D triangle and perform intersection tests with rays.""" - - def __init__(self, v0: NDArray[Any], v1: NDArray[Any], v2: NDArray[Any]) -> None: - """Initialize the Triangle3D object. - - Parameters - ---------- - v0 : NDArray[Any] - First vertex of the triangle. - v1 : NDArray[Any] - Second vertex of the triangle. - v2 : NDArray[Any] - Third vertex of the triangle. - """ - self.v0 = v0 - self.v1 = v1 - self.v2 = v2 - - def intersect(self, ray_origin: NDArray[Any], ray_direction: NDArray[Any]) -> bool: - """Check if a ray intersects with the triangle. - - Parameters - ---------- - ray_origin : NDArray[Any] - Origin of the ray. - ray_direction : NDArray[Any] - Direction of the ray. - - Returns - ------- - bool - True if the ray intersects with the triangle, False otherwise. - """ - # Möller–Trumbore intersection algorithm - # https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm - - epsilon = 0.0000001 - - h = np.cross(ray_direction, self.v2 - self.v0) - a = np.dot(self.v1 - self.v0, h) - - if -epsilon < a < epsilon: - return False - - f = 1.0 / a - s = ray_origin - self.v0 - u = f * np.dot(s, h) - - if u < 0.0 or u > 1.0: - return False - - q = np.cross(s, self.v1 - self.v0) - v = f * np.dot(ray_direction, q) - - if v < 0.0 or u + v > 1.0: - return False - - t = f * np.dot(self.v2 - self.v0, q) - - if t > epsilon: - return True - - return False - - -def cropping_mask( - lats: NDArray[Any], - lons: NDArray[Any], - north: float, - west: float, - south: float, - east: float, -) -> NDArray[Any]: - """Create a mask for the points within the specified latitude and longitude bounds. - - Parameters - ---------- - lats : NDArray[Any] - Latitude coordinates. - lons : NDArray[Any] - Longitude coordinates. - north : float - Northern boundary. - west : float - Western boundary. - south : float - Southern boundary. - east : float - Eastern boundary. - - Returns - ------- - NDArray[Any] - Mask array. - """ - mask = ( - (lats >= south) - & (lats <= north) - & ( - ((lons >= west) & (lons <= east)) - | ((lons >= west + 360) & (lons <= east + 360)) - | ((lons >= west - 360) & (lons <= east - 360)) - ) - ) - return mask - - -def cutout_mask( - lats: NDArray[Any], - lons: NDArray[Any], - global_lats: NDArray[Any], - global_lons: NDArray[Any], - cropping_distance: float = 2.0, - neighbours: int = 5, - min_distance_km: int | float | None = None, - plot: str | None = None, -) -> NDArray[Any]: - """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]. - - Parameters - ---------- - lats : NDArray[Any] - Latitude coordinates. - lons : NDArray[Any] - Longitude coordinates. - global_lats : NDArray[Any] - Global latitude coordinates. - global_lons : NDArray[Any] - Global longitude coordinates. - cropping_distance : float, optional - Cropping distance. Defaults to 2.0. - neighbours : int, optional - Number of neighbours. Defaults to 5. - min_distance_km : Optional[Union[int, float]], optional - Minimum distance in kilometers. Defaults to None. - plot : Optional[str], optional - Path for saving the plot. Defaults to None. - - Returns - ------- - NDArray[Any] - Mask array. - """ - from scipy.spatial import cKDTree - - # TODO: transform min_distance from lat/lon to xyz - - assert global_lats.ndim == 1 - assert global_lons.ndim == 1 - assert lats.ndim == 1 - assert lons.ndim == 1 - - assert global_lats.shape == global_lons.shape - assert lats.shape == lons.shape - - north = np.amax(lats) - south = np.amin(lats) - east = np.amax(lons) - west = np.amin(lons) - - # Reduce the global grid to the area of interest - - mask = cropping_mask( - global_lats, - global_lons, - np.min([90.0, north + cropping_distance]), - west - cropping_distance, - np.max([-90.0, south - cropping_distance]), - east + cropping_distance, - ) - - # return mask - # mask = np.array([True] * len(global_lats), dtype=bool) - global_lats_masked = global_lats[mask] - global_lons_masked = global_lons[mask] - - global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) - global_points = np.array(global_xyx).transpose() - - xyx = latlon_to_xyz(lats, lons) - lam_points = np.array(xyx).transpose() - - if isinstance(min_distance_km, (int, float)): - min_distance = min_distance_km / 6371.0 - else: - points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] - distances, _ = cKDTree(points).query(points, k=2) - min_distance = np.min(distances[:, 1]) - - LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") - - # Use a cKDTree to find the nearest points - distances, indices = cKDTree(lam_points).query(global_points, k=neighbours) - - # Centre of the Earth - zero = np.array([0.0, 0.0, 0.0]) - - # After the loop, 'inside_lam' will contain a list point to EXCLUDE - inside_lam = [] - - for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)): - - # We check more than one triangle in case te global point - # is near the edge of triangle, (the lam point and global points are colinear) - - inside = False - for j in range(neighbours): - t = Triangle3D( - lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]] - ) - inside = t.intersect(zero, global_point) - if inside: - break - - close = np.min(distance) <= min_distance - - inside_lam.append(inside or close) - - j = 0 - inside_lam_array = np.array(inside_lam) - for i, m in enumerate(mask): - if not m: - continue - - mask[i] = inside_lam_array[j] - j += 1 - - assert j == len(inside_lam_array) - - # Invert the mask, so we have only the points outside the cutout - mask = ~mask - - if plot: - plot_mask(plot, mask, lats, lons, global_lats, global_lons) - - return mask - - -def thinning_mask( - lats: NDArray[Any], - lons: NDArray[Any], - global_lats: NDArray[Any], - global_lons: NDArray[Any], - cropping_distance: float = 2.0, -) -> NDArray[Any]: - """Return the list of points in [lats, lons] closest to [global_lats, global_lons]. - - Parameters - ---------- - lats : NDArray[Any] - Latitude coordinates. - lons : NDArray[Any] - Longitude coordinates. - global_lats : NDArray[Any] - Global latitude coordinates. - global_lons : NDArray[Any] - Global longitude coordinates. - cropping_distance : float, optional - Cropping distance. Defaults to 2.0. - - Returns - ------- - NDArray[Any] - Array of indices of the closest points. - """ - from scipy.spatial import cKDTree - - assert global_lats.ndim == 1 - assert global_lons.ndim == 1 - assert lats.ndim == 1 - assert lons.ndim == 1 - - assert global_lats.shape == global_lons.shape - assert lats.shape == lons.shape - - north = np.amax(lats) - south = np.amin(lats) - east = np.amax(lons) - west = np.amin(lons) - - # Reduce the global grid to the area of interest - - mask = cropping_mask( - global_lats, - global_lons, - np.min([90.0, north + cropping_distance]), - west - cropping_distance, - np.max([-90.0, south - cropping_distance]), - east + cropping_distance, - ) - - # return mask - global_lats_masked = global_lats[mask] - global_lons_masked = global_lons[mask] - - global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) - global_points = np.array(global_xyx).transpose() - - xyx = latlon_to_xyz(lats, lons) - points = np.array(xyx).transpose() - - # Use a cKDTree to find the nearest points - _, indices = cKDTree(points).query(global_points, k=1) - - return np.array([i for i in indices]) - - -def outline(lats: NDArray[Any], lons: NDArray[Any], neighbours: int = 5) -> list[int]: - """Find the outline of the grid points. - - Parameters - ---------- - lats : NDArray[Any] - Latitude coordinates. - lons : NDArray[Any] - Longitude coordinates. - neighbours : int, optional - Number of neighbours. Defaults to 5. - - Returns - ------- - List[int] - Indices of the outline points. - """ - from scipy.spatial import cKDTree - - xyx = latlon_to_xyz(lats, lons) - grid_points = np.array(xyx).transpose() - - # Use a cKDTree to find the nearest points - _, indices = cKDTree(grid_points).query(grid_points, k=neighbours) - - # Centre of the Earth - zero = np.array([0.0, 0.0, 0.0]) - - outside = [] - - for i, (point, index) in enumerate(zip(grid_points, indices)): - inside = False - for j in range(1, neighbours): - t = Triangle3D( - grid_points[index[j]], - grid_points[index[(j + 1) % neighbours]], - grid_points[index[(j + 2) % neighbours]], - ) - inside = t.intersect(zero, point) - if inside: - break - - if not inside: - outside.append(i) - - return outside - - -def deserialise_mask(encoded: str) -> NDArray[Any]: - """Deserialise a mask from a base64 encoded string. - - Parameters - ---------- - encoded : str - Base64 encoded string. - - Returns - ------- - NDArray[Any] - Deserialised mask array. - """ - import pickle - import zlib - - packed = pickle.loads(zlib.decompress(base64.b64decode(encoded))) - - mask = [] - value = False - for count in packed: - mask.extend([value] * count) - value = not value - return np.array(mask, dtype=bool) - - -def _serialise_mask(mask: NDArray[Any]) -> str: - """Serialise a mask to a base64 encoded string. - - Parameters - ---------- - mask : NDArray[Any] - Mask array. - - Returns - ------- - str - Base64 encoded string. - """ - import pickle - import zlib - - assert len(mask.shape) == 1 - assert len(mask) - - packed = [] - last = mask[0] - count = 1 - - for value in mask[1:]: - if value == last: - count += 1 - else: - packed.append(count) - last = value - count = 1 - - packed.append(count) - - # We always start with an 'off' value - # So if the first value is 'on', we need to add a zero - if mask[0]: - packed.insert(0, 0) - - return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8") - - -def serialise_mask(mask: NDArray[Any]) -> str: - """Serialise a mask and ensure it can be deserialised. - - Parameters - ---------- - mask : NDArray[Any] - Mask array. - - Returns - ------- - str - Base64 encoded string. - """ - result = _serialise_mask(mask) - # Make sure we can deserialise it - assert np.all(mask == deserialise_mask(result)) - return result - - -def nearest_grid_points( - source_latitudes: NDArray[Any], - source_longitudes: NDArray[Any], - target_latitudes: NDArray[Any], - target_longitudes: NDArray[Any], - max_distance: float = None, - k: int = 1, -) -> NDArray[Any]: - """Find the nearest grid points from source to target coordinates. - - Parameters - ---------- - source_latitudes : NDArray[Any] - Source latitude coordinates. - source_longitudes : NDArray[Any] - Source longitude coordinates. - target_latitudes : NDArray[Any] - Target latitude coordinates. - target_longitudes : NDArray[Any] - Target longitude coordinates. - max_distance: float, optional - Maximum distance between nearest point and point to interpolate. Defaults to None. - For example, 1e-3 is 1 km. - k : int, optional - The number of k closest neighbors to consider for interpolation - - Returns - ------- - NDArray[Any] - Indices of the nearest grid points. - """ - # TODO: Use the one from anemoi.utils.grids instead - # from anemoi.utils.grids import ... - from scipy.spatial import cKDTree - - source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) - source_points = np.array(source_xyz).transpose() - - target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) - target_points = np.array(target_xyz).transpose() - if max_distance is None: - distances, indices = cKDTree(source_points).query(target_points, k=k) - else: - distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) - return distances, indices - - -if __name__ == "__main__": - global_lats, global_lons = np.meshgrid( - np.linspace(90, -90, 90), - np.linspace(-180, 180, 180), - ) - global_lats = global_lats.flatten() - global_lons = global_lons.flatten() - - lats, lons = np.meshgrid( - np.linspace(50, 40, 100), - np.linspace(-10, 15, 100), - ) - lats = lats.flatten() - lons = lons.flatten() - - mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0) - - import matplotlib.pyplot as plt - - fig = plt.figure(figsize=(10, 5)) - plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") - plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") - # plt.scatter(lons, lats, s=0.01) - plt.savefig("cutout.png") diff --git a/src/anemoi/datasets/schemas/recipe.json b/src/anemoi/datasets/schemas/recipe.json deleted file mode 100644 index 3c02bfd64..000000000 --- a/src/anemoi/datasets/schemas/recipe.json +++ /dev/null @@ -1,131 +0,0 @@ -{ - "$schema": "https://json-schema.org/draft/2020-12/schema", - "type": "object", - "$id": "https://ecmwf.int/anemoi-datasets-recipe.schema.json", - "title": "Product", - "description": "Anemoi datasets recipe configuration", - "additionalProperties": false, - "$defs": { - "source-or-filter": { - "type": "object", - "minProperties": 1, - "maxProperties": 1 - }, - "pipe": { - "type": "array", - "items": { - "$ref": "#/$defs/input-object" - } - }, - "join": { - "type": "array", - "items": { - "$ref": "#/$defs/input-object" - } - }, - "concat": { - "type": "array", - "items": { - "type": "object", - "minProperties": 2, - "maxProperties": 2, - "required": [ - "dates" - ] - } - }, - "input-object": { - "oneOf": [ - { - "$ref": "#/$defs/pipe" - }, - { - "$ref": "#/$defs/join" - }, - { - "$ref": "#/$defs/concat" - }, - { - "$ref": "#/$defs/source-or-filter" - } - ] - } - }, - "properties": { - "env": { - "type": "object" - }, - "description": { - "type": "string" - }, - "name": { - "type": "string" - }, - "licence": { - "type": "string" - }, - "attribution": { - "type": "string" - }, - "dates": { - "type": "object", - "required": [ - "start", - "end" - ], - "properties": { - "start": { - "type": "string", - "format": "date" - }, - "end": { - "type": "string", - "format": "date" - }, - "frequency": { - "type": [ - "integer", - "string" - ] - }, - "group_by": { - "type": [ - "integer", - "string" - ] - } - } - }, - "input": { - "$ref": "#/$defs/input-object" - }, - "data_sources": { - "type": "object", - "patternProperties": { - "^[a-zA-Z_][a-zA-Z0-9_]*$": { - "$ref": "#/$defs/input-object" - } - }, - "additionalProperties": false - }, - "output": { - "type": "object" - }, - "statistics": { - "type": "object" - }, - "build": { - "type": "object" - }, - "common": { - "type": "object" - }, - "platform": { - "type": "object" - } - }, - "required": [ - "dates", - "input" - ] -} diff --git a/src/anemoi/datasets/testing.py b/src/anemoi/datasets/testing.py deleted file mode 100644 index a15c7fd7e..000000000 --- a/src/anemoi/datasets/testing.py +++ /dev/null @@ -1,173 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -# A collection of functions to support pytest testing - -import logging -from typing import Any - -LOG = logging.getLogger(__name__) - - -def assert_field_list( - fs: list[Any], - size: int | None = None, - start: Any | None = None, - end: Any | None = None, - constant: bool = False, - skip: Any | None = None, -) -> None: - """Asserts various properties of a list of fields. - - Parameters - ---------- - fs : List[Any] - List of fields to be checked. - size : Optional[int], optional - Expected size of the list. If None, the list must be non-empty. - start : Optional[Any], optional - Expected start metadata value. If None, no check is performed. - end : Optional[Any], optional - Expected end metadata value. If None, no check is performed. - constant : bool, optional - If True, checks that all fields are constant. - skip : Optional[Any], optional - Placeholder for future use. - """ - import numpy as np - - if size is None: - assert len(fs) > 0, fs - else: - assert len(fs) == size, (len(fs), size) - - first = fs[0] - last = fs[-1] - - if constant: - # TODO: add a check for constant fields - pass - else: - assert start is None or first.metadata("valid_datetime") == start, (first.metadata("valid_datetime"), start) - assert end is None or last.metadata("valid_datetime") == end, (last.metadata("valid_datetime"), end) - print(first.datetime()) - - print(last.metadata()) - - first = first - latitudes, longitudes = first.grid_points() - - assert len(latitudes.shape) == 1, latitudes.shape - assert len(longitudes.shape) == 1, longitudes.shape - - assert len(latitudes) == len(longitudes), (len(latitudes), len(longitudes)) - data = first.to_numpy(flatten=True) - - assert len(data) == len(latitudes), (len(data), len(latitudes)) - - north = np.max(latitudes) - south = np.min(latitudes) - east = np.max(longitudes) - west = np.min(longitudes) - - assert north >= south, (north, south) - assert east >= west, (east, west) - assert north <= 90, north - assert south >= -90, south - assert east <= 360, east - assert west >= -180, west - - -class IndexTester: - """Class to test indexing of datasets.""" - - def __init__(self, ds: Any) -> None: - """Initialise the IndexTester. - - Parameters - ---------- - ds : Any - Dataset. - """ - self.ds = ds - self.np = ds[:] # Numpy array - - assert self.ds.shape == self.np.shape, (self.ds.shape, self.np.shape) - assert (self.ds == self.np).all() - - def __getitem__(self, index: Any) -> None: - """Test indexing. - - Parameters - ---------- - index : Any - Index. - """ - LOG.info("IndexTester: %s", index) - if self.ds[index] is None: - assert False, (self.ds, index) - - if not (self.ds[index] == self.np[index]).all(): - assert (self.ds[index] == self.np[index]).all() - - -def default_test_indexing(ds): - - t = IndexTester(ds) - - t[0:10, :, 0] - t[:, 0:3, 0] - # t[:, :, 0] - t[0:10, 0:3, 0] - t[:, :, :] - - if ds.shape[1] > 2: # Variable dimension - t[:, (1, 2), :] - t[:, (1, 2)] - - t[0] - t[0, :] - t[0, 0, :] - t[0, 0, 0, :] - - if ds.shape[2] > 1: # Ensemble dimension - t[0:10, :, (0, 1)] - - for i in range(3): - t[i] - start = 5 * i - end = len(ds) - 5 * i - step = len(ds) // 10 - - t[start:end:step] - t[start:end] - t[start:] - t[:end] - t[::step] - - -class Trace: - - def __init__(self, ds): - self.ds = ds - self.f = open("trace.txt", "a") - - def __getattr__(self, name: str) -> Any: - - print(name, file=self.f, flush=True) - return getattr(self.ds, name) - - def __len__(self) -> int: - print("__len__", file=self.f, flush=True) - return len(self.ds) - - def __getitem__(self, index: Any) -> Any: - print("__getitem__", file=self.f, flush=True) - return self.ds[index] diff --git a/src/anemoi/datasets/use/tabular/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py index f09c32e4d..786202908 100644 --- a/src/anemoi/datasets/use/tabular/records/backends/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/backends/__init__.py @@ -100,7 +100,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.build import json_tidy + from anemoi.datasets.build.gridded import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -128,7 +128,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.build import json_tidy + from anemoi.datasets.build.gridded import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/src/anemoi/datasets/validate.py b/src/anemoi/datasets/validate.py deleted file mode 100644 index a1e168116..000000000 --- a/src/anemoi/datasets/validate.py +++ /dev/null @@ -1,598 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import math -from collections import defaultdict - -import numpy as np - -from anemoi.datasets.testing import default_test_indexing -from anemoi.datasets.use.dataset import Dataset - -LOG = logging.getLogger(__name__) -# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 - -TRAINING_METHODS = [ - "__getitem__", - "__len__", - "latitudes", - "longitudes", - "metadata", # Accessed when checkpointing - "missing", - "name_to_index", - "shape", - "statistics", - "supporting_arrays", # Accessed when checkpointing - "variables", -] - -EXTRA_TRAINING_METHODS = [ - "statistics_tendencies", -] - -DEBUGGING_METHODS = [ - "plot", - "to_index", - "tree", - "source", -] - -PUBLIC_METADATA_METHODS = [ - "arguments", - "dtype", - "end_date", - "resolution", - "start_date", - "field_shape", - "frequency", - "dates", - "typed_variables", - "variables_metadata", -] - -PRIVATE_METADATA_METHODS = [ - "computed_constant_fields", - "constant_fields", - "dataset_metadata", - "label", - "metadata_specific", - "provenance", -] - -INTERNAL_METHODS = [ - "mutate", - "swap_with_parent", - "dates_interval_to_indices", -] - -EXPERIMENTAL_METHODS = [ - "get_dataset_names", - "name", - "grids", -] - -OTHER_METHODS = [ - "collect_input_sources", - "collect_supporting_arrays", - "sub_shape", -] - - -METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")} - - -METHODS = set(sum(METHODS_CATEGORIES.values(), [])) - - -KWARGS = { - "__len__": {}, - "__getitem__": {"index": 0}, - "get_dataset_names": {"names": set()}, - "metadata": {}, - "metadata_specific": {}, - "mutate": {}, - "plot": {"date": 0, "variable": 0}, - "provenance": {}, - "source": {"index": 0}, - "statistics_tendencies": {}, - "sub_shape": {}, - "supporting_arrays": {}, - "swap_with_parent": {}, - "to_index": {"date": 0, "variable": 0}, - "tree": {}, -} - - -class Unknown: - emoji = "❓" - - -class Success: - emoji = "✅" - success = True - - def __repr__(self): - return "Success" - - -class Error: - success = False - - def __init__(self, message): - self.message = message - - def __repr__(self): - return str(self.message) or repr(self.message) or "Error" - - -class Failure(Error): - emoji = "💥" - - -class Internal(Error): - emoji = "💣" - - -class Invalid(Error): - emoji = "❌" - - -class Report: - - def __init__(self): - self.report = {} - self.methods = {} - self.warnings = defaultdict(list) - - def method(self, name, method): - self.methods[name] = method - - def success(self, name): - self.report[name] = Success() - - def failure(self, name, message): - self.report[name] = Failure(message) - - def internal(self, name, message): - self.report[name] = Internal(message) - - def invalid(self, name, exception): - self.report[name] = Invalid(exception) - - def warning(self, name, message): - self.warnings[name].append(message) - - def summary(self, detailed=False): - - maxlen = max(len(name) for name in self.report.keys()) - - for name, methods in METHODS_CATEGORIES.items(): - print() - print(f"{name.title().replace('_', ' ')}:") - print("-" * (len(name) + 1)) - print() - - for method in methods: - r = self.report.get(method, Unknown()) - msg = repr(r) - if not msg.endswith("."): - msg += "." - print(f"{r.emoji} {method.ljust(maxlen)}: {msg}") - - for w in self.warnings.get(method, []): - print(" " * (maxlen + 4), "⚠️", w) - - if r.success: - continue - - if not detailed: - continue - - if method not in self.methods: - continue - - proc = self.methods[method] - - doc = proc.__doc__ - if doc: - width = 80 - indent = maxlen + 4 - doc = "\n".join(["=" * width, "", doc, "=" * width]) - indented_doc = "\n".join(" " * indent + line for line in doc.splitlines()) - print() - print(indented_doc) - print() - print() - - print() - - -def _no_validate(report, dataset, name, result): - report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}") - - -def validate_variables(report, dataset, name, result): - """Validate the variables of the dataset.""" - - if not isinstance(result, (list, tuple)): - raise ValueError(f"Result is not a list or tuple {type(result)}") - - if len(result) != dataset.shape[1]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}") - - for value in result: - if not isinstance(value, str): - raise ValueError(f"`{value}` is not a string") - - -def validate_latitudes(report, dataset, name, result): - """Validate the latitudes of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result) != dataset.shape[3]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - if not np.all((result >= -90) & (result <= 90)): - raise ValueError("Result contains values outside the range [-90, 90]") - - if np.all((result >= -np.pi) & (result <= np.pi)): - report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?") - - -def validate_longitudes(report, dataset, name, result): - """Validate the longitudes of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result) != dataset.shape[3]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - if not np.all((result >= -180) & (result <= 360)): - raise ValueError("Result contains values outside the range [-180, 360]") - - if np.all((result >= -np.pi) & (result <= 2 * np.pi)): - report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?") - - -def validate_statistics(report, dataset, name, result): - """Validate the statistics of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - for key in ["mean", "stdev", "minimum", "maximum"]: - - if key not in result: - raise ValueError(f"Result does not contain `{key}`") - - if not isinstance(result[key], np.ndarray): - raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}") - - if len(result[key].shape) != 1: - raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1") - - if result[key].shape[0] != len(dataset.variables): - raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}") - - if not np.all(np.isfinite(result[key])): - raise ValueError(f"Result[{key}] contains non-finite values") - - if np.isnan(result[key]).any(): - report.invalid(name, ValueError(f"Result[{key}] contains NaN values")) - - -def validate_shape(report, dataset, name, result): - """Validate the shape of the dataset.""" - - if not isinstance(result, tuple): - raise ValueError(f"Result is not a tuple {type(result)}") - - if len(result) != 4: - raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}") - - if result[0] != len(dataset): - raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}") - - if result[1] != len(dataset.variables): - raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}") - - if result[2] != 1: # We ignore ensemble dimension for now - pass - - if result[3] != len(dataset.latitudes): - raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}") - - -def validate_supporting_arrays(report, dataset, name, result): - """Validate the supporting arrays of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - if "latitudes" not in result: - raise ValueError("Result does not contain `latitudes`") - - if "longitudes" not in result: - raise ValueError("Result does not contain `longitudes`") - - if not isinstance(result["latitudes"], np.ndarray): - raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}") - - if not isinstance(result["longitudes"], np.ndarray): - raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}") - - if np.any(result["latitudes"] != dataset.latitudes): - raise ValueError("Result[latitudes] does not match dataset.latitudes") - - if np.any(result["longitudes"] != dataset.longitudes): - raise ValueError("Result[longitudes] does not match dataset.longitudes") - - -def validate_dates(report, dataset, name, result): - """Validate the dates of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result.shape) != 1: - raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1") - - if result.shape[0] != len(dataset.dates): - raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}") - - if not np.issubdtype(result.dtype, np.datetime64): - raise ValueError(f"Result is not a datetime64 array {result.dtype}") - - if len(result) != len(dataset.dates): - raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - for d1, d2 in zip(result[:-1], result[1:]): - if d1 >= d2: - raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}") - - frequency = np.diff(result) - if not np.all(frequency == frequency[0]): - raise ValueError("Result contains non-constant frequency") - - -def validate_metadata(report, dataset, name, result): - """Validate the metadata of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - -def validate_missing(report, dataset, name, result): - """Validate the missing values of the dataset.""" - - if not isinstance(result, set): - raise ValueError(f"Result is not a set {type(result)}") - - if not all(isinstance(item, int) for item in result): - raise ValueError("Result contains non-integer values") - - if len(result) > 0: - if min(result) < 0: - raise ValueError("Result contains negative values") - - if max(result) >= len(dataset): - raise ValueError(f"Result contains values greater than {len(dataset)}") - - -def validate_name_to_index(report, dataset, name, result): - """Validate the name to index mapping of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - for key in dataset.variables: - if key not in result: - raise ValueError(f"Result does not contain `{key}`") - - if not isinstance(result[key], int): - raise ValueError(f"Result[{key}] is not an int {type(result[key])}") - - if result[key] < 0 or result[key] >= len(dataset.variables): - raise ValueError(f"Result[{key}] is out of bounds: {result[key]}") - - index_to_name = {v: k for k, v in result.items()} - for i in range(len(dataset.variables)): - if i not in index_to_name: - raise ValueError(f"Result does not contain index `{i}`") - - if not isinstance(index_to_name[i], str): - raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}") - - if index_to_name[i] != dataset.variables[i]: - raise ValueError( - f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}" - ) - - -def validate___getitem__(report, dataset, name, result): - """Validate the __getitem__ method of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if result.shape != dataset.shape[1:]: - raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}") - - -def validate___len__(report, dataset, name, result): - """Validate the __len__ method of the dataset.""" - - if not isinstance(result, int): - raise ValueError(f"Result is not an int {type(result)}") - - if result != dataset.shape[0]: - raise ValueError(f"Result has wrong length: {result} != {len(dataset)}") - - if result != len(dataset.dates): - raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}") - - -def validate_start_date(report, dataset, name, result): - """Validate the start date of the dataset.""" - - if not isinstance(result, np.datetime64): - raise ValueError(f"Result is not a datetime64 {type(result)}") - - if result != dataset.dates[0]: - raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}") - - -def validate_end_date(report, dataset, name, result): - """Validate the end date of the dataset.""" - - if not isinstance(result, np.datetime64): - raise ValueError(f"Result is not a datetime64 {type(result)}") - - if result != dataset.dates[-1]: - raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}") - - -def validate_field_shape(report, dataset, name, result): - """Validate the field shape of the dataset.""" - - if not isinstance(result, tuple): - raise ValueError(f"Result is not a tuple {type(result)}") - - if math.prod(result) != dataset.shape[-1]: - raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}") - - -def validate(report, dataset, name, kwargs=None): - - try: - - validate_fn = globals().get(f"validate_{name}", _no_validate) - - # Check if the method is still in the Dataset class - try: - report.method(name, getattr(Dataset, name)) - except AttributeError: - report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.") - return - - # Check if the method is supported by the dataset instance - try: - result = getattr(dataset, name) - except AttributeError as e: - report.failure(name, e) - return - - # Check if the method is callable - if callable(result): - if kwargs is None: - report.internal( - name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly." - ) - return - else: - if kwargs is not None: - report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.") - return - - if kwargs is not None: - result = result(**kwargs) - - if isinstance(result, np.ndarray) and np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - try: - validate_fn(report, dataset, name, result) - except Exception as e: - report.invalid(name, e) - return - - report.success(name) - - except Exception as e: - report.failure(name, e) - - -def validate_dtype(report, dataset, name, result): - """Validate the dtype of the dataset.""" - - if not isinstance(result, np.dtype): - raise ValueError(f"Result is not a np.dtype {type(result)}") - - -def validate_dataset(dataset, costly_checks=False, detailed=False): - """Validate the dataset.""" - - report = Report() - - if costly_checks: - # This check is expensive as it loads the entire dataset into memory - # so we make it optional - default_test_indexing(dataset) - - for i, x in enumerate(dataset): - y = dataset[i] - assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}" - - for name in METHODS: - validate(report, dataset, name, kwargs=KWARGS.get(name)) - - report.summary(detailed=detailed) - - -if __name__ == "__main__": - methods = METHODS_CATEGORIES.copy() - methods.pop("OTHER_METHODS") - - o = set(OTHER_METHODS) - overlap = False - for m in methods: - if set(methods[m]).intersection(set(OTHER_METHODS)): - print( - f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}" - ) - o = o - set(methods[m]) - overlap = True - - for m in methods: - for n in methods: - if n is not m: - if set(methods[m]).intersection(set(methods[n])): - print( - f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}" - ) - - if overlap: - print(sorted(o)) diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index a9fd74575..a10c83132 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.build import creator_factory +from anemoi.datasets.build.gridded import creator_factory class TestingContext: From 6c5ce246e23a4e110ec32d7066d33854c16c6163 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:25:04 +0000 Subject: [PATCH 150/212] rename files --- src/anemoi/datasets/misc/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/anemoi/datasets/misc/__init__.py diff --git a/src/anemoi/datasets/misc/__init__.py b/src/anemoi/datasets/misc/__init__.py new file mode 100644 index 000000000..e69de29bb From 6b1e5b5ae9cfb477490c0d226be3b57ecc6d34fe Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:25:38 +0000 Subject: [PATCH 151/212] rename files --- src/anemoi/datasets/build/gridded/__init__.py | 1658 +++++++++++++++++ src/anemoi/datasets/build/gridded/check.py | 328 ++++ src/anemoi/datasets/build/gridded/chunks.py | 138 ++ src/anemoi/datasets/build/gridded/config.py | 445 +++++ src/anemoi/datasets/build/gridded/filter.py | 47 + src/anemoi/datasets/build/gridded/patch.py | 188 ++ .../datasets/build/gridded/persistent.py | 269 +++ src/anemoi/datasets/build/gridded/size.py | 47 + src/anemoi/datasets/build/gridded/source.py | 51 + .../build/gridded/statistics/__init__.py | 561 ++++++ .../build/gridded/statistics/summary.py | 152 ++ src/anemoi/datasets/build/gridded/testing.py | 4 + src/anemoi/datasets/build/gridded/typing.py | 14 + src/anemoi/datasets/build/gridded/utils.py | 198 ++ src/anemoi/datasets/build/gridded/writer.py | 64 + src/anemoi/datasets/build/gridded/zarr.py | 331 ++++ src/anemoi/datasets/misc/check.py | 93 + src/anemoi/datasets/misc/dumper.py | 76 + src/anemoi/datasets/misc/grids.py | 668 +++++++ src/anemoi/datasets/misc/testing.py | 173 ++ src/anemoi/datasets/misc/validate.py | 598 ++++++ 21 files changed, 6103 insertions(+) create mode 100644 src/anemoi/datasets/build/gridded/__init__.py create mode 100644 src/anemoi/datasets/build/gridded/check.py create mode 100644 src/anemoi/datasets/build/gridded/chunks.py create mode 100644 src/anemoi/datasets/build/gridded/config.py create mode 100644 src/anemoi/datasets/build/gridded/filter.py create mode 100755 src/anemoi/datasets/build/gridded/patch.py create mode 100644 src/anemoi/datasets/build/gridded/persistent.py create mode 100644 src/anemoi/datasets/build/gridded/size.py create mode 100644 src/anemoi/datasets/build/gridded/source.py create mode 100644 src/anemoi/datasets/build/gridded/statistics/__init__.py create mode 100644 src/anemoi/datasets/build/gridded/statistics/summary.py create mode 100644 src/anemoi/datasets/build/gridded/testing.py create mode 100644 src/anemoi/datasets/build/gridded/typing.py create mode 100644 src/anemoi/datasets/build/gridded/utils.py create mode 100644 src/anemoi/datasets/build/gridded/writer.py create mode 100644 src/anemoi/datasets/build/gridded/zarr.py create mode 100644 src/anemoi/datasets/misc/check.py create mode 100644 src/anemoi/datasets/misc/dumper.py create mode 100644 src/anemoi/datasets/misc/grids.py create mode 100644 src/anemoi/datasets/misc/testing.py create mode 100644 src/anemoi/datasets/misc/validate.py diff --git a/src/anemoi/datasets/build/gridded/__init__.py b/src/anemoi/datasets/build/gridded/__init__.py new file mode 100644 index 000000000..f28955dd8 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/__init__.py @@ -0,0 +1,1658 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import json +import logging +import os +import time +import uuid +import warnings +from functools import cached_property +from typing import Any + +import cftime +import numpy as np +import tqdm +import zarr +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.humanize import compress_dates +from anemoi.utils.humanize import seconds_to_human +from anemoi.utils.sanitise import sanitise +from earthkit.data.core.order import build_remapping + +from anemoi.datasets import MissingDateError +from anemoi.datasets import open_dataset +from anemoi.datasets.build.check import DatasetName +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.chunks import ChunkFilter +from anemoi.datasets.build.config import build_output +from anemoi.datasets.build.config import loader_config +from anemoi.datasets.build.input import InputBuilder +from anemoi.datasets.build.input.trace import enable_trace +from anemoi.datasets.build.persistent import build_storage +from anemoi.datasets.build.statistics import Summary +from anemoi.datasets.build.statistics import TmpStatistics +from anemoi.datasets.build.statistics import check_variance +from anemoi.datasets.build.statistics import compute_statistics +from anemoi.datasets.build.statistics import default_statistics_dates +from anemoi.datasets.build.statistics import fix_variance +from anemoi.datasets.build.utils import normalize_and_check_dates +from anemoi.datasets.build.writer import ViewCacheArray +from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.use.misc import as_first_date +from anemoi.datasets.use.misc import as_last_date + +LOG = logging.getLogger(__name__) + +VERSION = "0.30" + + +def json_tidy(o: Any) -> Any: + """Convert various types to JSON serializable format. + + Parameters + ---------- + o : Any + The object to convert. + + Returns + ------- + Any + The JSON serializable object. + """ + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.timedelta): + return frequency_to_string(o) + + if isinstance(o, cftime.DatetimeJulian): + import pandas as pd + + o = pd.Timestamp( + o.year, + o.month, + o.day, + o.hour, + o.minute, + o.second, + ) + return o.isoformat() + + if isinstance(o, (np.float32, np.float64)): + return float(o) + + raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") + + +def build_statistics_dates( + dates: list[datetime.datetime], + start: datetime.datetime | None, + end: datetime.datetime | None, +) -> tuple[str, str]: + """Compute the start and end dates for the statistics. + + Parameters + ---------- + dates : list of datetime.datetime + The list of dates. + start : Optional[datetime.datetime] + The start date. + end : Optional[datetime.datetime] + The end date. + + Returns + ------- + tuple of str + The start and end dates in ISO format. + """ + # if not specified, use the default statistics dates + default_start, default_end = default_statistics_dates(dates) + if start is None: + start = default_start + if end is None: + end = default_end + + # in any case, adapt to the actual dates in the dataset + start = as_first_date(start, dates) + end = as_last_date(end, dates) + + # and convert to datetime to isoformat + start = start.astype(datetime.datetime) + end = end.astype(datetime.datetime) + return (start.isoformat(), end.isoformat()) + + +def _path_readable(path: str) -> bool: + """Check if the path is readable. + + Parameters + ---------- + path : str + The path to check. + + Returns + ------- + bool + True if the path is readable, False otherwise. + """ + import zarr + + try: + zarr.open(path, "r") + return True + except zarr.errors.PathNotFoundError: + return False + + +class Dataset: + """A class to represent a dataset.""" + + def __init__(self, path: str): + """Initialize a Dataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + self.path = path + + _, ext = os.path.splitext(self.path) + if ext != ".zarr": + raise ValueError(f"Unsupported extension={ext} for path={self.path}") + + def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: + """Add a dataset to the Zarr store. + + Parameters + ---------- + mode : str, optional + The mode to open the Zarr store. + **kwargs + Additional arguments for the dataset. + + Returns + ------- + zarr.Array + The added dataset. + """ + import zarr + + z = zarr.open(self.path, mode=mode) + from anemoi.datasets.build.zarr import add_zarr_dataset + + return add_zarr_dataset(zarr_root=z, **kwargs) + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + import zarr + + LOG.debug(f"Updating metadata {kwargs}") + z = zarr.open(self.path, mode="w+") + for k, v in kwargs.items(): + if isinstance(v, np.datetime64): + v = v.astype(datetime.datetime) + if isinstance(v, datetime.date): + v = v.isoformat() + z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) + + @cached_property + def anemoi_dataset(self) -> Any: + """Get the Anemoi dataset.""" + return open_dataset(self.path) + + @cached_property + def zarr_metadata(self) -> dict: + """Get the Zarr metadata.""" + import zarr + + return dict(zarr.open(self.path, mode="r").attrs) + + def print_info(self) -> None: + """Print information about the dataset.""" + import zarr + + z = zarr.open(self.path, mode="r") + try: + LOG.info(z["data"].info) + except Exception as e: + LOG.info(e) + + def get_zarr_chunks(self) -> tuple: + """Get the chunks of the Zarr dataset. + + Returns + ------- + tuple + The chunks of the Zarr dataset. + """ + import zarr + + z = zarr.open(self.path, mode="r") + return z["data"].chunks + + def check_name( + self, + resolution: str, + dates: list[datetime.datetime], + frequency: datetime.timedelta, + raise_exception: bool = True, + is_test: bool = False, + ) -> None: + """Check the name of the dataset. + + Parameters + ---------- + resolution : str + The resolution of the dataset. + dates : list of datetime.datetime + The dates of the dataset. + frequency : datetime.timedelta + The frequency of the dataset. + raise_exception : bool, optional + Whether to raise an exception if the name is invalid. + is_test : bool, optional + Whether this is a test. + """ + basename, _ = os.path.splitext(os.path.basename(self.path)) + try: + DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() + except Exception as e: + if raise_exception and not is_test: + raise e + else: + LOG.warning(f"Dataset name error: {e}") + + def get_main_config(self) -> Any: + """Get the main configuration of the dataset. + + Returns + ------- + Any + The main configuration. + """ + import zarr + + z = zarr.open(self.path, mode="r") + config = loader_config(z.attrs.get("_create_yaml_config")) + + if "env" in config: + for k, v in config["env"].items(): + LOG.info(f"Setting env variable {k}={v}") + os.environ[k] = str(v) + + return config + + +class WritableDataset(Dataset): + """A class to represent a writable dataset.""" + + def __init__(self, path: str): + """Initialize a WritableDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="r+") + + @cached_property + def data_array(self) -> Any: + """Get the data array of the dataset.""" + import zarr + + return zarr.open(self.path, mode="r+")["data"] + + +class NewDataset(Dataset): + """A class to represent a new dataset.""" + + def __init__(self, path: str, overwrite: bool = False): + """Initialize a NewDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + overwrite : bool, optional + Whether to overwrite the existing dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="w") + self.z.create_group("_build") + + +class Actor: # TODO: rename to Creator + """A base class for dataset creation actors.""" + + dataset_class = WritableDataset + + def __init__(self, path: str, cache: str | None = None): + """Initialize an Actor instance. + + Parameters + ---------- + path : str + The path to the dataset. + cache : Optional[str], optional + The cache directory. + """ + # Catch all floating point errors, including overflow, sqrt(<0), etc + np.seterr(all="raise", under="warn") + + self.path = path + self.cache = cache + self.dataset = self.dataset_class(self.path) + + def run(self) -> None: + """Run the actor.""" + # to be implemented in the sub-classes + raise NotImplementedError() + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + self.dataset.update_metadata(**kwargs) + + def _cache_context(self) -> Any: + """Get the cache context. + + Returns + ------- + Any + The cache context. + """ + from anemoi.datasets.build.utils import cache_context + + return cache_context(self.cache) + + def check_unkown_kwargs(self, kwargs: dict) -> None: + """Check for unknown keyword arguments. + + Parameters + ---------- + kwargs : dict + The keyword arguments. + """ + # remove this latter + LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") + + def read_dataset_metadata(self, path: str) -> None: + """Read the metadata of the dataset. + + Parameters + ---------- + path : str + The path to the dataset. + """ + ds = open_dataset(path) + self.dataset_shape = ds.shape + self.variables_names = ds.variables + assert len(self.variables_names) == ds.shape[1], self.dataset_shape + self.dates = ds.dates + + self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) + + def check_missing_dates(expected: list[np.datetime64]) -> None: + """Check if the missing dates in the dataset match the expected dates. + + Parameters + ---------- + expected : list of np.datetime64 + The expected missing dates. + + Raises + ------ + ValueError + If the missing dates in the dataset do not match the expected dates. + """ + import zarr + + z = zarr.open(path, "r") + missing_dates = z.attrs.get("missing_dates", []) + missing_dates = sorted([np.datetime64(d) for d in missing_dates]) + if missing_dates != expected: + LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") + LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") + LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") + raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") + + check_missing_dates(self.missing_dates) + + +class Patch(Actor): + """A class to apply patches to a dataset.""" + + def __init__(self, path: str, options: dict = None, **kwargs: Any): + """Initialize a Patch instance. + + Parameters + ---------- + path : str + The path to the dataset. + options : dict, optional + The patch options. + """ + self.path = path + self.options = options or {} + + def run(self) -> None: + """Run the patch.""" + from anemoi.datasets.build.patch import apply_patch + + apply_patch(self.path, **self.options) + + +class Size(Actor): + """A class to compute the size of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Size instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the size computation.""" + from anemoi.datasets.build.size import compute_directory_sizes + + metadata = compute_directory_sizes(self.path) + self.update_metadata(**metadata) + + # Look for constant fields + ds = open_dataset(self.path) + constants = ds.computed_constant_fields() + + variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() + for k in constants: + variables_metadata[k]["constant_in_time"] = True + + self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) + + +class HasRegistryMixin: + """A mixin class to provide registry functionality.""" + + @cached_property + def registry(self) -> Any: + """Get the registry.""" + from anemoi.datasets.build.zarr import ZarrBuiltRegistry + + return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) + + +class HasStatisticTempMixin: + """A mixin class to provide temporary statistics functionality.""" + + @cached_property + def tmp_statistics(self) -> TmpStatistics: + """Get the temporary statistics.""" + directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") + return TmpStatistics(directory) + + +class HasElementForDataMixin: + """A mixin class to provide element creation functionality for data.""" + + def create_elements(self, config: Any) -> None: + """Create elements for the dataset. + + Parameters + ---------- + config : Any + The configuration. + """ + assert self.registry + assert self.tmp_statistics + + LOG.info(dict(config.dates)) + + self.groups = Groups(**config.dates) + LOG.info(self.groups) + + self.output = build_output(config.output, parent=self) + + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=config.build.use_grib_paramid, + ) + LOG.debug("✅ INPUT_BUILDER") + LOG.debug(self.input) + + +class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to initialize a new dataset.""" + + dataset_class = NewDataset + + def __init__( + self, + path: str, + config: dict, + check_name: bool = False, + overwrite: bool = False, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + test: bool = False, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize an Init instance. + + Parameters + ---------- + path : str + The path to the dataset. + config : dict + The configuration. + check_name : bool, optional + Whether to check the dataset name. + overwrite : bool, optional + Whether to overwrite the existing dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + test : bool, optional + Whether this is a test. + cache : Optional[str], optional + The cache directory. + """ + if _path_readable(path) and not overwrite: + raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + + super().__init__(path, cache=cache) + self.config = config + self.check_name = check_name + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.test = test + + self.main_config = loader_config(config, is_test=test) + + # self.registry.delete() ?? + self.tmp_statistics.delete() + + assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by + self.create_elements(self.main_config) + + LOG.info(f"Groups: {self.groups}") + + one_date = self.groups.one_date() + # assert False, (type(one_date), type(self.groups)) + self.minimal_input = self.input.select(one_date) + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") + LOG.info(self.minimal_input) + + def run(self) -> int: + """Run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + with self._cache_context(): + return self._run() + + def _run(self) -> int: + """Internal method to run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + """Create an empty dataset of the right final shape. + + Read a small part of the data to get the shape of the data and the resolution and more metadata. + """ + + LOG.info("Config loaded ok:") + # LOG.info(self.main_config) + + dates = self.groups.provider.values + frequency = self.groups.provider.frequency + missing = self.groups.provider.missing + + assert isinstance(frequency, datetime.timedelta), frequency + + LOG.info(f"Found {len(dates)} datetimes.") + LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") + LOG.info(f"Missing dates: {len(missing)}") + lengths = tuple(len(g) for g in self.groups) + + variables = self.minimal_input.variables + LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") + + variables_with_nans = self.main_config.statistics.get("allow_nans", []) + + ensembles = self.minimal_input.ensembles + LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") + + grid_points = self.minimal_input.grid_points + LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") + + resolution = self.minimal_input.resolution + LOG.info(f"{resolution=}") + + coords = self.minimal_input.coords + coords["dates"] = dates + total_shape = self.minimal_input.shape + total_shape[0] = len(dates) + LOG.info(f"total_shape = {total_shape}") + + chunks = self.output.get_chunking(coords) + LOG.info(f"{chunks=}") + dtype = self.output.dtype + + LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") + + metadata = {} + metadata["uuid"] = str(uuid.uuid4()) + + metadata.update(self.main_config.get("add_metadata", {})) + + metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() + + recipe = sanitise(self.main_config.get_serialisable_dict()) + + # Remove stuff added by prepml + for k in [ + "build_dataset", + "config_format_version", + "config_path", + "dataset_status", + "ecflow", + "metadata", + "platform", + "reading_chunks", + "upload", + ]: + recipe.pop(k, None) + + metadata["recipe"] = recipe + + metadata["description"] = self.main_config.description + metadata["licence"] = self.main_config["licence"] + metadata["attribution"] = self.main_config["attribution"] + + metadata["remapping"] = self.output.remapping + metadata["order_by"] = self.output.order_by_as_list + metadata["flatten_grid"] = self.output.flatten_grid + + metadata["ensemble_dimension"] = len(ensembles) + metadata["variables"] = variables + metadata["variables_with_nans"] = variables_with_nans + metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) + metadata["resolution"] = resolution + + metadata["data_request"] = self.minimal_input.data_request + metadata["field_shape"] = self.minimal_input.field_shape + metadata["proj_string"] = self.minimal_input.proj_string + metadata["variables_metadata"] = self.minimal_input.variables_metadata + + metadata["start_date"] = dates[0].isoformat() + metadata["end_date"] = dates[-1].isoformat() + metadata["frequency"] = frequency + metadata["missing_dates"] = [_.isoformat() for _ in missing] + + metadata["version"] = VERSION + + self.dataset.check_name( + raise_exception=self.check_name, + is_test=self.test, + resolution=resolution, + dates=dates, + frequency=frequency, + ) + + if len(dates) != total_shape[0]: + raise ValueError( + f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " + f"does not match data shape {total_shape[0]}. {total_shape=}" + ) + + dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) + + metadata.update(self.main_config.get("force_metadata", {})) + + ############################################################### + # write metadata + ############################################################### + + self.update_metadata(**metadata) + + self.dataset.add_dataset( + name="data", + chunks=chunks, + dtype=dtype, + shape=total_shape, + dimensions=("time", "variable", "ensemble", "cell"), + ) + self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) + self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) + self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) + + self.registry.create(lengths=lengths) + self.tmp_statistics.create(exist_ok=False) + self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) + + statistics_start, statistics_end = build_statistics_dates( + dates, + self.main_config.statistics.get("start"), + self.main_config.statistics.get("end"), + ) + self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) + LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") + + self.registry.add_to_history("init finished") + + assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) + + # Return the number of groups to process, so we can show a nice progress bar + return len(lengths) + + +class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to load data into a dataset.""" + + def __init__( + self, + path: str, + parts: str | None = None, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize a Load instance. + + Parameters + ---------- + path : str + The path to the dataset. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + cache : Optional[str], optional + The cache directory. + """ + super().__init__(path, cache=cache) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.parts = parts + self.dataset = WritableDataset(self.path) + + self.main_config = self.dataset.get_main_config() + self.create_elements(self.main_config) + self.read_dataset_metadata(self.dataset.path) + + total = len(self.registry.get_flags()) + self.chunk_filter = ChunkFilter(parts=self.parts, total=total) + + self.data_array = self.dataset.data_array + self.n_groups = len(self.groups) + + def run(self) -> None: + """Run the data loading.""" + with self._cache_context(): + self._run() + + def _run(self) -> None: + """Internal method to run the data loading.""" + for igroup, group in enumerate(self.groups): + if not self.chunk_filter(igroup): + continue + if self.registry.get_flag(igroup): + LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") + continue + + # assert isinstance(group[0], datetime.datetime), type(group[0]) + LOG.debug(f"Building data for group {igroup}/{self.n_groups}") + + result = self.input.select(argument=group) + assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) + + # There are several groups. + # There is one result to load for each group. + self.load_result(result) + self.registry.set_flag(igroup) + + self.registry.add_provenance(name="provenance_load") + self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) + + self.dataset.print_info() + + def load_result(self, result: Any) -> None: + """Load the result into the dataset. + + Parameters + ---------- + result : Any + The result to load. + """ + # There is one cube to load for each result. + dates = list(result.group_of_dates) + + LOG.debug(f"Loading cube for {len(dates)} dates") + + cube = result.get_cube() + shape = cube.extended_user_shape + dates_in_data = cube.user_coords["valid_datetime"] + + LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") + + def check_shape(cube, dates, dates_in_data): + if cube.extended_user_shape[0] != len(dates): + print( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + print("Requested dates", compress_dates(dates)) + print("Cube dates", compress_dates(dates_in_data)) + + a = {as_datetime(_) for _ in dates} + b = {as_datetime(_) for _ in dates_in_data} + + print("Missing dates", compress_dates(a - b)) + print("Extra dates", compress_dates(b - a)) + + raise ValueError( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + + check_shape(cube, dates, dates_in_data) + + def check_dates_in_data(dates_in_data, requested_dates): + _requested_dates = [np.datetime64(_) for _ in requested_dates] + _dates_in_data = [np.datetime64(_) for _ in dates_in_data] + if _dates_in_data != _requested_dates: + LOG.error("Dates in data are not the requested ones:") + + dates_in_data = set(dates_in_data) + requested_dates = set(requested_dates) + + missing = sorted(requested_dates - dates_in_data) + extra = sorted(dates_in_data - requested_dates) + + if missing: + LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") + if extra: + LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") + + raise ValueError("Dates in data are not the requested ones") + + check_dates_in_data(dates_in_data, dates) + + def dates_to_indexes(dates, all_dates): + x = np.array(dates, dtype=np.datetime64) + y = np.array(all_dates, dtype=np.datetime64) + bitmap = np.isin(x, y) + return np.where(bitmap)[0] + + indexes = dates_to_indexes(self.dates, dates_in_data) + + array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) + LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") + self.load_cube(cube, array) + + stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) + self.tmp_statistics.write(indexes, stats, dates=dates_in_data) + LOG.info("Flush data array") + array.flush() + LOG.info("Flushed data array") + + def _get_allow_nans(self) -> bool | list: + """Get the allow_nans configuration. + + Returns + ------- + bool | list + The allow_nans configuration. + """ + config = self.main_config + if "allow_nans" in config.build: + return config.build.allow_nans + + return config.statistics.get("allow_nans", []) + + def load_cube(self, cube: Any, array: ViewCacheArray) -> None: + """Load the cube into the array. + + Parameters + ---------- + cube : Any + The cube to load. + array : ViewCacheArray + The array to load into. + """ + # There are several cubelets for each cube + start = time.time() + load = 0 + save = 0 + + reading_chunks = None + total = cube.count(reading_chunks) + LOG.debug(f"Loading datacube: {cube}") + + def position(x: Any) -> int | None: + if isinstance(x, str) and "/" in x: + x = x.split("/") + return int(x[0]) + return None + + bar = tqdm.tqdm( + iterable=cube.iterate_cubelets(reading_chunks), + total=total, + desc=f"Loading datacube {cube}", + position=position(self.parts), + ) + for i, cubelet in enumerate(bar): + bar.set_description(f"Loading {i}/{total}") + + now = time.time() + data = cubelet.to_numpy() + local_indexes = cubelet.coords + load += time.time() - now + + name = self.variables_names[local_indexes[1]] + check_data_values( + data[:], + name=name, + log=[i, data.shape, local_indexes], + allow_nans=self._get_allow_nans(), + ) + + now = time.time() + array[local_indexes] = data + save += time.time() - now + + now = time.time() + save += time.time() - now + LOG.debug( + f"Elapsed: {seconds_to_human(time.time() - start)}, " + f"load time: {seconds_to_human(load)}, " + f"write time: {seconds_to_human(save)}." + ) + + +class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): + """A class to clean up temporary data and registry entries.""" + + def __init__( + self, + path: str, + statistics_temp_dir: str | None = None, + delta: list = [], + use_threads: bool = False, + **kwargs: Any, + ): + """Initialize a Cleanup instance. + + Parameters + ---------- + path : str + The path to the dataset. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + delta : list, optional + The delta values. + use_threads : bool, optional + Whether to use threads. + """ + super().__init__(path) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.additinon_temp_dir = statistics_temp_dir + self.actors = [ + _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) + for d in delta + ] + + def run(self) -> None: + """Run the cleanup.""" + + self.tmp_statistics.delete() + self.registry.clean() + for actor in self.actors: + actor.cleanup() + + +class Verify(Actor): + """A class to verify the integrity of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Verify instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the verification.""" + LOG.info(f"Verifying dataset at {self.path}") + LOG.info(str(self.dataset.anemoi_dataset)) + + +class AdditionsMixin: + """A mixin class to handle dataset additions.""" + + def skip(self) -> bool: + """Check if the additions should be skipped. + + Returns + ------- + bool + Whether to skip the additions. + """ + frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + if not self.delta.total_seconds() % frequency.total_seconds() == 0: + LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") + return True + + if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: + LOG.warning(f"Additions are disabled for {self.path} in the recipe.") + return True + + return False + + @cached_property + def tmp_storage_path(self) -> str: + """Get the path to the temporary storage.""" + name = "storage_for_additions" + if self.delta: + name += frequency_to_string(self.delta) + return os.path.join(f"{self.path}.{name}.tmp") + + def read_from_dataset(self) -> None: + """Read data from the dataset.""" + self.variables = self.dataset.anemoi_dataset.variables + self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + start = self.dataset.zarr_metadata["statistics_start_date"] + end = self.dataset.zarr_metadata["statistics_end_date"] + self.start = datetime.datetime.fromisoformat(start) + self.end = datetime.datetime.fromisoformat(end) + + ds = open_dataset(self.path, start=self.start, end=self.end) + self.dates = ds.dates + self.total = len(self.dates) + + idelta = self.delta.total_seconds() // self.frequency.total_seconds() + assert int(idelta) == idelta, idelta + idelta = int(idelta) + self.ds = DeltaDataset(ds, idelta) + + +class DeltaDataset: + """A class to represent a dataset with delta values.""" + + def __init__(self, ds: Any, idelta: int): + """Initialize a DeltaDataset instance. + + Parameters + ---------- + ds : Any + The dataset. + idelta : int + The delta value. + """ + self.ds = ds + self.idelta = idelta + + def __getitem__(self, i: int) -> Any: + """Get an item from the dataset. + + Parameters + ---------- + i : int + The index. + + Returns + ------- + Any + The item. + """ + j = i - self.idelta + if j < 0: + raise MissingDateError(f"Missing date {j}") + return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] + + +class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to initialize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize an _InitAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + def run(self) -> None: + """Run the additions initialization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) + self.tmp_storage.delete() + self.tmp_storage.create() + LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") + + def cleanup(self) -> None: + """Clean up the temporary storage.""" + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + self.tmp_storage.delete() + LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") + + +class _LoadAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to run dataset additions.""" + + def __init__( + self, + path: str, + delta: str, + parts: str | None = None, + use_threads: bool = False, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a _LoadAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + self.parts = parts + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Writing in {self.tmp_storage_path}") + + def run(self) -> None: + """Run the additions.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.read_from_dataset() + + chunk_filter = ChunkFilter(parts=self.parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i] + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.debug(f"Dataset {self.path} additions run.") + + def allow_nans(self) -> bool: + """Check if NaNs are allowed. + + Returns + ------- + bool + Whether NaNs are allowed. + """ + if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): + return True + + variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) + if variables_with_nans is not None: + return variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") + return True + + +class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): + """A class to finalize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize a _FinaliseAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Reading from {self.tmp_storage_path}.") + + def run(self) -> None: + """Run the additions finalization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}.") + return + + self.read_from_dataset() + + shape = (len(self.dates), len(self.variables)) + agg = dict( + minimum=np.full(shape, np.nan, dtype=np.float64), + maximum=np.full(shape, np.nan, dtype=np.float64), + sums=np.full(shape, np.nan, dtype=np.float64), + squares=np.full(shape, np.nan, dtype=np.float64), + count=np.full(shape, -1, dtype=np.int64), + has_nans=np.full(shape, False, dtype=np.bool_), + ) + LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") + + found = set() + ifound = set() + missing = set() + for _date, (date, i, stats) in self.tmp_storage.items(): + assert _date == date + if stats == "missing": + missing.add(date) + continue + + assert date not in found, f"Duplicates found {date}" + found.add(date) + ifound.add(i) + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k][i, ...] = stats[k] + + assert len(found) + len(missing) == len(self.dates), ( + len(found), + len(missing), + len(self.dates), + ) + assert found.union(missing) == set(self.dates), ( + found, + missing, + set(self.dates), + ) + + if len(ifound) < 2: + LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") + self.tmp_storage.delete() + return + + mask = sorted(list(ifound)) + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k] = agg[k][mask, ...] + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + assert agg[k].shape == agg["count"].shape, ( + agg[k].shape, + agg["count"].shape, + ) + + minimum = np.nanmin(agg["minimum"], axis=0) + maximum = np.nanmax(agg["maximum"], axis=0) + sums = np.nansum(agg["sums"], axis=0) + squares = np.nansum(agg["squares"], axis=0) + count = np.nansum(agg["count"], axis=0) + has_nans = np.any(agg["has_nans"], axis=0) + + assert sums.shape == count.shape + assert sums.shape == squares.shape + assert sums.shape == minimum.shape + assert sums.shape == maximum.shape + assert sums.shape == has_nans.shape + + mean = sums / count + assert sums.shape == mean.shape + + x = squares / count - mean * mean + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + # remove negative variance due to numerical errors + for i, name in enumerate(self.variables): + x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) + check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) + + stdev = np.sqrt(x) + assert sums.shape == stdev.shape + + self.summary = Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables, + has_nans=has_nans, + ) + LOG.info(f"Dataset {self.path} additions finalised.") + # self.check_statistics() + self._write(self.summary) + self.tmp_storage.delete() + + def _write(self, summary: Summary) -> None: + """Write the summary to the dataset. + + Parameters + ---------- + summary : Summary + The summary to write. + """ + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" + self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) + self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") + LOG.debug(f"Wrote additions in {self.path}") + + +def multi_addition(cls: type) -> type: + """Create a class to handle multiple additions. + + Parameters + ---------- + cls : type + The class to handle additions. + + Returns + ------- + type + The class to handle multiple additions. + """ + + class MultiAdditions: + def __init__(self, *args, **kwargs: Any): + self.actors = [] + + for k in kwargs.pop("delta", []): + self.actors.append(cls(*args, delta=k, **kwargs)) + + if not self.actors: + LOG.warning("No delta found in kwargs, no additions will be computed.") + + def run(self) -> None: + """Run the additions.""" + for actor in self.actors: + actor.run() + + return MultiAdditions + + +InitAdditions = multi_addition(_InitAdditions) +LoadAdditions = multi_addition(_LoadAdditions) +FinaliseAdditions = multi_addition(_FinaliseAdditions) + + +class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): + """A class to compute statistics for a dataset.""" + + def __init__( + self, + path: str, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a Statistics instance. + + Parameters + ---------- + path : str + The path to the dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.use_threads = use_threads + self.progress = progress + self.statistics_temp_dir = statistics_temp_dir + + def run(self) -> None: + """Run the statistics computation.""" + start, end = ( + self.dataset.zarr_metadata["statistics_start_date"], + self.dataset.zarr_metadata["statistics_end_date"], + ) + start, end = np.datetime64(start), np.datetime64(end) + dates = self.dataset.anemoi_dataset.dates + + assert type(dates[0]) is type(start), (type(dates[0]), type(start)) + + dates = [d for d in dates if d >= start and d <= end] + dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] + variables = self.dataset.anemoi_dataset.variables + stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) + + LOG.info(stats) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") + + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: + self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) + + self.registry.add_to_history("compute_statistics_end") + LOG.info(f"Wrote statistics in {self.path}") + + @cached_property + def allow_nans(self) -> bool | list: + """Check if NaNs are allowed.""" + import zarr + + z = zarr.open(self.path, mode="r") + if "allow_nans" in z.attrs: + return z.attrs["allow_nans"] + + if "variables_with_nans" in z.attrs: + return z.attrs["variables_with_nans"] + + warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") + return True + + +def chain(tasks: list) -> type: + """Create a class to chain multiple tasks. + + Parameters + ---------- + tasks : list + The list of tasks to chain. + + Returns + ------- + type + The class to chain multiple tasks. + """ + + class Chain(Actor): + def __init__(self, **kwargs: Any): + self.kwargs = kwargs + + def run(self) -> None: + """Run the chained tasks.""" + for cls in tasks: + t = cls(**self.kwargs) + t.run() + + return Chain + + +def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: + """Create a dataset creator. + + Parameters + ---------- + name : str + The name of the creator. + trace : Optional[str], optional + The trace file. + **kwargs + Additional arguments for the creator. + + Returns + ------- + Any + The dataset creator. + """ + if trace: + + enable_trace(trace) + + cls = dict( + init=Init, + load=Load, + size=Size, + patch=Patch, + statistics=Statistics, + finalise=chain([Statistics, Size, Cleanup]), + cleanup=Cleanup, + verify=Verify, + init_additions=InitAdditions, + load_additions=LoadAdditions, + finalise_additions=chain([FinaliseAdditions, Size]), + additions=chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup]), + )[name] + LOG.debug(f"Creating {cls.__name__} with {kwargs}") + return cls(**kwargs) + + +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise diff --git a/src/anemoi/datasets/build/gridded/check.py b/src/anemoi/datasets/build/gridded/check.py new file mode 100644 index 000000000..3c09cc80b --- /dev/null +++ b/src/anemoi/datasets/build/gridded/check.py @@ -0,0 +1,328 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import re +import warnings +from collections.abc import Callable +from typing import Any + +import numpy as np +from anemoi.utils.config import load_config +from anemoi.utils.dates import frequency_to_string +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +class DatasetName: + """Validate and parse dataset names according to naming conventions.""" + + def __init__( + self, + name: str, + resolution: str | None = None, + start_date: datetime.date | None = None, + end_date: datetime.date | None = None, + frequency: datetime.timedelta | None = None, + ): + """Initialize a DatasetName instance. + + Parameters + ---------- + name : str + The name of the dataset. + resolution : Optional[str], optional + The resolution of the dataset. + start_date : Optional[datetime.date], optional + The start date of the dataset. + end_date : Optional[datetime.date], optional + The end date of the dataset. + frequency : Optional[datetime.timedelta], optional + The frequency of the dataset. + """ + self.name = name + self.parsed = self._parse(name) + print("---------------") + print(self.parsed) + print("---------------") + + self.messages = [] + + config = load_config().get("datasets", {}) + + if config.get("ignore_naming_conventions", False): + # setting the env variable ANEMOI_CONFIG_DATASETS_IGNORE_NAMING_CONVENTIONS=1 + # will ignore the naming conventions + return + + self.check_characters() + self.check_parsed() + self.check_resolution(resolution) + self.check_frequency(frequency) + self.check_start_date(start_date) + self.check_end_date(end_date) + + if self.messages: + self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items())) + + @property + def error_message(self) -> str: + """Generate an error message based on the collected messages.""" + out = " And ".join(self.messages) + if out: + out[0].upper() + out[1:] + return out + + def raise_if_not_valid(self, print: Callable = print) -> None: + """Raise a ValueError if the dataset name is not valid. + + Parameters + ---------- + print : Callable + The function to use for printing messages. + """ + if self.messages: + for m in self.messages: + print(m) + raise ValueError(self.error_message) + + def _parse(self, name: str) -> dict: + """Parse the dataset name into its components. + + Parameters + ---------- + name : str + The name of the dataset. + + Returns + ------- + dict + The parsed components of the dataset name. + """ + pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$" + match = re.match(pattern, name) + + if not match: + raise ValueError(f"the dataset name '{name}' does not follow naming convention. Does not match {pattern}") + + parsed = {} + if match: + keys = [ + "purpose", + "labelling", + "source", + "resolution", + "start_date", + "end_date", + "frequency", + "version", + "additional", + ] + parsed = {k: v for k, v in zip(keys, match.groups())} + + return parsed + + def __str__(self) -> str: + """Return the string representation of the dataset name.""" + return self.name + + def check_parsed(self) -> None: + """Check if the dataset name was parsed correctly.""" + if not self.parsed: + self.messages.append( + f"the dataset name {self} does not follow naming convention. " + "See here for details: " + "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html" + ) + + def check_resolution(self, resolution: str | None) -> None: + """Check if the resolution matches the expected format. + + Parameters + ---------- + resolution : str or None + The expected resolution. + """ + if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on": + self.messages.append( + f"the resolution {self.parsed['resolution'] } should start " + f"with a number or 'o' or 'n' in the dataset name {self}." + ) + + if resolution is None: + return + resolution_str = str(resolution).replace(".", "p").lower() + self._check_missing("resolution", resolution_str) + self._check_mismatch("resolution", resolution_str) + + def check_characters(self) -> None: + if not self.name.islower(): + self.messages.append(f"the {self.name} should be in lower case.") + if "_" in self.name: + self.messages.append(f"the {self.name} should use '-' instead of '_'.") + for c in self.name: + if not c.isalnum() and c not in "-": + self.messages.append(f"the {self.name} should only contain alphanumeric characters and '-'.") + + def check_frequency(self, frequency: datetime.timedelta | None) -> None: + """Check if the frequency matches the expected format. + + Parameters + ---------- + frequency : datetime.timedelta or None + The expected frequency. + """ + if frequency is None: + return + frequency_str = frequency_to_string(frequency) + self._check_missing("frequency", frequency_str) + self._check_mismatch("frequency", frequency_str) + + def check_start_date(self, start_date: datetime.date | None) -> None: + """Check if the start date matches the expected format. + + Parameters + ---------- + start_date : datetime.date or None + The expected start date. + """ + if start_date is None: + return + start_date_str = str(start_date.year) + self._check_missing("start_date", start_date_str) + self._check_mismatch("start_date", start_date_str) + + def check_end_date(self, end_date: datetime.date | None) -> None: + """Check if the end date matches the expected format. + + Parameters + ---------- + end_date : datetime.date or None + The expected end date. + """ + if end_date is None: + return + end_date_str = str(end_date.year) + self._check_missing("end_date", end_date_str) + self._check_mismatch("end_date", end_date_str) + + def _check_missing(self, key: str, value: str) -> None: + """Check if a component is missing from the dataset name. + + Parameters + ---------- + key : str + The component key. + value : str + The expected value. + """ + if value not in self.name: + self.messages.append(f"the {key} is {value}, but is missing in {self.name}.") + + def _check_mismatch(self, key: str, value: str) -> None: + """Check if a component value mismatches the expected value. + + Parameters + ---------- + key : str + The component key. + value : str + The expected value. + """ + if self.parsed.get(key) and self.parsed[key] != value: + self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.") + + +class StatisticsValueError(ValueError): + """Custom error for statistics value issues.""" + + pass + + +def check_data_values( + arr: NDArray[Any], *, name: str, log: list = [], allow_nans: bool | list | set | tuple | dict = False +) -> None: + """Check the values in the data array for validity. + + Parameters + ---------- + arr : NDArray[Any] + The data array to check. + name : str + The name of the data array. + log : list, optional + A list to log messages. + allow_nans : bool or list or set or tuple or dict, optional + Whether to allow NaNs in the data array. + """ + shape = arr.shape + + if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans: + arr = arr[~np.isnan(arr)] + + if arr.size == 0: + warnings.warn(f"Empty array for {name} ({shape})") + return + + assert arr.size > 0, (name, *log) + + min, max = arr.min(), arr.max() + assert not (np.isnan(arr).any()), (name, min, max, *log) + + if min == 9999.0: + warnings.warn(f"Min value 9999 for {name}") + + if max == 9999.0: + warnings.warn(f"Max value 9999 for {name}") + + in_minus_1_plus_1 = dict(minimum=-1, maximum=1) + limits = { + "cos_latitude": in_minus_1_plus_1, + "sin_latitude": in_minus_1_plus_1, + "cos_longitude": in_minus_1_plus_1, + "sin_longitude": in_minus_1_plus_1, + } + + if name in limits: + if min < limits[name]["minimum"]: + warnings.warn( + f"For {name}: minimum value in the data is {min}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" + ) + if max > limits[name]["maximum"]: + warnings.warn( + f"For {name}: maximum value in the data is {max}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" + ) + + +def check_stats(minimum: float, maximum: float, mean: float, msg: str, **kwargs: Any) -> None: + """Check if the mean value is within the min/max interval. + + Parameters + ---------- + minimum : float + The minimum value. + maximum : float + The maximum value. + mean : float + The mean value. + msg : str + The message to include in the error. + **kwargs : Any + Additional keyword arguments. + """ + tolerance = (abs(minimum) + abs(maximum)) * 0.01 + if (mean - minimum < -tolerance) or (mean - minimum < -tolerance): + raise StatisticsValueError( + f"Mean is not in min/max interval{msg} : we should have {minimum} <= {mean} <= {maximum}" + ) diff --git a/src/anemoi/datasets/build/gridded/chunks.py b/src/anemoi/datasets/build/gridded/chunks.py new file mode 100644 index 000000000..08cc1edfd --- /dev/null +++ b/src/anemoi/datasets/build/gridded/chunks.py @@ -0,0 +1,138 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import warnings + +LOG = logging.getLogger(__name__) + +ALL = object() + + +class ChunkFilter: + """A filter to determine which chunks to process based on the specified parts. + + Attributes + ---------- + total : int + The total number of chunks. + allowed : object or list + The chunks that are allowed to be processed. + """ + + def __init__(self, *, parts: str | list, total: int): + """Initializes the ChunkFilter with the given parts and total number of chunks. + + Parameters + ---------- + parts : str or list + The parts to process, specified as 'i/n' or a list of such strings. + total : int + The total number of chunks. + + Raises + ------ + ValueError + If the parts format is invalid. + AssertionError + If the chunk number is invalid. + Warning + If the number of chunks is larger than the total number of chunks. + """ + self.total = total + + if isinstance(parts, list): + if len(parts) == 1: + parts = parts[0] + elif len(parts) == 0: + parts = None + else: + raise ValueError(f"Invalid parts format: {parts}. Must be in the form 'i/n'.") + + if not parts: + parts = "all" + + assert isinstance(parts, str), f"Argument parts must be a string, got {parts}." + + if parts.lower() == "all" or parts == "*": + self.allowed = ALL + return + + assert "/" in parts, f"Invalid parts format: {parts}. Must be in the form 'i/n'." + + i, n = parts.split("/") + i, n = int(i), int(n) + + assert i > 0, f"Chunk number {i} must be positive." + assert i <= n, f"Chunk number {i} must be less than total chunks {n}." + if n > total: + warnings.warn( + f"Number of chunks {n} is larger than the total number of chunks: {total}. " + "Some chunks will be empty." + ) + + chunk_size = total / n + parts = [x for x in range(total) if x >= (i - 1) * chunk_size and x < i * chunk_size] + + for i in parts: + if i < 0 or i >= total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {total - 1}.") + if not parts: + warnings.warn(f"Nothing to do for chunk {i}/{n}.") + + LOG.debug(f"Running parts: {parts}") + + self.allowed = parts + + def __call__(self, i: int) -> bool: + """Checks if the given chunk number is allowed to be processed. + + Parameters + ---------- + i : int + The chunk number to check. + + Returns + ------- + bool + True if the chunk is allowed, False otherwise. + + Raises + ------ + AssertionError + If the chunk number is invalid. + """ + if i < 0 or i >= self.total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {self.total - 1}.") + + if self.allowed == ALL: + return True + return i in self.allowed + + def __iter__(self) -> iter: + """Iterates over the allowed chunks. + + Yields + ------ + int + The next allowed chunk number. + """ + for i in range(self.total): + if self(i): + yield i + + def __len__(self) -> int: + """Returns the number of allowed chunks. + + Returns + ------- + int + The number of allowed chunks. + """ + return len([_ for _ in self]) diff --git a/src/anemoi/datasets/build/gridded/config.py b/src/anemoi/datasets/build/gridded/config.py new file mode 100644 index 000000000..4720ebb6b --- /dev/null +++ b/src/anemoi/datasets/build/gridded/config.py @@ -0,0 +1,445 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from copy import deepcopy +from typing import Any + +import yaml +from anemoi.utils.config import DotDict +from anemoi.utils.config import load_any_dict_format +from earthkit.data.core.order import normalize_order_by + +from anemoi.datasets.dates.groups import Groups + +LOG = logging.getLogger(__name__) + + +def _get_first_key_if_dict(x: str | dict) -> str: + """Returns the first key if the input is a dictionary, otherwise returns the input string. + + Parameters + ---------- + x : str or dict + Input string or dictionary. + + Returns + ------- + str + The first key if input is a dictionary, otherwise the input string. + """ + if isinstance(x, str): + return x + return list(x.keys())[0] + + +def ensure_element_in_list(lst: list, elt: str, index: int) -> list: + """Ensures that a specified element is present at a given index in a list. + + Parameters + ---------- + lst : list + The list to check. + elt : str + The element to ensure is in the list. + index : int + The index at which the element should be present. + + Returns + ------- + list + The modified list with the element at the specified index. + """ + if elt in lst: + assert lst[index] == elt + return lst + + _lst = [_get_first_key_if_dict(d) for d in lst] + if elt in _lst: + assert _lst[index] == elt + return lst + + return lst[:index] + [elt] + lst[index:] + + +def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: + """Checks if a dictionary contains a specific key-value pair and sets it if not present. + + Parameters + ---------- + dic : dict + The dictionary to check. + key : str + The key to check in the dictionary. + value : Any + The value to set if the key is not present. + + Raises + ------ + ValueError + If the key is present but with a different value. + """ + if key in dic: + if dic[key] == value: + return + raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") + LOG.info(f"Setting {key}={value} in config") + dic[key] = value + + +def resolve_includes(config: dict | list) -> dict | list: + """Resolves '<<' includes in a configuration dictionary or list. + + Parameters + ---------- + config : dict or list + The configuration to resolve includes for. + + Returns + ------- + dict or list + The configuration with includes resolved. + """ + if isinstance(config, list): + return [resolve_includes(c) for c in config] + if isinstance(config, dict): + include = config.pop("<<", {}) + new = deepcopy(include) + new.update(config) + return {k: resolve_includes(v) for k, v in new.items()} + return config + + +class Config(DotDict): + """Configuration class that extends DotDict to handle configuration loading and processing.""" + + def __init__(self, config: str | dict | None = None, **kwargs): + """Initializes the Config object. + + Parameters + ---------- + config : str or dict, optional + Path to the configuration file or a dictionary. Defaults to None. + **kwargs + Additional keyword arguments to update the configuration. + """ + if isinstance(config, str): + config = load_any_dict_format(config) + else: + config = deepcopy(config if config is not None else {}) + config = resolve_includes(config) + config.update(kwargs) + super().__init__(config) + + +class OutputSpecs: + """Class to handle output specifications for datasets.""" + + def __init__(self, config: Config, parent: Any): + """Initializes the OutputSpecs object. + + Parameters + ---------- + config : Config + The configuration object. + parent : Any + The parent object. + """ + self.config = config + if "order_by" in config: + assert isinstance(config.order_by, dict), config.order_by + + self.parent = parent + + @property + def dtype(self) -> str: + """Returns the data type for the output.""" + return self.config.dtype + + @property + def order_by_as_list(self) -> list[dict]: + """Returns the order_by configuration as a list of dictionaries.""" + return [{k: v} for k, v in self.config.order_by.items()] + + def get_chunking(self, coords: dict) -> tuple: + """Returns the chunking configuration based on coordinates. + + Parameters + ---------- + coords : dict + The coordinates dictionary. + + Returns + ------- + tuple + The chunking configuration. + """ + user = deepcopy(self.config.chunking) + chunks = [] + for k, v in coords.items(): + if k in user: + chunks.append(user.pop(k)) + else: + chunks.append(len(v)) + if user: + raise ValueError( + f"Unused chunking keys from config: {list(user.keys())}, not in known keys : {list(coords.keys())}" + ) + return tuple(chunks) + + @property + def order_by(self) -> dict: + """Returns the order_by configuration.""" + return self.config.order_by + + @property + def remapping(self) -> dict: + """Returns the remapping configuration.""" + return self.config.remapping + + @property + def flatten_grid(self) -> bool: + """Returns whether the grid should be flattened.""" + return self.config.flatten_grid + + @property + def statistics(self) -> str: + """Returns the statistics configuration.""" + return self.config.statistics + + +class LoadersConfig(Config): + """Configuration class for dataset loaders.""" + + def __init__(self, config: dict, *args, **kwargs): + """Initializes the LoadersConfig object. + + Parameters + ---------- + config : dict + The configuration dictionary. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + """ + super().__init__(config, *args, **kwargs) + + # TODO: should use a json schema to validate the config + + self.setdefault("dataset_status", "experimental") + self.setdefault("description", "No description provided.") + self.setdefault("licence", "unknown") + self.setdefault("attribution", "unknown") + + self.setdefault("build", Config()) + self.build.setdefault("group_by", "monthly") + self.build.setdefault("use_grib_paramid", False) + self.build.setdefault("variable_naming", "default") + variable_naming = dict( + param="{param}", + param_levelist="{param}_{levelist}", + default="{param}_{levelist}", + ).get(self.build.variable_naming, self.build.variable_naming) + + self.setdefault("output", Config()) + self.output.setdefault("order_by", ["valid_datetime", "param_level", "number"]) + self.output.setdefault("remapping", Config(param_level=variable_naming)) + self.output.setdefault("statistics", "param_level") + self.output.setdefault("chunking", Config(dates=1, ensembles=1)) + self.output.setdefault("dtype", "float32") + + if "statistics_start" in self.output: + raise ValueError("statistics_start is not supported anymore. Use 'statistics:start:' instead") + if "statistics_end" in self.output: + raise ValueError("statistics_end is not supported anymore. Use 'statistics:end:' instead") + + self.setdefault("statistics", Config()) + if "allow_nans" not in self.statistics: + self.statistics.allow_nans = [] + + check_dict_value_and_set(self.output, "flatten_grid", True) + check_dict_value_and_set(self.output, "ensemble_dimension", 2) + + assert isinstance(self.output.order_by, (list, tuple)), self.output.order_by + self.output.order_by = ensure_element_in_list(self.output.order_by, "number", self.output.ensemble_dimension) + + order_by = self.output.order_by + assert len(order_by) == 3, order_by + assert _get_first_key_if_dict(order_by[0]) == "valid_datetime", order_by + assert _get_first_key_if_dict(order_by[2]) == "number", order_by + + self.output.order_by = normalize_order_by(self.output.order_by) + + self.setdefault("dates", Config()) + + self.dates["group_by"] = self.build.group_by + + ########### + + self.reading_chunks = self.get("reading_chunks") + + def get_serialisable_dict(self) -> dict: + """Returns a serializable dictionary representation of the configuration. + + Returns + ------- + dict + The serializable dictionary. + """ + return _prepare_serialisation(self) + + +def _prepare_serialisation(o: Any) -> Any: + """Prepares an object for serialization. + + Parameters + ---------- + o : Any + The object to prepare. + + Returns + ------- + Any + The prepared object. + """ + if isinstance(o, dict): + dic = {} + for k, v in o.items(): + v = _prepare_serialisation(v) + if k == "order_by" and isinstance(v, dict): + # zarr attributes are saved with sort_keys=True + # and ordered dict are reordered. + # This is a problem for "order_by" + # We ensure here that the order_by key contains + # a list of dict + v = [{kk: vv} for kk, vv in v.items()] + dic[k] = v + return dic + + if isinstance(o, (list, tuple)): + return [_prepare_serialisation(v) for v in o] + + if o in (None, True, False): + return o + + if isinstance(o, (str, int, float)): + return o + + if isinstance(o, (datetime.date, datetime.datetime)): + return o.isoformat() + + return str(o) + + +def set_to_test_mode(cfg: dict) -> None: + """Modifies the configuration to run in test mode. + + Parameters + ---------- + cfg : dict + The configuration dictionary. + """ + NUMBER_OF_DATES = 4 + + LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") + groups = Groups(**LoadersConfig(cfg).dates) + + dates = groups.provider.values + cfg["dates"] = dict( + start=dates[0], + end=dates[NUMBER_OF_DATES - 1], + frequency=groups.provider.frequency, + group_by=NUMBER_OF_DATES, + ) + + def set_element_to_test(obj): + if isinstance(obj, (list, tuple)): + for v in obj: + set_element_to_test(v) + return + if isinstance(obj, (dict, DotDict)): + if "grid" in obj: + previous = obj["grid"] + obj["grid"] = "20./20." + LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") + if "number" in obj: + if isinstance(obj["number"], (list, tuple)): + previous = obj["number"] + obj["number"] = previous[0:3] + LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") + for k, v in obj.items(): + set_element_to_test(v) + if "constants" in obj: + constants = obj["constants"] + if "param" in constants and isinstance(constants["param"], list): + constants["param"] = ["cos_latitude"] + + set_element_to_test(cfg) + + +def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: + """Loads and validates the configuration for dataset loaders. + + Parameters + ---------- + config : dict + The configuration dictionary. + is_test : bool, optional + Whether to run in test mode. Defaults to False. + + Returns + ------- + LoadersConfig + The validated configuration object. + """ + config = Config(config) + if is_test: + set_to_test_mode(config) + obj = LoadersConfig(config) + + # yaml round trip to check that serialisation works as expected + copy = obj.get_serialisable_dict() + copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader) + copy = Config(copy) + copy = LoadersConfig(config) + + a = yaml.dump(obj) + b = yaml.dump(copy) + if a != b: + print(a) + print(b) + raise ValueError("Serialisation failed") + + if "env" in copy: + for k, v in copy["env"].items(): + LOG.info(f"Setting env variable {k}={v}") + os.environ[k] = str(v) + + return copy + + +def build_output(*args, **kwargs) -> OutputSpecs: + """Builds the output specifications. + + Parameters + ---------- + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + + Returns + ------- + OutputSpecs + The output specifications object. + """ + return OutputSpecs(*args, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/filter.py b/src/anemoi/datasets/build/gridded/filter.py new file mode 100644 index 000000000..4544db8f2 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/filter.py @@ -0,0 +1,47 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +import earthkit.data as ekd + + +class TransformFilter: + """Calls filters from anemoi.transform.filters + + Parameters + ---------- + context : Any + The context in which the filter is created. + name : str + The name of the filter. + config : Dict[str, Any] + The configuration for the filter. + """ + + def __init__(self, context: Any, name: str, config: dict[str, Any]) -> None: + from anemoi.transform.filters import create_filter + + self.name = name + self.transform_filter = create_filter(context, config) + + def execute(self, input: ekd.FieldList) -> ekd.FieldList: + """Execute the transformation filter. + + Parameters + ---------- + input : ekd.FieldList + The input data to be transformed. + + Returns + ------- + ekd.FieldList + The transformed data. + """ + return self.transform_filter.forward(input) diff --git a/src/anemoi/datasets/build/gridded/patch.py b/src/anemoi/datasets/build/gridded/patch.py new file mode 100755 index 000000000..5cb08ec82 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/patch.py @@ -0,0 +1,188 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +import logging +import os + +import zarr + +LOG = logging.getLogger(__name__) + + +def fix_order_by(order_by: dict | list) -> list[dict]: + """Fix the order_by attribute to ensure it is a list of dictionaries. + + Parameters + ---------- + order_by : dict or list + The order_by attribute to fix. + + Returns + ------- + list[dict] + The fixed order_by attribute. + """ + if isinstance(order_by, list): + return order_by + + assert isinstance(order_by, dict), order_by + assert len(order_by) <= 3, order_by + lst = [] + lst.append({"valid_datetime": order_by["valid_datetime"]}) + lst.append({"param_level": order_by["param_level"]}) + lst.append({"number": order_by["number"]}) + return lst + + +def fix_history(history: list[dict]) -> list[dict]: + """Fix the history attribute by removing specific actions. + + Parameters + ---------- + history : list[dict] + The history attribute to fix. + + Returns + ------- + list[dict] + The fixed history attribute. + """ + new = history + new = [d for d in new if d.get("action") != "loading_data_start"] + new = [d for d in new if d.get("action") != "loading_data_end"] + return new + + +def fix_provenance(provenance: dict) -> dict: + """Fix the provenance attribute by adding missing fields and removing unnecessary ones. + + Parameters + ---------- + provenance : dict + The provenance attribute to fix. + + Returns + ------- + dict + The fixed provenance attribute. + """ + if "python" not in provenance: + provenance["python"] = provenance["platform"]["python_version"] + + for q in ( + "args", + "config_paths", + "executable", + "gpus", + "platform", + "python_path", + "assets", + ): + if q in provenance: + del provenance[q] + + for k, v in list(provenance["module_versions"].items()): + if v.startswith("<"): + del provenance["module_versions"][k] + if v.startswith("/"): + provenance["module_versions"][k] = os.path.join("...", os.path.basename(v)) + + for k, v in list(provenance["git_versions"].items()): + LOG.debug(k, v) + modified_files = v["git"].get("modified_files", []) + untracked_files = v["git"].get("untracked_files", []) + if not isinstance(modified_files, int): + modified_files = len(modified_files) + if not isinstance(untracked_files, int): + untracked_files = len(untracked_files) + provenance["git_versions"][k] = dict( + git={ + "sha1": v["git"]["sha1"], + "modified_files": modified_files, + "untracked_files": untracked_files, + } + ) + + LOG.debug(json.dumps(provenance, indent=2)) + # assert False + return provenance + + +def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: + """Apply a patch to the dataset at the given path. + + Parameters + ---------- + path : str + The path to the dataset. + verbose : bool, optional + Whether to log detailed information. Defaults to True. + dry_run : bool, optional + If True, do not actually apply the patch. Defaults to False. + """ + LOG.debug("====================") + LOG.debug(f"Patching {path}") + LOG.debug("====================") + + try: + attrs = zarr.open(path, mode="r").attrs.asdict() + except zarr.errors.PathNotFoundError as e: + LOG.error(f"Failed to open {path}") + LOG.error(e) + exit(0) + + FIXES = { + "history": fix_history, + "provenance_load": fix_provenance, + "provenance_statistics": fix_provenance, + "order_by": fix_order_by, + } + REMOVE = ["_create_yaml_config"] + + before = json.dumps(attrs, sort_keys=True) + + fixed_attrs = {} + for k, v in attrs.items(): + v = attrs[k] + if k in REMOVE: + LOG.info(f"✅ Remove {k}") + continue + + if k not in FIXES: + assert not k.startswith("provenance"), f"[{k}]" + LOG.debug(f"✅ Don't fix {k}") + fixed_attrs[k] = v + continue + + new_v = FIXES[k](v) + if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True): + LOG.info(f"✅ Fix {k}") + if verbose: + LOG.info(f" Before : {k}= {v}") + LOG.info(f" After : {k}= {new_v}") + else: + LOG.debug(f"✅ Unchanged {k}") + fixed_attrs[k] = new_v + + if dry_run: + return + z = zarr.open(path, mode="r+") + + for k in list(z.attrs.keys()): + if k not in fixed_attrs: + del z.attrs[k] + for k, v in fixed_attrs.items(): + z.attrs[k] = v + + after = json.dumps(z.attrs.asdict(), sort_keys=True) + if before != after: + LOG.info("Dataset changed by patch") + + assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True) diff --git a/src/anemoi/datasets/build/gridded/persistent.py b/src/anemoi/datasets/build/gridded/persistent.py new file mode 100644 index 000000000..e52938507 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/persistent.py @@ -0,0 +1,269 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import glob +import hashlib +import json +import logging +import os +import pickle +import shutil +import socket +from collections.abc import Iterator +from typing import Any + +import numpy as np +from anemoi.utils.provenance import gather_provenance_info + +LOG = logging.getLogger(__name__) + + +class PersistentDict: + """A dictionary-like object that persists its contents to disk using pickle files. + + Attributes + ---------- + version : int + The version of the PersistentDict. + dirname : str + The directory where the data is stored. + name : str + The name of the directory. + ext : str + The extension of the directory. + """ + + version = 3 + + # Used in parrallel, during data loading, + # to write data in pickle files. + def __init__(self, directory: str, create: bool = True): + """Initialize the PersistentDict. + + Parameters + ---------- + directory : str + The directory where the data will be stored. + create : bool, optional + Whether to create the directory if it doesn't exist. + """ + self.dirname = directory + self.name, self.ext = os.path.splitext(os.path.basename(self.dirname)) + if create: + self.create() + + def create(self) -> None: + """Create the directory if it doesn't exist.""" + os.makedirs(self.dirname, exist_ok=True) + + def delete(self) -> None: + """Delete the directory and its contents.""" + try: + shutil.rmtree(self.dirname) + except FileNotFoundError: + pass + + def __str__(self) -> str: + """Return a string representation of the PersistentDict.""" + return f"{self.__class__.__name__}({self.dirname})" + + def items(self) -> Iterator[Any]: + """Yield items stored in the directory. + + Yields + ------ + Iterator[Any] + An iterator over the items. + """ + # use glob to read all pickles + files = glob.glob(self.dirname + "/*.pickle") + LOG.debug(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") + assert len(files) > 0, f"No files found in {self.dirname}" + for f in files: + with open(f, "rb") as f: + yield pickle.load(f) + + def add_provenance(self, **kwargs: Any) -> None: + """Add provenance information to the directory. + + Parameters + ---------- + **kwargs : Any + Additional provenance information. + """ + path = os.path.join(self.dirname, "provenance.json") + if os.path.exists(path): + return + out = dict(provenance=gather_provenance_info(), **kwargs) + with open(path, "w") as f: + json.dump(out, f) + + def add(self, elt: Any, *, key: Any) -> None: + """Add an element to the PersistentDict. + + Parameters + ---------- + elt : Any + The element to add. + key : Any + The key associated with the element. + """ + self[key] = elt + + def __setitem__(self, key: Any, elt: Any) -> None: + """Set an item in the PersistentDict. + + Parameters + ---------- + key : Any + The key associated with the element. + elt : Any + The element to set. + """ + h = hashlib.sha256(str(key).encode("utf-8")).hexdigest() + path = os.path.join(self.dirname, f"{h}.pickle") + + if os.path.exists(path): + LOG.warning(f"{path} already exists") + + tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" + with open(tmp_path, "wb") as f: + pickle.dump((key, elt), f) + shutil.move(tmp_path, path) + + LOG.debug(f"Written {self.name} data for len {key} in {path}") + + def flush(self) -> None: + """Flush the PersistentDict (no-op).""" + pass + + +class BufferedPersistentDict(PersistentDict): + """A buffered version of PersistentDict that stores elements in memory before persisting them to disk. + + Attributes + ---------- + buffer_size : int + The size of the buffer. + elements : list + The list of elements in the buffer. + keys : list + The list of keys in the buffer. + storage : PersistentDict + The underlying PersistentDict used for storage. + """ + + def __init__(self, buffer_size: int = 1000, **kwargs: Any): + """Initialize the BufferedPersistentDict. + + Parameters + ---------- + buffer_size : int, optional + The size of the buffer. + **kwargs : Any + Additional arguments for PersistentDict. + """ + self.buffer_size = buffer_size + self.elements = [] + self.keys = [] + self.storage = PersistentDict(**kwargs) + + def add(self, elt: Any, *, key: Any) -> None: + """Add an element to the BufferedPersistentDict. + + Parameters + ---------- + elt : Any + The element to add. + key : Any + The key associated with the element. + """ + self.elements.append(elt) + self.keys.append(key) + if len(self.keys) > self.buffer_size: + self.flush() + + def flush(self) -> None: + """Flush the buffer and store the elements in PersistentDict.""" + k = sorted(self.keys) + self.storage.add(self.elements, key=k) + self.elements = [] + self.keys = [] + + def items(self) -> Iterator[tuple[Any, Any]]: + """Yield items stored in the BufferedPersistentDict. + + Yields + ------ + Iterator[Tuple[Any, Any]] + An iterator over the items. + """ + for keys, elements in self.storage.items(): + yield from zip(keys, elements) + + def delete(self) -> None: + """Delete the storage directory and its contents.""" + self.storage.delete() + + def create(self) -> None: + """Create the storage directory if it doesn't exist.""" + self.storage.create() + + +def build_storage(directory: str, create: bool = True) -> BufferedPersistentDict: + """Build a BufferedPersistentDict storage. + + Parameters + ---------- + directory : str + The directory where the data will be stored. + create : bool, optional + Whether to create the directory if it doesn't exist. + + Returns + ------- + BufferedPersistentDict + The created BufferedPersistentDict. + """ + return BufferedPersistentDict(directory=directory, create=create) + + +if __name__ == "__main__": + N = 3 + P = 2 + directory = "h" + p = PersistentDict(directory=directory) + print(p) + assert os.path.exists(directory) + import numpy as np + + arrs = [np.random.randint(1, 101, size=(P,)) for _ in range(N)] + dates = [np.array([np.datetime64(f"2021-01-0{_+1}") + np.timedelta64(i, "h") for i in range(P)]) for _ in range(N)] + + print() + print("Writing the data") + for i in range(N): + _arr = arrs[i] + _dates = dates[i] + print(f"Writing : {i=}, {_arr=} {_dates=}") + p[_dates] = (i, _arr) + + print() + print("Reading the data back") + + p = PersistentDict(directory="h") + for _dates, (i, _arr) in p.items(): + print(f"{i=}, {_arr=}, {_dates=}") + + assert np.allclose(_arr, arrs[i]) + + assert len(_dates) == len(dates[i]) + for a, b in zip(_dates, dates[i]): + assert a == b diff --git a/src/anemoi/datasets/build/gridded/size.py b/src/anemoi/datasets/build/gridded/size.py new file mode 100644 index 000000000..4cffd66d7 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/size.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import os + +import tqdm +from anemoi.utils.humanize import bytes_to_human + +LOG = logging.getLogger(__name__) + + +def compute_directory_sizes(path: str) -> dict[str, int] | None: + """Computes the total size and number of files in a directory. + + Parameters + ---------- + path : str + The path to the directory. + + Returns + ------- + dict of str to int or None + A dictionary with the total size and number of files, or None if the path is not a directory. + """ + if not os.path.isdir(path): + return None + + size, n = 0, 0 + bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}") + for dirpath, _, filenames in bar: + for filename in filenames: + file_path = os.path.join(dirpath, filename) + size += os.path.getsize(file_path) + n += 1 + + LOG.info(f"Total size: {bytes_to_human(size)}") + LOG.info(f"Total number of files: {n}") + + return dict(total_size=size, total_number_of_files=n) diff --git a/src/anemoi/datasets/build/gridded/source.py b/src/anemoi/datasets/build/gridded/source.py new file mode 100644 index 000000000..df4911690 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/source.py @@ -0,0 +1,51 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from abc import ABC +from abc import abstractmethod + +import earthkit.data as ekd + +from anemoi.datasets.build.typing import DateList + + +class Source(ABC): + """Represents a data source with a given context.""" + + emoji = "📦" # For tracing + + def __init__(self, context: any, *args: tuple, **kwargs: dict): + """Initialise the source. + Parameters + ---------- + context : Any + The context for the data source. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + """ + self.context = context + + @abstractmethod + def execute(self, dates: DateList) -> ekd.FieldList: + """Execute the filter. + + Parameters + ---------- + dates : DateList + The input dates. + + Returns + ------- + ekd.FieldList + The output data. + """ + + pass diff --git a/src/anemoi/datasets/build/gridded/statistics/__init__.py b/src/anemoi/datasets/build/gridded/statistics/__init__.py new file mode 100644 index 000000000..f7ece19bb --- /dev/null +++ b/src/anemoi/datasets/build/gridded/statistics/__init__.py @@ -0,0 +1,561 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import glob +import hashlib +import json +import logging +import os +import pickle +import shutil +import socket +from typing import Any + +import numpy as np +import tqdm +from anemoi.utils.provenance import gather_provenance_info +from numpy.typing import NDArray + +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.statistics.summary import Summary + +LOG = logging.getLogger(__name__) + + +def default_statistics_dates(dates: list[datetime.datetime]) -> tuple[datetime.datetime, datetime.datetime]: + """Calculate default statistics dates based on the given list of dates. + + Parameters + ---------- + dates : list of datetime.datetime + List of datetime objects representing dates. + + Returns + ------- + tuple of datetime.datetime + A tuple containing the default start and end dates. + """ + + def to_datetime(d): + if isinstance(d, np.datetime64): + return d.tolist() + assert isinstance(d, datetime.datetime), d + return d + + first = dates[0] + last = dates[-1] + + first = to_datetime(first) + last = to_datetime(last) + + n_years = round((last - first).total_seconds() / (365.25 * 24 * 60 * 60)) + + if n_years < 10: + # leave out 20% of the data + k = int(len(dates) * 0.8) + end = dates[k - 1] + LOG.info(f"Number of years {n_years} < 10, leaving out 20%. {end=}") + return dates[0], end + + delta = 1 + if n_years >= 20: + delta = 3 + LOG.info(f"Number of years {n_years}, leaving out {delta} years.") + end_year = last.year - delta + + end = max(d for d in dates if to_datetime(d).year == end_year) + return dates[0], end + + +def to_datetime(date: str | datetime.datetime) -> np.datetime64: + """Convert a date to numpy datetime64 format. + + Parameters + ---------- + date : str or datetime.datetime + The date to convert. + + Returns + ------- + numpy.datetime64 + The converted date. + """ + if isinstance(date, str): + return np.datetime64(date) + if isinstance(date, datetime.datetime): + return np.datetime64(date, "s") + return date + + +def to_datetimes(dates: list[str | datetime.datetime]) -> list[np.datetime64]: + """Convert a list of dates to numpy datetime64 format. + + Parameters + ---------- + dates : list of str or datetime.datetime + List of dates to convert. + + Returns + ------- + list of numpy.datetime64 + List of converted dates. + """ + return [to_datetime(d) for d in dates] + + +def fix_variance(x: float, name: str, count: NDArray[Any], sums: NDArray[Any], squares: NDArray[Any]) -> float: + """Fix negative variance values due to numerical errors. + + Parameters + ---------- + x : float + The variance value. + name : str + The variable name. + count : numpy.ndarray + The count array. + sums : numpy.ndarray + The sums array. + squares : numpy.ndarray + The squares array. + + Returns + ------- + float + The fixed variance value. + """ + assert count.shape == sums.shape == squares.shape + assert isinstance(x, float) + + mean = sums / count + assert mean.shape == count.shape + + if x >= 0: + return x + + LOG.warning(f"Negative variance for {name=}, variance={x}") + magnitude = np.sqrt((squares / count + mean * mean) / 2) + LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}") + LOG.warning(f"Variable span order of magnitude is {magnitude}.") + LOG.warning(f"Count is {count}.") + + variances = squares / count - mean * mean + assert variances.shape == squares.shape == mean.shape + if np.all(variances >= 0): + LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.") + return 0 + + # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6: + # LOG.warning("Variance is negative but very small.") + # variances = squares / count - mean * mean + # return 0 + + LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).") + return 0 + + +def check_variance( + x: NDArray[Any], + variables_names: list[str], + minimum: NDArray[Any], + maximum: NDArray[Any], + mean: NDArray[Any], + count: NDArray[Any], + sums: NDArray[Any], + squares: NDArray[Any], +) -> None: + """Check for negative variance values and raise an error if found. + + Parameters + ---------- + x : numpy.ndarray + The variance array. + variables_names : list of str + List of variable names. + minimum : numpy.ndarray + The minimum values array. + maximum : numpy.ndarray + The maximum values array. + mean : numpy.ndarray + The mean values array. + count : numpy.ndarray + The count array. + sums : numpy.ndarray + The sums array. + squares : numpy.ndarray + The squares array. + + Raises + ------ + ValueError + If negative variance is found. + """ + if (x >= 0).all(): + return + print(x) + print(variables_names) + for i, (name, y) in enumerate(zip(variables_names, x)): + if y >= 0: + continue + print("---") + print(f"❗ Negative variance for {name=}, variance={y}") + print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}") + print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}") + print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}") + print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}") + print( + f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}" + ) + + raise ValueError("Negative variance") + + +def compute_statistics( + array: NDArray[Any], check_variables_names: list[str] | None = None, allow_nans: bool = False +) -> dict[str, np.ndarray]: + """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary. + + Parameters + ---------- + array : numpy.ndarray + The array to compute statistics for. + check_variables_names : list of str, optional + List of variable names to check. Defaults to None. + allow_nans : bool, optional + Whether to allow NaN values. Defaults to False. + + Returns + ------- + dict of str to numpy.ndarray + A dictionary containing the computed statistics. + """ + LOG.info(f"Computing statistics for {array.shape} array") + nvars = array.shape[1] + + LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}") + if check_variables_names: + assert nvars == len(check_variables_names), (nvars, check_variables_names) + stats_shape = (array.shape[0], nvars) + + count = np.zeros(stats_shape, dtype=np.int64) + sums = np.zeros(stats_shape, dtype=np.float64) + squares = np.zeros(stats_shape, dtype=np.float64) + minimum = np.zeros(stats_shape, dtype=np.float64) + maximum = np.zeros(stats_shape, dtype=np.float64) + has_nans = np.zeros(stats_shape, dtype=np.bool_) + + for i, chunk in tqdm.tqdm(enumerate(array), delay=1, total=array.shape[0], desc="Computing statistics"): + values = chunk.reshape((nvars, -1)) + + for j, name in enumerate(check_variables_names): + check_data_values(values[j, :], name=name, allow_nans=allow_nans) + if np.isnan(values[j, :]).all(): + # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") + LOG.warning(f"All NaN values for {name} ({j}) for date {i}") + + # Ignore NaN values + minimum[i] = np.nanmin(values, axis=1) + maximum[i] = np.nanmax(values, axis=1) + sums[i] = np.nansum(values, axis=1) + squares[i] = np.nansum(values * values, axis=1) + count[i] = np.sum(~np.isnan(values), axis=1) + has_nans[i] = np.isnan(values).any() + + LOG.info(f"Statistics computed for {nvars} variables.") + + return { + "minimum": minimum, + "maximum": maximum, + "sums": sums, + "squares": squares, + "count": count, + "has_nans": has_nans, + } + + +class TmpStatistics: + """Temporary statistics storage class.""" + + version = 3 + # Used in parrallel, during data loading, + # to write statistics in pickled npz files. + # can provide statistics for a subset of dates. + + def __init__(self, dirname: str, overwrite: bool = False) -> None: + """Initialize TmpStatistics. + + Parameters + ---------- + dirname : str + Directory name for storing statistics. + overwrite : bool, optional + Whether to overwrite existing files. Defaults to False. + """ + self.dirname = dirname + self.overwrite = overwrite + + def add_provenance(self, **kwargs: dict) -> None: + """Add provenance information. + + Parameters + ---------- + **kwargs : dict + Additional provenance information. + """ + self.create(exist_ok=True) + path = os.path.join(self.dirname, "provenance.json") + if os.path.exists(path): + return + out = dict(provenance=gather_provenance_info(), **kwargs) + with open(path, "w") as f: + json.dump(out, f) + + def create(self, exist_ok: bool) -> None: + """Create the directory for storing statistics. + + Parameters + ---------- + exist_ok : bool + Whether to ignore if the directory already exists. + """ + os.makedirs(self.dirname, exist_ok=exist_ok) + + def delete(self) -> None: + """Delete the directory for storing statistics.""" + try: + shutil.rmtree(self.dirname) + except FileNotFoundError: + pass + + def write(self, key: str, data: any, dates: list[datetime.datetime]) -> None: + """Write statistics data to a file. + + Parameters + ---------- + key : str + The key for the data. + data : any + The data to write. + dates : list of datetime.datetime + List of dates associated with the data. + """ + self.create(exist_ok=True) + h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest() + path = os.path.join(self.dirname, f"{h}.npz") + + if not self.overwrite: + assert not os.path.exists(path), f"{path} already exists" + + tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" + with open(tmp_path, "wb") as f: + pickle.dump((key, dates, data), f) + shutil.move(tmp_path, path) + + LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})") + + def _gather_data(self) -> tuple[str, list[datetime.datetime], dict]: + """Gather data from stored files. + + Yields + ------ + tuple of str, list of datetime.datetime, dict + A tuple containing key, dates, and data. + """ + # use glob to read all pickles + files = glob.glob(self.dirname + "/*.npz") + LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}") + assert len(files) > 0, f"No files found in {self.dirname}" + for f in files: + with open(f, "rb") as f: + yield pickle.load(f) + + def get_aggregated(self, *args: Any, **kwargs: Any) -> Summary: + """Get aggregated statistics. + + Parameters + ---------- + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Summary + The aggregated statistics summary. + """ + aggregator = StatAggregator(self, *args, **kwargs) + return aggregator.aggregate() + + def __str__(self) -> str: + """String representation of TmpStatistics. + + Returns + ------- + str + The string representation. + """ + return f"TmpStatistics({self.dirname})" + + +class StatAggregator: + """Statistics aggregator class.""" + + NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"] + + def __init__( + self, owner: TmpStatistics, dates: list[datetime.datetime], variables_names: list[str], allow_nans: bool + ) -> None: + """Initialize StatAggregator. + + Parameters + ---------- + owner : TmpStatistics + The owner TmpStatistics instance. + dates : list of datetime.datetime + List of dates to aggregate. + variables_names : list of str + List of variable names. + allow_nans : bool + Whether to allow NaN values. + """ + dates = sorted(dates) + dates = to_datetimes(dates) + assert dates, "No dates selected" + self.owner = owner + self.dates = dates + self._number_of_dates = len(dates) + self._set_of_dates = set(dates) + self.variables_names = variables_names + self.allow_nans = allow_nans + + self.shape = (self._number_of_dates, len(self.variables_names)) + LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") + + self.minimum = np.full(self.shape, np.nan, dtype=np.float64) + self.maximum = np.full(self.shape, np.nan, dtype=np.float64) + self.sums = np.full(self.shape, np.nan, dtype=np.float64) + self.squares = np.full(self.shape, np.nan, dtype=np.float64) + self.count = np.full(self.shape, -1, dtype=np.int64) + self.has_nans = np.full(self.shape, False, dtype=np.bool_) + + self._read() + + def _read(self) -> None: + """Read and aggregate statistics data from files.""" + + def check_type(a, b): + if not isinstance(a, set): + a = set(list(a)) + if not isinstance(b, set): + b = set(list(b)) + a = next(iter(a)) if a else None + b = next(iter(b)) if b else None + assert type(a) is type(b), (type(a), type(b)) + + found = set() + offset = 0 + + for _, _dates, stats in self.owner._gather_data(): + assert isinstance(stats, dict), stats + assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates)) + assert stats["minimum"].shape[1] == len(self.variables_names), ( + stats["minimum"].shape, + len(self.variables_names), + ) + for n in self.NAMES: + assert n in stats, (n, list(stats.keys())) + _dates = to_datetimes(_dates) + check_type(_dates, self._set_of_dates) + if found: + check_type(found, self._set_of_dates) + assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics" + + # filter dates + dates = set(_dates) & self._set_of_dates + + if not dates: + # dates have been completely filtered for this chunk + continue + + # filter data + bitmap = np.array([d in self._set_of_dates for d in _dates]) + for k in self.NAMES: + stats[k] = stats[k][bitmap] + + assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates)) + + # store data in self + found |= set(dates) + for name in self.NAMES: + array = getattr(self, name) + assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates)) + array[offset : offset + len(dates)] = stats[name] + offset += len(dates) + + for d in self.dates: + assert d in found, f"Statistics for date {d} not precomputed." + assert self._number_of_dates == len(found), "Not all dates found in precomputed statistics" + assert self._number_of_dates == offset, "Not all dates found in precomputed statistics." + LOG.debug(f"Statistics for {len(found)} dates found.") + + def aggregate(self) -> Summary: + """Aggregate the statistics data. + + Returns + ------- + Summary + The aggregated statistics summary. + """ + minimum = np.nanmin(self.minimum, axis=0) + maximum = np.nanmax(self.maximum, axis=0) + + sums = np.nansum(self.sums, axis=0) + squares = np.nansum(self.squares, axis=0) + count = np.nansum(self.count, axis=0) + has_nans = np.any(self.has_nans, axis=0) + assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape + + mean = sums / count + assert mean.shape == minimum.shape + + x = squares / count - mean * mean + assert x.shape == minimum.shape + + for i, name in enumerate(self.variables_names): + # remove negative variance due to numerical errors + x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1]) + + for i, name in enumerate(self.variables_names): + check_variance( + x[i : i + 1], + [name], + minimum[i : i + 1], + maximum[i : i + 1], + mean[i : i + 1], + count[i : i + 1], + sums[i : i + 1], + squares[i : i + 1], + ) + check_data_values(np.array([mean[i]]), name=name, allow_nans=False) + + stdev = np.sqrt(x) + + return Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables_names, + has_nans=has_nans, + ) diff --git a/src/anemoi/datasets/build/gridded/statistics/summary.py b/src/anemoi/datasets/build/gridded/statistics/summary.py new file mode 100644 index 000000000..59f3998b4 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/statistics/summary.py @@ -0,0 +1,152 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +from collections import defaultdict +from typing import Any + +import numpy as np + +from anemoi.datasets.build.check import StatisticsValueError +from anemoi.datasets.build.check import check_data_values +from anemoi.datasets.build.check import check_stats + + +class Summary(dict): + """This class is used to store the summary statistics of a dataset. It can be saved and loaded from a json file. And does some basic checks on the data.""" + + STATS_NAMES = [ + "minimum", + "maximum", + "mean", + "stdev", + "has_nans", + ] # order matter for __str__. + + def __init__(self, **kwargs: Any) -> None: + """Initialize the Summary object with given keyword arguments. + + Parameters + ---------- + **kwargs : Any + Arbitrary keyword arguments representing summary statistics. + """ + super().__init__(**kwargs) + self.check() + + @property + def size(self) -> int: + """Get the size of the summary, which is the number of variables.""" + return len(self["variables_names"]) + + def check(self) -> None: + """Perform checks on the summary statistics to ensure they are valid. + + Raises + ------ + AssertionError + If any of the checks fail. + StatisticsValueError + If any of the statistical checks fail. + """ + for k, v in self.items(): + if k == "variables_names": + assert len(v) == self.size + continue + assert v.shape == (self.size,) + if k == "count": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.int64, (k, v) + continue + if k == "has_nans": + assert v.dtype == np.bool_, (k, v) + continue + if k == "stdev": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.float64, (k, v) + + for i, name in enumerate(self["variables_names"]): + try: + check_stats(**{k: v[i] for k, v in self.items()}, msg=f"{i} {name}") + check_data_values(self["minimum"][i], name=name) + check_data_values(self["maximum"][i], name=name) + check_data_values(self["mean"][i], name=name) + except StatisticsValueError as e: + e.args += (i, name) + raise + + def __str__(self) -> str: + """Return a string representation of the summary statistics. + + Returns + ------- + str + A formatted string of the summary statistics. + """ + header = ["Variables"] + self.STATS_NAMES + out = [" ".join(header)] + + out += [ + " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) + for i, v in enumerate(self["variables_names"]) + ] + return "\n".join(out) + + def save(self, filename: str, **metadata: Any) -> None: + """Save the summary statistics to a JSON file. + + Parameters + ---------- + filename : str + The name of the file to save the summary statistics. + **metadata : Any + Additional metadata to include in the JSON file. + """ + assert filename.endswith(".json"), filename + dic = {} + for k in self.STATS_NAMES: + dic[k] = list(self[k]) + + out = dict(data=defaultdict(dict)) + for i, name in enumerate(self["variables_names"]): + for k in self.STATS_NAMES: + out["data"][name][k] = dic[k][i] + + out["metadata"] = metadata + + with open(filename, "w") as f: + json.dump(out, f, indent=2) + + def load(self, filename: str) -> "Summary": + """Load the summary statistics from a JSON file. + + Parameters + ---------- + filename : str + The name of the file to load the summary statistics from. + + Returns + ------- + Summary + The loaded Summary object. + """ + assert filename.endswith(".json"), filename + with open(filename) as f: + dic = json.load(f) + + dic_ = {} + for k, v in dic.items(): + if k == "count": + dic_[k] = np.array(v, dtype=np.int64) + continue + if k == "variables": + dic_[k] = v + continue + dic_[k] = np.array(v, dtype=np.float64) + return Summary(dic_) diff --git a/src/anemoi/datasets/build/gridded/testing.py b/src/anemoi/datasets/build/gridded/testing.py new file mode 100644 index 000000000..5363cd9f7 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/testing.py @@ -0,0 +1,4 @@ +class TestingContext: + """A context for testing plugins.""" + + pass diff --git a/src/anemoi/datasets/build/gridded/typing.py b/src/anemoi/datasets/build/gridded/typing.py new file mode 100644 index 000000000..0eafdb193 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/typing.py @@ -0,0 +1,14 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime + +Date = datetime.datetime + +DateList = list[Date] diff --git a/src/anemoi/datasets/build/gridded/utils.py b/src/anemoi/datasets/build/gridded/utils.py new file mode 100644 index 000000000..00ea89e7b --- /dev/null +++ b/src/anemoi/datasets/build/gridded/utils.py @@ -0,0 +1,198 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import os +import warnings +from contextlib import contextmanager +from typing import Any + +import numpy as np +from earthkit.data import settings +from numpy.typing import NDArray + + +def cache_context(dirname: str) -> contextmanager: + """Context manager for setting a temporary cache directory. + + Parameters + ---------- + dirname : str + The directory name for the cache. + + Returns + ------- + contextmanager + A context manager that sets the cache directory. + """ + + @contextmanager + def no_cache_context(): + yield + + if dirname is None: + return no_cache_context() + + os.makedirs(dirname, exist_ok=True) + # return settings.temporary("cache-directory", dirname) + return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname}) + + +def to_datetime_list(*args: Any, **kwargs: Any) -> list[datetime.datetime]: + """Convert various date formats to a list of datetime objects. + + Parameters + ---------- + *args : Any + Positional arguments for date conversion. + **kwargs : Any + Keyword arguments for date conversion. + + Returns + ------- + list[datetime.datetime] + A list of datetime objects. + """ + from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_ + + warnings.warn( + "to_datetime_list() is deprecated. Call earthkit.data.utils.dates.to_datetime_list() instead.", + DeprecationWarning, + stacklevel=2, + ) + return to_datetime_list_(*args, **kwargs) + + +def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime: + """Convert various date formats to a single datetime object. + + Parameters + ---------- + *args : Any + Positional arguments for date conversion. + **kwargs : Any + Keyword arguments for date conversion. + + Returns + ------- + datetime.datetime + A datetime object. + """ + from earthkit.data.utils.dates import to_datetime as to_datetime_ + + warnings.warn( + "to_datetime() is deprecated. Call earthkit.data.utils.dates.to_datetime() instead.", + DeprecationWarning, + stacklevel=2, + ) + + return to_datetime_(*args, **kwargs) + + +def make_list_int(value: str | list | tuple | int) -> list[int]: + """Convert a string, list, tuple, or integer to a list of integers. + + Parameters + ---------- + value : str or list or tuple or int + The value to convert. + + Returns + ------- + list[int] + A list of integers. + + Raises + ------ + ValueError + If the value cannot be converted to a list of integers. + """ + # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers. + # Moved to anemoi.utils.humanize + # replace with from anemoi.utils.humanize import make_list_int + # when anemoi-utils is released and pyproject.toml is updated + if isinstance(value, str): + if "/" not in value: + return [value] + bits = value.split("/") + if len(bits) == 3 and bits[1].lower() == "to": + value = list(range(int(bits[0]), int(bits[2]) + 1, 1)) + + elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by": + value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4]))) + + if isinstance(value, list): + return value + if isinstance(value, tuple): + return value + if isinstance(value, int): + return [value] + + raise ValueError(f"Cannot make list from {value}") + + +def normalize_and_check_dates( + dates: list[datetime.datetime], + start: datetime.datetime, + end: datetime.datetime, + frequency: datetime.timedelta, + dtype: str = "datetime64[s]", +) -> NDArray[Any]: + """Normalize and check a list of dates against a specified frequency. + + Parameters + ---------- + dates : list[datetime.datetime] + The list of dates to check. + start : datetime.datetime + The start date. + end : datetime.datetime + The end date. + frequency : datetime.timedelta + The frequency of the dates. + dtype : str, optional + The data type of the dates, by default "datetime64[s]". + + Returns + ------- + NDArray[Any] + An array of normalized dates. + + Raises + ------ + ValueError + If the final date size does not match the data shape. + """ + dates = [d.hdate if hasattr(d, "hdate") else d for d in dates] + + assert isinstance(frequency, datetime.timedelta), frequency + start = np.datetime64(start) + end = np.datetime64(end) + delta = np.timedelta64(frequency) + + res = [] + while start <= end: + res.append(start) + start += delta + dates_ = np.array(res).astype(dtype) + + if len(dates_) != len(dates): + raise ValueError( + f"Final date size {len(dates_)} (from {dates_[0]} to {dates_[-1]}, " + f"{frequency=}) does not match data shape {len(dates)} (from {dates[0]} to " + f"{dates[-1]})." + ) + + for i, (d1, d2) in enumerate(zip(dates, dates_)): + d1 = np.datetime64(d1) + d2 = np.datetime64(d2) + assert d1 == d2, (i, d1, d2) + + return dates_ diff --git a/src/anemoi/datasets/build/gridded/writer.py b/src/anemoi/datasets/build/gridded/writer.py new file mode 100644 index 000000000..d573c1ca5 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/writer.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +class ViewCacheArray: + """A class that provides a caching mechanism for writing to a NumPy-like array. + + The is initialised with a NumPy-like array, a shape and a list to reindex the first + dimension. The array is used to store the final data, while the cache is used to + temporarily store the data before flushing it to the array. + + The `flush` method copies the contents of the cache to the final array. + """ + + def __init__(self, array: NDArray[Any], *, shape: tuple[int, ...], indexes: list[int]): + """Initialize the ViewCacheArray. + + Parameters + ---------- + array : NDArray[Any] + The NumPy-like array to store the final data. + shape : tuple[int, ...] + The shape of the cache array. + indexes : list[int] + List to reindex the first dimension. + """ + assert len(indexes) == shape[0], (len(indexes), shape[0]) + self.array = array + self.dtype = array.dtype + self.cache = np.full(shape, np.nan, dtype=self.dtype) + self.indexes = indexes + + def __setitem__(self, key: tuple[int, ...], value: NDArray[Any]) -> None: + """Set the value in the cache array at the specified key. + + Parameters + ---------- + key : tuple[int, ...] + The index key to set the value. + value : NDArray[Any] + The value to set in the cache array. + """ + self.cache[key] = value + + def flush(self) -> None: + """Copy the contents of the cache to the final array.""" + for i in range(self.cache.shape[0]): + global_i = self.indexes[i] + self.array[global_i] = self.cache[i] diff --git a/src/anemoi/datasets/build/gridded/zarr.py b/src/anemoi/datasets/build/gridded/zarr.py new file mode 100644 index 000000000..32b493dd3 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/zarr.py @@ -0,0 +1,331 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import shutil +from typing import Any + +import numpy as np +import zarr +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +def add_zarr_dataset( + *, + name: str, + dtype: np.dtype = None, + fill_value: np.generic = None, + zarr_root: zarr.Group, + shape: tuple[int, ...] = None, + array: NDArray[Any] = None, + overwrite: bool = True, + dimensions: tuple[str, ...] = None, + **kwargs, +) -> zarr.Array: + """Add a dataset to a Zarr group. + + Parameters + ---------- + name : str + Name of the dataset. + dtype : np.dtype, optional + Data type of the dataset. + fill_value : np.generic, optional + Fill value for the dataset. + zarr_root : zarr.Group + Root Zarr group. + shape : tuple[int, ...], optional + Shape of the dataset. + array : NDArray[Any], optional + Array to initialize the dataset with. + overwrite : bool + Whether to overwrite existing dataset. + dimensions : tuple[str, ...] + Dimensions of the dataset. + **kwargs + Additional arguments for Zarr dataset creation. + + Returns + ------- + zarr.Array + The created Zarr array. + """ + assert dimensions is not None, "Please pass dimensions to add_zarr_dataset." + assert isinstance(dimensions, (tuple, list)) + + if dtype is None: + assert array is not None, (name, shape, array, dtype, zarr_root) + dtype = array.dtype + + if shape is None: + assert array is not None, (name, shape, array, dtype, zarr_root) + shape = array.shape + + if array is not None: + assert array.shape == shape, (array.shape, shape) + a = zarr_root.create_dataset( + name, + shape=shape, + dtype=dtype, + overwrite=overwrite, + **kwargs, + ) + a[...] = array + a.attrs["_ARRAY_DIMENSIONS"] = dimensions + return a + + if "fill_value" not in kwargs: + if str(dtype).startswith("float") or str(dtype).startswith("numpy.float"): + kwargs["fill_value"] = np.nan + elif str(dtype).startswith("datetime64") or str(dtype).startswith("numpy.datetime64"): + kwargs["fill_value"] = np.datetime64("NaT") + # elif str(dtype).startswith("timedelta64") or str(dtype).startswith( + # "numpy.timedelta64" + # ): + # kwargs["fill_value"] = np.timedelta64("NaT") + elif str(dtype).startswith("int") or str(dtype).startswith("numpy.int"): + kwargs["fill_value"] = 0 + elif str(dtype).startswith("bool") or str(dtype).startswith("numpy.bool"): + kwargs["fill_value"] = False + else: + raise ValueError(f"No fill_value for dtype={dtype}") + + a = zarr_root.create_dataset( + name, + shape=shape, + dtype=dtype, + overwrite=overwrite, + **kwargs, + ) + a.attrs["_ARRAY_DIMENSIONS"] = dimensions + return a + + +class ZarrBuiltRegistry: + """A class to manage the creation and access of Zarr datasets.""" + + name_lengths = "lengths" + name_flags = "flags" + lengths = None + flags = None + z = None + + def __init__(self, path: str, synchronizer_path: str | None = None, use_threads: bool = False): + """Initialize the ZarrBuiltRegistry. + + Parameters + ---------- + path : str + Path to the Zarr store. + synchronizer_path : Optional[str], optional + Path to the synchronizer. + use_threads : bool + Whether to use thread-based synchronization. + """ + import zarr + + assert isinstance(path, str), path + self.zarr_path = path + + if use_threads: + self.synchronizer = zarr.ThreadSynchronizer() + self.synchronizer_path = None + else: + if synchronizer_path is None: + synchronizer_path = self.zarr_path + ".sync" + self.synchronizer_path = synchronizer_path + self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) + + def clean(self) -> None: + """Clean up the synchronizer path.""" + if self.synchronizer_path is not None: + try: + shutil.rmtree(self.synchronizer_path) + except FileNotFoundError: + pass + + _build = self.zarr_path + "/_build" + try: + shutil.rmtree(_build) + except FileNotFoundError: + pass + + def _open_write(self) -> zarr.Group: + """Open the Zarr store in write mode.""" + import zarr + + return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer) + + def _open_read(self, sync: bool = True) -> zarr.Group: + """Open the Zarr store in read mode. + + Parameters + ---------- + sync : bool + Whether to use synchronization. + + Returns + ------- + zarr.Group + The opened Zarr group. + """ + import zarr + + if sync: + return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer) + else: + return zarr.open(self.zarr_path, mode="r") + + def new_dataset(self, *args, **kwargs) -> None: + """Create a new dataset in the Zarr store. + + Parameters + ---------- + *args + Positional arguments for dataset creation. + **kwargs + Keyword arguments for dataset creation. + """ + z = self._open_write() + zarr_root = z["_build"] + add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + + def add_to_history(self, action: str, **kwargs) -> None: + """Add an action to the history attribute of the Zarr store. + + Parameters + ---------- + action : str + The action to record. + **kwargs + Additional information about the action. + """ + new = dict( + action=action, + timestamp=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat(), + ) + new.update(kwargs) + + z = self._open_write() + history = z.attrs.get("history", []) + history.append(new) + z.attrs["history"] = history + + def get_lengths(self) -> list[int]: + """Get the lengths dataset. + + Returns + ------- + list[int] + The lengths dataset. + """ + z = self._open_read() + return list(z["_build"][self.name_lengths][:]) + + def get_flags(self, **kwargs) -> list[bool]: + """Get the flags dataset. + + Parameters + ---------- + **kwargs + Additional arguments for reading the dataset. + + Returns + ------- + list[bool] + The flags dataset. + """ + z = self._open_read(**kwargs) + return list(z["_build"][self.name_flags][:]) + + def get_flag(self, i: int) -> bool: + """Get a specific flag. + + Parameters + ---------- + i : int + Index of the flag. + + Returns + ------- + bool + The flag value. + """ + z = self._open_read() + return z["_build"][self.name_flags][i] + + def set_flag(self, i: int, value: bool = True) -> None: + """Set a specific flag. + + Parameters + ---------- + i : int + Index of the flag. + value : bool + Value to set the flag to. + """ + z = self._open_write() + z.attrs["latest_write_timestamp"] = ( + datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() + ) + z["_build"][self.name_flags][i] = value + + def ready(self) -> bool: + """Check if all flags are set. + + Returns + ------- + bool + True if all flags are set, False otherwise. + """ + return all(self.get_flags()) + + def create(self, lengths: list[int], overwrite: bool = False) -> None: + """Create the lengths and flags datasets. + + Parameters + ---------- + lengths : list[int] + Lengths to initialize the dataset with. + overwrite : bool + Whether to overwrite existing datasets. + """ + self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4")) + self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool)) + self.add_to_history("initialised") + + def reset(self, lengths: list[int]) -> None: + """Reset the lengths and flags datasets. + + Parameters + ---------- + lengths : list[int] + Lengths to initialize the dataset with. + """ + return self.create(lengths, overwrite=True) + + def add_provenance(self, name: str) -> None: + """Add provenance information to the Zarr store. + + Parameters + ---------- + name : str + Name of the provenance attribute. + """ + z = self._open_write() + + if name in z.attrs: + return + + from anemoi.utils.provenance import gather_provenance_info + + z.attrs[name] = gather_provenance_info() diff --git a/src/anemoi/datasets/misc/check.py b/src/anemoi/datasets/misc/check.py new file mode 100644 index 000000000..d795d13f9 --- /dev/null +++ b/src/anemoi/datasets/misc/check.py @@ -0,0 +1,93 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +# A collection of functions to support pytest testing + +import logging +import math +import os +import re + +LOG = logging.getLogger(__name__) + + +def _check_group(group, verbosity: int, *path) -> None: + import zarr + + group_keys = sorted(group.keys()) + if not group_keys: + raise ValueError(f"Check group: {group} is empty.") + + for name in sorted(group_keys): + if name.startswith("."): + if verbosity > 1: + LOG.info(f"Check group: skipping {name}") + continue + + if isinstance(group[name], zarr.hierarchy.Group): + _check_group(group[name], verbosity, *path, name) + else: + _check_array(group[name], verbosity, *path, name) + + +def _check_array(array, verbosity: int, *path) -> None: + assert len(array.chunks) == len(array.shape) + assert math.prod(array.shape) % math.prod(array.chunks) == 0 + + file_count = math.prod(array.shape) // math.prod(array.chunks) + + full = os.path.join(*path) + + chunks = array.chunks + + count = 0 + for f in os.listdir(full): + if verbosity > 1: + LOG.info(f"Check array: checking {f}") + + if f.startswith("."): + if verbosity > 1: + LOG.info(f"Check array: skipping {f}") + continue + + bits = f.split(".") + + if len(bits) != len(chunks): + raise ValueError(f"File {f} is not a valid chunk file.") + + if not all(re.match(r"^\d+$", bit) for bit in bits): + raise ValueError(f"File {f} is not a valid chunk file.") + + count += 1 + + if count != file_count: + raise ValueError(f"File count {count} does not match expected {file_count} for {array.name}.") + + +def check_zarr(path: str, verbosity: int = 0) -> None: + """Check if a Zarr archive is valid, that no files are missing, and that the chunking is correct. + + Parameters + ---------- + path : str + Path to the Zarr archive. + verbosity : int, optional + Verbosity level for logging. Default is 0 (no logging). + """ + import zarr + + if verbosity > 0: + LOG.info(f"Checking Zarr archive {path}") + + if not os.path.exists(path) and not os.path.isdir(path): + # This does not work with non-directory Zarr archives + raise ValueError(f"Path {path} does not exist.") + + _check_group(zarr.open(path, mode="r"), verbosity, path) diff --git a/src/anemoi/datasets/misc/dumper.py b/src/anemoi/datasets/misc/dumper.py new file mode 100644 index 000000000..18c8d34d4 --- /dev/null +++ b/src/anemoi/datasets/misc/dumper.py @@ -0,0 +1,76 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import io +import logging + +import ruamel.yaml + +LOG = logging.getLogger(__name__) + + +def represent_date(dumper, data): + + if isinstance(data, datetime.datetime): + if data.tzinfo is None: + data = data.replace(tzinfo=datetime.timezone.utc) + data = data.astimezone(datetime.timezone.utc) + iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z" + else: + iso_str = data.isoformat() + + return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str) + + +# --- Represent multiline strings with | style --- +def represent_multiline_str(dumper, data): + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|") + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +# --- Represent short lists inline (flow style) --- +def represent_inline_list(dumper, data): + + if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data): + return dumper.represent_sequence("tag:yaml.org,2002:seq", data) + + return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True) + + +def yaml_dump(obj, order=None, stream=None, **kwargs): + + if order: + + def _ordering(k): + return order.index(k) if k in order else len(order) + + obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))} + + yaml = ruamel.yaml.YAML() + yaml.width = 120 # wrap long flow sequences + + yaml.Representer.add_representer(datetime.date, represent_date) + yaml.Representer.add_representer(datetime.datetime, represent_date) + yaml.Representer.add_representer(str, represent_multiline_str) + yaml.Representer.add_representer(list, represent_inline_list) + + data = ruamel.yaml.comments.CommentedMap() + for i, (k, v) in enumerate(obj.items()): + data[k] = v + if i > 0: + data.yaml_set_comment_before_after_key(key=k, before="\n") + + if stream: + yaml.dump(data, stream=stream, **kwargs) + + stream = io.StringIO() + yaml.dump(data, stream=stream, **kwargs) + return stream.getvalue() diff --git a/src/anemoi/datasets/misc/grids.py b/src/anemoi/datasets/misc/grids.py new file mode 100644 index 000000000..26f675526 --- /dev/null +++ b/src/anemoi/datasets/misc/grids.py @@ -0,0 +1,668 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import base64 +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +def plot_mask( + path: str, + mask: NDArray[Any], + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], +) -> None: + """Plot and save various visualizations of the mask and coordinates. + + Parameters + ---------- + path : str + The base path for saving the plots. + mask : NDArray[Any] + The mask array. + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + """ + import matplotlib.pyplot as plt + + s = 1 + + global_lons[global_lons >= 180] -= 360 + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=s, marker="o", c="r") + if isinstance(path, str): + plt.savefig(path + "-global.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k") + if isinstance(path, str): + plt.savefig(path + "-cutout.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(lons, lats, s=s) + if isinstance(path, str): + plt.savefig(path + "-lam.png") + # plt.scatter(lons, lats, s=0.01) + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.scatter(lons, lats, s=s) + if isinstance(path, str): + plt.savefig(path + "-both.png") + # plt.scatter(lons, lats, s=0.01) + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.scatter(lons, lats, s=s) + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-both-zoomed.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-global-zoomed.png") + + +# TODO: Use the one from anemoi.utils.grids instead +# from anemoi.utils.grids import ... +def xyz_to_latlon(x: NDArray[Any], y: NDArray[Any], z: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any]]: + """Convert Cartesian coordinates to latitude and longitude. + + Parameters + ---------- + x : NDArray[Any] + X coordinates. + y : NDArray[Any] + Y coordinates. + z : NDArray[Any] + Z coordinates. + + Returns + ------- + Tuple[NDArray[Any], NDArray[Any]] + Latitude and longitude coordinates. + """ + return ( + np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))), + np.rad2deg(np.arctan2(y, x)), + ) + + +# TODO: Use the one from anemoi.utils.grids instead +# from anemoi.utils.grids import ... +def latlon_to_xyz( + lat: NDArray[Any], lon: NDArray[Any], radius: float = 1.0 +) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: + """Convert latitude and longitude to Cartesian coordinates. + + Parameters + ---------- + lat : NDArray[Any] + Latitude coordinates. + lon : NDArray[Any] + Longitude coordinates. + radius : float, optional + Radius of the sphere. Defaults to 1.0. + + Returns + ------- + Tuple[NDArray[Any], NDArray[Any], NDArray[Any]] + X, Y, and Z coordinates. + """ + # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates + # We assume that the Earth is a sphere of radius 1 so N(phi) = 1 + # We assume h = 0 + # + phi = np.deg2rad(lat) + lda = np.deg2rad(lon) + + cos_phi = np.cos(phi) + cos_lda = np.cos(lda) + sin_phi = np.sin(phi) + sin_lda = np.sin(lda) + + x = cos_phi * cos_lda * radius + y = cos_phi * sin_lda * radius + z = sin_phi * radius + + return x, y, z + + +class Triangle3D: + """A class to represent a 3D triangle and perform intersection tests with rays.""" + + def __init__(self, v0: NDArray[Any], v1: NDArray[Any], v2: NDArray[Any]) -> None: + """Initialize the Triangle3D object. + + Parameters + ---------- + v0 : NDArray[Any] + First vertex of the triangle. + v1 : NDArray[Any] + Second vertex of the triangle. + v2 : NDArray[Any] + Third vertex of the triangle. + """ + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + + def intersect(self, ray_origin: NDArray[Any], ray_direction: NDArray[Any]) -> bool: + """Check if a ray intersects with the triangle. + + Parameters + ---------- + ray_origin : NDArray[Any] + Origin of the ray. + ray_direction : NDArray[Any] + Direction of the ray. + + Returns + ------- + bool + True if the ray intersects with the triangle, False otherwise. + """ + # Möller–Trumbore intersection algorithm + # https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm + + epsilon = 0.0000001 + + h = np.cross(ray_direction, self.v2 - self.v0) + a = np.dot(self.v1 - self.v0, h) + + if -epsilon < a < epsilon: + return False + + f = 1.0 / a + s = ray_origin - self.v0 + u = f * np.dot(s, h) + + if u < 0.0 or u > 1.0: + return False + + q = np.cross(s, self.v1 - self.v0) + v = f * np.dot(ray_direction, q) + + if v < 0.0 or u + v > 1.0: + return False + + t = f * np.dot(self.v2 - self.v0, q) + + if t > epsilon: + return True + + return False + + +def cropping_mask( + lats: NDArray[Any], + lons: NDArray[Any], + north: float, + west: float, + south: float, + east: float, +) -> NDArray[Any]: + """Create a mask for the points within the specified latitude and longitude bounds. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + north : float + Northern boundary. + west : float + Western boundary. + south : float + Southern boundary. + east : float + Eastern boundary. + + Returns + ------- + NDArray[Any] + Mask array. + """ + mask = ( + (lats >= south) + & (lats <= north) + & ( + ((lons >= west) & (lons <= east)) + | ((lons >= west + 360) & (lons <= east + 360)) + | ((lons >= west - 360) & (lons <= east - 360)) + ) + ) + return mask + + +def cutout_mask( + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], + cropping_distance: float = 2.0, + neighbours: int = 5, + min_distance_km: int | float | None = None, + plot: str | None = None, +) -> NDArray[Any]: + """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + cropping_distance : float, optional + Cropping distance. Defaults to 2.0. + neighbours : int, optional + Number of neighbours. Defaults to 5. + min_distance_km : Optional[Union[int, float]], optional + Minimum distance in kilometers. Defaults to None. + plot : Optional[str], optional + Path for saving the plot. Defaults to None. + + Returns + ------- + NDArray[Any] + Mask array. + """ + from scipy.spatial import cKDTree + + # TODO: transform min_distance from lat/lon to xyz + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + # mask = np.array([True] * len(global_lats), dtype=bool) + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + lam_points = np.array(xyx).transpose() + + if isinstance(min_distance_km, (int, float)): + min_distance = min_distance_km / 6371.0 + else: + points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] + distances, _ = cKDTree(points).query(points, k=2) + min_distance = np.min(distances[:, 1]) + + LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") + + # Use a cKDTree to find the nearest points + distances, indices = cKDTree(lam_points).query(global_points, k=neighbours) + + # Centre of the Earth + zero = np.array([0.0, 0.0, 0.0]) + + # After the loop, 'inside_lam' will contain a list point to EXCLUDE + inside_lam = [] + + for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)): + + # We check more than one triangle in case te global point + # is near the edge of triangle, (the lam point and global points are colinear) + + inside = False + for j in range(neighbours): + t = Triangle3D( + lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]] + ) + inside = t.intersect(zero, global_point) + if inside: + break + + close = np.min(distance) <= min_distance + + inside_lam.append(inside or close) + + j = 0 + inside_lam_array = np.array(inside_lam) + for i, m in enumerate(mask): + if not m: + continue + + mask[i] = inside_lam_array[j] + j += 1 + + assert j == len(inside_lam_array) + + # Invert the mask, so we have only the points outside the cutout + mask = ~mask + + if plot: + plot_mask(plot, mask, lats, lons, global_lats, global_lons) + + return mask + + +def thinning_mask( + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], + cropping_distance: float = 2.0, +) -> NDArray[Any]: + """Return the list of points in [lats, lons] closest to [global_lats, global_lons]. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + cropping_distance : float, optional + Cropping distance. Defaults to 2.0. + + Returns + ------- + NDArray[Any] + Array of indices of the closest points. + """ + from scipy.spatial import cKDTree + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + points = np.array(xyx).transpose() + + # Use a cKDTree to find the nearest points + _, indices = cKDTree(points).query(global_points, k=1) + + return np.array([i for i in indices]) + + +def outline(lats: NDArray[Any], lons: NDArray[Any], neighbours: int = 5) -> list[int]: + """Find the outline of the grid points. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + neighbours : int, optional + Number of neighbours. Defaults to 5. + + Returns + ------- + List[int] + Indices of the outline points. + """ + from scipy.spatial import cKDTree + + xyx = latlon_to_xyz(lats, lons) + grid_points = np.array(xyx).transpose() + + # Use a cKDTree to find the nearest points + _, indices = cKDTree(grid_points).query(grid_points, k=neighbours) + + # Centre of the Earth + zero = np.array([0.0, 0.0, 0.0]) + + outside = [] + + for i, (point, index) in enumerate(zip(grid_points, indices)): + inside = False + for j in range(1, neighbours): + t = Triangle3D( + grid_points[index[j]], + grid_points[index[(j + 1) % neighbours]], + grid_points[index[(j + 2) % neighbours]], + ) + inside = t.intersect(zero, point) + if inside: + break + + if not inside: + outside.append(i) + + return outside + + +def deserialise_mask(encoded: str) -> NDArray[Any]: + """Deserialise a mask from a base64 encoded string. + + Parameters + ---------- + encoded : str + Base64 encoded string. + + Returns + ------- + NDArray[Any] + Deserialised mask array. + """ + import pickle + import zlib + + packed = pickle.loads(zlib.decompress(base64.b64decode(encoded))) + + mask = [] + value = False + for count in packed: + mask.extend([value] * count) + value = not value + return np.array(mask, dtype=bool) + + +def _serialise_mask(mask: NDArray[Any]) -> str: + """Serialise a mask to a base64 encoded string. + + Parameters + ---------- + mask : NDArray[Any] + Mask array. + + Returns + ------- + str + Base64 encoded string. + """ + import pickle + import zlib + + assert len(mask.shape) == 1 + assert len(mask) + + packed = [] + last = mask[0] + count = 1 + + for value in mask[1:]: + if value == last: + count += 1 + else: + packed.append(count) + last = value + count = 1 + + packed.append(count) + + # We always start with an 'off' value + # So if the first value is 'on', we need to add a zero + if mask[0]: + packed.insert(0, 0) + + return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8") + + +def serialise_mask(mask: NDArray[Any]) -> str: + """Serialise a mask and ensure it can be deserialised. + + Parameters + ---------- + mask : NDArray[Any] + Mask array. + + Returns + ------- + str + Base64 encoded string. + """ + result = _serialise_mask(mask) + # Make sure we can deserialise it + assert np.all(mask == deserialise_mask(result)) + return result + + +def nearest_grid_points( + source_latitudes: NDArray[Any], + source_longitudes: NDArray[Any], + target_latitudes: NDArray[Any], + target_longitudes: NDArray[Any], + max_distance: float = None, + k: int = 1, +) -> NDArray[Any]: + """Find the nearest grid points from source to target coordinates. + + Parameters + ---------- + source_latitudes : NDArray[Any] + Source latitude coordinates. + source_longitudes : NDArray[Any] + Source longitude coordinates. + target_latitudes : NDArray[Any] + Target latitude coordinates. + target_longitudes : NDArray[Any] + Target longitude coordinates. + max_distance: float, optional + Maximum distance between nearest point and point to interpolate. Defaults to None. + For example, 1e-3 is 1 km. + k : int, optional + The number of k closest neighbors to consider for interpolation + + Returns + ------- + NDArray[Any] + Indices of the nearest grid points. + """ + # TODO: Use the one from anemoi.utils.grids instead + # from anemoi.utils.grids import ... + from scipy.spatial import cKDTree + + source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) + source_points = np.array(source_xyz).transpose() + + target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) + target_points = np.array(target_xyz).transpose() + if max_distance is None: + distances, indices = cKDTree(source_points).query(target_points, k=k) + else: + distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + return distances, indices + + +if __name__ == "__main__": + global_lats, global_lons = np.meshgrid( + np.linspace(90, -90, 90), + np.linspace(-180, 180, 180), + ) + global_lats = global_lats.flatten() + global_lons = global_lons.flatten() + + lats, lons = np.meshgrid( + np.linspace(50, 40, 100), + np.linspace(-10, 15, 100), + ) + lats = lats.flatten() + lons = lons.flatten() + + mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0) + + import matplotlib.pyplot as plt + + fig = plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") + plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") + # plt.scatter(lons, lats, s=0.01) + plt.savefig("cutout.png") diff --git a/src/anemoi/datasets/misc/testing.py b/src/anemoi/datasets/misc/testing.py new file mode 100644 index 000000000..a15c7fd7e --- /dev/null +++ b/src/anemoi/datasets/misc/testing.py @@ -0,0 +1,173 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +# A collection of functions to support pytest testing + +import logging +from typing import Any + +LOG = logging.getLogger(__name__) + + +def assert_field_list( + fs: list[Any], + size: int | None = None, + start: Any | None = None, + end: Any | None = None, + constant: bool = False, + skip: Any | None = None, +) -> None: + """Asserts various properties of a list of fields. + + Parameters + ---------- + fs : List[Any] + List of fields to be checked. + size : Optional[int], optional + Expected size of the list. If None, the list must be non-empty. + start : Optional[Any], optional + Expected start metadata value. If None, no check is performed. + end : Optional[Any], optional + Expected end metadata value. If None, no check is performed. + constant : bool, optional + If True, checks that all fields are constant. + skip : Optional[Any], optional + Placeholder for future use. + """ + import numpy as np + + if size is None: + assert len(fs) > 0, fs + else: + assert len(fs) == size, (len(fs), size) + + first = fs[0] + last = fs[-1] + + if constant: + # TODO: add a check for constant fields + pass + else: + assert start is None or first.metadata("valid_datetime") == start, (first.metadata("valid_datetime"), start) + assert end is None or last.metadata("valid_datetime") == end, (last.metadata("valid_datetime"), end) + print(first.datetime()) + + print(last.metadata()) + + first = first + latitudes, longitudes = first.grid_points() + + assert len(latitudes.shape) == 1, latitudes.shape + assert len(longitudes.shape) == 1, longitudes.shape + + assert len(latitudes) == len(longitudes), (len(latitudes), len(longitudes)) + data = first.to_numpy(flatten=True) + + assert len(data) == len(latitudes), (len(data), len(latitudes)) + + north = np.max(latitudes) + south = np.min(latitudes) + east = np.max(longitudes) + west = np.min(longitudes) + + assert north >= south, (north, south) + assert east >= west, (east, west) + assert north <= 90, north + assert south >= -90, south + assert east <= 360, east + assert west >= -180, west + + +class IndexTester: + """Class to test indexing of datasets.""" + + def __init__(self, ds: Any) -> None: + """Initialise the IndexTester. + + Parameters + ---------- + ds : Any + Dataset. + """ + self.ds = ds + self.np = ds[:] # Numpy array + + assert self.ds.shape == self.np.shape, (self.ds.shape, self.np.shape) + assert (self.ds == self.np).all() + + def __getitem__(self, index: Any) -> None: + """Test indexing. + + Parameters + ---------- + index : Any + Index. + """ + LOG.info("IndexTester: %s", index) + if self.ds[index] is None: + assert False, (self.ds, index) + + if not (self.ds[index] == self.np[index]).all(): + assert (self.ds[index] == self.np[index]).all() + + +def default_test_indexing(ds): + + t = IndexTester(ds) + + t[0:10, :, 0] + t[:, 0:3, 0] + # t[:, :, 0] + t[0:10, 0:3, 0] + t[:, :, :] + + if ds.shape[1] > 2: # Variable dimension + t[:, (1, 2), :] + t[:, (1, 2)] + + t[0] + t[0, :] + t[0, 0, :] + t[0, 0, 0, :] + + if ds.shape[2] > 1: # Ensemble dimension + t[0:10, :, (0, 1)] + + for i in range(3): + t[i] + start = 5 * i + end = len(ds) - 5 * i + step = len(ds) // 10 + + t[start:end:step] + t[start:end] + t[start:] + t[:end] + t[::step] + + +class Trace: + + def __init__(self, ds): + self.ds = ds + self.f = open("trace.txt", "a") + + def __getattr__(self, name: str) -> Any: + + print(name, file=self.f, flush=True) + return getattr(self.ds, name) + + def __len__(self) -> int: + print("__len__", file=self.f, flush=True) + return len(self.ds) + + def __getitem__(self, index: Any) -> Any: + print("__getitem__", file=self.f, flush=True) + return self.ds[index] diff --git a/src/anemoi/datasets/misc/validate.py b/src/anemoi/datasets/misc/validate.py new file mode 100644 index 000000000..a1e168116 --- /dev/null +++ b/src/anemoi/datasets/misc/validate.py @@ -0,0 +1,598 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import math +from collections import defaultdict + +import numpy as np + +from anemoi.datasets.testing import default_test_indexing +from anemoi.datasets.use.dataset import Dataset + +LOG = logging.getLogger(__name__) +# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 + +TRAINING_METHODS = [ + "__getitem__", + "__len__", + "latitudes", + "longitudes", + "metadata", # Accessed when checkpointing + "missing", + "name_to_index", + "shape", + "statistics", + "supporting_arrays", # Accessed when checkpointing + "variables", +] + +EXTRA_TRAINING_METHODS = [ + "statistics_tendencies", +] + +DEBUGGING_METHODS = [ + "plot", + "to_index", + "tree", + "source", +] + +PUBLIC_METADATA_METHODS = [ + "arguments", + "dtype", + "end_date", + "resolution", + "start_date", + "field_shape", + "frequency", + "dates", + "typed_variables", + "variables_metadata", +] + +PRIVATE_METADATA_METHODS = [ + "computed_constant_fields", + "constant_fields", + "dataset_metadata", + "label", + "metadata_specific", + "provenance", +] + +INTERNAL_METHODS = [ + "mutate", + "swap_with_parent", + "dates_interval_to_indices", +] + +EXPERIMENTAL_METHODS = [ + "get_dataset_names", + "name", + "grids", +] + +OTHER_METHODS = [ + "collect_input_sources", + "collect_supporting_arrays", + "sub_shape", +] + + +METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")} + + +METHODS = set(sum(METHODS_CATEGORIES.values(), [])) + + +KWARGS = { + "__len__": {}, + "__getitem__": {"index": 0}, + "get_dataset_names": {"names": set()}, + "metadata": {}, + "metadata_specific": {}, + "mutate": {}, + "plot": {"date": 0, "variable": 0}, + "provenance": {}, + "source": {"index": 0}, + "statistics_tendencies": {}, + "sub_shape": {}, + "supporting_arrays": {}, + "swap_with_parent": {}, + "to_index": {"date": 0, "variable": 0}, + "tree": {}, +} + + +class Unknown: + emoji = "❓" + + +class Success: + emoji = "✅" + success = True + + def __repr__(self): + return "Success" + + +class Error: + success = False + + def __init__(self, message): + self.message = message + + def __repr__(self): + return str(self.message) or repr(self.message) or "Error" + + +class Failure(Error): + emoji = "💥" + + +class Internal(Error): + emoji = "💣" + + +class Invalid(Error): + emoji = "❌" + + +class Report: + + def __init__(self): + self.report = {} + self.methods = {} + self.warnings = defaultdict(list) + + def method(self, name, method): + self.methods[name] = method + + def success(self, name): + self.report[name] = Success() + + def failure(self, name, message): + self.report[name] = Failure(message) + + def internal(self, name, message): + self.report[name] = Internal(message) + + def invalid(self, name, exception): + self.report[name] = Invalid(exception) + + def warning(self, name, message): + self.warnings[name].append(message) + + def summary(self, detailed=False): + + maxlen = max(len(name) for name in self.report.keys()) + + for name, methods in METHODS_CATEGORIES.items(): + print() + print(f"{name.title().replace('_', ' ')}:") + print("-" * (len(name) + 1)) + print() + + for method in methods: + r = self.report.get(method, Unknown()) + msg = repr(r) + if not msg.endswith("."): + msg += "." + print(f"{r.emoji} {method.ljust(maxlen)}: {msg}") + + for w in self.warnings.get(method, []): + print(" " * (maxlen + 4), "⚠️", w) + + if r.success: + continue + + if not detailed: + continue + + if method not in self.methods: + continue + + proc = self.methods[method] + + doc = proc.__doc__ + if doc: + width = 80 + indent = maxlen + 4 + doc = "\n".join(["=" * width, "", doc, "=" * width]) + indented_doc = "\n".join(" " * indent + line for line in doc.splitlines()) + print() + print(indented_doc) + print() + print() + + print() + + +def _no_validate(report, dataset, name, result): + report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}") + + +def validate_variables(report, dataset, name, result): + """Validate the variables of the dataset.""" + + if not isinstance(result, (list, tuple)): + raise ValueError(f"Result is not a list or tuple {type(result)}") + + if len(result) != dataset.shape[1]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}") + + for value in result: + if not isinstance(value, str): + raise ValueError(f"`{value}` is not a string") + + +def validate_latitudes(report, dataset, name, result): + """Validate the latitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -90) & (result <= 90)): + raise ValueError("Result contains values outside the range [-90, 90]") + + if np.all((result >= -np.pi) & (result <= np.pi)): + report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?") + + +def validate_longitudes(report, dataset, name, result): + """Validate the longitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -180) & (result <= 360)): + raise ValueError("Result contains values outside the range [-180, 360]") + + if np.all((result >= -np.pi) & (result <= 2 * np.pi)): + report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?") + + +def validate_statistics(report, dataset, name, result): + """Validate the statistics of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in ["mean", "stdev", "minimum", "maximum"]: + + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], np.ndarray): + raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}") + + if len(result[key].shape) != 1: + raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1") + + if result[key].shape[0] != len(dataset.variables): + raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}") + + if not np.all(np.isfinite(result[key])): + raise ValueError(f"Result[{key}] contains non-finite values") + + if np.isnan(result[key]).any(): + report.invalid(name, ValueError(f"Result[{key}] contains NaN values")) + + +def validate_shape(report, dataset, name, result): + """Validate the shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if len(result) != 4: + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}") + + if result[0] != len(dataset): + raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}") + + if result[1] != len(dataset.variables): + raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}") + + if result[2] != 1: # We ignore ensemble dimension for now + pass + + if result[3] != len(dataset.latitudes): + raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}") + + +def validate_supporting_arrays(report, dataset, name, result): + """Validate the supporting arrays of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + if "latitudes" not in result: + raise ValueError("Result does not contain `latitudes`") + + if "longitudes" not in result: + raise ValueError("Result does not contain `longitudes`") + + if not isinstance(result["latitudes"], np.ndarray): + raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}") + + if not isinstance(result["longitudes"], np.ndarray): + raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}") + + if np.any(result["latitudes"] != dataset.latitudes): + raise ValueError("Result[latitudes] does not match dataset.latitudes") + + if np.any(result["longitudes"] != dataset.longitudes): + raise ValueError("Result[longitudes] does not match dataset.longitudes") + + +def validate_dates(report, dataset, name, result): + """Validate the dates of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result.shape) != 1: + raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1") + + if result.shape[0] != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}") + + if not np.issubdtype(result.dtype, np.datetime64): + raise ValueError(f"Result is not a datetime64 array {result.dtype}") + + if len(result) != len(dataset.dates): + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + for d1, d2 in zip(result[:-1], result[1:]): + if d1 >= d2: + raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}") + + frequency = np.diff(result) + if not np.all(frequency == frequency[0]): + raise ValueError("Result contains non-constant frequency") + + +def validate_metadata(report, dataset, name, result): + """Validate the metadata of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + +def validate_missing(report, dataset, name, result): + """Validate the missing values of the dataset.""" + + if not isinstance(result, set): + raise ValueError(f"Result is not a set {type(result)}") + + if not all(isinstance(item, int) for item in result): + raise ValueError("Result contains non-integer values") + + if len(result) > 0: + if min(result) < 0: + raise ValueError("Result contains negative values") + + if max(result) >= len(dataset): + raise ValueError(f"Result contains values greater than {len(dataset)}") + + +def validate_name_to_index(report, dataset, name, result): + """Validate the name to index mapping of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in dataset.variables: + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], int): + raise ValueError(f"Result[{key}] is not an int {type(result[key])}") + + if result[key] < 0 or result[key] >= len(dataset.variables): + raise ValueError(f"Result[{key}] is out of bounds: {result[key]}") + + index_to_name = {v: k for k, v in result.items()} + for i in range(len(dataset.variables)): + if i not in index_to_name: + raise ValueError(f"Result does not contain index `{i}`") + + if not isinstance(index_to_name[i], str): + raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}") + + if index_to_name[i] != dataset.variables[i]: + raise ValueError( + f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}" + ) + + +def validate___getitem__(report, dataset, name, result): + """Validate the __getitem__ method of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if result.shape != dataset.shape[1:]: + raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}") + + +def validate___len__(report, dataset, name, result): + """Validate the __len__ method of the dataset.""" + + if not isinstance(result, int): + raise ValueError(f"Result is not an int {type(result)}") + + if result != dataset.shape[0]: + raise ValueError(f"Result has wrong length: {result} != {len(dataset)}") + + if result != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}") + + +def validate_start_date(report, dataset, name, result): + """Validate the start date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[0]: + raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}") + + +def validate_end_date(report, dataset, name, result): + """Validate the end date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[-1]: + raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}") + + +def validate_field_shape(report, dataset, name, result): + """Validate the field shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if math.prod(result) != dataset.shape[-1]: + raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}") + + +def validate(report, dataset, name, kwargs=None): + + try: + + validate_fn = globals().get(f"validate_{name}", _no_validate) + + # Check if the method is still in the Dataset class + try: + report.method(name, getattr(Dataset, name)) + except AttributeError: + report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.") + return + + # Check if the method is supported by the dataset instance + try: + result = getattr(dataset, name) + except AttributeError as e: + report.failure(name, e) + return + + # Check if the method is callable + if callable(result): + if kwargs is None: + report.internal( + name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly." + ) + return + else: + if kwargs is not None: + report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.") + return + + if kwargs is not None: + result = result(**kwargs) + + if isinstance(result, np.ndarray) and np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + try: + validate_fn(report, dataset, name, result) + except Exception as e: + report.invalid(name, e) + return + + report.success(name) + + except Exception as e: + report.failure(name, e) + + +def validate_dtype(report, dataset, name, result): + """Validate the dtype of the dataset.""" + + if not isinstance(result, np.dtype): + raise ValueError(f"Result is not a np.dtype {type(result)}") + + +def validate_dataset(dataset, costly_checks=False, detailed=False): + """Validate the dataset.""" + + report = Report() + + if costly_checks: + # This check is expensive as it loads the entire dataset into memory + # so we make it optional + default_test_indexing(dataset) + + for i, x in enumerate(dataset): + y = dataset[i] + assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}" + + for name in METHODS: + validate(report, dataset, name, kwargs=KWARGS.get(name)) + + report.summary(detailed=detailed) + + +if __name__ == "__main__": + methods = METHODS_CATEGORIES.copy() + methods.pop("OTHER_METHODS") + + o = set(OTHER_METHODS) + overlap = False + for m in methods: + if set(methods[m]).intersection(set(OTHER_METHODS)): + print( + f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}" + ) + o = o - set(methods[m]) + overlap = True + + for m in methods: + for n in methods: + if n is not m: + if set(methods[m]).intersection(set(methods[n])): + print( + f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}" + ) + + if overlap: + print(sorted(o)) From ba24f4cb50e2c89813c9855f58fe450d0c300dee Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:31:11 +0000 Subject: [PATCH 152/212] rename files --- src/anemoi/datasets/misc/validate.py | 598 --------------------------- 1 file changed, 598 deletions(-) delete mode 100644 src/anemoi/datasets/misc/validate.py diff --git a/src/anemoi/datasets/misc/validate.py b/src/anemoi/datasets/misc/validate.py deleted file mode 100644 index a1e168116..000000000 --- a/src/anemoi/datasets/misc/validate.py +++ /dev/null @@ -1,598 +0,0 @@ -# (C) Copyright 2025- Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -import math -from collections import defaultdict - -import numpy as np - -from anemoi.datasets.testing import default_test_indexing -from anemoi.datasets.use.dataset import Dataset - -LOG = logging.getLogger(__name__) -# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 - -TRAINING_METHODS = [ - "__getitem__", - "__len__", - "latitudes", - "longitudes", - "metadata", # Accessed when checkpointing - "missing", - "name_to_index", - "shape", - "statistics", - "supporting_arrays", # Accessed when checkpointing - "variables", -] - -EXTRA_TRAINING_METHODS = [ - "statistics_tendencies", -] - -DEBUGGING_METHODS = [ - "plot", - "to_index", - "tree", - "source", -] - -PUBLIC_METADATA_METHODS = [ - "arguments", - "dtype", - "end_date", - "resolution", - "start_date", - "field_shape", - "frequency", - "dates", - "typed_variables", - "variables_metadata", -] - -PRIVATE_METADATA_METHODS = [ - "computed_constant_fields", - "constant_fields", - "dataset_metadata", - "label", - "metadata_specific", - "provenance", -] - -INTERNAL_METHODS = [ - "mutate", - "swap_with_parent", - "dates_interval_to_indices", -] - -EXPERIMENTAL_METHODS = [ - "get_dataset_names", - "name", - "grids", -] - -OTHER_METHODS = [ - "collect_input_sources", - "collect_supporting_arrays", - "sub_shape", -] - - -METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")} - - -METHODS = set(sum(METHODS_CATEGORIES.values(), [])) - - -KWARGS = { - "__len__": {}, - "__getitem__": {"index": 0}, - "get_dataset_names": {"names": set()}, - "metadata": {}, - "metadata_specific": {}, - "mutate": {}, - "plot": {"date": 0, "variable": 0}, - "provenance": {}, - "source": {"index": 0}, - "statistics_tendencies": {}, - "sub_shape": {}, - "supporting_arrays": {}, - "swap_with_parent": {}, - "to_index": {"date": 0, "variable": 0}, - "tree": {}, -} - - -class Unknown: - emoji = "❓" - - -class Success: - emoji = "✅" - success = True - - def __repr__(self): - return "Success" - - -class Error: - success = False - - def __init__(self, message): - self.message = message - - def __repr__(self): - return str(self.message) or repr(self.message) or "Error" - - -class Failure(Error): - emoji = "💥" - - -class Internal(Error): - emoji = "💣" - - -class Invalid(Error): - emoji = "❌" - - -class Report: - - def __init__(self): - self.report = {} - self.methods = {} - self.warnings = defaultdict(list) - - def method(self, name, method): - self.methods[name] = method - - def success(self, name): - self.report[name] = Success() - - def failure(self, name, message): - self.report[name] = Failure(message) - - def internal(self, name, message): - self.report[name] = Internal(message) - - def invalid(self, name, exception): - self.report[name] = Invalid(exception) - - def warning(self, name, message): - self.warnings[name].append(message) - - def summary(self, detailed=False): - - maxlen = max(len(name) for name in self.report.keys()) - - for name, methods in METHODS_CATEGORIES.items(): - print() - print(f"{name.title().replace('_', ' ')}:") - print("-" * (len(name) + 1)) - print() - - for method in methods: - r = self.report.get(method, Unknown()) - msg = repr(r) - if not msg.endswith("."): - msg += "." - print(f"{r.emoji} {method.ljust(maxlen)}: {msg}") - - for w in self.warnings.get(method, []): - print(" " * (maxlen + 4), "⚠️", w) - - if r.success: - continue - - if not detailed: - continue - - if method not in self.methods: - continue - - proc = self.methods[method] - - doc = proc.__doc__ - if doc: - width = 80 - indent = maxlen + 4 - doc = "\n".join(["=" * width, "", doc, "=" * width]) - indented_doc = "\n".join(" " * indent + line for line in doc.splitlines()) - print() - print(indented_doc) - print() - print() - - print() - - -def _no_validate(report, dataset, name, result): - report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}") - - -def validate_variables(report, dataset, name, result): - """Validate the variables of the dataset.""" - - if not isinstance(result, (list, tuple)): - raise ValueError(f"Result is not a list or tuple {type(result)}") - - if len(result) != dataset.shape[1]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}") - - for value in result: - if not isinstance(value, str): - raise ValueError(f"`{value}` is not a string") - - -def validate_latitudes(report, dataset, name, result): - """Validate the latitudes of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result) != dataset.shape[3]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - if not np.all((result >= -90) & (result <= 90)): - raise ValueError("Result contains values outside the range [-90, 90]") - - if np.all((result >= -np.pi) & (result <= np.pi)): - report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?") - - -def validate_longitudes(report, dataset, name, result): - """Validate the longitudes of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result) != dataset.shape[3]: - raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - if not np.all((result >= -180) & (result <= 360)): - raise ValueError("Result contains values outside the range [-180, 360]") - - if np.all((result >= -np.pi) & (result <= 2 * np.pi)): - report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?") - - -def validate_statistics(report, dataset, name, result): - """Validate the statistics of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - for key in ["mean", "stdev", "minimum", "maximum"]: - - if key not in result: - raise ValueError(f"Result does not contain `{key}`") - - if not isinstance(result[key], np.ndarray): - raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}") - - if len(result[key].shape) != 1: - raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1") - - if result[key].shape[0] != len(dataset.variables): - raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}") - - if not np.all(np.isfinite(result[key])): - raise ValueError(f"Result[{key}] contains non-finite values") - - if np.isnan(result[key]).any(): - report.invalid(name, ValueError(f"Result[{key}] contains NaN values")) - - -def validate_shape(report, dataset, name, result): - """Validate the shape of the dataset.""" - - if not isinstance(result, tuple): - raise ValueError(f"Result is not a tuple {type(result)}") - - if len(result) != 4: - raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}") - - if result[0] != len(dataset): - raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}") - - if result[1] != len(dataset.variables): - raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}") - - if result[2] != 1: # We ignore ensemble dimension for now - pass - - if result[3] != len(dataset.latitudes): - raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}") - - -def validate_supporting_arrays(report, dataset, name, result): - """Validate the supporting arrays of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - if "latitudes" not in result: - raise ValueError("Result does not contain `latitudes`") - - if "longitudes" not in result: - raise ValueError("Result does not contain `longitudes`") - - if not isinstance(result["latitudes"], np.ndarray): - raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}") - - if not isinstance(result["longitudes"], np.ndarray): - raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}") - - if np.any(result["latitudes"] != dataset.latitudes): - raise ValueError("Result[latitudes] does not match dataset.latitudes") - - if np.any(result["longitudes"] != dataset.longitudes): - raise ValueError("Result[longitudes] does not match dataset.longitudes") - - -def validate_dates(report, dataset, name, result): - """Validate the dates of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if len(result.shape) != 1: - raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1") - - if result.shape[0] != len(dataset.dates): - raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}") - - if not np.issubdtype(result.dtype, np.datetime64): - raise ValueError(f"Result is not a datetime64 array {result.dtype}") - - if len(result) != len(dataset.dates): - raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}") - - if not np.all(np.isfinite(result)): - raise ValueError("Result contains non-finite values") - - if np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - for d1, d2 in zip(result[:-1], result[1:]): - if d1 >= d2: - raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}") - - frequency = np.diff(result) - if not np.all(frequency == frequency[0]): - raise ValueError("Result contains non-constant frequency") - - -def validate_metadata(report, dataset, name, result): - """Validate the metadata of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - -def validate_missing(report, dataset, name, result): - """Validate the missing values of the dataset.""" - - if not isinstance(result, set): - raise ValueError(f"Result is not a set {type(result)}") - - if not all(isinstance(item, int) for item in result): - raise ValueError("Result contains non-integer values") - - if len(result) > 0: - if min(result) < 0: - raise ValueError("Result contains negative values") - - if max(result) >= len(dataset): - raise ValueError(f"Result contains values greater than {len(dataset)}") - - -def validate_name_to_index(report, dataset, name, result): - """Validate the name to index mapping of the dataset.""" - - if not isinstance(result, dict): - raise ValueError(f"Result is not a dict {type(result)}") - - for key in dataset.variables: - if key not in result: - raise ValueError(f"Result does not contain `{key}`") - - if not isinstance(result[key], int): - raise ValueError(f"Result[{key}] is not an int {type(result[key])}") - - if result[key] < 0 or result[key] >= len(dataset.variables): - raise ValueError(f"Result[{key}] is out of bounds: {result[key]}") - - index_to_name = {v: k for k, v in result.items()} - for i in range(len(dataset.variables)): - if i not in index_to_name: - raise ValueError(f"Result does not contain index `{i}`") - - if not isinstance(index_to_name[i], str): - raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}") - - if index_to_name[i] != dataset.variables[i]: - raise ValueError( - f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}" - ) - - -def validate___getitem__(report, dataset, name, result): - """Validate the __getitem__ method of the dataset.""" - - if not isinstance(result, np.ndarray): - raise ValueError(f"Result is not a np.ndarray {type(result)}") - - if result.shape != dataset.shape[1:]: - raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}") - - -def validate___len__(report, dataset, name, result): - """Validate the __len__ method of the dataset.""" - - if not isinstance(result, int): - raise ValueError(f"Result is not an int {type(result)}") - - if result != dataset.shape[0]: - raise ValueError(f"Result has wrong length: {result} != {len(dataset)}") - - if result != len(dataset.dates): - raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}") - - -def validate_start_date(report, dataset, name, result): - """Validate the start date of the dataset.""" - - if not isinstance(result, np.datetime64): - raise ValueError(f"Result is not a datetime64 {type(result)}") - - if result != dataset.dates[0]: - raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}") - - -def validate_end_date(report, dataset, name, result): - """Validate the end date of the dataset.""" - - if not isinstance(result, np.datetime64): - raise ValueError(f"Result is not a datetime64 {type(result)}") - - if result != dataset.dates[-1]: - raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}") - - -def validate_field_shape(report, dataset, name, result): - """Validate the field shape of the dataset.""" - - if not isinstance(result, tuple): - raise ValueError(f"Result is not a tuple {type(result)}") - - if math.prod(result) != dataset.shape[-1]: - raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}") - - -def validate(report, dataset, name, kwargs=None): - - try: - - validate_fn = globals().get(f"validate_{name}", _no_validate) - - # Check if the method is still in the Dataset class - try: - report.method(name, getattr(Dataset, name)) - except AttributeError: - report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.") - return - - # Check if the method is supported by the dataset instance - try: - result = getattr(dataset, name) - except AttributeError as e: - report.failure(name, e) - return - - # Check if the method is callable - if callable(result): - if kwargs is None: - report.internal( - name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly." - ) - return - else: - if kwargs is not None: - report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.") - return - - if kwargs is not None: - result = result(**kwargs) - - if isinstance(result, np.ndarray) and np.isnan(result).any(): - report.invalid(name, ValueError("Result contains NaN values")) - return - - try: - validate_fn(report, dataset, name, result) - except Exception as e: - report.invalid(name, e) - return - - report.success(name) - - except Exception as e: - report.failure(name, e) - - -def validate_dtype(report, dataset, name, result): - """Validate the dtype of the dataset.""" - - if not isinstance(result, np.dtype): - raise ValueError(f"Result is not a np.dtype {type(result)}") - - -def validate_dataset(dataset, costly_checks=False, detailed=False): - """Validate the dataset.""" - - report = Report() - - if costly_checks: - # This check is expensive as it loads the entire dataset into memory - # so we make it optional - default_test_indexing(dataset) - - for i, x in enumerate(dataset): - y = dataset[i] - assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}" - - for name in METHODS: - validate(report, dataset, name, kwargs=KWARGS.get(name)) - - report.summary(detailed=detailed) - - -if __name__ == "__main__": - methods = METHODS_CATEGORIES.copy() - methods.pop("OTHER_METHODS") - - o = set(OTHER_METHODS) - overlap = False - for m in methods: - if set(methods[m]).intersection(set(OTHER_METHODS)): - print( - f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}" - ) - o = o - set(methods[m]) - overlap = True - - for m in methods: - for n in methods: - if n is not m: - if set(methods[m]).intersection(set(methods[n])): - print( - f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}" - ) - - if overlap: - print(sorted(o)) From fa6f799263519dedef218719fda7f90644e5779e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 15:35:16 +0000 Subject: [PATCH 153/212] rename files --- src/anemoi/datasets/build/gridded/validate.py | 598 ++++++++++++++++++ 1 file changed, 598 insertions(+) create mode 100644 src/anemoi/datasets/build/gridded/validate.py diff --git a/src/anemoi/datasets/build/gridded/validate.py b/src/anemoi/datasets/build/gridded/validate.py new file mode 100644 index 000000000..a1e168116 --- /dev/null +++ b/src/anemoi/datasets/build/gridded/validate.py @@ -0,0 +1,598 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import math +from collections import defaultdict + +import numpy as np + +from anemoi.datasets.testing import default_test_indexing +from anemoi.datasets.use.dataset import Dataset + +LOG = logging.getLogger(__name__) +# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 + +TRAINING_METHODS = [ + "__getitem__", + "__len__", + "latitudes", + "longitudes", + "metadata", # Accessed when checkpointing + "missing", + "name_to_index", + "shape", + "statistics", + "supporting_arrays", # Accessed when checkpointing + "variables", +] + +EXTRA_TRAINING_METHODS = [ + "statistics_tendencies", +] + +DEBUGGING_METHODS = [ + "plot", + "to_index", + "tree", + "source", +] + +PUBLIC_METADATA_METHODS = [ + "arguments", + "dtype", + "end_date", + "resolution", + "start_date", + "field_shape", + "frequency", + "dates", + "typed_variables", + "variables_metadata", +] + +PRIVATE_METADATA_METHODS = [ + "computed_constant_fields", + "constant_fields", + "dataset_metadata", + "label", + "metadata_specific", + "provenance", +] + +INTERNAL_METHODS = [ + "mutate", + "swap_with_parent", + "dates_interval_to_indices", +] + +EXPERIMENTAL_METHODS = [ + "get_dataset_names", + "name", + "grids", +] + +OTHER_METHODS = [ + "collect_input_sources", + "collect_supporting_arrays", + "sub_shape", +] + + +METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")} + + +METHODS = set(sum(METHODS_CATEGORIES.values(), [])) + + +KWARGS = { + "__len__": {}, + "__getitem__": {"index": 0}, + "get_dataset_names": {"names": set()}, + "metadata": {}, + "metadata_specific": {}, + "mutate": {}, + "plot": {"date": 0, "variable": 0}, + "provenance": {}, + "source": {"index": 0}, + "statistics_tendencies": {}, + "sub_shape": {}, + "supporting_arrays": {}, + "swap_with_parent": {}, + "to_index": {"date": 0, "variable": 0}, + "tree": {}, +} + + +class Unknown: + emoji = "❓" + + +class Success: + emoji = "✅" + success = True + + def __repr__(self): + return "Success" + + +class Error: + success = False + + def __init__(self, message): + self.message = message + + def __repr__(self): + return str(self.message) or repr(self.message) or "Error" + + +class Failure(Error): + emoji = "💥" + + +class Internal(Error): + emoji = "💣" + + +class Invalid(Error): + emoji = "❌" + + +class Report: + + def __init__(self): + self.report = {} + self.methods = {} + self.warnings = defaultdict(list) + + def method(self, name, method): + self.methods[name] = method + + def success(self, name): + self.report[name] = Success() + + def failure(self, name, message): + self.report[name] = Failure(message) + + def internal(self, name, message): + self.report[name] = Internal(message) + + def invalid(self, name, exception): + self.report[name] = Invalid(exception) + + def warning(self, name, message): + self.warnings[name].append(message) + + def summary(self, detailed=False): + + maxlen = max(len(name) for name in self.report.keys()) + + for name, methods in METHODS_CATEGORIES.items(): + print() + print(f"{name.title().replace('_', ' ')}:") + print("-" * (len(name) + 1)) + print() + + for method in methods: + r = self.report.get(method, Unknown()) + msg = repr(r) + if not msg.endswith("."): + msg += "." + print(f"{r.emoji} {method.ljust(maxlen)}: {msg}") + + for w in self.warnings.get(method, []): + print(" " * (maxlen + 4), "⚠️", w) + + if r.success: + continue + + if not detailed: + continue + + if method not in self.methods: + continue + + proc = self.methods[method] + + doc = proc.__doc__ + if doc: + width = 80 + indent = maxlen + 4 + doc = "\n".join(["=" * width, "", doc, "=" * width]) + indented_doc = "\n".join(" " * indent + line for line in doc.splitlines()) + print() + print(indented_doc) + print() + print() + + print() + + +def _no_validate(report, dataset, name, result): + report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}") + + +def validate_variables(report, dataset, name, result): + """Validate the variables of the dataset.""" + + if not isinstance(result, (list, tuple)): + raise ValueError(f"Result is not a list or tuple {type(result)}") + + if len(result) != dataset.shape[1]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}") + + for value in result: + if not isinstance(value, str): + raise ValueError(f"`{value}` is not a string") + + +def validate_latitudes(report, dataset, name, result): + """Validate the latitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -90) & (result <= 90)): + raise ValueError("Result contains values outside the range [-90, 90]") + + if np.all((result >= -np.pi) & (result <= np.pi)): + report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?") + + +def validate_longitudes(report, dataset, name, result): + """Validate the longitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -180) & (result <= 360)): + raise ValueError("Result contains values outside the range [-180, 360]") + + if np.all((result >= -np.pi) & (result <= 2 * np.pi)): + report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?") + + +def validate_statistics(report, dataset, name, result): + """Validate the statistics of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in ["mean", "stdev", "minimum", "maximum"]: + + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], np.ndarray): + raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}") + + if len(result[key].shape) != 1: + raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1") + + if result[key].shape[0] != len(dataset.variables): + raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}") + + if not np.all(np.isfinite(result[key])): + raise ValueError(f"Result[{key}] contains non-finite values") + + if np.isnan(result[key]).any(): + report.invalid(name, ValueError(f"Result[{key}] contains NaN values")) + + +def validate_shape(report, dataset, name, result): + """Validate the shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if len(result) != 4: + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}") + + if result[0] != len(dataset): + raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}") + + if result[1] != len(dataset.variables): + raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}") + + if result[2] != 1: # We ignore ensemble dimension for now + pass + + if result[3] != len(dataset.latitudes): + raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}") + + +def validate_supporting_arrays(report, dataset, name, result): + """Validate the supporting arrays of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + if "latitudes" not in result: + raise ValueError("Result does not contain `latitudes`") + + if "longitudes" not in result: + raise ValueError("Result does not contain `longitudes`") + + if not isinstance(result["latitudes"], np.ndarray): + raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}") + + if not isinstance(result["longitudes"], np.ndarray): + raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}") + + if np.any(result["latitudes"] != dataset.latitudes): + raise ValueError("Result[latitudes] does not match dataset.latitudes") + + if np.any(result["longitudes"] != dataset.longitudes): + raise ValueError("Result[longitudes] does not match dataset.longitudes") + + +def validate_dates(report, dataset, name, result): + """Validate the dates of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result.shape) != 1: + raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1") + + if result.shape[0] != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}") + + if not np.issubdtype(result.dtype, np.datetime64): + raise ValueError(f"Result is not a datetime64 array {result.dtype}") + + if len(result) != len(dataset.dates): + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + for d1, d2 in zip(result[:-1], result[1:]): + if d1 >= d2: + raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}") + + frequency = np.diff(result) + if not np.all(frequency == frequency[0]): + raise ValueError("Result contains non-constant frequency") + + +def validate_metadata(report, dataset, name, result): + """Validate the metadata of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + +def validate_missing(report, dataset, name, result): + """Validate the missing values of the dataset.""" + + if not isinstance(result, set): + raise ValueError(f"Result is not a set {type(result)}") + + if not all(isinstance(item, int) for item in result): + raise ValueError("Result contains non-integer values") + + if len(result) > 0: + if min(result) < 0: + raise ValueError("Result contains negative values") + + if max(result) >= len(dataset): + raise ValueError(f"Result contains values greater than {len(dataset)}") + + +def validate_name_to_index(report, dataset, name, result): + """Validate the name to index mapping of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in dataset.variables: + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], int): + raise ValueError(f"Result[{key}] is not an int {type(result[key])}") + + if result[key] < 0 or result[key] >= len(dataset.variables): + raise ValueError(f"Result[{key}] is out of bounds: {result[key]}") + + index_to_name = {v: k for k, v in result.items()} + for i in range(len(dataset.variables)): + if i not in index_to_name: + raise ValueError(f"Result does not contain index `{i}`") + + if not isinstance(index_to_name[i], str): + raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}") + + if index_to_name[i] != dataset.variables[i]: + raise ValueError( + f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}" + ) + + +def validate___getitem__(report, dataset, name, result): + """Validate the __getitem__ method of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if result.shape != dataset.shape[1:]: + raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}") + + +def validate___len__(report, dataset, name, result): + """Validate the __len__ method of the dataset.""" + + if not isinstance(result, int): + raise ValueError(f"Result is not an int {type(result)}") + + if result != dataset.shape[0]: + raise ValueError(f"Result has wrong length: {result} != {len(dataset)}") + + if result != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}") + + +def validate_start_date(report, dataset, name, result): + """Validate the start date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[0]: + raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}") + + +def validate_end_date(report, dataset, name, result): + """Validate the end date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[-1]: + raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}") + + +def validate_field_shape(report, dataset, name, result): + """Validate the field shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if math.prod(result) != dataset.shape[-1]: + raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}") + + +def validate(report, dataset, name, kwargs=None): + + try: + + validate_fn = globals().get(f"validate_{name}", _no_validate) + + # Check if the method is still in the Dataset class + try: + report.method(name, getattr(Dataset, name)) + except AttributeError: + report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.") + return + + # Check if the method is supported by the dataset instance + try: + result = getattr(dataset, name) + except AttributeError as e: + report.failure(name, e) + return + + # Check if the method is callable + if callable(result): + if kwargs is None: + report.internal( + name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly." + ) + return + else: + if kwargs is not None: + report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.") + return + + if kwargs is not None: + result = result(**kwargs) + + if isinstance(result, np.ndarray) and np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + try: + validate_fn(report, dataset, name, result) + except Exception as e: + report.invalid(name, e) + return + + report.success(name) + + except Exception as e: + report.failure(name, e) + + +def validate_dtype(report, dataset, name, result): + """Validate the dtype of the dataset.""" + + if not isinstance(result, np.dtype): + raise ValueError(f"Result is not a np.dtype {type(result)}") + + +def validate_dataset(dataset, costly_checks=False, detailed=False): + """Validate the dataset.""" + + report = Report() + + if costly_checks: + # This check is expensive as it loads the entire dataset into memory + # so we make it optional + default_test_indexing(dataset) + + for i, x in enumerate(dataset): + y = dataset[i] + assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}" + + for name in METHODS: + validate(report, dataset, name, kwargs=KWARGS.get(name)) + + report.summary(detailed=detailed) + + +if __name__ == "__main__": + methods = METHODS_CATEGORIES.copy() + methods.pop("OTHER_METHODS") + + o = set(OTHER_METHODS) + overlap = False + for m in methods: + if set(methods[m]).intersection(set(OTHER_METHODS)): + print( + f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}" + ) + o = o - set(methods[m]) + overlap = True + + for m in methods: + for n in methods: + if n is not m: + if set(methods[m]).intersection(set(methods[n])): + print( + f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}" + ) + + if overlap: + print(sorted(o)) From 3dbc49c791433bc3d44b5db7306f77885e0b5f32 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 16:35:53 +0000 Subject: [PATCH 154/212] revert to some relative import when meaningful --- src/anemoi/datasets/__init__.py | 10 ++-- src/anemoi/datasets/build/gridded/__init__.py | 42 ++++++++-------- .../build/gridded/sources/accumulations.py | 7 +-- .../build/gridded/sources/accumulations2.py | 5 +- .../build/gridded/sources/anemoi_dataset.py | 2 +- .../build/gridded/sources/constants.py | 2 +- .../build/gridded/sources/eccc_fstd.py | 4 +- .../datasets/build/gridded/sources/empty.py | 2 +- .../datasets/build/gridded/sources/fdb.py | 7 +-- .../build/gridded/sources/forcings.py | 2 +- .../datasets/build/gridded/sources/grib.py | 2 +- .../build/gridded/sources/grib_index.py | 2 +- .../build/gridded/sources/hindcasts.py | 4 +- .../datasets/build/gridded/sources/legacy.py | 4 +- .../datasets/build/gridded/sources/mars.py | 3 +- .../datasets/build/gridded/sources/netcdf.py | 4 +- .../datasets/build/gridded/sources/opendap.py | 4 +- .../gridded/sources/planetary_computer.py | 4 +- .../build/gridded/sources/recentre.py | 5 +- .../build/gridded/sources/repeated_dates.py | 4 +- .../datasets/build/gridded/sources/source.py | 3 +- .../build/gridded/sources/tendencies.py | 3 +- .../datasets/build/gridded/sources/xarray.py | 11 +++-- .../build/gridded/sources/xarray_kerchunk.py | 4 +- .../sources/xarray_support/__init__.py | 7 +-- .../gridded/sources/xarray_support/field.py | 6 +-- .../sources/xarray_support/fieldlist.py | 12 ++--- .../gridded/sources/xarray_support/flavour.py | 38 +++++++-------- .../sources/xarray_support/metadata.py | 2 +- .../gridded/sources/xarray_support/time.py | 4 +- .../sources/xarray_support/variable.py | 2 +- .../build/gridded/sources/xarray_zarr.py | 4 +- .../datasets/build/gridded/sources/zenodo.py | 6 +-- .../build/gridded/statistics/__init__.py | 4 +- src/anemoi/datasets/build/gridded/validate.py | 4 +- src/anemoi/datasets/commands/check.py | 2 +- src/anemoi/datasets/commands/grib-index.py | 2 +- src/anemoi/datasets/commands/inspect.py | 4 +- src/anemoi/datasets/misc/__init__.py | 8 ++++ src/anemoi/datasets/use/__init__.py | 8 ++++ src/anemoi/datasets/use/gridded/__init__.py | 10 ++-- src/anemoi/datasets/use/gridded/complement.py | 24 +++++----- src/anemoi/datasets/use/gridded/concat.py | 30 ++++++------ src/anemoi/datasets/use/gridded/dataset.py | 48 +++++++++---------- src/anemoi/datasets/use/gridded/debug.py | 2 +- src/anemoi/datasets/use/gridded/ensemble.py | 22 ++++----- .../datasets/use/gridded/fill_missing.py | 22 ++++----- src/anemoi/datasets/use/gridded/forwards.py | 20 ++++---- src/anemoi/datasets/use/gridded/grids.py | 32 ++++++------- src/anemoi/datasets/use/gridded/indexing.py | 6 +-- .../datasets/use/gridded/interpolate.py | 24 +++++----- src/anemoi/datasets/use/gridded/join.py | 30 ++++++------ src/anemoi/datasets/use/gridded/masked.py | 26 +++++----- src/anemoi/datasets/use/gridded/merge.py | 26 +++++----- src/anemoi/datasets/use/gridded/misc.py | 38 +++++++-------- src/anemoi/datasets/use/gridded/missing.py | 20 ++++---- src/anemoi/datasets/use/gridded/padded.py | 20 ++++---- src/anemoi/datasets/use/gridded/rescale.py | 20 ++++---- src/anemoi/datasets/use/gridded/select.py | 24 +++++----- src/anemoi/datasets/use/gridded/statistics.py | 8 ++-- src/anemoi/datasets/use/gridded/stores.py | 22 ++++----- src/anemoi/datasets/use/gridded/subset.py | 30 ++++++------ src/anemoi/datasets/use/gridded/unchecked.py | 16 +++---- src/anemoi/datasets/use/gridded/xy.py | 12 ++--- src/anemoi/datasets/use/tabular/__init__.py | 8 ++++ .../use/tabular/observations/__init__.py | 8 ++-- .../use/tabular/observations/multi.py | 2 +- .../datasets/use/tabular/records/__init__.py | 6 +-- tests/create/utils/compare.py | 2 +- tests/test_data.py | 30 ++++++------ tests/test_data_gridded.py | 2 +- tests/test_indexing.py | 2 +- tests/test_records.py | 6 +-- tests/test_validate.py | 2 +- tests/xarray/test_flavour.py | 24 +++++----- tests/xarray/test_netcdf.py | 2 +- tests/xarray/test_opendap.py | 2 +- tests/xarray/test_variable.py | 16 +++---- tests/xarray/test_zarr.py | 2 +- tools/build-obs.py | 2 +- 80 files changed, 467 insertions(+), 434 deletions(-) create mode 100644 src/anemoi/datasets/use/tabular/__init__.py diff --git a/src/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py index c38a5de68..84264bc23 100644 --- a/src/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -8,11 +8,11 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.use import MissingDateError -from anemoi.datasets.use import add_dataset_path -from anemoi.datasets.use import add_named_dataset -from anemoi.datasets.use import list_dataset_names -from anemoi.datasets.use import open_dataset +from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use.gridded import add_dataset_path +from anemoi.datasets.use.gridded import add_named_dataset +from anemoi.datasets.use.gridded import list_dataset_names +from anemoi.datasets.use.gridded import open_dataset try: # NOTE: the `_version.py` file must not be present in the git repository diff --git a/src/anemoi/datasets/build/gridded/__init__.py b/src/anemoi/datasets/build/gridded/__init__.py index f28955dd8..0773a5e60 100644 --- a/src/anemoi/datasets/build/gridded/__init__.py +++ b/src/anemoi/datasets/build/gridded/__init__.py @@ -31,25 +31,25 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset -from anemoi.datasets.build.check import DatasetName -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.chunks import ChunkFilter -from anemoi.datasets.build.config import build_output -from anemoi.datasets.build.config import loader_config +from anemoi.datasets.build.gridded.check import DatasetName +from anemoi.datasets.build.gridded.check import check_data_values +from anemoi.datasets.build.gridded.chunks import ChunkFilter +from anemoi.datasets.build.gridded.config import build_output +from anemoi.datasets.build.gridded.config import loader_config +from anemoi.datasets.build.gridded.persistent import build_storage +from anemoi.datasets.build.gridded.statistics import Summary +from anemoi.datasets.build.gridded.statistics import TmpStatistics +from anemoi.datasets.build.gridded.statistics import check_variance +from anemoi.datasets.build.gridded.statistics import compute_statistics +from anemoi.datasets.build.gridded.statistics import default_statistics_dates +from anemoi.datasets.build.gridded.statistics import fix_variance +from anemoi.datasets.build.gridded.utils import normalize_and_check_dates +from anemoi.datasets.build.gridded.writer import ViewCacheArray from anemoi.datasets.build.input import InputBuilder from anemoi.datasets.build.input.trace import enable_trace -from anemoi.datasets.build.persistent import build_storage -from anemoi.datasets.build.statistics import Summary -from anemoi.datasets.build.statistics import TmpStatistics -from anemoi.datasets.build.statistics import check_variance -from anemoi.datasets.build.statistics import compute_statistics -from anemoi.datasets.build.statistics import default_statistics_dates -from anemoi.datasets.build.statistics import fix_variance -from anemoi.datasets.build.utils import normalize_and_check_dates -from anemoi.datasets.build.writer import ViewCacheArray from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.use.misc import as_first_date -from anemoi.datasets.use.misc import as_last_date +from anemoi.datasets.misc import as_first_date +from anemoi.datasets.misc import as_last_date LOG = logging.getLogger(__name__) @@ -192,7 +192,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from anemoi.datasets.build.zarr import add_zarr_dataset + from anemoi.datasets.build.gridded.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -396,7 +396,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from anemoi.datasets.build.utils import cache_context + from anemoi.datasets.build.gridded.utils import cache_context return cache_context(self.cache) @@ -472,7 +472,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from anemoi.datasets.build.patch import apply_patch + from anemoi.datasets.build.gridded.patch import apply_patch apply_patch(self.path, **self.options) @@ -492,7 +492,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from anemoi.datasets.build.size import compute_directory_sizes + from anemoi.datasets.build.gridded.size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -514,7 +514,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from anemoi.datasets.build.zarr import ZarrBuiltRegistry + from anemoi.datasets.build.gridded.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py index 2d45b164a..c01c0fe54 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -20,9 +20,10 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.mars import mars -from anemoi.datasets.build.utils import to_datetime_list +from anemoi.datasets.create.utils import to_datetime_list + +from .legacy import legacy_source +from .mars import mars LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations2.py b/src/anemoi/datasets/build/gridded/sources/accumulations2.py index eb560b4b2..1a15badfa 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations2.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations2.py @@ -18,10 +18,11 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.mars import mars from anemoi.datasets.build.utils import to_datetime_list +from .legacy import legacy_source +from .mars import mars + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py index e890f8130..12d41db23 100644 --- a/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py @@ -9,7 +9,7 @@ import numpy as np -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/constants.py b/src/anemoi/datasets/build/gridded/sources/constants.py index b0c15ce94..104f24863 100644 --- a/src/anemoi/datasets/build/gridded/sources/constants.py +++ b/src/anemoi/datasets/build/gridded/sources/constants.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py b/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py index 59be1ea81..41734e9b6 100644 --- a/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py +++ b/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.build.sources import source_registry -from anemoi.datasets.build.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/build/gridded/sources/empty.py b/src/anemoi/datasets/build/gridded/sources/empty.py index fbcfdecf1..fb7fcd906 100644 --- a/src/anemoi/datasets/build/gridded/sources/empty.py +++ b/src/anemoi/datasets/build/gridded/sources/empty.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/fdb.py b/src/anemoi/datasets/build/gridded/sources/fdb.py index bdadb9d83..fb918b353 100644 --- a/src/anemoi/datasets/build/gridded/sources/fdb.py +++ b/src/anemoi/datasets/build/gridded/sources/fdb.py @@ -16,9 +16,10 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry -from anemoi.datasets.build.source import Source -from anemoi.datasets.build.sources import source_registry -from anemoi.datasets.build.typing import DateList +from anemoi.datasets.build.gridded.typing import DateList + +from ..source import Source +from . import source_registry @source_registry.register("fdb") diff --git a/src/anemoi/datasets/build/gridded/sources/forcings.py b/src/anemoi/datasets/build/gridded/sources/forcings.py index ae3545b3f..bbafaa465 100644 --- a/src/anemoi/datasets/build/gridded/sources/forcings.py +++ b/src/anemoi/datasets/build/gridded/sources/forcings.py @@ -11,7 +11,7 @@ from earthkit.data import from_source -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/grib.py b/src/anemoi/datasets/build/gridded/sources/grib.py index 2d5932347..03bcda475 100644 --- a/src/anemoi/datasets/build/gridded/sources/grib.py +++ b/src/anemoi/datasets/build/gridded/sources/grib.py @@ -20,7 +20,7 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/grib_index.py b/src/anemoi/datasets/build/gridded/sources/grib_index.py index 9c52c462f..ea6878929 100644 --- a/src/anemoi/datasets/build/gridded/sources/grib_index.py +++ b/src/anemoi/datasets/build/gridded/sources/grib_index.py @@ -19,7 +19,7 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from anemoi.datasets.build.sources.legacy import legacy_source +from .legacy import legacy_source LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/hindcasts.py b/src/anemoi/datasets/build/gridded/sources/hindcasts.py index b633b320c..3a7f5eac8 100644 --- a/src/anemoi/datasets/build/gridded/sources/hindcasts.py +++ b/src/anemoi/datasets/build/gridded/sources/hindcasts.py @@ -12,8 +12,8 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.mars import mars +from .legacy import legacy_source +from .mars import mars LOGGER = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/legacy.py b/src/anemoi/datasets/build/gridded/sources/legacy.py index 058443293..4dbd481cd 100644 --- a/src/anemoi/datasets/build/gridded/sources/legacy.py +++ b/src/anemoi/datasets/build/gridded/sources/legacy.py @@ -14,8 +14,8 @@ from collections.abc import Callable from typing import Any -from anemoi.datasets.build.source import Source -from anemoi.datasets.build.sources import source_registry +from ..source import Source +from . import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/mars.py b/src/anemoi/datasets/build/gridded/sources/mars.py index 5ba70950e..db075321e 100644 --- a/src/anemoi/datasets/build/gridded/sources/mars.py +++ b/src/anemoi/datasets/build/gridded/sources/mars.py @@ -16,9 +16,10 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.build.sources.legacy import legacy_source from anemoi.datasets.build.utils import to_datetime_list +from .legacy import legacy_source + DEBUG = False diff --git a/src/anemoi/datasets/build/gridded/sources/netcdf.py b/src/anemoi/datasets/build/gridded/sources/netcdf.py index 175b97a65..a73c095d3 100644 --- a/src/anemoi/datasets/build/gridded/sources/netcdf.py +++ b/src/anemoi/datasets/build/gridded/sources/netcdf.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.xarray import load_many +from .legacy import legacy_source +from .xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/opendap.py b/src/anemoi/datasets/build/gridded/sources/opendap.py index 09c4a0986..483295a8b 100644 --- a/src/anemoi/datasets/build/gridded/sources/opendap.py +++ b/src/anemoi/datasets/build/gridded/sources/opendap.py @@ -12,8 +12,8 @@ import earthkit.data as ekd -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.xarray import load_many +from .legacy import legacy_source +from .xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/planetary_computer.py b/src/anemoi/datasets/build/gridded/sources/planetary_computer.py index 538857a32..b710bcbbe 100644 --- a/src/anemoi/datasets/build/gridded/sources/planetary_computer.py +++ b/src/anemoi/datasets/build/gridded/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.build.sources import source_registry -from anemoi.datasets.build.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/build/gridded/sources/recentre.py b/src/anemoi/datasets/build/gridded/sources/recentre.py index c989dadb6..53ace8152 100644 --- a/src/anemoi/datasets/build/gridded/sources/recentre.py +++ b/src/anemoi/datasets/build/gridded/sources/recentre.py @@ -10,10 +10,11 @@ from copy import deepcopy from typing import Any -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.mars import mars from anemoi.datasets.compute.recentre import recentre as _recentre +from .legacy import legacy_source +from .mars import mars + def to_list(x: list | tuple | str) -> list: """Converts the input to a list. If the input is a string, it splits it by '/'. diff --git a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py index cdc4b5926..d337cead8 100644 --- a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py +++ b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py @@ -19,8 +19,8 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.build.source import Source -from anemoi.datasets.build.sources import source_registry +from .source import Source +from .sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/source.py b/src/anemoi/datasets/build/gridded/sources/source.py index 5d724f4fd..3338daf02 100644 --- a/src/anemoi/datasets/build/gridded/sources/source.py +++ b/src/anemoi/datasets/build/gridded/sources/source.py @@ -12,9 +12,10 @@ from earthkit.data import from_source -from anemoi.datasets.build.sources.legacy import legacy_source from anemoi.datasets.build.utils import to_datetime_list +from .legacy import legacy_source + @legacy_source(__file__) def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: diff --git a/src/anemoi/datasets/build/gridded/sources/tendencies.py b/src/anemoi/datasets/build/gridded/sources/tendencies.py index 0f716f803..2f357b008 100644 --- a/src/anemoi/datasets/build/gridded/sources/tendencies.py +++ b/src/anemoi/datasets/build/gridded/sources/tendencies.py @@ -14,9 +14,10 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.sources.legacy import legacy_source from anemoi.datasets.build.utils import to_datetime_list +from .legacy import legacy_source + def _date_to_datetime(d: Any) -> Any: """Converts a date string or a list/tuple of date strings to datetime objects. diff --git a/src/anemoi/datasets/build/gridded/sources/xarray.py b/src/anemoi/datasets/build/gridded/sources/xarray.py index 077bcd63a..fb10dab8e 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray.py @@ -11,11 +11,12 @@ import earthkit.data as ekd -from anemoi.datasets.build.source import Source -from anemoi.datasets.build.sources.xarray_support import XarrayFieldList -from anemoi.datasets.build.sources.xarray_support import load_many -from anemoi.datasets.build.sources.xarray_support import load_one -from anemoi.datasets.build.typing import DateList +from anemoi.datasets.build.gridded.typing import DateList + +from ..source import Source +from .xarray_support import XarrayFieldList +from .xarray_support import load_many +from .xarray_support import load_one __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py b/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py index caeb5e01a..056d756ca 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.build.sources import source_registry -from anemoi.datasets.build.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py index c40bd5fcd..33a057520 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py @@ -15,9 +15,10 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.patterns import iterate_patterns -from anemoi.datasets.build.sources.xarray_support.fieldlist import XarrayFieldList +from anemoi.datasets.create.sources.patterns import iterate_patterns + +from ..legacy import legacy_source +from .fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py index 7de7e6046..78f7de041 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from anemoi.datasets.build.sources.xarray_support.coordinates import extract_single_value -from anemoi.datasets.build.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.build.sources.xarray_support.metadata import XArrayMetadata +from .coordinates import extract_single_value +from .coordinates import is_scalar +from .metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py index 1798a1d4d..48f9cf0e1 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from anemoi.datasets.build.sources.xarray_support.field import EmptyFieldList -from anemoi.datasets.build.sources.xarray_support.flavour import CoordinateGuesser -from anemoi.datasets.build.sources.xarray_support.patch import patch_dataset -from anemoi.datasets.build.sources.xarray_support.time import Time -from anemoi.datasets.build.sources.xarray_support.variable import FilteredVariable -from anemoi.datasets.build.sources.xarray_support.variable import Variable +from .field import EmptyFieldList +from .flavour import CoordinateGuesser +from .patch import patch_dataset +from .time import Time +from .variable import FilteredVariable +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py index 94d1424ef..80f0b6a62 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from anemoi.datasets.build.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import PointCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.build.sources.xarray_support.grid import Grid -from anemoi.datasets.build.sources.xarray_support.grid import MeshedGrid -from anemoi.datasets.build.sources.xarray_support.grid import MeshProjectionGrid -from anemoi.datasets.build.sources.xarray_support.grid import UnstructuredGrid -from anemoi.datasets.build.sources.xarray_support.grid import UnstructuredProjectionGrid +from .coordinates import Coordinate +from .coordinates import DateCoordinate +from .coordinates import EnsembleCoordinate +from .coordinates import LatitudeCoordinate +from .coordinates import LevelCoordinate +from .coordinates import LongitudeCoordinate +from .coordinates import PointCoordinate +from .coordinates import ScalarCoordinate +from .coordinates import StepCoordinate +from .coordinates import TimeCoordinate +from .coordinates import UnsupportedCoordinate +from .coordinates import XCoordinate +from .coordinates import YCoordinate +from .coordinates import is_scalar +from .grid import Grid +from .grid import MeshedGrid +from .grid import MeshProjectionGrid +from .grid import UnstructuredGrid +from .grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py index 104d1fb62..23713ae74 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from anemoi.datasets.build.sources.xarray_support.field import XArrayField + from .field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py index 1a875473f..847b21598 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from anemoi.datasets.build.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.build.sources.xarray_support.variable import Variable +from .coordinates import Coordinate +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py index 541e60d32..5d2c1c5b1 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from anemoi.datasets.build.sources.xarray_support.field import XArrayField +from .field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py index 5e9da7f44..e91de781e 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py @@ -11,8 +11,8 @@ import earthkit.data as ekd -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.xarray import load_many +from .legacy import legacy_source +from .xarray import load_many @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/sources/zenodo.py b/src/anemoi/datasets/build/gridded/sources/zenodo.py index 774afd277..1b746bb42 100644 --- a/src/anemoi/datasets/build/gridded/sources/zenodo.py +++ b/src/anemoi/datasets/build/gridded/sources/zenodo.py @@ -14,9 +14,9 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from anemoi.datasets.build.sources.legacy import legacy_source -from anemoi.datasets.build.sources.patterns import iterate_patterns -from anemoi.datasets.build.sources.xarray import load_one +from .legacy import legacy_source +from .patterns import iterate_patterns +from .xarray import load_one @legacy_source(__file__) diff --git a/src/anemoi/datasets/build/gridded/statistics/__init__.py b/src/anemoi/datasets/build/gridded/statistics/__init__.py index f7ece19bb..e9835bfe2 100644 --- a/src/anemoi/datasets/build/gridded/statistics/__init__.py +++ b/src/anemoi/datasets/build/gridded/statistics/__init__.py @@ -23,8 +23,8 @@ from anemoi.utils.provenance import gather_provenance_info from numpy.typing import NDArray -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.statistics.summary import Summary +from anemoi.datasets.build.gridded.check import check_data_values +from anemoi.datasets.build.gridded.statistics.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/validate.py b/src/anemoi/datasets/build/gridded/validate.py index a1e168116..9c103f470 100644 --- a/src/anemoi/datasets/build/gridded/validate.py +++ b/src/anemoi/datasets/build/gridded/validate.py @@ -14,8 +14,8 @@ import numpy as np -from anemoi.datasets.testing import default_test_indexing -from anemoi.datasets.use.dataset import Dataset +from anemoi.datasets.misc.testing import default_test_indexing +from anemoi.datasets.use.gridded.dataset import Dataset LOG = logging.getLogger(__name__) # List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 4ac355515..f74165dd3 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,7 +13,7 @@ import yaml -from anemoi.datasets.build.check import DatasetName +from anemoi.datasets.build.gridded.check import DatasetName from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index 072099bdd..c77499b72 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -83,7 +83,7 @@ def match(path: str) -> bool: """ return fnmatch.fnmatch(os.path.basename(path), args.match) - from anemoi.datasets.build.sources.grib_index import GribIndex + from anemoi.datasets.build.gridded.sources.grib_index import GribIndex index = GribIndex( args.index, diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 59490bd33..930313de8 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -28,8 +28,8 @@ from anemoi.datasets import open_dataset from anemoi.datasets.commands import Command -from anemoi.datasets.use.stores import open_zarr -from anemoi.datasets.use.stores import zarr_lookup +from anemoi.datasets.use.gridded.stores import open_zarr +from anemoi.datasets.use.gridded.stores import zarr_lookup LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/misc/__init__.py b/src/anemoi/datasets/misc/__init__.py index e69de29bb..9fc775e54 100644 --- a/src/anemoi/datasets/misc/__init__.py +++ b/src/anemoi/datasets/misc/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/use/__init__.py b/src/anemoi/datasets/use/__init__.py index e69de29bb..9fc775e54 100644 --- a/src/anemoi/datasets/use/__init__.py +++ b/src/anemoi/datasets/use/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/use/gridded/__init__.py b/src/anemoi/datasets/use/gridded/__init__.py index f6f8f5a3d..6af38b2f4 100644 --- a/src/anemoi/datasets/use/gridded/__init__.py +++ b/src/anemoi/datasets/use/gridded/__init__.py @@ -15,13 +15,13 @@ # from .dataset import FullIndex # from .dataset import Shape # from .dataset import TupleIndex -from anemoi.datasets.use.misc import _open_dataset -from anemoi.datasets.use.misc import _save_dataset -from anemoi.datasets.use.misc import add_dataset_path -from anemoi.datasets.use.misc import add_named_dataset +from anemoi.datasets.use.gridded.misc import _open_dataset +from anemoi.datasets.use.gridded.misc import _save_dataset +from anemoi.datasets.use.gridded.misc import add_dataset_path +from anemoi.datasets.use.gridded.misc import add_named_dataset if TYPE_CHECKING: - from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.gridded.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/complement.py b/src/anemoi/datasets/use/gridded/complement.py index df9b5cc86..1881a74fa 100644 --- a/src/anemoi/datasets/use/gridded/complement.py +++ b/src/anemoi/datasets/use/gridded/complement.py @@ -16,18 +16,18 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.grids import nearest_grid_points -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open_dataset +from anemoi.datasets.misc.grids import nearest_grid_points +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open_dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/concat.py b/src/anemoi/datasets/use/gridded/concat.py index 9b9968468..2f3811995 100644 --- a/src/anemoi/datasets/use/gridded/concat.py +++ b/src/anemoi/datasets/use/gridded/concat.py @@ -16,20 +16,20 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import length_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) @@ -229,7 +229,7 @@ def check_dataset_compatibility(cls, datasets: list[Any], fill_missing_gaps: boo s = ranges[i + 1] if r[1] + frequency != s[0]: if fill_missing_gaps: - from anemoi.datasets.use.missing import MissingDataset + from anemoi.datasets.use.gridded.missing import MissingDataset result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency)) else: diff --git a/src/anemoi/datasets/use/gridded/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py index cbfdfd2b6..3094c1128 100644 --- a/src/anemoi/datasets/use/gridded/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -34,8 +34,8 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import Source +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source if TYPE_CHECKING: import matplotlib @@ -165,7 +165,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # This one must be first if "fill_missing_dates" in kwargs: - from anemoi.datasets.use.fill_missing import fill_missing_dates_factory + from anemoi.datasets.use.gridded.fill_missing import fill_missing_dates_factory fill_missing_dates = kwargs.pop("fill_missing_dates") ds = fill_missing_dates_factory(self, fill_missing_dates, kwargs) @@ -179,7 +179,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": if padding: if padding != "empty": raise ValueError(f"Only 'empty' padding is supported, got {padding=}") - from anemoi.datasets.use.padded import Padded + from anemoi.datasets.use.gridded.padded import Padded frequency = kwargs.pop("frequency", self.frequency) return ( @@ -188,14 +188,14 @@ def __subset(self, **kwargs: Any) -> "Dataset": .mutate() ) - from anemoi.datasets.use.subset import Subset + from anemoi.datasets.use.gridded.subset import Subset return ( Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate() ) if "frequency" in kwargs: - from anemoi.datasets.use.subset import Subset + from anemoi.datasets.use.gridded.subset import Subset if "interpolate_frequency" in kwargs: raise ValueError("Cannot use both `frequency` and `interpolate_frequency`") @@ -208,38 +208,38 @@ def __subset(self, **kwargs: Any) -> "Dataset": ) if "select" in kwargs: - from anemoi.datasets.use.select import Select + from anemoi.datasets.use.gridded.select import Select select = kwargs.pop("select") return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate() if "drop" in kwargs: - from anemoi.datasets.use.select import Select + from anemoi.datasets.use.gridded.select import Select drop = kwargs.pop("drop") return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate() if "reorder" in kwargs: - from anemoi.datasets.use.select import Select + from anemoi.datasets.use.gridded.select import Select reorder = kwargs.pop("reorder") return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() if "rename" in kwargs: - from anemoi.datasets.use.select import Rename + from anemoi.datasets.use.gridded.select import Rename rename = kwargs.pop("rename") return Rename(self, rename)._subset(**kwargs).mutate() if "rescale" in kwargs: - from anemoi.datasets.use.rescale import Rescale + from anemoi.datasets.use.gridded.rescale import Rescale rescale = kwargs.pop("rescale") return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from anemoi.datasets.use import open_dataset - from anemoi.datasets.use.statistics import Statistics + from anemoi.datasets.use.gridded import open_dataset + from anemoi.datasets.use.gridded.statistics import Statistics statistics = kwargs.pop("statistics") @@ -247,26 +247,26 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Note: trim_edge should go before thinning if "trim_edge" in kwargs: - from anemoi.datasets.use.masked import TrimEdge + from anemoi.datasets.use.gridded.masked import TrimEdge edge = kwargs.pop("trim_edge") return TrimEdge(self, edge)._subset(**kwargs).mutate() if "thinning" in kwargs: - from anemoi.datasets.use.masked import Thinning + from anemoi.datasets.use.gridded.masked import Thinning thinning = kwargs.pop("thinning") method = kwargs.pop("method", "every-nth") return Thinning(self, thinning, method)._subset(**kwargs).mutate() if "area" in kwargs: - from anemoi.datasets.use.masked import Cropping + from anemoi.datasets.use.gridded.masked import Cropping bbox = kwargs.pop("area") return Cropping(self, bbox)._subset(**kwargs).mutate() if "number" in kwargs or "numbers" in kwargs or "member" in kwargs or "members" in kwargs: - from anemoi.datasets.use.ensemble import Number + from anemoi.datasets.use.gridded.ensemble import Number members = {} for key in ["number", "numbers", "member", "members"]: @@ -276,13 +276,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Number(self, **members)._subset(**kwargs).mutate() if "set_missing_dates" in kwargs: - from anemoi.datasets.use.missing import MissingDates + from anemoi.datasets.use.gridded.missing import MissingDates set_missing_dates = kwargs.pop("set_missing_dates") return MissingDates(self, set_missing_dates)._subset(**kwargs).mutate() if "skip_missing_dates" in kwargs: - from anemoi.datasets.use.missing import SkipMissingDates + from anemoi.datasets.use.gridded.missing import SkipMissingDates if "expected_access" not in kwargs: raise ValueError("`expected_access` is required with `skip_missing_dates`") @@ -294,13 +294,13 @@ def __subset(self, **kwargs: Any) -> "Dataset": return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() if "interpolate_frequency" in kwargs: - from anemoi.datasets.use.interpolate import InterpolateFrequency + from anemoi.datasets.use.gridded.interpolate import InterpolateFrequency interpolate_frequency = kwargs.pop("interpolate_frequency") return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate() if "interpolate_variables" in kwargs: - from anemoi.datasets.use.interpolate import InterpolateNearest + from anemoi.datasets.use.gridded.interpolate import InterpolateNearest interpolate_variables = kwargs.pop("interpolate_variables") max_distance = kwargs.pop("max_distance", None) @@ -308,7 +308,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": # Keep last if "shuffle" in kwargs: - from anemoi.datasets.use.subset import Subset + from anemoi.datasets.use.gridded.subset import Subset shuffle = kwargs.pop("shuffle") @@ -372,8 +372,8 @@ def _dates_to_indices( list of int The list of indices. """ - from anemoi.datasets.use.misc import as_first_date - from anemoi.datasets.use.misc import as_last_date + from anemoi.datasets.use.gridded.misc import as_first_date + from anemoi.datasets.use.gridded.misc import as_last_date # TODO: optimize diff --git a/src/anemoi/datasets/use/gridded/debug.py b/src/anemoi/datasets/use/gridded/debug.py index 84c6f0b64..25b6649a6 100644 --- a/src/anemoi/datasets/use/gridded/debug.py +++ b/src/anemoi/datasets/use/gridded/debug.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.gridded.dataset import Dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/ensemble.py b/src/anemoi/datasets/use/gridded/ensemble.py index 1cf4d885b..0d1aa15b2 100644 --- a/src/anemoi/datasets/use/gridded/ensemble.py +++ b/src/anemoi/datasets/use/gridded/ensemble.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.forwards import GivenAxis -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.forwards import GivenAxis +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/fill_missing.py b/src/anemoi/datasets/use/gridded/fill_missing.py index 649a7e08b..337549cfc 100644 --- a/src/anemoi/datasets/use/gridded/fill_missing.py +++ b/src/anemoi/datasets/use/gridded/fill_missing.py @@ -14,17 +14,17 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use import MissingDateError -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/forwards.py b/src/anemoi/datasets/use/gridded/forwards.py index 058c66e9c..d0b8dedcb 100644 --- a/src/anemoi/datasets/use/gridded/forwards.py +++ b/src/anemoi/datasets/use/gridded/forwards.py @@ -18,16 +18,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import length_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/grids.py b/src/anemoi/datasets/use/gridded/grids.py index 423a57deb..790d93d0d 100644 --- a/src/anemoi/datasets/use/gridded/grids.py +++ b/src/anemoi/datasets/use/gridded/grids.py @@ -16,21 +16,21 @@ from numpy.typing import NDArray from scipy.spatial import cKDTree -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.forwards import GivenAxis -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import length_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.forwards import GivenAxis +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) @@ -355,7 +355,7 @@ def _initialize_masks(self) -> None: ValueError If the global mask dimension does not match the global dataset grid points. """ - from anemoi.datasets.grids import cutout_mask + from anemoi.datasets.misc.grids import cutout_mask for i, lam in enumerate(self.lams): assert len(lam.shape) == len( diff --git a/src/anemoi/datasets/use/gridded/indexing.py b/src/anemoi/datasets/use/gridded/indexing.py index f152e907f..b333ae361 100644 --- a/src/anemoi/datasets/use/gridded/indexing.py +++ b/src/anemoi/datasets/use/gridded/indexing.py @@ -15,9 +15,9 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex def _tuple_with_slices(t: TupleIndex, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]: diff --git a/src/anemoi/datasets/use/gridded/interpolate.py b/src/anemoi/datasets/use/gridded/interpolate.py index 5d8f70bf3..f3c5155f9 100644 --- a/src/anemoi/datasets/use/gridded/interpolate.py +++ b/src/anemoi/datasets/use/gridded/interpolate.py @@ -17,17 +17,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -227,7 +227,7 @@ def __init__(self, dataset: Dataset, interpolate_variables: list[str], max_dista max_distance : Optional[float], optional The maximum distance for nearest neighbor search, by default None. """ - from anemoi.datasets.grids import nearest_grid_points + from anemoi.datasets.misc.grids import nearest_grid_points super().__init__(dataset) self.vars = interpolate_variables diff --git a/src/anemoi/datasets/use/gridded/join.py b/src/anemoi/datasets/use/gridded/join.py index b852ab19f..4c146a73d 100644 --- a/src/anemoi/datasets/use/gridded/join.py +++ b/src/anemoi/datasets/use/gridded/join.py @@ -16,20 +16,20 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import Source -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) @@ -173,7 +173,7 @@ def _overlay(self) -> Dataset: if not ok: LOG.warning("Dataset %r completely overridden.", d) - from anemoi.datasets.use.select import Select + from anemoi.datasets.use.gridded.select import Select return Select(self, indices, {"overlay": variables}) diff --git a/src/anemoi/datasets/use/gridded/masked.py b/src/anemoi/datasets/use/gridded/masked.py index f64bb2f59..d12fc54d4 100644 --- a/src/anemoi/datasets/use/gridded/masked.py +++ b/src/anemoi/datasets/use/gridded/masked.py @@ -15,18 +15,18 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.grids import cropping_mask -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.misc.grids import cropping_mask +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -214,7 +214,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from anemoi.datasets.use import open_dataset + from anemoi.datasets.use.gridded import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/use/gridded/merge.py b/src/anemoi/datasets/use/gridded/merge.py index a2f3a83bd..d6a1943e5 100644 --- a/src/anemoi/datasets/use/gridded/merge.py +++ b/src/anemoi/datasets/use/gridded/merge.py @@ -16,19 +16,19 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use import MissingDateError -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index 4709265be..36afbee3c 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -23,7 +23,7 @@ from numpy.typing import NDArray if TYPE_CHECKING: - from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.gridded.dataset import Dataset LOG = logging.getLogger(__name__) @@ -323,11 +323,11 @@ def _concat_or_join(datasets: list["Dataset"], kwargs: dict[str, Any]) -> tuple[ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets] if len(set(ranges)) == 1: - from anemoi.datasets.use.join import Join + from anemoi.datasets.use.gridded.join import Join return Join(datasets)._overlay(), kwargs - from anemoi.datasets.use.concat import Concat + from anemoi.datasets.use.gridded.concat import Concat Concat.check_dataset_compatibility(datasets) @@ -347,9 +347,9 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " Dataset The opened dataset. """ - from anemoi.datasets.use.dataset import Dataset - from anemoi.datasets.use.stores import Zarr - from anemoi.datasets.use.stores import zarr_lookup + from anemoi.datasets.use.gridded.dataset import Dataset + from anemoi.datasets.use.gridded.stores import Zarr + from anemoi.datasets.use.gridded.stores import zarr_lookup if isinstance(a, str) and len(a.split(".")) in [2, 3]: @@ -359,7 +359,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in metadata: raise ValueError(f"Metadata for {a} does not contain 'backend' key") - from anemoi.datasets.use.records import open_records_dataset + from anemoi.datasets.use.gridded.records import open_records_dataset return open_records_dataset(a, backend=metadata["backend"]) @@ -501,7 +501,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from anemoi.datasets.use.observations import observations_factory + from anemoi.datasets.use.gridded.observations import observations_factory assert not sets, sets @@ -509,70 +509,70 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "xy" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.xy import xy_factory + from anemoi.datasets.use.gridded.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "x" in kwargs and "y" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.xy import xy_factory + from anemoi.datasets.use.gridded.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "zip" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.xy import zip_factory + from anemoi.datasets.use.gridded.gridded.xy import zip_factory assert not sets, sets return zip_factory(args, kwargs).mutate() if "chain" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.unchecked import chain_factory + from anemoi.datasets.use.gridded.unchecked import chain_factory assert not sets, sets return chain_factory(args, kwargs).mutate() if "join" in kwargs: - from anemoi.datasets.use.join import join_factory + from anemoi.datasets.use.gridded.join import join_factory assert not sets, sets return join_factory(args, kwargs).mutate() if "concat" in kwargs: - from anemoi.datasets.use.concat import concat_factory + from anemoi.datasets.use.gridded.concat import concat_factory assert not sets, sets return concat_factory(args, kwargs).mutate() if "merge" in kwargs: - from anemoi.datasets.use.merge import merge_factory + from anemoi.datasets.use.gridded.merge import merge_factory assert not sets, sets return merge_factory(args, kwargs).mutate() if "ensemble" in kwargs: - from anemoi.datasets.use.ensemble import ensemble_factory + from anemoi.datasets.use.gridded.ensemble import ensemble_factory assert not sets, sets return ensemble_factory(args, kwargs).mutate() if "grids" in kwargs: - from anemoi.datasets.use.grids import grids_factory + from anemoi.datasets.use.gridded.grids import grids_factory assert not sets, sets return grids_factory(args, kwargs).mutate() if "cutout" in kwargs: - from anemoi.datasets.use.grids import cutout_factory + from anemoi.datasets.use.gridded.grids import cutout_factory assert not sets, sets return cutout_factory(args, kwargs).mutate() if "complement" in kwargs: - from anemoi.datasets.use.complement import complement_factory + from anemoi.datasets.use.gridded.complement import complement_factory assert not sets, sets return complement_factory(args, kwargs).mutate() diff --git a/src/anemoi/datasets/use/gridded/missing.py b/src/anemoi/datasets/use/gridded/missing.py index 32a0d5c69..b1e83638d 100644 --- a/src/anemoi/datasets/use/gridded/missing.py +++ b/src/anemoi/datasets/use/gridded/missing.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.build.utils import to_datetime -from anemoi.datasets.use import MissingDateError -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.build.gridded.utils import to_datetime +from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/padded.py b/src/anemoi/datasets/use/gridded/padded.py index 53c45071d..1b23fb6fb 100644 --- a/src/anemoi/datasets/use/gridded/padded.py +++ b/src/anemoi/datasets/use/gridded/padded.py @@ -17,16 +17,16 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.misc import as_first_date -from anemoi.datasets.use.misc import as_last_date +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.misc import as_first_date +from anemoi.datasets.use.gridded.misc import as_last_date LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/rescale.py b/src/anemoi/datasets/use/gridded/rescale.py index 630199efc..8426bffbe 100644 --- a/src/anemoi/datasets/use/gridded/rescale.py +++ b/src/anemoi/datasets/use/gridded/rescale.py @@ -16,16 +16,16 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/select.py b/src/anemoi/datasets/use/gridded/select.py index 7a57639ce..3cb813bae 100644 --- a/src/anemoi/datasets/use/gridded/select.py +++ b/src/anemoi/datasets/use/gridded/select.py @@ -15,18 +15,18 @@ from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import Source -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/statistics.py b/src/anemoi/datasets/use/gridded/statistics.py index e6439ecec..236ce1b7a 100644 --- a/src/anemoi/datasets/use/gridded/statistics.py +++ b/src/anemoi/datasets/use/gridded/statistics.py @@ -15,10 +15,10 @@ from numpy.typing import NDArray -from anemoi.datasets.use import open_dataset -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.forwards import Forwards +from anemoi.datasets.use.gridded import open_dataset +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.forwards import Forwards LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/stores.py b/src/anemoi/datasets/use/gridded/stores.py index 4514f06c6..f48ebedf5 100644 --- a/src/anemoi/datasets/use/gridded/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -22,17 +22,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use import MissingDateError -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import DEBUG_ZARR_LOADING -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import Source -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.misc import load_config +from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import DEBUG_ZARR_LOADING +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.misc import load_config LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/subset.py b/src/anemoi/datasets/use/gridded/subset.py index bf65bb4a2..13b5d71e0 100644 --- a/src/anemoi/datasets/use/gridded/subset.py +++ b/src/anemoi/datasets/use/gridded/subset.py @@ -19,19 +19,19 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.dataset import TupleIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.debug import Source -from anemoi.datasets.use.debug import debug_indexing -from anemoi.datasets.use.forwards import Forwards -from anemoi.datasets.use.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.indexing import expand_list_indexing -from anemoi.datasets.use.indexing import index_to_slices -from anemoi.datasets.use.indexing import make_slice_or_index_from_list_or_tuple -from anemoi.datasets.use.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import make_slice_or_index_from_list_or_tuple +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) @@ -61,7 +61,7 @@ def _start(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the start date. """ - from anemoi.datasets.use.misc import as_first_date + from anemoi.datasets.use.gridded.misc import as_first_date c = as_first_date(a, dates) d = as_first_date(b, dates) @@ -82,7 +82,7 @@ def _end(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the end date. """ - from anemoi.datasets.use.misc import as_last_date + from anemoi.datasets.use.gridded.misc import as_last_date c = as_last_date(a, dates) d = as_last_date(b, dates) diff --git a/src/anemoi/datasets/use/gridded/unchecked.py b/src/anemoi/datasets/use/gridded/unchecked.py index 5f3a9fb4a..96907a651 100644 --- a/src/anemoi/datasets/use/gridded/unchecked.py +++ b/src/anemoi/datasets/use/gridded/unchecked.py @@ -18,14 +18,14 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.concat import ConcatMixin -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.dataset import Shape -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.concat import ConcatMixin +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/xy.py b/src/anemoi/datasets/use/gridded/xy.py index 7d65201b3..da51bde61 100644 --- a/src/anemoi/datasets/use/gridded/xy.py +++ b/src/anemoi/datasets/use/gridded/xy.py @@ -12,12 +12,12 @@ from functools import cached_property from typing import Any -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.dataset import FullIndex -from anemoi.datasets.use.debug import Node -from anemoi.datasets.use.forwards import Combined -from anemoi.datasets.use.misc import _auto_adjust -from anemoi.datasets.use.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/tabular/__init__.py b/src/anemoi/datasets/use/tabular/__init__.py new file mode 100644 index 000000000..9fc775e54 --- /dev/null +++ b/src/anemoi/datasets/use/tabular/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/use/tabular/observations/__init__.py b/src/anemoi/datasets/use/tabular/observations/__init__.py index 58e7fa822..7d1c278f9 100644 --- a/src/anemoi/datasets/use/tabular/observations/__init__.py +++ b/src/anemoi/datasets/use/tabular/observations/__init__.py @@ -14,8 +14,8 @@ import numpy as np from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.use.dataset import Dataset -from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.debug import Node LOG = logging.getLogger(__name__) @@ -138,7 +138,7 @@ def __init__(self, dataset, frequency=None, window=None): if isinstance(dataset, zarr.hierarchy.Group): dataset = dataset._store.path - from anemoi.datasets.use.stores import zarr_lookup + from anemoi.datasets.use.gridded.stores import zarr_lookup dataset = zarr_lookup(dataset) self.path = dataset @@ -176,7 +176,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from anemoi.datasets.use.observations.legacy_obs_dataset import ObsDataset + from anemoi.datasets.use.gridded.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/src/anemoi/datasets/use/tabular/observations/multi.py b/src/anemoi/datasets/use/tabular/observations/multi.py index a6b6be176..31fc4e1dd 100644 --- a/src/anemoi/datasets/use/tabular/observations/multi.py +++ b/src/anemoi/datasets/use/tabular/observations/multi.py @@ -10,7 +10,7 @@ import logging import os -from anemoi.datasets.use import open_dataset +from anemoi.datasets.use.gridded import open_dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index efd368606..2e78d22c1 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -16,7 +16,7 @@ import numpy as np from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.use.records.backends import backend_factory +from anemoi.datasets.use.gridded.records.backends import backend_factory LOG = logging.getLogger(__name__) @@ -91,8 +91,8 @@ def _subset(self, **kwargs): if start is not None or end is not None: def _dates_to_indices(start, end): - from anemoi.datasets.use.misc import as_first_date - from anemoi.datasets.use.misc import as_last_date + from anemoi.datasets.use.gridded.misc import as_first_date + from anemoi.datasets.use.gridded.misc import as_last_date start = self.dates[0] if start is None else as_first_date(start, self.dates) end = self.dates[-1] if end is None else as_last_date(end, self.dates) diff --git a/tests/create/utils/compare.py b/tests/create/utils/compare.py index 6da96ae95..8fd118994 100644 --- a/tests/create/utils/compare.py +++ b/tests/create/utils/compare.py @@ -12,7 +12,7 @@ import numpy as np from anemoi.datasets import open_dataset -from anemoi.datasets.use.stores import open_zarr +from anemoi.datasets.use.gridded.stores import open_zarr class Comparer: diff --git a/tests/test_data.py b/tests/test_data.py index c19f54b6c..575462685 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -25,20 +25,20 @@ from anemoi.datasets import open_dataset from anemoi.datasets.commands.inspect import InspectZarr from anemoi.datasets.commands.inspect import NoVersion -from anemoi.datasets.testing import default_test_indexing -from anemoi.datasets.use import save_dataset -from anemoi.datasets.use.concat import Concat -from anemoi.datasets.use.ensemble import Ensemble -from anemoi.datasets.use.grids import GridsBase -from anemoi.datasets.use.join import Join -from anemoi.datasets.use.misc import as_first_date -from anemoi.datasets.use.misc import as_last_date -from anemoi.datasets.use.padded import Padded -from anemoi.datasets.use.select import Rename -from anemoi.datasets.use.select import Select -from anemoi.datasets.use.statistics import Statistics -from anemoi.datasets.use.stores import Zarr -from anemoi.datasets.use.subset import Subset +from anemoi.datasets.misc.testing import default_test_indexing +from anemoi.datasets.use.gridded import save_dataset +from anemoi.datasets.use.gridded.concat import Concat +from anemoi.datasets.use.gridded.ensemble import Ensemble +from anemoi.datasets.use.gridded.grids import GridsBase +from anemoi.datasets.use.gridded.join import Join +from anemoi.datasets.use.gridded.misc import as_first_date +from anemoi.datasets.use.gridded.misc import as_last_date +from anemoi.datasets.use.gridded.padded import Padded +from anemoi.datasets.use.gridded.select import Rename +from anemoi.datasets.use.gridded.select import Select +from anemoi.datasets.use.gridded.statistics import Statistics +from anemoi.datasets.use.gridded.stores import Zarr +from anemoi.datasets.use.gridded.subset import Subset VALUES = 10 @@ -60,7 +60,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): with patch("zarr.convenience.open", zarr_from_str): - with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): + with patch("anemoi.datasets.use.gridded.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) return wrapper diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index d493a50e7..6e655f601 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -42,7 +42,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): with patch("zarr.convenience.open", zarr_from_str): - with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): + with patch("anemoi.datasets.use.gridded.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) return wrapper diff --git a/tests/test_indexing.py b/tests/test_indexing.py index cd5c6f25d..494376aa9 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -10,7 +10,7 @@ import numpy as np -from anemoi.datasets.use.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices def test_length_to_slices() -> None: diff --git a/tests/test_records.py b/tests/test_records.py index c96041cb7..6fadaf26e 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -11,9 +11,9 @@ import numpy as np import pytest -from anemoi.datasets.use import open_dataset -from anemoi.datasets.use.records import Record -from anemoi.datasets.use.records import Tabular +from anemoi.datasets.use.gridded import open_dataset +from anemoi.datasets.use.tabular.records import Record +from anemoi.datasets.use.tabular.records import Tabular def check_numpy(x, y): diff --git a/tests/test_validate.py b/tests/test_validate.py index 21fd250e1..4cd590ac9 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -17,7 +17,7 @@ from anemoi.utils.testing import GetTestData from anemoi.utils.testing import skip_if_offline -from anemoi.datasets.validate import validate_dataset +from anemoi.datasets.misc.validate import validate_dataset @pytest.fixture diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py index cdf093e5f..ab058839e 100644 --- a/tests/xarray/test_flavour.py +++ b/tests/xarray/test_flavour.py @@ -11,18 +11,18 @@ import pytest import xarray as xr -from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.build.sources.xarray_support.flavour import DefaultCoordinateGuesser +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.flavour import DefaultCoordinateGuesser def create_ds(var_name, standard_name, long_name, units, coord_length=5): diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py index 1619a47ac..7994789f6 100644 --- a/tests/xarray/test_netcdf.py +++ b/tests/xarray/test_netcdf.py @@ -12,7 +12,7 @@ import xarray as xr from multiurl import download -from anemoi.datasets.build.sources.xarray import XarrayFieldList +from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList URLS = { "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/efas.nc": dict(length=3), diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index b8f4eac9e..1625ef32c 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -12,7 +12,7 @@ import xarray as xr from anemoi.utils.testing import skip_if_offline -from anemoi.datasets.build.sources.xarray import XarrayFieldList +from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList from anemoi.datasets.testing import assert_field_list diff --git a/tests/xarray/test_variable.py b/tests/xarray/test_variable.py index afb82ecbf..0f060a32e 100644 --- a/tests/xarray/test_variable.py +++ b/tests/xarray/test_variable.py @@ -13,14 +13,14 @@ import pytest import xarray as xr -from anemoi.datasets.build.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.build.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.build.sources.xarray_support.time import ForecastFromValidTimeAndStep -from anemoi.datasets.build.sources.xarray_support.variable import Variable +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.build.gridded.sources.xarray_support.time import ForecastFromValidTimeAndStep +from anemoi.datasets.build.gridded.sources.xarray_support.variable import Variable @pytest.fixture diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 5f166be22..8202ed760 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -12,7 +12,7 @@ from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages -from anemoi.datasets.build.sources.xarray import XarrayFieldList +from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList from anemoi.datasets.testing import assert_field_list diff --git a/tools/build-obs.py b/tools/build-obs.py index 2ccd1c1a2..5013763cb 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,7 +28,7 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") - from anemoi.datasets.use.records.backends import writer_backend_factory + from anemoi.datasets.use.gridded.records.backends import writer_backend_factory if os.path.exists(output): if overwrite: From 8ba13081163cee016c70213a59200750dc53cb8f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 16:36:55 +0000 Subject: [PATCH 155/212] revert to some relative import when meaningful --- .../datasets/build/gridded/sources/xarray_support/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py index 33a057520..ec3a01144 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py @@ -15,7 +15,7 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources.patterns import iterate_patterns +from anemoi.datasets.create.gridded.sources.patterns import iterate_patterns from ..legacy import legacy_source from .fieldlist import XarrayFieldList From 17d00b5f4354af5c908519cc16322d8119f308a8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 16:45:59 +0000 Subject: [PATCH 156/212] update --- src/anemoi/datasets/build/__init__.py | 8 ++++++++ src/anemoi/datasets/build/gridded/__init__.py | 4 ++-- .../datasets/build/gridded/sources/accumulations.py | 2 +- src/anemoi/datasets/build/gridded/sources/fdb.py | 2 +- src/anemoi/datasets/build/gridded/statistics/summary.py | 6 +++--- src/anemoi/datasets/use/gridded/__init__.py | 2 +- src/anemoi/datasets/use/tabular/records/__init__.py | 2 +- 7 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 src/anemoi/datasets/build/__init__.py diff --git a/src/anemoi/datasets/build/__init__.py b/src/anemoi/datasets/build/__init__.py new file mode 100644 index 000000000..9fc775e54 --- /dev/null +++ b/src/anemoi/datasets/build/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/build/gridded/__init__.py b/src/anemoi/datasets/build/gridded/__init__.py index 0773a5e60..edee18a70 100644 --- a/src/anemoi/datasets/build/gridded/__init__.py +++ b/src/anemoi/datasets/build/gridded/__init__.py @@ -48,8 +48,8 @@ from anemoi.datasets.build.input import InputBuilder from anemoi.datasets.build.input.trace import enable_trace from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.misc import as_first_date -from anemoi.datasets.misc import as_last_date +from anemoi.datasets.use.gridded import as_first_date +from anemoi.datasets.use.gridded import as_last_date LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py index c01c0fe54..6acecbf98 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -994,7 +994,7 @@ def accumulations( and request.get("stream", "oper") == "oper" and request.get("accumulation_period") == 24 ): - from anemoi.datasets.build.sources.accumulations2 import accumulations as accumulations2 + from .accumulations2 import accumulations as accumulations2 LOG.warning( "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" diff --git a/src/anemoi/datasets/build/gridded/sources/fdb.py b/src/anemoi/datasets/build/gridded/sources/fdb.py index fb918b353..5d678fca7 100644 --- a/src/anemoi/datasets/build/gridded/sources/fdb.py +++ b/src/anemoi/datasets/build/gridded/sources/fdb.py @@ -125,7 +125,7 @@ def _time_request_keys(dt: datetime, offset_from_date: bool | None = None) -> st def _shortname_to_paramid(shortname: list[str], param_id_map: dict[str, int] | None = None) -> list[int]: - from anemoi.datasets.build.sources.mars import use_grib_paramid + from .mars import use_grib_paramid """Convert a shortname to a parameter ID.""" if param_id_map is None: diff --git a/src/anemoi/datasets/build/gridded/statistics/summary.py b/src/anemoi/datasets/build/gridded/statistics/summary.py index 59f3998b4..2f81f4e5b 100644 --- a/src/anemoi/datasets/build/gridded/statistics/summary.py +++ b/src/anemoi/datasets/build/gridded/statistics/summary.py @@ -13,9 +13,9 @@ import numpy as np -from anemoi.datasets.build.check import StatisticsValueError -from anemoi.datasets.build.check import check_data_values -from anemoi.datasets.build.check import check_stats +from anemoi.datasets.build.gridded.check import StatisticsValueError +from anemoi.datasets.build.gridded.check import check_data_values +from anemoi.datasets.build.gridded.check import check_stats class Summary(dict): diff --git a/src/anemoi/datasets/use/gridded/__init__.py b/src/anemoi/datasets/use/gridded/__init__.py index 6af38b2f4..9caa9e053 100644 --- a/src/anemoi/datasets/use/gridded/__init__.py +++ b/src/anemoi/datasets/use/gridded/__init__.py @@ -95,7 +95,7 @@ def open_dataset(*args: Any, **kwargs: Any) -> "Dataset": ds._check() if trace: - from anemoi.datasets.testing import Trace + from anemoi.datasets.misc import Trace ds = Trace(ds) diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index 2e78d22c1..9093a5845 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -16,7 +16,7 @@ import numpy as np from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.use.gridded.records.backends import backend_factory +from anemoi.datasets.use.tabular.records.backends import backend_factory LOG = logging.getLogger(__name__) From c90b10c5c7f0b45473e61b822413b64a35054a6b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 6 Oct 2025 16:51:12 +0000 Subject: [PATCH 157/212] update --- src/anemoi/datasets/commands/check.py | 3 ++- src/anemoi/datasets/commands/cleanup.py | 3 ++- src/anemoi/datasets/commands/compare-lam.py | 3 ++- src/anemoi/datasets/commands/compare.py | 3 ++- src/anemoi/datasets/commands/copy.py | 3 ++- src/anemoi/datasets/commands/create.py | 2 +- src/anemoi/datasets/commands/finalise-additions.py | 3 ++- src/anemoi/datasets/commands/finalise.py | 3 ++- src/anemoi/datasets/commands/grib-index.py | 2 +- src/anemoi/datasets/commands/init-additions.py | 3 ++- src/anemoi/datasets/commands/init.py | 3 ++- src/anemoi/datasets/commands/inspect.py | 3 ++- src/anemoi/datasets/commands/load-additions.py | 3 ++- src/anemoi/datasets/commands/load.py | 3 ++- src/anemoi/datasets/commands/patch.py | 3 ++- src/anemoi/datasets/commands/publish.py | 2 +- src/anemoi/datasets/commands/recipe/__init__.py | 7 ++++--- src/anemoi/datasets/commands/recipe/format.py | 2 +- src/anemoi/datasets/commands/recipe/migrate.py | 2 +- src/anemoi/datasets/commands/scan.py | 2 +- src/anemoi/datasets/commands/validate.py | 3 ++- 21 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index f74165dd3..820d73635 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -14,7 +14,8 @@ import yaml from anemoi.datasets.build.gridded.check import DatasetName -from anemoi.datasets.commands import Command + +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/cleanup.py b/src/anemoi/datasets/commands/cleanup.py index 25b5b9ca0..0b3a393bd 100644 --- a/src/anemoi/datasets/commands/cleanup.py +++ b/src/anemoi/datasets/commands/cleanup.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/compare-lam.py b/src/anemoi/datasets/commands/compare-lam.py index 92ea9a6af..74d97bb48 100644 --- a/src/anemoi/datasets/commands/compare-lam.py +++ b/src/anemoi/datasets/commands/compare-lam.py @@ -12,7 +12,8 @@ import os from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command + +from . import Command RADIUS_EARTH_KM = 6371.0 # Earth's radius in kilometers diff --git a/src/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py index bbd121bd1..ffe1ec02e 100644 --- a/src/anemoi/datasets/commands/compare.py +++ b/src/anemoi/datasets/commands/compare.py @@ -15,7 +15,8 @@ import zarr from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command + +from . import Command class Compare(Command): diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 5c5768714..5020a208d 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,7 +20,8 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr -from anemoi.datasets.commands import Command + +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 2b92718ae..601468d5c 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -18,7 +18,7 @@ import tqdm from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index 25380fbba..811760ae9 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index 5197fb73c..53999ad50 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index c77499b72..59c2fba89 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -13,7 +13,7 @@ import tqdm -from anemoi.datasets.commands import Command +from . import Command class GribIndexCmd(Command): diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index c49bbf76c..09788f0e4 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index c5aa515fd..0ca540b86 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 930313de8..50840ccbe 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,10 +27,11 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command from anemoi.datasets.use.gridded.stores import open_zarr from anemoi.datasets.use.gridded.stores import zarr_lookup +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index 82dec8b0a..a8cd5d5c9 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 7b1c1f684..3d969f5c3 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/patch.py b/src/anemoi/datasets/commands/patch.py index 1920fc420..dc8356126 100644 --- a/src/anemoi/datasets/commands/patch.py +++ b/src/anemoi/datasets/commands/patch.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/publish.py b/src/anemoi/datasets/commands/publish.py index 47282e65b..7f719543e 100644 --- a/src/anemoi/datasets/commands/publish.py +++ b/src/anemoi/datasets/commands/publish.py @@ -10,7 +10,7 @@ import logging from typing import Any -from anemoi.datasets.commands import Command +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 6a93af8e6..85fd574e3 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -16,9 +16,10 @@ import yaml from anemoi.datasets.build.gridded import validate_config -from anemoi.datasets.commands import Command -from anemoi.datasets.commands.recipe.format import format_recipe -from anemoi.datasets.commands.recipe.migrate import migrate_recipe + +from .. import Command +from .format import format_recipe +from .migrate import migrate_recipe LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index 328e6d756..b6993a49a 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -11,7 +11,7 @@ import datetime import logging -from anemoi.datasets.dumper import yaml_dump +from anemoi.datasets.misc.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index dc337d0ff..7da67b992 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -18,7 +18,7 @@ from glom import glom from anemoi.datasets.build.gridded import validate_config -from anemoi.datasets.dumper import yaml_dump +from anemoi.datasets.misc.dumperdumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/scan.py b/src/anemoi/datasets/commands/scan.py index 37c8d0cfd..8a048125e 100644 --- a/src/anemoi/datasets/commands/scan.py +++ b/src/anemoi/datasets/commands/scan.py @@ -17,7 +17,7 @@ import tqdm import yaml -from anemoi.datasets.commands import Command +from . import Command KEYS = ("class", "type", "stream", "expver", "levtype", "domain") diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 03691541c..1382814a7 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,9 +10,10 @@ import logging from typing import Any -from anemoi.datasets.commands import Command from anemoi.datasets.validate import validate_dataset +from . import Command + LOG = logging.getLogger(__name__) DEFAULT_DATASET = "aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8" From e0184e75dcb5d1eaa4b426fcd80c4da5014669d1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 7 Oct 2025 06:16:32 +0000 Subject: [PATCH 158/212] more refactoring --- src/anemoi/datasets/build/gridded/__init__.py | 4 ++-- src/anemoi/datasets/build/gridded/source.py | 2 +- src/anemoi/datasets/build/input/action.py | 4 ++-- src/anemoi/datasets/commands/validate.py | 2 +- src/anemoi/datasets/{build/gridded => misc}/validate.py | 0 5 files changed, 6 insertions(+), 6 deletions(-) rename src/anemoi/datasets/{build/gridded => misc}/validate.py (100%) diff --git a/src/anemoi/datasets/build/gridded/__init__.py b/src/anemoi/datasets/build/gridded/__init__.py index edee18a70..696fc118b 100644 --- a/src/anemoi/datasets/build/gridded/__init__.py +++ b/src/anemoi/datasets/build/gridded/__init__.py @@ -48,8 +48,8 @@ from anemoi.datasets.build.input import InputBuilder from anemoi.datasets.build.input.trace import enable_trace from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.use.gridded import as_first_date -from anemoi.datasets.use.gridded import as_last_date +from anemoi.datasets.use.gridded.misc import as_first_date +from anemoi.datasets.use.gridded.misc import as_last_date LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/source.py b/src/anemoi/datasets/build/gridded/source.py index df4911690..494b29b92 100644 --- a/src/anemoi/datasets/build/gridded/source.py +++ b/src/anemoi/datasets/build/gridded/source.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.build.typing import DateList +from anemoi.datasets.build.gridded.typing import DateList class Source(ABC): diff --git a/src/anemoi/datasets/build/input/action.py b/src/anemoi/datasets/build/input/action.py index 8a5cab48c..1a37d2f99 100644 --- a/src/anemoi/datasets/build/input/action.py +++ b/src/anemoi/datasets/build/input/action.py @@ -181,7 +181,7 @@ class DatasetSourceMixin: """Mixin class for sources defined in anemoi-datasets""" def create_object(self, context, config): - from anemoi.datasets.build.sources import create_source as create_datasets_source + from anemoi.datasets.build.gridded.sources import create_source as create_datasets_source return create_datasets_source(context, config) @@ -286,7 +286,7 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.build.sources import source_registry as dataset_source_registry + from anemoi.datasets.build.gridded.sources import source_registry as dataset_source_registry # Register sources, local first for name in dataset_source_registry.registered: diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 1382814a7..dfc2d297b 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,7 +10,7 @@ import logging from typing import Any -from anemoi.datasets.validate import validate_dataset +from anemoi.datasets.misc.validate import validate_dataset from . import Command diff --git a/src/anemoi/datasets/build/gridded/validate.py b/src/anemoi/datasets/misc/validate.py similarity index 100% rename from src/anemoi/datasets/build/gridded/validate.py rename to src/anemoi/datasets/misc/validate.py From 042a8f01662f8480bafdfec6f453aaf75f91db77 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 7 Oct 2025 06:42:41 +0000 Subject: [PATCH 159/212] fix test --- .../build/gridded/sources/accumulations.py | 2 +- .../datasets/build/gridded/sources/mars.py | 2 +- .../sources/xarray_support/__init__.py | 2 +- src/anemoi/datasets/use/gridded/misc.py | 10 ++++---- tests/create/test_sources.py | 25 ++++++++++++++----- tests/test_chunks.py | 4 +-- tests/test_dates.py | 2 +- tests/xarray/test_opendap.py | 2 +- tests/xarray/test_zarr.py | 2 +- 9 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py index 6acecbf98..ba943897f 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -20,7 +20,7 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.build.gridded.utils import to_datetime_list from .legacy import legacy_source from .mars import mars diff --git a/src/anemoi/datasets/build/gridded/sources/mars.py b/src/anemoi/datasets/build/gridded/sources/mars.py index db075321e..a92205ac1 100644 --- a/src/anemoi/datasets/build/gridded/sources/mars.py +++ b/src/anemoi/datasets/build/gridded/sources/mars.py @@ -16,7 +16,7 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.build.utils import to_datetime_list +from anemoi.datasets.build.gridded.utils import to_datetime_list from .legacy import legacy_source diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py index ec3a01144..4bfd4f76b 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py @@ -15,7 +15,7 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.gridded.sources.patterns import iterate_patterns +from anemoi.datasets.build.gridded.sources.patterns import iterate_patterns from ..legacy import legacy_source from .fieldlist import XarrayFieldList diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index 36afbee3c..97549ac24 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -359,7 +359,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in metadata: raise ValueError(f"Metadata for {a} does not contain 'backend' key") - from anemoi.datasets.use.gridded.records import open_records_dataset + from anemoi.datasets.use.tabular.records import open_records_dataset return open_records_dataset(a, backend=metadata["backend"]) @@ -501,7 +501,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from anemoi.datasets.use.gridded.observations import observations_factory + from anemoi.datasets.use.tabular.observations import observations_factory assert not sets, sets @@ -509,21 +509,21 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": if "xy" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.gridded.xy import xy_factory + from anemoi.datasets.use.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "x" in kwargs and "y" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.gridded.xy import xy_factory + from anemoi.datasets.use.gridded.xy import xy_factory assert not sets, sets return xy_factory(args, kwargs).mutate() if "zip" in kwargs: # Experimental feature, may be removed - from anemoi.datasets.use.gridded.gridded.xy import zip_factory + from anemoi.datasets.use.gridded.xy import zip_factory assert not sets, sets return zip_factory(args, kwargs).mutate() diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index dbf0d746a..e841744ea 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -96,14 +96,17 @@ def test_grib_gridfile(get_test_data) -> None: ) @skip_if_offline @pytest.mark.parametrize( - "refinement_level_c,shape", + "input_refinement_level_c,output_refinement_level_c,shape", ( - (2, (2, 13, 1, 2880)), - (7, (2, 13, 1, 2949120)), + (7, 2, (2, 13, 1, 2880)), + (7, 7, (2, 13, 1, 2949120)), ), ) def test_grib_gridfile_with_refinement_level( - refinement_level_c: str, shape: tuple[int, int, int, int, int], get_test_data: callable + input_refinement_level_c: str, + output_refinement_level_c: str, + shape: tuple[int, int, int, int, int], + get_test_data: callable, ) -> None: """Test the creation of a dataset from GRIB files with an unstructured grid. @@ -129,11 +132,21 @@ def test_grib_gridfile_with_refinement_level( grib = { "path": os.path.join(path, "{date:strftimedelta(+3h;%Y%m%d%H)}+fc_R03B07_rea_ml.{date:strftime(%Y%m%d%H)}"), - "grid_definition": {"icon": {"path": gridfile, "refinement_level_c": refinement_level_c}}, + "grid_definition": { + "icon": { + "path": gridfile, + "refinement_level_c": input_refinement_level_c, + } + }, "param": param, "level": level, } - refinement_filter = {"icon_refinement_level": {"grid": gridfile, "refinement_level_c": refinement_level_c}} + refinement_filter = { + "icon_refinement_level": { + "grid": gridfile, + "refinement_level_c": output_refinement_level_c, + } + } config = { "dates": { diff --git a/tests/test_chunks.py b/tests/test_chunks.py index b2aa3c789..529c1f0cd 100644 --- a/tests/test_chunks.py +++ b/tests/test_chunks.py @@ -7,11 +7,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -"""Test suite for the ChunkFilter class in the anemoi.datasets.create.chunks module.""" +"""Test suite for the ChunkFilter class in the anemoi.datasets.build.gridded.chunks module.""" import pytest -from anemoi.datasets.build.chunks import ChunkFilter +from anemoi.datasets.build.gridded.chunks import ChunkFilter def test_chunk_filter(): diff --git a/tests/test_dates.py b/tests/test_dates.py index d169498bb..abc746d8e 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from anemoi.datasets.build.statistics import default_statistics_dates +from anemoi.datasets.build.gridded.statistics import default_statistics_dates _ = datetime.datetime diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index 1625ef32c..538630a23 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -13,7 +13,7 @@ from anemoi.utils.testing import skip_if_offline from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList -from anemoi.datasets.testing import assert_field_list +from anemoi.datasets.misc.testing import assert_field_list @skip_if_offline diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 8202ed760..1c35361c7 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -13,7 +13,7 @@ from anemoi.utils.testing import skip_missing_packages from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList -from anemoi.datasets.testing import assert_field_list +from anemoi.datasets.misc.testing import assert_field_list @skip_if_offline From f4ad404bc046711d687c160a6ee4f03b0d35c3be Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 7 Oct 2025 06:46:08 +0000 Subject: [PATCH 160/212] update --- src/anemoi/datasets/commands/copy.py | 2 +- src/anemoi/datasets/commands/recipe/migrate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 5020a208d..886726d99 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -19,7 +19,7 @@ from anemoi.utils.remote import Transfer from anemoi.utils.remote import TransferMethodNotImplementedError -from anemoi.datasets.check import check_zarr +from anemoi.datasets.misc.check import check_zarr from . import Command diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 7da67b992..8ca2ddd5d 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -18,7 +18,7 @@ from glom import glom from anemoi.datasets.build.gridded import validate_config -from anemoi.datasets.misc.dumperdumper import yaml_dump +from anemoi.datasets.misc.dumper import yaml_dump LOG = logging.getLogger(__name__) From d286f9a0f042da4ace0261e3dedae12b7fee2cd2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 7 Oct 2025 07:16:51 +0000 Subject: [PATCH 161/212] refactoring --- .../build/{input/context/field.py => gridded/context.py} | 8 ++++---- .../build/{input/result/field.py => gridded/result.py} | 2 +- src/anemoi/datasets/build/input/__init__.py | 4 ++-- .../build/input/{context/__init__.py => context.py} | 0 src/anemoi/datasets/build/input/data_sources.py | 2 +- src/anemoi/datasets/build/input/repeated_dates.py | 2 +- .../build/input/{result/__init__.py => result.py} | 0 7 files changed, 9 insertions(+), 9 deletions(-) rename src/anemoi/datasets/build/{input/context/field.py => gridded/context.py} (88%) rename src/anemoi/datasets/build/{input/result/field.py => gridded/result.py} (99%) rename src/anemoi/datasets/build/input/{context/__init__.py => context.py} (100%) rename src/anemoi/datasets/build/input/{result/__init__.py => result.py} (100%) diff --git a/src/anemoi/datasets/build/input/context/field.py b/src/anemoi/datasets/build/gridded/context.py similarity index 88% rename from src/anemoi/datasets/build/input/context/field.py rename to src/anemoi/datasets/build/gridded/context.py index 1a03a603a..a9298cddf 100644 --- a/src/anemoi/datasets/build/input/context/field.py +++ b/src/anemoi/datasets/build/gridded/context.py @@ -12,11 +12,11 @@ from earthkit.data.core.order import build_remapping -from anemoi.datasets.build.input.context import Context -from anemoi.datasets.build.input.result.field import FieldResult +from anemoi.datasets.build.gridded.result import GriddedResult +from anemoi.datasets.build.input import Context -class FieldContext(Context): +class GriddedContext(Context): def __init__( self, @@ -46,7 +46,7 @@ def filter_argument(self, argument: Any) -> Any: return argument def create_result(self, data): - return FieldResult(self, data) + return GriddedResult(self, data) def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: from anemoi.datasets.dates.groups import GroupOfDates diff --git a/src/anemoi/datasets/build/input/result/field.py b/src/anemoi/datasets/build/gridded/result.py similarity index 99% rename from src/anemoi/datasets/build/input/result/field.py rename to src/anemoi/datasets/build/gridded/result.py index a80fdb3e6..69c560969 100644 --- a/src/anemoi/datasets/build/input/result/field.py +++ b/src/anemoi/datasets/build/gridded/result.py @@ -276,7 +276,7 @@ def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class FieldResult(Result): +class GriddedResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False diff --git a/src/anemoi/datasets/build/input/__init__.py b/src/anemoi/datasets/build/input/__init__.py index 4fd558242..4acb4de86 100644 --- a/src/anemoi/datasets/build/input/__init__.py +++ b/src/anemoi/datasets/build/input/__init__.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING from typing import Any -from anemoi.datasets.build.input.context.field import FieldContext +from anemoi.datasets.build.gridded.context import GriddedContext if TYPE_CHECKING: from anemoi.datasets.build.input.action import Recipe @@ -61,7 +61,7 @@ def select(self, argument) -> Any: Any Selected data. """ - context = FieldContext(argument, **self.kwargs) + context = GriddedContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) diff --git a/src/anemoi/datasets/build/input/context/__init__.py b/src/anemoi/datasets/build/input/context.py similarity index 100% rename from src/anemoi/datasets/build/input/context/__init__.py rename to src/anemoi/datasets/build/input/context.py diff --git a/src/anemoi/datasets/build/input/data_sources.py b/src/anemoi/datasets/build/input/data_sources.py index ab5cd5d50..6e9bfaa6a 100644 --- a/src/anemoi/datasets/build/input/data_sources.py +++ b/src/anemoi/datasets/build/input/data_sources.py @@ -13,10 +13,10 @@ from earthkit.data import FieldList +from anemoi.datasets.build.gridded.result import Result from anemoi.datasets.build.input.action import Action from anemoi.datasets.build.input.action import action_factory from anemoi.datasets.build.input.misc import _tidy -from anemoi.datasets.build.input.result.field import Result from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/input/repeated_dates.py b/src/anemoi/datasets/build/input/repeated_dates.py index 925886c00..f20d764ec 100644 --- a/src/anemoi/datasets/build/input/repeated_dates.py +++ b/src/anemoi/datasets/build/input/repeated_dates.py @@ -19,10 +19,10 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta +from anemoi.datasets.build.gridded.result import Result from anemoi.datasets.build.input.action import Action from anemoi.datasets.build.input.action import action_factory from anemoi.datasets.build.input.join import JoinResult -from anemoi.datasets.build.input.result.field import Result from anemoi.datasets.build.input.trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/input/result/__init__.py b/src/anemoi/datasets/build/input/result.py similarity index 100% rename from src/anemoi/datasets/build/input/result/__init__.py rename to src/anemoi/datasets/build/input/result.py From 804faa88bbc53ed9b57d0ff1b1a68edf06df0093 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 7 Oct 2025 07:21:27 +0000 Subject: [PATCH 162/212] refactoring --- src/anemoi/datasets/build/gridded/context.py | 2 +- src/anemoi/datasets/build/input/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/build/gridded/context.py b/src/anemoi/datasets/build/gridded/context.py index a9298cddf..91ea80c07 100644 --- a/src/anemoi/datasets/build/gridded/context.py +++ b/src/anemoi/datasets/build/gridded/context.py @@ -13,7 +13,7 @@ from earthkit.data.core.order import build_remapping from anemoi.datasets.build.gridded.result import GriddedResult -from anemoi.datasets.build.input import Context +from anemoi.datasets.build.input.context import Context class GriddedContext(Context): diff --git a/src/anemoi/datasets/build/input/__init__.py b/src/anemoi/datasets/build/input/__init__.py index 4acb4de86..c3d601fd1 100644 --- a/src/anemoi/datasets/build/input/__init__.py +++ b/src/anemoi/datasets/build/input/__init__.py @@ -12,8 +12,6 @@ from typing import TYPE_CHECKING from typing import Any -from anemoi.datasets.build.gridded.context import GriddedContext - if TYPE_CHECKING: from anemoi.datasets.build.input.action import Recipe @@ -61,6 +59,8 @@ def select(self, argument) -> Any: Any Selected data. """ + from anemoi.datasets.build.gridded.context import GriddedContext + context = GriddedContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) From ddfc22a814d89b42bab685cf27fc997a8f16f9e1 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Sat, 18 Oct 2025 23:33:52 +0200 Subject: [PATCH 163/212] wip --- src/anemoi/datasets/data/records/__init__.py | 72 ++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index ea901c7c8..68e3fb122 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -114,8 +114,7 @@ def _getrecord(self, i: int): return Record(self, i) def _load_data(self, i: int) -> dict: - """ - Load the data for a specific time step or window (i). + """Load the data for a specific time step or window (i). It is expected to return a dict containing keys of the form: - "data:group1" : numpy array @@ -150,8 +149,14 @@ def _subset(self, **kwargs): end = kwargs.pop("end", None) frequency = kwargs.pop("frequency", self.frequency) - if frequency != self.frequency: - raise ValueError(f"Changing the frequency {frequency} (from {self.frequency}) is not implemented yet.") + if frequency: + frequency = frequency_to_timedelta(frequency) + if self.frequency.total_seconds() % frequency.total_seconds() == 0: + return IncreaseFrequency(self, frequency) + elif frequency.total_seconds() % self.frequency.total_seconds() == 0: + raise NotImplementedError("Decreasing frequency not implemented yet") + # return DecreaseFrequency(self, frequency) + assert self.frequency == frequency, (self.frequency, frequency) if start is not None or end is not None: @@ -559,6 +564,65 @@ def ends_after(self, my_dates, other_dates, other_window): return my_end >= other_end +class IncreaseFrequency(RecordsForward): + # change the frequency of a records dataset by splitting the windows to fit the new frequency + def __init__(self, dataset, frequency): + super().__init__(dataset) + self.dataset = dataset + self._frequency = frequency_to_timedelta(frequency) + self.reason = {"frequency": frequency} + + self._n = self.dataset.frequency / self._frequency + if int(self._n) != self._n: + raise ValueError(f"Cannot split frequency {self.dataset.frequency} to {frequency}, not a multiple") + self._n = int(self._n) + + # self.missing = [] + # for i in self.dataset.missing: + # for j in range(self._n): + # self.missing.append(i * self._n + j) + # self.missing = sorted(self.missing) + + def __len__(self): + return len(self.dataset) * self._n + + @property + def frequency(self): + return self._frequency + + def _load_data(self, i): + j = i // self._n + k = i % self._n + + too_much_data = self.dataset._load_data(j) + + out = {} + for group in self.groups: + timedeltas = too_much_data[f"timedeltas:{group}"] + if timedeltas.dtype != "timedelta64[s]": + raise ValueError(f"Wrong type for {group}") + + start_delta = k * self.frequency + end_delta = (k + 1) * self.frequency + start_delta = _to_numpy_timedelta(start_delta) + end_delta = _to_numpy_timedelta(end_delta) + assert isinstance(start_delta, np.timedelta64), (type(start_delta), start_delta) + assert isinstance(timedeltas[0], np.timedelta64), type(timedeltas[0]) + + mask = (timedeltas >= start_delta) & (timedeltas < end_delta) + + out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] + out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] + out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] + out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] + out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] + + return out + + def tree(self): + return Node(self, [self.dataset.tree()], **self.reason) + + class Rewindowed(RecordsForward): # change the window of a records dataset # similar to changing the frequency of a dataset From 293f9bc10fdfaaeaee3b8e57159126577b9ba39e Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 20 Oct 2025 15:16:48 +0200 Subject: [PATCH 164/212] added change frequency of obs datasets, shrinking window --- src/anemoi/datasets/data/records/__init__.py | 77 +++++++++++++++++++- src/anemoi/datasets/data/records/windows.py | 12 +++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 43b1f3e35..60d1d23cc 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -146,7 +146,7 @@ def _subset(self, **kwargs): if frequency: frequency = frequency_to_timedelta(frequency) if self.frequency.total_seconds() % frequency.total_seconds() == 0: - return self # IncreaseFrequency(self, frequency) + return IncreaseFrequency(self, frequency) elif frequency.total_seconds() % self.frequency.total_seconds() == 0: raise NotImplementedError("Decreasing frequency not implemented yet") # return DecreaseFrequency(self, frequency) @@ -244,6 +244,81 @@ def tree(self): return Node(self, [self.forward.tree()], **self.reason) +class IncreaseFrequency(RecordsForward): + # change the frequency of a records dataset by splitting the windows to fit the new frequency + # the new frequency must be a divisor of the original frequency (e.g. 6h -> 3h, but not 3h -> 6h) (and not 6h -> 5h) + def __init__(self, dataset, frequency): + super().__init__(dataset) + self.dataset = dataset + self._frequency = frequency_to_timedelta(frequency) + self.reason = {"frequency": frequency} + + self._n = self.dataset.frequency / self._frequency + if int(self._n) != self._n: + raise ValueError(f"Cannot split frequency {self.dataset.frequency} to {frequency}, not a multiple") + self._n = int(self._n) + + @cached_property + def _window(self): + previous = self.dataset._window + if isinstance(previous, int): + previous = window_from_str(previous) + return previous / self._n + + def __len__(self): + return len(self.dataset) * self._n + + @property + def frequency(self): + return self._frequency + + def _load_data(self, i): + j = i // self._n + k = i % self._n + + too_much_data = self.dataset._load_data(j) + + out = {} + for group in self.groups: + timedeltas = too_much_data[f"timedeltas:{group}"] + if timedeltas.dtype != "timedelta64[s]": + raise ValueError(f"Wrong type for {group}") + + start_delta = k * self.frequency + self._window.start + end_delta = k * self.frequency + self._window.end + + def _to_numpy_timedelta(td): + if isinstance(td, np.timedelta64): + assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" + return td + return np.timedelta64(int(td.total_seconds()), "s") + + start_delta = _to_numpy_timedelta(start_delta) + end_delta = _to_numpy_timedelta(end_delta) + assert isinstance(start_delta, np.timedelta64), (type(start_delta), start_delta) + assert isinstance(timedeltas[0], np.timedelta64), type(timedeltas[0]) + + if self._window.include_start: + mask = timedeltas >= start_delta + else: + mask = timedeltas > start_delta + if self._window.include_end: + mask &= timedeltas <= end_delta + else: + mask &= timedeltas < end_delta + + out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] + out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] + out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] + out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] + out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] + + return out + + def tree(self): + return Node(self, [self.dataset.tree()], **self.reason) + + class FieldsRecords(RecordsForward): """A wrapper around a FieldsDataset to provide a consistent interface for records datasets.""" diff --git a/src/anemoi/datasets/data/records/windows.py b/src/anemoi/datasets/data/records/windows.py index 43ee671d7..5f02e3c82 100644 --- a/src/anemoi/datasets/data/records/windows.py +++ b/src/anemoi/datasets/data/records/windows.py @@ -238,3 +238,15 @@ def ends_after(self, my_dates, other_dates, other_window): return (not other_window.include_end) or self.include_end print(my_end >= other_end) return my_end >= other_end + + def __truediv__(self, n: int): + """Divide the window into a smaller windows, shrinked by a factor n.""" + assert isinstance(n, int), f"n must be an int, got {type(n)}" + assert n > 0, f"n must be positive, got {n}" + + return WindowsSpec( + start=self.start / n, + end=self.end / n, + include_start=self.include_start, + include_end=self.include_end, + ) From f1df2aa1c32e4914ca501bc6828f467b5a08b8ce Mon Sep 17 00:00:00 2001 From: Aaron Hopkinson <197336788+aaron-hopkinson@users.noreply.github.com> Date: Tue, 21 Oct 2025 11:14:35 +0100 Subject: [PATCH 165/212] Review suggestions for PR #433 (#436) Resolves merge conflicts and fixes broken imports after reorganisation. By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: Francesco Zanetta <62377868+frazane@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: b8raoult <53792887+b8raoult@users.noreply.github.com> --- .pre-commit-config.yaml | 8 +- pyproject.toml | 2 +- .../build/gridded/sources/accumulations.py | 168 ++++++------- .../build/gridded/sources/accumulations2.py | 61 ++--- .../build/gridded/sources/anemoi_dataset.py | 88 +++---- .../build/gridded/sources/constants.py | 77 +++--- .../datasets/build/gridded/sources/empty.py | 48 ++-- .../build/gridded/sources/forcings.py | 57 ++--- .../datasets/build/gridded/sources/grib.py | 164 ++++++------ .../build/gridded/sources/grib_index.py | 88 +++---- .../build/gridded/sources/hindcasts.py | 112 +++++---- .../datasets/build/gridded/sources/legacy.py | 72 +----- .../datasets/build/gridded/sources/mars.py | 238 ++++++++---------- .../datasets/build/gridded/sources/netcdf.py | 52 ++-- .../datasets/build/gridded/sources/opendap.py | 52 ++-- .../build/gridded/sources/recentre.py | 83 +++--- .../build/gridded/sources/repeated_dates.py | 4 +- .../datasets/build/gridded/sources/source.py | 74 ++---- .../build/gridded/sources/tendencies.py | 161 +++++------- .../sources/xarray_support/__init__.py | 53 ++-- .../build/gridded/sources/xarray_zarr.py | 52 ++-- .../datasets/build/gridded/sources/zenodo.py | 82 +++--- src/anemoi/datasets/commands/check.py | 2 +- .../use/tabular/observations/__init__.py | 2 +- tools/build-obs.py | 2 +- 25 files changed, 844 insertions(+), 958 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa1cfb159..d19c99e58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,12 +27,12 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 25.1.0 + rev: 25.9.0 hooks: - id: black args: [--line-length=120] - repo: https://github.com/pycqa/isort - rev: 6.0.1 + rev: 6.1.0 hooks: - id: isort args: @@ -41,7 +41,7 @@ repos: - --profile black - --project anemoi - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.11 + rev: v0.13.3 hooks: - id: ruff args: @@ -65,7 +65,7 @@ repos: - id: docconvert args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "v2.6.0" + rev: "v2.7.0" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig diff --git a/pyproject.toml b/pyproject.toml index 44df9260c..ad17d6e84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dependencies = [ ] optional-dependencies.all = [ - "anemoi-datasets[create,remote,xarray,comparelam]", + "anemoi-datasets[comparelam,create,remote,xarray]", ] optional-dependencies.comparelam = [ diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py index ba943897f..2d0124a81 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -20,12 +20,13 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.build.gridded.utils import to_datetime_list +from anemoi.datasets.build.gridded.sources import source_registry -from .legacy import legacy_source +from .legacy import LegacySource from .mars import mars LOG = logging.getLogger(__name__) +MISSING_VALUE = 1e-38 def _member(field: Any) -> int: @@ -168,6 +169,7 @@ def write(self, template: Any) -> None: # are used to store the end step edition = template.metadata("edition") + assert np.all(self.values != MISSING_VALUE) if edition == 1 and self.endStep > 254: self.out.write( @@ -176,6 +178,7 @@ def write(self, template: Any) -> None: stepType="instant", step=self.endStep, check_nans=True, + missing_value=MISSING_VALUE, ) else: self.out.write( @@ -185,6 +188,7 @@ def write(self, template: Any) -> None: startStep=self.startStep, endStep=self.endStep, check_nans=True, + missing_value=MISSING_VALUE, ) self.values = None self.done = True @@ -205,9 +209,6 @@ def add(self, field: Any, values: NDArray[Any]) -> None: if step not in self.steps: return - if not np.all(values >= 0): - warnings.warn(f"Negative values for {field}: {np.nanmin(values)} {np.nanmax(values)}") - assert not self.done, (self.key, step) assert step not in self.seen, (self.key, step) @@ -966,97 +967,76 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@legacy_source(__file__) -def accumulations( - context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any -) -> Any: - """Computes accumulations based on the provided context, dates, and request parameters. - - Parameters - ---------- - context : Any - Context for the computation. - dates : List[datetime.datetime] - List of dates. - use_cdsapi_dataset : Optional[str], optional - CDSAPI dataset to use. Defaults to None. - **request : Any - Additional request parameters. - - Returns - ------- - Any - The computed accumulations. - """ - - if ( - request.get("class") == "ea" - and request.get("stream", "oper") == "oper" - and request.get("accumulation_period") == 24 - ): - from .accumulations2 import accumulations as accumulations2 - - LOG.warning( - "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" - ) - return accumulations2(context, dates, **request) - - _to_list(request["param"]) - class_ = request.get("class", "od") - stream = request.get("stream", "oper") - - user_accumulation_period = request.pop("accumulation_period", 6) - accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) - user_date = request.pop("date", None) - - # If `data_accumulation_period` is not set, this means that the accumulations are from the start - # of the forecast. - - KWARGS = { - ("od", "oper"): dict(patch=_scda), - ("od", "elda"): dict(base_times=(6, 18)), - ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), - ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), - ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), - ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), - ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), - } - - kwargs = KWARGS.get((class_, stream), {}) - - context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") - - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - accumulations_reset_frequency=accumulations_reset_frequency, - use_cdsapi_dataset=use_cdsapi_dataset, - user_date=user_date, - **kwargs, - ) +@source_registry.register("accumulations") +class AccumulationsSource(LegacySource): + @staticmethod + def _execute( + context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any + ) -> Any: + """Computes accumulations based on the provided context, dates, and request parameters. -execute = accumulations - -if __name__ == "__main__": - import yaml + Parameters + ---------- + context : Any + Context for the computation. + dates : List[datetime.datetime] + List of dates. + use_cdsapi_dataset : Optional[str], optional + CDSAPI dataset to use. Defaults to None. + **request : Any + Additional request parameters. - config = yaml.safe_load( + Returns + ------- + Any + The computed accumulations. """ - class: ea - expver: '0001' - grid: 20./20. - levtype: sfc -# number: [0, 1] -# stream: enda - param: [cp, tp] -# accumulation_period: 6h - """ - ) - dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - for f in accumulations(None, dates, **config): - print(f, f.to_numpy().mean()) + if ( + request.get("class") == "ea" + and request.get("stream", "oper") == "oper" + and request.get("accumulation_period") == 24 + ): + from .accumulations2 import accumulations as accumulations2 + + LOG.warning( + "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" + ) + return accumulations2(context, dates, **request) + + _to_list(request["param"]) + class_ = request.get("class", "od") + stream = request.get("stream", "oper") + + user_accumulation_period = request.pop("accumulation_period", 6) + accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) + user_date = request.pop("date", None) + + # If `data_accumulation_period` is not set, this means that the accumulations are from the start + # of the forecast. + + KWARGS = { + ("od", "oper"): dict(patch=_scda), + ("od", "elda"): dict(base_times=(6, 18)), + ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), + ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), + ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), + ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), + ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), + } + + kwargs = KWARGS.get((class_, stream), {}) + + context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") + + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + accumulations_reset_frequency=accumulations_reset_frequency, + use_cdsapi_dataset=use_cdsapi_dataset, + user_date=user_date, + **kwargs, + ) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations2.py b/src/anemoi/datasets/build/gridded/sources/accumulations2.py index 1a15badfa..64410164f 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations2.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations2.py @@ -18,9 +18,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.utils import to_datetime_list +from anemoi.datasets.build.gridded.sources import source_registry -from .legacy import legacy_source +from .legacy import LegacySource from .mars import mars LOG = logging.getLogger(__name__) @@ -599,49 +599,20 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@legacy_source(__file__) -def accumulations(context, dates, **request): - _to_list(request["param"]) - user_accumulation_period = request.pop("accumulation_period", 6) - user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) +@source_registry.register("accumulations2") +class Accumulations2Source(LegacySource): - context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") + @staticmethod + def _execute(context, dates, **request): + _to_list(request["param"]) + user_accumulation_period = request.pop("accumulation_period", 6) + user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - ) - - -execute = accumulations - -if __name__ == "__main__": - import yaml - - config = yaml.safe_load( - """ - class: ea - expver: '0001' - grid: 20./20. - levtype: sfc -# number: [0, 1] -# stream: enda - param: [cp, tp] -# accumulation_period: 6h - accumulation_period: 2 - """ - ) - dates = yaml.safe_load("[2022-12-31 00:00, 2022-12-31 06:00]") - # dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - - class Context: - use_grib_paramid = True + context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") - def trace(self, *args): - print(*args) - - for f in accumulations(Context, dates, **config): - print(f, f.to_numpy().mean()) + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + ) diff --git a/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py index 12d41db23..743605bb9 100644 --- a/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py @@ -9,65 +9,69 @@ import numpy as np -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource -@legacy_source(__file__) -def execute(context, dates, params=None, **kwargs): - import earthkit.data as ekd +@source_registry.register("anemoi_dataset") +class AnemoiDatasetSource(LegacySource): - from anemoi.datasets import open_dataset + @staticmethod + def _execute(context, dates, params=None, **kwargs): + import earthkit.data as ekd - ds = open_dataset(**kwargs) - # dates_to_index = {date: i for i, date in enumerate(ds.dates)} + from anemoi.datasets import open_dataset - indices = [] - for date in dates: - idx = np.where(ds.dates == date)[0] - if len(idx) == 0: - continue - indices.append((int(idx[0]), date)) + ds = open_dataset(**kwargs) + # dates_to_index = {date: i for i, date in enumerate(ds.dates)} - vars = ds.variables - if params is None: - params = vars + indices = [] + for date in dates: + idx = np.where(ds.dates == date)[0] + if len(idx) == 0: + continue + indices.append((int(idx[0]), date)) - if not isinstance(params, (list, tuple, set)): - params = [params] + vars = ds.variables + if params is None: + params = vars - params = set(params) - results = [] + if not isinstance(params, (list, tuple, set)): + params = [params] - ensemble = ds.shape[2] > 1 - latitudes = ds.latitudes - longitudes = ds.longitudes + params = set(params) + results = [] - for idx, date in indices: + ensemble = ds.shape[2] > 1 + latitudes = ds.latitudes + longitudes = ds.longitudes - metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) + for idx, date in indices: - for j, y in enumerate(ds[idx]): + metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) - param = vars[j] - if param not in params: - continue + for j, y in enumerate(ds[idx]): + + param = vars[j] + if param not in params: + continue - # metadata['name'] = param - # metadata['param_level'] = param - metadata["param"] = param + # metadata['name'] = param + # metadata['param_level'] = param + metadata["param"] = param - for k, e in enumerate(y): - if ensemble: - metadata["number"] = k + 1 + for k, e in enumerate(y): + if ensemble: + metadata["number"] = k + 1 - metadata["values"] = e + metadata["values"] = e - results.append(metadata.copy()) + results.append(metadata.copy()) - print(results[0].keys()) + print(results[0].keys()) - # "list-of-dicts" does support resolution - results = ekd.from_source("list-of-dicts", results) + # "list-of-dicts" does support resolution + results = ekd.from_source("list-of-dicts", results) - # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) - return results + # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) + return results diff --git a/src/anemoi/datasets/build/gridded/sources/constants.py b/src/anemoi/datasets/build/gridded/sources/constants.py index 104f24863..a805c4b16 100644 --- a/src/anemoi/datasets/build/gridded/sources/constants.py +++ b/src/anemoi/datasets/build/gridded/sources/constants.py @@ -11,41 +11,42 @@ from earthkit.data import from_source -from .legacy import legacy_source - - -@legacy_source(__file__) -def constants(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: - """Deprecated function to retrieve constants data. - - Parameters - ---------- - context : Any - The context object for tracing. - dates : list of str - List of dates for which data is required. - template : dict of str to Any - Template dictionary for the data source. - param : str - Parameter to retrieve. - - Returns - ------- - Any - Data retrieved from the source. - """ - from warnings import warn - - warn( - "The source `constants` is deprecated, use `forcings` instead.", - DeprecationWarning, - stacklevel=2, - ) - context.trace("✅", f"from_source(constants, {template}, {param}") - if len(template) == 0: - raise ValueError("Forcings template is empty.") - - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) - - -execute: Any = constants +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("constants") +class ConstantsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: + """Deprecated function to retrieve constants data. + + Parameters + ---------- + context : Any + The context object for tracing. + dates : list of str + List of dates for which data is required. + template : dict of str to Any + Template dictionary for the data source. + param : str + Parameter to retrieve. + + Returns + ------- + Any + Data retrieved from the source. + """ + from warnings import warn + + warn( + "The source `constants` is deprecated, use `forcings` instead.", + DeprecationWarning, + stacklevel=2, + ) + context.trace("✅", f"from_source(constants, {template}, {param}") + if len(template) == 0: + raise ValueError("Forcings template is empty.") + + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) diff --git a/src/anemoi/datasets/build/gridded/sources/empty.py b/src/anemoi/datasets/build/gridded/sources/empty.py index fb7fcd906..fa8bc8d84 100644 --- a/src/anemoi/datasets/build/gridded/sources/empty.py +++ b/src/anemoi/datasets/build/gridded/sources/empty.py @@ -12,25 +12,29 @@ import earthkit.data as ekd -from .legacy import legacy_source - - -@legacy_source(__file__) -def execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: - """Executes the loading of an empty data source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - Loaded empty data source. - """ - return ekd.from_source("empty") +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("empty") +class EmptySource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: + """Executes the loading of an empty data source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + Loaded empty data source. + """ + return ekd.from_source("empty") diff --git a/src/anemoi/datasets/build/gridded/sources/forcings.py b/src/anemoi/datasets/build/gridded/sources/forcings.py index bbafaa465..6070772fc 100644 --- a/src/anemoi/datasets/build/gridded/sources/forcings.py +++ b/src/anemoi/datasets/build/gridded/sources/forcings.py @@ -11,31 +11,32 @@ from earthkit.data import from_source -from .legacy import legacy_source - - -@legacy_source(__file__) -def forcings(context: Any, dates: list[str], template: str, param: str) -> Any: - """Loads forcing data from a specified source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - template : FieldList - Template for the data source. - param : str - Parameter for the data source. - - Returns - ------- - object - Loaded forcing data. - """ - context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) - - -execute = forcings +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("forcings") +class ForcingsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], template: str, param: str) -> Any: + """Loads forcing data from a specified source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + template : FieldList + Template for the data source. + param : str + Parameter for the data source. + + Returns + ------- + object + Loaded forcing data. + """ + context.trace("✅", f"from_source(forcings, {template}, {param}") + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) diff --git a/src/anemoi/datasets/build/gridded/sources/grib.py b/src/anemoi/datasets/build/gridded/sources/grib.py index 03bcda475..d709efc5e 100644 --- a/src/anemoi/datasets/build/gridded/sources/grib.py +++ b/src/anemoi/datasets/build/gridded/sources/grib.py @@ -20,7 +20,8 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource LOG = logging.getLogger(__name__) @@ -47,6 +48,14 @@ def check(ds: Any, paths: list[str], **kwargs: Any) -> None: if isinstance(v, (tuple, list)): count *= len(v) + # in the case of static data (e.g repeated dates) dates might be empty + if len(ds) != count and kwargs.get("dates", []) == []: + LOG.warning( + f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})" + f" Received empty dates - assuming this is static data." + ) + return + if len(ds) != count: raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})") @@ -73,74 +82,85 @@ def _expand(paths: list[str]) -> Any: yield path -@legacy_source(__file__) -def execute( - context: Any, - dates: list[Any], - path: str | list[str], - flavour: str | dict[str, Any] | None = None, - grid_definition: dict[str, Any] | None = None, - *args: Any, - **kwargs: Any, -) -> ekd.FieldList: - """Executes the function to load data from GRIB files. - - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : list of Any - List of dates. - path : str or list of str - Path or list of paths to the GRIB files. - flavour : str or dict of str to Any, optional - Flavour information, by default None. - grid_definition : dict of str to Any, optional - Grid definition configuration to create a Grid object, by default None. - *args : Any - Additional positional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The loaded dataset. - """ - given_paths = path if isinstance(path, list) else [path] - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - if grid_definition is not None: - grid = grid_registry.from_config(grid_definition) - else: - grid = None - - ds = from_source("empty") - dates = [d.isoformat() for d in dates] - - for path in given_paths: - paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) - - for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): - if name in kwargs: - raise ValueError(f"MARS interpolation parameter '{name}' not supported") - - for path in _expand(paths): - context.trace("📁", "PATH", path) - s = from_source("file", path) - if flavour is not None: - s = flavour.map(s) - s = s.sel(valid_datetime=dates, **kwargs) - ds = ds + s - - if kwargs and not context.partial_ok: - check(ds, given_paths, valid_datetime=dates, **kwargs) - - if grid is not None: - ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) - - if len(ds) == 0: - LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") - - return ds +@source_registry.register("grib") +class GribSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[Any], + path: str | list[str], + flavour: str | dict[str, Any] | None = None, + grid_definition: dict[str, Any] | None = None, + *args: Any, + **kwargs: Any, + ) -> ekd.FieldList: + """Executes the function to load data from GRIB files. + + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : list of Any + List of dates. + path : str or list of str + Path or list of paths to the GRIB files. + flavour : str or dict of str to Any, optional + Flavour information, by default None. + grid_definition : dict of str to Any, optional + Grid definition configuration to create a Grid object, by default None. + *args : Any + Additional positional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Any + The loaded dataset. + """ + given_paths = path if isinstance(path, list) else [path] + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + if grid_definition is not None: + grid = grid_registry.from_config(grid_definition) + else: + grid = None + + ds = from_source("empty") + dates = [d.isoformat() for d in dates] + + for path in given_paths: + + # do not substitute if not needed + if "{" not in path: + paths = [path] + else: + paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) + + for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): + if name in kwargs: + raise ValueError(f"MARS interpolation parameter '{name}' not supported") + + for path in _expand(paths): + context.trace("📁", "PATH", path) + s = from_source("file", path) + if flavour is not None: + s = flavour.map(s) + sel_kwargs = kwargs.copy() + if dates != []: + sel_kwargs["valid_datetime"] = dates + s = s.sel(**sel_kwargs) + ds = ds + s + + if kwargs and not context.partial_ok: + check(ds, given_paths, valid_datetime=dates, **kwargs) + + if grid is not None: + ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) + + if len(ds) == 0: + LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") + + return ds diff --git a/src/anemoi/datasets/build/gridded/sources/grib_index.py b/src/anemoi/datasets/build/gridded/sources/grib_index.py index ea6878929..0d86732f6 100644 --- a/src/anemoi/datasets/build/gridded/sources/grib_index.py +++ b/src/anemoi/datasets/build/gridded/sources/grib_index.py @@ -19,7 +19,8 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource LOG = logging.getLogger(__name__) @@ -569,44 +570,47 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]: yield data -@legacy_source(__file__) -def execute( - context: Any, - dates: list[Any], - indexdb: str, - flavour: str | None = None, - **kwargs: Any, -) -> FieldArray: - """Execute the GRIB data retrieval process. - - Parameters - ---------- - context : Any - The execution context. - dates : List[Any] - List of dates to retrieve data for. - indexdb : str - Path to the GRIB index database. - flavour : Optional[str], optional - Flavour configuration for mapping fields, by default None. - **kwargs : Any - Additional filtering criteria. - - Returns - ------- - FieldArray - An array of retrieved GRIB fields. - """ - index = GribIndex(indexdb) - result = [] - - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - for grib in index.retrieve(dates, **kwargs): - field = ekd.from_source("memory", grib)[0] - if flavour: - field = flavour.apply(field) - result.append(field) - - return FieldArray(result) +@source_registry.register("grib_index") +class GribIndexSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[Any], + indexdb: str, + flavour: str | None = None, + **kwargs: Any, + ) -> FieldArray: + """Execute the GRIB data retrieval process. + + Parameters + ---------- + context : Any + The execution context. + dates : List[Any] + List of dates to retrieve data for. + indexdb : str + Path to the GRIB index database. + flavour : Optional[str], optional + Flavour configuration for mapping fields, by default None. + **kwargs : Any + Additional filtering criteria. + + Returns + ------- + FieldArray + An array of retrieved GRIB fields. + """ + index = GribIndex(indexdb) + result = [] + + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + for grib in index.retrieve(dates, **kwargs): + field = ekd.from_source("memory", grib)[0] + if flavour: + field = flavour.apply(field) + result.append(field) + + return FieldArray(result) diff --git a/src/anemoi/datasets/build/gridded/sources/hindcasts.py b/src/anemoi/datasets/build/gridded/sources/hindcasts.py index 3a7f5eac8..a61a00d12 100644 --- a/src/anemoi/datasets/build/gridded/sources/hindcasts.py +++ b/src/anemoi/datasets/build/gridded/sources/hindcasts.py @@ -12,7 +12,9 @@ from earthkit.data.core.fieldlist import MultiFieldList -from .legacy import legacy_source +from anemoi.datasets.build.gridded.sources import source_registry + +from .legacy import LegacySource from .mars import mars LOGGER = logging.getLogger(__name__) @@ -36,57 +38,57 @@ def _to_list(x: list | tuple | Any) -> list[Any]: return [x] -@legacy_source(__file__) -def hindcasts(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: - """Generates hindcast requests based on the provided dates and request parameters. - - Parameters - ---------- - context : Any - The context containing the dates provider and trace method. - dates : List[Any] - A list of dates for which to generate hindcast requests. - request : Dict[str, Any] - Additional request parameters. - - Returns - ------- - MultiFieldList - A MultiFieldList containing the hindcast data. - """ - from anemoi.datasets.dates import HindcastsDates - - provider = context.dates_provider - assert isinstance(provider, HindcastsDates) - - context.trace("H️", f"hindcasts {len(dates)=}") - - request["param"] = _to_list(request["param"]) - request["step"] = _to_list(request.get("step", 0)) - request["step"] = [int(_) for _ in request["step"]] - - context.trace("H️", f"hindcast {request}") - - requests = [] - for d in dates: - r = request.copy() - hindcast = provider.mapping[d] - r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") - r["date"] = hindcast.refdate.strftime("%Y-%m-%d") - r["time"] = hindcast.refdate.strftime("%H") - r["step"] = hindcast.step - requests.append(r) - - if len(requests) == 0: - return MultiFieldList([]) - - return mars( - context, - dates, - *requests, - date_key="hdate", - request_already_using_valid_datetime=True, - ) - - -execute = hindcasts +@source_registry.register("hindcasts") +class HindcastsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: + """Generates hindcast requests based on the provided dates and request parameters. + + Parameters + ---------- + context : Any + The context containing the dates provider and trace method. + dates : List[Any] + A list of dates for which to generate hindcast requests. + request : Dict[str, Any] + Additional request parameters. + + Returns + ------- + MultiFieldList + A MultiFieldList containing the hindcast data. + """ + from anemoi.datasets.dates import HindcastsDates + + provider = context.dates_provider + assert isinstance(provider, HindcastsDates) + + context.trace("H️", f"hindcasts {len(dates)=}") + + request["param"] = _to_list(request["param"]) + request["step"] = _to_list(request.get("step", 0)) + request["step"] = [int(_) for _ in request["step"]] + + context.trace("H️", f"hindcast {request}") + + requests = [] + for d in dates: + r = request.copy() + hindcast = provider.mapping[d] + r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") + r["date"] = hindcast.refdate.strftime("%Y-%m-%d") + r["time"] = hindcast.refdate.strftime("%H") + r["step"] = hindcast.step + requests.append(r) + + if len(requests) == 0: + return MultiFieldList([]) + + return mars( + context, + dates, + *requests, + date_key="hdate", + request_already_using_valid_datetime=True, + ) diff --git a/src/anemoi/datasets/build/gridded/sources/legacy.py b/src/anemoi/datasets/build/gridded/sources/legacy.py index 4dbd481cd..f9a0288a0 100644 --- a/src/anemoi/datasets/build/gridded/sources/legacy.py +++ b/src/anemoi/datasets/build/gridded/sources/legacy.py @@ -8,14 +8,13 @@ # nor does it submit to any jurisdiction. -import inspect import logging -import os -from collections.abc import Callable +from abc import abstractmethod from typing import Any +from anemoi.datasets.create.input.context import Context + from ..source import Source -from . import source_registry LOG = logging.getLogger(__name__) @@ -25,7 +24,7 @@ class LegacySource(Source): Parameters ---------- - context : Any + context : Context The context in which the source is created. *args : tuple Positional arguments. @@ -33,64 +32,15 @@ class LegacySource(Source): Keyword arguments. """ - def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None: + def __init__(self, context: Context, *args: Any, **kwargs: Any) -> None: super().__init__(context, *args, **kwargs) self.args = args self.kwargs = kwargs + @staticmethod + @abstractmethod + def _execute(context, *args, **kwargs): + pass -class legacy_source: - """A decorator class for legacy sources. - - Parameters - ---------- - name : str - The name of the legacy source. - """ - - def __init__(self, name: str) -> None: - name, _ = os.path.splitext(os.path.basename(name)) - self.name = name - - def __call__(self, execute: Callable) -> Callable: - """Call method to wrap the execute function. - - Parameters - ---------- - execute : function - The execute function to be wrapped. - - Returns - ------- - function - The wrapped execute function. - """ - this = self - name = f"Legacy{self.name.title()}Source" - source = ".".join([execute.__module__, execute.__name__]) - - def execute_wrapper(self, dates) -> Any: - """Wrapper method to call the execute function.""" - - args, kwargs = self.args, self.kwargs - - try: - return execute(self.context, dates, *args, **kwargs) - except TypeError: - LOG.error(f"Error executing source {this.name} from {source}") - LOG.error(f"Function signature is: {inspect.signature(execute)}") - LOG.error(f"Arguments are: {args=}, {kwargs=}") - raise - - klass = type( - name, - (LegacySource,), - { - "execute": execute_wrapper, - "_source": source, - }, - ) - - source_registry.register(self.name)(klass) - - return execute + def execute(self, dates: Any) -> Any: + return self._execute(self.context, dates, *self.args, **self.kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/mars.py b/src/anemoi/datasets/build/gridded/sources/mars.py index a92205ac1..a2804e77a 100644 --- a/src/anemoi/datasets/build/gridded/sources/mars.py +++ b/src/anemoi/datasets/build/gridded/sources/mars.py @@ -16,9 +16,9 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.build.gridded.utils import to_datetime_list +from anemoi.datasets.build.gridded.sources import source_registry -from .legacy import legacy_source +from .legacy import LegacySource DEBUG = False @@ -358,135 +358,111 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]: ] -@legacy_source(__file__) -def mars( - context: Any, - dates: list[datetime.datetime], - *requests: dict[str, Any], - request_already_using_valid_datetime: bool = False, - date_key: str = "date", - use_cdsapi_dataset: str | None = None, - **kwargs: Any, -) -> Any: - """Executes MARS requests based on the given context, dates, and other parameters. - - Parameters - ---------- - context : Any - The context for the requests. - dates : List[datetime.datetime] - The list of dates to be used in the requests. - requests : Dict[str, Any] - The input requests to be executed. - request_already_using_valid_datetime : bool, optional - Flag indicating if the requests already use valid datetime. - date_key : str, optional - The key for the date in the requests. - use_cdsapi_dataset : Optional[str], optional - The dataset to be used with CDS API. - kwargs : Any - Additional keyword arguments for the requests. - - Returns - ------- - Any - The resulting dataset. - """ - - if not requests: - requests = [kwargs] - - for r in requests: - param = r.get("param", []) - if not isinstance(param, (list, tuple)): - param = [param] - # check for "Norway bug" where yaml transforms 'no' into False, etc. - for p in param: - if p is False: - raise ValueError( - "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" - ) - if p is None: - raise ValueError( - "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" - ) - if p is True: - raise ValueError( - "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" - ) - - if len(dates) == 0: # When using `repeated_dates` - assert len(requests) == 1, requests - assert "date" in requests[0], requests[0] - if isinstance(requests[0]["date"], datetime.date): - requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") - else: - requests = factorise_requests( - dates, - *requests, - request_already_using_valid_datetime=request_already_using_valid_datetime, - date_key=date_key, - ) - - requests = list(requests) - - ds = from_source("empty") - context.trace("✅", f"{[str(d) for d in dates]}") - context.trace("✅", f"Will run {len(requests)} requests") - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - context.trace("✅", f"mars {r}") - - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - - if context.use_grib_paramid and "param" in r: - r = use_grib_paramid(r) - - for k, v in r.items(): - if k not in MARS_KEYS: - raise ValueError( - f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" - ) - try: - if use_cdsapi_dataset: - ds = ds + from_source("cds", use_cdsapi_dataset, r) - else: - ds = ds + from_source("mars", **r) - except Exception as e: - if "File is empty:" not in str(e): - raise - return ds - - -execute = mars - - -if __name__ == "__main__": - import yaml - - config = yaml.safe_load( +@source_registry.register("mars") +class MarsSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[datetime.datetime], + *requests: dict[str, Any], + request_already_using_valid_datetime: bool = False, + date_key: str = "date", + use_cdsapi_dataset: str | None = None, + **kwargs: Any, + ) -> Any: + """Executes MARS requests based on the given context, dates, and other parameters. + + Parameters + ---------- + context : Any + The context for the requests. + dates : List[datetime.datetime] + The list of dates to be used in the requests. + requests : Dict[str, Any] + The input requests to be executed. + request_already_using_valid_datetime : bool, optional + Flag indicating if the requests already use valid datetime. + date_key : str, optional + The key for the date in the requests. + use_cdsapi_dataset : Optional[str], optional + The dataset to be used with CDS API. + kwargs : Any + Additional keyword arguments for the requests. + + Returns + ------- + Any + The resulting dataset. """ - - class: ea - expver: '0001' - grid: 20.0/20.0 - levtype: sfc - param: [2t] - # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z] - number: [0, 1] - - # - class: ea - # expver: '0001' - # grid: 20.0/20.0 - # levtype: pl - # param: [q] - # levelist: [1000, 850] - """ - ) - dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) + if not requests: + requests = [kwargs] + + for r in requests: + param = r.get("param", []) + if not isinstance(param, (list, tuple)): + param = [param] + # check for "Norway bug" where yaml transforms 'no' into False, etc. + for p in param: + if p is False: + raise ValueError( + "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" + ) + if p is None: + raise ValueError( + "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" + ) + if p is True: + raise ValueError( + "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" + ) + + if len(dates) == 0: # When using `repeated_dates` + assert len(requests) == 1, requests + assert "date" in requests[0], requests[0] + if isinstance(requests[0]["date"], datetime.date): + requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") + else: + requests = factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=request_already_using_valid_datetime, + date_key=date_key, + ) - DEBUG = True - for f in mars(None, dates, *config): - print(f, f.to_numpy().mean()) + requests = list(requests) + + ds = from_source("empty") + context.trace("✅", f"{[str(d) for d in dates]}") + context.trace("✅", f"Will run {len(requests)} requests") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + context.trace("✅", f"mars {r}") + + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + + if context.use_grib_paramid and "param" in r: + r = use_grib_paramid(r) + + for k, v in r.items(): + if k not in MARS_KEYS: + raise ValueError( + f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" + ) + try: + if use_cdsapi_dataset: + ds = ds + from_source("cds", use_cdsapi_dataset, r) + else: + ds = ds + from_source("mars", **r) + except Exception as e: + if "File is empty:" not in str(e): + raise + return ds + + +# TODO: make clearer the interface between sources that use mars. +# Currently some sources use mars as a function rather than through the registry, +# e.g. accumulations, accumulations2, hindcasts, recentre, tendencies +mars = MarsSource._execute diff --git a/src/anemoi/datasets/build/gridded/sources/netcdf.py b/src/anemoi/datasets/build/gridded/sources/netcdf.py index a73c095d3..e6f4271a7 100644 --- a/src/anemoi/datasets/build/gridded/sources/netcdf.py +++ b/src/anemoi/datasets/build/gridded/sources/netcdf.py @@ -12,30 +12,34 @@ import earthkit.data as ekd -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource from .xarray import load_many -@legacy_source(__file__) -def execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the loading of multiple NetCDF files. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - path : str - Path to the directory containing the NetCDF files. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - object - The loaded data. - """ - return load_many("📁", context, dates, path, *args, **kwargs) +@source_registry.register("netcdf") +class NetCDFSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the loading of multiple NetCDF files. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + path : str + Path to the directory containing the NetCDF files. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + object + The loaded data. + """ + return load_many("📁", context, dates, path, *args, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/opendap.py b/src/anemoi/datasets/build/gridded/sources/opendap.py index 483295a8b..86cd3e6d2 100644 --- a/src/anemoi/datasets/build/gridded/sources/opendap.py +++ b/src/anemoi/datasets/build/gridded/sources/opendap.py @@ -12,30 +12,34 @@ import earthkit.data as ekd -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource from .xarray import load_many -@legacy_source(__file__) -def execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process from an OpenDAP source. - - Parameters - ---------- - context : dict - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - url : str - The URL of the OpenDAP source. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - xarray.Dataset - The loaded dataset. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +@source_registry.register("opendap") +class OpenDAPSource(LegacySource): + + @staticmethod + def _execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process from an OpenDAP source. + + Parameters + ---------- + context : dict + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + url : str + The URL of the OpenDAP source. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + xarray.Dataset + The loaded dataset. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/recentre.py b/src/anemoi/datasets/build/gridded/sources/recentre.py index 53ace8152..2d6c70b1d 100644 --- a/src/anemoi/datasets/build/gridded/sources/recentre.py +++ b/src/anemoi/datasets/build/gridded/sources/recentre.py @@ -12,7 +12,8 @@ from anemoi.datasets.compute.recentre import recentre as _recentre -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource from .mars import mars @@ -105,43 +106,43 @@ def load_if_needed(context: Any, dates: Any, dict_or_dataset: dict | Any) -> Any return dict_or_dataset -@legacy_source(__file__) -def recentre( - context: Any, - dates: Any, - members: dict | Any, - centre: dict | Any, - alpha: float = 1.0, - remapping: dict = {}, - patches: dict = {}, -) -> Any: - """Recentres the members dataset using the centre dataset. - - Parameters - ---------- - context : Any - The context for recentering. - dates : Any - The dates for recentering. - members : Union[dict, Any] - The members dataset or request dictionary. - centre : Union[dict, Any] - The centre dataset or request dictionary. - alpha : float, optional - The alpha value for recentering. Defaults to 1.0. - remapping : dict, optional - The remapping dictionary. Defaults to {}. - patches : dict, optional - The patches dictionary. Defaults to {}. - - Returns - ------- - Any - The recentred dataset. - """ - members = load_if_needed(context, dates, members) - centre = load_if_needed(context, dates, centre) - return _recentre(members=members, centre=centre, alpha=alpha) - - -execute = recentre +@source_registry.register("recentre") +class RecentreSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: Any, + members: dict | Any, + centre: dict | Any, + alpha: float = 1.0, + remapping: dict = {}, + patches: dict = {}, + ) -> Any: + """Recentres the members dataset using the centre dataset. + + Parameters + ---------- + context : Any + The context for recentering. + dates : Any + The dates for recentering. + members : Union[dict, Any] + The members dataset or request dictionary. + centre : Union[dict, Any] + The centre dataset or request dictionary. + alpha : float, optional + The alpha value for recentering. Defaults to 1.0. + remapping : dict, optional + The remapping dictionary. Defaults to {}. + patches : dict, optional + The patches dictionary. Defaults to {}. + + Returns + ------- + Any + The recentred dataset. + """ + members = load_if_needed(context, dates, members) + centre = load_if_needed(context, dates, centre) + return _recentre(members=members, centre=centre, alpha=alpha) diff --git a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py index d337cead8..2714a0d10 100644 --- a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py +++ b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py @@ -19,8 +19,8 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from .source import Source -from .sources import source_registry +from ..source import Source +from ..sources import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/sources/source.py b/src/anemoi/datasets/build/gridded/sources/source.py index 3338daf02..1ad5850a7 100644 --- a/src/anemoi/datasets/build/gridded/sources/source.py +++ b/src/anemoi/datasets/build/gridded/sources/source.py @@ -12,58 +12,36 @@ from earthkit.data import from_source -from anemoi.datasets.build.utils import to_datetime_list +from anemoi.datasets.build.gridded.sources import source_registry -from .legacy import legacy_source +from .legacy import LegacySource -@legacy_source(__file__) -def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: - """Generates a source based on the provided context, dates, and additional keyword arguments. +@source_registry.register("source") +class GenericSource(LegacySource): - Parameters - ---------- - context : Optional[Any] - The context in which the source is generated. - dates : List[datetime] - A list of datetime objects representing the dates. - **kwargs : Any - Additional keyword arguments for the source generation. + @staticmethod + def _execute(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: + """Generates a source based on the provided context, dates, and additional keyword arguments. - Returns - ------- - Any - The generated source. - """ - name = kwargs.pop("name") - context.trace("✅", f"from_source({name}, {dates}, {kwargs}") - if kwargs["date"] == "$from_dates": - kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates}) - if kwargs["time"] == "$from_dates": - kwargs["time"] = list({d.strftime("%H%M") for d in dates}) - return from_source(name, **kwargs) + Parameters + ---------- + context : Optional[Any] + The context in which the source is generated. + dates : List[datetime] + A list of datetime objects representing the dates. + **kwargs : Any + Additional keyword arguments for the source generation. - -execute = source - -if __name__ == "__main__": - import yaml - - config: dict[str, Any] = yaml.safe_load( + Returns + ------- + Any + The generated source. """ - name: mars - class: ea - expver: '0001' - grid: 20.0/20.0 - levtype: sfc - param: [2t] - number: [0, 1] - date: $from_dates - time: $from_dates - """ - ) - dates: list[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - - for f in source(None, dates, **config): - print(f, f.to_numpy().mean()) + name = kwargs.pop("name") + context.trace("✅", f"from_source({name}, {dates}, {kwargs}") + if kwargs["date"] == "$from_dates": + kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates}) + if kwargs["time"] == "$from_dates": + kwargs["time"] = list({d.strftime("%H%M") for d in dates}) + return from_source(name, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/tendencies.py b/src/anemoi/datasets/build/gridded/sources/tendencies.py index 2f357b008..69c06a78c 100644 --- a/src/anemoi/datasets/build/gridded/sources/tendencies.py +++ b/src/anemoi/datasets/build/gridded/sources/tendencies.py @@ -14,9 +14,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.utils import to_datetime_list +from anemoi.datasets.build.gridded.sources import source_registry -from .legacy import legacy_source +from .legacy import LegacySource def _date_to_datetime(d: Any) -> Any: @@ -83,116 +83,89 @@ def group_by_field(ds: Any) -> dict[tuple, list[Any]]: return d -@legacy_source(__file__) -def tendencies(dates: list[datetime.datetime], time_increment: Any, **kwargs: Any) -> Any: - """Computes tendencies for the given dates and time increment. +@source_registry.register("tendencies") +class TendenciesSource(LegacySource): - Parameters - ---------- - dates : List[datetime.datetime] - A list of datetime objects. - time_increment : Any - A time increment string ending with 'h' or a datetime.timedelta object. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - A dataset object with computed tendencies. - """ - print("✅", kwargs) - time_increment = normalise_time_delta(time_increment) - - shifted_dates = [d - time_increment for d in dates] - all_dates = sorted(list(set(dates + shifted_dates))) + @staticmethod + def _execute(dates: list[datetime.datetime], time_increment: Any, **kwargs: Any) -> Any: + """Computes tendencies for the given dates and time increment. - # from .mars import execute as mars - from anemoi.datasets.build.mars import execute as mars + Parameters + ---------- + dates : List[datetime.datetime] + A list of datetime objects. + time_increment : Any + A time increment string ending with 'h' or a datetime.timedelta object. + **kwargs : Any + Additional keyword arguments. - ds = mars(dates=all_dates, **kwargs) - - dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"] - for d in all_dates: - assert d.isoformat() in dates_in_data, d - - ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates]) - ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates]) - - assert len(ds1) == len(ds2), (len(ds1), len(ds2)) - - group1 = group_by_field(ds1) - group2 = group_by_field(ds2) + Returns + ------- + Any + A dataset object with computed tendencies. + """ + print("✅", kwargs) + time_increment = normalise_time_delta(time_increment) - assert group1.keys() == group2.keys(), (group1.keys(), group2.keys()) + shifted_dates = [d - time_increment for d in dates] + all_dates = sorted(list(set(dates + shifted_dates))) - # prepare output tmp file so we can read it back - tmp = temp_file() - path = tmp.path - out = new_grib_output(path) + from .mars import mars - for k in group1: - assert len(group1[k]) == len(group2[k]), k - print() - print("❌", k) + ds = mars(dates=all_dates, **kwargs) - for field, b_field in zip(group1[k], group2[k]): - for k in ["param", "level", "number", "grid", "shape"]: - assert field.metadata(k) == b_field.metadata(k), ( - k, - field.metadata(k), - b_field.metadata(k), - ) + dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"] + for d in all_dates: + assert d.isoformat() in dates_in_data, d - c = field.to_numpy() - b = b_field.to_numpy() - assert c.shape == b.shape, (c.shape, b.shape) + ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates]) + ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates]) - ################ - # Actual computation happens here - x = c - b - ################ + assert len(ds1) == len(ds2), (len(ds1), len(ds2)) - assert x.shape == c.shape, c.shape - print(f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}") - out.write(x, template=field) + group1 = group_by_field(ds1) + group2 = group_by_field(ds2) - out.close() + assert group1.keys() == group2.keys(), (group1.keys(), group2.keys()) - from earthkit.data import from_source + # prepare output tmp file so we can read it back + tmp = temp_file() + path = tmp.path + out = new_grib_output(path) - ds = from_source("file", path) - # save a reference to the tmp file so it is deleted - # only when the dataset is not used anymore - ds._tmp = tmp + for k in group1: + assert len(group1[k]) == len(group2[k]), k + print() + print("❌", k) - return ds + for field, b_field in zip(group1[k], group2[k]): + for k in ["param", "level", "number", "grid", "shape"]: + assert field.metadata(k) == b_field.metadata(k), ( + k, + field.metadata(k), + b_field.metadata(k), + ) + c = field.to_numpy() + b = b_field.to_numpy() + assert c.shape == b.shape, (c.shape, b.shape) -execute = tendencies + ################ + # Actual computation happens here + x = c - b + ################ -if __name__ == "__main__": - import yaml + assert x.shape == c.shape, c.shape + print(f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}") + out.write(x, template=field) - config = yaml.safe_load( - """ + out.close() - config: - time_increment: 12h - database: marser - class: ea - # date: computed automatically - # time: computed automatically - expver: "0001" - grid: 20.0/20.0 - levtype: sfc - param: [2t] - """ - )["config"] + from earthkit.data import from_source - dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) + ds = from_source("file", path) + # save a reference to the tmp file so it is deleted + # only when the dataset is not used anymore + ds._tmp = tmp - DEBUG = True - for f in tendencies(dates, **config): - print(f, f.to_numpy().mean()) + return ds diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py index 4bfd4f76b..e0f4a7e75 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py @@ -17,7 +17,8 @@ from anemoi.datasets.build.gridded.sources.patterns import iterate_patterns -from ..legacy import legacy_source +from .. import source_registry +from ..legacy import LegacySource from .fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) @@ -152,26 +153,30 @@ def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern: return MultiFieldList(result) -@legacy_source("xarray") -def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the loading of datasets. - - Parameters - ---------- - context : Any - Context object. - dates : List[str] - List of dates. - url : str - URL pattern for loading datasets. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded datasets. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +@source_registry.register("xarray") +class LegacyXarraySource(LegacySource): + name = "xarray" + + @staticmethod + def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the loading of datasets. + + Parameters + ---------- + context : Any + Context object. + dates : List[str] + List of dates. + url : str + URL pattern for loading datasets. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded datasets. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py index e91de781e..2e89981bd 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py +++ b/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py @@ -11,30 +11,34 @@ import earthkit.data as ekd -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource from .xarray import load_many -@legacy_source(__file__) -def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process. - - Parameters - ---------- - context : Any - The context in which the execution occurs. - dates : List[str] - List of dates for which data is to be loaded. - url : str - The URL from which data is to be loaded. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded data. - """ - return load_many("🇿", context, dates, url, *args, **kwargs) +@source_registry.register("xarray_zarr") +class XarrayZarrSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process. + + Parameters + ---------- + context : Any + The context in which the execution occurs. + dates : List[str] + List of dates for which data is to be loaded. + url : str + The URL from which data is to be loaded. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded data. + """ + return load_many("🇿", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/build/gridded/sources/zenodo.py b/src/anemoi/datasets/build/gridded/sources/zenodo.py index 1b746bb42..9f4d68f97 100644 --- a/src/anemoi/datasets/build/gridded/sources/zenodo.py +++ b/src/anemoi/datasets/build/gridded/sources/zenodo.py @@ -14,54 +14,58 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from .legacy import legacy_source +from . import source_registry +from .legacy import LegacySource from .patterns import iterate_patterns from .xarray import load_one -@legacy_source(__file__) -def execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the download and processing of files from Zenodo. +@source_registry.register("zenodo") +class ZenodoSource(LegacySource): - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : Any - The dates for which the data is required. - record_id : str - The Zenodo record ID. - file_key : str - The key to identify the file. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. + @staticmethod + def _execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the download and processing of files from Zenodo. - Returns - ------- - MultiFieldList - A list of fields loaded from the downloaded files. - """ - import requests + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : Any + The dates for which the data is required. + record_id : str + The Zenodo record ID. + file_key : str + The key to identify the file. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. - result: list[Any] = [] + Returns + ------- + MultiFieldList + A list of fields loaded from the downloaded files. + """ + import requests - URLPATTERN = "https://zenodo.org/api/records/{record_id}" - url = URLPATTERN.format(record_id=record_id) - r = requests.get(url) - r.raise_for_status() - record: dict[str, Any] = r.json() + result: list[Any] = [] - urls: dict[str, str] = {} - for file in record["files"]: - urls[file["key"]] = file["links"]["self"] + URLPATTERN = "https://zenodo.org/api/records/{record_id}" + url = URLPATTERN.format(record_id=record_id) + r = requests.get(url) + r.raise_for_status() + record: dict[str, Any] = r.json() - for url, dates in iterate_patterns(file_key, dates, **kwargs): - if url not in urls: - continue + urls: dict[str, str] = {} + for file in record["files"]: + urls[file["key"]] = file["links"]["self"] - path = download_and_cache(urls[url]) - result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) + for url, dates in iterate_patterns(file_key, dates, **kwargs): + if url not in urls: + continue - return MultiFieldList(result) + path = download_and_cache(urls[url]) + result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) + + return MultiFieldList(result) diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 820d73635..212987839 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -90,7 +90,7 @@ def _check_name(self, name: str) -> None: def _check_zarr(self, zarr: str) -> None: - from anemoi.datasets.check import check_zarr + from anemoi.datasets.misc.check import check_zarr check_zarr(zarr) diff --git a/src/anemoi/datasets/use/tabular/observations/__init__.py b/src/anemoi/datasets/use/tabular/observations/__init__.py index 7d1c278f9..004d9299c 100644 --- a/src/anemoi/datasets/use/tabular/observations/__init__.py +++ b/src/anemoi/datasets/use/tabular/observations/__init__.py @@ -176,7 +176,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from anemoi.datasets.use.gridded.observations.legacy_obs_dataset import ObsDataset + from anemoi.datasets.use.tabular.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/tools/build-obs.py b/tools/build-obs.py index 5013763cb..d29339cda 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,7 +28,7 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") - from anemoi.datasets.use.gridded.records.backends import writer_backend_factory + from anemoi.datasets.use.tabular.records.backends import writer_backend_factory if os.path.exists(output): if overwrite: From 69e7daa60712d53dfbaa81492b872190cb9c9d96 Mon Sep 17 00:00:00 2001 From: Aaron Hopkinson Date: Tue, 21 Oct 2025 13:53:40 +0100 Subject: [PATCH 166/212] Update broken imports --- src/anemoi/datasets/build/gridded/sources/accumulations.py | 4 ++-- src/anemoi/datasets/build/gridded/sources/legacy.py | 2 +- src/anemoi/datasets/build/gridded/sources/repeated_dates.py | 2 +- tools/grids/grids_multilam.ipynb | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/build/gridded/sources/accumulations.py index 2d0124a81..86adea4d1 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/build/gridded/sources/accumulations.py @@ -998,12 +998,12 @@ def _execute( and request.get("stream", "oper") == "oper" and request.get("accumulation_period") == 24 ): - from .accumulations2 import accumulations as accumulations2 + from .accumulations2 import Accumulations2Source LOG.warning( "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" ) - return accumulations2(context, dates, **request) + return Accumulations2Source._execute(context, dates, **request) _to_list(request["param"]) class_ = request.get("class", "od") diff --git a/src/anemoi/datasets/build/gridded/sources/legacy.py b/src/anemoi/datasets/build/gridded/sources/legacy.py index f9a0288a0..d4110cf5b 100644 --- a/src/anemoi/datasets/build/gridded/sources/legacy.py +++ b/src/anemoi/datasets/build/gridded/sources/legacy.py @@ -12,7 +12,7 @@ from abc import abstractmethod from typing import Any -from anemoi.datasets.create.input.context import Context +from anemoi.datasets.build.input.context import Context from ..source import Source diff --git a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py index 9b297e193..509ee4966 100644 --- a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py +++ b/src/anemoi/datasets/build/gridded/sources/repeated_dates.py @@ -14,7 +14,7 @@ from anemoi.transform.fields import new_field_with_valid_datetime from anemoi.transform.fields import new_fieldlist_from_list -from anemoi.datasets.create.input.repeated_dates import DateMapper +from anemoi.datasets.build.input.repeated_dates import DateMapper from ..source import Source from ..sources import source_registry diff --git a/tools/grids/grids_multilam.ipynb b/tools/grids/grids_multilam.ipynb index bb212bc4a..f6b6f5355 100644 --- a/tools/grids/grids_multilam.ipynb +++ b/tools/grids/grids_multilam.ipynb @@ -8,7 +8,7 @@ "source": [ "import numpy as np\n", "from anemoi.datasets import open_dataset\n", - "from anemoi.datasets.data.grids import Cutout" + "from anemoi.datasets.use.gridded.grids import Cutout" ] }, { From b478d28162b706e823318c8510831388031e0cbe Mon Sep 17 00:00:00 2001 From: Aaron Hopkinson Date: Tue, 21 Oct 2025 17:48:12 +0100 Subject: [PATCH 167/212] Fix imports --- src/anemoi/datasets/use/gridded/__init__.py | 2 +- src/anemoi/datasets/use/gridded/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/use/gridded/__init__.py b/src/anemoi/datasets/use/gridded/__init__.py index 9caa9e053..dbbfcd9a5 100644 --- a/src/anemoi/datasets/use/gridded/__init__.py +++ b/src/anemoi/datasets/use/gridded/__init__.py @@ -95,7 +95,7 @@ def open_dataset(*args: Any, **kwargs: Any) -> "Dataset": ds._check() if trace: - from anemoi.datasets.misc import Trace + from anemoi.datasets.misc.testing import Trace ds = Trace(ds) diff --git a/src/anemoi/datasets/use/gridded/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py index 5a78c7b5b..d52a2753d 100644 --- a/src/anemoi/datasets/use/gridded/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -294,7 +294,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() if "rolling_average" in kwargs: - from .rolling_average import RollingAverage + from anemoi.datasets.use.gridded.rolling_average import RollingAverage rolling_average = kwargs.pop("rolling_average") return RollingAverage(self, rolling_average)._subset(**kwargs).mutate() From c02ecafded36026317fca23aefb400cd834ec311 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 22 Oct 2025 09:16:28 +0000 Subject: [PATCH 168/212] fix s3 access --- src/anemoi/datasets/data/stores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index bf1e74cad..4f39f923b 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -101,7 +101,7 @@ def __getitem__(self, key: str) -> bytes: target = self.url + "/" + key try: - return get_object(target).bytes() + return get_object(target) except FileNotFoundError: raise KeyError(target) From 0f56af3636952e49734a066fc360471fb27b17b0 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 23 Oct 2025 10:08:30 +0200 Subject: [PATCH 169/212] fix frequency for observations --- src/anemoi/datasets/data/records/__init__.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 60d1d23cc..eae3a8390 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -268,10 +268,20 @@ def _window(self): def __len__(self): return len(self.dataset) * self._n + @property + def dates(self): + dates = [] + for date in self.dataset.dates: + dates += [date + i * self._frequency for i in range(self._n)] + return dates + @property def frequency(self): return self._frequency + def metadata(self): + return self.dataset.metadata + def _load_data(self, i): j = i // self._n k = i % self._n @@ -284,8 +294,8 @@ def _load_data(self, i): if timedeltas.dtype != "timedelta64[s]": raise ValueError(f"Wrong type for {group}") - start_delta = k * self.frequency + self._window.start - end_delta = k * self.frequency + self._window.end + start_delta = self.dataset._window.start + k * self.frequency + end_delta = start_delta + self._window.end - self._window.start def _to_numpy_timedelta(td): if isinstance(td, np.timedelta64): @@ -295,8 +305,7 @@ def _to_numpy_timedelta(td): start_delta = _to_numpy_timedelta(start_delta) end_delta = _to_numpy_timedelta(end_delta) - assert isinstance(start_delta, np.timedelta64), (type(start_delta), start_delta) - assert isinstance(timedeltas[0], np.timedelta64), type(timedeltas[0]) + assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" if self._window.include_start: mask = timedeltas >= start_delta From 8752edb610e4494d02b31585ddce6d83b31f445c Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 10 Nov 2025 14:58:03 +0000 Subject: [PATCH 170/212] fix bug --- src/anemoi/datasets/data/records/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index eae3a8390..08aead42f 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -144,10 +144,13 @@ def _subset(self, **kwargs): frequency = kwargs.pop("frequency", self.frequency) if frequency: + frequency = frequency_to_timedelta(frequency) - if self.frequency.total_seconds() % frequency.total_seconds() == 0: + current = self.frequency.total_seconds() + new = frequency.total_seconds() + if current != new and current % new == 0: return IncreaseFrequency(self, frequency) - elif frequency.total_seconds() % self.frequency.total_seconds() == 0: + elif current != new and new % current == 0: raise NotImplementedError("Decreasing frequency not implemented yet") # return DecreaseFrequency(self, frequency) assert self.frequency == frequency, (self.frequency, frequency) From 6f1a48c19440d3e75b48a0b3e024d2a5ceddebdb Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 10 Nov 2025 15:05:04 +0000 Subject: [PATCH 171/212] fix bug --- src/anemoi/datasets/data/records/__init__.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 08aead42f..8e687f5da 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -139,22 +139,20 @@ def groups(self): raise NotImplementedError("Must be implemented in subclass") def _subset(self, **kwargs): - start = kwargs.pop("start", None) - end = kwargs.pop("end", None) frequency = kwargs.pop("frequency", self.frequency) - if frequency: - frequency = frequency_to_timedelta(frequency) current = self.frequency.total_seconds() new = frequency.total_seconds() if current != new and current % new == 0: - return IncreaseFrequency(self, frequency) + return IncreaseFrequency(self, frequency)._subset(**kwargs) elif current != new and new % current == 0: raise NotImplementedError("Decreasing frequency not implemented yet") - # return DecreaseFrequency(self, frequency) + # return DecreaseFrequency(self, frequency)._subset(**kwargs) assert self.frequency == frequency, (self.frequency, frequency) + start = kwargs.pop("start", None) + end = kwargs.pop("end", None) if start is not None or end is not None: def _dates_to_indices(start, end): From 9d8b93f6313f7faa3922bcf4af22dd6d7fb5e4fc Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 10 Nov 2025 15:09:56 +0000 Subject: [PATCH 172/212] fix bug --- src/anemoi/datasets/data/records/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index 8e687f5da..a96a70812 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -685,7 +685,8 @@ def __init__(self, dataset, indices, reason): @cached_property def dates(self): - return self.dataset.dates[self._indices] + dates = self.dataset.dates + return [dates[i] for i in self._indices] def _load_data(self, i): return self.dataset._load_data(self._indices[i]) From e154dca80798f4020598ef6366eeb18cfd51902b Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 10 Nov 2025 21:05:49 +0000 Subject: [PATCH 173/212] more consistent dates --- src/anemoi/datasets/data/records/__init__.py | 71 +++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/src/anemoi/datasets/data/records/__init__.py b/src/anemoi/datasets/data/records/__init__.py index a96a70812..79150b32a 100644 --- a/src/anemoi/datasets/data/records/__init__.py +++ b/src/anemoi/datasets/data/records/__init__.py @@ -42,6 +42,13 @@ def counter(func): return func +def _to_numpy_timedelta(td): + if isinstance(td, np.timedelta64): + assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" + return td + return np.timedelta64(int(td.total_seconds()), "s") + + def open_records_dataset(dataset, **kwargs): metadata_path = os.path.join(dataset, "metadata.json") if not os.path.exists(metadata_path): @@ -139,6 +146,10 @@ def groups(self): raise NotImplementedError("Must be implemented in subclass") def _subset(self, **kwargs): + window = kwargs.pop("window", None) + if window is not None: + return Rewindowed(self, window)._subset(**kwargs) + frequency = kwargs.pop("frequency", self.frequency) if frequency: frequency = frequency_to_timedelta(frequency) @@ -172,10 +183,6 @@ def _dates_to_indices(start, end): if select is not None: return Select(self, select)._subset(**kwargs) - window = kwargs.pop("window", None) - if window is not None: - return Rewindowed(self, window)._subset(**kwargs) - set_group = kwargs.pop("set_group", None) if set_group is not None: return SetGroup(self, set_group)._subset(**kwargs) @@ -220,7 +227,7 @@ def groups(self): @property def dates(self): - return self.forward.dates + return np.array(self.forward.dates, dtype="datetime64[s]") @property def name_to_index(self): @@ -248,6 +255,7 @@ def tree(self): class IncreaseFrequency(RecordsForward): # change the frequency of a records dataset by splitting the windows to fit the new frequency # the new frequency must be a divisor of the original frequency (e.g. 6h -> 3h, but not 3h -> 6h) (and not 6h -> 5h) + # and the window length should match the frequency def __init__(self, dataset, frequency): super().__init__(dataset) self.dataset = dataset @@ -259,6 +267,11 @@ def __init__(self, dataset, frequency): raise ValueError(f"Cannot split frequency {self.dataset.frequency} to {frequency}, not a multiple") self._n = int(self._n) + if self.dataset._window.end - self.dataset._window.start != self.dataset.frequency: + raise ValueError( + f"Cannot split frequency {self.dataset.frequency} to {frequency}, window {self.dataset._window} does not match frequency" + ) + @cached_property def _window(self): previous = self.dataset._window @@ -272,9 +285,10 @@ def __len__(self): @property def dates(self): dates = [] + freq = _to_numpy_timedelta(self._frequency) for date in self.dataset.dates: - dates += [date + i * self._frequency for i in range(self._n)] - return dates + dates += [date + i * freq for i in range(self._n)] + return np.array(dates, dtype="datetime64[s]") @property def frequency(self): @@ -286,6 +300,26 @@ def metadata(self): def _load_data(self, i): j = i // self._n k = i % self._n + # k = 0 -> shift of (self._n - 1) * self.frequency + # k = ... + # k = self._n - 1 -> shift of 0 (0 * self.frequency) + # so we need to shift by (self._n - 1 - k) * self.frequency + assert k < self._n, (k, self._n) + assert k >= 0 + + s = self._window.start + e = self._window.end + + ref_timedelta = -self.dataset.frequency + (k + 1) * self.frequency + start_delta = ref_timedelta + s + end_delta = ref_timedelta + e + # print( + # f" {i}={j}*{self._n}+{k} ({self.dates[i]}) -> ref_timedelta={ref_timedelta.total_seconds()/3600}, [start, end] = [{start_delta.total_seconds()/3600}, {end_delta.total_seconds()/3600}]" + # ) + + start_delta = _to_numpy_timedelta(start_delta) + end_delta = _to_numpy_timedelta(end_delta) + ref_timedelta = _to_numpy_timedelta(ref_timedelta) too_much_data = self.dataset._load_data(j) @@ -295,19 +329,6 @@ def _load_data(self, i): if timedeltas.dtype != "timedelta64[s]": raise ValueError(f"Wrong type for {group}") - start_delta = self.dataset._window.start + k * self.frequency - end_delta = start_delta + self._window.end - self._window.start - - def _to_numpy_timedelta(td): - if isinstance(td, np.timedelta64): - assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" - return td - return np.timedelta64(int(td.total_seconds()), "s") - - start_delta = _to_numpy_timedelta(start_delta) - end_delta = _to_numpy_timedelta(end_delta) - assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" - if self._window.include_start: mask = timedeltas >= start_delta else: @@ -320,7 +341,7 @@ def _to_numpy_timedelta(td): out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] - out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] + out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] - ref_timedelta out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] return out @@ -383,7 +404,7 @@ def variables(self): @property def dates(self): - return self.forward.dates + return np.array(self.forward.dates, dtype="datetime64[s]") @property def longitudes(self): @@ -537,7 +558,7 @@ def window(self): @property def dates(self): - return self._dates + return np.array(self._dates, dtype="datetime64[s]") def __len__(self): return len(self.dates) @@ -686,7 +707,7 @@ def __init__(self, dataset, indices, reason): @cached_property def dates(self): dates = self.dataset.dates - return [dates[i] for i in self._indices] + return np.array([dates[i] for i in self._indices], dtype="datetime64[s]") def _load_data(self, i): return self.dataset._load_data(self._indices[i]) @@ -770,7 +791,7 @@ def dates(self): while d <= self.end_date: result.append(d) d += delta - return np.array(result) + return np.array(result, dtype="datetime64[s]") @counter def _load_data(self, i): From 3ab2564a33be7fdd358a767f8fe8277be4459542 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 11 Nov 2025 14:33:12 +0000 Subject: [PATCH 174/212] rename build to create --- src/anemoi/datasets/commands/check.py | 2 +- src/anemoi/datasets/commands/create.py | 2 +- src/anemoi/datasets/commands/grib-index.py | 2 +- .../datasets/commands/recipe/__init__.py | 2 +- .../datasets/commands/recipe/migrate.py | 2 +- .../datasets/{build => create}/__init__.py | 0 .../{build => create}/gridded/__init__.py | 42 +++++++++---------- .../{build => create}/gridded/check.py | 0 .../{build => create}/gridded/chunks.py | 0 .../{build => create}/gridded/config.py | 0 .../{build => create}/gridded/context.py | 4 +- .../{build => create}/gridded/patch.py | 0 .../{build => create}/gridded/persistent.py | 0 .../{build => create}/gridded/result.py | 2 +- .../{build => create}/gridded/size.py | 0 .../{build => create}/gridded/source.py | 2 +- .../gridded/sources/__init__.py | 0 .../gridded/sources/accumulations.py | 2 +- .../gridded/sources/accumulations2.py | 2 +- .../gridded/sources/anemoi_dataset.py | 0 .../gridded/sources/constants.py | 0 .../gridded/sources/eccc_fstd.py | 0 .../gridded/sources/empty.py | 0 .../{build => create}/gridded/sources/fdb.py | 2 +- .../gridded/sources/forcings.py | 0 .../{build => create}/gridded/sources/grib.py | 0 .../gridded/sources/grib_index.py | 0 .../gridded/sources/hindcasts.py | 2 +- .../gridded/sources/legacy.py | 2 +- .../{build => create}/gridded/sources/mars.py | 2 +- .../gridded/sources/netcdf.py | 0 .../gridded/sources/opendap.py | 0 .../gridded/sources/patterns.py | 0 .../gridded/sources/planetary_computer.py | 0 .../gridded/sources/recentre.py | 0 .../gridded/sources/repeated_dates.py | 2 +- .../gridded/sources/source.py | 2 +- .../gridded/sources/tendencies.py | 2 +- .../gridded/sources/xarray.py | 2 +- .../gridded/sources/xarray_kerchunk.py | 0 .../gridded/sources/xarray_support/README.md | 0 .../sources/xarray_support/__init__.py | 2 +- .../sources/xarray_support/coordinates.py | 0 .../gridded/sources/xarray_support/field.py | 0 .../sources/xarray_support/fieldlist.py | 0 .../gridded/sources/xarray_support/flavour.py | 0 .../gridded/sources/xarray_support/grid.py | 0 .../sources/xarray_support/metadata.py | 0 .../gridded/sources/xarray_support/patch.py | 0 .../gridded/sources/xarray_support/time.py | 0 .../sources/xarray_support/variable.py | 0 .../gridded/sources/xarray_zarr.py | 0 .../gridded/sources/zenodo.py | 0 .../gridded/statistics/__init__.py | 4 +- .../gridded/statistics/summary.py | 6 +-- .../{build => create}/gridded/testing.py | 0 .../{build => create}/gridded/typing.py | 0 .../{build => create}/gridded/utils.py | 0 .../{build => create}/gridded/writer.py | 0 .../{build => create}/gridded/zarr.py | 0 .../{build => create}/input/__init__.py | 8 ++-- .../{build => create}/input/action.py | 4 +- .../{build => create}/input/context.py | 2 +- .../{build => create}/input/data_sources.py | 8 ++-- .../datasets/{build => create}/input/misc.py | 0 .../{build => create}/input/repeated_dates.py | 10 ++--- .../{build => create}/input/result.py | 0 .../datasets/{build => create}/input/trace.py | 0 src/anemoi/datasets/traits/gridded.py | 2 + src/anemoi/datasets/traits/tabular.py | 2 + src/anemoi/datasets/use/gridded/missing.py | 2 +- .../use/tabular/records/backends/__init__.py | 4 +- tests/create/utils/create.py | 2 +- tests/test_chunks.py | 2 +- tests/test_dates.py | 2 +- tests/xarray/test_flavour.py | 24 +++++------ tests/xarray/test_netcdf.py | 2 +- tests/xarray/test_opendap.py | 2 +- tests/xarray/test_variable.py | 16 +++---- tests/xarray/test_zarr.py | 2 +- 80 files changed, 95 insertions(+), 91 deletions(-) rename src/anemoi/datasets/{build => create}/__init__.py (100%) rename src/anemoi/datasets/{build => create}/gridded/__init__.py (97%) rename src/anemoi/datasets/{build => create}/gridded/check.py (100%) rename src/anemoi/datasets/{build => create}/gridded/chunks.py (100%) rename src/anemoi/datasets/{build => create}/gridded/config.py (100%) rename src/anemoi/datasets/{build => create}/gridded/context.py (92%) rename src/anemoi/datasets/{build => create}/gridded/patch.py (100%) rename src/anemoi/datasets/{build => create}/gridded/persistent.py (100%) rename src/anemoi/datasets/{build => create}/gridded/result.py (99%) rename src/anemoi/datasets/{build => create}/gridded/size.py (100%) rename src/anemoi/datasets/{build => create}/gridded/source.py (95%) rename src/anemoi/datasets/{build => create}/gridded/sources/__init__.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/accumulations.py (99%) rename src/anemoi/datasets/{build => create}/gridded/sources/accumulations2.py (99%) rename src/anemoi/datasets/{build => create}/gridded/sources/anemoi_dataset.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/constants.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/eccc_fstd.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/empty.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/fdb.py (98%) rename src/anemoi/datasets/{build => create}/gridded/sources/forcings.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/grib.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/grib_index.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/hindcasts.py (97%) rename src/anemoi/datasets/{build => create}/gridded/sources/legacy.py (95%) rename src/anemoi/datasets/{build => create}/gridded/sources/mars.py (99%) rename src/anemoi/datasets/{build => create}/gridded/sources/netcdf.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/opendap.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/patterns.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/planetary_computer.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/recentre.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/repeated_dates.py (96%) rename src/anemoi/datasets/{build => create}/gridded/sources/source.py (95%) rename src/anemoi/datasets/{build => create}/gridded/sources/tendencies.py (98%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray.py (97%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_kerchunk.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/README.md (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/__init__.py (98%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/coordinates.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/field.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/fieldlist.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/flavour.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/grid.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/metadata.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/patch.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/time.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_support/variable.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/xarray_zarr.py (100%) rename src/anemoi/datasets/{build => create}/gridded/sources/zenodo.py (100%) rename src/anemoi/datasets/{build => create}/gridded/statistics/__init__.py (99%) rename src/anemoi/datasets/{build => create}/gridded/statistics/summary.py (95%) rename src/anemoi/datasets/{build => create}/gridded/testing.py (100%) rename src/anemoi/datasets/{build => create}/gridded/typing.py (100%) rename src/anemoi/datasets/{build => create}/gridded/utils.py (100%) rename src/anemoi/datasets/{build => create}/gridded/writer.py (100%) rename src/anemoi/datasets/{build => create}/gridded/zarr.py (100%) rename src/anemoi/datasets/{build => create}/input/__init__.py (89%) rename src/anemoi/datasets/{build => create}/input/action.py (97%) rename src/anemoi/datasets/{build => create}/input/context.py (96%) rename src/anemoi/datasets/{build => create}/input/data_sources.py (94%) rename src/anemoi/datasets/{build => create}/input/misc.py (100%) rename src/anemoi/datasets/{build => create}/input/repeated_dates.py (97%) rename src/anemoi/datasets/{build => create}/input/result.py (100%) rename src/anemoi/datasets/{build => create}/input/trace.py (100%) create mode 100644 src/anemoi/datasets/traits/gridded.py create mode 100644 src/anemoi/datasets/traits/tabular.py diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 212987839..ff4f852b6 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,7 +13,7 @@ import yaml -from anemoi.datasets.build.gridded.check import DatasetName +from anemoi.datasets.create.gridded.check import DatasetName from . import Command diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 601468d5c..b33918141 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -45,7 +45,7 @@ def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") - from anemoi.datasets.build.gridded import creator_factory + from anemoi.datasets.create.gridded import creator_factory options = {k: v for k, v in options.items() if v is not None} diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index 59c2fba89..a7af0e8f8 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -83,7 +83,7 @@ def match(path: str) -> bool: """ return fnmatch.fnmatch(os.path.basename(path), args.match) - from anemoi.datasets.build.gridded.sources.grib_index import GribIndex + from anemoi.datasets.create.gridded.sources.grib_index import GribIndex index = GribIndex( args.index, diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index 85fd574e3..e93184bf2 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,7 +15,7 @@ import yaml -from anemoi.datasets.build.gridded import validate_config +from anemoi.datasets.create.gridded import validate_config from .. import Command from .format import format_recipe diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index 8ca2ddd5d..047c57278 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -17,7 +17,7 @@ from glom import delete from glom import glom -from anemoi.datasets.build.gridded import validate_config +from anemoi.datasets.create.gridded import validate_config from anemoi.datasets.misc.dumper import yaml_dump LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/__init__.py b/src/anemoi/datasets/create/__init__.py similarity index 100% rename from src/anemoi/datasets/build/__init__.py rename to src/anemoi/datasets/create/__init__.py diff --git a/src/anemoi/datasets/build/gridded/__init__.py b/src/anemoi/datasets/create/gridded/__init__.py similarity index 97% rename from src/anemoi/datasets/build/gridded/__init__.py rename to src/anemoi/datasets/create/gridded/__init__.py index 696fc118b..377852420 100644 --- a/src/anemoi/datasets/build/gridded/__init__.py +++ b/src/anemoi/datasets/create/gridded/__init__.py @@ -31,22 +31,22 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset -from anemoi.datasets.build.gridded.check import DatasetName -from anemoi.datasets.build.gridded.check import check_data_values -from anemoi.datasets.build.gridded.chunks import ChunkFilter -from anemoi.datasets.build.gridded.config import build_output -from anemoi.datasets.build.gridded.config import loader_config -from anemoi.datasets.build.gridded.persistent import build_storage -from anemoi.datasets.build.gridded.statistics import Summary -from anemoi.datasets.build.gridded.statistics import TmpStatistics -from anemoi.datasets.build.gridded.statistics import check_variance -from anemoi.datasets.build.gridded.statistics import compute_statistics -from anemoi.datasets.build.gridded.statistics import default_statistics_dates -from anemoi.datasets.build.gridded.statistics import fix_variance -from anemoi.datasets.build.gridded.utils import normalize_and_check_dates -from anemoi.datasets.build.gridded.writer import ViewCacheArray -from anemoi.datasets.build.input import InputBuilder -from anemoi.datasets.build.input.trace import enable_trace +from anemoi.datasets.create.gridded.check import DatasetName +from anemoi.datasets.create.gridded.check import check_data_values +from anemoi.datasets.create.gridded.chunks import ChunkFilter +from anemoi.datasets.create.gridded.config import build_output +from anemoi.datasets.create.gridded.config import loader_config +from anemoi.datasets.create.gridded.persistent import build_storage +from anemoi.datasets.create.gridded.statistics import Summary +from anemoi.datasets.create.gridded.statistics import TmpStatistics +from anemoi.datasets.create.gridded.statistics import check_variance +from anemoi.datasets.create.gridded.statistics import compute_statistics +from anemoi.datasets.create.gridded.statistics import default_statistics_dates +from anemoi.datasets.create.gridded.statistics import fix_variance +from anemoi.datasets.create.gridded.utils import normalize_and_check_dates +from anemoi.datasets.create.gridded.writer import ViewCacheArray +from anemoi.datasets.create.input import InputBuilder +from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.dates.groups import Groups from anemoi.datasets.use.gridded.misc import as_first_date from anemoi.datasets.use.gridded.misc import as_last_date @@ -192,7 +192,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from anemoi.datasets.build.gridded.zarr import add_zarr_dataset + from anemoi.datasets.create.gridded.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -396,7 +396,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from anemoi.datasets.build.gridded.utils import cache_context + from anemoi.datasets.create.gridded.utils import cache_context return cache_context(self.cache) @@ -472,7 +472,7 @@ def __init__(self, path: str, options: dict = None, **kwargs: Any): def run(self) -> None: """Run the patch.""" - from anemoi.datasets.build.gridded.patch import apply_patch + from anemoi.datasets.create.gridded.patch import apply_patch apply_patch(self.path, **self.options) @@ -492,7 +492,7 @@ def __init__(self, path: str, **kwargs: Any): def run(self) -> None: """Run the size computation.""" - from anemoi.datasets.build.gridded.size import compute_directory_sizes + from anemoi.datasets.create.gridded.size import compute_directory_sizes metadata = compute_directory_sizes(self.path) self.update_metadata(**metadata) @@ -514,7 +514,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from anemoi.datasets.build.gridded.zarr import ZarrBuiltRegistry + from anemoi.datasets.create.gridded.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) diff --git a/src/anemoi/datasets/build/gridded/check.py b/src/anemoi/datasets/create/gridded/check.py similarity index 100% rename from src/anemoi/datasets/build/gridded/check.py rename to src/anemoi/datasets/create/gridded/check.py diff --git a/src/anemoi/datasets/build/gridded/chunks.py b/src/anemoi/datasets/create/gridded/chunks.py similarity index 100% rename from src/anemoi/datasets/build/gridded/chunks.py rename to src/anemoi/datasets/create/gridded/chunks.py diff --git a/src/anemoi/datasets/build/gridded/config.py b/src/anemoi/datasets/create/gridded/config.py similarity index 100% rename from src/anemoi/datasets/build/gridded/config.py rename to src/anemoi/datasets/create/gridded/config.py diff --git a/src/anemoi/datasets/build/gridded/context.py b/src/anemoi/datasets/create/gridded/context.py similarity index 92% rename from src/anemoi/datasets/build/gridded/context.py rename to src/anemoi/datasets/create/gridded/context.py index 91ea80c07..a20e51133 100644 --- a/src/anemoi/datasets/build/gridded/context.py +++ b/src/anemoi/datasets/create/gridded/context.py @@ -12,8 +12,8 @@ from earthkit.data.core.order import build_remapping -from anemoi.datasets.build.gridded.result import GriddedResult -from anemoi.datasets.build.input.context import Context +from anemoi.datasets.create.gridded.result import GriddedResult +from anemoi.datasets.create.input.context import Context class GriddedContext(Context): diff --git a/src/anemoi/datasets/build/gridded/patch.py b/src/anemoi/datasets/create/gridded/patch.py similarity index 100% rename from src/anemoi/datasets/build/gridded/patch.py rename to src/anemoi/datasets/create/gridded/patch.py diff --git a/src/anemoi/datasets/build/gridded/persistent.py b/src/anemoi/datasets/create/gridded/persistent.py similarity index 100% rename from src/anemoi/datasets/build/gridded/persistent.py rename to src/anemoi/datasets/create/gridded/persistent.py diff --git a/src/anemoi/datasets/build/gridded/result.py b/src/anemoi/datasets/create/gridded/result.py similarity index 99% rename from src/anemoi/datasets/build/gridded/result.py rename to src/anemoi/datasets/create/gridded/result.py index 69c560969..ed8440c52 100644 --- a/src/anemoi/datasets/build/gridded/result.py +++ b/src/anemoi/datasets/create/gridded/result.py @@ -22,7 +22,7 @@ from anemoi.utils.humanize import shorten_list from earthkit.data.core.order import build_remapping -from anemoi.datasets.build.input.result import Result +from anemoi.datasets.create.input.result import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/size.py b/src/anemoi/datasets/create/gridded/size.py similarity index 100% rename from src/anemoi/datasets/build/gridded/size.py rename to src/anemoi/datasets/create/gridded/size.py diff --git a/src/anemoi/datasets/build/gridded/source.py b/src/anemoi/datasets/create/gridded/source.py similarity index 95% rename from src/anemoi/datasets/build/gridded/source.py rename to src/anemoi/datasets/create/gridded/source.py index 494b29b92..d4d716ac3 100644 --- a/src/anemoi/datasets/build/gridded/source.py +++ b/src/anemoi/datasets/create/gridded/source.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.build.gridded.typing import DateList +from anemoi.datasets.create.gridded.typing import DateList class Source(ABC): diff --git a/src/anemoi/datasets/build/gridded/sources/__init__.py b/src/anemoi/datasets/create/gridded/sources/__init__.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/__init__.py rename to src/anemoi/datasets/create/gridded/sources/__init__.py diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations.py b/src/anemoi/datasets/create/gridded/sources/accumulations.py similarity index 99% rename from src/anemoi/datasets/build/gridded/sources/accumulations.py rename to src/anemoi/datasets/create/gridded/sources/accumulations.py index 86adea4d1..a704a93b5 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/create/gridded/sources/accumulations.py @@ -20,7 +20,7 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/build/gridded/sources/accumulations2.py b/src/anemoi/datasets/create/gridded/sources/accumulations2.py similarity index 99% rename from src/anemoi/datasets/build/gridded/sources/accumulations2.py rename to src/anemoi/datasets/create/gridded/sources/accumulations2.py index 64410164f..618d68f27 100644 --- a/src/anemoi/datasets/build/gridded/sources/accumulations2.py +++ b/src/anemoi/datasets/create/gridded/sources/accumulations2.py @@ -18,7 +18,7 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py b/src/anemoi/datasets/create/gridded/sources/anemoi_dataset.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/anemoi_dataset.py rename to src/anemoi/datasets/create/gridded/sources/anemoi_dataset.py diff --git a/src/anemoi/datasets/build/gridded/sources/constants.py b/src/anemoi/datasets/create/gridded/sources/constants.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/constants.py rename to src/anemoi/datasets/create/gridded/sources/constants.py diff --git a/src/anemoi/datasets/build/gridded/sources/eccc_fstd.py b/src/anemoi/datasets/create/gridded/sources/eccc_fstd.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/eccc_fstd.py rename to src/anemoi/datasets/create/gridded/sources/eccc_fstd.py diff --git a/src/anemoi/datasets/build/gridded/sources/empty.py b/src/anemoi/datasets/create/gridded/sources/empty.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/empty.py rename to src/anemoi/datasets/create/gridded/sources/empty.py diff --git a/src/anemoi/datasets/build/gridded/sources/fdb.py b/src/anemoi/datasets/create/gridded/sources/fdb.py similarity index 98% rename from src/anemoi/datasets/build/gridded/sources/fdb.py rename to src/anemoi/datasets/create/gridded/sources/fdb.py index 5d678fca7..67bfe8870 100644 --- a/src/anemoi/datasets/build/gridded/sources/fdb.py +++ b/src/anemoi/datasets/create/gridded/sources/fdb.py @@ -16,7 +16,7 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry -from anemoi.datasets.build.gridded.typing import DateList +from anemoi.datasets.create.gridded.typing import DateList from ..source import Source from . import source_registry diff --git a/src/anemoi/datasets/build/gridded/sources/forcings.py b/src/anemoi/datasets/create/gridded/sources/forcings.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/forcings.py rename to src/anemoi/datasets/create/gridded/sources/forcings.py diff --git a/src/anemoi/datasets/build/gridded/sources/grib.py b/src/anemoi/datasets/create/gridded/sources/grib.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/grib.py rename to src/anemoi/datasets/create/gridded/sources/grib.py diff --git a/src/anemoi/datasets/build/gridded/sources/grib_index.py b/src/anemoi/datasets/create/gridded/sources/grib_index.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/grib_index.py rename to src/anemoi/datasets/create/gridded/sources/grib_index.py diff --git a/src/anemoi/datasets/build/gridded/sources/hindcasts.py b/src/anemoi/datasets/create/gridded/sources/hindcasts.py similarity index 97% rename from src/anemoi/datasets/build/gridded/sources/hindcasts.py rename to src/anemoi/datasets/create/gridded/sources/hindcasts.py index a61a00d12..cee33a679 100644 --- a/src/anemoi/datasets/build/gridded/sources/hindcasts.py +++ b/src/anemoi/datasets/create/gridded/sources/hindcasts.py @@ -12,7 +12,7 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/build/gridded/sources/legacy.py b/src/anemoi/datasets/create/gridded/sources/legacy.py similarity index 95% rename from src/anemoi/datasets/build/gridded/sources/legacy.py rename to src/anemoi/datasets/create/gridded/sources/legacy.py index d4110cf5b..f9a0288a0 100644 --- a/src/anemoi/datasets/build/gridded/sources/legacy.py +++ b/src/anemoi/datasets/create/gridded/sources/legacy.py @@ -12,7 +12,7 @@ from abc import abstractmethod from typing import Any -from anemoi.datasets.build.input.context import Context +from anemoi.datasets.create.input.context import Context from ..source import Source diff --git a/src/anemoi/datasets/build/gridded/sources/mars.py b/src/anemoi/datasets/create/gridded/sources/mars.py similarity index 99% rename from src/anemoi/datasets/build/gridded/sources/mars.py rename to src/anemoi/datasets/create/gridded/sources/mars.py index a2804e77a..ee2dd0b90 100644 --- a/src/anemoi/datasets/build/gridded/sources/mars.py +++ b/src/anemoi/datasets/create/gridded/sources/mars.py @@ -16,7 +16,7 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/build/gridded/sources/netcdf.py b/src/anemoi/datasets/create/gridded/sources/netcdf.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/netcdf.py rename to src/anemoi/datasets/create/gridded/sources/netcdf.py diff --git a/src/anemoi/datasets/build/gridded/sources/opendap.py b/src/anemoi/datasets/create/gridded/sources/opendap.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/opendap.py rename to src/anemoi/datasets/create/gridded/sources/opendap.py diff --git a/src/anemoi/datasets/build/gridded/sources/patterns.py b/src/anemoi/datasets/create/gridded/sources/patterns.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/patterns.py rename to src/anemoi/datasets/create/gridded/sources/patterns.py diff --git a/src/anemoi/datasets/build/gridded/sources/planetary_computer.py b/src/anemoi/datasets/create/gridded/sources/planetary_computer.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/planetary_computer.py rename to src/anemoi/datasets/create/gridded/sources/planetary_computer.py diff --git a/src/anemoi/datasets/build/gridded/sources/recentre.py b/src/anemoi/datasets/create/gridded/sources/recentre.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/recentre.py rename to src/anemoi/datasets/create/gridded/sources/recentre.py diff --git a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py b/src/anemoi/datasets/create/gridded/sources/repeated_dates.py similarity index 96% rename from src/anemoi/datasets/build/gridded/sources/repeated_dates.py rename to src/anemoi/datasets/create/gridded/sources/repeated_dates.py index 509ee4966..9b297e193 100644 --- a/src/anemoi/datasets/build/gridded/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/gridded/sources/repeated_dates.py @@ -14,7 +14,7 @@ from anemoi.transform.fields import new_field_with_valid_datetime from anemoi.transform.fields import new_fieldlist_from_list -from anemoi.datasets.build.input.repeated_dates import DateMapper +from anemoi.datasets.create.input.repeated_dates import DateMapper from ..source import Source from ..sources import source_registry diff --git a/src/anemoi/datasets/build/gridded/sources/source.py b/src/anemoi/datasets/create/gridded/sources/source.py similarity index 95% rename from src/anemoi/datasets/build/gridded/sources/source.py rename to src/anemoi/datasets/create/gridded/sources/source.py index 1ad5850a7..8918c1303 100644 --- a/src/anemoi/datasets/build/gridded/sources/source.py +++ b/src/anemoi/datasets/create/gridded/sources/source.py @@ -12,7 +12,7 @@ from earthkit.data import from_source -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/build/gridded/sources/tendencies.py b/src/anemoi/datasets/create/gridded/sources/tendencies.py similarity index 98% rename from src/anemoi/datasets/build/gridded/sources/tendencies.py rename to src/anemoi/datasets/create/gridded/sources/tendencies.py index 69c06a78c..780bd3832 100644 --- a/src/anemoi/datasets/build/gridded/sources/tendencies.py +++ b/src/anemoi/datasets/create/gridded/sources/tendencies.py @@ -14,7 +14,7 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.build.gridded.sources import source_registry +from anemoi.datasets.create.gridded.sources import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/build/gridded/sources/xarray.py b/src/anemoi/datasets/create/gridded/sources/xarray.py similarity index 97% rename from src/anemoi/datasets/build/gridded/sources/xarray.py rename to src/anemoi/datasets/create/gridded/sources/xarray.py index fb10dab8e..a735e52f6 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray.py +++ b/src/anemoi/datasets/create/gridded/sources/xarray.py @@ -11,7 +11,7 @@ import earthkit.data as ekd -from anemoi.datasets.build.gridded.typing import DateList +from anemoi.datasets.create.gridded.typing import DateList from ..source import Source from .xarray_support import XarrayFieldList diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/gridded/sources/xarray_kerchunk.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_kerchunk.py rename to src/anemoi/datasets/create/gridded/sources/xarray_kerchunk.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/README.md b/src/anemoi/datasets/create/gridded/sources/xarray_support/README.md similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/README.md rename to src/anemoi/datasets/create/gridded/sources/xarray_support/README.md diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py similarity index 98% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py index e0f4a7e75..cbbd9f0e3 100644 --- a/src/anemoi/datasets/build/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py @@ -15,7 +15,7 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.build.gridded.sources.patterns import iterate_patterns +from anemoi.datasets.create.gridded.sources.patterns import iterate_patterns from .. import source_registry from ..legacy import LegacySource diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/coordinates.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/coordinates.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/coordinates.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/coordinates.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/field.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/field.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/field.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/field.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/fieldlist.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/fieldlist.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/fieldlist.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/flavour.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/flavour.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/flavour.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/grid.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/grid.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/grid.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/grid.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/metadata.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/metadata.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/metadata.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/patch.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/patch.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/patch.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/patch.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/time.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/time.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/time.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/time.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py b/src/anemoi/datasets/create/gridded/sources/xarray_support/variable.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_support/variable.py rename to src/anemoi/datasets/create/gridded/sources/xarray_support/variable.py diff --git a/src/anemoi/datasets/build/gridded/sources/xarray_zarr.py b/src/anemoi/datasets/create/gridded/sources/xarray_zarr.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/xarray_zarr.py rename to src/anemoi/datasets/create/gridded/sources/xarray_zarr.py diff --git a/src/anemoi/datasets/build/gridded/sources/zenodo.py b/src/anemoi/datasets/create/gridded/sources/zenodo.py similarity index 100% rename from src/anemoi/datasets/build/gridded/sources/zenodo.py rename to src/anemoi/datasets/create/gridded/sources/zenodo.py diff --git a/src/anemoi/datasets/build/gridded/statistics/__init__.py b/src/anemoi/datasets/create/gridded/statistics/__init__.py similarity index 99% rename from src/anemoi/datasets/build/gridded/statistics/__init__.py rename to src/anemoi/datasets/create/gridded/statistics/__init__.py index e9835bfe2..fb59573c2 100644 --- a/src/anemoi/datasets/build/gridded/statistics/__init__.py +++ b/src/anemoi/datasets/create/gridded/statistics/__init__.py @@ -23,8 +23,8 @@ from anemoi.utils.provenance import gather_provenance_info from numpy.typing import NDArray -from anemoi.datasets.build.gridded.check import check_data_values -from anemoi.datasets.build.gridded.statistics.summary import Summary +from anemoi.datasets.create.gridded.check import check_data_values +from anemoi.datasets.create.gridded.statistics.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/gridded/statistics/summary.py b/src/anemoi/datasets/create/gridded/statistics/summary.py similarity index 95% rename from src/anemoi/datasets/build/gridded/statistics/summary.py rename to src/anemoi/datasets/create/gridded/statistics/summary.py index 2f81f4e5b..8e88a2d76 100644 --- a/src/anemoi/datasets/build/gridded/statistics/summary.py +++ b/src/anemoi/datasets/create/gridded/statistics/summary.py @@ -13,9 +13,9 @@ import numpy as np -from anemoi.datasets.build.gridded.check import StatisticsValueError -from anemoi.datasets.build.gridded.check import check_data_values -from anemoi.datasets.build.gridded.check import check_stats +from anemoi.datasets.create.gridded.check import StatisticsValueError +from anemoi.datasets.create.gridded.check import check_data_values +from anemoi.datasets.create.gridded.check import check_stats class Summary(dict): diff --git a/src/anemoi/datasets/build/gridded/testing.py b/src/anemoi/datasets/create/gridded/testing.py similarity index 100% rename from src/anemoi/datasets/build/gridded/testing.py rename to src/anemoi/datasets/create/gridded/testing.py diff --git a/src/anemoi/datasets/build/gridded/typing.py b/src/anemoi/datasets/create/gridded/typing.py similarity index 100% rename from src/anemoi/datasets/build/gridded/typing.py rename to src/anemoi/datasets/create/gridded/typing.py diff --git a/src/anemoi/datasets/build/gridded/utils.py b/src/anemoi/datasets/create/gridded/utils.py similarity index 100% rename from src/anemoi/datasets/build/gridded/utils.py rename to src/anemoi/datasets/create/gridded/utils.py diff --git a/src/anemoi/datasets/build/gridded/writer.py b/src/anemoi/datasets/create/gridded/writer.py similarity index 100% rename from src/anemoi/datasets/build/gridded/writer.py rename to src/anemoi/datasets/create/gridded/writer.py diff --git a/src/anemoi/datasets/build/gridded/zarr.py b/src/anemoi/datasets/create/gridded/zarr.py similarity index 100% rename from src/anemoi/datasets/build/gridded/zarr.py rename to src/anemoi/datasets/create/gridded/zarr.py diff --git a/src/anemoi/datasets/build/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py similarity index 89% rename from src/anemoi/datasets/build/input/__init__.py rename to src/anemoi/datasets/create/input/__init__.py index c3d601fd1..e4e312fa8 100644 --- a/src/anemoi/datasets/build/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -13,7 +13,7 @@ from typing import Any if TYPE_CHECKING: - from anemoi.datasets.build.input.action import Recipe + from anemoi.datasets.create.input.action import Recipe class InputBuilder: @@ -38,8 +38,8 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No @cached_property def action(self) -> "Recipe": """Returns the action object based on the configuration.""" - from anemoi.datasets.build.input.action import Recipe - from anemoi.datasets.build.input.action import action_factory + from anemoi.datasets.create.input.action import Recipe + from anemoi.datasets.create.input.action import action_factory sources = action_factory(self.data_sources, "data_sources") input = action_factory(self.config, "input") @@ -59,7 +59,7 @@ def select(self, argument) -> Any: Any Selected data. """ - from anemoi.datasets.build.gridded.context import GriddedContext + from anemoi.datasets.create.gridded.context import GriddedContext context = GriddedContext(argument, **self.kwargs) return context.create_result(self.action(context, argument)) diff --git a/src/anemoi/datasets/build/input/action.py b/src/anemoi/datasets/create/input/action.py similarity index 97% rename from src/anemoi/datasets/build/input/action.py rename to src/anemoi/datasets/create/input/action.py index 1a37d2f99..68945d59d 100644 --- a/src/anemoi/datasets/build/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -181,7 +181,7 @@ class DatasetSourceMixin: """Mixin class for sources defined in anemoi-datasets""" def create_object(self, context, config): - from anemoi.datasets.build.gridded.sources import create_source as create_datasets_source + from anemoi.datasets.create.gridded.sources import create_source as create_datasets_source return create_datasets_source(context, config) @@ -286,7 +286,7 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.build.gridded.sources import source_registry as dataset_source_registry + from anemoi.datasets.create.gridded.sources import source_registry as dataset_source_registry # Register sources, local first for name in dataset_source_registry.registered: diff --git a/src/anemoi/datasets/build/input/context.py b/src/anemoi/datasets/create/input/context.py similarity index 96% rename from src/anemoi/datasets/build/input/context.py rename to src/anemoi/datasets/create/input/context.py index e8572ba78..89df7a727 100644 --- a/src/anemoi/datasets/build/input/context.py +++ b/src/anemoi/datasets/create/input/context.py @@ -55,7 +55,7 @@ def resolve(self, config): return config def create_source(self, config: Any, *path) -> Any: - from anemoi.datasets.build.input.action import action_factory + from anemoi.datasets.create.input.action import action_factory if not isinstance(config, dict): # It is already a result (e.g. ekd.FieldList), loaded from ${a.b.c} diff --git a/src/anemoi/datasets/build/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py similarity index 94% rename from src/anemoi/datasets/build/input/data_sources.py rename to src/anemoi/datasets/create/input/data_sources.py index 6e9bfaa6a..2f776dff9 100644 --- a/src/anemoi/datasets/build/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -13,10 +13,10 @@ from earthkit.data import FieldList -from anemoi.datasets.build.gridded.result import Result -from anemoi.datasets.build.input.action import Action -from anemoi.datasets.build.input.action import action_factory -from anemoi.datasets.build.input.misc import _tidy +from anemoi.datasets.create.gridded.result import Result +from anemoi.datasets.create.input.action import Action +from anemoi.datasets.create.input.action import action_factory +from anemoi.datasets.create.input.misc import _tidy from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/input/misc.py b/src/anemoi/datasets/create/input/misc.py similarity index 100% rename from src/anemoi/datasets/build/input/misc.py rename to src/anemoi/datasets/create/input/misc.py diff --git a/src/anemoi/datasets/build/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py similarity index 97% rename from src/anemoi/datasets/build/input/repeated_dates.py rename to src/anemoi/datasets/create/input/repeated_dates.py index f20d764ec..0e9966818 100644 --- a/src/anemoi/datasets/build/input/repeated_dates.py +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -19,11 +19,11 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.build.gridded.result import Result -from anemoi.datasets.build.input.action import Action -from anemoi.datasets.build.input.action import action_factory -from anemoi.datasets.build.input.join import JoinResult -from anemoi.datasets.build.input.trace import trace_select +from anemoi.datasets.create.gridded.result import Result +from anemoi.datasets.create.input.action import Action +from anemoi.datasets.create.input.action import action_factory +from anemoi.datasets.create.input.join import JoinResult +from anemoi.datasets.create.input.trace import trace_select LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/build/input/result.py b/src/anemoi/datasets/create/input/result.py similarity index 100% rename from src/anemoi/datasets/build/input/result.py rename to src/anemoi/datasets/create/input/result.py diff --git a/src/anemoi/datasets/build/input/trace.py b/src/anemoi/datasets/create/input/trace.py similarity index 100% rename from src/anemoi/datasets/build/input/trace.py rename to src/anemoi/datasets/create/input/trace.py diff --git a/src/anemoi/datasets/traits/gridded.py b/src/anemoi/datasets/traits/gridded.py new file mode 100644 index 000000000..62f8c87d3 --- /dev/null +++ b/src/anemoi/datasets/traits/gridded.py @@ -0,0 +1,2 @@ +class Gridded: + pass diff --git a/src/anemoi/datasets/traits/tabular.py b/src/anemoi/datasets/traits/tabular.py new file mode 100644 index 000000000..4ad7058a7 --- /dev/null +++ b/src/anemoi/datasets/traits/tabular.py @@ -0,0 +1,2 @@ +class Tabular: + pass diff --git a/src/anemoi/datasets/use/gridded/missing.py b/src/anemoi/datasets/use/gridded/missing.py index b1e83638d..75bd02137 100644 --- a/src/anemoi/datasets/use/gridded/missing.py +++ b/src/anemoi/datasets/use/gridded/missing.py @@ -16,7 +16,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.build.gridded.utils import to_datetime +from anemoi.datasets.create.gridded.utils import to_datetime from anemoi.datasets.use.gridded import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex diff --git a/src/anemoi/datasets/use/tabular/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py index 786202908..61293e4e6 100644 --- a/src/anemoi/datasets/use/tabular/records/backends/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/backends/__init__.py @@ -100,7 +100,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.build.gridded import json_tidy + from anemoi.datasets.create.gridded import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -128,7 +128,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.build.gridded import json_tidy + from anemoi.datasets.create.gridded import json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index a10c83132..8d1c719d5 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.build.gridded import creator_factory +from anemoi.datasets.create.gridded import creator_factory class TestingContext: diff --git a/tests/test_chunks.py b/tests/test_chunks.py index 529c1f0cd..132614cd0 100644 --- a/tests/test_chunks.py +++ b/tests/test_chunks.py @@ -11,7 +11,7 @@ import pytest -from anemoi.datasets.build.gridded.chunks import ChunkFilter +from anemoi.datasets.create.gridded.chunks import ChunkFilter def test_chunk_filter(): diff --git a/tests/test_dates.py b/tests/test_dates.py index abc746d8e..32baf7a51 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from anemoi.datasets.build.gridded.statistics import default_statistics_dates +from anemoi.datasets.create.gridded.statistics import default_statistics_dates _ = datetime.datetime diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py index ab058839e..e70e97aa1 100644 --- a/tests/xarray/test_flavour.py +++ b/tests/xarray/test_flavour.py @@ -11,18 +11,18 @@ import pytest import xarray as xr -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.flavour import DefaultCoordinateGuesser +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.flavour import DefaultCoordinateGuesser def create_ds(var_name, standard_name, long_name, units, coord_length=5): diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py index 7994789f6..c4239bf25 100644 --- a/tests/xarray/test_netcdf.py +++ b/tests/xarray/test_netcdf.py @@ -12,7 +12,7 @@ import xarray as xr from multiurl import download -from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList URLS = { "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/efas.nc": dict(length=3), diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index 538630a23..049965e07 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -12,7 +12,7 @@ import xarray as xr from anemoi.utils.testing import skip_if_offline -from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList from anemoi.datasets.misc.testing import assert_field_list diff --git a/tests/xarray/test_variable.py b/tests/xarray/test_variable.py index 0f060a32e..e597d37a6 100644 --- a/tests/xarray/test_variable.py +++ b/tests/xarray/test_variable.py @@ -13,14 +13,14 @@ import pytest import xarray as xr -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.build.gridded.sources.xarray_support.time import ForecastFromValidTimeAndStep -from anemoi.datasets.build.gridded.sources.xarray_support.variable import Variable +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.gridded.sources.xarray_support.time import ForecastFromValidTimeAndStep +from anemoi.datasets.create.gridded.sources.xarray_support.variable import Variable @pytest.fixture diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 1c35361c7..4568754cf 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -12,7 +12,7 @@ from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages -from anemoi.datasets.build.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList from anemoi.datasets.misc.testing import assert_field_list From 37ca75780c82f9a1b73067b39c7e74d94f91faaa Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 12 Nov 2025 08:07:50 +0000 Subject: [PATCH 175/212] more refactoring --- src/anemoi/datasets/commands/grib-index.py | 2 +- src/anemoi/datasets/create/input/action.py | 4 +-- .../datasets/create/{gridded => }/source.py | 0 .../create/{gridded => }/sources/__init__.py | 0 .../{gridded => }/sources/accumulations.py | 2 +- .../{gridded => }/sources/accumulations2.py | 2 +- .../{gridded => }/sources/anemoi_dataset.py | 0 .../create/{gridded => }/sources/constants.py | 0 src/anemoi/datasets/create/sources/csv.py | 33 +++++++++++++++++++ .../create/{gridded => }/sources/eccc_fstd.py | 0 .../create/{gridded => }/sources/empty.py | 0 .../create/{gridded => }/sources/fdb.py | 0 .../create/{gridded => }/sources/forcings.py | 0 .../sources/source.py => sources/generic.py} | 3 +- .../create/{gridded => }/sources/grib.py | 0 .../{gridded => }/sources/grib_index.py | 0 .../create/{gridded => }/sources/hindcasts.py | 2 +- .../create/{gridded => }/sources/legacy.py | 0 .../create/{gridded => }/sources/mars.py | 2 +- .../create/{gridded => }/sources/netcdf.py | 0 .../create/{gridded => }/sources/opendap.py | 0 .../create/{gridded => }/sources/patterns.py | 0 .../sources/planetary_computer.py | 0 .../create/{gridded => }/sources/recentre.py | 0 .../{gridded => }/sources/repeated_dates.py | 2 +- .../{gridded => }/sources/tendencies.py | 2 +- .../create/{gridded => }/sources/xarray.py | 0 .../{gridded => }/sources/xarray_kerchunk.py | 0 .../sources/xarray_support/README.md | 0 .../sources/xarray_support/__init__.py | 2 +- .../sources/xarray_support/coordinates.py | 0 .../sources/xarray_support/field.py | 0 .../sources/xarray_support/fieldlist.py | 0 .../sources/xarray_support/flavour.py | 0 .../sources/xarray_support/grid.py | 0 .../sources/xarray_support/metadata.py | 0 .../sources/xarray_support/patch.py | 0 .../sources/xarray_support/time.py | 0 .../sources/xarray_support/variable.py | 0 .../{gridded => }/sources/xarray_zarr.py | 0 .../create/{gridded => }/sources/zenodo.py | 0 tests/test_csv.py | 29 ++++++++++++++++ tests/xarray/test_flavour.py | 24 +++++++------- tests/xarray/test_netcdf.py | 2 +- tests/xarray/test_opendap.py | 2 +- tests/xarray/test_variable.py | 16 ++++----- tests/xarray/test_zarr.py | 2 +- 47 files changed, 96 insertions(+), 35 deletions(-) rename src/anemoi/datasets/create/{gridded => }/source.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/__init__.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/accumulations.py (99%) rename src/anemoi/datasets/create/{gridded => }/sources/accumulations2.py (99%) rename src/anemoi/datasets/create/{gridded => }/sources/anemoi_dataset.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/constants.py (100%) create mode 100644 src/anemoi/datasets/create/sources/csv.py rename src/anemoi/datasets/create/{gridded => }/sources/eccc_fstd.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/empty.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/fdb.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/forcings.py (100%) rename src/anemoi/datasets/create/{gridded/sources/source.py => sources/generic.py} (95%) rename src/anemoi/datasets/create/{gridded => }/sources/grib.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/grib_index.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/hindcasts.py (97%) rename src/anemoi/datasets/create/{gridded => }/sources/legacy.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/mars.py (99%) rename src/anemoi/datasets/create/{gridded => }/sources/netcdf.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/opendap.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/patterns.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/planetary_computer.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/recentre.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/repeated_dates.py (97%) rename src/anemoi/datasets/create/{gridded => }/sources/tendencies.py (98%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_kerchunk.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/README.md (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/__init__.py (98%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/coordinates.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/field.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/fieldlist.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/flavour.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/grid.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/metadata.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/patch.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/time.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_support/variable.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/xarray_zarr.py (100%) rename src/anemoi/datasets/create/{gridded => }/sources/zenodo.py (100%) create mode 100644 tests/test_csv.py diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index a7af0e8f8..cfd7a08e8 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -83,7 +83,7 @@ def match(path: str) -> bool: """ return fnmatch.fnmatch(os.path.basename(path), args.match) - from anemoi.datasets.create.gridded.sources.grib_index import GribIndex + from anemoi.datasets.create.sources.grib_index import GribIndex index = GribIndex( args.index, diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 68945d59d..7808ae717 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -181,7 +181,7 @@ class DatasetSourceMixin: """Mixin class for sources defined in anemoi-datasets""" def create_object(self, context, config): - from anemoi.datasets.create.gridded.sources import create_source as create_datasets_source + from anemoi.datasets.create.sources import create_source as create_datasets_source return create_datasets_source(context, config) @@ -286,7 +286,7 @@ def make(key, config, *path): from anemoi.transform.filters import filter_registry as transform_filter_registry from anemoi.transform.sources import source_registry as transform_source_registry - from anemoi.datasets.create.gridded.sources import source_registry as dataset_source_registry + from anemoi.datasets.create.sources import source_registry as dataset_source_registry # Register sources, local first for name in dataset_source_registry.registered: diff --git a/src/anemoi/datasets/create/gridded/source.py b/src/anemoi/datasets/create/source.py similarity index 100% rename from src/anemoi/datasets/create/gridded/source.py rename to src/anemoi/datasets/create/source.py diff --git a/src/anemoi/datasets/create/gridded/sources/__init__.py b/src/anemoi/datasets/create/sources/__init__.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/__init__.py rename to src/anemoi/datasets/create/sources/__init__.py diff --git a/src/anemoi/datasets/create/gridded/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py similarity index 99% rename from src/anemoi/datasets/create/gridded/sources/accumulations.py rename to src/anemoi/datasets/create/sources/accumulations.py index a704a93b5..ce4ff6266 100644 --- a/src/anemoi/datasets/create/gridded/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -20,7 +20,7 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.create.gridded.sources import source_registry +from anemoi.datasets.create.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/create/gridded/sources/accumulations2.py b/src/anemoi/datasets/create/sources/accumulations2.py similarity index 99% rename from src/anemoi/datasets/create/gridded/sources/accumulations2.py rename to src/anemoi/datasets/create/sources/accumulations2.py index 618d68f27..2f719e46e 100644 --- a/src/anemoi/datasets/create/gridded/sources/accumulations2.py +++ b/src/anemoi/datasets/create/sources/accumulations2.py @@ -18,7 +18,7 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.gridded.sources import source_registry +from anemoi.datasets.create.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/create/gridded/sources/anemoi_dataset.py b/src/anemoi/datasets/create/sources/anemoi_dataset.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/anemoi_dataset.py rename to src/anemoi/datasets/create/sources/anemoi_dataset.py diff --git a/src/anemoi/datasets/create/gridded/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/constants.py rename to src/anemoi/datasets/create/sources/constants.py diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py new file mode 100644 index 000000000..66d93d332 --- /dev/null +++ b/src/anemoi/datasets/create/sources/csv.py @@ -0,0 +1,33 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +from anemoi.datasets.create.gridded.typing import DateList + +from ..source import Source +from . import source_registry + + +@source_registry.register("csv") +class CsvSource(Source): + """CSV data source.""" + + emoji = "?" + + def __init__( + self, + context, + **kwargs: dict[str, Any], + ): + + super().__init__(context) + + def execute(self, dates: DateList): + raise NotImplementedError("To be developed") diff --git a/src/anemoi/datasets/create/gridded/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/eccc_fstd.py rename to src/anemoi/datasets/create/sources/eccc_fstd.py diff --git a/src/anemoi/datasets/create/gridded/sources/empty.py b/src/anemoi/datasets/create/sources/empty.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/empty.py rename to src/anemoi/datasets/create/sources/empty.py diff --git a/src/anemoi/datasets/create/gridded/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/fdb.py rename to src/anemoi/datasets/create/sources/fdb.py diff --git a/src/anemoi/datasets/create/gridded/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/forcings.py rename to src/anemoi/datasets/create/sources/forcings.py diff --git a/src/anemoi/datasets/create/gridded/sources/source.py b/src/anemoi/datasets/create/sources/generic.py similarity index 95% rename from src/anemoi/datasets/create/gridded/sources/source.py rename to src/anemoi/datasets/create/sources/generic.py index 8918c1303..a6675449a 100644 --- a/src/anemoi/datasets/create/gridded/sources/source.py +++ b/src/anemoi/datasets/create/sources/generic.py @@ -12,8 +12,7 @@ from earthkit.data import from_source -from anemoi.datasets.create.gridded.sources import source_registry - +from . import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/create/gridded/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/grib.py rename to src/anemoi/datasets/create/sources/grib.py diff --git a/src/anemoi/datasets/create/gridded/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/grib_index.py rename to src/anemoi/datasets/create/sources/grib_index.py diff --git a/src/anemoi/datasets/create/gridded/sources/hindcasts.py b/src/anemoi/datasets/create/sources/hindcasts.py similarity index 97% rename from src/anemoi/datasets/create/gridded/sources/hindcasts.py rename to src/anemoi/datasets/create/sources/hindcasts.py index cee33a679..b9985ccf1 100644 --- a/src/anemoi/datasets/create/gridded/sources/hindcasts.py +++ b/src/anemoi/datasets/create/sources/hindcasts.py @@ -12,7 +12,7 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.gridded.sources import source_registry +from anemoi.datasets.create.sources import source_registry from .legacy import LegacySource from .mars import mars diff --git a/src/anemoi/datasets/create/gridded/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/legacy.py rename to src/anemoi/datasets/create/sources/legacy.py diff --git a/src/anemoi/datasets/create/gridded/sources/mars.py b/src/anemoi/datasets/create/sources/mars.py similarity index 99% rename from src/anemoi/datasets/create/gridded/sources/mars.py rename to src/anemoi/datasets/create/sources/mars.py index ee2dd0b90..25e223cb4 100644 --- a/src/anemoi/datasets/create/gridded/sources/mars.py +++ b/src/anemoi/datasets/create/sources/mars.py @@ -16,7 +16,7 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.create.gridded.sources import source_registry +from anemoi.datasets.create.sources import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/create/gridded/sources/netcdf.py b/src/anemoi/datasets/create/sources/netcdf.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/netcdf.py rename to src/anemoi/datasets/create/sources/netcdf.py diff --git a/src/anemoi/datasets/create/gridded/sources/opendap.py b/src/anemoi/datasets/create/sources/opendap.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/opendap.py rename to src/anemoi/datasets/create/sources/opendap.py diff --git a/src/anemoi/datasets/create/gridded/sources/patterns.py b/src/anemoi/datasets/create/sources/patterns.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/patterns.py rename to src/anemoi/datasets/create/sources/patterns.py diff --git a/src/anemoi/datasets/create/gridded/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/planetary_computer.py rename to src/anemoi/datasets/create/sources/planetary_computer.py diff --git a/src/anemoi/datasets/create/gridded/sources/recentre.py b/src/anemoi/datasets/create/sources/recentre.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/recentre.py rename to src/anemoi/datasets/create/sources/recentre.py diff --git a/src/anemoi/datasets/create/gridded/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py similarity index 97% rename from src/anemoi/datasets/create/gridded/sources/repeated_dates.py rename to src/anemoi/datasets/create/sources/repeated_dates.py index 9b297e193..f43ec8ce0 100644 --- a/src/anemoi/datasets/create/gridded/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -17,7 +17,7 @@ from anemoi.datasets.create.input.repeated_dates import DateMapper from ..source import Source -from ..sources import source_registry +from . import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/gridded/sources/tendencies.py b/src/anemoi/datasets/create/sources/tendencies.py similarity index 98% rename from src/anemoi/datasets/create/gridded/sources/tendencies.py rename to src/anemoi/datasets/create/sources/tendencies.py index 780bd3832..cdf4ce291 100644 --- a/src/anemoi/datasets/create/gridded/sources/tendencies.py +++ b/src/anemoi/datasets/create/sources/tendencies.py @@ -14,7 +14,7 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.gridded.sources import source_registry +from anemoi.datasets.create.sources import source_registry from .legacy import LegacySource diff --git a/src/anemoi/datasets/create/gridded/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray.py rename to src/anemoi/datasets/create/sources/xarray.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_kerchunk.py rename to src/anemoi/datasets/create/sources/xarray_kerchunk.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/README.md b/src/anemoi/datasets/create/sources/xarray_support/README.md similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/README.md rename to src/anemoi/datasets/create/sources/xarray_support/README.md diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py similarity index 98% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py rename to src/anemoi/datasets/create/sources/xarray_support/__init__.py index cbbd9f0e3..8e3cebc08 100644 --- a/src/anemoi/datasets/create/gridded/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -15,7 +15,7 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.gridded.sources.patterns import iterate_patterns +from anemoi.datasets.create.sources.patterns import iterate_patterns from .. import source_registry from ..legacy import LegacySource diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/coordinates.py b/src/anemoi/datasets/create/sources/xarray_support/coordinates.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/coordinates.py rename to src/anemoi/datasets/create/sources/xarray_support/coordinates.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/field.py rename to src/anemoi/datasets/create/sources/xarray_support/field.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/fieldlist.py rename to src/anemoi/datasets/create/sources/xarray_support/fieldlist.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/flavour.py rename to src/anemoi/datasets/create/sources/xarray_support/flavour.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/grid.py b/src/anemoi/datasets/create/sources/xarray_support/grid.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/grid.py rename to src/anemoi/datasets/create/sources/xarray_support/grid.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/metadata.py rename to src/anemoi/datasets/create/sources/xarray_support/metadata.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/patch.py b/src/anemoi/datasets/create/sources/xarray_support/patch.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/patch.py rename to src/anemoi/datasets/create/sources/xarray_support/patch.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/time.py rename to src/anemoi/datasets/create/sources/xarray_support/time.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_support/variable.py rename to src/anemoi/datasets/create/sources/xarray_support/variable.py diff --git a/src/anemoi/datasets/create/gridded/sources/xarray_zarr.py b/src/anemoi/datasets/create/sources/xarray_zarr.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/xarray_zarr.py rename to src/anemoi/datasets/create/sources/xarray_zarr.py diff --git a/src/anemoi/datasets/create/gridded/sources/zenodo.py b/src/anemoi/datasets/create/sources/zenodo.py similarity index 100% rename from src/anemoi/datasets/create/gridded/sources/zenodo.py rename to src/anemoi/datasets/create/sources/zenodo.py diff --git a/tests/test_csv.py b/tests/test_csv.py new file mode 100644 index 000000000..8127a5d35 --- /dev/null +++ b/tests/test_csv.py @@ -0,0 +1,29 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import pytest + +from anemoi.datasets.create.sources import create_source + +LOG = logging.getLogger(__name__) + + +def test_csv_source_registration(): + + source = create_source(context=None, config={"csv": {"path": "data.csv"}}) + + with pytest.raises(NotImplementedError): + source.execute(dates=[]) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + test_csv_source_registration() diff --git a/tests/xarray/test_flavour.py b/tests/xarray/test_flavour.py index e70e97aa1..7b2bb33e5 100644 --- a/tests/xarray/test_flavour.py +++ b/tests/xarray/test_flavour.py @@ -11,18 +11,18 @@ import pytest import xarray as xr -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.flavour import DefaultCoordinateGuesser +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.flavour import DefaultCoordinateGuesser def create_ds(var_name, standard_name, long_name, units, coord_length=5): diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py index c4239bf25..f25d8c4d7 100644 --- a/tests/xarray/test_netcdf.py +++ b/tests/xarray/test_netcdf.py @@ -12,7 +12,7 @@ import xarray as xr from multiurl import download -from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.sources.xarray import XarrayFieldList URLS = { "https://get.ecmwf.int/repository/test-data/earthkit-data/examples/efas.nc": dict(length=3), diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index 049965e07..f4fb9cf4c 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -12,7 +12,7 @@ import xarray as xr from anemoi.utils.testing import skip_if_offline -from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.sources.xarray import XarrayFieldList from anemoi.datasets.misc.testing import assert_field_list diff --git a/tests/xarray/test_variable.py b/tests/xarray/test_variable.py index e597d37a6..ff43da389 100644 --- a/tests/xarray/test_variable.py +++ b/tests/xarray/test_variable.py @@ -13,14 +13,14 @@ import pytest import xarray as xr -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.gridded.sources.xarray_support.time import ForecastFromValidTimeAndStep -from anemoi.datasets.create.gridded.sources.xarray_support.variable import Variable +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.time import ForecastFromValidTimeAndStep +from anemoi.datasets.create.sources.xarray_support.variable import Variable @pytest.fixture diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 4568754cf..6ddeec310 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -12,7 +12,7 @@ from anemoi.utils.testing import skip_if_offline from anemoi.utils.testing import skip_missing_packages -from anemoi.datasets.create.gridded.sources.xarray import XarrayFieldList +from anemoi.datasets.create.sources.xarray import XarrayFieldList from anemoi.datasets.misc.testing import assert_field_list From 92c0023b29ac6d7069fb488630dc7e8c4508ce54 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 10:30:50 +0000 Subject: [PATCH 176/212] update --- src/anemoi/datasets/create/sources/csv.py | 34 +++++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 0b293845e..48ba1f9f8 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -18,7 +18,15 @@ class CSVSource(ObservationsSource): emoji = "📄" # For tracing - def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): + def __init__( + self, + context: any, + path: str, + columns: list = None, + flavour: dict = None, + *args, + **kwargs, + ): """Initialise the CSVSource. Parameters @@ -27,16 +35,36 @@ def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): The context for the data source. filepath : str The path to the CSV file. + columns : list, optional + The list of columns to read from the CSV file. *args : tuple Additional positional arguments. **kwargs : dict Additional keyword arguments. """ super().__init__(context, *args, **kwargs) + self.path = path + self.columns = columns + + self.flavour = { + "latitude": "latitude", + "longitude": "longitude", + "time": "time", + } + + if flavour is not None: + self.flavour.update(flavour) def execute(self, dates): import pandas as pd - frame = pd.read_csv(self.path) - print(frame) + if self.columns is None: + frame = pd.read_csv(self.path) + else: + frame = pd.read_csv(self.path, usecols=self.columns) + + start, end = dates.window.start_date, dates.window.end_date + mask = (frame[self.flavour["time"]] >= start) & (frame[self.flavour["time"]] <= end) + frame = frame.loc[mask] + return frame From 7a0f959a62489d1610adea01c6b42093f6e4569f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 11:07:09 +0000 Subject: [PATCH 177/212] update --- src/anemoi/datasets/commands/create.py | 2 +- .../datasets/create/fields/additions.py | 8 ++--- src/anemoi/datasets/create/fields/init.py | 4 +-- src/anemoi/datasets/create/fields/load.py | 4 +-- src/anemoi/datasets/create/fields/tasks.py | 20 ++++++------- src/anemoi/datasets/create/tasks.py | 4 +-- src/anemoi/datasets/use/gridded/dataset.py | 2 +- src/anemoi/datasets/use/gridded/misc.py | 6 ++-- src/anemoi/datasets/use/gridded/stores.py | 2 +- src/anemoi/datasets/use/gridded/subset.py | 2 +- .../datasets/use/tabular/records/__init__.py | 4 +-- tests/create/test_observations.py | 4 +-- tests/create/test_observations_mars.py | 4 +-- tests/create/test_observations_mars_bufr.py | 4 +-- .../test_observations_mars_bufr_complex.py | 4 +-- .../test_observations_mars_bufr_parallel.py | 4 +-- tests/create/test_sources.py | 21 +++++++++++--- tests/test_csv.py | 29 ------------------- 18 files changed, 56 insertions(+), 72 deletions(-) delete mode 100644 tests/test_csv.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 151b175d9..9c7f63cc4 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -32,7 +32,7 @@ def task(what: str, fields: bool, options: dict, *args: Any, **kwargs: Any) -> A options = {k: v for k, v in options.items() if v is not None} - c = task_factory(what.replace("-", "_"), fields, **options) + c = task_factory(what.replace("-", "_"), **options) result = c.run() LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") diff --git a/src/anemoi/datasets/create/fields/additions.py b/src/anemoi/datasets/create/fields/additions.py index 94972e1c4..0b113aeef 100644 --- a/src/anemoi/datasets/create/fields/additions.py +++ b/src/anemoi/datasets/create/fields/additions.py @@ -21,11 +21,11 @@ from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.gridded.statistics import Summary +from anemoi.datasets.create.gridded.statistics import check_variance +from anemoi.datasets.create.gridded.statistics import compute_statistics +from anemoi.datasets.create.gridded.statistics import fix_variance from anemoi.datasets.create.persistent import build_storage -from anemoi.datasets.create.statistics import Summary -from anemoi.datasets.create.statistics import check_variance -from anemoi.datasets.create.statistics import compute_statistics -from anemoi.datasets.create.statistics import fix_variance from .tasks import FieldTask from .tasks import HasRegistryMixin diff --git a/src/anemoi/datasets/create/fields/init.py b/src/anemoi/datasets/create/fields/init.py index 77e1f36e1..347802c32 100644 --- a/src/anemoi/datasets/create/fields/init.py +++ b/src/anemoi/datasets/create/fields/init.py @@ -15,8 +15,8 @@ import zarr from anemoi.utils.sanitise import sanitise -from anemoi.datasets.create.config import loader_config -from anemoi.datasets.create.utils import normalize_and_check_dates +from anemoi.datasets.create.gridded.config import loader_config +from anemoi.datasets.create.gridded.utils import normalize_and_check_dates from .tasks import FieldTask from .tasks import HasElementForDataMixin diff --git a/src/anemoi/datasets/create/fields/load.py b/src/anemoi/datasets/create/fields/load.py index bab731cb2..813b6b3ea 100644 --- a/src/anemoi/datasets/create/fields/load.py +++ b/src/anemoi/datasets/create/fields/load.py @@ -17,9 +17,9 @@ from anemoi.utils.humanize import compress_dates from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.create.check import check_data_values from anemoi.datasets.create.chunks import ChunkFilter -from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.gridded.check import check_data_values +from anemoi.datasets.create.gridded.statistics import compute_statistics from anemoi.datasets.create.writer import ViewCacheArray from .tasks import FieldTask diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py index 97beef80f..cafbdd233 100644 --- a/src/anemoi/datasets/create/fields/tasks.py +++ b/src/anemoi/datasets/create/fields/tasks.py @@ -21,16 +21,16 @@ from earthkit.data.core.order import build_remapping from anemoi.datasets import open_dataset -from anemoi.datasets.create.check import DatasetName -from anemoi.datasets.create.config import build_output -from anemoi.datasets.create.config import loader_config from anemoi.datasets.create.fields.context import FieldContext +from anemoi.datasets.create.gridded.check import DatasetName +from anemoi.datasets.create.gridded.config import build_output +from anemoi.datasets.create.gridded.config import loader_config +from anemoi.datasets.create.gridded.statistics import TmpStatistics +from anemoi.datasets.create.gridded.statistics import default_statistics_dates from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.statistics import TmpStatistics -from anemoi.datasets.create.statistics import default_statistics_dates -from anemoi.datasets.data.misc import as_first_date -from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.use.gridded.misc import as_first_date +from anemoi.datasets.use.gridded.misc import as_last_date from ..tasks import chain @@ -151,7 +151,7 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: import zarr z = zarr.open(self.path, mode=mode) - from anemoi.datasets.create.zarr import add_zarr_dataset + from anemoi.datasets.create.gridded.zarr import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -355,7 +355,7 @@ def _cache_context(self) -> Any: Any The cache context. """ - from anemoi.datasets.create.utils import cache_context + from anemoi.datasets.create.gridded.utils import cache_context return cache_context(self.cache) @@ -419,7 +419,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from anemoi.datasets.create.zarr import ZarrBuiltRegistry + from anemoi.datasets.create.gridded.zarr import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py index f06f8fe5d..af249d730 100644 --- a/src/anemoi/datasets/create/tasks.py +++ b/src/anemoi/datasets/create/tasks.py @@ -46,9 +46,9 @@ def run(self) -> None: return Chain -def task_factory(name: str, fields: bool, trace: str | None = None, **kwargs): +def task_factory(name: str, trace: str | None = None, **kwargs): - if fields: + if True: from anemoi.datasets.create.fields.tasks import TaskCreator creator = TaskCreator() diff --git a/src/anemoi/datasets/use/gridded/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py index 4f8657a00..5a4df0052 100644 --- a/src/anemoi/datasets/use/gridded/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -1026,7 +1026,7 @@ def origins(self) -> Any: print(p.origins()) def components(self) -> Any: - from anemoi.datasets.data.components import Projection + from anemoi.datasets.use.components import Projection slices = tuple(slice(0, i, 1) for i in self.shape) return self.project(Projection(slices)) diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index 6aad8ae1f..58305511d 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -386,7 +386,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in load_any_dict_format(metadata_path): raise ValueError(f"Metadata for {path} does not contain 'backend' key") - from anemoi.datasets.data.records import open_records_dataset + from anemoi.datasets.use.records import open_records_dataset return open_records_dataset(path) @@ -608,13 +608,13 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": assert len(sets) > 0, (args, kwargs) if "set_group" in kwargs: - from anemoi.datasets.data.records import FieldsRecords + from anemoi.datasets.use.records import FieldsRecords set_group = kwargs.pop("set_group") assert len(sets) == 1, "set_group can only be used with a single dataset" dataset = sets[0] - from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.use.dataset import Dataset if isinstance(dataset, Dataset): # Fields dataset return FieldsRecords(dataset, **kwargs, name=set_group).mutate() diff --git a/src/anemoi/datasets/use/gridded/stores.py b/src/anemoi/datasets/use/gridded/stores.py index 444191da0..3538e1040 100644 --- a/src/anemoi/datasets/use/gridded/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -581,7 +581,7 @@ def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: tried.append(full) try: - from anemoi.datasets.data.records import open_records_dataset + from anemoi.datasets.use.tabular.records import open_records_dataset z = open_records_dataset(full) if z is not None: diff --git a/src/anemoi/datasets/use/gridded/subset.py b/src/anemoi/datasets/use/gridded/subset.py index 3901aaf75..5e8c1cfb7 100644 --- a/src/anemoi/datasets/use/gridded/subset.py +++ b/src/anemoi/datasets/use/gridded/subset.py @@ -82,7 +82,7 @@ def _end(a: int, b: int, dates: NDArray[np.datetime64]) -> int: Returns: int: The index of the end date. """ - from anemoi.datasets.data.misc import as_last_date + from anemoi.datasets.use.gridded.misc import as_last_date c = as_last_date(a, dates) d = as_last_date(b, dates) diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index 7ac085196..9216bcadc 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -18,7 +18,7 @@ from anemoi.utils.config import load_any_dict_format from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.data.debug import Node +from anemoi.datasets.use.debug import Node from .records.backends import backend_factory from .windows import window_from_str @@ -364,7 +364,7 @@ def __init__(self, fields_dataset, name): . """ self.forward = fields_dataset - from anemoi.datasets.data.dataset import Dataset + from anemoi.datasets.use.dataset import Dataset assert isinstance(fields_dataset, Dataset), f"fields_dataset must be a Dataset, got {type(fields_dataset)}" self._name = name diff --git a/tests/create/test_observations.py b/tests/create/test_observations.py index 01410cf2f..af0f02fe5 100644 --- a/tests/create/test_observations.py +++ b/tests/create/test_observations.py @@ -14,8 +14,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import Interval -from anemoi.datasets.data.records import window_from_str +from anemoi.datasets.use.records import Interval +from anemoi.datasets.use.records import window_from_str class DummpySource(ObservationsSource): diff --git a/tests/create/test_observations_mars.py b/tests/create/test_observations_mars.py index b28e00708..1ca686b49 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/test_observations_mars.py @@ -16,8 +16,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import Interval -from anemoi.datasets.data.records import window_from_str +from anemoi.datasets.use.records import Interval +from anemoi.datasets.use.records import window_from_str log = logging.getLogger(__name__) diff --git a/tests/create/test_observations_mars_bufr.py b/tests/create/test_observations_mars_bufr.py index 747274af5..b916a58c0 100644 --- a/tests/create/test_observations_mars_bufr.py +++ b/tests/create/test_observations_mars_bufr.py @@ -16,8 +16,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import Interval -from anemoi.datasets.data.records import window_from_str +from anemoi.datasets.use.records import Interval +from anemoi.datasets.use.records import window_from_str log = logging.getLogger(__name__) diff --git a/tests/create/test_observations_mars_bufr_complex.py b/tests/create/test_observations_mars_bufr_complex.py index 2901e9cf6..d271a43c0 100644 --- a/tests/create/test_observations_mars_bufr_complex.py +++ b/tests/create/test_observations_mars_bufr_complex.py @@ -16,8 +16,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import Interval -from anemoi.datasets.data.records import window_from_str +from anemoi.datasets.use.records import Interval +from anemoi.datasets.use.records import window_from_str log = logging.getLogger(__name__) diff --git a/tests/create/test_observations_mars_bufr_parallel.py b/tests/create/test_observations_mars_bufr_parallel.py index 05c9397b3..d3562191d 100644 --- a/tests/create/test_observations_mars_bufr_parallel.py +++ b/tests/create/test_observations_mars_bufr_parallel.py @@ -16,8 +16,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.data.records import Interval -from anemoi.datasets.data.records import window_from_str +from anemoi.datasets.use.records import Interval +from anemoi.datasets.use.records import window_from_str log = logging.getLogger(__name__) diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index e679fb6bf..82dffb264 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -285,8 +285,21 @@ def test_planetary_computer_conus404() -> None: assert ds.shape == (2, 1, 1, 1387505), ds.shape -if __name__ == "__main__": - - from anemoi.utils.testing import run_tests +@skip_if_offline +def test_csv(get_test_data: callable) -> None: + """Test for CSV source registration.""" + from anemoi.datasets.create.sources import create_source + from anemoi.datasets.dates import DatesProvider + + data = get_test_data("anemoi-datasets/obs/dribu.csv") + + source = create_source(context=None, config={"csv": {"path": data}}) + window = DatesProvider.from_config( + { + "start": "2020-01-01T00:00:00", + "end": "2020-01-02:23:59:59", + "window": "(-3h:+3h]", + } + ) - run_tests(globals()) + source.execute(window) diff --git a/tests/test_csv.py b/tests/test_csv.py deleted file mode 100644 index 8127a5d35..000000000 --- a/tests/test_csv.py +++ /dev/null @@ -1,29 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging - -import pytest - -from anemoi.datasets.create.sources import create_source - -LOG = logging.getLogger(__name__) - - -def test_csv_source_registration(): - - source = create_source(context=None, config={"csv": {"path": "data.csv"}}) - - with pytest.raises(NotImplementedError): - source.execute(dates=[]) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_csv_source_registration() From b8892aec3b8ea46ed93da70052daa05342ee4078 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 11:14:09 +0000 Subject: [PATCH 178/212] update --- .../datasets/create/fields/additions.py | 413 ----------- src/anemoi/datasets/create/fields/cleanup.py | 60 -- src/anemoi/datasets/create/fields/context.py | 78 -- src/anemoi/datasets/create/fields/init.py | 293 -------- src/anemoi/datasets/create/fields/load.py | 260 ------- src/anemoi/datasets/create/fields/patch.py | 38 - src/anemoi/datasets/create/fields/result.py | 668 ------------------ src/anemoi/datasets/create/fields/size.py | 48 -- .../datasets/create/fields/statistics.py | 102 --- src/anemoi/datasets/create/fields/tasks.py | 606 ---------------- src/anemoi/datasets/create/fields/verify.py | 34 - .../datasets/create/observations/__init__.py | 0 .../create/{fields => tabular}/__init__.py | 0 .../create/{observations => tabular}/tasks.py | 0 14 files changed, 2600 deletions(-) delete mode 100644 src/anemoi/datasets/create/fields/additions.py delete mode 100644 src/anemoi/datasets/create/fields/cleanup.py delete mode 100644 src/anemoi/datasets/create/fields/context.py delete mode 100644 src/anemoi/datasets/create/fields/init.py delete mode 100644 src/anemoi/datasets/create/fields/load.py delete mode 100644 src/anemoi/datasets/create/fields/patch.py delete mode 100644 src/anemoi/datasets/create/fields/result.py delete mode 100644 src/anemoi/datasets/create/fields/size.py delete mode 100644 src/anemoi/datasets/create/fields/statistics.py delete mode 100644 src/anemoi/datasets/create/fields/tasks.py delete mode 100644 src/anemoi/datasets/create/fields/verify.py delete mode 100644 src/anemoi/datasets/create/observations/__init__.py rename src/anemoi/datasets/create/{fields => tabular}/__init__.py (100%) rename src/anemoi/datasets/create/{observations => tabular}/tasks.py (100%) diff --git a/src/anemoi/datasets/create/fields/additions.py b/src/anemoi/datasets/create/fields/additions.py deleted file mode 100644 index 0b113aeef..000000000 --- a/src/anemoi/datasets/create/fields/additions.py +++ /dev/null @@ -1,413 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import os -import warnings -from functools import cached_property -from typing import Any - -import numpy as np -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets import MissingDateError -from anemoi.datasets import open_dataset -from anemoi.datasets.create.chunks import ChunkFilter -from anemoi.datasets.create.gridded.statistics import Summary -from anemoi.datasets.create.gridded.statistics import check_variance -from anemoi.datasets.create.gridded.statistics import compute_statistics -from anemoi.datasets.create.gridded.statistics import fix_variance -from anemoi.datasets.create.persistent import build_storage - -from .tasks import FieldTask -from .tasks import HasRegistryMixin - -LOG = logging.getLogger(__name__) - - -class AdditionsMixin: - """A mixin class to handle dataset additions.""" - - def skip(self) -> bool: - """Check if the additions should be skipped. - - Returns - ------- - bool - Whether to skip the additions. - """ - frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - if not self.delta.total_seconds() % frequency.total_seconds() == 0: - LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") - return True - - if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: - LOG.warning(f"Additions are disabled for {self.path} in the recipe.") - return True - - return False - - @cached_property - def tmp_storage_path(self) -> str: - """Get the path to the temporary storage.""" - name = "storage_for_additions" - if self.delta: - name += frequency_to_string(self.delta) - return os.path.join(f"{self.path}.{name}.tmp") - - def read_from_dataset(self) -> None: - """Read data from the dataset.""" - self.variables = self.dataset.anemoi_dataset.variables - self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - start = self.dataset.zarr_metadata["statistics_start_date"] - end = self.dataset.zarr_metadata["statistics_end_date"] - self.start = datetime.datetime.fromisoformat(start) - self.end = datetime.datetime.fromisoformat(end) - - ds = open_dataset(self.path, start=self.start, end=self.end) - self.dates = ds.dates - self.total = len(self.dates) - - idelta = self.delta.total_seconds() // self.frequency.total_seconds() - assert int(idelta) == idelta, idelta - idelta = int(idelta) - self.ds = DeltaDataset(ds, idelta) - - -class DeltaDataset: - """A class to represent a dataset with delta values.""" - - def __init__(self, ds: Any, idelta: int): - """Initialize a DeltaDataset instance. - - Parameters - ---------- - ds : Any - The dataset. - idelta : int - The delta value. - """ - self.ds = ds - self.idelta = idelta - - def __getitem__(self, i: int) -> Any: - """Get an item from the dataset. - - Parameters - ---------- - i : int - The index. - - Returns - ------- - Any - The item. - """ - j = i - self.idelta - if j < 0: - raise MissingDateError(f"Missing date {j}") - return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] - - -class _InitAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to initialize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize an _InitAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - def run(self) -> None: - """Run the additions initialization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) - self.tmp_storage.delete() - self.tmp_storage.create() - LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") - - def cleanup(self) -> None: - """Clean up the temporary storage.""" - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - self.tmp_storage.delete() - LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") - - -class _LoadAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to run dataset additions.""" - - def __init__( - self, - path: str, - delta: str, - parts: str | None = None, - use_threads: bool = False, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a _LoadAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - self.parts = parts - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Writing in {self.tmp_storage_path}") - - def run(self) -> None: - """Run the additions.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.read_from_dataset() - - chunk_filter = ChunkFilter(parts=self.parts, total=self.total) - for i in range(0, self.total): - if not chunk_filter(i): - continue - date = self.dates[i] - try: - arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) - self.tmp_storage.add([date, i, stats], key=date) - except MissingDateError: - self.tmp_storage.add([date, i, "missing"], key=date) - self.tmp_storage.flush() - LOG.debug(f"Dataset {self.path} additions run.") - - def allow_nans(self) -> bool: - """Check if NaNs are allowed. - - Returns - ------- - bool - Whether NaNs are allowed. - """ - if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): - return True - - variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) - if variables_with_nans is not None: - return variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") - return True - - -class _FinaliseAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): - """A class to finalize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize a _FinaliseAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Reading from {self.tmp_storage_path}.") - - def run(self) -> None: - """Run the additions finalization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}.") - return - - self.read_from_dataset() - - shape = (len(self.dates), len(self.variables)) - agg = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - has_nans=np.full(shape, False, dtype=np.bool_), - ) - LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") - - found = set() - ifound = set() - missing = set() - for _date, (date, i, stats) in self.tmp_storage.items(): - assert _date == date - if stats == "missing": - missing.add(date) - continue - - assert date not in found, f"Duplicates found {date}" - found.add(date) - ifound.add(i) - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k][i, ...] = stats[k] - - assert len(found) + len(missing) == len(self.dates), ( - len(found), - len(missing), - len(self.dates), - ) - assert found.union(missing) == set(self.dates), ( - found, - missing, - set(self.dates), - ) - - if len(ifound) < 2: - LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") - self.tmp_storage.delete() - return - - mask = sorted(list(ifound)) - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k] = agg[k][mask, ...] - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - assert agg[k].shape == agg["count"].shape, ( - agg[k].shape, - agg["count"].shape, - ) - - minimum = np.nanmin(agg["minimum"], axis=0) - maximum = np.nanmax(agg["maximum"], axis=0) - sums = np.nansum(agg["sums"], axis=0) - squares = np.nansum(agg["squares"], axis=0) - count = np.nansum(agg["count"], axis=0) - has_nans = np.any(agg["has_nans"], axis=0) - - assert sums.shape == count.shape - assert sums.shape == squares.shape - assert sums.shape == minimum.shape - assert sums.shape == maximum.shape - assert sums.shape == has_nans.shape - - mean = sums / count - assert sums.shape == mean.shape - - x = squares / count - mean * mean - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - # remove negative variance due to numerical errors - for i, name in enumerate(self.variables): - x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) - check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) - - stdev = np.sqrt(x) - assert sums.shape == stdev.shape - - self.summary = Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables, - has_nans=has_nans, - ) - LOG.info(f"Dataset {self.path} additions finalised.") - # self.check_statistics() - self._write(self.summary) - self.tmp_storage.delete() - - def _write(self, summary: Summary) -> None: - """Write the summary to the dataset. - - Parameters - ---------- - summary : Summary - The summary to write. - """ - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" - self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) - self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.debug(f"Wrote additions in {self.path}") - - -def multi_addition(cls: type) -> type: - """Create a class to handle multiple additions. - - Parameters - ---------- - cls : type - The class to handle additions. - - Returns - ------- - type - The class to handle multiple additions. - """ - - class MultiAdditions: - def __init__(self, *args, **kwargs: Any): - self.tasks = [] - - for k in kwargs.pop("delta", []): - self.tasks.append(cls(*args, delta=k, **kwargs)) - - if not self.tasks: - LOG.warning("No delta found in kwargs, no additions will be computed.") - - def run(self) -> None: - """Run the additions.""" - for actor in self.tasks: - actor.run() - - return MultiAdditions - - -InitAdditions = multi_addition(_InitAdditions) -LoadAdditions = multi_addition(_LoadAdditions) -FinaliseAdditions = multi_addition(_FinaliseAdditions) diff --git a/src/anemoi/datasets/create/fields/cleanup.py b/src/anemoi/datasets/create/fields/cleanup.py deleted file mode 100644 index 8b87ba3cc..000000000 --- a/src/anemoi/datasets/create/fields/cleanup.py +++ /dev/null @@ -1,60 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from typing import Any - -from .additions import _InitAdditions -from .tasks import FieldTask -from .tasks import HasRegistryMixin -from .tasks import HasStatisticTempMixin - -LOG = logging.getLogger(__name__) - - -class Cleanup(FieldTask, HasRegistryMixin, HasStatisticTempMixin): - """A class to clean up temporary data and registry entries.""" - - def __init__( - self, - path: str, - statistics_temp_dir: str | None = None, - delta: list = [], - use_threads: bool = False, - **kwargs: Any, - ): - """Initialize a Cleanup instance. - - Parameters - ---------- - path : str - The path to the dataset. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - delta : list, optional - The delta values. - use_threads : bool, optional - Whether to use threads. - """ - super().__init__(path) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.additinon_temp_dir = statistics_temp_dir - self.tasks = [ - _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) - for d in delta - ] - - def run(self) -> None: - """Run the cleanup.""" - - self.tmp_statistics.delete() - self.registry.clean() - for actor in self.tasks: - actor.cleanup() diff --git a/src/anemoi/datasets/create/fields/context.py b/src/anemoi/datasets/create/fields/context.py deleted file mode 100644 index ef3ebeca5..000000000 --- a/src/anemoi/datasets/create/fields/context.py +++ /dev/null @@ -1,78 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import logging -from typing import Any - -from anemoi.transform.fields import new_field_with_metadata -from anemoi.transform.fields import new_fieldlist_from_list -from earthkit.data.core.order import build_remapping - -from anemoi.datasets.create.input.context import Context - -LOG = logging.getLogger(__name__) - - -class FieldContext(Context): - - def __init__( - self, - /, - order_by: str, - flatten_grid: bool, - remapping: dict[str, Any], - use_grib_paramid: bool, - ) -> None: - - super().__init__() - - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - self.partial_ok = False - - def empty_result(self) -> Any: - import earthkit.data as ekd - - return ekd.from_source("empty") - - def source_argument(self, argument: Any) -> Any: - return argument # .dates - - def filter_argument(self, argument: Any) -> Any: - return argument - - def create_result(self, argument, data): - from anemoi.datasets.create.fields.result import FieldResult - - return FieldResult(self, argument, data) - - def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: - from anemoi.datasets.dates.groups import GroupOfDates - - return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) - - def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: - - origin = action.origin() - - result = [] - for fs in data: - previous = fs.metadata("anemoi_origin", default=None) - fall_through = fs.metadata("anemoi_fall_through", default=False) - if fall_through: - # The field has pass unchanges in a filter - result.append(fs) - else: - anemoi_origin = origin.combine(previous, action, action_arguments) - result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) - - return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/fields/init.py b/src/anemoi/datasets/create/fields/init.py deleted file mode 100644 index 347802c32..000000000 --- a/src/anemoi/datasets/create/fields/init.py +++ /dev/null @@ -1,293 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import uuid -from typing import Any - -import zarr -from anemoi.utils.sanitise import sanitise - -from anemoi.datasets.create.gridded.config import loader_config -from anemoi.datasets.create.gridded.utils import normalize_and_check_dates - -from .tasks import FieldTask -from .tasks import HasElementForDataMixin -from .tasks import HasRegistryMixin -from .tasks import HasStatisticTempMixin -from .tasks import NewDataset -from .tasks import _build_statistics_dates - -LOG = logging.getLogger(__name__) - -VERSION = "0.30" - - -def _path_readable(path: str) -> bool: - """Check if the path is readable. - - Parameters - ---------- - path : str - The path to check. - - Returns - ------- - bool - True if the path is readable, False otherwise. - """ - - try: - zarr.open(path, "r") - return True - except zarr.errors.PathNotFoundError: - return False - - -class Init(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to initialize a new dataset.""" - - dataset_class = NewDataset - - def __init__( - self, - path: str, - config: dict, - check_name: bool = False, - overwrite: bool = False, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - test: bool = False, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize an Init instance. - - Parameters - ---------- - path : str - The path to the dataset. - config : dict - The configuration. - check_name : bool, optional - Whether to check the dataset name. - overwrite : bool, optional - Whether to overwrite the existing dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - test : bool, optional - Whether this is a test. - cache : Optional[str], optional - The cache directory. - """ - if _path_readable(path) and not overwrite: - raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - - super().__init__(path, cache=cache) - self.config = config - self.check_name = check_name - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.test = test - - self.main_config = loader_config(config, is_test=test) - - # self.registry.delete() ?? - self.tmp_statistics.delete() - - assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by - self.create_elements(self.main_config) - - LOG.info(f"Groups: {self.groups}") - - # window = self.main_config.dates.get("window") - - one_date = self.groups.one_date() - - self.minimal_input = self.input.select(self.context, one_date) - - LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") - LOG.info(self.minimal_input) - - def run(self) -> int: - """Run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - with self._cache_context(): - return self._run() - - def _run(self) -> int: - """Internal method to run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - """Create an empty dataset of the right final shape. - - Read a small part of the data to get the shape of the data and the resolution and more metadata. - """ - - LOG.info("Config loaded ok:") - # LOG.info(self.main_config) - - dates = self.groups.provider.values - frequency = self.groups.provider.frequency - missing = self.groups.provider.missing - - assert isinstance(frequency, datetime.timedelta), frequency - - LOG.info(f"Found {len(dates)} datetimes.") - LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(missing)}") - lengths = tuple(len(g) for g in self.groups) - - variables = self.minimal_input.variables - LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") - - variables_with_nans = self.main_config.statistics.get("allow_nans", []) - - ensembles = self.minimal_input.ensembles - LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") - - grid_points = self.minimal_input.grid_points - LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") - - resolution = self.minimal_input.resolution - LOG.info(f"{resolution=}") - - coords = self.minimal_input.coords - coords["dates"] = dates - total_shape = self.minimal_input.shape - total_shape[0] = len(dates) - LOG.info(f"total_shape = {total_shape}") - - chunks = self.output.get_chunking(coords) - LOG.info(f"{chunks=}") - dtype = self.output.dtype - - LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") - - metadata = {} - metadata["uuid"] = str(uuid.uuid4()) - - metadata.update(self.main_config.get("add_metadata", {})) - - metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() - - recipe = sanitise(self.main_config.get_serialisable_dict()) - - # Remove stuff added by prepml - for k in [ - "build_dataset", - "config_format_version", - "config_path", - "dataset_status", - "ecflow", - "metadata", - "platform", - "reading_chunks", - "upload", - ]: - recipe.pop(k, None) - - metadata["recipe"] = recipe - - metadata["description"] = self.main_config.description - metadata["licence"] = self.main_config["licence"] - metadata["attribution"] = self.main_config["attribution"] - - metadata["remapping"] = self.output.remapping - metadata["order_by"] = self.output.order_by_as_list - metadata["flatten_grid"] = self.output.flatten_grid - - metadata["ensemble_dimension"] = len(ensembles) - metadata["variables"] = variables - metadata["variables_with_nans"] = variables_with_nans - metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) - metadata["resolution"] = resolution - - metadata["data_request"] = self.minimal_input.data_request - metadata["field_shape"] = self.minimal_input.field_shape - metadata["proj_string"] = self.minimal_input.proj_string - metadata["variables_metadata"] = self.minimal_input.variables_metadata - - metadata["start_date"] = dates[0].isoformat() - metadata["end_date"] = dates[-1].isoformat() - metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in missing] - metadata["origins"] = self.minimal_input.origins - - metadata["version"] = VERSION - - self.dataset.check_name( - raise_exception=self.check_name, - is_test=self.test, - resolution=resolution, - dates=dates, - frequency=frequency, - ) - - if len(dates) != total_shape[0]: - raise ValueError( - f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " - f"does not match data shape {total_shape[0]}. {total_shape=}" - ) - - dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) - - metadata.update(self.main_config.get("force_metadata", {})) - - ############################################################### - # write metadata - ############################################################### - - self.update_metadata(**metadata) - - self.dataset.add_dataset( - name="data", - chunks=chunks, - dtype=dtype, - shape=total_shape, - dimensions=("time", "variable", "ensemble", "cell"), - ) - self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) - self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) - self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) - - self.registry.create(lengths=lengths) - self.tmp_statistics.create(exist_ok=False) - self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - - statistics_start, statistics_end = _build_statistics_dates( - dates, - self.main_config.statistics.get("start"), - self.main_config.statistics.get("end"), - ) - self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) - LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") - - self.registry.add_to_history("init finished") - - assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - - # Return the number of groups to process, so we can show a nice progress bar - return len(lengths) diff --git a/src/anemoi/datasets/create/fields/load.py b/src/anemoi/datasets/create/fields/load.py deleted file mode 100644 index 813b6b3ea..000000000 --- a/src/anemoi/datasets/create/fields/load.py +++ /dev/null @@ -1,260 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import time -from typing import Any - -import numpy as np -import tqdm -from anemoi.utils.dates import as_datetime -from anemoi.utils.humanize import compress_dates -from anemoi.utils.humanize import seconds_to_human - -from anemoi.datasets.create.chunks import ChunkFilter -from anemoi.datasets.create.gridded.check import check_data_values -from anemoi.datasets.create.gridded.statistics import compute_statistics -from anemoi.datasets.create.writer import ViewCacheArray - -from .tasks import FieldTask -from .tasks import HasElementForDataMixin -from .tasks import HasRegistryMixin -from .tasks import HasStatisticTempMixin -from .tasks import WritableDataset - -LOG = logging.getLogger(__name__) - - -class Load(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to load data into a dataset.""" - - def __init__( - self, - path: str, - parts: str | None = None, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize a Load instance. - - Parameters - ---------- - path : str - The path to the dataset. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - cache : Optional[str], optional - The cache directory. - """ - super().__init__(path, cache=cache) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.parts = parts - self.dataset = WritableDataset(self.path) - - self.main_config = self.dataset.get_main_config() - self.create_elements(self.main_config) - self.read_dataset_metadata(self.dataset.path) - - total = len(self.registry.get_flags()) - self.chunk_filter = ChunkFilter(parts=self.parts, total=total) - - self.data_array = self.dataset.data_array - self.n_groups = len(self.groups) - - def run(self) -> None: - """Run the data loading.""" - with self._cache_context(): - self._run() - - def _run(self) -> None: - """Internal method to run the data loading.""" - for igroup, group in enumerate(self.groups): - if not self.chunk_filter(igroup): - continue - if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") - continue - - # assert isinstance(group[0], datetime.datetime), type(group[0]) - LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - - result = self.input.select(self.context, argument=group) - assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) - - # There are several groups. - # There is one result to load for each group. - self.load_result(result) - self.registry.set_flag(igroup) - - self.registry.add_provenance(name="provenance_load") - self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - - self.dataset.print_info() - - def load_result(self, result: Any) -> None: - """Load the result into the dataset. - - Parameters - ---------- - result : Any - The result to load. - """ - # There is one cube to load for each result. - dates = list(result.group_of_dates) - - LOG.debug(f"Loading cube for {len(dates)} dates") - - cube = result.get_cube() - shape = cube.extended_user_shape - dates_in_data = cube.user_coords["valid_datetime"] - - LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") - - def check_shape(cube, dates, dates_in_data): - if cube.extended_user_shape[0] != len(dates): - print( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - print("Requested dates", compress_dates(dates)) - print("Cube dates", compress_dates(dates_in_data)) - - a = {as_datetime(_) for _ in dates} - b = {as_datetime(_) for _ in dates_in_data} - - print("Missing dates", compress_dates(a - b)) - print("Extra dates", compress_dates(b - a)) - - raise ValueError( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - - check_shape(cube, dates, dates_in_data) - - def check_dates_in_data(dates_in_data, requested_dates): - _requested_dates = [np.datetime64(_) for _ in requested_dates] - _dates_in_data = [np.datetime64(_) for _ in dates_in_data] - if _dates_in_data != _requested_dates: - LOG.error("Dates in data are not the requested ones:") - - dates_in_data = set(dates_in_data) - requested_dates = set(requested_dates) - - missing = sorted(requested_dates - dates_in_data) - extra = sorted(dates_in_data - requested_dates) - - if missing: - LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") - if extra: - LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") - - raise ValueError("Dates in data are not the requested ones") - - check_dates_in_data(dates_in_data, dates) - - def dates_to_indexes(dates, all_dates): - x = np.array(dates, dtype=np.datetime64) - y = np.array(all_dates, dtype=np.datetime64) - bitmap = np.isin(x, y) - return np.where(bitmap)[0] - - indexes = dates_to_indexes(self.dates, dates_in_data) - - array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) - LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") - self.load_cube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) - self.tmp_statistics.write(indexes, stats, dates=dates_in_data) - LOG.info("Flush data array") - array.flush() - LOG.info("Flushed data array") - - def _get_allow_nans(self) -> bool | list: - """Get the allow_nans configuration. - - Returns - ------- - bool | list - The allow_nans configuration. - """ - config = self.main_config - if "allow_nans" in config.build: - return config.build.allow_nans - - return config.statistics.get("allow_nans", []) - - def load_cube(self, cube: Any, array: ViewCacheArray) -> None: - """Load the cube into the array. - - Parameters - ---------- - cube : Any - The cube to load. - array : ViewCacheArray - The array to load into. - """ - # There are several cubelets for each cube - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - LOG.debug(f"Loading datacube: {cube}") - - def position(x: Any) -> int | None: - if isinstance(x, str) and "/" in x: - x = x.split("/") - return int(x[0]) - return None - - bar = tqdm.tqdm( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - position=position(self.parts), - ) - for i, cubelet in enumerate(bar): - bar.set_description(f"Loading {i}/{total}") - - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values( - data[:], - name=name, - log=[i, data.shape, local_indexes], - allow_nans=self._get_allow_nans(), - ) - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.debug( - f"Elapsed: {seconds_to_human(time.time() - start)}, " - f"load time: {seconds_to_human(load)}, " - f"write time: {seconds_to_human(save)}." - ) diff --git a/src/anemoi/datasets/create/fields/patch.py b/src/anemoi/datasets/create/fields/patch.py deleted file mode 100644 index 546d53a13..000000000 --- a/src/anemoi/datasets/create/fields/patch.py +++ /dev/null @@ -1,38 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from typing import Any - -from .tasks import FieldTask - -LOG = logging.getLogger(__name__) - - -class Patch(FieldTask): - """A class to apply patches to a dataset.""" - - def __init__(self, path: str, options: dict = None, **kwargs: Any): - """Initialize a Patch instance. - - Parameters - ---------- - path : str - The path to the dataset. - options : dict, optional - The patch options. - """ - self.path = path - self.options = options or {} - - def run(self) -> None: - """Run the patch.""" - from anemoi.datasets.create.patch import apply_patch - - apply_patch(self.path, **self.options) diff --git a/src/anemoi/datasets/create/fields/result.py b/src/anemoi/datasets/create/fields/result.py deleted file mode 100644 index d4bcf58ea..000000000 --- a/src/anemoi/datasets/create/fields/result.py +++ /dev/null @@ -1,668 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import itertools -import logging -import math -import time -from collections import defaultdict -from functools import cached_property -from typing import Any -from typing import DefaultDict - -import numpy as np -from anemoi.utils.dates import as_timedelta -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.humanize import shorten_list -from earthkit.data.core.order import build_remapping - -from anemoi.datasets.create.input.result import Result - -LOG = logging.getLogger(__name__) - - -def _fields_metatata(variables: tuple[str, ...], cube: Any) -> dict[str, Any]: - """Retrieve metadata for the given variables and cube. - - Parameters - ---------- - variables : tuple of str - The variables to retrieve metadata for. - cube : Any - The data cube. - - Returns - ------- - dict - The metadata dictionary. - """ - assert isinstance(variables, tuple), variables - - KNOWN: dict[str, dict[str, bool]] = { - "cos_julian_day": dict(computed_forcing=True, constant_in_time=False), - "cos_latitude": dict(computed_forcing=True, constant_in_time=True), - "cos_local_time": dict(computed_forcing=True, constant_in_time=False), - "cos_longitude": dict(computed_forcing=True, constant_in_time=True), - "cos_solar_zenith_angle": dict(computed_forcing=True, constant_in_time=False), - "insolation": dict(computed_forcing=True, constant_in_time=False), - "latitude": dict(computed_forcing=True, constant_in_time=True), - "longitude": dict(computed_forcing=True, constant_in_time=True), - "sin_julian_day": dict(computed_forcing=True, constant_in_time=False), - "sin_latitude": dict(computed_forcing=True, constant_in_time=True), - "sin_local_time": dict(computed_forcing=True, constant_in_time=False), - "sin_longitude": dict(computed_forcing=True, constant_in_time=True), - } - - def _merge(md1: dict[str, Any], md2: dict[str, Any]) -> dict[str, Any]: - assert set(md1.keys()) == set(md2.keys()), (set(md1.keys()), set(md2.keys())) - result: dict[str, Any] = {} - for k in md1.keys(): - v1 = md1[k] - v2 = md2[k] - - if v1 == v2: - result[k] = v1 - continue - - if isinstance(v1, list): - assert v2 not in v1, (v1, v2) - result[k] = sorted(v1 + [v2]) - continue - - if isinstance(v2, list): - assert v1 not in v2, (v1, v2) - result[k] = sorted(v2 + [v1]) - continue - - result[k] = sorted([v1, v2]) - - return result - - mars: dict[str, Any] = {} - other: DefaultDict[str, dict[str, Any]] = defaultdict(dict) - i: int = -1 - date: str | None = None - for c in cube.iterate_cubelets(): - - if date is None: - date = c._coords_names[0] - - if date != c._coords_names[0]: - continue - - if i == -1 or c._coords_names[1] != variables[i]: - i += 1 - - f = cube[c.coords] - md = f.metadata(namespace="mars") - if not md: - md = f.metadata(namespace="default") - - if md.get("param") == "~": - md["param"] = f.metadata("param") - assert md["param"] not in ("~", "unknown"), (md, f.metadata("param")) - - if md.get("param") == "unknown": - md["param"] = str(f.metadata("paramId", default="unknown")) - # assert md['param'] != 'unknown', (md, f.metadata('param')) - - startStep = f.metadata("startStep", default=None) - if startStep is not None: - startStep = as_timedelta(startStep) - - endStep = f.metadata("endStep", default=None) - if endStep is not None: - endStep = as_timedelta(endStep) - - stepTypeForConversion = f.metadata("stepTypeForConversion", default=None) - typeOfStatisticalProcessing = f.metadata("typeOfStatisticalProcessing", default=None) - timeRangeIndicator = f.metadata("timeRangeIndicator", default=None) - - # GRIB1 precipitation accumulations are not correctly encoded - if startStep == endStep and stepTypeForConversion == "accum": - endStep = f.metadata("P1") - startStep = f.metadata("P2") - - if startStep != endStep: - # https://codes.ecmwf.int/grib/format/grib2/ctables/4/10/ - TYPE_OF_STATISTICAL_PROCESSING: dict[int | None, str | None] = { - None: None, - 0: "average", - 1: "accumulation", - 2: "maximum", - 3: "minimum", - 4: "difference(end-start)", - 5: "root_mean_square", - 6: "standard_deviation", - 7: "covariance", - 8: "difference(start-end)", - 9: "ratio", - 10: "standardized_anomaly", - 11: "summation", - 100: "severity", - 101: "mode", - } - - # https://codes.ecmwf.int/grib/format/grib1/ctable/5/ - - TIME_RANGE_INDICATOR: dict[int, str] = { - 4: "accumulation", - 3: "average", - } - - STEP_TYPE_FOR_CONVERSION: dict[str, str] = { - "min": "minimum", - "max": "maximum", - "accum": "accumulation", - } - - # - # A few patches - # - - PATCHES: dict[str, str] = { - "10fg6": "maximum", - "mntpr3": "minimum", # Not in param db - "mntpr6": "minimum", # Not in param db - "mxtpr3": "maximum", # Not in param db - "mxtpr6": "maximum", # Not in param db - } - - process = TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing) - if process is None: - process = TIME_RANGE_INDICATOR.get(timeRangeIndicator) - if process is None: - process = STEP_TYPE_FOR_CONVERSION.get(stepTypeForConversion) - if process is None: - process = PATCHES.get(md["param"]) - if process is not None: - LOG.error(f"Unknown process {stepTypeForConversion} for {md['param']}, using {process} instead") - - if process is None: - raise ValueError( - f"Unknown for {md['param']}:" - f" {stepTypeForConversion=} ({STEP_TYPE_FOR_CONVERSION.get('stepTypeForConversion')})," - f" {typeOfStatisticalProcessing=} ({TYPE_OF_STATISTICAL_PROCESSING.get(typeOfStatisticalProcessing)})," - f" {timeRangeIndicator=} ({TIME_RANGE_INDICATOR.get(timeRangeIndicator)})" - ) - - # print(md["param"], "startStep", startStep, "endStep", endStep, "process", process, "typeOfStatisticalProcessing", typeOfStatisticalProcessing) - other[variables[i]]["process"] = process - other[variables[i]]["period"] = (startStep, endStep) - - for k in md.copy().keys(): - if k.startswith("_"): - md.pop(k) - - if variables[i] in mars: - mars[variables[i]] = _merge(md, mars[variables[i]]) - else: - mars[variables[i]] = md - - result: dict[str, dict[str, Any]] = {} - for k, v in mars.items(): - result[k] = dict(mars=v) if v else {} - result[k].update(other[k]) - result[k].update(KNOWN.get(k, {})) - # assert result[k], k - - assert i + 1 == len(variables), (i + 1, len(variables)) - return result - - -def _data_request(data: Any) -> dict[str, Any]: - """Build a data request dictionary from the given data. - - Parameters - ---------- - data : Any - The data to build the request from. - - Returns - ------- - dict - The data request dictionary. - """ - date: Any | None = None - params_levels: DefaultDict[str, set] = defaultdict(set) - params_steps: DefaultDict[str, set] = defaultdict(set) - - area: Any | None = None - grid: Any | None = None - - for field in data: - try: - if date is None: - date = field.metadata("valid_datetime") - - if field.metadata("valid_datetime") != date: - continue - - as_mars = field.metadata(namespace="mars") - if not as_mars: - continue - step = as_mars.get("step") - levtype = as_mars.get("levtype", "sfc") - param = as_mars["param"] - levelist = as_mars.get("levelist", None) - area = field.mars_area - grid = field.mars_grid - - if levelist is None: - params_levels[levtype].add(param) - else: - params_levels[levtype].add((param, levelist)) - - if step: - params_steps[levtype].add((param, step)) - except Exception: - LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) - - def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: - new_dic: dict[str, list[Any]] = {} - for k, v in old_dic.items(): - new_dic[k] = sorted(list(v)) - return new_dic - - params_steps = sort(params_steps) - params_levels = sort(params_levels) - - return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) - - -class FieldResult(Result): - """Class to represent the result of an action in the dataset creation process.""" - - empty: bool = False - _coords_already_built: bool = False - - def __init__(self, context: Any, argument: Any, datasource: Any) -> None: - - from anemoi.datasets.dates.groups import GroupOfDates - - self.context: Any = context - self.datasource = datasource - self.group_of_dates = argument - assert isinstance( - self.group_of_dates, GroupOfDates - ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" - - self._origins = [] - - @property - def data_request(self) -> dict[str, Any]: - """Returns a dictionary with the parameters needed to retrieve the data.""" - return _data_request(self.datasource) - - @property - def origins(self) -> dict[str, Any]: - """Returns a dictionary with the parameters needed to retrieve the data.""" - return {"version": 1, "origins": self._origins} - - def get_cube(self) -> Any: - """Retrieve the data cube for the result. - - Returns - ------- - Any - The data cube. - """ - - ds: Any = self.datasource - - self.remapping: Any = self.context.remapping - self.order_by: Any = self.context.order_by - self.flatten_grid: Any = self.context.flatten_grid - self.start: float = time.time() - LOG.debug("Sorting dataset %s %s", dict(self.order_by), self.remapping) - assert self.order_by, self.order_by - - self.patches: dict[str, dict[Any | None, int]] = {"number": {None: 0}} - - try: - cube: Any = ds.cube( - self.order_by, - remapping=self.remapping, - flatten_values=self.flatten_grid, - patches=self.patches, - ) - cube = cube.squeeze() - LOG.debug(f"Sorting done in {seconds_to_human(time.time()-self.start)}.") - except ValueError: - self.explain(ds, self.order_by, remapping=self.remapping, patches=self.patches) - # raise ValueError(f"Error in {self}") - exit(1) - - if LOG.isEnabledFor(logging.DEBUG): - LOG.debug("Cube shape: %s", cube) - for k, v in cube.user_coords.items(): - LOG.debug(" %s %s", k, shorten_list(v, max_length=10)) - - return cube - - def explain(self, ds: Any, *args: Any, remapping: Any, patches: Any) -> None: - """Explain the data cube creation process. - - Parameters - ---------- - ds : Any - The data source. - args : Any - Additional arguments. - remapping : Any - The remapping configuration. - patches : Any - The patches configuration. - """ - METADATA: tuple[str, ...] = ( - "date", - "time", - "step", - "hdate", - "valid_datetime", - "levtype", - "levelist", - "number", - "level", - "shortName", - "paramId", - "variable", - ) - - # We redo the logic here - print() - print("❌" * 40) - print() - if len(args) == 1 and isinstance(args[0], (list, tuple)): - args = args[0] - - # print("Executing", self.action_path) - # print("Dates:", compress_dates(self.dates)) - - names: list[str] = [] - for a in args: - if isinstance(a, str): - names.append(a) - elif isinstance(a, dict): - names += list(a.keys()) - - print(f"Building a {len(names)}D hypercube using", names) - ds = ds.order_by(*args, remapping=remapping, patches=patches) - user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False) - - print() - print("Number of unique values found for each coordinate:") - for k, v in user_coords.items(): - print(f" {k:20}:", len(v)) - for n in sorted(v): - print(" ", n) - - print() - user_shape: tuple[int, ...] = tuple(len(v) for k, v in user_coords.items()) - print("Shape of the hypercube :", user_shape) - print( - "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape]) - ) - print("Number of fields in the dataset :", len(ds)) - print("Difference :", abs(len(ds) - math.prod(user_shape))) - print() - - remapping = build_remapping(remapping, patches) - expected = set(itertools.product(*user_coords.values())) - extra = set() - - if math.prod(user_shape) > len(ds): - print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.") - - for f in ds: - metadata = remapping(f.metadata) - key = tuple(metadata(n, default=None) for n in names) - if key in expected: - expected.remove(key) - else: - extra.add(key) - - print("Missing fields:") - print() - for i, f in enumerate(sorted(expected)): - print(" ", f) - if i >= 9 and len(expected) > 10: - print("...", len(expected) - i - 1, "more") - break - - print("Extra fields:") - print() - for i, f in enumerate(sorted(extra)): - print(" ", f) - if i >= 9 and len(extra) > 10: - print("...", len(extra) - i - 1, "more") - break - - print() - print("Missing values:") - per_name = defaultdict(set) - for e in expected: - for n, v in zip(names, e): - per_name[n].add(v) - - for n, v in per_name.items(): - print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) - print() - - print("Extra values:") - per_name = defaultdict(set) - for e in extra: - for n, v in zip(names, e): - per_name[n].add(v) - - for n, v in per_name.items(): - print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) - print() - - print("To solve this issue, you can:") - print( - " - Provide a better selection, like 'step: 0' or 'level: 1000' to " - "reduce the number of selected fields." - ) - print( - " - Split the 'input' part in smaller sections using 'join', " - "making sure that each section represent a full hypercube." - ) - - else: - print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.") - duplicated = defaultdict(list) - for f in ds: - # print(f.metadata(namespace="default")) - metadata = remapping(f.metadata) - key = tuple(metadata(n, default=None) for n in names) - duplicated[key].append(f) - - print("Duplicated fields:") - print() - duplicated = {k: v for k, v in duplicated.items() if len(v) > 1} - for i, (k, v) in enumerate(sorted(duplicated.items())): - print(" ", k) - for f in v: - x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None} - print(" ", f, x) - if i >= 9 and len(duplicated) > 10: - print("...", len(duplicated) - i - 1, "more") - break - - print() - print("To solve this issue, you can:") - print(" - Provide a better selection, like 'step: 0' or 'level: 1000'") - print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.") - - print() - print("❌" * 40) - print() - exit(1) - - def build_coords(self) -> None: - """Build the coordinates for the result.""" - if self._coords_already_built: - return - - cube: Any = self.get_cube() - - from_data: Any = cube.user_coords - from_config: Any = self.context.order_by - - keys_from_config: list = list(from_config.keys()) - keys_from_data: list = list(from_data.keys()) - assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}" - - variables_key: str = list(from_config.keys())[1] - ensembles_key: str = list(from_config.keys())[2] - - if isinstance(from_config[variables_key], (list, tuple)): - assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), ( - from_data[variables_key], - from_config[variables_key], - ) - - self._variables: Any = from_data[variables_key] # "param_level" - self._ensembles: Any = from_data[ensembles_key] # "number" - - first_field: Any = self.datasource[0] - grid_points: Any = first_field.grid_points() - - lats: Any = grid_points[0] - lons: Any = grid_points[1] - - assert len(lats) == len(lons), (len(lats), len(lons), first_field) - assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field) - - north: float = np.amax(lats) - south: float = np.amin(lats) - east: float = np.amax(lons) - west: float = np.amin(lons) - - assert -90 <= south <= north <= 90, (south, north, first_field) - assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( - west, - east, - first_field, - ) - - grid_values: list = list(range(len(grid_points[0]))) - - self._grid_points: Any = grid_points - self._resolution: Any = first_field.resolution - self._grid_values: Any = grid_values - self._field_shape: Any = first_field.shape - self._proj_string: Any = first_field.proj_string if hasattr(first_field, "proj_string") else None - - self._cube: Any = cube - - name_key = list(self.order_by.keys())[1] - - p = None - origins_per_number = defaultdict(lambda: defaultdict(set)) - - for fs in self.datasource: - o = fs.metadata("anemoi_origin", remapping=self.remapping, patches=self.patches) - name = fs.metadata(name_key, remapping=self.remapping, patches=self.patches) - number = fs.metadata("number", remapping=self.remapping, patches=self.patches) - - assert name not in origins_per_number[number][o], name - origins_per_number[number][o].add(name) - - if p is not o: - LOG.info(f"🔥🔥🔥🔥🔥🔥 Source: {name}, {o}") - p = o - - origins_per_variables = defaultdict(lambda: defaultdict(set)) - for number, origins in origins_per_number.items(): - for origin, names in origins.items(): - for name in names: - origins_per_variables[name][origin].add(number) - - origins = defaultdict(set) - - # Check if all members of a variable have the same origins - for name, origin_number in origins_per_variables.items(): - # For now we do not support variables with members from different origins - assert len(origin_number) == 1, origin_number - origins[list(origin_number.keys())[0]].add(name) - - self._origins = [] - for k, v in origins.items(): - self._origins.append({"origin": k.as_dict(), "variables": sorted(v)}) - - self._coords_already_built: bool = True - - @property - def variables(self) -> list[str]: - """Retrieve the variables for the result.""" - self.build_coords() - return self._variables - - @property - def variables_metadata(self) -> dict[str, Any]: - """Retrieve the metadata for the variables.""" - return _fields_metatata(self.variables, self._cube) - - @property - def ensembles(self) -> Any: - """Retrieve the ensembles for the result.""" - self.build_coords() - return self._ensembles - - @property - def resolution(self) -> Any: - """Retrieve the resolution for the result.""" - self.build_coords() - return self._resolution - - @property - def grid_values(self) -> Any: - """Retrieve the grid values for the result.""" - self.build_coords() - return self._grid_values - - @property - def grid_points(self) -> Any: - """Retrieve the grid points for the result.""" - self.build_coords() - return self._grid_points - - @property - def field_shape(self) -> Any: - """Retrieve the field shape for the result.""" - self.build_coords() - return self._field_shape - - @property - def proj_string(self) -> Any: - """Retrieve the projection string for the result.""" - self.build_coords() - return self._proj_string - - @cached_property - def shape(self) -> list[int]: - """Retrieve the shape of the result.""" - return [ - len(self.group_of_dates), - len(self.variables), - len(self.ensembles), - len(self.grid_values), - ] - - @cached_property - def coords(self) -> dict[str, Any]: - """Retrieve the coordinates of the result.""" - return { - "dates": list(self.group_of_dates), - "variables": self.variables, - "ensembles": self.ensembles, - "values": self.grid_values, - } diff --git a/src/anemoi/datasets/create/fields/size.py b/src/anemoi/datasets/create/fields/size.py deleted file mode 100644 index 10c64d4d7..000000000 --- a/src/anemoi/datasets/create/fields/size.py +++ /dev/null @@ -1,48 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from typing import Any - -from anemoi.datasets import open_dataset - -from .tasks import FieldTask - -LOG = logging.getLogger(__name__) - - -class Size(FieldTask): - """A class to compute the size of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Size instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the size computation.""" - from anemoi.datasets.create.size import compute_directory_sizes - - metadata = compute_directory_sizes(self.path) - self.update_metadata(**metadata) - - # Look for constant fields - ds = open_dataset(self.path) - constants = ds.computed_constant_fields() - - variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() - for k in constants: - variables_metadata[k]["constant_in_time"] = True - - self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) diff --git a/src/anemoi/datasets/create/fields/statistics.py b/src/anemoi/datasets/create/fields/statistics.py deleted file mode 100644 index b199fd052..000000000 --- a/src/anemoi/datasets/create/fields/statistics.py +++ /dev/null @@ -1,102 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import warnings -from functools import cached_property -from typing import Any - -import numpy as np -import zarr - -from .tasks import FieldTask -from .tasks import HasRegistryMixin -from .tasks import HasStatisticTempMixin - -LOG = logging.getLogger(__name__) - - -class Statistics(FieldTask, HasStatisticTempMixin, HasRegistryMixin): - """A class to compute statistics for a dataset.""" - - def __init__( - self, - path: str, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a Statistics instance. - - Parameters - ---------- - path : str - The path to the dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.use_threads = use_threads - self.progress = progress - self.statistics_temp_dir = statistics_temp_dir - - def run(self) -> None: - """Run the statistics computation.""" - start, end = ( - self.dataset.zarr_metadata["statistics_start_date"], - self.dataset.zarr_metadata["statistics_end_date"], - ) - start, end = np.datetime64(start), np.datetime64(end) - dates = self.dataset.anemoi_dataset.dates - - assert type(dates[0]) is type(start), (type(dates[0]), type(start)) - - dates = [d for d in dates if d >= start and d <= end] - dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] - variables = self.dataset.anemoi_dataset.variables - stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) - - LOG.info(stats) - - if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - "has_nans", - ]: - self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) - - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote statistics in {self.path}") - - @cached_property - def allow_nans(self) -> bool | list: - """Check if NaNs are allowed.""" - - z = zarr.open(self.path, mode="r") - if "allow_nans" in z.attrs: - return z.attrs["allow_nans"] - - if "variables_with_nans" in z.attrs: - return z.attrs["variables_with_nans"] - - warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") - return True diff --git a/src/anemoi/datasets/create/fields/tasks.py b/src/anemoi/datasets/create/fields/tasks.py deleted file mode 100644 index cafbdd233..000000000 --- a/src/anemoi/datasets/create/fields/tasks.py +++ /dev/null @@ -1,606 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import json -import logging -import os -from functools import cached_property -from typing import Any - -import cftime -import numpy as np -import zarr -from anemoi.utils.dates import frequency_to_string -from earthkit.data.core.order import build_remapping - -from anemoi.datasets import open_dataset -from anemoi.datasets.create.fields.context import FieldContext -from anemoi.datasets.create.gridded.check import DatasetName -from anemoi.datasets.create.gridded.config import build_output -from anemoi.datasets.create.gridded.config import loader_config -from anemoi.datasets.create.gridded.statistics import TmpStatistics -from anemoi.datasets.create.gridded.statistics import default_statistics_dates -from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.use.gridded.misc import as_first_date -from anemoi.datasets.use.gridded.misc import as_last_date - -from ..tasks import chain - -LOG = logging.getLogger(__name__) - - -def _json_tidy(o: Any) -> Any: - """Convert various types to JSON serializable format. - - Parameters - ---------- - o : Any - The object to convert. - - Returns - ------- - Any - The JSON serializable object. - """ - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.timedelta): - return frequency_to_string(o) - - if isinstance(o, cftime.DatetimeJulian): - import pandas as pd - - o = pd.Timestamp( - o.year, - o.month, - o.day, - o.hour, - o.minute, - o.second, - ) - return o.isoformat() - - if isinstance(o, (np.float32, np.float64)): - return float(o) - - raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") - - -def _build_statistics_dates( - dates: list[datetime.datetime], - start: datetime.datetime | None, - end: datetime.datetime | None, -) -> tuple[str, str]: - """Compute the start and end dates for the statistics. - - Parameters - ---------- - dates : list of datetime.datetime - The list of dates. - start : Optional[datetime.datetime] - The start date. - end : Optional[datetime.datetime] - The end date. - - Returns - ------- - tuple of str - The start and end dates in ISO format. - """ - # if not specified, use the default statistics dates - default_start, default_end = default_statistics_dates(dates) - if start is None: - start = default_start - if end is None: - end = default_end - - # in any case, adapt to the actual dates in the dataset - start = as_first_date(start, dates) - end = as_last_date(end, dates) - - # and convert to datetime to isoformat - start = start.astype(datetime.datetime) - end = end.astype(datetime.datetime) - return (start.isoformat(), end.isoformat()) - - -class Dataset: - """A class to represent a dataset.""" - - def __init__(self, path: str): - """Initialize a Dataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - self.path = path - - _, ext = os.path.splitext(self.path) - if ext != ".zarr": - raise ValueError(f"Unsupported extension={ext} for path={self.path}") - - def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: - """Add a dataset to the Zarr store. - - Parameters - ---------- - mode : str, optional - The mode to open the Zarr store. - **kwargs - Additional arguments for the dataset. - - Returns - ------- - zarr.Array - The added dataset. - """ - import zarr - - z = zarr.open(self.path, mode=mode) - from anemoi.datasets.create.gridded.zarr import add_zarr_dataset - - return add_zarr_dataset(zarr_root=z, **kwargs) - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - import zarr - - LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") - for k, v in kwargs.items(): - if isinstance(v, np.datetime64): - v = v.astype(datetime.datetime) - if isinstance(v, datetime.date): - v = v.isoformat() - z.attrs[k] = json.loads(json.dumps(v, default=_json_tidy)) - - @cached_property - def anemoi_dataset(self) -> Any: - """Get the Anemoi dataset.""" - return open_dataset(self.path) - - @cached_property - def zarr_metadata(self) -> dict: - """Get the Zarr metadata.""" - import zarr - - return dict(zarr.open(self.path, mode="r").attrs) - - def print_info(self) -> None: - """Print information about the dataset.""" - import zarr - - z = zarr.open(self.path, mode="r") - try: - LOG.info(z["data"].info) - except Exception as e: - LOG.info(e) - - def get_zarr_chunks(self) -> tuple: - """Get the chunks of the Zarr dataset. - - Returns - ------- - tuple - The chunks of the Zarr dataset. - """ - import zarr - - z = zarr.open(self.path, mode="r") - return z["data"].chunks - - def check_name( - self, - resolution: str, - dates: list[datetime.datetime], - frequency: datetime.timedelta, - raise_exception: bool = True, - is_test: bool = False, - ) -> None: - """Check the name of the dataset. - - Parameters - ---------- - resolution : str - The resolution of the dataset. - dates : list of datetime.datetime - The dates of the dataset. - frequency : datetime.timedelta - The frequency of the dataset. - raise_exception : bool, optional - Whether to raise an exception if the name is invalid. - is_test : bool, optional - Whether this is a test. - """ - basename, _ = os.path.splitext(os.path.basename(self.path)) - try: - DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() - except Exception as e: - if raise_exception and not is_test: - raise e - else: - LOG.warning(f"Dataset name error: {e}") - - def get_main_config(self) -> Any: - """Get the main configuration of the dataset. - - Returns - ------- - Any - The main configuration. - """ - import zarr - - z = zarr.open(self.path, mode="r") - config = loader_config(z.attrs.get("_create_yaml_config")) - - if "env" in config: - for k, v in config["env"].items(): - LOG.info(f"Setting env variable {k}={v}") - os.environ[k] = str(v) - - return config - - -class WritableDataset(Dataset): - """A class to represent a writable dataset.""" - - def __init__(self, path: str): - """Initialize a WritableDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="r+") - - @cached_property - def data_array(self) -> Any: - """Get the data array of the dataset.""" - import zarr - - return zarr.open(self.path, mode="r+")["data"] - - -class NewDataset(Dataset): - """A class to represent a new dataset.""" - - def __init__(self, path: str, overwrite: bool = False): - """Initialize a NewDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - overwrite : bool, optional - Whether to overwrite the existing dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="w") - self.z.create_group("_build") - - -class FieldTask: - """A base class for dataset creation tasks.""" - - dataset_class = WritableDataset - - def __init__(self, path: str, cache: str | None = None): - """Initialize an Actor instance. - - Parameters - ---------- - path : str - The path to the dataset. - cache : Optional[str], optional - The cache directory. - """ - # Catch all floating point errors, including overflow, sqrt(<0), etc - np.seterr(all="raise", under="warn") - - self.path = path - self.cache = cache - self.dataset = self.dataset_class(self.path) - - def run(self) -> None: - """Run the actor.""" - # to be implemented in the sub-classes - raise NotImplementedError() - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - self.dataset.update_metadata(**kwargs) - - def _cache_context(self) -> Any: - """Get the cache context. - - Returns - ------- - Any - The cache context. - """ - from anemoi.datasets.create.gridded.utils import cache_context - - return cache_context(self.cache) - - def check_unkown_kwargs(self, kwargs: dict) -> None: - """Check for unknown keyword arguments. - - Parameters - ---------- - kwargs : dict - The keyword arguments. - """ - # remove this latter - LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") - - def read_dataset_metadata(self, path: str) -> None: - """Read the metadata of the dataset. - - Parameters - ---------- - path : str - The path to the dataset. - """ - ds = open_dataset(path) - self.dataset_shape = ds.shape - self.variables_names = ds.variables - assert len(self.variables_names) == ds.shape[1], self.dataset_shape - self.dates = ds.dates - - self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) - - def check_missing_dates(expected: list[np.datetime64]) -> None: - """Check if the missing dates in the dataset match the expected dates. - - Parameters - ---------- - expected : list of np.datetime64 - The expected missing dates. - - Raises - ------ - ValueError - If the missing dates in the dataset do not match the expected dates. - """ - import zarr - - z = zarr.open(path, "r") - missing_dates = z.attrs.get("missing_dates", []) - missing_dates = sorted([np.datetime64(d) for d in missing_dates]) - if missing_dates != expected: - LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") - LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") - LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") - raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") - - check_missing_dates(self.missing_dates) - - -class HasRegistryMixin: - """A mixin class to provide registry functionality.""" - - @cached_property - def registry(self) -> Any: - """Get the registry.""" - from anemoi.datasets.create.gridded.zarr import ZarrBuiltRegistry - - return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) - - -class HasStatisticTempMixin: - """A mixin class to provide temporary statistics functionality.""" - - @cached_property - def tmp_statistics(self) -> TmpStatistics: - """Get the temporary statistics.""" - directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") - return TmpStatistics(directory) - - -class HasElementForDataMixin: - """A mixin class to provide element creation functionality for data.""" - - def create_elements(self, config: Any) -> None: - """Create elements for the dataset. - - Parameters - ---------- - config : Any - The configuration. - """ - assert self.registry - assert self.tmp_statistics - - LOG.info(dict(config.dates)) - - self.groups = Groups(**config.dates) - LOG.info(self.groups) - - self.output = build_output(config.output, parent=self) - - self.context = FieldContext( - order_by=self.output.order_by, - flatten_grid=self.output.flatten_grid, - remapping=build_remapping(self.output.remapping), - use_grib_paramid=config.build.use_grib_paramid, - ) - - self.input = InputBuilder( - config.input, - data_sources=config.get("data_sources", {}), - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(self.input) - - -def _validate_config(config: Any) -> None: - - import json - - import jsonschema - - def _tidy(d): - if isinstance(d, dict): - return {k: _tidy(v) for k, v in d.items()} - - if isinstance(d, list): - return [_tidy(v) for v in d if v is not None] - - # jsonschema does not support datetime.date - if isinstance(d, datetime.datetime): - return d.isoformat() - - if isinstance(d, datetime.date): - return d.isoformat() - - return d - - # https://json-schema.org - - with open( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "schemas", - "recipe.json", - ) - ) as f: - schema = json.load(f) - - try: - jsonschema.validate(instance=_tidy(config), schema=schema) - except jsonschema.exceptions.ValidationError as e: - LOG.error("❌ Config validation failed (jsonschema):") - LOG.error(e.message) - raise - - -def _config_to_python(config: Any) -> Any: - - from anemoi.datasets.create.create.python import PythonScript - - raw_config = config - - config = loader_config(config) - - input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - - code = PythonScript() - x = input.python_code(code) - code = code.source_code(x, raw_config) - - try: - import black - - return black.format_str(code, mode=black.Mode()) - # except ImportError: - except Exception: - LOG.warning("Black not installed, skipping formatting") - return code - - -class TaskCreator: - """A class to create and run dataset creation tasks.""" - - def init(self, *args: Any, **kwargs: Any): - from .init import Init - - return Init(*args, **kwargs) - - def load(self, *args: Any, **kwargs: Any): - from .load import Load - - return Load(*args, **kwargs) - - def size(self, *args: Any, **kwargs: Any): - from .size import Size - - return Size(*args, **kwargs) - - def patch(self, *args: Any, **kwargs: Any): - from .patch import Patch - - return Patch(*args, **kwargs) - - def statistics(self, *args: Any, **kwargs: Any): - from .statistics import Statistics - - return Statistics(*args, **kwargs) - - def finalise(self, *args: Any, **kwargs: Any): - from .cleanup import Cleanup - from .size import Size - from .statistics import Statistics - - return chain([Statistics, Size, Cleanup])(*args, **kwargs) - - def cleanup(self, *args: Any, **kwargs: Any): - from .cleanup import Cleanup - - return Cleanup(*args, **kwargs) - - def verify(self, *args: Any, **kwargs: Any): - from .verify import Verify - - return Verify(*args, **kwargs) - - def init_additions(self, *args: Any, **kwargs: Any): - from .additions import InitAdditions - - return InitAdditions(*args, **kwargs) - - def load_additions(self, *args: Any, **kwargs: Any): - from .additions import LoadAdditions - - return LoadAdditions(*args, **kwargs) - - def finalise_additions(self, *args: Any, **kwargs: Any): - from .additions import FinaliseAdditions - from .size import Size - - return chain([FinaliseAdditions, Size])(*args, **kwargs) - - def additions(self, *args: Any, **kwargs: Any): - from .additions import FinaliseAdditions - from .additions import InitAdditions - from .additions import LoadAdditions - from .cleanup import Cleanup - from .size import Size - - return chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) diff --git a/src/anemoi/datasets/create/fields/verify.py b/src/anemoi/datasets/create/fields/verify.py deleted file mode 100644 index 27b3e5f24..000000000 --- a/src/anemoi/datasets/create/fields/verify.py +++ /dev/null @@ -1,34 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from typing import Any - -from .tasks import FieldTask - -LOG = logging.getLogger(__name__) - - -class Verify(FieldTask): - """A class to verify the integrity of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Verify instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the verification.""" - LOG.info(f"Verifying dataset at {self.path}") - LOG.info(str(self.dataset.anemoi_dataset)) diff --git a/src/anemoi/datasets/create/observations/__init__.py b/src/anemoi/datasets/create/observations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/anemoi/datasets/create/fields/__init__.py b/src/anemoi/datasets/create/tabular/__init__.py similarity index 100% rename from src/anemoi/datasets/create/fields/__init__.py rename to src/anemoi/datasets/create/tabular/__init__.py diff --git a/src/anemoi/datasets/create/observations/tasks.py b/src/anemoi/datasets/create/tabular/tasks.py similarity index 100% rename from src/anemoi/datasets/create/observations/tasks.py rename to src/anemoi/datasets/create/tabular/tasks.py From 2a1b980d2adee6261b1c061edc2a7f1367ce68c8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 11:25:30 +0000 Subject: [PATCH 179/212] update --- src/anemoi/datasets/create/gridded/config.py | 9 +- src/anemoi/datasets/create/input/__init__.py | 2 - src/anemoi/datasets/create/input/context.py | 9 +- .../datasets/create/input/data_sources.py | 5 - src/anemoi/datasets/create/input/origin.py | 159 ------------------ src/anemoi/datasets/create/sources/csv.py | 6 +- .../datasets/create/sources/eccc_fstd.py | 4 +- src/anemoi/datasets/create/sources/fdb.py | 2 +- .../datasets/create/sources/forcings.py | 2 +- .../datasets/create/sources/grib_index.py | 3 +- .../create/sources/planetary_computer.py | 4 +- .../datasets/create/sources/repeated_dates.py | 16 +- src/anemoi/datasets/create/sources/xarray.py | 11 +- .../create/sources/xarray_kerchunk.py | 4 +- .../create/sources/xarray_support/field.py | 6 +- .../sources/xarray_support/fieldlist.py | 12 +- .../create/sources/xarray_support/flavour.py | 38 ++--- .../create/sources/xarray_support/metadata.py | 2 +- .../create/sources/xarray_support/time.py | 4 +- .../create/sources/xarray_support/variable.py | 2 +- src/anemoi/datasets/dates/__init__.py | 14 +- src/anemoi/datasets/misc/grids.py | 6 +- src/anemoi/datasets/use/gridded/complement.py | 8 +- 23 files changed, 69 insertions(+), 259 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/origin.py diff --git a/src/anemoi/datasets/create/gridded/config.py b/src/anemoi/datasets/create/gridded/config.py index 1d10d081c..4720ebb6b 100644 --- a/src/anemoi/datasets/create/gridded/config.py +++ b/src/anemoi/datasets/create/gridded/config.py @@ -10,8 +10,6 @@ import datetime import logging import os -import subprocess -import sys from copy import deepcopy from typing import Any @@ -93,7 +91,7 @@ def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: if dic[key] == value: return raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") - # LOG.info(f"Setting {key}={value} in config") + LOG.info(f"Setting {key}={value} in config") dic[key] = value @@ -403,11 +401,6 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: LoadersConfig The validated configuration object. """ - - if isinstance(config, str) and config.endswith(".py"): - result = subprocess.run([sys.executable, config], capture_output=True, text=True, check=True) - config = yaml.safe_load(result.stdout) - config = Config(config) if is_test: set_to_test_mode(config) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index ac0222726..e4e312fa8 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -1,5 +1,4 @@ # (C) Copyright 2024-2025 Anemoi contributors. -# (C) Copyright 2024-2025 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -27,7 +26,6 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No ---------- config : dict Configuration dictionary. - data_sources : dict data_sources : dict Data sources. **kwargs : Any diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py index 28c797dd5..89df7a727 100644 --- a/src/anemoi/datasets/create/input/context.py +++ b/src/anemoi/datasets/create/input/context.py @@ -18,9 +18,10 @@ class Context(ABC): """Context for building input data.""" - def __init__(self) -> None: + def __init__(self, /, argument: Any) -> None: self.results = {} self.cache = {} + self.argument = argument def trace(self, emoji, *message) -> None: @@ -33,7 +34,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - LOG.info(f"Registering data at path: {'.'.join(str(x) for x in path)}") + LOG.info(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -46,9 +47,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: - print(f"Path not found {path}") + LOG.warning(f"Path not found {path}") for p in sorted(self.results): - print(f" Available paths: {p}") + LOG.info(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index f53ed3674..2f776dff9 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -84,11 +84,6 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_code(self, code) -> str: - for n, s in zip(self.names, self.sources): - code.source(n, s.python_code(code)) - return code - class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py deleted file mode 100644 index 9f5173afc..000000000 --- a/src/anemoi/datasets/create/input/origin.py +++ /dev/null @@ -1,159 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from abc import ABC - -LOG = logging.getLogger(__name__) - - -class Origin(ABC): - - def __init__(self, when="dataset-create"): - self.when = when - - def __eq__(self, other): - if not isinstance(other, Origin): - return False - return self is other - - def __hash__(self): - return id(self) - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class Pipe(Origin): - def __init__(self, s1, s2, when="dataset-create"): - super().__init__(when) - self.steps = [s1, s2] - - assert s1 is not None, (s1, s2) - assert s2 is not None, (s1, s2) - - if isinstance(s1, Pipe): - assert not isinstance(s2, Pipe), (s1, s2) - self.steps = s1.steps + [s2] - - def combine(self, previous, action, action_arguments): - assert False, (self, previous) - - def as_dict(self): - return { - "type": "pipe", - "steps": [s.as_dict() for s in self.steps], - "when": self.when, - } - - def __repr__(self): - return " | ".join(repr(s) for s in self.steps) - - -class Join(Origin): - def __init__(self, origins, when="dataset-create"): - assert isinstance(origins, (list, tuple, set)), origins - super().__init__(when) - self.steps = list(origins) - - assert all(o is not None for o in origins), origins - - def combine(self, previous, action, action_arguments): - assert False, (self, previous) - - def as_dict(self): - return { - "type": "join", - "steps": [s.as_dict() for s in self.steps], - "when": self.when, - } - - def __repr__(self): - return " & ".join(repr(s) for s in self.steps) - - -class Source(Origin): - def __init__(self, name, config, when="dataset-create"): - super().__init__(when) - assert isinstance(config, dict), f"Config must be a dictionary {config}" - self.name = name - self.config = _un_dotdict(config) - - def combine(self, previous, action, action_arguments): - assert previous is None, f"Cannot combine origins, previous already exists: {previous}" - return self - - def as_dict(self): - return { - "type": "source", - "name": self.name, - "config": self.config, - "when": self.when, - } - - def __repr__(self): - return f"{self.name}({id(self)})" - - -class Filter(Origin): - def __init__(self, name, config, when="dataset-create"): - super().__init__(when) - assert isinstance(config, dict), f"Config must be a dictionary {config}" - self.name = name - self.config = _un_dotdict(config) - self._cache = {} - - def combine(self, previous, action, action_arguments): - - if previous is None: - # This can happen if the filter does not tag its output with an origin - # (e.g. a user plugin). In that case we try to get the origin from the action arguments - key = (id(action), id(action_arguments)) - if key not in self._cache: - - LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") - LOG.warning(f"Connecting to action arguments {action_arguments}") - origins = set() - for k in action_arguments: - o = k.metadata("anemoi_origin", default=None) - if o is None: - raise ValueError( - f"Cannot combine origins, previous is None and action_arguments {action_arguments} has no origin" - ) - origins.add(o) - if len(origins) == 1: - self._cache[key] = origins.pop() - else: - self._cache[key] = Join(origins) - previous = self._cache[key] - - if previous in self._cache: - # We use a cache to avoid recomputing the same combination - return self._cache[previous] - - self._cache[previous] = Pipe(previous, self) - return self._cache[previous] - - def as_dict(self): - return { - "type": "filter", - "name": self.name, - "config": self.config, - "when": self.when, - } - - def __repr__(self): - return f"{self.name}({id(self)})" diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 48ba1f9f8..8e5a329f5 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -8,12 +8,12 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.source import ObservationsSource -from anemoi.datasets.create.sources import source_registry +from ..source import Source +from . import source_registry @source_registry.register("csv") -class CSVSource(ObservationsSource): +class CSVSource(Source): """A source that reads data from a CSV file.""" emoji = "📄" # For tracing diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py index fdd79af8d..41734e9b6 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/create/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py index 382ccd3a1..67bfe8870 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/create/sources/fdb.py @@ -18,8 +18,8 @@ from anemoi.datasets.create.gridded.typing import DateList +from ..source import Source from . import source_registry -from .source import Source @source_registry.register("fdb") diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index 4750029ed..6070772fc 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -32,11 +32,11 @@ def _execute(context: Any, dates: list[str], template: str, param: str) -> Any: Template for the data source. param : str Parameter for the data source. + Returns ------- object Loaded forcing data. """ - context.trace("✅", f"from_source(forcings, {template}, {param}") return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index 0935aff96..0d86732f6 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -19,9 +19,8 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from anemoi.datasets.create.sources.legacy import LegacySource - from . import source_registry +from .legacy import LegacySource LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py index 07e8f0203..b710bcbbe 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index 77a06c76c..c484efd82 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -10,8 +10,8 @@ import logging from collections import defaultdict -from collections.abc import Generator from typing import Any +from typing import Generator import numpy as np from anemoi.transform.fields import new_field_with_valid_datetime @@ -19,18 +19,8 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry - -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - +from ..source import Source +from . import source_registry LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py index 5e3cc4c10..a735e52f6 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/create/sources/xarray.py @@ -11,11 +11,12 @@ import earthkit.data as ekd -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources.xarray_support import XarrayFieldList -from anemoi.datasets.create.sources.xarray_support import load_many -from anemoi.datasets.create.sources.xarray_support import load_one -from anemoi.datasets.create.typing import DateList +from anemoi.datasets.create.gridded.typing import DateList + +from ..source import Source +from .xarray_support import XarrayFieldList +from .xarray_support import load_many +from .xarray_support import load_one __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py index 632a7cae2..056d756ca 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/create/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 85f9970f8..78f7de041 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata +from .coordinates import extract_single_value +from .coordinates import is_scalar +from .metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py index 174cb2716..48f9cf0e1 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList -from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser -from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset -from anemoi.datasets.create.sources.xarray_support.time import Time -from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from .field import EmptyFieldList +from .flavour import CoordinateGuesser +from .patch import patch_dataset +from .time import Time +from .variable import FilteredVariable +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 74fcdbd03..80f0b6a62 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.grid import Grid -from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid -from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid +from .coordinates import Coordinate +from .coordinates import DateCoordinate +from .coordinates import EnsembleCoordinate +from .coordinates import LatitudeCoordinate +from .coordinates import LevelCoordinate +from .coordinates import LongitudeCoordinate +from .coordinates import PointCoordinate +from .coordinates import ScalarCoordinate +from .coordinates import StepCoordinate +from .coordinates import TimeCoordinate +from .coordinates import UnsupportedCoordinate +from .coordinates import XCoordinate +from .coordinates import YCoordinate +from .coordinates import is_scalar +from .grid import Grid +from .grid import MeshedGrid +from .grid import MeshProjectionGrid +from .grid import UnstructuredGrid +from .grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 2230db3ef..23713ae74 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from anemoi.datasets.create.sources.xarray_support.field import XArrayField + from .field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py index 7b1f60e58..847b21598 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/create/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from .coordinates import Coordinate +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 13d6fa4e2..5d2c1c5b1 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from anemoi.datasets.create.sources.xarray_support.field import XArrayField +from .field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 223736971..18a09ecfd 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -27,15 +27,13 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Parameters - ---------- - x : Union[str, List[Any], Tuple[Any, ...]] - A date range string or list/tuple of dates. + Args: + x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. - Yields - ------ - datetime.datetime - Individual datetime objects. + Returns + ------- + Iterator[datetime.datetime] + An iterator of datetime objects. """ if isinstance(x, (list, tuple)): diff --git a/src/anemoi/datasets/misc/grids.py b/src/anemoi/datasets/misc/grids.py index 2ec50f69c..075f73495 100644 --- a/src/anemoi/datasets/misc/grids.py +++ b/src/anemoi/datasets/misc/grids.py @@ -477,7 +477,7 @@ def nearest_grid_points( """ # TODO: Use the one from anemoi.utils.grids instead # from anemoi.utils.grids import ... - from scipy.spatial import KDTree + from scipy.spatial import cKDTree source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) source_points = np.array(source_xyz).transpose() @@ -485,7 +485,7 @@ def nearest_grid_points( target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) target_points = np.array(target_xyz).transpose() if max_distance is None: - distances, indices = KDTree(source_points).query(target_points, k=k) + distances, indices = cKDTree(source_points).query(target_points, k=k) else: - distances, indices = KDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) return distances, indices diff --git a/src/anemoi/datasets/use/gridded/complement.py b/src/anemoi/datasets/use/gridded/complement.py index 7fac83c14..1881a74fa 100644 --- a/src/anemoi/datasets/use/gridded/complement.py +++ b/src/anemoi/datasets/use/gridded/complement.py @@ -249,13 +249,7 @@ def __init__(self, target: Any, source: Any, max_distance: float = None, k: int """ super().__init__(target, source) - if isinstance(k, str): - assert False - LOG.warning(f"ComplementNearest: Interpreting k={k} ({type(k)}) as integer") - k = int(k) - self.k = k - self._distances, self._nearest_grid_points = nearest_grid_points( self._source.latitudes, self._source.longitudes, @@ -359,7 +353,7 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset: }[interpolation] if interpolation == "nearest": - k = kwargs.pop("k", 1) + k = kwargs.pop("k", "1") complement = Class(target=target, source=source, k=k)._subset(**kwargs) else: From cc6c38437838378716dfeba64e53f6dba47e89fa Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 11:29:36 +0000 Subject: [PATCH 180/212] update --- src/anemoi/datasets/use/gridded/concat.py | 28 +++++++++---------- src/anemoi/datasets/use/gridded/ensemble.py | 6 ---- src/anemoi/datasets/use/gridded/forwards.py | 3 -- .../datasets/use/gridded/interpolate.py | 22 +++++++-------- src/anemoi/datasets/use/gridded/join.py | 28 +++++++++---------- src/anemoi/datasets/use/gridded/masked.py | 6 ---- src/anemoi/datasets/use/gridded/padded.py | 2 +- src/anemoi/datasets/use/gridded/rescale.py | 7 ----- src/anemoi/datasets/use/gridded/select.py | 11 -------- src/anemoi/datasets/use/gridded/subset.py | 26 ++++++++--------- 10 files changed, 53 insertions(+), 86 deletions(-) diff --git a/src/anemoi/datasets/use/gridded/concat.py b/src/anemoi/datasets/use/gridded/concat.py index d9fb2a01b..d6bcb6297 100644 --- a/src/anemoi/datasets/use/gridded/concat.py +++ b/src/anemoi/datasets/use/gridded/concat.py @@ -16,20 +16,20 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Dataset -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Shape -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Node -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.forwards import Combined -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import length_to_slices -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.use.griddedanemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.use.griddedanemoi.datasets.data.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/ensemble.py b/src/anemoi/datasets/use/gridded/ensemble.py index 1ecff8b97..0d1aa15b2 100644 --- a/src/anemoi/datasets/use/gridded/ensemble.py +++ b/src/anemoi/datasets/use/gridded/ensemble.py @@ -124,12 +124,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """Returns metadata specific to the Number object.""" return {} - def origin_transformation(self, variable, origins): - return { - "name": "number", - "config": {"members": self.members}, - } - class Ensemble(GivenAxis): """A class to represent an ensemble of datasets combined along a given axis.""" diff --git a/src/anemoi/datasets/use/gridded/forwards.py b/src/anemoi/datasets/use/gridded/forwards.py index 3966dd34b..0ee6f8ac7 100644 --- a/src/anemoi/datasets/use/gridded/forwards.py +++ b/src/anemoi/datasets/use/gridded/forwards.py @@ -240,9 +240,6 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields - def project(self, projection): - return self.forward.project(projection).add_transformation(self) - class Combined(Forwards): """A class to combine multiple datasets into a single dataset.""" diff --git a/src/anemoi/datasets/use/gridded/interpolate.py b/src/anemoi/datasets/use/gridded/interpolate.py index 7c9e45c81..f3c5155f9 100644 --- a/src/anemoi/datasets/use/gridded/interpolate.py +++ b/src/anemoi/datasets/use/gridded/interpolate.py @@ -17,17 +17,17 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Dataset -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Shape -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Node -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.forwards import Forwards -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/join.py b/src/anemoi/datasets/use/gridded/join.py index d2d1a9033..7d150f01b 100644 --- a/src/anemoi/datasets/use/gridded/join.py +++ b/src/anemoi/datasets/use/gridded/join.py @@ -16,20 +16,20 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Dataset -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Shape -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Node -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Source -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.forwards import Combined -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import update_tuple -from anemoi.datasets.use.griddedanemoi.datasets.data.misc import _auto_adjust -from anemoi.datasets.use.griddedanemoi.datasets.data.misc import _open +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple +from anemoi.datasets.use.gridded.misc import _auto_adjust +from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/use/gridded/masked.py b/src/anemoi/datasets/use/gridded/masked.py index 18c96a37c..675ae8dc2 100644 --- a/src/anemoi/datasets/use/gridded/masked.py +++ b/src/anemoi/datasets/use/gridded/masked.py @@ -200,12 +200,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(thinning=self.thinning, method=self.method) - def origin_transformation(self, variable, origins): - return { - "name": "thinning", - "config": dict(thinning=self.thinning, method=self.method), - } - class Cropping(Masked): """A class to represent a cropped dataset.""" diff --git a/src/anemoi/datasets/use/gridded/padded.py b/src/anemoi/datasets/use/gridded/padded.py index 37037ad56..df7b793ec 100644 --- a/src/anemoi/datasets/use/gridded/padded.py +++ b/src/anemoi/datasets/use/gridded/padded.py @@ -18,7 +18,7 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import Shape diff --git a/src/anemoi/datasets/use/gridded/rescale.py b/src/anemoi/datasets/use/gridded/rescale.py index 4ecc1849d..8426bffbe 100644 --- a/src/anemoi/datasets/use/gridded/rescale.py +++ b/src/anemoi/datasets/use/gridded/rescale.py @@ -242,10 +242,3 @@ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict raise NotImplementedError("rescale tendencies statistics", k) return result - - def origin_transformation(self, variable, origins): - config = {} - for variable, (a, b) in self.rescale.items(): - config[variable] = {"scale": a, "offset": b} - - return {"name": "rescale", "config": config} diff --git a/src/anemoi/datasets/use/gridded/select.py b/src/anemoi/datasets/use/gridded/select.py index 344ae56a4..9a8fcc385 100644 --- a/src/anemoi/datasets/use/gridded/select.py +++ b/src/anemoi/datasets/use/gridded/select.py @@ -224,17 +224,6 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: # return dict(indices=self.indices) return dict(reason=self.reason) - def forward_subclass_origin(self, index): - assert ( - isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) - ), tuple - - return self.dataset.origin((index[0], self.indices[index[1]], index[2], index[3])) - - def project(self, projection): - projection = projection.from_indices(axis=1, indices=self.indices) - return self.dataset.project(projection) - class Rename(Forwards): """Class to rename variables in a dataset.""" diff --git a/src/anemoi/datasets/use/gridded/subset.py b/src/anemoi/datasets/use/gridded/subset.py index 5e8c1cfb7..f14501a3e 100644 --- a/src/anemoi/datasets/use/gridded/subset.py +++ b/src/anemoi/datasets/use/gridded/subset.py @@ -19,19 +19,19 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Dataset -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import FullIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import Shape -from anemoi.datasets.use.griddedanemoi.datasets.data.dataset import TupleIndex -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Node -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import Source -from anemoi.datasets.use.griddedanemoi.datasets.data.debug import debug_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.forwards import Forwards -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import apply_index_to_slices_changes -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import expand_list_indexing -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import index_to_slices -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import make_slice_or_index_from_list_or_tuple -from anemoi.datasets.use.griddedanemoi.datasets.data.indexing import update_tuple +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.dataset import FullIndex +from anemoi.datasets.use.gridded.dataset import Shape +from anemoi.datasets.use.gridded.dataset import TupleIndex +from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import Source +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Forwards +from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing +from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import make_slice_or_index_from_list_or_tuple +from anemoi.datasets.use.gridded.indexing import update_tuple LOG = logging.getLogger(__name__) From 6f3abee2dc5b47ac22e1648766e681490dff7e35 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 11:36:45 +0000 Subject: [PATCH 181/212] update --- src/anemoi/datasets/create/tasks.py | 2 +- src/anemoi/datasets/use/gridded/components.py | 258 ------------------ .../datasets/use/tabular/records/__init__.py | 2 +- .../use/tabular/records/backends/__init__.py | 6 +- 4 files changed, 5 insertions(+), 263 deletions(-) delete mode 100644 src/anemoi/datasets/use/gridded/components.py diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py index af249d730..23728e6aa 100644 --- a/src/anemoi/datasets/create/tasks.py +++ b/src/anemoi/datasets/create/tasks.py @@ -49,7 +49,7 @@ def run(self) -> None: def task_factory(name: str, trace: str | None = None, **kwargs): if True: - from anemoi.datasets.create.fields.tasks import TaskCreator + from anemoi.datasets.create.gridded.tasks import TaskCreator creator = TaskCreator() else: diff --git a/src/anemoi/datasets/use/gridded/components.py b/src/anemoi/datasets/use/gridded/components.py deleted file mode 100644 index c16c20c6f..000000000 --- a/src/anemoi/datasets/use/gridded/components.py +++ /dev/null @@ -1,258 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -from collections import defaultdict - - -def _indices_to_slices(indices: list[int]) -> list[slice]: - indices = sorted(indices) - assert len(indices) == len(set(indices)), "Duplicate indices are not allowed" - - if not indices: - return [] - - slices = [] - n = len(indices) - i = 0 - - while i < n: - start = indices[i] - # default step = 1 - if i + 1 < n: - step = indices[i + 1] - indices[i] - else: - step = 1 - - j = i + 1 - while j < n and indices[j] - indices[j - 1] == step: - j += 1 - - stop = indices[j - 1] + step - slices.append(slice(start, stop, step)) - i = j - - check = list() - for s in slices: - check.extend(range(s.start, s.stop, s.step)) - - assert check == list(indices), slices - - return slices - - -def _combine_slices(length, *slices): - - start, step, current_length = 0, 1, length - - for s in slices: - assert s.stop >= s.start and s.step > 0 - new_start, new_stop, new_step = s.indices(current_length) - new_length = len(range(new_start, new_stop, new_step)) - start = start + new_start * step - step = step * new_step - current_length = new_length - - if current_length == 0: - return slice(0, 0, 1) # canonical empty slice - - if current_length == 0: - return slice(0, 0, 1) - - stop = start + current_length * step - - return slice(start, stop, step) - - -class ProjectionBase: - - def from_store(self, slices, store): - return ProjectionStore(slices, store) - - @classmethod - def from_slices(cls, slices): - return Projection(slices) - - @classmethod - def list_or_single(cls, projections): - if len(projections) == 1: - return projections[0] - return ProjectionList(projections) - - def ensure_list(self): - return ProjectionList([self]) - - def compressed_origins(self): - result = defaultdict(list) - for p in self.ensure_list(): - for k, v in p.origins().items(): - result[k].append(v) - return result - - -class Projection(ProjectionBase): - - def __init__(self, slices): - assert isinstance(slices, (list, tuple)), slices - assert all(isinstance(s, slice) for s in slices), slices - assert len(slices) == 4, slices - self.slices = tuple(slices) - - def from_indices(self, *, axis, indices): - length = max(indices) + 1 - slices = _indices_to_slices(indices) - this_slice = self.slices[axis] - combined = [] - - for s in slices: - combined.append(_combine_slices(max(this_slice.stop, s.stop, length), s, this_slice)) - - projections = [ - Projection([c if i == axis else self.slices[i] for i in range(len(self.slices))]) for c in combined - ] - - return self.list_or_single(projections) - - def __repr__(self): - return f"Projection(slices={self.slices})" - - def offset(self, axis, amount): - return Projection( - [ - ( - slice( - s.start + amount, - s.stop + amount, - s.step, - ) - if i == axis - else s - ) - for i, s in enumerate(self.slices) - ] - ) - - -class ProjectionList(ProjectionBase): - def __init__(self, projections): - assert isinstance(projections, (list, tuple)), projections - assert all(isinstance(p, ProjectionBase) for p in projections), projections - - self.projections = [] - for p in projections: - if isinstance(p, ProjectionList): - self.projections.extend(p.projections) - else: - self.projections.append(p) - - def from_indices(self, *, axis, indices): - return ProjectionList([p.from_indices(axis=axis, indices=indices) for p in self.projections]) - - def __repr__(self): - return "ProjectionList(" + ",".join(repr(p) for p in self.projections) + ")" - - def ensure_list(self): - return self - - def __iter__(self): - return iter(self.projections) - - def add_transformation(self, transformation): - return ProjectionList([p.add_transformation(transformation) for p in self.projections]) - - -class ProjectionStore(ProjectionBase): - def __init__(self, slices, store, transformations=None): - assert isinstance(slices, (list, tuple)), slices - assert all(isinstance(s, slice) for s in slices), slices - assert len(slices) == 4, slices - - self.slices = slices - self.store = store - self.transformations = transformations or [] - - def __repr__(self): - return repr((self.slices, self.store.dataset_name)) - - def apply(self, projection): - - projections = projection.ensure_list() - - result = [] - - for projection in projections: - - slices = [] - for a, b in zip(self.slices, projection.slices): - slices.append(_combine_slices(a.stop, a, b)) - result.append(ProjectionStore(slices, self.store)) - - return self.list_or_single(result) - - def variables(self): - return self.store.variables[self.slices[1]] - - def origins(self, compressed=False): - result = {} - for variable in self.variables(): - - origins = self.store.origins[variable] - - pipe = [] - for transformation in self.transformations: - - action = transformation.origin_transformation(variable, origins) - if isinstance(action, tuple): - # Needed to support 'rename' - action, variable = action - - action = action.copy() - action.setdefault("when", "dataset-usage") - action.setdefault("type", "filter") - pipe.append(action) - - if pipe: - origins = { - "type": "pipe", - "when": "dataset-usage", - "steps": [origins] + pipe, - } - - result[variable] = origins - - if compressed: - - def _hashable(v): - if isinstance(v, dict): - return tuple((k, _hashable(vv)) for k, vv in sorted(v.items())) - if isinstance(v, list): - return tuple(_hashable(vv) for vv in v) - return v - - compressed_result = defaultdict(list) - for k, v in result.items(): - compressed_result[_hashable(v)].append((k, v)) - - result = {} - for v in compressed_result.values(): - key = tuple(sorted(k for k, _ in v)) - value = v[0][1] - result[key] = value - - return result - - def add_transformation(self, transformation): - return ProjectionStore(self.slices, self.store, self.transformations + [transformation]) - - def __iter__(self): - return iter([self]) - - @property - def dataset_name(self): - return self.store.dataset_name diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index 9216bcadc..13b729ef0 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -18,7 +18,7 @@ from anemoi.utils.config import load_any_dict_format from anemoi.utils.dates import frequency_to_timedelta -from anemoi.datasets.use.debug import Node +from anemoi.datasets.use.gridded.debug import Node from .records.backends import backend_factory from .windows import window_from_str diff --git a/src/anemoi/datasets/use/tabular/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py index bda4de274..5d27203ff 100644 --- a/src/anemoi/datasets/use/tabular/records/backends/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/backends/__init__.py @@ -213,7 +213,7 @@ def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): os.rename(tmp_path, out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.gridded.fields.tasks import _json_tidy + from anemoi.datasets.create.gridded.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) @@ -257,7 +257,7 @@ def write(self, i, data, **kwargs): ds.to_netcdf(out_path) def write_metadata(self, metadata): - from anemoi.datasets.create.fields.tasks import _json_tidy + from anemoi.datasets.create.gridded.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: @@ -295,7 +295,7 @@ def write(self, i, data, **kwargs): np.savez(out_path, **data) def write_metadata(self, metadata): - from anemoi.datasets.create.gridded.fields.tasks import _json_tidy + from anemoi.datasets.create.gridded.tasks import _json_tidy os.makedirs(self.path, exist_ok=True) with open(os.path.join(self.path, "metadata.json"), "w") as f: From 819a860bbb54bc424f3c4da2b4f980dfde68523a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 15:03:40 +0000 Subject: [PATCH 182/212] checkpoint --- .gitignore | 3 + src/anemoi/datasets/check.py | 93 + src/anemoi/datasets/commands/check.py | 7 +- src/anemoi/datasets/commands/copy.py | 5 +- src/anemoi/datasets/commands/create.py | 2 +- src/anemoi/datasets/commands/inspect.py | 15 +- .../datasets/commands/recipe/__init__.py | 11 +- src/anemoi/datasets/commands/recipe/format.py | 2 +- .../datasets/commands/recipe/migrate.py | 555 ++++++ src/anemoi/datasets/commands/validate.py | 5 +- src/anemoi/datasets/create/__init__.py | 8 - src/anemoi/datasets/create/check.py | 328 ++++ src/anemoi/datasets/create/chunks.py | 138 ++ src/anemoi/datasets/create/config.py | 453 +++++ src/anemoi/datasets/create/filter.py | 47 + .../datasets/create/gridded/__init__.py | 1658 ----------------- .../datasets/create/gridded/additions.py | 413 ++++ src/anemoi/datasets/create/gridded/cleanup.py | 60 + src/anemoi/datasets/create/gridded/context.py | 36 +- src/anemoi/datasets/create/gridded/init.py | 293 +++ src/anemoi/datasets/create/gridded/load.py | 260 +++ src/anemoi/datasets/create/gridded/patch.py | 188 +- src/anemoi/datasets/create/gridded/result.py | 74 +- src/anemoi/datasets/create/gridded/size.py | 55 +- .../datasets/create/gridded/statistics.py | 102 + .../gridded/{statistics => stats}/__init__.py | 2 +- .../gridded/{statistics => stats}/summary.py | 0 src/anemoi/datasets/create/gridded/tasks.py | 606 ++++++ src/anemoi/datasets/create/gridded/verify.py | 34 + src/anemoi/datasets/create/input/__init__.py | 24 +- src/anemoi/datasets/create/input/action.py | 162 +- src/anemoi/datasets/create/input/context.py | 9 +- .../datasets/create/input/data_sources.py | 7 +- src/anemoi/datasets/create/input/misc.py | 71 + src/anemoi/datasets/create/input/origin.py | 159 ++ src/anemoi/datasets/create/patch.py | 188 ++ src/anemoi/datasets/create/persistent.py | 269 +++ src/anemoi/datasets/create/size.py | 47 + src/anemoi/datasets/create/source.py | 2 +- .../datasets/create/sources/accumulations.py | 171 +- .../datasets/create/sources/accumulations2.py | 64 +- .../datasets/create/sources/anemoi_dataset.py | 88 +- .../datasets/create/sources/constants.py | 77 +- src/anemoi/datasets/create/sources/csv.py | 40 +- .../datasets/create/sources/eccc_fstd.py | 4 +- src/anemoi/datasets/create/sources/empty.py | 48 +- src/anemoi/datasets/create/sources/fdb.py | 9 +- .../datasets/create/sources/forcings.py | 57 +- src/anemoi/datasets/create/sources/grib.py | 171 +- .../datasets/create/sources/grib_index.py | 88 +- .../datasets/create/sources/hindcasts.py | 114 +- src/anemoi/datasets/create/sources/legacy.py | 75 +- src/anemoi/datasets/create/sources/mars.py | 239 +-- src/anemoi/datasets/create/sources/netcdf.py | 58 +- src/anemoi/datasets/create/sources/opendap.py | 58 +- .../create/sources/planetary_computer.py | 4 +- .../datasets/create/sources/recentre.py | 86 +- .../datasets/create/sources/repeated_dates.py | 16 +- src/anemoi/datasets/create/sources/source.py | 68 + .../datasets/create/sources/tendencies.py | 162 +- src/anemoi/datasets/create/sources/xarray.py | 11 +- .../create/sources/xarray_kerchunk.py | 4 +- .../create/sources/xarray_support/__init__.py | 56 +- .../create/sources/xarray_support/field.py | 6 +- .../sources/xarray_support/fieldlist.py | 12 +- .../create/sources/xarray_support/flavour.py | 38 +- .../create/sources/xarray_support/metadata.py | 2 +- .../create/sources/xarray_support/time.py | 4 +- .../create/sources/xarray_support/variable.py | 2 +- .../datasets/create/sources/xarray_zarr.py | 58 +- src/anemoi/datasets/create/sources/zenodo.py | 86 +- .../datasets/create/statistics/__init__.py | 561 ++++++ .../datasets/create/statistics/summary.py | 152 ++ src/anemoi/datasets/create/tasks.py | 4 +- src/anemoi/datasets/create/testing.py | 4 + src/anemoi/datasets/create/typing.py | 14 + src/anemoi/datasets/create/utils.py | 198 ++ src/anemoi/datasets/create/writer.py | 64 + src/anemoi/datasets/create/zarr.py | 331 ++++ src/anemoi/datasets/dates/__init__.py | 14 +- src/anemoi/datasets/grids.py | 668 +++++++ src/anemoi/datasets/testing.py | 173 ++ src/anemoi/datasets/use/__init__.py | 8 - src/anemoi/datasets/use/gridded/__init__.py | 2 +- src/anemoi/datasets/use/gridded/complement.py | 10 +- src/anemoi/datasets/use/gridded/dataset.py | 12 +- src/anemoi/datasets/use/gridded/ensemble.py | 6 + .../datasets/use/gridded/fill_missing.py | 2 +- src/anemoi/datasets/use/gridded/forwards.py | 3 + src/anemoi/datasets/use/gridded/grids.py | 154 +- .../datasets/use/gridded/interpolate.py | 2 +- src/anemoi/datasets/use/gridded/join.py | 3 + src/anemoi/datasets/use/gridded/masked.py | 10 +- src/anemoi/datasets/use/gridded/merge.py | 2 +- src/anemoi/datasets/use/gridded/misc.py | 23 +- src/anemoi/datasets/use/gridded/missing.py | 4 +- .../use/gridded/observations/__init__.py | 313 ++++ .../observations/legacy_obs_dataset.py | 200 ++ .../use/gridded/observations/multi.py | 64 + src/anemoi/datasets/use/gridded/rescale.py | 7 + src/anemoi/datasets/use/gridded/select.py | 11 + src/anemoi/datasets/use/gridded/statistics.py | 2 +- src/anemoi/datasets/use/gridded/stores.py | 12 +- .../use/tabular/observations/__init__.py | 6 +- .../datasets/use/tabular/records/__init__.py | 6 +- .../use/tabular/{records => }/windows.py | 0 src/anemoi/datasets/validate.py | 598 ++++++ ...ervations.py => dont_test_observations.py} | 4 +- ...mars.py => dont_test_observations_mars.py} | 10 +- ...py => dont_test_observations_mars_bufr.py} | 9 +- ...nt_test_observations_mars_bufr_complex.py} | 9 +- ...t_test_observations_mars_bufr_parallel.py} | 9 +- tests/create/test_sources.py | 22 +- tests/create/utils/create.py | 10 +- tests/test_classes.py | 9 +- tests/test_data.py | 2 +- tests/test_data_gridded.py | 2 +- tests/test_dates.py | 2 +- tests/test_records.py | 1 + tools/build-obs.py | 2 +- 120 files changed, 9061 insertions(+), 3040 deletions(-) create mode 100644 src/anemoi/datasets/check.py create mode 100644 src/anemoi/datasets/create/check.py create mode 100644 src/anemoi/datasets/create/chunks.py create mode 100644 src/anemoi/datasets/create/config.py create mode 100644 src/anemoi/datasets/create/filter.py create mode 100644 src/anemoi/datasets/create/gridded/additions.py create mode 100644 src/anemoi/datasets/create/gridded/cleanup.py create mode 100644 src/anemoi/datasets/create/gridded/init.py create mode 100644 src/anemoi/datasets/create/gridded/load.py mode change 100755 => 100644 src/anemoi/datasets/create/gridded/patch.py create mode 100644 src/anemoi/datasets/create/gridded/statistics.py rename src/anemoi/datasets/create/gridded/{statistics => stats}/__init__.py (99%) rename src/anemoi/datasets/create/gridded/{statistics => stats}/summary.py (100%) create mode 100644 src/anemoi/datasets/create/gridded/tasks.py create mode 100644 src/anemoi/datasets/create/gridded/verify.py create mode 100644 src/anemoi/datasets/create/input/origin.py create mode 100755 src/anemoi/datasets/create/patch.py create mode 100644 src/anemoi/datasets/create/persistent.py create mode 100644 src/anemoi/datasets/create/size.py create mode 100644 src/anemoi/datasets/create/sources/source.py create mode 100644 src/anemoi/datasets/create/statistics/__init__.py create mode 100644 src/anemoi/datasets/create/statistics/summary.py create mode 100644 src/anemoi/datasets/create/testing.py create mode 100644 src/anemoi/datasets/create/typing.py create mode 100644 src/anemoi/datasets/create/utils.py create mode 100644 src/anemoi/datasets/create/writer.py create mode 100644 src/anemoi/datasets/create/zarr.py create mode 100644 src/anemoi/datasets/grids.py create mode 100644 src/anemoi/datasets/testing.py create mode 100644 src/anemoi/datasets/use/gridded/observations/__init__.py create mode 100644 src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py create mode 100644 src/anemoi/datasets/use/gridded/observations/multi.py rename src/anemoi/datasets/use/tabular/{records => }/windows.py (100%) create mode 100644 src/anemoi/datasets/validate.py rename tests/create/{test_observations.py => dont_test_observations.py} (94%) rename tests/create/{test_observations_mars.py => dont_test_observations_mars.py} (95%) rename tests/create/{test_observations_mars_bufr.py => dont_test_observations_mars_bufr.py} (95%) rename tests/create/{test_observations_mars_bufr_complex.py => dont_test_observations_mars_bufr_complex.py} (95%) rename tests/create/{test_observations_mars_bufr_parallel.py => dont_test_observations_mars_bufr_parallel.py} (94%) diff --git a/.gitignore b/.gitignore index 031461ff0..9238c6425 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ trace.txt *.prof prof/ *.gz +*.odb +*.bufr +*.csv diff --git a/src/anemoi/datasets/check.py b/src/anemoi/datasets/check.py new file mode 100644 index 000000000..d795d13f9 --- /dev/null +++ b/src/anemoi/datasets/check.py @@ -0,0 +1,93 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +# A collection of functions to support pytest testing + +import logging +import math +import os +import re + +LOG = logging.getLogger(__name__) + + +def _check_group(group, verbosity: int, *path) -> None: + import zarr + + group_keys = sorted(group.keys()) + if not group_keys: + raise ValueError(f"Check group: {group} is empty.") + + for name in sorted(group_keys): + if name.startswith("."): + if verbosity > 1: + LOG.info(f"Check group: skipping {name}") + continue + + if isinstance(group[name], zarr.hierarchy.Group): + _check_group(group[name], verbosity, *path, name) + else: + _check_array(group[name], verbosity, *path, name) + + +def _check_array(array, verbosity: int, *path) -> None: + assert len(array.chunks) == len(array.shape) + assert math.prod(array.shape) % math.prod(array.chunks) == 0 + + file_count = math.prod(array.shape) // math.prod(array.chunks) + + full = os.path.join(*path) + + chunks = array.chunks + + count = 0 + for f in os.listdir(full): + if verbosity > 1: + LOG.info(f"Check array: checking {f}") + + if f.startswith("."): + if verbosity > 1: + LOG.info(f"Check array: skipping {f}") + continue + + bits = f.split(".") + + if len(bits) != len(chunks): + raise ValueError(f"File {f} is not a valid chunk file.") + + if not all(re.match(r"^\d+$", bit) for bit in bits): + raise ValueError(f"File {f} is not a valid chunk file.") + + count += 1 + + if count != file_count: + raise ValueError(f"File count {count} does not match expected {file_count} for {array.name}.") + + +def check_zarr(path: str, verbosity: int = 0) -> None: + """Check if a Zarr archive is valid, that no files are missing, and that the chunking is correct. + + Parameters + ---------- + path : str + Path to the Zarr archive. + verbosity : int, optional + Verbosity level for logging. Default is 0 (no logging). + """ + import zarr + + if verbosity > 0: + LOG.info(f"Checking Zarr archive {path}") + + if not os.path.exists(path) and not os.path.isdir(path): + # This does not work with non-directory Zarr archives + raise ValueError(f"Path {path} does not exist.") + + _check_group(zarr.open(path, mode="r"), verbosity, path) diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 2c349d1b2..4202ed09f 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,9 +13,8 @@ import yaml -from anemoi.datasets.create.gridded.check import DatasetName - -from .. import Command +from anemoi.datasets.commands import Command +from anemoi.datasets.create.check import DatasetName LOG = logging.getLogger(__name__) @@ -90,7 +89,7 @@ def _check_name(self, name: str) -> None: def _check_zarr(self, zarr: str) -> None: - from anemoi.datasets.misc.check import check_zarr + from anemoi.datasets.check import check_zarr check_zarr(zarr) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 0c4aabffa..9628bae8e 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -19,9 +19,8 @@ from anemoi.utils.remote import Transfer from anemoi.utils.remote import TransferMethodNotImplementedError -from anemoi.datasets.misc.check import check_zarr - -from . import Command +from anemoi.datasets.check import check_zarr +from anemoi.datasets.commands import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 9c7f63cc4..151b175d9 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -32,7 +32,7 @@ def task(what: str, fields: bool, options: dict, *args: Any, **kwargs: Any) -> A options = {k: v for k, v in options.items() if v is not None} - c = task_factory(what.replace("-", "_"), **options) + c = task_factory(what.replace("-", "_"), fields, **options) result = c.run() LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 6f1ac3555..257bee122 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,10 +27,9 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset +from anemoi.datasets.commands import Command +from anemoi.datasets.use.gridded.stores import dataset_lookup from anemoi.datasets.use.gridded.stores import open_zarr -from anemoi.datasets.use.gridded.stores import zarr_lookup - -from .. import Command LOG = logging.getLogger(__name__) @@ -396,9 +395,13 @@ def progress(self) -> None: ) return - build_flags = self.build_flags or np.array([], dtype=bool) + if self.build_flags is None: + print("🪫 Dataset not initialised") + return + + build_flags = self.build_flags - build_lengths = self.build_lengths or np.array([], dtype=bool) + build_lengths = self.build_lengths assert build_flags.size == build_lengths.size latest_write_timestamp = self.zarr.attrs.get("latest_write_timestamp") @@ -810,7 +813,7 @@ def _info(self, path: str) -> Version: Version The version object of the dataset. """ - z = open_zarr(zarr_lookup(path)) + z = open_zarr(dataset_lookup(path)) metadata = dict(z.attrs) version = metadata.get("version", "0.0.0") diff --git a/src/anemoi/datasets/commands/recipe/__init__.py b/src/anemoi/datasets/commands/recipe/__init__.py index e93184bf2..3af2a2230 100644 --- a/src/anemoi/datasets/commands/recipe/__init__.py +++ b/src/anemoi/datasets/commands/recipe/__init__.py @@ -15,8 +15,6 @@ import yaml -from anemoi.datasets.create.gridded import validate_config - from .. import Command from .format import format_recipe from .migrate import migrate_recipe @@ -37,6 +35,7 @@ def add_arguments(self, command_parser: Any) -> None: command_parser.add_argument("--validate", action="store_true", help="Validate recipe.") command_parser.add_argument("--format", action="store_true", help="Format the recipe.") command_parser.add_argument("--migrate", action="store_true", help="Migrate the recipe to the latest version.") + command_parser.add_argument("--python", action="store_true", help="Convert the recipe to a Python script.") group = command_parser.add_mutually_exclusive_group() group.add_argument("--inplace", action="store_true", help="Overwrite the recipe file in place.") @@ -49,7 +48,7 @@ def add_arguments(self, command_parser: Any) -> None: def run(self, args: Any) -> None: - if not args.validate and not args.format and not args.migrate: + if not args.validate and not args.format and not args.migrate and not args.python: args.validate = True with open(args.path) as file: @@ -58,10 +57,12 @@ def run(self, args: Any) -> None: assert isinstance(config, dict) if args.validate: - if args.inplace and (not args.format and not args.migrate): + from anemoi.datasets.create.gridded.tasks import validate_config + + if args.inplace and (not args.format and not args.migrate and not args.python): argparse.ArgumentError(None, "--inplace is not supported with --validate.") - if args.output and (not args.format and not args.migrate): + if args.output and (not args.format and not args.migrate and not args.python): argparse.ArgumentError(None, "--output is not supported with --validate.") validate_config(config) diff --git a/src/anemoi/datasets/commands/recipe/format.py b/src/anemoi/datasets/commands/recipe/format.py index b6993a49a..3c5f43431 100644 --- a/src/anemoi/datasets/commands/recipe/format.py +++ b/src/anemoi/datasets/commands/recipe/format.py @@ -1,4 +1,4 @@ -# (C) Copyright 2025 Anemoi contributors. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. diff --git a/src/anemoi/datasets/commands/recipe/migrate.py b/src/anemoi/datasets/commands/recipe/migrate.py index e69de29bb..7edbd80e2 100644 --- a/src/anemoi/datasets/commands/recipe/migrate.py +++ b/src/anemoi/datasets/commands/recipe/migrate.py @@ -0,0 +1,555 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import sys +from collections.abc import Sequence +from typing import Any + +from glom import assign +from glom import delete +from glom import glom + +from anemoi.datasets.create.gridded.tasks import validate_config +from anemoi.datasets.dumper import yaml_dump + +LOG = logging.getLogger(__name__) + + +def find_paths(data, target_key=None, target_value=None, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if (target_key is not None and k == target_key) or (target_value is not None and v == target_value): + matches.append(list(path) + [k]) + matches.extend(find_paths(v, target_key, target_value, *path, k)) + elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)): + for i, item in enumerate(data): + matches.extend(find_paths(item, target_key, target_value, *path, str(i))) + return matches + + +def find_chevrons(data, *path): + + matches = [] + + if isinstance(data, dict): + for k, v in data.items(): + if k == "<<": + matches.append(list(path) + [k]) + matches.extend(find_chevrons(v, *path, k)) + elif isinstance(data, list): + for i, item in enumerate(data): + matches.extend(find_chevrons(item, *path, str(i))) + return matches + + +def find_paths_in_substrees(path, obj, cur_path=None): + if cur_path is None: + cur_path = [] + matches = [] + try: + glom(obj, path) # just to check existence + matches.append(cur_path + path.split(".")) + except Exception: + pass + + if isinstance(obj, dict): + for k, v in obj.items(): + matches.extend(find_paths_in_substrees(path, v, cur_path + [k])) + elif isinstance(obj, list): + for i, v in enumerate(obj): + matches.extend(find_paths_in_substrees(path, v, cur_path + [str(i)])) + return matches + + +MIGRATE = { + "output.statistics_end": "statistics.end", + "has_nans": "statistics.allow_nans", + "loop.dates.group_by": "build.group_by", + "loop.0.dates.group_by": "build.group_by", + "loop.dates": "dates", + "loop.0.dates": "dates", + "copyright": "attribution", + "dates.<<": "dates", + "options.group_by": "build.group_by", + "loops.0.loop_a.dates": "dates", + "loop.0.loop_a.dates": "dates", + "dates.stop": "dates.end", + "dates.group_by": "build.group_by", + "include.mars": "data_sources.mars.mars", + "ensemble_dimension": "build.ensemble_dimension", + "flatten_grid": "build.flatten_grid", +} + +DELETE = [ + "purpose", + # "input.join.0.label", + "status", + "common", + "config_format_version", + "aliases", + # "platform", + "loops.0.loop_a.applies_to", + "loop.0.loop_a.applies_to", + "dataset_status", + "alias", + "resources", + "input.dates.<<", + "input.dates.join.0.label.name", +] + + +SOURCES = { + "oper-accumulations": "accumulations", + "era5-accumulations": "accumulations", + "ensemble-perturbations": "recentre", + "ensemble_perturbations": "recentre", + "perturbations": "recentre", + "custom-regrid": "regrid", +} + +MARKER = object() + + +def _delete(config, path): + x = glom(config, path, default=MARKER) + if x is MARKER: + return + delete(config, path) + + +def _move(config, path, new_path, result): + x = glom(config, path, default=MARKER) + if x is MARKER: + return + delete(result, path) + assign(result, new_path, x, missing=dict) + + +def _fix_input_0(config): + if isinstance(config["input"], dict): + return + + input = config["input"] + new_input = [] + + blocks = {} + first = None + for block in input: + assert isinstance(block, dict), block + + assert len(block) == 1, block + + block_name, values = list(block.items())[0] + + if "kwargs" in values: + inherit = values.pop("inherit", None) + assert len(values) == 1, values + values = values["kwargs"] + values.pop("date", None) + source_name = values.pop("name", None) + + if inherit is not None: + if inherit.startswith("$"): + inherit = inherit[1:] + inherited = blocks[inherit].copy() + inherited.update(values) + values = inherited + + if first is None: + first = source_name + + blocks[block_name] = values.copy() + + new_input.append({SOURCES.get(source_name, source_name): values.copy()}) + else: + assert False, f"Block {block_name} does not have 'kwargs': {values}" + + blocks[block_name] = values.copy() + + config["input"] = dict(join=new_input) + + +def _fix_input_1(result, config): + if isinstance(config["input"], dict): + return + + input = config["input"] + join = [] + for k in input: + assert isinstance(k, dict) + assert len(k) == 1, f"Input key {k} is not a string: {input}" + name, values = list(k.items())[0] + join.append(values) + + result["input"] = {"join": join} + config["input"] = result["input"].copy() + + +def remove_empties(config: dict) -> None: + """Remove empty dictionaries and lists from the config.""" + if isinstance(config, dict): + keys_to_delete = [k for k, v in config.items() if v in (None, {}, [], [{}])] + + for k in keys_to_delete: + del config[k] + + for k, v in config.items(): + remove_empties(v) + + if isinstance(config, list): + for item in config: + remove_empties(item) + + +def _fix_loops(result: dict, config: dict) -> None: + if "loops" not in config: + return + + input = config["input"] + loops = config["loops"] + + assert isinstance(loops, list), loops + assert isinstance(input, list), input + + entries = {} + dates_block = None + for loop in loops: + assert isinstance(loop, dict), loop + assert len(loop) == 1, loop + loop = list(loop.values())[0] + applies_to = loop["applies_to"] + dates = loop["dates"] + assert isinstance(applies_to, list), (applies_to, loop) + for a in applies_to: + entries[a] = dates.copy() + + if "start" in dates: + start = dates["start"] + else: + start = max(dates["values"]) + + if "end" in dates or "stop" in dates: + end = dates.get("end", dates.get("stop")) + else: + end = min(dates["values"]) + + if dates_block is None: + dates_block = { + "start": start, + "end": end, + } + + if "frequency" in dates: + if "frequency" not in dates_block: + dates_block["frequency"] = dates["frequency"] + else: + assert dates_block["frequency"] == dates["frequency"], (dates_block["frequency"], dates["frequency"]) + + dates_block["start"] = min(dates_block["start"], start) + dates_block["end"] = max(dates_block["end"], end) + + concat = [] + result["input"] = {"concat": concat} + + print("Found loops:", entries) + + for block in input: + assert isinstance(block, dict), block + assert len(block) == 1, block + name, values = list(block.items())[0] + assert name in entries, f"Loop {name} not found in loops: {list(entries.keys())}" + dates = entries[name].copy() + + assert "kwargs" not in values + + concat.append(dict(dates=dates, **values)) + + d = concat[0]["dates"] + if all(c["dates"] == d for c in concat): + join = [] + for c in concat: + del c["dates"] + join.append(c) + result["input"] = {"join": join} + + del config["loops"] + config["input"] = result["input"].copy() + config["dates"] = dates_block.copy() + del result["loops"] + result["dates"] = dates_block + + +def _fix_other(result: dict, config: dict) -> None: + paths = find_paths(config, target_key="source_or_dataset", target_value="$previous_data") + for p in paths: + print(f"Fixing {'.'.join(p)}") + assign(result, ".".join(p[:-1] + ["template"]), "${input.join.0.mars}", missing=dict) + delete(result, ".".join(p)) + + paths = find_paths(config, target_key="date", target_value="$dates") + for p in paths: + delete(result, ".".join(p)) + + +def _fix_join(result: dict, config: dict) -> None: + print("Fixing join...") + input = config["input"] + if "dates" in input and "join" in input["dates"]: + result["input"]["join"] = input["dates"]["join"] + config["input"]["join"] = input["dates"]["join"].copy() + + if "join" not in input: + return + + join = input["join"] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1 + + key, values = list(j.items())[0] + + if key not in ("label", "source"): + return + + assert isinstance(values, dict), f"Join values for {key} should be a dict: {values}" + if key == "label": + j = values + j.pop("name") + key, values = list(j.items())[0] + + print(values) + source_name = values.pop("name", "mars") + new_join.append( + { + SOURCES.get(source_name, source_name): values, + } + ) + + result["input"] = {"join": new_join} + config["input"] = result["input"].copy() + + +def _fix_sources(config: dict, what) -> None: + + input = config["input"] + if what not in input: + return + + join = input[what] + new_join = [] + for j in join: + assert isinstance(j, dict) + assert len(j) == 1, j + + key, values = list(j.items())[0] + + key = SOURCES.get(key, key) + + new_join.append( + { + key: values, + } + ) + + config["input"][what] = new_join + config["input"][what] = new_join.copy() + + +def _assign(config, path, value): + print(f"Assign {path} {value}") + assign(config, path, value) + + +def _fix_chevrons(result: dict, config: dict) -> None: + print("Fixing chevrons...") + paths = find_chevrons(config) + for p in paths: + a = glom(config, ".".join(p)) + b = glom(config, ".".join(p[:-1])) + delete(result, ".".join(p)) + a.update(b) + assign(result, ".".join(p[:-1]), a) + + +def _fix_some(config: dict) -> None: + + paths = find_paths_in_substrees("label.function", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + _assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("constants.source_or_dataset", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + node["template"] = node.pop("source_or_dataset") + if node["template"] == "$previous_data": + node["template"] = "${input.join.0.mars}" + paths = find_paths_in_substrees("constants.template", config) + for p in paths: + node = glom(config, ".".join(p[:-1])) + if node["template"] == "$pl_data": + node["template"] = "${input.join.0.mars}" + for d in ("date", "dates", "time"): + paths = find_paths_in_substrees(d, config) + for p in paths: + if len(p) > 1: + node = glom(config, ".".join(p[:-1])) + if isinstance(node, dict) and isinstance(node[d], str) and node[d].startswith("$"): + del node[d] + + paths = find_paths_in_substrees("source.<<", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + node.update(node.pop("<<")) + parent[node.pop("name")] = node + assert len(parent) == 2, parent + del parent["source"] + + paths = find_paths_in_substrees("label.mars", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + assert node + assign(config, ".".join(p[:-2]), node) + + paths = find_paths_in_substrees("input.dates.join", config) + for p in paths: + node = glom(config, ".".join(p)) + config["input"]["join"] = node + del config["input"]["dates"] + + paths = find_paths_in_substrees("source.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assign(config, ".".join(p[:-2]), {name: node}) + + paths = find_paths_in_substrees("function.name", config) + for p in paths: + parent = glom(config, ".".join(p[:-2])) + node = glom(config, ".".join(p[:-1])) + name = node.pop("name") + assert node + assign(config, ".".join(p[:-2]), {name: node}) + + +def _migrate(config: dict, n) -> dict: + + result = config.copy() + + _fix_input_0(result) + # _fix_loops(result, config) + # _fix_input_1(result, config) + # _fix_join(result, config) + # _fix_chevrons(result, config) + # _fix_other(result, config) + + for k, v in MIGRATE.items(): + _move(config, k, v, result) + + _fix_some(result) + _fix_sources(result, "join") + + for k in DELETE: + _delete(result, k) + + remove_empties(result) + + return result + + +def migrate(old: dict) -> dict: + + for i in range(10): + new = _migrate(old, i) + if new == old: + return new + old = new + + return new + + +def has_key(config, key: str) -> bool: + if isinstance(config, dict): + if key in config: + return True + for k, v in config.items(): + if has_key(v, key): + return True + if isinstance(config, list): + for item in config: + if has_key(item, key): + return True + return False + + +def has_value(config, value: str) -> bool: + if isinstance(config, dict): + for k, v in config.items(): + if v == value: + return True + if has_value(v, value): + return True + + if isinstance(config, list): + for item in config: + if item == value: + return True + if has_value(item, value): + return True + return config == value + + +def check(config): + + try: + + validate_config(config) + assert config.get("input", {}) + assert config.get("dates", {}) + assert not has_key(config, "label") + assert not has_key(config, "kwargs") + assert not has_value(config, "$previous_data") + assert not has_value(config, "$pl_data") + assert not has_value(config, "$dates") + assert not has_key(config, "inherit") + assert not has_key(config, "source_or_dataset") + assert not has_key(config, "<<") + + for n in SOURCES.keys(): + assert not has_key(config, n), f"Source {n} found in config. Please update to {SOURCES[n]}." + + except Exception as e: + print("Validation failed:") + print(e) + print(yaml_dump(config)) + sys.exit(1) + + +def migrate_recipe(args: Any, config) -> None: + + print(f"Migrating {args.path}") + + migrated = migrate(config) + + check(migrated) + if migrated == config: + return None + + return migrated diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 6af8ee996..03691541c 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,9 +10,8 @@ import logging from typing import Any -from anemoi.datasets.misc.validate import validate_dataset - -from .. import Command +from anemoi.datasets.commands import Command +from anemoi.datasets.validate import validate_dataset LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 9fc775e54..e69de29bb 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -1,8 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py new file mode 100644 index 000000000..3c09cc80b --- /dev/null +++ b/src/anemoi/datasets/create/check.py @@ -0,0 +1,328 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import logging +import re +import warnings +from collections.abc import Callable +from typing import Any + +import numpy as np +from anemoi.utils.config import load_config +from anemoi.utils.dates import frequency_to_string +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +class DatasetName: + """Validate and parse dataset names according to naming conventions.""" + + def __init__( + self, + name: str, + resolution: str | None = None, + start_date: datetime.date | None = None, + end_date: datetime.date | None = None, + frequency: datetime.timedelta | None = None, + ): + """Initialize a DatasetName instance. + + Parameters + ---------- + name : str + The name of the dataset. + resolution : Optional[str], optional + The resolution of the dataset. + start_date : Optional[datetime.date], optional + The start date of the dataset. + end_date : Optional[datetime.date], optional + The end date of the dataset. + frequency : Optional[datetime.timedelta], optional + The frequency of the dataset. + """ + self.name = name + self.parsed = self._parse(name) + print("---------------") + print(self.parsed) + print("---------------") + + self.messages = [] + + config = load_config().get("datasets", {}) + + if config.get("ignore_naming_conventions", False): + # setting the env variable ANEMOI_CONFIG_DATASETS_IGNORE_NAMING_CONVENTIONS=1 + # will ignore the naming conventions + return + + self.check_characters() + self.check_parsed() + self.check_resolution(resolution) + self.check_frequency(frequency) + self.check_start_date(start_date) + self.check_end_date(end_date) + + if self.messages: + self.messages.append(f"{self} is parsed as :" + "/".join(f"{k}={v}" for k, v in self.parsed.items())) + + @property + def error_message(self) -> str: + """Generate an error message based on the collected messages.""" + out = " And ".join(self.messages) + if out: + out[0].upper() + out[1:] + return out + + def raise_if_not_valid(self, print: Callable = print) -> None: + """Raise a ValueError if the dataset name is not valid. + + Parameters + ---------- + print : Callable + The function to use for printing messages. + """ + if self.messages: + for m in self.messages: + print(m) + raise ValueError(self.error_message) + + def _parse(self, name: str) -> dict: + """Parse the dataset name into its components. + + Parameters + ---------- + name : str + The name of the dataset. + + Returns + ------- + dict + The parsed components of the dataset name. + """ + pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$" + match = re.match(pattern, name) + + if not match: + raise ValueError(f"the dataset name '{name}' does not follow naming convention. Does not match {pattern}") + + parsed = {} + if match: + keys = [ + "purpose", + "labelling", + "source", + "resolution", + "start_date", + "end_date", + "frequency", + "version", + "additional", + ] + parsed = {k: v for k, v in zip(keys, match.groups())} + + return parsed + + def __str__(self) -> str: + """Return the string representation of the dataset name.""" + return self.name + + def check_parsed(self) -> None: + """Check if the dataset name was parsed correctly.""" + if not self.parsed: + self.messages.append( + f"the dataset name {self} does not follow naming convention. " + "See here for details: " + "https://anemoi-registry.readthedocs.io/en/latest/naming-conventions.html" + ) + + def check_resolution(self, resolution: str | None) -> None: + """Check if the resolution matches the expected format. + + Parameters + ---------- + resolution : str or None + The expected resolution. + """ + if self.parsed.get("resolution") and self.parsed["resolution"][0] not in "0123456789on": + self.messages.append( + f"the resolution {self.parsed['resolution'] } should start " + f"with a number or 'o' or 'n' in the dataset name {self}." + ) + + if resolution is None: + return + resolution_str = str(resolution).replace(".", "p").lower() + self._check_missing("resolution", resolution_str) + self._check_mismatch("resolution", resolution_str) + + def check_characters(self) -> None: + if not self.name.islower(): + self.messages.append(f"the {self.name} should be in lower case.") + if "_" in self.name: + self.messages.append(f"the {self.name} should use '-' instead of '_'.") + for c in self.name: + if not c.isalnum() and c not in "-": + self.messages.append(f"the {self.name} should only contain alphanumeric characters and '-'.") + + def check_frequency(self, frequency: datetime.timedelta | None) -> None: + """Check if the frequency matches the expected format. + + Parameters + ---------- + frequency : datetime.timedelta or None + The expected frequency. + """ + if frequency is None: + return + frequency_str = frequency_to_string(frequency) + self._check_missing("frequency", frequency_str) + self._check_mismatch("frequency", frequency_str) + + def check_start_date(self, start_date: datetime.date | None) -> None: + """Check if the start date matches the expected format. + + Parameters + ---------- + start_date : datetime.date or None + The expected start date. + """ + if start_date is None: + return + start_date_str = str(start_date.year) + self._check_missing("start_date", start_date_str) + self._check_mismatch("start_date", start_date_str) + + def check_end_date(self, end_date: datetime.date | None) -> None: + """Check if the end date matches the expected format. + + Parameters + ---------- + end_date : datetime.date or None + The expected end date. + """ + if end_date is None: + return + end_date_str = str(end_date.year) + self._check_missing("end_date", end_date_str) + self._check_mismatch("end_date", end_date_str) + + def _check_missing(self, key: str, value: str) -> None: + """Check if a component is missing from the dataset name. + + Parameters + ---------- + key : str + The component key. + value : str + The expected value. + """ + if value not in self.name: + self.messages.append(f"the {key} is {value}, but is missing in {self.name}.") + + def _check_mismatch(self, key: str, value: str) -> None: + """Check if a component value mismatches the expected value. + + Parameters + ---------- + key : str + The component key. + value : str + The expected value. + """ + if self.parsed.get(key) and self.parsed[key] != value: + self.messages.append(f"the {key} is {value}, but is {self.parsed[key]} in {self.name}.") + + +class StatisticsValueError(ValueError): + """Custom error for statistics value issues.""" + + pass + + +def check_data_values( + arr: NDArray[Any], *, name: str, log: list = [], allow_nans: bool | list | set | tuple | dict = False +) -> None: + """Check the values in the data array for validity. + + Parameters + ---------- + arr : NDArray[Any] + The data array to check. + name : str + The name of the data array. + log : list, optional + A list to log messages. + allow_nans : bool or list or set or tuple or dict, optional + Whether to allow NaNs in the data array. + """ + shape = arr.shape + + if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans: + arr = arr[~np.isnan(arr)] + + if arr.size == 0: + warnings.warn(f"Empty array for {name} ({shape})") + return + + assert arr.size > 0, (name, *log) + + min, max = arr.min(), arr.max() + assert not (np.isnan(arr).any()), (name, min, max, *log) + + if min == 9999.0: + warnings.warn(f"Min value 9999 for {name}") + + if max == 9999.0: + warnings.warn(f"Max value 9999 for {name}") + + in_minus_1_plus_1 = dict(minimum=-1, maximum=1) + limits = { + "cos_latitude": in_minus_1_plus_1, + "sin_latitude": in_minus_1_plus_1, + "cos_longitude": in_minus_1_plus_1, + "sin_longitude": in_minus_1_plus_1, + } + + if name in limits: + if min < limits[name]["minimum"]: + warnings.warn( + f"For {name}: minimum value in the data is {min}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" + ) + if max > limits[name]["maximum"]: + warnings.warn( + f"For {name}: maximum value in the data is {max}. " + "Not in acceptable range [{limits[name]['minimum']} ; {limits[name]['maximum']}]" + ) + + +def check_stats(minimum: float, maximum: float, mean: float, msg: str, **kwargs: Any) -> None: + """Check if the mean value is within the min/max interval. + + Parameters + ---------- + minimum : float + The minimum value. + maximum : float + The maximum value. + mean : float + The mean value. + msg : str + The message to include in the error. + **kwargs : Any + Additional keyword arguments. + """ + tolerance = (abs(minimum) + abs(maximum)) * 0.01 + if (mean - minimum < -tolerance) or (mean - minimum < -tolerance): + raise StatisticsValueError( + f"Mean is not in min/max interval{msg} : we should have {minimum} <= {mean} <= {maximum}" + ) diff --git a/src/anemoi/datasets/create/chunks.py b/src/anemoi/datasets/create/chunks.py new file mode 100644 index 000000000..08cc1edfd --- /dev/null +++ b/src/anemoi/datasets/create/chunks.py @@ -0,0 +1,138 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import warnings + +LOG = logging.getLogger(__name__) + +ALL = object() + + +class ChunkFilter: + """A filter to determine which chunks to process based on the specified parts. + + Attributes + ---------- + total : int + The total number of chunks. + allowed : object or list + The chunks that are allowed to be processed. + """ + + def __init__(self, *, parts: str | list, total: int): + """Initializes the ChunkFilter with the given parts and total number of chunks. + + Parameters + ---------- + parts : str or list + The parts to process, specified as 'i/n' or a list of such strings. + total : int + The total number of chunks. + + Raises + ------ + ValueError + If the parts format is invalid. + AssertionError + If the chunk number is invalid. + Warning + If the number of chunks is larger than the total number of chunks. + """ + self.total = total + + if isinstance(parts, list): + if len(parts) == 1: + parts = parts[0] + elif len(parts) == 0: + parts = None + else: + raise ValueError(f"Invalid parts format: {parts}. Must be in the form 'i/n'.") + + if not parts: + parts = "all" + + assert isinstance(parts, str), f"Argument parts must be a string, got {parts}." + + if parts.lower() == "all" or parts == "*": + self.allowed = ALL + return + + assert "/" in parts, f"Invalid parts format: {parts}. Must be in the form 'i/n'." + + i, n = parts.split("/") + i, n = int(i), int(n) + + assert i > 0, f"Chunk number {i} must be positive." + assert i <= n, f"Chunk number {i} must be less than total chunks {n}." + if n > total: + warnings.warn( + f"Number of chunks {n} is larger than the total number of chunks: {total}. " + "Some chunks will be empty." + ) + + chunk_size = total / n + parts = [x for x in range(total) if x >= (i - 1) * chunk_size and x < i * chunk_size] + + for i in parts: + if i < 0 or i >= total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {total - 1}.") + if not parts: + warnings.warn(f"Nothing to do for chunk {i}/{n}.") + + LOG.debug(f"Running parts: {parts}") + + self.allowed = parts + + def __call__(self, i: int) -> bool: + """Checks if the given chunk number is allowed to be processed. + + Parameters + ---------- + i : int + The chunk number to check. + + Returns + ------- + bool + True if the chunk is allowed, False otherwise. + + Raises + ------ + AssertionError + If the chunk number is invalid. + """ + if i < 0 or i >= self.total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {self.total - 1}.") + + if self.allowed == ALL: + return True + return i in self.allowed + + def __iter__(self) -> iter: + """Iterates over the allowed chunks. + + Yields + ------ + int + The next allowed chunk number. + """ + for i in range(self.total): + if self(i): + yield i + + def __len__(self) -> int: + """Returns the number of allowed chunks. + + Returns + ------- + int + The number of allowed chunks. + """ + return len([_ for _ in self]) diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py new file mode 100644 index 000000000..2e5f27de7 --- /dev/null +++ b/src/anemoi/datasets/create/config.py @@ -0,0 +1,453 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +import subprocess +import sys +from copy import deepcopy +from typing import Any + +import yaml +from anemoi.utils.config import DotDict +from anemoi.utils.config import load_any_dict_format +from earthkit.data.core.order import normalize_order_by + +from anemoi.datasets.dates.groups import Groups + +LOG = logging.getLogger(__name__) + + +def _get_first_key_if_dict(x: str | dict) -> str: + """Returns the first key if the input is a dictionary, otherwise returns the input string. + + Parameters + ---------- + x : str or dict + Input string or dictionary. + + Returns + ------- + str + The first key if input is a dictionary, otherwise the input string. + """ + if isinstance(x, str): + return x + return list(x.keys())[0] + + +def ensure_element_in_list(lst: list, elt: str, index: int) -> list: + """Ensures that a specified element is present at a given index in a list. + + Parameters + ---------- + lst : list + The list to check. + elt : str + The element to ensure is in the list. + index : int + The index at which the element should be present. + + Returns + ------- + list + The modified list with the element at the specified index. + """ + if elt in lst: + assert lst[index] == elt + return lst + + _lst = [_get_first_key_if_dict(d) for d in lst] + if elt in _lst: + assert _lst[index] == elt + return lst + + return lst[:index] + [elt] + lst[index:] + + +def check_dict_value_and_set(dic: dict, key: str, value: Any) -> None: + """Checks if a dictionary contains a specific key-value pair and sets it if not present. + + Parameters + ---------- + dic : dict + The dictionary to check. + key : str + The key to check in the dictionary. + value : Any + The value to set if the key is not present. + + Raises + ------ + ValueError + If the key is present but with a different value. + """ + if key in dic: + if dic[key] == value: + return + raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.") + # LOG.info(f"Setting {key}={value} in config") + dic[key] = value + + +def resolve_includes(config: dict | list) -> dict | list: + """Resolves '<<' includes in a configuration dictionary or list. + + Parameters + ---------- + config : dict or list + The configuration to resolve includes for. + + Returns + ------- + dict or list + The configuration with includes resolved. + """ + if isinstance(config, list): + return [resolve_includes(c) for c in config] + if isinstance(config, dict): + include = config.pop("<<", {}) + new = deepcopy(include) + new.update(config) + return {k: resolve_includes(v) for k, v in new.items()} + return config + + +class Config(DotDict): + """Configuration class that extends DotDict to handle configuration loading and processing.""" + + def __init__(self, config: str | dict | None = None, **kwargs): + """Initializes the Config object. + + Parameters + ---------- + config : str or dict, optional + Path to the configuration file or a dictionary. Defaults to None. + **kwargs + Additional keyword arguments to update the configuration. + """ + if isinstance(config, str): + self.config_path = os.path.realpath(config) + config = load_any_dict_format(config) + else: + config = deepcopy(config if config is not None else {}) + config = resolve_includes(config) + config.update(kwargs) + super().__init__(config) + + +class OutputSpecs: + """Class to handle output specifications for datasets.""" + + def __init__(self, config: Config, parent: Any): + """Initializes the OutputSpecs object. + + Parameters + ---------- + config : Config + The configuration object. + parent : Any + The parent object. + """ + self.config = config + if "order_by" in config: + assert isinstance(config.order_by, dict), config.order_by + + self.parent = parent + + @property + def dtype(self) -> str: + """Returns the data type for the output.""" + return self.config.dtype + + @property + def order_by_as_list(self) -> list[dict]: + """Returns the order_by configuration as a list of dictionaries.""" + return [{k: v} for k, v in self.config.order_by.items()] + + def get_chunking(self, coords: dict) -> tuple: + """Returns the chunking configuration based on coordinates. + + Parameters + ---------- + coords : dict + The coordinates dictionary. + + Returns + ------- + tuple + The chunking configuration. + """ + user = deepcopy(self.config.chunking) + chunks = [] + for k, v in coords.items(): + if k in user: + chunks.append(user.pop(k)) + else: + chunks.append(len(v)) + if user: + raise ValueError( + f"Unused chunking keys from config: {list(user.keys())}, not in known keys : {list(coords.keys())}" + ) + return tuple(chunks) + + @property + def order_by(self) -> dict: + """Returns the order_by configuration.""" + return self.config.order_by + + @property + def remapping(self) -> dict: + """Returns the remapping configuration.""" + return self.config.remapping + + @property + def flatten_grid(self) -> bool: + """Returns whether the grid should be flattened.""" + return self.config.flatten_grid + + @property + def statistics(self) -> str: + """Returns the statistics configuration.""" + return self.config.statistics + + +class LoadersConfig(Config): + """Configuration class for dataset loaders.""" + + def __init__(self, config: dict, *args, **kwargs): + """Initializes the LoadersConfig object. + + Parameters + ---------- + config : dict + The configuration dictionary. + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + """ + super().__init__(config, *args, **kwargs) + + # TODO: should use a json schema to validate the config + + self.setdefault("dataset_status", "experimental") + self.setdefault("description", "No description provided.") + self.setdefault("licence", "unknown") + self.setdefault("attribution", "unknown") + + self.setdefault("build", Config()) + self.build.setdefault("group_by", "monthly") + self.build.setdefault("use_grib_paramid", False) + self.build.setdefault("variable_naming", "default") + variable_naming = dict( + param="{param}", + param_levelist="{param}_{levelist}", + default="{param}_{levelist}", + ).get(self.build.variable_naming, self.build.variable_naming) + + self.setdefault("output", Config()) + self.output.setdefault("order_by", ["valid_datetime", "param_level", "number"]) + self.output.setdefault("remapping", Config(param_level=variable_naming)) + self.output.setdefault("statistics", "param_level") + self.output.setdefault("chunking", Config(dates=1, ensembles=1)) + self.output.setdefault("dtype", "float32") + + if "statistics_start" in self.output: + raise ValueError("statistics_start is not supported anymore. Use 'statistics:start:' instead") + if "statistics_end" in self.output: + raise ValueError("statistics_end is not supported anymore. Use 'statistics:end:' instead") + + self.setdefault("statistics", Config()) + if "allow_nans" not in self.statistics: + self.statistics.allow_nans = [] + + check_dict_value_and_set(self.output, "flatten_grid", True) + check_dict_value_and_set(self.output, "ensemble_dimension", 2) + + assert isinstance(self.output.order_by, (list, tuple)), self.output.order_by + self.output.order_by = ensure_element_in_list(self.output.order_by, "number", self.output.ensemble_dimension) + + order_by = self.output.order_by + assert len(order_by) == 3, order_by + assert _get_first_key_if_dict(order_by[0]) == "valid_datetime", order_by + assert _get_first_key_if_dict(order_by[2]) == "number", order_by + + self.output.order_by = normalize_order_by(self.output.order_by) + + self.setdefault("dates", Config()) + + self.dates["group_by"] = self.build.group_by + + ########### + + self.reading_chunks = self.get("reading_chunks") + + def get_serialisable_dict(self) -> dict: + """Returns a serializable dictionary representation of the configuration. + + Returns + ------- + dict + The serializable dictionary. + """ + return _prepare_serialisation(self) + + +def _prepare_serialisation(o: Any) -> Any: + """Prepares an object for serialization. + + Parameters + ---------- + o : Any + The object to prepare. + + Returns + ------- + Any + The prepared object. + """ + if isinstance(o, dict): + dic = {} + for k, v in o.items(): + v = _prepare_serialisation(v) + if k == "order_by" and isinstance(v, dict): + # zarr attributes are saved with sort_keys=True + # and ordered dict are reordered. + # This is a problem for "order_by" + # We ensure here that the order_by key contains + # a list of dict + v = [{kk: vv} for kk, vv in v.items()] + dic[k] = v + return dic + + if isinstance(o, (list, tuple)): + return [_prepare_serialisation(v) for v in o] + + if o in (None, True, False): + return o + + if isinstance(o, (str, int, float)): + return o + + if isinstance(o, (datetime.date, datetime.datetime)): + return o.isoformat() + + return str(o) + + +def set_to_test_mode(cfg: dict) -> None: + """Modifies the configuration to run in test mode. + + Parameters + ---------- + cfg : dict + The configuration dictionary. + """ + NUMBER_OF_DATES = 4 + + LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") + groups = Groups(**LoadersConfig(cfg).dates) + + dates = groups.provider.values + cfg["dates"] = dict( + start=dates[0], + end=dates[NUMBER_OF_DATES - 1], + frequency=groups.provider.frequency, + group_by=NUMBER_OF_DATES, + ) + + def set_element_to_test(obj): + if isinstance(obj, (list, tuple)): + for v in obj: + set_element_to_test(v) + return + if isinstance(obj, (dict, DotDict)): + if "grid" in obj: + previous = obj["grid"] + obj["grid"] = "20./20." + LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") + if "number" in obj: + if isinstance(obj["number"], (list, tuple)): + previous = obj["number"] + obj["number"] = previous[0:3] + LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") + for k, v in obj.items(): + set_element_to_test(v) + if "constants" in obj: + constants = obj["constants"] + if "param" in constants and isinstance(constants["param"], list): + constants["param"] = ["cos_latitude"] + + set_element_to_test(cfg) + + +def loader_config(config: dict, is_test: bool = False) -> LoadersConfig: + """Loads and validates the configuration for dataset loaders. + + Parameters + ---------- + config : dict + The configuration dictionary. + is_test : bool, optional + Whether to run in test mode. Defaults to False. + + Returns + ------- + LoadersConfig + The validated configuration object. + """ + + if isinstance(config, str) and config.endswith(".py"): + result = subprocess.run([sys.executable, config], capture_output=True, text=True, check=True) + config = yaml.safe_load(result.stdout) + + config = Config(config) + if is_test: + set_to_test_mode(config) + obj = LoadersConfig(config) + + # yaml round trip to check that serialisation works as expected + copy = obj.get_serialisable_dict() + copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader) + copy = Config(copy) + copy = LoadersConfig(config) + + a = yaml.dump(obj) + b = yaml.dump(copy) + if a != b: + print(a) + print(b) + raise ValueError("Serialisation failed") + + if "env" in copy: + for k, v in copy["env"].items(): + LOG.info(f"Setting env variable {k}={v}") + os.environ[k] = str(v) + + return copy + + +def build_output(*args, **kwargs) -> OutputSpecs: + """Builds the output specifications. + + Parameters + ---------- + *args + Additional positional arguments. + **kwargs + Additional keyword arguments. + + Returns + ------- + OutputSpecs + The output specifications object. + """ + return OutputSpecs(*args, **kwargs) diff --git a/src/anemoi/datasets/create/filter.py b/src/anemoi/datasets/create/filter.py new file mode 100644 index 000000000..4544db8f2 --- /dev/null +++ b/src/anemoi/datasets/create/filter.py @@ -0,0 +1,47 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from typing import Any + +import earthkit.data as ekd + + +class TransformFilter: + """Calls filters from anemoi.transform.filters + + Parameters + ---------- + context : Any + The context in which the filter is created. + name : str + The name of the filter. + config : Dict[str, Any] + The configuration for the filter. + """ + + def __init__(self, context: Any, name: str, config: dict[str, Any]) -> None: + from anemoi.transform.filters import create_filter + + self.name = name + self.transform_filter = create_filter(context, config) + + def execute(self, input: ekd.FieldList) -> ekd.FieldList: + """Execute the transformation filter. + + Parameters + ---------- + input : ekd.FieldList + The input data to be transformed. + + Returns + ------- + ekd.FieldList + The transformed data. + """ + return self.transform_filter.forward(input) diff --git a/src/anemoi/datasets/create/gridded/__init__.py b/src/anemoi/datasets/create/gridded/__init__.py index 377852420..e69de29bb 100644 --- a/src/anemoi/datasets/create/gridded/__init__.py +++ b/src/anemoi/datasets/create/gridded/__init__.py @@ -1,1658 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import json -import logging -import os -import time -import uuid -import warnings -from functools import cached_property -from typing import Any - -import cftime -import numpy as np -import tqdm -import zarr -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta -from anemoi.utils.humanize import compress_dates -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.sanitise import sanitise -from earthkit.data.core.order import build_remapping - -from anemoi.datasets import MissingDateError -from anemoi.datasets import open_dataset -from anemoi.datasets.create.gridded.check import DatasetName -from anemoi.datasets.create.gridded.check import check_data_values -from anemoi.datasets.create.gridded.chunks import ChunkFilter -from anemoi.datasets.create.gridded.config import build_output -from anemoi.datasets.create.gridded.config import loader_config -from anemoi.datasets.create.gridded.persistent import build_storage -from anemoi.datasets.create.gridded.statistics import Summary -from anemoi.datasets.create.gridded.statistics import TmpStatistics -from anemoi.datasets.create.gridded.statistics import check_variance -from anemoi.datasets.create.gridded.statistics import compute_statistics -from anemoi.datasets.create.gridded.statistics import default_statistics_dates -from anemoi.datasets.create.gridded.statistics import fix_variance -from anemoi.datasets.create.gridded.utils import normalize_and_check_dates -from anemoi.datasets.create.gridded.writer import ViewCacheArray -from anemoi.datasets.create.input import InputBuilder -from anemoi.datasets.create.input.trace import enable_trace -from anemoi.datasets.dates.groups import Groups -from anemoi.datasets.use.gridded.misc import as_first_date -from anemoi.datasets.use.gridded.misc import as_last_date - -LOG = logging.getLogger(__name__) - -VERSION = "0.30" - - -def json_tidy(o: Any) -> Any: - """Convert various types to JSON serializable format. - - Parameters - ---------- - o : Any - The object to convert. - - Returns - ------- - Any - The JSON serializable object. - """ - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.datetime): - return o.isoformat() - - if isinstance(o, datetime.timedelta): - return frequency_to_string(o) - - if isinstance(o, cftime.DatetimeJulian): - import pandas as pd - - o = pd.Timestamp( - o.year, - o.month, - o.day, - o.hour, - o.minute, - o.second, - ) - return o.isoformat() - - if isinstance(o, (np.float32, np.float64)): - return float(o) - - raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") - - -def build_statistics_dates( - dates: list[datetime.datetime], - start: datetime.datetime | None, - end: datetime.datetime | None, -) -> tuple[str, str]: - """Compute the start and end dates for the statistics. - - Parameters - ---------- - dates : list of datetime.datetime - The list of dates. - start : Optional[datetime.datetime] - The start date. - end : Optional[datetime.datetime] - The end date. - - Returns - ------- - tuple of str - The start and end dates in ISO format. - """ - # if not specified, use the default statistics dates - default_start, default_end = default_statistics_dates(dates) - if start is None: - start = default_start - if end is None: - end = default_end - - # in any case, adapt to the actual dates in the dataset - start = as_first_date(start, dates) - end = as_last_date(end, dates) - - # and convert to datetime to isoformat - start = start.astype(datetime.datetime) - end = end.astype(datetime.datetime) - return (start.isoformat(), end.isoformat()) - - -def _path_readable(path: str) -> bool: - """Check if the path is readable. - - Parameters - ---------- - path : str - The path to check. - - Returns - ------- - bool - True if the path is readable, False otherwise. - """ - import zarr - - try: - zarr.open(path, "r") - return True - except zarr.errors.PathNotFoundError: - return False - - -class Dataset: - """A class to represent a dataset.""" - - def __init__(self, path: str): - """Initialize a Dataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - self.path = path - - _, ext = os.path.splitext(self.path) - if ext != ".zarr": - raise ValueError(f"Unsupported extension={ext} for path={self.path}") - - def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: - """Add a dataset to the Zarr store. - - Parameters - ---------- - mode : str, optional - The mode to open the Zarr store. - **kwargs - Additional arguments for the dataset. - - Returns - ------- - zarr.Array - The added dataset. - """ - import zarr - - z = zarr.open(self.path, mode=mode) - from anemoi.datasets.create.gridded.zarr import add_zarr_dataset - - return add_zarr_dataset(zarr_root=z, **kwargs) - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - import zarr - - LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") - for k, v in kwargs.items(): - if isinstance(v, np.datetime64): - v = v.astype(datetime.datetime) - if isinstance(v, datetime.date): - v = v.isoformat() - z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) - - @cached_property - def anemoi_dataset(self) -> Any: - """Get the Anemoi dataset.""" - return open_dataset(self.path) - - @cached_property - def zarr_metadata(self) -> dict: - """Get the Zarr metadata.""" - import zarr - - return dict(zarr.open(self.path, mode="r").attrs) - - def print_info(self) -> None: - """Print information about the dataset.""" - import zarr - - z = zarr.open(self.path, mode="r") - try: - LOG.info(z["data"].info) - except Exception as e: - LOG.info(e) - - def get_zarr_chunks(self) -> tuple: - """Get the chunks of the Zarr dataset. - - Returns - ------- - tuple - The chunks of the Zarr dataset. - """ - import zarr - - z = zarr.open(self.path, mode="r") - return z["data"].chunks - - def check_name( - self, - resolution: str, - dates: list[datetime.datetime], - frequency: datetime.timedelta, - raise_exception: bool = True, - is_test: bool = False, - ) -> None: - """Check the name of the dataset. - - Parameters - ---------- - resolution : str - The resolution of the dataset. - dates : list of datetime.datetime - The dates of the dataset. - frequency : datetime.timedelta - The frequency of the dataset. - raise_exception : bool, optional - Whether to raise an exception if the name is invalid. - is_test : bool, optional - Whether this is a test. - """ - basename, _ = os.path.splitext(os.path.basename(self.path)) - try: - DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() - except Exception as e: - if raise_exception and not is_test: - raise e - else: - LOG.warning(f"Dataset name error: {e}") - - def get_main_config(self) -> Any: - """Get the main configuration of the dataset. - - Returns - ------- - Any - The main configuration. - """ - import zarr - - z = zarr.open(self.path, mode="r") - config = loader_config(z.attrs.get("_create_yaml_config")) - - if "env" in config: - for k, v in config["env"].items(): - LOG.info(f"Setting env variable {k}={v}") - os.environ[k] = str(v) - - return config - - -class WritableDataset(Dataset): - """A class to represent a writable dataset.""" - - def __init__(self, path: str): - """Initialize a WritableDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="r+") - - @cached_property - def data_array(self) -> Any: - """Get the data array of the dataset.""" - import zarr - - return zarr.open(self.path, mode="r+")["data"] - - -class NewDataset(Dataset): - """A class to represent a new dataset.""" - - def __init__(self, path: str, overwrite: bool = False): - """Initialize a NewDataset instance. - - Parameters - ---------- - path : str - The path to the dataset. - overwrite : bool, optional - Whether to overwrite the existing dataset. - """ - super().__init__(path) - self.path = path - - import zarr - - self.z = zarr.open(self.path, mode="w") - self.z.create_group("_build") - - -class Actor: # TODO: rename to Creator - """A base class for dataset creation actors.""" - - dataset_class = WritableDataset - - def __init__(self, path: str, cache: str | None = None): - """Initialize an Actor instance. - - Parameters - ---------- - path : str - The path to the dataset. - cache : Optional[str], optional - The cache directory. - """ - # Catch all floating point errors, including overflow, sqrt(<0), etc - np.seterr(all="raise", under="warn") - - self.path = path - self.cache = cache - self.dataset = self.dataset_class(self.path) - - def run(self) -> None: - """Run the actor.""" - # to be implemented in the sub-classes - raise NotImplementedError() - - def update_metadata(self, **kwargs: Any) -> None: - """Update the metadata of the dataset. - - Parameters - ---------- - **kwargs - The metadata to update. - """ - self.dataset.update_metadata(**kwargs) - - def _cache_context(self) -> Any: - """Get the cache context. - - Returns - ------- - Any - The cache context. - """ - from anemoi.datasets.create.gridded.utils import cache_context - - return cache_context(self.cache) - - def check_unkown_kwargs(self, kwargs: dict) -> None: - """Check for unknown keyword arguments. - - Parameters - ---------- - kwargs : dict - The keyword arguments. - """ - # remove this latter - LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") - - def read_dataset_metadata(self, path: str) -> None: - """Read the metadata of the dataset. - - Parameters - ---------- - path : str - The path to the dataset. - """ - ds = open_dataset(path) - self.dataset_shape = ds.shape - self.variables_names = ds.variables - assert len(self.variables_names) == ds.shape[1], self.dataset_shape - self.dates = ds.dates - - self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) - - def check_missing_dates(expected: list[np.datetime64]) -> None: - """Check if the missing dates in the dataset match the expected dates. - - Parameters - ---------- - expected : list of np.datetime64 - The expected missing dates. - - Raises - ------ - ValueError - If the missing dates in the dataset do not match the expected dates. - """ - import zarr - - z = zarr.open(path, "r") - missing_dates = z.attrs.get("missing_dates", []) - missing_dates = sorted([np.datetime64(d) for d in missing_dates]) - if missing_dates != expected: - LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") - LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") - LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") - raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") - - check_missing_dates(self.missing_dates) - - -class Patch(Actor): - """A class to apply patches to a dataset.""" - - def __init__(self, path: str, options: dict = None, **kwargs: Any): - """Initialize a Patch instance. - - Parameters - ---------- - path : str - The path to the dataset. - options : dict, optional - The patch options. - """ - self.path = path - self.options = options or {} - - def run(self) -> None: - """Run the patch.""" - from anemoi.datasets.create.gridded.patch import apply_patch - - apply_patch(self.path, **self.options) - - -class Size(Actor): - """A class to compute the size of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Size instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the size computation.""" - from anemoi.datasets.create.gridded.size import compute_directory_sizes - - metadata = compute_directory_sizes(self.path) - self.update_metadata(**metadata) - - # Look for constant fields - ds = open_dataset(self.path) - constants = ds.computed_constant_fields() - - variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() - for k in constants: - variables_metadata[k]["constant_in_time"] = True - - self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) - - -class HasRegistryMixin: - """A mixin class to provide registry functionality.""" - - @cached_property - def registry(self) -> Any: - """Get the registry.""" - from anemoi.datasets.create.gridded.zarr import ZarrBuiltRegistry - - return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) - - -class HasStatisticTempMixin: - """A mixin class to provide temporary statistics functionality.""" - - @cached_property - def tmp_statistics(self) -> TmpStatistics: - """Get the temporary statistics.""" - directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") - return TmpStatistics(directory) - - -class HasElementForDataMixin: - """A mixin class to provide element creation functionality for data.""" - - def create_elements(self, config: Any) -> None: - """Create elements for the dataset. - - Parameters - ---------- - config : Any - The configuration. - """ - assert self.registry - assert self.tmp_statistics - - LOG.info(dict(config.dates)) - - self.groups = Groups(**config.dates) - LOG.info(self.groups) - - self.output = build_output(config.output, parent=self) - - self.input = InputBuilder( - config.input, - data_sources=config.get("data_sources", {}), - order_by=self.output.order_by, - flatten_grid=self.output.flatten_grid, - remapping=build_remapping(self.output.remapping), - use_grib_paramid=config.build.use_grib_paramid, - ) - LOG.debug("✅ INPUT_BUILDER") - LOG.debug(self.input) - - -class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to initialize a new dataset.""" - - dataset_class = NewDataset - - def __init__( - self, - path: str, - config: dict, - check_name: bool = False, - overwrite: bool = False, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - test: bool = False, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize an Init instance. - - Parameters - ---------- - path : str - The path to the dataset. - config : dict - The configuration. - check_name : bool, optional - Whether to check the dataset name. - overwrite : bool, optional - Whether to overwrite the existing dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - test : bool, optional - Whether this is a test. - cache : Optional[str], optional - The cache directory. - """ - if _path_readable(path) and not overwrite: - raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") - - super().__init__(path, cache=cache) - self.config = config - self.check_name = check_name - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.test = test - - self.main_config = loader_config(config, is_test=test) - - # self.registry.delete() ?? - self.tmp_statistics.delete() - - assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by - self.create_elements(self.main_config) - - LOG.info(f"Groups: {self.groups}") - - one_date = self.groups.one_date() - # assert False, (type(one_date), type(self.groups)) - self.minimal_input = self.input.select(one_date) - LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") - LOG.info(self.minimal_input) - - def run(self) -> int: - """Run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - with self._cache_context(): - return self._run() - - def _run(self) -> int: - """Internal method to run the initialization. - - Returns - ------- - int - The number of groups to process. - """ - """Create an empty dataset of the right final shape. - - Read a small part of the data to get the shape of the data and the resolution and more metadata. - """ - - LOG.info("Config loaded ok:") - # LOG.info(self.main_config) - - dates = self.groups.provider.values - frequency = self.groups.provider.frequency - missing = self.groups.provider.missing - - assert isinstance(frequency, datetime.timedelta), frequency - - LOG.info(f"Found {len(dates)} datetimes.") - LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(missing)}") - lengths = tuple(len(g) for g in self.groups) - - variables = self.minimal_input.variables - LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") - - variables_with_nans = self.main_config.statistics.get("allow_nans", []) - - ensembles = self.minimal_input.ensembles - LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") - - grid_points = self.minimal_input.grid_points - LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") - - resolution = self.minimal_input.resolution - LOG.info(f"{resolution=}") - - coords = self.minimal_input.coords - coords["dates"] = dates - total_shape = self.minimal_input.shape - total_shape[0] = len(dates) - LOG.info(f"total_shape = {total_shape}") - - chunks = self.output.get_chunking(coords) - LOG.info(f"{chunks=}") - dtype = self.output.dtype - - LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") - - metadata = {} - metadata["uuid"] = str(uuid.uuid4()) - - metadata.update(self.main_config.get("add_metadata", {})) - - metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() - - recipe = sanitise(self.main_config.get_serialisable_dict()) - - # Remove stuff added by prepml - for k in [ - "build_dataset", - "config_format_version", - "config_path", - "dataset_status", - "ecflow", - "metadata", - "platform", - "reading_chunks", - "upload", - ]: - recipe.pop(k, None) - - metadata["recipe"] = recipe - - metadata["description"] = self.main_config.description - metadata["licence"] = self.main_config["licence"] - metadata["attribution"] = self.main_config["attribution"] - - metadata["remapping"] = self.output.remapping - metadata["order_by"] = self.output.order_by_as_list - metadata["flatten_grid"] = self.output.flatten_grid - - metadata["ensemble_dimension"] = len(ensembles) - metadata["variables"] = variables - metadata["variables_with_nans"] = variables_with_nans - metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) - metadata["resolution"] = resolution - - metadata["data_request"] = self.minimal_input.data_request - metadata["field_shape"] = self.minimal_input.field_shape - metadata["proj_string"] = self.minimal_input.proj_string - metadata["variables_metadata"] = self.minimal_input.variables_metadata - - metadata["start_date"] = dates[0].isoformat() - metadata["end_date"] = dates[-1].isoformat() - metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in missing] - - metadata["version"] = VERSION - - self.dataset.check_name( - raise_exception=self.check_name, - is_test=self.test, - resolution=resolution, - dates=dates, - frequency=frequency, - ) - - if len(dates) != total_shape[0]: - raise ValueError( - f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " - f"does not match data shape {total_shape[0]}. {total_shape=}" - ) - - dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) - - metadata.update(self.main_config.get("force_metadata", {})) - - ############################################################### - # write metadata - ############################################################### - - self.update_metadata(**metadata) - - self.dataset.add_dataset( - name="data", - chunks=chunks, - dtype=dtype, - shape=total_shape, - dimensions=("time", "variable", "ensemble", "cell"), - ) - self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) - self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) - self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) - - self.registry.create(lengths=lengths) - self.tmp_statistics.create(exist_ok=False) - self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) - - statistics_start, statistics_end = build_statistics_dates( - dates, - self.main_config.statistics.get("start"), - self.main_config.statistics.get("end"), - ) - self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) - LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") - - self.registry.add_to_history("init finished") - - assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - - # Return the number of groups to process, so we can show a nice progress bar - return len(lengths) - - -class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - """A class to load data into a dataset.""" - - def __init__( - self, - path: str, - parts: str | None = None, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - cache: str | None = None, - **kwargs: Any, - ): - """Initialize a Load instance. - - Parameters - ---------- - path : str - The path to the dataset. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - cache : Optional[str], optional - The cache directory. - """ - super().__init__(path, cache=cache) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.progress = progress - self.parts = parts - self.dataset = WritableDataset(self.path) - - self.main_config = self.dataset.get_main_config() - self.create_elements(self.main_config) - self.read_dataset_metadata(self.dataset.path) - - total = len(self.registry.get_flags()) - self.chunk_filter = ChunkFilter(parts=self.parts, total=total) - - self.data_array = self.dataset.data_array - self.n_groups = len(self.groups) - - def run(self) -> None: - """Run the data loading.""" - with self._cache_context(): - self._run() - - def _run(self) -> None: - """Internal method to run the data loading.""" - for igroup, group in enumerate(self.groups): - if not self.chunk_filter(igroup): - continue - if self.registry.get_flag(igroup): - LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") - continue - - # assert isinstance(group[0], datetime.datetime), type(group[0]) - LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - - result = self.input.select(argument=group) - assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) - - # There are several groups. - # There is one result to load for each group. - self.load_result(result) - self.registry.set_flag(igroup) - - self.registry.add_provenance(name="provenance_load") - self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) - - self.dataset.print_info() - - def load_result(self, result: Any) -> None: - """Load the result into the dataset. - - Parameters - ---------- - result : Any - The result to load. - """ - # There is one cube to load for each result. - dates = list(result.group_of_dates) - - LOG.debug(f"Loading cube for {len(dates)} dates") - - cube = result.get_cube() - shape = cube.extended_user_shape - dates_in_data = cube.user_coords["valid_datetime"] - - LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") - - def check_shape(cube, dates, dates_in_data): - if cube.extended_user_shape[0] != len(dates): - print( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - print("Requested dates", compress_dates(dates)) - print("Cube dates", compress_dates(dates_in_data)) - - a = {as_datetime(_) for _ in dates} - b = {as_datetime(_) for _ in dates_in_data} - - print("Missing dates", compress_dates(a - b)) - print("Extra dates", compress_dates(b - a)) - - raise ValueError( - f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" - ) - - check_shape(cube, dates, dates_in_data) - - def check_dates_in_data(dates_in_data, requested_dates): - _requested_dates = [np.datetime64(_) for _ in requested_dates] - _dates_in_data = [np.datetime64(_) for _ in dates_in_data] - if _dates_in_data != _requested_dates: - LOG.error("Dates in data are not the requested ones:") - - dates_in_data = set(dates_in_data) - requested_dates = set(requested_dates) - - missing = sorted(requested_dates - dates_in_data) - extra = sorted(dates_in_data - requested_dates) - - if missing: - LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") - if extra: - LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") - - raise ValueError("Dates in data are not the requested ones") - - check_dates_in_data(dates_in_data, dates) - - def dates_to_indexes(dates, all_dates): - x = np.array(dates, dtype=np.datetime64) - y = np.array(all_dates, dtype=np.datetime64) - bitmap = np.isin(x, y) - return np.where(bitmap)[0] - - indexes = dates_to_indexes(self.dates, dates_in_data) - - array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) - LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") - self.load_cube(cube, array) - - stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) - self.tmp_statistics.write(indexes, stats, dates=dates_in_data) - LOG.info("Flush data array") - array.flush() - LOG.info("Flushed data array") - - def _get_allow_nans(self) -> bool | list: - """Get the allow_nans configuration. - - Returns - ------- - bool | list - The allow_nans configuration. - """ - config = self.main_config - if "allow_nans" in config.build: - return config.build.allow_nans - - return config.statistics.get("allow_nans", []) - - def load_cube(self, cube: Any, array: ViewCacheArray) -> None: - """Load the cube into the array. - - Parameters - ---------- - cube : Any - The cube to load. - array : ViewCacheArray - The array to load into. - """ - # There are several cubelets for each cube - start = time.time() - load = 0 - save = 0 - - reading_chunks = None - total = cube.count(reading_chunks) - LOG.debug(f"Loading datacube: {cube}") - - def position(x: Any) -> int | None: - if isinstance(x, str) and "/" in x: - x = x.split("/") - return int(x[0]) - return None - - bar = tqdm.tqdm( - iterable=cube.iterate_cubelets(reading_chunks), - total=total, - desc=f"Loading datacube {cube}", - position=position(self.parts), - ) - for i, cubelet in enumerate(bar): - bar.set_description(f"Loading {i}/{total}") - - now = time.time() - data = cubelet.to_numpy() - local_indexes = cubelet.coords - load += time.time() - now - - name = self.variables_names[local_indexes[1]] - check_data_values( - data[:], - name=name, - log=[i, data.shape, local_indexes], - allow_nans=self._get_allow_nans(), - ) - - now = time.time() - array[local_indexes] = data - save += time.time() - now - - now = time.time() - save += time.time() - now - LOG.debug( - f"Elapsed: {seconds_to_human(time.time() - start)}, " - f"load time: {seconds_to_human(load)}, " - f"write time: {seconds_to_human(save)}." - ) - - -class Cleanup(Actor, HasRegistryMixin, HasStatisticTempMixin): - """A class to clean up temporary data and registry entries.""" - - def __init__( - self, - path: str, - statistics_temp_dir: str | None = None, - delta: list = [], - use_threads: bool = False, - **kwargs: Any, - ): - """Initialize a Cleanup instance. - - Parameters - ---------- - path : str - The path to the dataset. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - delta : list, optional - The delta values. - use_threads : bool, optional - Whether to use threads. - """ - super().__init__(path) - self.use_threads = use_threads - self.statistics_temp_dir = statistics_temp_dir - self.additinon_temp_dir = statistics_temp_dir - self.actors = [ - _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) - for d in delta - ] - - def run(self) -> None: - """Run the cleanup.""" - - self.tmp_statistics.delete() - self.registry.clean() - for actor in self.actors: - actor.cleanup() - - -class Verify(Actor): - """A class to verify the integrity of a dataset.""" - - def __init__(self, path: str, **kwargs: Any): - """Initialize a Verify instance. - - Parameters - ---------- - path : str - The path to the dataset. - """ - super().__init__(path) - - def run(self) -> None: - """Run the verification.""" - LOG.info(f"Verifying dataset at {self.path}") - LOG.info(str(self.dataset.anemoi_dataset)) - - -class AdditionsMixin: - """A mixin class to handle dataset additions.""" - - def skip(self) -> bool: - """Check if the additions should be skipped. - - Returns - ------- - bool - Whether to skip the additions. - """ - frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - if not self.delta.total_seconds() % frequency.total_seconds() == 0: - LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") - return True - - if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: - LOG.warning(f"Additions are disabled for {self.path} in the recipe.") - return True - - return False - - @cached_property - def tmp_storage_path(self) -> str: - """Get the path to the temporary storage.""" - name = "storage_for_additions" - if self.delta: - name += frequency_to_string(self.delta) - return os.path.join(f"{self.path}.{name}.tmp") - - def read_from_dataset(self) -> None: - """Read data from the dataset.""" - self.variables = self.dataset.anemoi_dataset.variables - self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) - start = self.dataset.zarr_metadata["statistics_start_date"] - end = self.dataset.zarr_metadata["statistics_end_date"] - self.start = datetime.datetime.fromisoformat(start) - self.end = datetime.datetime.fromisoformat(end) - - ds = open_dataset(self.path, start=self.start, end=self.end) - self.dates = ds.dates - self.total = len(self.dates) - - idelta = self.delta.total_seconds() // self.frequency.total_seconds() - assert int(idelta) == idelta, idelta - idelta = int(idelta) - self.ds = DeltaDataset(ds, idelta) - - -class DeltaDataset: - """A class to represent a dataset with delta values.""" - - def __init__(self, ds: Any, idelta: int): - """Initialize a DeltaDataset instance. - - Parameters - ---------- - ds : Any - The dataset. - idelta : int - The delta value. - """ - self.ds = ds - self.idelta = idelta - - def __getitem__(self, i: int) -> Any: - """Get an item from the dataset. - - Parameters - ---------- - i : int - The index. - - Returns - ------- - Any - The item. - """ - j = i - self.idelta - if j < 0: - raise MissingDateError(f"Missing date {j}") - return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] - - -class _InitAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to initialize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize an _InitAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - def run(self) -> None: - """Run the additions initialization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) - self.tmp_storage.delete() - self.tmp_storage.create() - LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") - - def cleanup(self) -> None: - """Clean up the temporary storage.""" - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - self.tmp_storage.delete() - LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") - - -class _LoadAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to run dataset additions.""" - - def __init__( - self, - path: str, - delta: str, - parts: str | None = None, - use_threads: bool = False, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a _LoadAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - parts : Optional[str], optional - The parts to load. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - self.parts = parts - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Writing in {self.tmp_storage_path}") - - def run(self) -> None: - """Run the additions.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}") - return - - self.read_from_dataset() - - chunk_filter = ChunkFilter(parts=self.parts, total=self.total) - for i in range(0, self.total): - if not chunk_filter(i): - continue - date = self.dates[i] - try: - arr = self.ds[i] - stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) - self.tmp_storage.add([date, i, stats], key=date) - except MissingDateError: - self.tmp_storage.add([date, i, "missing"], key=date) - self.tmp_storage.flush() - LOG.debug(f"Dataset {self.path} additions run.") - - def allow_nans(self) -> bool: - """Check if NaNs are allowed. - - Returns - ------- - bool - Whether NaNs are allowed. - """ - if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): - return True - - variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) - if variables_with_nans is not None: - return variables_with_nans - warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") - return True - - -class _FinaliseAdditions(Actor, HasRegistryMixin, AdditionsMixin): - """A class to finalize dataset additions.""" - - def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): - """Initialize a _FinaliseAdditions instance. - - Parameters - ---------- - path : str - The path to the dataset. - delta : str - The delta value. - use_threads : bool, optional - Whether to use threads. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.delta = frequency_to_timedelta(delta) - self.use_threads = use_threads - self.progress = progress - - self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) - LOG.info(f"Reading from {self.tmp_storage_path}.") - - def run(self) -> None: - """Run the additions finalization.""" - if self.skip(): - LOG.info(f"Skipping delta={self.delta}.") - return - - self.read_from_dataset() - - shape = (len(self.dates), len(self.variables)) - agg = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - has_nans=np.full(shape, False, dtype=np.bool_), - ) - LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") - - found = set() - ifound = set() - missing = set() - for _date, (date, i, stats) in self.tmp_storage.items(): - assert _date == date - if stats == "missing": - missing.add(date) - continue - - assert date not in found, f"Duplicates found {date}" - found.add(date) - ifound.add(i) - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k][i, ...] = stats[k] - - assert len(found) + len(missing) == len(self.dates), ( - len(found), - len(missing), - len(self.dates), - ) - assert found.union(missing) == set(self.dates), ( - found, - missing, - set(self.dates), - ) - - if len(ifound) < 2: - LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") - self.tmp_storage.delete() - return - - mask = sorted(list(ifound)) - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - agg[k] = agg[k][mask, ...] - - for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: - assert agg[k].shape == agg["count"].shape, ( - agg[k].shape, - agg["count"].shape, - ) - - minimum = np.nanmin(agg["minimum"], axis=0) - maximum = np.nanmax(agg["maximum"], axis=0) - sums = np.nansum(agg["sums"], axis=0) - squares = np.nansum(agg["squares"], axis=0) - count = np.nansum(agg["count"], axis=0) - has_nans = np.any(agg["has_nans"], axis=0) - - assert sums.shape == count.shape - assert sums.shape == squares.shape - assert sums.shape == minimum.shape - assert sums.shape == maximum.shape - assert sums.shape == has_nans.shape - - mean = sums / count - assert sums.shape == mean.shape - - x = squares / count - mean * mean - # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 - # remove negative variance due to numerical errors - for i, name in enumerate(self.variables): - x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) - check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) - - stdev = np.sqrt(x) - assert sums.shape == stdev.shape - - self.summary = Summary( - minimum=minimum, - maximum=maximum, - mean=mean, - count=count, - sums=sums, - squares=squares, - stdev=stdev, - variables_names=self.variables, - has_nans=has_nans, - ) - LOG.info(f"Dataset {self.path} additions finalised.") - # self.check_statistics() - self._write(self.summary) - self.tmp_storage.delete() - - def _write(self, summary: Summary) -> None: - """Write the summary to the dataset. - - Parameters - ---------- - summary : Summary - The summary to write. - """ - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" - self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) - self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") - LOG.debug(f"Wrote additions in {self.path}") - - -def multi_addition(cls: type) -> type: - """Create a class to handle multiple additions. - - Parameters - ---------- - cls : type - The class to handle additions. - - Returns - ------- - type - The class to handle multiple additions. - """ - - class MultiAdditions: - def __init__(self, *args, **kwargs: Any): - self.actors = [] - - for k in kwargs.pop("delta", []): - self.actors.append(cls(*args, delta=k, **kwargs)) - - if not self.actors: - LOG.warning("No delta found in kwargs, no additions will be computed.") - - def run(self) -> None: - """Run the additions.""" - for actor in self.actors: - actor.run() - - return MultiAdditions - - -InitAdditions = multi_addition(_InitAdditions) -LoadAdditions = multi_addition(_LoadAdditions) -FinaliseAdditions = multi_addition(_FinaliseAdditions) - - -class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin): - """A class to compute statistics for a dataset.""" - - def __init__( - self, - path: str, - use_threads: bool = False, - statistics_temp_dir: str | None = None, - progress: Any = None, - **kwargs: Any, - ): - """Initialize a Statistics instance. - - Parameters - ---------- - path : str - The path to the dataset. - use_threads : bool, optional - Whether to use threads. - statistics_temp_dir : Optional[str], optional - The directory for temporary statistics. - progress : Any, optional - The progress indicator. - """ - super().__init__(path) - self.use_threads = use_threads - self.progress = progress - self.statistics_temp_dir = statistics_temp_dir - - def run(self) -> None: - """Run the statistics computation.""" - start, end = ( - self.dataset.zarr_metadata["statistics_start_date"], - self.dataset.zarr_metadata["statistics_end_date"], - ) - start, end = np.datetime64(start), np.datetime64(end) - dates = self.dataset.anemoi_dataset.dates - - assert type(dates[0]) is type(start), (type(dates[0]), type(start)) - - dates = [d for d in dates if d >= start and d <= end] - dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] - variables = self.dataset.anemoi_dataset.variables - stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) - - LOG.info(stats) - - if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") - - for k in [ - "mean", - "stdev", - "minimum", - "maximum", - "sums", - "squares", - "count", - "has_nans", - ]: - self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) - - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote statistics in {self.path}") - - @cached_property - def allow_nans(self) -> bool | list: - """Check if NaNs are allowed.""" - import zarr - - z = zarr.open(self.path, mode="r") - if "allow_nans" in z.attrs: - return z.attrs["allow_nans"] - - if "variables_with_nans" in z.attrs: - return z.attrs["variables_with_nans"] - - warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") - return True - - -def chain(tasks: list) -> type: - """Create a class to chain multiple tasks. - - Parameters - ---------- - tasks : list - The list of tasks to chain. - - Returns - ------- - type - The class to chain multiple tasks. - """ - - class Chain(Actor): - def __init__(self, **kwargs: Any): - self.kwargs = kwargs - - def run(self) -> None: - """Run the chained tasks.""" - for cls in tasks: - t = cls(**self.kwargs) - t.run() - - return Chain - - -def creator_factory(name: str, trace: str | None = None, **kwargs: Any) -> Any: - """Create a dataset creator. - - Parameters - ---------- - name : str - The name of the creator. - trace : Optional[str], optional - The trace file. - **kwargs - Additional arguments for the creator. - - Returns - ------- - Any - The dataset creator. - """ - if trace: - - enable_trace(trace) - - cls = dict( - init=Init, - load=Load, - size=Size, - patch=Patch, - statistics=Statistics, - finalise=chain([Statistics, Size, Cleanup]), - cleanup=Cleanup, - verify=Verify, - init_additions=InitAdditions, - load_additions=LoadAdditions, - finalise_additions=chain([FinaliseAdditions, Size]), - additions=chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup]), - )[name] - LOG.debug(f"Creating {cls.__name__} with {kwargs}") - return cls(**kwargs) - - -def validate_config(config: Any) -> None: - - import json - - import jsonschema - - def _tidy(d): - if isinstance(d, dict): - return {k: _tidy(v) for k, v in d.items()} - - if isinstance(d, list): - return [_tidy(v) for v in d if v is not None] - - # jsonschema does not support datetime.date - if isinstance(d, datetime.datetime): - return d.isoformat() - - if isinstance(d, datetime.date): - return d.isoformat() - - return d - - # https://json-schema.org - - with open( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "schemas", - "recipe.json", - ) - ) as f: - schema = json.load(f) - - try: - jsonschema.validate(instance=_tidy(config), schema=schema) - except jsonschema.exceptions.ValidationError as e: - LOG.error("❌ Config validation failed (jsonschema):") - LOG.error(e.message) - raise diff --git a/src/anemoi/datasets/create/gridded/additions.py b/src/anemoi/datasets/create/gridded/additions.py new file mode 100644 index 000000000..4f949cf31 --- /dev/null +++ b/src/anemoi/datasets/create/gridded/additions.py @@ -0,0 +1,413 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +import warnings +from functools import cached_property +from typing import Any + +import numpy as np +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets import MissingDateError +from anemoi.datasets import open_dataset +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.persistent import build_storage +from anemoi.datasets.create.statistics import Summary +from anemoi.datasets.create.statistics import check_variance +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.statistics import fix_variance + +from ..gridded.tasks import FieldTask +from ..gridded.tasks import HasRegistryMixin + +LOG = logging.getLogger(__name__) + + +class AdditionsMixin: + """A mixin class to handle dataset additions.""" + + def skip(self) -> bool: + """Check if the additions should be skipped. + + Returns + ------- + bool + Whether to skip the additions. + """ + frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + if not self.delta.total_seconds() % frequency.total_seconds() == 0: + LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.") + return True + + if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False: + LOG.warning(f"Additions are disabled for {self.path} in the recipe.") + return True + + return False + + @cached_property + def tmp_storage_path(self) -> str: + """Get the path to the temporary storage.""" + name = "storage_for_additions" + if self.delta: + name += frequency_to_string(self.delta) + return os.path.join(f"{self.path}.{name}.tmp") + + def read_from_dataset(self) -> None: + """Read data from the dataset.""" + self.variables = self.dataset.anemoi_dataset.variables + self.frequency = frequency_to_timedelta(self.dataset.anemoi_dataset.frequency) + start = self.dataset.zarr_metadata["statistics_start_date"] + end = self.dataset.zarr_metadata["statistics_end_date"] + self.start = datetime.datetime.fromisoformat(start) + self.end = datetime.datetime.fromisoformat(end) + + ds = open_dataset(self.path, start=self.start, end=self.end) + self.dates = ds.dates + self.total = len(self.dates) + + idelta = self.delta.total_seconds() // self.frequency.total_seconds() + assert int(idelta) == idelta, idelta + idelta = int(idelta) + self.ds = DeltaDataset(ds, idelta) + + +class DeltaDataset: + """A class to represent a dataset with delta values.""" + + def __init__(self, ds: Any, idelta: int): + """Initialize a DeltaDataset instance. + + Parameters + ---------- + ds : Any + The dataset. + idelta : int + The delta value. + """ + self.ds = ds + self.idelta = idelta + + def __getitem__(self, i: int) -> Any: + """Get an item from the dataset. + + Parameters + ---------- + i : int + The index. + + Returns + ------- + Any + The item. + """ + j = i - self.idelta + if j < 0: + raise MissingDateError(f"Missing date {j}") + return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] + + +class _InitAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to initialize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize an _InitAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + def run(self) -> None: + """Run the additions initialization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) + self.tmp_storage.delete() + self.tmp_storage.create() + LOG.info(f"Dataset {self.tmp_storage_path} additions initialised.") + + def cleanup(self) -> None: + """Clean up the temporary storage.""" + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + self.tmp_storage.delete() + LOG.info(f"Cleaned temporary storage {self.tmp_storage_path}") + + +class _LoadAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to run dataset additions.""" + + def __init__( + self, + path: str, + delta: str, + parts: str | None = None, + use_threads: bool = False, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a _LoadAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + self.parts = parts + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Writing in {self.tmp_storage_path}") + + def run(self) -> None: + """Run the additions.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}") + return + + self.read_from_dataset() + + chunk_filter = ChunkFilter(parts=self.parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i] + stats = compute_statistics(arr, self.variables, allow_nans=self.allow_nans) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.debug(f"Dataset {self.path} additions run.") + + def allow_nans(self) -> bool: + """Check if NaNs are allowed. + + Returns + ------- + bool + Whether NaNs are allowed. + """ + if self.dataset.anemoi_dataset.metadata.get("allow_nans", False): + return True + + variables_with_nans = self.dataset.anemoi_dataset.metadata.get("variables_with_nans", None) + if variables_with_nans is not None: + return variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, assuming nans allowed.") + return True + + +class _FinaliseAdditions(FieldTask, HasRegistryMixin, AdditionsMixin): + """A class to finalize dataset additions.""" + + def __init__(self, path: str, delta: str, use_threads: bool = False, progress: Any = None, **kwargs: Any): + """Initialize a _FinaliseAdditions instance. + + Parameters + ---------- + path : str + The path to the dataset. + delta : str + The delta value. + use_threads : bool, optional + Whether to use threads. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.delta = frequency_to_timedelta(delta) + self.use_threads = use_threads + self.progress = progress + + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=False) + LOG.info(f"Reading from {self.tmp_storage_path}.") + + def run(self) -> None: + """Run the additions finalization.""" + if self.skip(): + LOG.info(f"Skipping delta={self.delta}.") + return + + self.read_from_dataset() + + shape = (len(self.dates), len(self.variables)) + agg = dict( + minimum=np.full(shape, np.nan, dtype=np.float64), + maximum=np.full(shape, np.nan, dtype=np.float64), + sums=np.full(shape, np.nan, dtype=np.float64), + squares=np.full(shape, np.nan, dtype=np.float64), + count=np.full(shape, -1, dtype=np.int64), + has_nans=np.full(shape, False, dtype=np.bool_), + ) + LOG.debug(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") + + found = set() + ifound = set() + missing = set() + for _date, (date, i, stats) in self.tmp_storage.items(): + assert _date == date + if stats == "missing": + missing.add(date) + continue + + assert date not in found, f"Duplicates found {date}" + found.add(date) + ifound.add(i) + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k][i, ...] = stats[k] + + assert len(found) + len(missing) == len(self.dates), ( + len(found), + len(missing), + len(self.dates), + ) + assert found.union(missing) == set(self.dates), ( + found, + missing, + set(self.dates), + ) + + if len(ifound) < 2: + LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") + self.tmp_storage.delete() + return + + mask = sorted(list(ifound)) + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k] = agg[k][mask, ...] + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + assert agg[k].shape == agg["count"].shape, ( + agg[k].shape, + agg["count"].shape, + ) + + minimum = np.nanmin(agg["minimum"], axis=0) + maximum = np.nanmax(agg["maximum"], axis=0) + sums = np.nansum(agg["sums"], axis=0) + squares = np.nansum(agg["squares"], axis=0) + count = np.nansum(agg["count"], axis=0) + has_nans = np.any(agg["has_nans"], axis=0) + + assert sums.shape == count.shape + assert sums.shape == squares.shape + assert sums.shape == minimum.shape + assert sums.shape == maximum.shape + assert sums.shape == has_nans.shape + + mean = sums / count + assert sums.shape == mean.shape + + x = squares / count - mean * mean + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + # remove negative variance due to numerical errors + for i, name in enumerate(self.variables): + x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) + check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) + + stdev = np.sqrt(x) + assert sums.shape == stdev.shape + + self.summary = Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables, + has_nans=has_nans, + ) + LOG.info(f"Dataset {self.path} additions finalised.") + # self.check_statistics() + self._write(self.summary) + self.tmp_storage.delete() + + def _write(self, summary: Summary) -> None: + """Write the summary to the dataset. + + Parameters + ---------- + summary : Summary + The summary to write. + """ + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + name = f"statistics_tendencies_{frequency_to_string(self.delta)}_{k}" + self.dataset.add_dataset(name=name, array=summary[k], dimensions=("variable",)) + self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") + LOG.debug(f"Wrote additions in {self.path}") + + +def multi_addition(cls: type) -> type: + """Create a class to handle multiple additions. + + Parameters + ---------- + cls : type + The class to handle additions. + + Returns + ------- + type + The class to handle multiple additions. + """ + + class MultiAdditions: + def __init__(self, *args, **kwargs: Any): + self.tasks = [] + + for k in kwargs.pop("delta", []): + self.tasks.append(cls(*args, delta=k, **kwargs)) + + if not self.tasks: + LOG.warning("No delta found in kwargs, no additions will be computed.") + + def run(self) -> None: + """Run the additions.""" + for actor in self.tasks: + actor.run() + + return MultiAdditions + + +InitAdditions = multi_addition(_InitAdditions) +LoadAdditions = multi_addition(_LoadAdditions) +FinaliseAdditions = multi_addition(_FinaliseAdditions) diff --git a/src/anemoi/datasets/create/gridded/cleanup.py b/src/anemoi/datasets/create/gridded/cleanup.py new file mode 100644 index 000000000..49f1728ce --- /dev/null +++ b/src/anemoi/datasets/create/gridded/cleanup.py @@ -0,0 +1,60 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from ..gridded.tasks import FieldTask +from ..gridded.tasks import HasRegistryMixin +from ..gridded.tasks import HasStatisticTempMixin +from .additions import _InitAdditions + +LOG = logging.getLogger(__name__) + + +class Cleanup(FieldTask, HasRegistryMixin, HasStatisticTempMixin): + """A class to clean up temporary data and registry entries.""" + + def __init__( + self, + path: str, + statistics_temp_dir: str | None = None, + delta: list = [], + use_threads: bool = False, + **kwargs: Any, + ): + """Initialize a Cleanup instance. + + Parameters + ---------- + path : str + The path to the dataset. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + delta : list, optional + The delta values. + use_threads : bool, optional + Whether to use threads. + """ + super().__init__(path) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.additinon_temp_dir = statistics_temp_dir + self.tasks = [ + _InitAdditions(path, delta=d, use_threads=use_threads, statistics_temp_dir=statistics_temp_dir) + for d in delta + ] + + def run(self) -> None: + """Run the cleanup.""" + + self.tmp_statistics.delete() + self.registry.clean() + for actor in self.tasks: + actor.cleanup() diff --git a/src/anemoi/datasets/create/gridded/context.py b/src/anemoi/datasets/create/gridded/context.py index a20e51133..f16d84bbc 100644 --- a/src/anemoi/datasets/create/gridded/context.py +++ b/src/anemoi/datasets/create/gridded/context.py @@ -8,26 +8,31 @@ # nor does it submit to any jurisdiction. +import logging from typing import Any +from anemoi.transform.fields import new_field_with_metadata +from anemoi.transform.fields import new_fieldlist_from_list from earthkit.data.core.order import build_remapping -from anemoi.datasets.create.gridded.result import GriddedResult from anemoi.datasets.create.input.context import Context +LOG = logging.getLogger(__name__) -class GriddedContext(Context): + +class FieldContext(Context): def __init__( self, /, - argument: Any, order_by: str, flatten_grid: bool, remapping: dict[str, Any], use_grib_paramid: bool, ) -> None: - super().__init__(argument) + + super().__init__() + self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) @@ -45,10 +50,29 @@ def source_argument(self, argument: Any) -> Any: def filter_argument(self, argument: Any) -> Any: return argument - def create_result(self, data): - return GriddedResult(self, data) + def create_result(self, argument, data): + from anemoi.datasets.create.gridded.result import FieldResult + + return FieldResult(self, argument, data) def matching_dates(self, filtering_dates, group_of_dates: Any) -> Any: from anemoi.datasets.dates.groups import GroupOfDates return GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) + + def origin(self, data: Any, action: Any, action_arguments: Any) -> Any: + + origin = action.origin() + + result = [] + for fs in data: + previous = fs.metadata("anemoi_origin", default=None) + fall_through = fs.metadata("anemoi_fall_through", default=False) + if fall_through: + # The field has pass unchanges in a filter + result.append(fs) + else: + anemoi_origin = origin.combine(previous, action, action_arguments) + result.append(new_field_with_metadata(fs, anemoi_origin=anemoi_origin)) + + return new_fieldlist_from_list(result) diff --git a/src/anemoi/datasets/create/gridded/init.py b/src/anemoi/datasets/create/gridded/init.py new file mode 100644 index 000000000..11f5a22b6 --- /dev/null +++ b/src/anemoi/datasets/create/gridded/init.py @@ -0,0 +1,293 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import uuid +from typing import Any + +import zarr +from anemoi.utils.sanitise import sanitise + +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.utils import normalize_and_check_dates + +from ..gridded.tasks import FieldTask +from ..gridded.tasks import HasElementForDataMixin +from ..gridded.tasks import HasRegistryMixin +from ..gridded.tasks import HasStatisticTempMixin +from ..gridded.tasks import NewDataset +from ..gridded.tasks import _build_statistics_dates + +LOG = logging.getLogger(__name__) + +VERSION = "0.30" + + +def _path_readable(path: str) -> bool: + """Check if the path is readable. + + Parameters + ---------- + path : str + The path to check. + + Returns + ------- + bool + True if the path is readable, False otherwise. + """ + + try: + zarr.open(path, "r") + return True + except zarr.errors.PathNotFoundError: + return False + + +class Init(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to initialize a new dataset.""" + + dataset_class = NewDataset + + def __init__( + self, + path: str, + config: dict, + check_name: bool = False, + overwrite: bool = False, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + test: bool = False, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize an Init instance. + + Parameters + ---------- + path : str + The path to the dataset. + config : dict + The configuration. + check_name : bool, optional + Whether to check the dataset name. + overwrite : bool, optional + Whether to overwrite the existing dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + test : bool, optional + Whether this is a test. + cache : Optional[str], optional + The cache directory. + """ + if _path_readable(path) and not overwrite: + raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + + super().__init__(path, cache=cache) + self.config = config + self.check_name = check_name + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.test = test + + self.main_config = loader_config(config, is_test=test) + + # self.registry.delete() ?? + self.tmp_statistics.delete() + + assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by + self.create_elements(self.main_config) + + LOG.info(f"Groups: {self.groups}") + + # window = self.main_config.dates.get("window") + + one_date = self.groups.one_date() + + self.minimal_input = self.input.select(self.context, one_date) + + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") + LOG.info(self.minimal_input) + + def run(self) -> int: + """Run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + with self._cache_context(): + return self._run() + + def _run(self) -> int: + """Internal method to run the initialization. + + Returns + ------- + int + The number of groups to process. + """ + """Create an empty dataset of the right final shape. + + Read a small part of the data to get the shape of the data and the resolution and more metadata. + """ + + LOG.info("Config loaded ok:") + # LOG.info(self.main_config) + + dates = self.groups.provider.values + frequency = self.groups.provider.frequency + missing = self.groups.provider.missing + + assert isinstance(frequency, datetime.timedelta), frequency + + LOG.info(f"Found {len(dates)} datetimes.") + LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") + LOG.info(f"Missing dates: {len(missing)}") + lengths = tuple(len(g) for g in self.groups) + + variables = self.minimal_input.variables + LOG.info(f"Found {len(variables)} variables : {','.join(variables)}.") + + variables_with_nans = self.main_config.statistics.get("allow_nans", []) + + ensembles = self.minimal_input.ensembles + LOG.info(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.") + + grid_points = self.minimal_input.grid_points + LOG.info(f"gridpoints size: {[len(i) for i in grid_points]}") + + resolution = self.minimal_input.resolution + LOG.info(f"{resolution=}") + + coords = self.minimal_input.coords + coords["dates"] = dates + total_shape = self.minimal_input.shape + total_shape[0] = len(dates) + LOG.info(f"total_shape = {total_shape}") + + chunks = self.output.get_chunking(coords) + LOG.info(f"{chunks=}") + dtype = self.output.dtype + + LOG.info(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}") + + metadata = {} + metadata["uuid"] = str(uuid.uuid4()) + + metadata.update(self.main_config.get("add_metadata", {})) + + metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict() + + recipe = sanitise(self.main_config.get_serialisable_dict()) + + # Remove stuff added by prepml + for k in [ + "build_dataset", + "config_format_version", + "config_path", + "dataset_status", + "ecflow", + "metadata", + "platform", + "reading_chunks", + "upload", + ]: + recipe.pop(k, None) + + metadata["recipe"] = recipe + + metadata["description"] = self.main_config.description + metadata["licence"] = self.main_config["licence"] + metadata["attribution"] = self.main_config["attribution"] + + metadata["remapping"] = self.output.remapping + metadata["order_by"] = self.output.order_by_as_list + metadata["flatten_grid"] = self.output.flatten_grid + + metadata["ensemble_dimension"] = len(ensembles) + metadata["variables"] = variables + metadata["variables_with_nans"] = variables_with_nans + metadata["allow_nans"] = self.main_config.build.get("allow_nans", False) + metadata["resolution"] = resolution + + metadata["data_request"] = self.minimal_input.data_request + metadata["field_shape"] = self.minimal_input.field_shape + metadata["proj_string"] = self.minimal_input.proj_string + metadata["variables_metadata"] = self.minimal_input.variables_metadata + + metadata["start_date"] = dates[0].isoformat() + metadata["end_date"] = dates[-1].isoformat() + metadata["frequency"] = frequency + metadata["missing_dates"] = [_.isoformat() for _ in missing] + metadata["origins"] = self.minimal_input.origins + + metadata["version"] = VERSION + + self.dataset.check_name( + raise_exception=self.check_name, + is_test=self.test, + resolution=resolution, + dates=dates, + frequency=frequency, + ) + + if len(dates) != total_shape[0]: + raise ValueError( + f"Final date size {len(dates)} (from {dates[0]} to {dates[-1]}, {frequency=}) " + f"does not match data shape {total_shape[0]}. {total_shape=}" + ) + + dates = normalize_and_check_dates(dates, metadata["start_date"], metadata["end_date"], metadata["frequency"]) + + metadata.update(self.main_config.get("force_metadata", {})) + + ############################################################### + # write metadata + ############################################################### + + self.update_metadata(**metadata) + + self.dataset.add_dataset( + name="data", + chunks=chunks, + dtype=dtype, + shape=total_shape, + dimensions=("time", "variable", "ensemble", "cell"), + ) + self.dataset.add_dataset(name="dates", array=dates, dimensions=("time",)) + self.dataset.add_dataset(name="latitudes", array=grid_points[0], dimensions=("cell",)) + self.dataset.add_dataset(name="longitudes", array=grid_points[1], dimensions=("cell",)) + + self.registry.create(lengths=lengths) + self.tmp_statistics.create(exist_ok=False) + self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) + + statistics_start, statistics_end = _build_statistics_dates( + dates, + self.main_config.statistics.get("start"), + self.main_config.statistics.get("end"), + ) + self.update_metadata(statistics_start_date=statistics_start, statistics_end_date=statistics_end) + LOG.info(f"Will compute statistics from {statistics_start} to {statistics_end}") + + self.registry.add_to_history("init finished") + + assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) + + # Return the number of groups to process, so we can show a nice progress bar + return len(lengths) diff --git a/src/anemoi/datasets/create/gridded/load.py b/src/anemoi/datasets/create/gridded/load.py new file mode 100644 index 000000000..53c2481cf --- /dev/null +++ b/src/anemoi/datasets/create/gridded/load.py @@ -0,0 +1,260 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import time +from typing import Any + +import numpy as np +import tqdm +from anemoi.utils.dates import as_datetime +from anemoi.utils.humanize import compress_dates +from anemoi.utils.humanize import seconds_to_human + +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.chunks import ChunkFilter +from anemoi.datasets.create.statistics import compute_statistics +from anemoi.datasets.create.writer import ViewCacheArray + +from ..gridded.tasks import FieldTask +from ..gridded.tasks import HasElementForDataMixin +from ..gridded.tasks import HasRegistryMixin +from ..gridded.tasks import HasStatisticTempMixin +from ..gridded.tasks import WritableDataset + +LOG = logging.getLogger(__name__) + + +class Load(FieldTask, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): + """A class to load data into a dataset.""" + + def __init__( + self, + path: str, + parts: str | None = None, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + cache: str | None = None, + **kwargs: Any, + ): + """Initialize a Load instance. + + Parameters + ---------- + path : str + The path to the dataset. + parts : Optional[str], optional + The parts to load. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + cache : Optional[str], optional + The cache directory. + """ + super().__init__(path, cache=cache) + self.use_threads = use_threads + self.statistics_temp_dir = statistics_temp_dir + self.progress = progress + self.parts = parts + self.dataset = WritableDataset(self.path) + + self.main_config = self.dataset.get_main_config() + self.create_elements(self.main_config) + self.read_dataset_metadata(self.dataset.path) + + total = len(self.registry.get_flags()) + self.chunk_filter = ChunkFilter(parts=self.parts, total=total) + + self.data_array = self.dataset.data_array + self.n_groups = len(self.groups) + + def run(self) -> None: + """Run the data loading.""" + with self._cache_context(): + self._run() + + def _run(self) -> None: + """Internal method to run the data loading.""" + for igroup, group in enumerate(self.groups): + if not self.chunk_filter(igroup): + continue + if self.registry.get_flag(igroup): + LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") + continue + + # assert isinstance(group[0], datetime.datetime), type(group[0]) + LOG.debug(f"Building data for group {igroup}/{self.n_groups}") + + result = self.input.select(self.context, argument=group) + assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) + + # There are several groups. + # There is one result to load for each group. + self.load_result(result) + self.registry.set_flag(igroup) + + self.registry.add_provenance(name="provenance_load") + self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) + + self.dataset.print_info() + + def load_result(self, result: Any) -> None: + """Load the result into the dataset. + + Parameters + ---------- + result : Any + The result to load. + """ + # There is one cube to load for each result. + dates = list(result.group_of_dates) + + LOG.debug(f"Loading cube for {len(dates)} dates") + + cube = result.get_cube() + shape = cube.extended_user_shape + dates_in_data = cube.user_coords["valid_datetime"] + + LOG.debug(f"Loading {shape=} in {self.data_array.shape=}") + + def check_shape(cube, dates, dates_in_data): + if cube.extended_user_shape[0] != len(dates): + print( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + print("Requested dates", compress_dates(dates)) + print("Cube dates", compress_dates(dates_in_data)) + + a = {as_datetime(_) for _ in dates} + b = {as_datetime(_) for _ in dates_in_data} + + print("Missing dates", compress_dates(a - b)) + print("Extra dates", compress_dates(b - a)) + + raise ValueError( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) + + check_shape(cube, dates, dates_in_data) + + def check_dates_in_data(dates_in_data, requested_dates): + _requested_dates = [np.datetime64(_) for _ in requested_dates] + _dates_in_data = [np.datetime64(_) for _ in dates_in_data] + if _dates_in_data != _requested_dates: + LOG.error("Dates in data are not the requested ones:") + + dates_in_data = set(dates_in_data) + requested_dates = set(requested_dates) + + missing = sorted(requested_dates - dates_in_data) + extra = sorted(dates_in_data - requested_dates) + + if missing: + LOG.error(f"Missing dates: {[_.isoformat() for _ in missing]}") + if extra: + LOG.error(f"Extra dates: {[_.isoformat() for _ in extra]}") + + raise ValueError("Dates in data are not the requested ones") + + check_dates_in_data(dates_in_data, dates) + + def dates_to_indexes(dates, all_dates): + x = np.array(dates, dtype=np.datetime64) + y = np.array(all_dates, dtype=np.datetime64) + bitmap = np.isin(x, y) + return np.where(bitmap)[0] + + indexes = dates_to_indexes(self.dates, dates_in_data) + + array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes) + LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}") + self.load_cube(cube, array) + + stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans()) + self.tmp_statistics.write(indexes, stats, dates=dates_in_data) + LOG.info("Flush data array") + array.flush() + LOG.info("Flushed data array") + + def _get_allow_nans(self) -> bool | list: + """Get the allow_nans configuration. + + Returns + ------- + bool | list + The allow_nans configuration. + """ + config = self.main_config + if "allow_nans" in config.build: + return config.build.allow_nans + + return config.statistics.get("allow_nans", []) + + def load_cube(self, cube: Any, array: ViewCacheArray) -> None: + """Load the cube into the array. + + Parameters + ---------- + cube : Any + The cube to load. + array : ViewCacheArray + The array to load into. + """ + # There are several cubelets for each cube + start = time.time() + load = 0 + save = 0 + + reading_chunks = None + total = cube.count(reading_chunks) + LOG.debug(f"Loading datacube: {cube}") + + def position(x: Any) -> int | None: + if isinstance(x, str) and "/" in x: + x = x.split("/") + return int(x[0]) + return None + + bar = tqdm.tqdm( + iterable=cube.iterate_cubelets(reading_chunks), + total=total, + desc=f"Loading datacube {cube}", + position=position(self.parts), + ) + for i, cubelet in enumerate(bar): + bar.set_description(f"Loading {i}/{total}") + + now = time.time() + data = cubelet.to_numpy() + local_indexes = cubelet.coords + load += time.time() - now + + name = self.variables_names[local_indexes[1]] + check_data_values( + data[:], + name=name, + log=[i, data.shape, local_indexes], + allow_nans=self._get_allow_nans(), + ) + + now = time.time() + array[local_indexes] = data + save += time.time() - now + + now = time.time() + save += time.time() - now + LOG.debug( + f"Elapsed: {seconds_to_human(time.time() - start)}, " + f"load time: {seconds_to_human(load)}, " + f"write time: {seconds_to_human(save)}." + ) diff --git a/src/anemoi/datasets/create/gridded/patch.py b/src/anemoi/datasets/create/gridded/patch.py old mode 100755 new mode 100644 index 5cb08ec82..e4dabb28d --- a/src/anemoi/datasets/create/gridded/patch.py +++ b/src/anemoi/datasets/create/gridded/patch.py @@ -7,182 +7,32 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json import logging -import os +from typing import Any -import zarr +from ..gridded.tasks import FieldTask LOG = logging.getLogger(__name__) -def fix_order_by(order_by: dict | list) -> list[dict]: - """Fix the order_by attribute to ensure it is a list of dictionaries. +class Patch(FieldTask): + """A class to apply patches to a dataset.""" - Parameters - ---------- - order_by : dict or list - The order_by attribute to fix. + def __init__(self, path: str, options: dict = None, **kwargs: Any): + """Initialize a Patch instance. - Returns - ------- - list[dict] - The fixed order_by attribute. - """ - if isinstance(order_by, list): - return order_by + Parameters + ---------- + path : str + The path to the dataset. + options : dict, optional + The patch options. + """ + self.path = path + self.options = options or {} - assert isinstance(order_by, dict), order_by - assert len(order_by) <= 3, order_by - lst = [] - lst.append({"valid_datetime": order_by["valid_datetime"]}) - lst.append({"param_level": order_by["param_level"]}) - lst.append({"number": order_by["number"]}) - return lst + def run(self) -> None: + """Run the patch.""" + from anemoi.datasets.create.patch import apply_patch - -def fix_history(history: list[dict]) -> list[dict]: - """Fix the history attribute by removing specific actions. - - Parameters - ---------- - history : list[dict] - The history attribute to fix. - - Returns - ------- - list[dict] - The fixed history attribute. - """ - new = history - new = [d for d in new if d.get("action") != "loading_data_start"] - new = [d for d in new if d.get("action") != "loading_data_end"] - return new - - -def fix_provenance(provenance: dict) -> dict: - """Fix the provenance attribute by adding missing fields and removing unnecessary ones. - - Parameters - ---------- - provenance : dict - The provenance attribute to fix. - - Returns - ------- - dict - The fixed provenance attribute. - """ - if "python" not in provenance: - provenance["python"] = provenance["platform"]["python_version"] - - for q in ( - "args", - "config_paths", - "executable", - "gpus", - "platform", - "python_path", - "assets", - ): - if q in provenance: - del provenance[q] - - for k, v in list(provenance["module_versions"].items()): - if v.startswith("<"): - del provenance["module_versions"][k] - if v.startswith("/"): - provenance["module_versions"][k] = os.path.join("...", os.path.basename(v)) - - for k, v in list(provenance["git_versions"].items()): - LOG.debug(k, v) - modified_files = v["git"].get("modified_files", []) - untracked_files = v["git"].get("untracked_files", []) - if not isinstance(modified_files, int): - modified_files = len(modified_files) - if not isinstance(untracked_files, int): - untracked_files = len(untracked_files) - provenance["git_versions"][k] = dict( - git={ - "sha1": v["git"]["sha1"], - "modified_files": modified_files, - "untracked_files": untracked_files, - } - ) - - LOG.debug(json.dumps(provenance, indent=2)) - # assert False - return provenance - - -def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: - """Apply a patch to the dataset at the given path. - - Parameters - ---------- - path : str - The path to the dataset. - verbose : bool, optional - Whether to log detailed information. Defaults to True. - dry_run : bool, optional - If True, do not actually apply the patch. Defaults to False. - """ - LOG.debug("====================") - LOG.debug(f"Patching {path}") - LOG.debug("====================") - - try: - attrs = zarr.open(path, mode="r").attrs.asdict() - except zarr.errors.PathNotFoundError as e: - LOG.error(f"Failed to open {path}") - LOG.error(e) - exit(0) - - FIXES = { - "history": fix_history, - "provenance_load": fix_provenance, - "provenance_statistics": fix_provenance, - "order_by": fix_order_by, - } - REMOVE = ["_create_yaml_config"] - - before = json.dumps(attrs, sort_keys=True) - - fixed_attrs = {} - for k, v in attrs.items(): - v = attrs[k] - if k in REMOVE: - LOG.info(f"✅ Remove {k}") - continue - - if k not in FIXES: - assert not k.startswith("provenance"), f"[{k}]" - LOG.debug(f"✅ Don't fix {k}") - fixed_attrs[k] = v - continue - - new_v = FIXES[k](v) - if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True): - LOG.info(f"✅ Fix {k}") - if verbose: - LOG.info(f" Before : {k}= {v}") - LOG.info(f" After : {k}= {new_v}") - else: - LOG.debug(f"✅ Unchanged {k}") - fixed_attrs[k] = new_v - - if dry_run: - return - z = zarr.open(path, mode="r+") - - for k in list(z.attrs.keys()): - if k not in fixed_attrs: - del z.attrs[k] - for k, v in fixed_attrs.items(): - z.attrs[k] = v - - after = json.dumps(z.attrs.asdict(), sort_keys=True) - if before != after: - LOG.info("Dataset changed by patch") - - assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True) + apply_patch(self.path, **self.options) diff --git a/src/anemoi/datasets/create/gridded/result.py b/src/anemoi/datasets/create/gridded/result.py index ed8440c52..d4bcf58ea 100644 --- a/src/anemoi/datasets/create/gridded/result.py +++ b/src/anemoi/datasets/create/gridded/result.py @@ -276,28 +276,35 @@ def sort(old_dic: DefaultDict[str, set]) -> dict[str, list[Any]]: return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) -class GriddedResult(Result): +class FieldResult(Result): """Class to represent the result of an action in the dataset creation process.""" empty: bool = False _coords_already_built: bool = False - def __init__(self, context: Any, datasource: Any) -> None: + def __init__(self, context: Any, argument: Any, datasource: Any) -> None: from anemoi.datasets.dates.groups import GroupOfDates self.context: Any = context self.datasource = datasource - self.group_of_dates = context.argument + self.group_of_dates = argument assert isinstance( self.group_of_dates, GroupOfDates ), f"Expected group_of_dates to be a GroupOfDates, got {type(self.group_of_dates)}: {self.group_of_dates}" + self._origins = [] + @property def data_request(self) -> dict[str, Any]: """Returns a dictionary with the parameters needed to retrieve the data.""" return _data_request(self.datasource) + @property + def origins(self) -> dict[str, Any]: + """Returns a dictionary with the parameters needed to retrieve the data.""" + return {"version": 1, "origins": self._origins} + def get_cube(self) -> Any: """Retrieve the data cube for the result. @@ -309,26 +316,26 @@ def get_cube(self) -> Any: ds: Any = self.datasource - remapping: Any = self.context.remapping - order_by: Any = self.context.order_by - flatten_grid: Any = self.context.flatten_grid - start: float = time.time() - LOG.debug("Sorting dataset %s %s", dict(order_by), remapping) - assert order_by, order_by + self.remapping: Any = self.context.remapping + self.order_by: Any = self.context.order_by + self.flatten_grid: Any = self.context.flatten_grid + self.start: float = time.time() + LOG.debug("Sorting dataset %s %s", dict(self.order_by), self.remapping) + assert self.order_by, self.order_by - patches: dict[str, dict[Any | None, int]] = {"number": {None: 0}} + self.patches: dict[str, dict[Any | None, int]] = {"number": {None: 0}} try: cube: Any = ds.cube( - order_by, - remapping=remapping, - flatten_values=flatten_grid, - patches=patches, + self.order_by, + remapping=self.remapping, + flatten_values=self.flatten_grid, + patches=self.patches, ) cube = cube.squeeze() - LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.") + LOG.debug(f"Sorting done in {seconds_to_human(time.time()-self.start)}.") except ValueError: - self.explain(ds, order_by, remapping=remapping, patches=patches) + self.explain(ds, self.order_by, remapping=self.remapping, patches=self.patches) # raise ValueError(f"Error in {self}") exit(1) @@ -556,6 +563,41 @@ def build_coords(self) -> None: self._cube: Any = cube + name_key = list(self.order_by.keys())[1] + + p = None + origins_per_number = defaultdict(lambda: defaultdict(set)) + + for fs in self.datasource: + o = fs.metadata("anemoi_origin", remapping=self.remapping, patches=self.patches) + name = fs.metadata(name_key, remapping=self.remapping, patches=self.patches) + number = fs.metadata("number", remapping=self.remapping, patches=self.patches) + + assert name not in origins_per_number[number][o], name + origins_per_number[number][o].add(name) + + if p is not o: + LOG.info(f"🔥🔥🔥🔥🔥🔥 Source: {name}, {o}") + p = o + + origins_per_variables = defaultdict(lambda: defaultdict(set)) + for number, origins in origins_per_number.items(): + for origin, names in origins.items(): + for name in names: + origins_per_variables[name][origin].add(number) + + origins = defaultdict(set) + + # Check if all members of a variable have the same origins + for name, origin_number in origins_per_variables.items(): + # For now we do not support variables with members from different origins + assert len(origin_number) == 1, origin_number + origins[list(origin_number.keys())[0]].add(name) + + self._origins = [] + for k, v in origins.items(): + self._origins.append({"origin": k.as_dict(), "variables": sorted(v)}) + self._coords_already_built: bool = True @property diff --git a/src/anemoi/datasets/create/gridded/size.py b/src/anemoi/datasets/create/gridded/size.py index 4cffd66d7..b55827b80 100644 --- a/src/anemoi/datasets/create/gridded/size.py +++ b/src/anemoi/datasets/create/gridded/size.py @@ -7,41 +7,42 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. - import logging -import os +from typing import Any + +from anemoi.datasets import open_dataset -import tqdm -from anemoi.utils.humanize import bytes_to_human +from ..gridded.tasks import FieldTask LOG = logging.getLogger(__name__) -def compute_directory_sizes(path: str) -> dict[str, int] | None: - """Computes the total size and number of files in a directory. +class Size(FieldTask): + """A class to compute the size of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Size instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) - Parameters - ---------- - path : str - The path to the directory. + def run(self) -> None: + """Run the size computation.""" + from anemoi.datasets.create.size import compute_directory_sizes - Returns - ------- - dict of str to int or None - A dictionary with the total size and number of files, or None if the path is not a directory. - """ - if not os.path.isdir(path): - return None + metadata = compute_directory_sizes(self.path) + self.update_metadata(**metadata) - size, n = 0, 0 - bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}") - for dirpath, _, filenames in bar: - for filename in filenames: - file_path = os.path.join(dirpath, filename) - size += os.path.getsize(file_path) - n += 1 + # Look for constant fields + ds = open_dataset(self.path) + constants = ds.computed_constant_fields() - LOG.info(f"Total size: {bytes_to_human(size)}") - LOG.info(f"Total number of files: {n}") + variables_metadata = self.dataset.zarr_metadata.get("variables_metadata", {}).copy() + for k in constants: + variables_metadata[k]["constant_in_time"] = True - return dict(total_size=size, total_number_of_files=n) + self.update_metadata(constant_fields=constants, variables_metadata=variables_metadata) diff --git a/src/anemoi/datasets/create/gridded/statistics.py b/src/anemoi/datasets/create/gridded/statistics.py new file mode 100644 index 000000000..1cabcc366 --- /dev/null +++ b/src/anemoi/datasets/create/gridded/statistics.py @@ -0,0 +1,102 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import warnings +from functools import cached_property +from typing import Any + +import numpy as np +import zarr + +from ..gridded.tasks import FieldTask +from ..gridded.tasks import HasRegistryMixin +from ..gridded.tasks import HasStatisticTempMixin + +LOG = logging.getLogger(__name__) + + +class Statistics(FieldTask, HasStatisticTempMixin, HasRegistryMixin): + """A class to compute statistics for a dataset.""" + + def __init__( + self, + path: str, + use_threads: bool = False, + statistics_temp_dir: str | None = None, + progress: Any = None, + **kwargs: Any, + ): + """Initialize a Statistics instance. + + Parameters + ---------- + path : str + The path to the dataset. + use_threads : bool, optional + Whether to use threads. + statistics_temp_dir : Optional[str], optional + The directory for temporary statistics. + progress : Any, optional + The progress indicator. + """ + super().__init__(path) + self.use_threads = use_threads + self.progress = progress + self.statistics_temp_dir = statistics_temp_dir + + def run(self) -> None: + """Run the statistics computation.""" + start, end = ( + self.dataset.zarr_metadata["statistics_start_date"], + self.dataset.zarr_metadata["statistics_end_date"], + ) + start, end = np.datetime64(start), np.datetime64(end) + dates = self.dataset.anemoi_dataset.dates + + assert type(dates[0]) is type(start), (type(dates[0]), type(start)) + + dates = [d for d in dates if d >= start and d <= end] + dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] + variables = self.dataset.anemoi_dataset.variables + stats = self.tmp_statistics.get_aggregated(dates, variables, self.allow_nans) + + LOG.info(stats) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") + + for k in [ + "mean", + "stdev", + "minimum", + "maximum", + "sums", + "squares", + "count", + "has_nans", + ]: + self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) + + self.registry.add_to_history("compute_statistics_end") + LOG.info(f"Wrote statistics in {self.path}") + + @cached_property + def allow_nans(self) -> bool | list: + """Check if NaNs are allowed.""" + + z = zarr.open(self.path, mode="r") + if "allow_nans" in z.attrs: + return z.attrs["allow_nans"] + + if "variables_with_nans" in z.attrs: + return z.attrs["variables_with_nans"] + + warnings.warn(f"Cannot find 'variables_with_nans' of 'allow_nans' in {self.path}.") + return True diff --git a/src/anemoi/datasets/create/gridded/statistics/__init__.py b/src/anemoi/datasets/create/gridded/stats/__init__.py similarity index 99% rename from src/anemoi/datasets/create/gridded/statistics/__init__.py rename to src/anemoi/datasets/create/gridded/stats/__init__.py index fb59573c2..ee87e21ce 100644 --- a/src/anemoi/datasets/create/gridded/statistics/__init__.py +++ b/src/anemoi/datasets/create/gridded/stats/__init__.py @@ -24,7 +24,7 @@ from numpy.typing import NDArray from anemoi.datasets.create.gridded.check import check_data_values -from anemoi.datasets.create.gridded.statistics.summary import Summary +from anemoi.datasets.create.gridded.stats.summary import Summary LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/gridded/statistics/summary.py b/src/anemoi/datasets/create/gridded/stats/summary.py similarity index 100% rename from src/anemoi/datasets/create/gridded/statistics/summary.py rename to src/anemoi/datasets/create/gridded/stats/summary.py diff --git a/src/anemoi/datasets/create/gridded/tasks.py b/src/anemoi/datasets/create/gridded/tasks.py new file mode 100644 index 000000000..d4cb1f288 --- /dev/null +++ b/src/anemoi/datasets/create/gridded/tasks.py @@ -0,0 +1,606 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import json +import logging +import os +from functools import cached_property +from typing import Any + +import cftime +import numpy as np +import zarr +from anemoi.utils.dates import frequency_to_string +from earthkit.data.core.order import build_remapping + +from anemoi.datasets import open_dataset +from anemoi.datasets.create.check import DatasetName +from anemoi.datasets.create.config import build_output +from anemoi.datasets.create.config import loader_config +from anemoi.datasets.create.gridded.context import FieldContext +from anemoi.datasets.create.input import InputBuilder +from anemoi.datasets.create.statistics import TmpStatistics +from anemoi.datasets.create.statistics import default_statistics_dates +from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.use.gridded.misc import as_first_date +from anemoi.datasets.use.gridded.misc import as_last_date + +from ..tasks import chain + +LOG = logging.getLogger(__name__) + + +def _json_tidy(o: Any) -> Any: + """Convert various types to JSON serializable format. + + Parameters + ---------- + o : Any + The object to convert. + + Returns + ------- + Any + The JSON serializable object. + """ + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.timedelta): + return frequency_to_string(o) + + if isinstance(o, cftime.DatetimeJulian): + import pandas as pd + + o = pd.Timestamp( + o.year, + o.month, + o.day, + o.hour, + o.minute, + o.second, + ) + return o.isoformat() + + if isinstance(o, (np.float32, np.float64)): + return float(o) + + raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}") + + +def _build_statistics_dates( + dates: list[datetime.datetime], + start: datetime.datetime | None, + end: datetime.datetime | None, +) -> tuple[str, str]: + """Compute the start and end dates for the statistics. + + Parameters + ---------- + dates : list of datetime.datetime + The list of dates. + start : Optional[datetime.datetime] + The start date. + end : Optional[datetime.datetime] + The end date. + + Returns + ------- + tuple of str + The start and end dates in ISO format. + """ + # if not specified, use the default statistics dates + default_start, default_end = default_statistics_dates(dates) + if start is None: + start = default_start + if end is None: + end = default_end + + # in any case, adapt to the actual dates in the dataset + start = as_first_date(start, dates) + end = as_last_date(end, dates) + + # and convert to datetime to isoformat + start = start.astype(datetime.datetime) + end = end.astype(datetime.datetime) + return (start.isoformat(), end.isoformat()) + + +class Dataset: + """A class to represent a dataset.""" + + def __init__(self, path: str): + """Initialize a Dataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + self.path = path + + _, ext = os.path.splitext(self.path) + if ext != ".zarr": + raise ValueError(f"Unsupported extension={ext} for path={self.path}") + + def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: + """Add a dataset to the Zarr store. + + Parameters + ---------- + mode : str, optional + The mode to open the Zarr store. + **kwargs + Additional arguments for the dataset. + + Returns + ------- + zarr.Array + The added dataset. + """ + import zarr + + z = zarr.open(self.path, mode=mode) + from anemoi.datasets.create.zarr import add_zarr_dataset + + return add_zarr_dataset(zarr_root=z, **kwargs) + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + import zarr + + LOG.debug(f"Updating metadata {kwargs}") + z = zarr.open(self.path, mode="w+") + for k, v in kwargs.items(): + if isinstance(v, np.datetime64): + v = v.astype(datetime.datetime) + if isinstance(v, datetime.date): + v = v.isoformat() + z.attrs[k] = json.loads(json.dumps(v, default=_json_tidy)) + + @cached_property + def anemoi_dataset(self) -> Any: + """Get the Anemoi dataset.""" + return open_dataset(self.path) + + @cached_property + def zarr_metadata(self) -> dict: + """Get the Zarr metadata.""" + import zarr + + return dict(zarr.open(self.path, mode="r").attrs) + + def print_info(self) -> None: + """Print information about the dataset.""" + import zarr + + z = zarr.open(self.path, mode="r") + try: + LOG.info(z["data"].info) + except Exception as e: + LOG.info(e) + + def get_zarr_chunks(self) -> tuple: + """Get the chunks of the Zarr dataset. + + Returns + ------- + tuple + The chunks of the Zarr dataset. + """ + import zarr + + z = zarr.open(self.path, mode="r") + return z["data"].chunks + + def check_name( + self, + resolution: str, + dates: list[datetime.datetime], + frequency: datetime.timedelta, + raise_exception: bool = True, + is_test: bool = False, + ) -> None: + """Check the name of the dataset. + + Parameters + ---------- + resolution : str + The resolution of the dataset. + dates : list of datetime.datetime + The dates of the dataset. + frequency : datetime.timedelta + The frequency of the dataset. + raise_exception : bool, optional + Whether to raise an exception if the name is invalid. + is_test : bool, optional + Whether this is a test. + """ + basename, _ = os.path.splitext(os.path.basename(self.path)) + try: + DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid() + except Exception as e: + if raise_exception and not is_test: + raise e + else: + LOG.warning(f"Dataset name error: {e}") + + def get_main_config(self) -> Any: + """Get the main configuration of the dataset. + + Returns + ------- + Any + The main configuration. + """ + import zarr + + z = zarr.open(self.path, mode="r") + config = loader_config(z.attrs.get("_create_yaml_config")) + + if "env" in config: + for k, v in config["env"].items(): + LOG.info(f"Setting env variable {k}={v}") + os.environ[k] = str(v) + + return config + + +class WritableDataset(Dataset): + """A class to represent a writable dataset.""" + + def __init__(self, path: str): + """Initialize a WritableDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="r+") + + @cached_property + def data_array(self) -> Any: + """Get the data array of the dataset.""" + import zarr + + return zarr.open(self.path, mode="r+")["data"] + + +class NewDataset(Dataset): + """A class to represent a new dataset.""" + + def __init__(self, path: str, overwrite: bool = False): + """Initialize a NewDataset instance. + + Parameters + ---------- + path : str + The path to the dataset. + overwrite : bool, optional + Whether to overwrite the existing dataset. + """ + super().__init__(path) + self.path = path + + import zarr + + self.z = zarr.open(self.path, mode="w") + self.z.create_group("_build") + + +class FieldTask: + """A base class for dataset creation tasks.""" + + dataset_class = WritableDataset + + def __init__(self, path: str, cache: str | None = None): + """Initialize an Actor instance. + + Parameters + ---------- + path : str + The path to the dataset. + cache : Optional[str], optional + The cache directory. + """ + # Catch all floating point errors, including overflow, sqrt(<0), etc + np.seterr(all="raise", under="warn") + + self.path = path + self.cache = cache + self.dataset = self.dataset_class(self.path) + + def run(self) -> None: + """Run the actor.""" + # to be implemented in the sub-classes + raise NotImplementedError() + + def update_metadata(self, **kwargs: Any) -> None: + """Update the metadata of the dataset. + + Parameters + ---------- + **kwargs + The metadata to update. + """ + self.dataset.update_metadata(**kwargs) + + def _cache_context(self) -> Any: + """Get the cache context. + + Returns + ------- + Any + The cache context. + """ + from anemoi.datasets.create.utils import cache_context + + return cache_context(self.cache) + + def check_unkown_kwargs(self, kwargs: dict) -> None: + """Check for unknown keyword arguments. + + Parameters + ---------- + kwargs : dict + The keyword arguments. + """ + # remove this latter + LOG.warning(f"💬 Unknown kwargs for {self.__class__.__name__}: {kwargs}") + + def read_dataset_metadata(self, path: str) -> None: + """Read the metadata of the dataset. + + Parameters + ---------- + path : str + The path to the dataset. + """ + ds = open_dataset(path) + self.dataset_shape = ds.shape + self.variables_names = ds.variables + assert len(self.variables_names) == ds.shape[1], self.dataset_shape + self.dates = ds.dates + + self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) + + def check_missing_dates(expected: list[np.datetime64]) -> None: + """Check if the missing dates in the dataset match the expected dates. + + Parameters + ---------- + expected : list of np.datetime64 + The expected missing dates. + + Raises + ------ + ValueError + If the missing dates in the dataset do not match the expected dates. + """ + import zarr + + z = zarr.open(path, "r") + missing_dates = z.attrs.get("missing_dates", []) + missing_dates = sorted([np.datetime64(d) for d in missing_dates]) + if missing_dates != expected: + LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") + LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") + LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") + raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") + + check_missing_dates(self.missing_dates) + + +class HasRegistryMixin: + """A mixin class to provide registry functionality.""" + + @cached_property + def registry(self) -> Any: + """Get the registry.""" + from anemoi.datasets.create.zarr import ZarrBuiltRegistry + + return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) + + +class HasStatisticTempMixin: + """A mixin class to provide temporary statistics functionality.""" + + @cached_property + def tmp_statistics(self) -> TmpStatistics: + """Get the temporary statistics.""" + directory = self.statistics_temp_dir or os.path.join(self.path + ".storage_for_statistics.tmp") + return TmpStatistics(directory) + + +class HasElementForDataMixin: + """A mixin class to provide element creation functionality for data.""" + + def create_elements(self, config: Any) -> None: + """Create elements for the dataset. + + Parameters + ---------- + config : Any + The configuration. + """ + assert self.registry + assert self.tmp_statistics + + LOG.info(dict(config.dates)) + + self.groups = Groups(**config.dates) + LOG.info(self.groups) + + self.output = build_output(config.output, parent=self) + + self.context = FieldContext( + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=config.build.use_grib_paramid, + ) + + self.input = InputBuilder( + config.input, + data_sources=config.get("data_sources", {}), + ) + LOG.debug("✅ INPUT_BUILDER") + LOG.debug(self.input) + + +def validate_config(config: Any) -> None: + + import json + + import jsonschema + + def _tidy(d): + if isinstance(d, dict): + return {k: _tidy(v) for k, v in d.items()} + + if isinstance(d, list): + return [_tidy(v) for v in d if v is not None] + + # jsonschema does not support datetime.date + if isinstance(d, datetime.datetime): + return d.isoformat() + + if isinstance(d, datetime.date): + return d.isoformat() + + return d + + # https://json-schema.org + + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "schemas", + "recipe.json", + ) + ) as f: + schema = json.load(f) + + try: + jsonschema.validate(instance=_tidy(config), schema=schema) + except jsonschema.exceptions.ValidationError as e: + LOG.error("❌ Config validation failed (jsonschema):") + LOG.error(e.message) + raise + + +def _config_to_python(config: Any) -> Any: + + from anemoi.datasets.create.create.python import PythonScript + + raw_config = config + + config = loader_config(config) + + input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) + + code = PythonScript() + x = input.python_code(code) + code = code.source_code(x, raw_config) + + try: + import black + + return black.format_str(code, mode=black.Mode()) + # except ImportError: + except Exception: + LOG.warning("Black not installed, skipping formatting") + return code + + +class TaskCreator: + """A class to create and run dataset creation tasks.""" + + def init(self, *args: Any, **kwargs: Any): + from .init import Init + + return Init(*args, **kwargs) + + def load(self, *args: Any, **kwargs: Any): + from .load import Load + + return Load(*args, **kwargs) + + def size(self, *args: Any, **kwargs: Any): + from .size import Size + + return Size(*args, **kwargs) + + def patch(self, *args: Any, **kwargs: Any): + from .patch import Patch + + return Patch(*args, **kwargs) + + def statistics(self, *args: Any, **kwargs: Any): + from .statistics import Statistics + + return Statistics(*args, **kwargs) + + def finalise(self, *args: Any, **kwargs: Any): + from .cleanup import Cleanup + from .size import Size + from .statistics import Statistics + + return chain([Statistics, Size, Cleanup])(*args, **kwargs) + + def cleanup(self, *args: Any, **kwargs: Any): + from .cleanup import Cleanup + + return Cleanup(*args, **kwargs) + + def verify(self, *args: Any, **kwargs: Any): + from .verify import Verify + + return Verify(*args, **kwargs) + + def init_additions(self, *args: Any, **kwargs: Any): + from .additions import InitAdditions + + return InitAdditions(*args, **kwargs) + + def load_additions(self, *args: Any, **kwargs: Any): + from .additions import LoadAdditions + + return LoadAdditions(*args, **kwargs) + + def finalise_additions(self, *args: Any, **kwargs: Any): + from .additions import FinaliseAdditions + from .size import Size + + return chain([FinaliseAdditions, Size])(*args, **kwargs) + + def additions(self, *args: Any, **kwargs: Any): + from .additions import FinaliseAdditions + from .additions import InitAdditions + from .additions import LoadAdditions + from .cleanup import Cleanup + from .size import Size + + return chain([InitAdditions, LoadAdditions, FinaliseAdditions, Size, Cleanup])(*args, **kwargs) diff --git a/src/anemoi/datasets/create/gridded/verify.py b/src/anemoi/datasets/create/gridded/verify.py new file mode 100644 index 000000000..3b29578ce --- /dev/null +++ b/src/anemoi/datasets/create/gridded/verify.py @@ -0,0 +1,34 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any + +from ..gridded.tasks import FieldTask + +LOG = logging.getLogger(__name__) + + +class Verify(FieldTask): + """A class to verify the integrity of a dataset.""" + + def __init__(self, path: str, **kwargs: Any): + """Initialize a Verify instance. + + Parameters + ---------- + path : str + The path to the dataset. + """ + super().__init__(path) + + def run(self) -> None: + """Run the verification.""" + LOG.info(f"Verifying dataset at {self.path}") + LOG.info(str(self.dataset.anemoi_dataset)) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index e4e312fa8..f56bbd067 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -9,17 +9,13 @@ from copy import deepcopy from functools import cached_property -from typing import TYPE_CHECKING from typing import Any -if TYPE_CHECKING: - from anemoi.datasets.create.input.action import Recipe - class InputBuilder: """Builder class for creating input data from configuration and data sources.""" - def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> None: + def __init__(self, config: dict, data_sources: dict | list) -> None: """Initialize the InputBuilder. Parameters @@ -31,12 +27,11 @@ def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> No **kwargs : Any Additional keyword arguments. """ - self.kwargs = kwargs self.config = deepcopy(config) self.data_sources = deepcopy(dict(data_sources=data_sources)) @cached_property - def action(self) -> "Recipe": + def action(self) -> Any: """Returns the action object based on the configuration.""" from anemoi.datasets.create.input.action import Recipe from anemoi.datasets.create.input.action import action_factory @@ -46,11 +41,13 @@ def action(self) -> "Recipe": return Recipe(input, sources) - def select(self, argument) -> Any: + def select(self, context, argument) -> Any: """Select data based on the group of dates. Parameters ---------- + context : Any + The context for the data selection. argument : GroupOfDates Group of dates to select data for. @@ -59,10 +56,15 @@ def select(self, argument) -> Any: Any Selected data. """ - from anemoi.datasets.create.gridded.context import GriddedContext + # TODO: move me elsewhere + + return context.create_result( + argument, + self.action(context, argument), + ) - context = GriddedContext(argument, **self.kwargs) - return context.create_result(self.action(context, argument)) + def python_code(self, code): + return self.action.python_code(code) def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 7808ae717..831456435 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -8,20 +8,15 @@ # nor does it submit to any jurisdiction. import logging +from abc import ABC +from abc import abstractmethod from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) -class Action: - """An "Action" represents a single operation described in the yaml configuration, e.g. a source, a filter, - pipe, join, etc. - - See :ref:`operations` for more details. - - """ - +class Action(ABC): def __init__(self, config, *path): self.config = config self.path = path @@ -30,32 +25,19 @@ def __init__(self, config, *path): "data_sources", ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" + @abstractmethod + def __call__(self, context, argument): + pass -class Concat(Action): - """The Concat contruct is used to concat different actions that are responsible - for delivery fields for different dates. - - See :ref:`building-concat` for more details. - - .. block-code:: yaml - - input: - concat: - - dates: - start: 2023-01-01 - end: 2023-01-31 - frequency: 1d - action: # some action - ... + @abstractmethod + def python_code(self, code): + pass - - dates: - start: 2023-02-01 - end: 2023-02-28 - frequency: 1d - action: # some action + def __repr__(self): + return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" - """ +class Concat(Action): def __init__(self, config, *path): super().__init__(config, *path, "concat") @@ -65,6 +47,7 @@ def __init__(self, config, *path): for i, item in enumerate(config): + assert "dates" in item, f"Value must contain the key 'dates' {item}" dates = item["dates"] filtering_dates = DatesProvider.from_config(**dates) action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) @@ -85,28 +68,17 @@ def __call__(self, context, argument): return context.register(results, self.path) + def python_code(self, code): + return code.concat( + {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} + ) -class Join(Action): - """Implement the join operation to combine results from multiple actions. - - See :ref:`building-join` for more details. - - .. block-code:: yaml - - input: - join: - - grib: - ... - - - netcdf: # some other action - ... - - """ +class Join(Action): def __init__(self, config, *path): super().__init__(config, *path, "join") - assert isinstance(config, list), f"Value of Join Action must be a list, got: {config}" + assert isinstance(config, list), f"Value must be a list {config}" self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] @@ -121,27 +93,13 @@ def __call__(self, context, argument): return context.register(results, self.path) + def python_code(self, code) -> None: + return code.sum(a.python_code(code) for a in self.actions) -class Pipe(Action): - """Implement the pipe operation to chain results from a - source through multiple filters. - - See :ref:`building-pipe` for more details. - - .. block-code:: yaml - - input: - pipe: - - grib: - ... - - - rename: - ... - - """ +class Pipe(Action): def __init__(self, config, *path): - assert isinstance(config, list), f"Value of Pipe Action must be a list, got {config}" + assert isinstance(config, list), f"Value must be a list {config}" super().__init__(config, *path, "pipe") self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] @@ -159,10 +117,11 @@ def __call__(self, context, argument): return context.register(result, self.path) + def python_code(self, code) -> None: + return code.pipe(a.python_code(code) for a in self.actions) -class Function(Action): - """Base class for sources and filters.""" +class Function(Action): def __init__(self, config, *path): super().__init__(config, *path, self.name) @@ -176,45 +135,63 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) + def python_code(self, code) -> str: + # For now... + if "source" in self.config: + source = action_factory(self.config["source"], *self.path, "source") + self.config["source"] = source.python_code(code) + return code.call(self.name, self.config) -class DatasetSourceMixin: - """Mixin class for sources defined in anemoi-datasets""" +class DatasetSourceMixin: def create_object(self, context, config): from anemoi.datasets.create.sources import create_source as create_datasets_source return create_datasets_source(context, config) def call_object(self, context, source, argument): - return source.execute(context.source_argument(argument)) + result = source.execute(context.source_argument(argument)) + return context.origin(result, self, argument) + def origin(self): + from anemoi.datasets.create.input.origin import Source + + return Source(self.path[-1], self.config) -class TransformSourceMixin: - """Mixin class for sources defined in anemoi-transform""" +class TransformSourceMixin: def create_object(self, context, config): from anemoi.transform.sources import create_source as create_transform_source return create_transform_source(context, config) + def combine_origins(self, current, previous): + assert previous is None, f"Cannot combine origins, previous already exists: {previous}" + return current -class TransformFilterMixin: - """Mixin class for filters defined in anemoi-transform""" + def origin(self): + from anemoi.datasets.create.input.origin import Source + return Source(self.path[-1], self.config) + + +class TransformFilterMixin: def create_object(self, context, config): from anemoi.transform.filters import create_filter as create_transform_filter return create_transform_filter(context, config) def call_object(self, context, filter, argument): - return filter.forward(context.filter_argument(argument)) + result = filter.forward(context.filter_argument(argument)) + return context.origin(result, self, argument) + def origin(self): + from anemoi.datasets.create.input.origin import Filter -class FilterFunction(Function): - """Action to call a filter on the argument (e.g. rename, regrid, etc.).""" + return Filter(self.path[-1], self.config) - def __call__(self, context, argument): - return self.call(context, argument, context.filter_argument) + def combine_origins(self, current, previous): + return {"_apply": current, **(previous or {})} def _make_name(name, what): @@ -240,8 +217,6 @@ def new_filter(name, mixin): class DataSources(Action): - """Action to call a source (e.g. mars, netcdf, grib, etc.).""" - def __init__(self, config, *path): super().__init__(config, *path) assert isinstance(config, (dict, list)), f"Invalid config type: {type(config)}" @@ -250,18 +225,25 @@ def __init__(self, config, *path): else: self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} + def python_code(self, code): + return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) + def __call__(self, context, argument): for name, source in self.sources.items(): context.register(source(context, argument), self.path + (name,)) class Recipe(Action): - """Action that represent a recipe (i.e. a sequence of data_sources and input).""" - def __init__(self, input, data_sources): self.input = input self.data_sources = data_sources + def python_code(self, code): + return code.recipe( + self.input.python_code(code), + self.data_sources.python_code(code), + ) + def __call__(self, context, argument): # Load data_sources self.data_sources(context, argument) @@ -276,6 +258,7 @@ def __call__(self, context, argument): } LEN_KLASS = len(KLASS) +TYPES = {} def make(key, config, *path): @@ -292,17 +275,28 @@ def make(key, config, *path): for name in dataset_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) + TYPES[name.replace("_", "-")] = "source" for name in transform_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) + TYPES[name.replace("_", "-")] = "source" # Register filters for name in transform_filter_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) + TYPES[name.replace("_", "-")] = "filter" + + key = key.replace("_", "-") + + if key not in KLASS: + LOG.error(f"Unknown action '{key}' in {'.'.join(x for x in path)}") + for available in sorted(KLASS): + LOG.error(f" Available: {available} (type={TYPES.get(available, 'built-in')})") + raise ValueError(f"Unknown action '{key}' in {'.'.join(x for x in path)}") - return KLASS[key.replace("_", "-")](config, *path) + return KLASS[key](config, *path) def action_factory(data, *path): diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py index 89df7a727..28c797dd5 100644 --- a/src/anemoi/datasets/create/input/context.py +++ b/src/anemoi/datasets/create/input/context.py @@ -18,10 +18,9 @@ class Context(ABC): """Context for building input data.""" - def __init__(self, /, argument: Any) -> None: + def __init__(self) -> None: self.results = {} self.cache = {} - self.argument = argument def trace(self, emoji, *message) -> None: @@ -34,7 +33,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - LOG.info(f"Registering data at path: {path}") + LOG.info(f"Registering data at path: {'.'.join(str(x) for x in path)}") self.results[tuple(path)] = data return data @@ -47,9 +46,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: - LOG.warning(f"Path not found {path}") + print(f"Path not found {path}") for p in sorted(self.results): - LOG.info(f" Available paths: {p}") + print(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 2f776dff9..7a706c8ef 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -13,10 +13,10 @@ from earthkit.data import FieldList -from anemoi.datasets.create.gridded.result import Result from anemoi.datasets.create.input.action import Action from anemoi.datasets.create.input.action import action_factory from anemoi.datasets.create.input.misc import _tidy +from anemoi.datasets.create.input.result.field import Result from anemoi.datasets.dates.groups import GroupOfDates LOG = logging.getLogger(__name__) @@ -84,6 +84,11 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) + def python_code(self, code) -> str: + for n, s in zip(self.names, self.sources): + code.source(n, s.python_code(code)) + return code + class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" diff --git a/src/anemoi/datasets/create/input/misc.py b/src/anemoi/datasets/create/input/misc.py index a4791642c..bd31ca4f0 100644 --- a/src/anemoi/datasets/create/input/misc.py +++ b/src/anemoi/datasets/create/input/misc.py @@ -8,6 +8,9 @@ # nor does it submit to any jurisdiction. import logging +from collections.abc import Callable +from functools import wraps +from typing import Any from earthkit.data import FieldList from earthkit.data.core.fieldlist import MultiFieldList @@ -15,6 +18,74 @@ LOG = logging.getLogger(__name__) +def parse_function_name(name: str) -> tuple[str, int | None]: + """Parses a function name to extract the base name and an optional time delta. + + Parameters + ---------- + name : str + The function name to parse. + + Returns + ------- + tuple of (str, int or None) + The base name and an optional time delta. + """ + if name.endswith("h") and name[:-1].isdigit(): + + if "-" in name: + name, delta = name.split("-") + sign = -1 + + elif "+" in name: + name, delta = name.split("+") + sign = 1 + + else: + return name, None + + assert delta[-1] == "h", (name, delta) + delta = sign * int(delta[:-1]) + return name, delta + + return name, None + + +def assert_fieldlist(method: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to assert that the result of a method is an instance of FieldList. + + Parameters + ---------- + method : Callable[..., Any] + The method to decorate. + + Returns + ------- + Callable[..., Any] + The decorated method. + """ + + @wraps(method) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + + result = method(self, *args, **kwargs) + assert isinstance(result, FieldList), type(result) + return result + + return wrapper + + +def assert_is_fieldlist(obj: object) -> None: + """Asserts that the given object is an instance of FieldList. + + Parameters + ---------- + obj : object + The object to check. + """ + assert isinstance(obj, FieldList), type(obj) + + def _flatten(ds: MultiFieldList | FieldList) -> list: """Flattens a MultiFieldList or FieldList into a list of FieldList objects. diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py new file mode 100644 index 000000000..9f5173afc --- /dev/null +++ b/src/anemoi/datasets/create/input/origin.py @@ -0,0 +1,159 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC + +LOG = logging.getLogger(__name__) + + +class Origin(ABC): + + def __init__(self, when="dataset-create"): + self.when = when + + def __eq__(self, other): + if not isinstance(other, Origin): + return False + return self is other + + def __hash__(self): + return id(self) + + +def _un_dotdict(x): + if isinstance(x, dict): + return {k: _un_dotdict(v) for k, v in x.items()} + + if isinstance(x, (list, tuple, set)): + return [_un_dotdict(a) for a in x] + + return x + + +class Pipe(Origin): + def __init__(self, s1, s2, when="dataset-create"): + super().__init__(when) + self.steps = [s1, s2] + + assert s1 is not None, (s1, s2) + assert s2 is not None, (s1, s2) + + if isinstance(s1, Pipe): + assert not isinstance(s2, Pipe), (s1, s2) + self.steps = s1.steps + [s2] + + def combine(self, previous, action, action_arguments): + assert False, (self, previous) + + def as_dict(self): + return { + "type": "pipe", + "steps": [s.as_dict() for s in self.steps], + "when": self.when, + } + + def __repr__(self): + return " | ".join(repr(s) for s in self.steps) + + +class Join(Origin): + def __init__(self, origins, when="dataset-create"): + assert isinstance(origins, (list, tuple, set)), origins + super().__init__(when) + self.steps = list(origins) + + assert all(o is not None for o in origins), origins + + def combine(self, previous, action, action_arguments): + assert False, (self, previous) + + def as_dict(self): + return { + "type": "join", + "steps": [s.as_dict() for s in self.steps], + "when": self.when, + } + + def __repr__(self): + return " & ".join(repr(s) for s in self.steps) + + +class Source(Origin): + def __init__(self, name, config, when="dataset-create"): + super().__init__(when) + assert isinstance(config, dict), f"Config must be a dictionary {config}" + self.name = name + self.config = _un_dotdict(config) + + def combine(self, previous, action, action_arguments): + assert previous is None, f"Cannot combine origins, previous already exists: {previous}" + return self + + def as_dict(self): + return { + "type": "source", + "name": self.name, + "config": self.config, + "when": self.when, + } + + def __repr__(self): + return f"{self.name}({id(self)})" + + +class Filter(Origin): + def __init__(self, name, config, when="dataset-create"): + super().__init__(when) + assert isinstance(config, dict), f"Config must be a dictionary {config}" + self.name = name + self.config = _un_dotdict(config) + self._cache = {} + + def combine(self, previous, action, action_arguments): + + if previous is None: + # This can happen if the filter does not tag its output with an origin + # (e.g. a user plugin). In that case we try to get the origin from the action arguments + key = (id(action), id(action_arguments)) + if key not in self._cache: + + LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") + LOG.warning(f"Connecting to action arguments {action_arguments}") + origins = set() + for k in action_arguments: + o = k.metadata("anemoi_origin", default=None) + if o is None: + raise ValueError( + f"Cannot combine origins, previous is None and action_arguments {action_arguments} has no origin" + ) + origins.add(o) + if len(origins) == 1: + self._cache[key] = origins.pop() + else: + self._cache[key] = Join(origins) + previous = self._cache[key] + + if previous in self._cache: + # We use a cache to avoid recomputing the same combination + return self._cache[previous] + + self._cache[previous] = Pipe(previous, self) + return self._cache[previous] + + def as_dict(self): + return { + "type": "filter", + "name": self.name, + "config": self.config, + "when": self.when, + } + + def __repr__(self): + return f"{self.name}({id(self)})" diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py new file mode 100755 index 000000000..5cb08ec82 --- /dev/null +++ b/src/anemoi/datasets/create/patch.py @@ -0,0 +1,188 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +import logging +import os + +import zarr + +LOG = logging.getLogger(__name__) + + +def fix_order_by(order_by: dict | list) -> list[dict]: + """Fix the order_by attribute to ensure it is a list of dictionaries. + + Parameters + ---------- + order_by : dict or list + The order_by attribute to fix. + + Returns + ------- + list[dict] + The fixed order_by attribute. + """ + if isinstance(order_by, list): + return order_by + + assert isinstance(order_by, dict), order_by + assert len(order_by) <= 3, order_by + lst = [] + lst.append({"valid_datetime": order_by["valid_datetime"]}) + lst.append({"param_level": order_by["param_level"]}) + lst.append({"number": order_by["number"]}) + return lst + + +def fix_history(history: list[dict]) -> list[dict]: + """Fix the history attribute by removing specific actions. + + Parameters + ---------- + history : list[dict] + The history attribute to fix. + + Returns + ------- + list[dict] + The fixed history attribute. + """ + new = history + new = [d for d in new if d.get("action") != "loading_data_start"] + new = [d for d in new if d.get("action") != "loading_data_end"] + return new + + +def fix_provenance(provenance: dict) -> dict: + """Fix the provenance attribute by adding missing fields and removing unnecessary ones. + + Parameters + ---------- + provenance : dict + The provenance attribute to fix. + + Returns + ------- + dict + The fixed provenance attribute. + """ + if "python" not in provenance: + provenance["python"] = provenance["platform"]["python_version"] + + for q in ( + "args", + "config_paths", + "executable", + "gpus", + "platform", + "python_path", + "assets", + ): + if q in provenance: + del provenance[q] + + for k, v in list(provenance["module_versions"].items()): + if v.startswith("<"): + del provenance["module_versions"][k] + if v.startswith("/"): + provenance["module_versions"][k] = os.path.join("...", os.path.basename(v)) + + for k, v in list(provenance["git_versions"].items()): + LOG.debug(k, v) + modified_files = v["git"].get("modified_files", []) + untracked_files = v["git"].get("untracked_files", []) + if not isinstance(modified_files, int): + modified_files = len(modified_files) + if not isinstance(untracked_files, int): + untracked_files = len(untracked_files) + provenance["git_versions"][k] = dict( + git={ + "sha1": v["git"]["sha1"], + "modified_files": modified_files, + "untracked_files": untracked_files, + } + ) + + LOG.debug(json.dumps(provenance, indent=2)) + # assert False + return provenance + + +def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: + """Apply a patch to the dataset at the given path. + + Parameters + ---------- + path : str + The path to the dataset. + verbose : bool, optional + Whether to log detailed information. Defaults to True. + dry_run : bool, optional + If True, do not actually apply the patch. Defaults to False. + """ + LOG.debug("====================") + LOG.debug(f"Patching {path}") + LOG.debug("====================") + + try: + attrs = zarr.open(path, mode="r").attrs.asdict() + except zarr.errors.PathNotFoundError as e: + LOG.error(f"Failed to open {path}") + LOG.error(e) + exit(0) + + FIXES = { + "history": fix_history, + "provenance_load": fix_provenance, + "provenance_statistics": fix_provenance, + "order_by": fix_order_by, + } + REMOVE = ["_create_yaml_config"] + + before = json.dumps(attrs, sort_keys=True) + + fixed_attrs = {} + for k, v in attrs.items(): + v = attrs[k] + if k in REMOVE: + LOG.info(f"✅ Remove {k}") + continue + + if k not in FIXES: + assert not k.startswith("provenance"), f"[{k}]" + LOG.debug(f"✅ Don't fix {k}") + fixed_attrs[k] = v + continue + + new_v = FIXES[k](v) + if json.dumps(new_v, sort_keys=True) != json.dumps(v, sort_keys=True): + LOG.info(f"✅ Fix {k}") + if verbose: + LOG.info(f" Before : {k}= {v}") + LOG.info(f" After : {k}= {new_v}") + else: + LOG.debug(f"✅ Unchanged {k}") + fixed_attrs[k] = new_v + + if dry_run: + return + z = zarr.open(path, mode="r+") + + for k in list(z.attrs.keys()): + if k not in fixed_attrs: + del z.attrs[k] + for k, v in fixed_attrs.items(): + z.attrs[k] = v + + after = json.dumps(z.attrs.asdict(), sort_keys=True) + if before != after: + LOG.info("Dataset changed by patch") + + assert json.dumps(z.attrs.asdict(), sort_keys=True) == json.dumps(fixed_attrs, sort_keys=True) diff --git a/src/anemoi/datasets/create/persistent.py b/src/anemoi/datasets/create/persistent.py new file mode 100644 index 000000000..e52938507 --- /dev/null +++ b/src/anemoi/datasets/create/persistent.py @@ -0,0 +1,269 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import glob +import hashlib +import json +import logging +import os +import pickle +import shutil +import socket +from collections.abc import Iterator +from typing import Any + +import numpy as np +from anemoi.utils.provenance import gather_provenance_info + +LOG = logging.getLogger(__name__) + + +class PersistentDict: + """A dictionary-like object that persists its contents to disk using pickle files. + + Attributes + ---------- + version : int + The version of the PersistentDict. + dirname : str + The directory where the data is stored. + name : str + The name of the directory. + ext : str + The extension of the directory. + """ + + version = 3 + + # Used in parrallel, during data loading, + # to write data in pickle files. + def __init__(self, directory: str, create: bool = True): + """Initialize the PersistentDict. + + Parameters + ---------- + directory : str + The directory where the data will be stored. + create : bool, optional + Whether to create the directory if it doesn't exist. + """ + self.dirname = directory + self.name, self.ext = os.path.splitext(os.path.basename(self.dirname)) + if create: + self.create() + + def create(self) -> None: + """Create the directory if it doesn't exist.""" + os.makedirs(self.dirname, exist_ok=True) + + def delete(self) -> None: + """Delete the directory and its contents.""" + try: + shutil.rmtree(self.dirname) + except FileNotFoundError: + pass + + def __str__(self) -> str: + """Return a string representation of the PersistentDict.""" + return f"{self.__class__.__name__}({self.dirname})" + + def items(self) -> Iterator[Any]: + """Yield items stored in the directory. + + Yields + ------ + Iterator[Any] + An iterator over the items. + """ + # use glob to read all pickles + files = glob.glob(self.dirname + "/*.pickle") + LOG.debug(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") + assert len(files) > 0, f"No files found in {self.dirname}" + for f in files: + with open(f, "rb") as f: + yield pickle.load(f) + + def add_provenance(self, **kwargs: Any) -> None: + """Add provenance information to the directory. + + Parameters + ---------- + **kwargs : Any + Additional provenance information. + """ + path = os.path.join(self.dirname, "provenance.json") + if os.path.exists(path): + return + out = dict(provenance=gather_provenance_info(), **kwargs) + with open(path, "w") as f: + json.dump(out, f) + + def add(self, elt: Any, *, key: Any) -> None: + """Add an element to the PersistentDict. + + Parameters + ---------- + elt : Any + The element to add. + key : Any + The key associated with the element. + """ + self[key] = elt + + def __setitem__(self, key: Any, elt: Any) -> None: + """Set an item in the PersistentDict. + + Parameters + ---------- + key : Any + The key associated with the element. + elt : Any + The element to set. + """ + h = hashlib.sha256(str(key).encode("utf-8")).hexdigest() + path = os.path.join(self.dirname, f"{h}.pickle") + + if os.path.exists(path): + LOG.warning(f"{path} already exists") + + tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" + with open(tmp_path, "wb") as f: + pickle.dump((key, elt), f) + shutil.move(tmp_path, path) + + LOG.debug(f"Written {self.name} data for len {key} in {path}") + + def flush(self) -> None: + """Flush the PersistentDict (no-op).""" + pass + + +class BufferedPersistentDict(PersistentDict): + """A buffered version of PersistentDict that stores elements in memory before persisting them to disk. + + Attributes + ---------- + buffer_size : int + The size of the buffer. + elements : list + The list of elements in the buffer. + keys : list + The list of keys in the buffer. + storage : PersistentDict + The underlying PersistentDict used for storage. + """ + + def __init__(self, buffer_size: int = 1000, **kwargs: Any): + """Initialize the BufferedPersistentDict. + + Parameters + ---------- + buffer_size : int, optional + The size of the buffer. + **kwargs : Any + Additional arguments for PersistentDict. + """ + self.buffer_size = buffer_size + self.elements = [] + self.keys = [] + self.storage = PersistentDict(**kwargs) + + def add(self, elt: Any, *, key: Any) -> None: + """Add an element to the BufferedPersistentDict. + + Parameters + ---------- + elt : Any + The element to add. + key : Any + The key associated with the element. + """ + self.elements.append(elt) + self.keys.append(key) + if len(self.keys) > self.buffer_size: + self.flush() + + def flush(self) -> None: + """Flush the buffer and store the elements in PersistentDict.""" + k = sorted(self.keys) + self.storage.add(self.elements, key=k) + self.elements = [] + self.keys = [] + + def items(self) -> Iterator[tuple[Any, Any]]: + """Yield items stored in the BufferedPersistentDict. + + Yields + ------ + Iterator[Tuple[Any, Any]] + An iterator over the items. + """ + for keys, elements in self.storage.items(): + yield from zip(keys, elements) + + def delete(self) -> None: + """Delete the storage directory and its contents.""" + self.storage.delete() + + def create(self) -> None: + """Create the storage directory if it doesn't exist.""" + self.storage.create() + + +def build_storage(directory: str, create: bool = True) -> BufferedPersistentDict: + """Build a BufferedPersistentDict storage. + + Parameters + ---------- + directory : str + The directory where the data will be stored. + create : bool, optional + Whether to create the directory if it doesn't exist. + + Returns + ------- + BufferedPersistentDict + The created BufferedPersistentDict. + """ + return BufferedPersistentDict(directory=directory, create=create) + + +if __name__ == "__main__": + N = 3 + P = 2 + directory = "h" + p = PersistentDict(directory=directory) + print(p) + assert os.path.exists(directory) + import numpy as np + + arrs = [np.random.randint(1, 101, size=(P,)) for _ in range(N)] + dates = [np.array([np.datetime64(f"2021-01-0{_+1}") + np.timedelta64(i, "h") for i in range(P)]) for _ in range(N)] + + print() + print("Writing the data") + for i in range(N): + _arr = arrs[i] + _dates = dates[i] + print(f"Writing : {i=}, {_arr=} {_dates=}") + p[_dates] = (i, _arr) + + print() + print("Reading the data back") + + p = PersistentDict(directory="h") + for _dates, (i, _arr) in p.items(): + print(f"{i=}, {_arr=}, {_dates=}") + + assert np.allclose(_arr, arrs[i]) + + assert len(_dates) == len(dates[i]) + for a, b in zip(_dates, dates[i]): + assert a == b diff --git a/src/anemoi/datasets/create/size.py b/src/anemoi/datasets/create/size.py new file mode 100644 index 000000000..4cffd66d7 --- /dev/null +++ b/src/anemoi/datasets/create/size.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import os + +import tqdm +from anemoi.utils.humanize import bytes_to_human + +LOG = logging.getLogger(__name__) + + +def compute_directory_sizes(path: str) -> dict[str, int] | None: + """Computes the total size and number of files in a directory. + + Parameters + ---------- + path : str + The path to the directory. + + Returns + ------- + dict of str to int or None + A dictionary with the total size and number of files, or None if the path is not a directory. + """ + if not os.path.isdir(path): + return None + + size, n = 0, 0 + bar = tqdm.tqdm(iterable=os.walk(path), desc=f"Computing size of {path}") + for dirpath, _, filenames in bar: + for filename in filenames: + file_path = os.path.join(dirpath, filename) + size += os.path.getsize(file_path) + n += 1 + + LOG.info(f"Total size: {bytes_to_human(size)}") + LOG.info(f"Total number of files: {n}") + + return dict(total_size=size, total_number_of_files=n) diff --git a/src/anemoi/datasets/create/source.py b/src/anemoi/datasets/create/source.py index 3bbb52b6a..8c9c3044d 100644 --- a/src/anemoi/datasets/create/source.py +++ b/src/anemoi/datasets/create/source.py @@ -12,7 +12,7 @@ import earthkit.data as ekd -from anemoi.datasets.create.gridded.typing import DateList +from anemoi.datasets.create.typing import DateList class Source(ABC): diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index ce4ff6266..40b8749f6 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -20,13 +20,11 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.create.sources import source_registry - -from .legacy import LegacySource -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars +from anemoi.datasets.create.utils import to_datetime_list LOG = logging.getLogger(__name__) -MISSING_VALUE = 1e-38 def _member(field: Any) -> int: @@ -169,7 +167,6 @@ def write(self, template: Any) -> None: # are used to store the end step edition = template.metadata("edition") - assert np.all(self.values != MISSING_VALUE) if edition == 1 and self.endStep > 254: self.out.write( @@ -178,7 +175,6 @@ def write(self, template: Any) -> None: stepType="instant", step=self.endStep, check_nans=True, - missing_value=MISSING_VALUE, ) else: self.out.write( @@ -188,7 +184,6 @@ def write(self, template: Any) -> None: startStep=self.startStep, endStep=self.endStep, check_nans=True, - missing_value=MISSING_VALUE, ) self.values = None self.done = True @@ -209,6 +204,9 @@ def add(self, field: Any, values: NDArray[Any]) -> None: if step not in self.steps: return + if not np.all(values >= 0): + warnings.warn(f"Negative values for {field}: {np.nanmin(values)} {np.nanmax(values)}") + assert not self.done, (self.key, step) assert step not in self.seen, (self.key, step) @@ -967,76 +965,97 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@source_registry.register("accumulations") -class AccumulationsSource(LegacySource): - - @staticmethod - def _execute( - context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any - ) -> Any: - """Computes accumulations based on the provided context, dates, and request parameters. +@legacy_source(__file__) +def accumulations( + context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any +) -> Any: + """Computes accumulations based on the provided context, dates, and request parameters. - Parameters - ---------- - context : Any - Context for the computation. - dates : List[datetime.datetime] - List of dates. - use_cdsapi_dataset : Optional[str], optional - CDSAPI dataset to use. Defaults to None. - **request : Any - Additional request parameters. + Parameters + ---------- + context : Any + Context for the computation. + dates : List[datetime.datetime] + List of dates. + use_cdsapi_dataset : Optional[str], optional + CDSAPI dataset to use. Defaults to None. + **request : Any + Additional request parameters. - Returns - ------- - Any - The computed accumulations. - """ + Returns + ------- + Any + The computed accumulations. + """ - if ( - request.get("class") == "ea" - and request.get("stream", "oper") == "oper" - and request.get("accumulation_period") == 24 - ): - from .accumulations2 import Accumulations2Source + if ( + request.get("class") == "ea" + and request.get("stream", "oper") == "oper" + and request.get("accumulation_period") == 24 + ): + from anemoi.datasets.create.sources.accumulations2 import accumulations as accumulations2 - LOG.warning( - "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" - ) - return Accumulations2Source._execute(context, dates, **request) - - _to_list(request["param"]) - class_ = request.get("class", "od") - stream = request.get("stream", "oper") - - user_accumulation_period = request.pop("accumulation_period", 6) - accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) - user_date = request.pop("date", None) - - # If `data_accumulation_period` is not set, this means that the accumulations are from the start - # of the forecast. - - KWARGS = { - ("od", "oper"): dict(patch=_scda), - ("od", "elda"): dict(base_times=(6, 18)), - ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), - ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), - ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), - ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), - ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), - } - - kwargs = KWARGS.get((class_, stream), {}) - - context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") - - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - accumulations_reset_frequency=accumulations_reset_frequency, - use_cdsapi_dataset=use_cdsapi_dataset, - user_date=user_date, - **kwargs, + LOG.warning( + "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" ) + return accumulations2(context, dates, **request) + + _to_list(request["param"]) + class_ = request.get("class", "od") + stream = request.get("stream", "oper") + + user_accumulation_period = request.pop("accumulation_period", 6) + accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) + user_date = request.pop("date", None) + + # If `data_accumulation_period` is not set, this means that the accumulations are from the start + # of the forecast. + + KWARGS = { + ("od", "oper"): dict(patch=_scda), + ("od", "elda"): dict(base_times=(6, 18)), + ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), + ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), + ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), + ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), + ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), + } + + kwargs = KWARGS.get((class_, stream), {}) + + context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") + + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + accumulations_reset_frequency=accumulations_reset_frequency, + use_cdsapi_dataset=use_cdsapi_dataset, + user_date=user_date, + **kwargs, + ) + + +execute = accumulations + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + class: ea + expver: '0001' + grid: 20./20. + levtype: sfc +# number: [0, 1] +# stream: enda + param: [cp, tp] +# accumulation_period: 6h + """ + ) + dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) + + for f in accumulations(None, dates, **config): + print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/accumulations2.py b/src/anemoi/datasets/create/sources/accumulations2.py index 2f719e46e..3c34d392e 100644 --- a/src/anemoi/datasets/create/sources/accumulations2.py +++ b/src/anemoi/datasets/create/sources/accumulations2.py @@ -18,10 +18,9 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.sources import source_registry - -from .legacy import LegacySource -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars +from anemoi.datasets.create.utils import to_datetime_list LOG = logging.getLogger(__name__) @@ -599,20 +598,49 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@source_registry.register("accumulations2") -class Accumulations2Source(LegacySource): +@legacy_source(__file__) +def accumulations(context, dates, **request): + _to_list(request["param"]) + user_accumulation_period = request.pop("accumulation_period", 6) + user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) - @staticmethod - def _execute(context, dates, **request): - _to_list(request["param"]) - user_accumulation_period = request.pop("accumulation_period", 6) - user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) + context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") - context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + ) - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - ) + +execute = accumulations + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + class: ea + expver: '0001' + grid: 20./20. + levtype: sfc +# number: [0, 1] +# stream: enda + param: [cp, tp] +# accumulation_period: 6h + accumulation_period: 2 + """ + ) + dates = yaml.safe_load("[2022-12-31 00:00, 2022-12-31 06:00]") + # dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) + + class Context: + use_grib_paramid = True + + def trace(self, *args): + print(*args) + + for f in accumulations(Context, dates, **config): + print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/anemoi_dataset.py b/src/anemoi/datasets/create/sources/anemoi_dataset.py index 743605bb9..a05e7df51 100644 --- a/src/anemoi/datasets/create/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/create/sources/anemoi_dataset.py @@ -9,69 +9,65 @@ import numpy as np -from . import source_registry -from .legacy import LegacySource +from anemoi.datasets.create.sources.legacy import legacy_source -@source_registry.register("anemoi_dataset") -class AnemoiDatasetSource(LegacySource): +@legacy_source(__file__) +def execute(context, dates, params=None, **kwargs): + import earthkit.data as ekd - @staticmethod - def _execute(context, dates, params=None, **kwargs): - import earthkit.data as ekd + from anemoi.datasets import open_dataset - from anemoi.datasets import open_dataset + ds = open_dataset(**kwargs) + # dates_to_index = {date: i for i, date in enumerate(ds.dates)} - ds = open_dataset(**kwargs) - # dates_to_index = {date: i for i, date in enumerate(ds.dates)} + indices = [] + for date in dates: + idx = np.where(ds.dates == date)[0] + if len(idx) == 0: + continue + indices.append((int(idx[0]), date)) - indices = [] - for date in dates: - idx = np.where(ds.dates == date)[0] - if len(idx) == 0: - continue - indices.append((int(idx[0]), date)) - - vars = ds.variables - if params is None: - params = vars + vars = ds.variables + if params is None: + params = vars - if not isinstance(params, (list, tuple, set)): - params = [params] + if not isinstance(params, (list, tuple, set)): + params = [params] - params = set(params) - results = [] + params = set(params) + results = [] - ensemble = ds.shape[2] > 1 - latitudes = ds.latitudes - longitudes = ds.longitudes + ensemble = ds.shape[2] > 1 + latitudes = ds.latitudes + longitudes = ds.longitudes - for idx, date in indices: + for idx, date in indices: - metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) + metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) - for j, y in enumerate(ds[idx]): + for j, y in enumerate(ds[idx]): - param = vars[j] - if param not in params: - continue + param = vars[j] + if param not in params: + continue - # metadata['name'] = param - # metadata['param_level'] = param - metadata["param"] = param + # metadata['name'] = param + # metadata['param_level'] = param + metadata["param"] = param - for k, e in enumerate(y): - if ensemble: - metadata["number"] = k + 1 + for k, e in enumerate(y): + if ensemble: + metadata["number"] = k + 1 - metadata["values"] = e + metadata["values"] = e - results.append(metadata.copy()) + results.append(metadata.copy()) - print(results[0].keys()) + print(results[0].keys()) - # "list-of-dicts" does support resolution - results = ekd.from_source("list-of-dicts", results) + # "list-of-dicts" does support resolution + results = ekd.from_source("list-of-dicts", results) - # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) - return results + # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) + return results diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index a805c4b16..accde7936 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -11,42 +11,41 @@ from earthkit.data import from_source -from . import source_registry -from .legacy import LegacySource - - -@source_registry.register("constants") -class ConstantsSource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: - """Deprecated function to retrieve constants data. - - Parameters - ---------- - context : Any - The context object for tracing. - dates : list of str - List of dates for which data is required. - template : dict of str to Any - Template dictionary for the data source. - param : str - Parameter to retrieve. - - Returns - ------- - Any - Data retrieved from the source. - """ - from warnings import warn - - warn( - "The source `constants` is deprecated, use `forcings` instead.", - DeprecationWarning, - stacklevel=2, - ) - context.trace("✅", f"from_source(constants, {template}, {param}") - if len(template) == 0: - raise ValueError("Forcings template is empty.") - - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) +from anemoi.datasets.create.sources.legacy import legacy_source + + +@legacy_source(__file__) +def constants(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: + """Deprecated function to retrieve constants data. + + Parameters + ---------- + context : Any + The context object for tracing. + dates : list of str + List of dates for which data is required. + template : dict of str to Any + Template dictionary for the data source. + param : str + Parameter to retrieve. + + Returns + ------- + Any + Data retrieved from the source. + """ + from warnings import warn + + warn( + "The source `constants` is deprecated, use `forcings` instead.", + DeprecationWarning, + stacklevel=2, + ) + context.trace("✅", f"from_source(constants, {template}, {param}") + if len(template) == 0: + raise ValueError("Forcings template is empty.") + + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) + + +execute: Any = constants diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 8e5a329f5..0b293845e 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -8,25 +8,17 @@ # nor does it submit to any jurisdiction. -from ..source import Source -from . import source_registry +from anemoi.datasets.create.source import ObservationsSource +from anemoi.datasets.create.sources import source_registry @source_registry.register("csv") -class CSVSource(Source): +class CSVSource(ObservationsSource): """A source that reads data from a CSV file.""" emoji = "📄" # For tracing - def __init__( - self, - context: any, - path: str, - columns: list = None, - flavour: dict = None, - *args, - **kwargs, - ): + def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): """Initialise the CSVSource. Parameters @@ -35,36 +27,16 @@ def __init__( The context for the data source. filepath : str The path to the CSV file. - columns : list, optional - The list of columns to read from the CSV file. *args : tuple Additional positional arguments. **kwargs : dict Additional keyword arguments. """ super().__init__(context, *args, **kwargs) - self.path = path - self.columns = columns - - self.flavour = { - "latitude": "latitude", - "longitude": "longitude", - "time": "time", - } - - if flavour is not None: - self.flavour.update(flavour) def execute(self, dates): import pandas as pd - if self.columns is None: - frame = pd.read_csv(self.path) - else: - frame = pd.read_csv(self.path, usecols=self.columns) - - start, end = dates.window.start_date, dates.window.end_date - mask = (frame[self.flavour["time"]] >= start) & (frame[self.flavour["time"]] <= end) - frame = frame.loc[mask] - return frame + frame = pd.read_csv(self.path) + print(frame) diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py index 41734e9b6..fdd79af8d 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/create/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/empty.py b/src/anemoi/datasets/create/sources/empty.py index fa8bc8d84..f948810f5 100644 --- a/src/anemoi/datasets/create/sources/empty.py +++ b/src/anemoi/datasets/create/sources/empty.py @@ -12,29 +12,25 @@ import earthkit.data as ekd -from . import source_registry -from .legacy import LegacySource - - -@source_registry.register("empty") -class EmptySource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: - """Executes the loading of an empty data source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - Loaded empty data source. - """ - return ekd.from_source("empty") +from anemoi.datasets.create.sources.legacy import legacy_source + + +@legacy_source(__file__) +def execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: + """Executes the loading of an empty data source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + Loaded empty data source. + """ + return ekd.from_source("empty") diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py index 67bfe8870..81cdb7e13 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/create/sources/fdb.py @@ -16,10 +16,9 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry -from anemoi.datasets.create.gridded.typing import DateList - -from ..source import Source -from . import source_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.typing import DateList @source_registry.register("fdb") @@ -125,7 +124,7 @@ def _time_request_keys(dt: datetime, offset_from_date: bool | None = None) -> st def _shortname_to_paramid(shortname: list[str], param_id_map: dict[str, int] | None = None) -> list[int]: - from .mars import use_grib_paramid + from anemoi.datasets.create.sources.mars import use_grib_paramid """Convert a shortname to a parameter ID.""" if param_id_map is None: diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index 6070772fc..88eca92e4 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -11,32 +11,31 @@ from earthkit.data import from_source -from . import source_registry -from .legacy import LegacySource - - -@source_registry.register("forcings") -class ForcingsSource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[str], template: str, param: str) -> Any: - """Loads forcing data from a specified source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - template : FieldList - Template for the data source. - param : str - Parameter for the data source. - - Returns - ------- - object - Loaded forcing data. - """ - context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) +from anemoi.datasets.create.sources.legacy import legacy_source + + +@legacy_source(__file__) +def forcings(context: Any, dates: list[str], template: str, param: str) -> Any: + """Loads forcing data from a specified source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + template : FieldList + Template for the data source. + param : str + Parameter for the data source. + + Returns + ------- + object + Loaded forcing data. + """ + context.trace("✅", f"from_source(forcings, {template}, {param}") + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) + + +execute = forcings diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index d709efc5e..550709f98 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -20,8 +20,7 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from . import source_registry -from .legacy import LegacySource +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) @@ -48,14 +47,6 @@ def check(ds: Any, paths: list[str], **kwargs: Any) -> None: if isinstance(v, (tuple, list)): count *= len(v) - # in the case of static data (e.g repeated dates) dates might be empty - if len(ds) != count and kwargs.get("dates", []) == []: - LOG.warning( - f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})" - f" Received empty dates - assuming this is static data." - ) - return - if len(ds) != count: raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})") @@ -82,85 +73,81 @@ def _expand(paths: list[str]) -> Any: yield path -@source_registry.register("grib") -class GribSource(LegacySource): - - @staticmethod - def _execute( - context: Any, - dates: list[Any], - path: str | list[str], - flavour: str | dict[str, Any] | None = None, - grid_definition: dict[str, Any] | None = None, - *args: Any, - **kwargs: Any, - ) -> ekd.FieldList: - """Executes the function to load data from GRIB files. - - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : list of Any - List of dates. - path : str or list of str - Path or list of paths to the GRIB files. - flavour : str or dict of str to Any, optional - Flavour information, by default None. - grid_definition : dict of str to Any, optional - Grid definition configuration to create a Grid object, by default None. - *args : Any - Additional positional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The loaded dataset. - """ - given_paths = path if isinstance(path, list) else [path] - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - if grid_definition is not None: - grid = grid_registry.from_config(grid_definition) - else: - grid = None - - ds = from_source("empty") - dates = [d.isoformat() for d in dates] - - for path in given_paths: - - # do not substitute if not needed - if "{" not in path: - paths = [path] - else: - paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) - - for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): - if name in kwargs: - raise ValueError(f"MARS interpolation parameter '{name}' not supported") - - for path in _expand(paths): - context.trace("📁", "PATH", path) - s = from_source("file", path) - if flavour is not None: - s = flavour.map(s) - sel_kwargs = kwargs.copy() - if dates != []: - sel_kwargs["valid_datetime"] = dates - s = s.sel(**sel_kwargs) - ds = ds + s - - if kwargs and not context.partial_ok: - check(ds, given_paths, valid_datetime=dates, **kwargs) - - if grid is not None: - ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) - - if len(ds) == 0: - LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") - - return ds +@legacy_source(__file__) +def execute( + context: Any, + dates: list[Any], + path: str | list[str], + flavour: str | dict[str, Any] | None = None, + grid_definition: dict[str, Any] | None = None, + *args: Any, + **kwargs: Any, +) -> ekd.FieldList: + """Executes the function to load data from GRIB files. + + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : list of Any + List of dates. + path : str or list of str + Path or list of paths to the GRIB files. + flavour : str or dict of str to Any, optional + Flavour information, by default None. + grid_definition : dict of str to Any, optional + Grid definition configuration to create a Grid object, by default None. + *args : Any + Additional positional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Any + The loaded dataset. + """ + given_paths = path if isinstance(path, list) else [path] + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + if grid_definition is not None: + grid = grid_registry.from_config(grid_definition) + else: + grid = None + + ds = from_source("empty") + dates = [d.isoformat() for d in dates] + + for path in given_paths: + paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) + + for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): + if name in kwargs: + raise ValueError(f"MARS interpolation parameter '{name}' not supported") + + for path in _expand(paths): + context.trace("📁", "PATH", path) + s = from_source("file", path) + if flavour is not None: + s = flavour.map(s) + s = s.sel(valid_datetime=dates, **kwargs) + ds = ds + s + + if kwargs and not context.partial_ok: + check(ds, given_paths, valid_datetime=dates, **kwargs) + + if grid is not None: + + lat, lon = grid.latlon() + + assert len(lat) == len(lon), (len(lat), len(lon)) + for f in ds: + assert len(f.to_numpy(flatten=True)) == len(lat), (len(f.to_numpy(flatten=True)), len(lat)) + + ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) + + if len(ds) == 0: + LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") + + return ds diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index 0d86732f6..160ff3f3a 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -19,8 +19,7 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from . import source_registry -from .legacy import LegacySource +from anemoi.datasets.create.sources.legacy import legacy_source LOG = logging.getLogger(__name__) @@ -570,47 +569,44 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]: yield data -@source_registry.register("grib_index") -class GribIndexSource(LegacySource): - - @staticmethod - def _execute( - context: Any, - dates: list[Any], - indexdb: str, - flavour: str | None = None, - **kwargs: Any, - ) -> FieldArray: - """Execute the GRIB data retrieval process. - - Parameters - ---------- - context : Any - The execution context. - dates : List[Any] - List of dates to retrieve data for. - indexdb : str - Path to the GRIB index database. - flavour : Optional[str], optional - Flavour configuration for mapping fields, by default None. - **kwargs : Any - Additional filtering criteria. - - Returns - ------- - FieldArray - An array of retrieved GRIB fields. - """ - index = GribIndex(indexdb) - result = [] - - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - for grib in index.retrieve(dates, **kwargs): - field = ekd.from_source("memory", grib)[0] - if flavour: - field = flavour.apply(field) - result.append(field) - - return FieldArray(result) +@legacy_source(__file__) +def execute( + context: Any, + dates: list[Any], + indexdb: str, + flavour: str | None = None, + **kwargs: Any, +) -> FieldArray: + """Execute the GRIB data retrieval process. + + Parameters + ---------- + context : Any + The execution context. + dates : List[Any] + List of dates to retrieve data for. + indexdb : str + Path to the GRIB index database. + flavour : Optional[str], optional + Flavour configuration for mapping fields, by default None. + **kwargs : Any + Additional filtering criteria. + + Returns + ------- + FieldArray + An array of retrieved GRIB fields. + """ + index = GribIndex(indexdb) + result = [] + + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + for grib in index.retrieve(dates, **kwargs): + field = ekd.from_source("memory", grib)[0] + if flavour: + field = flavour.apply(field) + result.append(field) + + return FieldArray(result) diff --git a/src/anemoi/datasets/create/sources/hindcasts.py b/src/anemoi/datasets/create/sources/hindcasts.py index b9985ccf1..d796a74af 100644 --- a/src/anemoi/datasets/create/sources/hindcasts.py +++ b/src/anemoi/datasets/create/sources/hindcasts.py @@ -12,10 +12,8 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources import source_registry - -from .legacy import LegacySource -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars LOGGER = logging.getLogger(__name__) @@ -38,57 +36,57 @@ def _to_list(x: list | tuple | Any) -> list[Any]: return [x] -@source_registry.register("hindcasts") -class HindcastsSource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: - """Generates hindcast requests based on the provided dates and request parameters. - - Parameters - ---------- - context : Any - The context containing the dates provider and trace method. - dates : List[Any] - A list of dates for which to generate hindcast requests. - request : Dict[str, Any] - Additional request parameters. - - Returns - ------- - MultiFieldList - A MultiFieldList containing the hindcast data. - """ - from anemoi.datasets.dates import HindcastsDates - - provider = context.dates_provider - assert isinstance(provider, HindcastsDates) - - context.trace("H️", f"hindcasts {len(dates)=}") - - request["param"] = _to_list(request["param"]) - request["step"] = _to_list(request.get("step", 0)) - request["step"] = [int(_) for _ in request["step"]] - - context.trace("H️", f"hindcast {request}") - - requests = [] - for d in dates: - r = request.copy() - hindcast = provider.mapping[d] - r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") - r["date"] = hindcast.refdate.strftime("%Y-%m-%d") - r["time"] = hindcast.refdate.strftime("%H") - r["step"] = hindcast.step - requests.append(r) - - if len(requests) == 0: - return MultiFieldList([]) - - return mars( - context, - dates, - *requests, - date_key="hdate", - request_already_using_valid_datetime=True, - ) +@legacy_source(__file__) +def hindcasts(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: + """Generates hindcast requests based on the provided dates and request parameters. + + Parameters + ---------- + context : Any + The context containing the dates provider and trace method. + dates : List[Any] + A list of dates for which to generate hindcast requests. + request : Dict[str, Any] + Additional request parameters. + + Returns + ------- + MultiFieldList + A MultiFieldList containing the hindcast data. + """ + from anemoi.datasets.dates import HindcastsDates + + provider = context.dates_provider + assert isinstance(provider, HindcastsDates) + + context.trace("H️", f"hindcasts {len(dates)=}") + + request["param"] = _to_list(request["param"]) + request["step"] = _to_list(request.get("step", 0)) + request["step"] = [int(_) for _ in request["step"]] + + context.trace("H️", f"hindcast {request}") + + requests = [] + for d in dates: + r = request.copy() + hindcast = provider.mapping[d] + r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") + r["date"] = hindcast.refdate.strftime("%Y-%m-%d") + r["time"] = hindcast.refdate.strftime("%H") + r["step"] = hindcast.step + requests.append(r) + + if len(requests) == 0: + return MultiFieldList([]) + + return mars( + context, + dates, + *requests, + date_key="hdate", + request_already_using_valid_datetime=True, + ) + + +execute = hindcasts diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index f9a0288a0..0de230d29 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -8,13 +8,14 @@ # nor does it submit to any jurisdiction. +import inspect import logging -from abc import abstractmethod +import os +from collections.abc import Callable from typing import Any -from anemoi.datasets.create.input.context import Context - -from ..source import Source +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry LOG = logging.getLogger(__name__) @@ -24,7 +25,7 @@ class LegacySource(Source): Parameters ---------- - context : Context + context : Any The context in which the source is created. *args : tuple Positional arguments. @@ -32,15 +33,65 @@ class LegacySource(Source): Keyword arguments. """ - def __init__(self, context: Context, *args: Any, **kwargs: Any) -> None: + def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None: super().__init__(context, *args, **kwargs) self.args = args self.kwargs = kwargs - @staticmethod - @abstractmethod - def _execute(context, *args, **kwargs): - pass - def execute(self, dates: Any) -> Any: - return self._execute(self.context, dates, *self.args, **self.kwargs) +class legacy_source: + """A decorator class for legacy sources. + + Parameters + ---------- + name : str + The name of the legacy source. + """ + + def __init__(self, name: str) -> None: + name, _ = os.path.splitext(os.path.basename(name)) + self.name = name + + def __call__(self, execute: Callable) -> Callable: + """Call method to wrap the execute function. + + Parameters + ---------- + execute : function + The execute function to be wrapped. + + Returns + ------- + function + The wrapped execute function. + """ + this = self + name = f"Legacy{self.name.title()}Source" + source = ".".join([execute.__module__, execute.__name__]) + + def execute_wrapper(self, dates) -> Any: + """Wrapper method to call the execute function.""" + + # args, kwargs = resolve(context, (self.args, self.kwargs)) + args, kwargs = self.args, self.kwargs + + try: + return execute(self.context, dates, *args, **kwargs) + except TypeError: + LOG.error(f"Error executing source {this.name} from {source}") + LOG.error(f"Function signature is: {inspect.signature(execute)}") + LOG.error(f"Arguments are: {args=}, {kwargs=}") + raise + + klass = type( + name, + (LegacySource,), + { + "execute": execute_wrapper, + "_source": source, + }, + ) + + source_registry.register(self.name)(klass) + + return execute diff --git a/src/anemoi/datasets/create/sources/mars.py b/src/anemoi/datasets/create/sources/mars.py index 25e223cb4..d59f6034d 100644 --- a/src/anemoi/datasets/create/sources/mars.py +++ b/src/anemoi/datasets/create/sources/mars.py @@ -16,9 +16,8 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.create.sources import source_registry - -from .legacy import LegacySource +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.utils import to_datetime_list DEBUG = False @@ -358,111 +357,135 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]: ] -@source_registry.register("mars") -class MarsSource(LegacySource): - - @staticmethod - def _execute( - context: Any, - dates: list[datetime.datetime], - *requests: dict[str, Any], - request_already_using_valid_datetime: bool = False, - date_key: str = "date", - use_cdsapi_dataset: str | None = None, - **kwargs: Any, - ) -> Any: - """Executes MARS requests based on the given context, dates, and other parameters. - - Parameters - ---------- - context : Any - The context for the requests. - dates : List[datetime.datetime] - The list of dates to be used in the requests. - requests : Dict[str, Any] - The input requests to be executed. - request_already_using_valid_datetime : bool, optional - Flag indicating if the requests already use valid datetime. - date_key : str, optional - The key for the date in the requests. - use_cdsapi_dataset : Optional[str], optional - The dataset to be used with CDS API. - kwargs : Any - Additional keyword arguments for the requests. - - Returns - ------- - Any - The resulting dataset. +@legacy_source(__file__) +def mars( + context: Any, + dates: list[datetime.datetime], + *requests: dict[str, Any], + request_already_using_valid_datetime: bool = False, + date_key: str = "date", + use_cdsapi_dataset: str | None = None, + **kwargs: Any, +) -> Any: + """Executes MARS requests based on the given context, dates, and other parameters. + + Parameters + ---------- + context : Any + The context for the requests. + dates : List[datetime.datetime] + The list of dates to be used in the requests. + requests : Dict[str, Any] + The input requests to be executed. + request_already_using_valid_datetime : bool, optional + Flag indicating if the requests already use valid datetime. + date_key : str, optional + The key for the date in the requests. + use_cdsapi_dataset : Optional[str], optional + The dataset to be used with CDS API. + kwargs : Any + Additional keyword arguments for the requests. + + Returns + ------- + Any + The resulting dataset. + """ + + if not requests: + requests = [kwargs] + + for r in requests: + param = r.get("param", []) + if not isinstance(param, (list, tuple)): + param = [param] + # check for "Norway bug" where yaml transforms 'no' into False, etc. + for p in param: + if p is False: + raise ValueError( + "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" + ) + if p is None: + raise ValueError( + "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" + ) + if p is True: + raise ValueError( + "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" + ) + + if len(dates) == 0: # When using `repeated_dates` + assert len(requests) == 1, requests + assert "date" in requests[0], requests[0] + if isinstance(requests[0]["date"], datetime.date): + requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") + else: + requests = factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=request_already_using_valid_datetime, + date_key=date_key, + ) + + requests = list(requests) + + ds = from_source("empty") + context.trace("✅", f"{[str(d) for d in dates]}") + context.trace("✅", f"Will run {len(requests)} requests") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + context.trace("✅", f"mars {r}") + + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + + if context.use_grib_paramid and "param" in r: + r = use_grib_paramid(r) + + for k, v in r.items(): + if k not in MARS_KEYS: + raise ValueError( + f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" + ) + try: + if use_cdsapi_dataset: + ds = ds + from_source("cds", use_cdsapi_dataset, r) + else: + ds = ds + from_source("mars", **r) + except Exception as e: + if "File is empty:" not in str(e): + raise + return ds + + +execute = mars + + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( """ + - class: ea + expver: '0001' + grid: 20.0/20.0 + levtype: sfc + param: [2t] + # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z] + number: [0, 1] + + # - class: ea + # expver: '0001' + # grid: 20.0/20.0 + # levtype: pl + # param: [q] + # levelist: [1000, 850] - if not requests: - requests = [kwargs] - - for r in requests: - param = r.get("param", []) - if not isinstance(param, (list, tuple)): - param = [param] - # check for "Norway bug" where yaml transforms 'no' into False, etc. - for p in param: - if p is False: - raise ValueError( - "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" - ) - if p is None: - raise ValueError( - "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" - ) - if p is True: - raise ValueError( - "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" - ) - - if len(dates) == 0: # When using `repeated_dates` - assert len(requests) == 1, requests - assert "date" in requests[0], requests[0] - if isinstance(requests[0]["date"], datetime.date): - requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") - else: - requests = factorise_requests( - dates, - *requests, - request_already_using_valid_datetime=request_already_using_valid_datetime, - date_key=date_key, - ) + """ + ) + dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) - requests = list(requests) - - ds = from_source("empty") - context.trace("✅", f"{[str(d) for d in dates]}") - context.trace("✅", f"Will run {len(requests)} requests") - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - context.trace("✅", f"mars {r}") - - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - - if context.use_grib_paramid and "param" in r: - r = use_grib_paramid(r) - - for k, v in r.items(): - if k not in MARS_KEYS: - raise ValueError( - f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" - ) - try: - if use_cdsapi_dataset: - ds = ds + from_source("cds", use_cdsapi_dataset, r) - else: - ds = ds + from_source("mars", **r) - except Exception as e: - if "File is empty:" not in str(e): - raise - return ds - - -# TODO: make clearer the interface between sources that use mars. -# Currently some sources use mars as a function rather than through the registry, -# e.g. accumulations, accumulations2, hindcasts, recentre, tendencies -mars = MarsSource._execute + DEBUG = True + for f in mars(None, dates, *config): + print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/netcdf.py b/src/anemoi/datasets/create/sources/netcdf.py index e6f4271a7..606a8dd53 100644 --- a/src/anemoi/datasets/create/sources/netcdf.py +++ b/src/anemoi/datasets/create/sources/netcdf.py @@ -12,34 +12,30 @@ import earthkit.data as ekd -from . import source_registry -from .legacy import LegacySource -from .xarray import load_many - - -@source_registry.register("netcdf") -class NetCDFSource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the loading of multiple NetCDF files. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - path : str - Path to the directory containing the NetCDF files. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - object - The loaded data. - """ - return load_many("📁", context, dates, path, *args, **kwargs) +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many + + +@legacy_source(__file__) +def execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the loading of multiple NetCDF files. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + path : str + Path to the directory containing the NetCDF files. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + object + The loaded data. + """ + return load_many("📁", context, dates, path, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/opendap.py b/src/anemoi/datasets/create/sources/opendap.py index 86cd3e6d2..34e3fe94d 100644 --- a/src/anemoi/datasets/create/sources/opendap.py +++ b/src/anemoi/datasets/create/sources/opendap.py @@ -12,34 +12,30 @@ import earthkit.data as ekd -from . import source_registry -from .legacy import LegacySource -from .xarray import load_many - - -@source_registry.register("opendap") -class OpenDAPSource(LegacySource): - - @staticmethod - def _execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process from an OpenDAP source. - - Parameters - ---------- - context : dict - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - url : str - The URL of the OpenDAP source. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - xarray.Dataset - The loaded dataset. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many + + +@legacy_source(__file__) +def execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process from an OpenDAP source. + + Parameters + ---------- + context : dict + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + url : str + The URL of the OpenDAP source. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + xarray.Dataset + The loaded dataset. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py index b710bcbbe..07e8f0203 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/recentre.py b/src/anemoi/datasets/create/sources/recentre.py index 2d6c70b1d..d0959f664 100644 --- a/src/anemoi/datasets/create/sources/recentre.py +++ b/src/anemoi/datasets/create/sources/recentre.py @@ -11,10 +11,8 @@ from typing import Any from anemoi.datasets.compute.recentre import recentre as _recentre - -from . import source_registry -from .legacy import LegacySource -from .mars import mars +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.mars import mars def to_list(x: list | tuple | str) -> list: @@ -106,43 +104,43 @@ def load_if_needed(context: Any, dates: Any, dict_or_dataset: dict | Any) -> Any return dict_or_dataset -@source_registry.register("recentre") -class RecentreSource(LegacySource): - - @staticmethod - def _execute( - context: Any, - dates: Any, - members: dict | Any, - centre: dict | Any, - alpha: float = 1.0, - remapping: dict = {}, - patches: dict = {}, - ) -> Any: - """Recentres the members dataset using the centre dataset. - - Parameters - ---------- - context : Any - The context for recentering. - dates : Any - The dates for recentering. - members : Union[dict, Any] - The members dataset or request dictionary. - centre : Union[dict, Any] - The centre dataset or request dictionary. - alpha : float, optional - The alpha value for recentering. Defaults to 1.0. - remapping : dict, optional - The remapping dictionary. Defaults to {}. - patches : dict, optional - The patches dictionary. Defaults to {}. - - Returns - ------- - Any - The recentred dataset. - """ - members = load_if_needed(context, dates, members) - centre = load_if_needed(context, dates, centre) - return _recentre(members=members, centre=centre, alpha=alpha) +@legacy_source(__file__) +def recentre( + context: Any, + dates: Any, + members: dict | Any, + centre: dict | Any, + alpha: float = 1.0, + remapping: dict = {}, + patches: dict = {}, +) -> Any: + """Recentres the members dataset using the centre dataset. + + Parameters + ---------- + context : Any + The context for recentering. + dates : Any + The dates for recentering. + members : Union[dict, Any] + The members dataset or request dictionary. + centre : Union[dict, Any] + The centre dataset or request dictionary. + alpha : float, optional + The alpha value for recentering. Defaults to 1.0. + remapping : dict, optional + The remapping dictionary. Defaults to {}. + patches : dict, optional + The patches dictionary. Defaults to {}. + + Returns + ------- + Any + The recentred dataset. + """ + members = load_if_needed(context, dates, members) + centre = load_if_needed(context, dates, centre) + return _recentre(members=members, centre=centre, alpha=alpha) + + +execute = recentre diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index c484efd82..77a06c76c 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -10,8 +10,8 @@ import logging from collections import defaultdict +from collections.abc import Generator from typing import Any -from typing import Generator import numpy as np from anemoi.transform.fields import new_field_with_valid_datetime @@ -19,8 +19,18 @@ from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta -from ..source import Source -from . import source_registry +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources import source_registry + +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/source.py b/src/anemoi/datasets/create/sources/source.py new file mode 100644 index 000000000..1bac545d8 --- /dev/null +++ b/src/anemoi/datasets/create/sources/source.py @@ -0,0 +1,68 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from datetime import datetime +from typing import Any + +from earthkit.data import from_source + +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.utils import to_datetime_list + + +@legacy_source(__file__) +def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: + """Generates a source based on the provided context, dates, and additional keyword arguments. + + Parameters + ---------- + context : Optional[Any] + The context in which the source is generated. + dates : List[datetime] + A list of datetime objects representing the dates. + **kwargs : Any + Additional keyword arguments for the source generation. + + Returns + ------- + Any + The generated source. + """ + name = kwargs.pop("name") + context.trace("✅", f"from_source({name}, {dates}, {kwargs}") + if kwargs["date"] == "$from_dates": + kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates}) + if kwargs["time"] == "$from_dates": + kwargs["time"] = list({d.strftime("%H%M") for d in dates}) + return from_source(name, **kwargs) + + +execute = source + +if __name__ == "__main__": + import yaml + + config: dict[str, Any] = yaml.safe_load( + """ + name: mars + class: ea + expver: '0001' + grid: 20.0/20.0 + levtype: sfc + param: [2t] + number: [0, 1] + date: $from_dates + time: $from_dates + """ + ) + dates: list[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) + + for f in source(None, dates, **config): + print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/tendencies.py b/src/anemoi/datasets/create/sources/tendencies.py index cdf4ce291..222dca9a4 100644 --- a/src/anemoi/datasets/create/sources/tendencies.py +++ b/src/anemoi/datasets/create/sources/tendencies.py @@ -14,9 +14,8 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.sources import source_registry - -from .legacy import LegacySource +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.utils import to_datetime_list def _date_to_datetime(d: Any) -> Any: @@ -83,89 +82,116 @@ def group_by_field(ds: Any) -> dict[tuple, list[Any]]: return d -@source_registry.register("tendencies") -class TendenciesSource(LegacySource): +@legacy_source(__file__) +def tendencies(dates: list[datetime.datetime], time_increment: Any, **kwargs: Any) -> Any: + """Computes tendencies for the given dates and time increment. - @staticmethod - def _execute(dates: list[datetime.datetime], time_increment: Any, **kwargs: Any) -> Any: - """Computes tendencies for the given dates and time increment. + Parameters + ---------- + dates : List[datetime.datetime] + A list of datetime objects. + time_increment : Any + A time increment string ending with 'h' or a datetime.timedelta object. + **kwargs : Any + Additional keyword arguments. - Parameters - ---------- - dates : List[datetime.datetime] - A list of datetime objects. - time_increment : Any - A time increment string ending with 'h' or a datetime.timedelta object. - **kwargs : Any - Additional keyword arguments. + Returns + ------- + Any + A dataset object with computed tendencies. + """ + print("✅", kwargs) + time_increment = normalise_time_delta(time_increment) - Returns - ------- - Any - A dataset object with computed tendencies. - """ - print("✅", kwargs) - time_increment = normalise_time_delta(time_increment) + shifted_dates = [d - time_increment for d in dates] + all_dates = sorted(list(set(dates + shifted_dates))) - shifted_dates = [d - time_increment for d in dates] - all_dates = sorted(list(set(dates + shifted_dates))) + # from .mars import execute as mars + from anemoi.datasets.create.mars import execute as mars - from .mars import mars + ds = mars(dates=all_dates, **kwargs) - ds = mars(dates=all_dates, **kwargs) + dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"] + for d in all_dates: + assert d.isoformat() in dates_in_data, d - dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"] - for d in all_dates: - assert d.isoformat() in dates_in_data, d + ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates]) + ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates]) - ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates]) - ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates]) + assert len(ds1) == len(ds2), (len(ds1), len(ds2)) - assert len(ds1) == len(ds2), (len(ds1), len(ds2)) + group1 = group_by_field(ds1) + group2 = group_by_field(ds2) - group1 = group_by_field(ds1) - group2 = group_by_field(ds2) + assert group1.keys() == group2.keys(), (group1.keys(), group2.keys()) - assert group1.keys() == group2.keys(), (group1.keys(), group2.keys()) + # prepare output tmp file so we can read it back + tmp = temp_file() + path = tmp.path + out = new_grib_output(path) - # prepare output tmp file so we can read it back - tmp = temp_file() - path = tmp.path - out = new_grib_output(path) + for k in group1: + assert len(group1[k]) == len(group2[k]), k + print() + print("❌", k) - for k in group1: - assert len(group1[k]) == len(group2[k]), k - print() - print("❌", k) + for field, b_field in zip(group1[k], group2[k]): + for k in ["param", "level", "number", "grid", "shape"]: + assert field.metadata(k) == b_field.metadata(k), ( + k, + field.metadata(k), + b_field.metadata(k), + ) - for field, b_field in zip(group1[k], group2[k]): - for k in ["param", "level", "number", "grid", "shape"]: - assert field.metadata(k) == b_field.metadata(k), ( - k, - field.metadata(k), - b_field.metadata(k), - ) + c = field.to_numpy() + b = b_field.to_numpy() + assert c.shape == b.shape, (c.shape, b.shape) - c = field.to_numpy() - b = b_field.to_numpy() - assert c.shape == b.shape, (c.shape, b.shape) + ################ + # Actual computation happens here + x = c - b + ################ - ################ - # Actual computation happens here - x = c - b - ################ + assert x.shape == c.shape, c.shape + print(f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}") + out.write(x, template=field) - assert x.shape == c.shape, c.shape - print(f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}") - out.write(x, template=field) + out.close() - out.close() + from earthkit.data import from_source - from earthkit.data import from_source + ds = from_source("file", path) + # save a reference to the tmp file so it is deleted + # only when the dataset is not used anymore + ds._tmp = tmp + + return ds + + +execute = tendencies + +if __name__ == "__main__": + import yaml + + config = yaml.safe_load( + """ + + config: + time_increment: 12h + database: marser + class: ea + # date: computed automatically + # time: computed automatically + expver: "0001" + grid: 20.0/20.0 + levtype: sfc + param: [2t] + """ + )["config"] - ds = from_source("file", path) - # save a reference to the tmp file so it is deleted - # only when the dataset is not used anymore - ds._tmp = tmp + dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") + dates = to_datetime_list(dates) - return ds + DEBUG = True + for f in tendencies(dates, **config): + print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py index a735e52f6..5e3cc4c10 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/create/sources/xarray.py @@ -11,12 +11,11 @@ import earthkit.data as ekd -from anemoi.datasets.create.gridded.typing import DateList - -from ..source import Source -from .xarray_support import XarrayFieldList -from .xarray_support import load_many -from .xarray_support import load_one +from anemoi.datasets.create.source import Source +from anemoi.datasets.create.sources.xarray_support import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support import load_many +from anemoi.datasets.create.sources.xarray_support import load_one +from anemoi.datasets.create.typing import DateList __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py index 056d756ca..632a7cae2 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/create/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from . import source_registry -from .xarray import XarraySourceBase +from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.sources.xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 8e3cebc08..c33ce7bfc 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -15,11 +15,9 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList +from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.patterns import iterate_patterns - -from .. import source_registry -from ..legacy import LegacySource -from .fieldlist import XarrayFieldList +from anemoi.datasets.create.sources.xarray_support.fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) @@ -153,30 +151,26 @@ def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern: return MultiFieldList(result) -@source_registry.register("xarray") -class LegacyXarraySource(LegacySource): - name = "xarray" - - @staticmethod - def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the loading of datasets. - - Parameters - ---------- - context : Any - Context object. - dates : List[str] - List of dates. - url : str - URL pattern for loading datasets. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded datasets. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +@legacy_source("xarray") +def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the loading of datasets. + + Parameters + ---------- + context : Any + Context object. + dates : List[str] + List of dates. + url : str + URL pattern for loading datasets. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded datasets. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 78f7de041..85f9970f8 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from .coordinates import extract_single_value -from .coordinates import is_scalar -from .metadata import XArrayMetadata +from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py index 48f9cf0e1..174cb2716 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from .field import EmptyFieldList -from .flavour import CoordinateGuesser -from .patch import patch_dataset -from .time import Time -from .variable import FilteredVariable -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList +from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser +from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset +from anemoi.datasets.create.sources.xarray_support.time import Time +from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 80f0b6a62..74fcdbd03 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from .coordinates import Coordinate -from .coordinates import DateCoordinate -from .coordinates import EnsembleCoordinate -from .coordinates import LatitudeCoordinate -from .coordinates import LevelCoordinate -from .coordinates import LongitudeCoordinate -from .coordinates import PointCoordinate -from .coordinates import ScalarCoordinate -from .coordinates import StepCoordinate -from .coordinates import TimeCoordinate -from .coordinates import UnsupportedCoordinate -from .coordinates import XCoordinate -from .coordinates import YCoordinate -from .coordinates import is_scalar -from .grid import Grid -from .grid import MeshedGrid -from .grid import MeshProjectionGrid -from .grid import UnstructuredGrid -from .grid import UnstructuredProjectionGrid +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate +from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar +from anemoi.datasets.create.sources.xarray_support.grid import Grid +from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid +from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid +from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 23713ae74..2230db3ef 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from .field import XArrayField + from anemoi.datasets.create.sources.xarray_support.field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py index 847b21598..7b1f60e58 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/create/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from .coordinates import Coordinate -from .variable import Variable +from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate +from anemoi.datasets.create.sources.xarray_support.variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 5d2c1c5b1..13d6fa4e2 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from .field import XArrayField +from anemoi.datasets.create.sources.xarray_support.field import XArrayField LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_zarr.py b/src/anemoi/datasets/create/sources/xarray_zarr.py index 2e89981bd..2f96ab207 100644 --- a/src/anemoi/datasets/create/sources/xarray_zarr.py +++ b/src/anemoi/datasets/create/sources/xarray_zarr.py @@ -11,34 +11,30 @@ import earthkit.data as ekd -from . import source_registry -from .legacy import LegacySource -from .xarray import load_many - - -@source_registry.register("xarray_zarr") -class XarrayZarrSource(LegacySource): - - @staticmethod - def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process. - - Parameters - ---------- - context : Any - The context in which the execution occurs. - dates : List[str] - List of dates for which data is to be loaded. - url : str - The URL from which data is to be loaded. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded data. - """ - return load_many("🇿", context, dates, url, *args, **kwargs) +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.xarray import load_many + + +@legacy_source(__file__) +def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process. + + Parameters + ---------- + context : Any + The context in which the execution occurs. + dates : List[str] + List of dates for which data is to be loaded. + url : str + The URL from which data is to be loaded. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded data. + """ + return load_many("🇿", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/zenodo.py b/src/anemoi/datasets/create/sources/zenodo.py index 9f4d68f97..e23b8fa47 100644 --- a/src/anemoi/datasets/create/sources/zenodo.py +++ b/src/anemoi/datasets/create/sources/zenodo.py @@ -14,58 +14,54 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from . import source_registry -from .legacy import LegacySource -from .patterns import iterate_patterns -from .xarray import load_one +from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources.patterns import iterate_patterns +from anemoi.datasets.create.sources.xarray import load_one -@source_registry.register("zenodo") -class ZenodoSource(LegacySource): +@legacy_source(__file__) +def execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the download and processing of files from Zenodo. - @staticmethod - def _execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the download and processing of files from Zenodo. + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : Any + The dates for which the data is required. + record_id : str + The Zenodo record ID. + file_key : str + The key to identify the file. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : Any - The dates for which the data is required. - record_id : str - The Zenodo record ID. - file_key : str - The key to identify the file. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. + Returns + ------- + MultiFieldList + A list of fields loaded from the downloaded files. + """ + import requests - Returns - ------- - MultiFieldList - A list of fields loaded from the downloaded files. - """ - import requests + result: list[Any] = [] - result: list[Any] = [] + URLPATTERN = "https://zenodo.org/api/records/{record_id}" + url = URLPATTERN.format(record_id=record_id) + r = requests.get(url) + r.raise_for_status() + record: dict[str, Any] = r.json() - URLPATTERN = "https://zenodo.org/api/records/{record_id}" - url = URLPATTERN.format(record_id=record_id) - r = requests.get(url) - r.raise_for_status() - record: dict[str, Any] = r.json() + urls: dict[str, str] = {} + for file in record["files"]: + urls[file["key"]] = file["links"]["self"] - urls: dict[str, str] = {} - for file in record["files"]: - urls[file["key"]] = file["links"]["self"] + for url, dates in iterate_patterns(file_key, dates, **kwargs): + if url not in urls: + continue - for url, dates in iterate_patterns(file_key, dates, **kwargs): - if url not in urls: - continue + path = download_and_cache(urls[url]) + result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) - path = download_and_cache(urls[url]) - result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) - - return MultiFieldList(result) + return MultiFieldList(result) diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py new file mode 100644 index 000000000..e8e71c45a --- /dev/null +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -0,0 +1,561 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import glob +import hashlib +import json +import logging +import os +import pickle +import shutil +import socket +from typing import Any + +import numpy as np +import tqdm +from anemoi.utils.provenance import gather_provenance_info +from numpy.typing import NDArray + +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.statistics.summary import Summary + +LOG = logging.getLogger(__name__) + + +def default_statistics_dates(dates: list[datetime.datetime]) -> tuple[datetime.datetime, datetime.datetime]: + """Calculate default statistics dates based on the given list of dates. + + Parameters + ---------- + dates : list of datetime.datetime + List of datetime objects representing dates. + + Returns + ------- + tuple of datetime.datetime + A tuple containing the default start and end dates. + """ + + def to_datetime(d): + if isinstance(d, np.datetime64): + return d.tolist() + assert isinstance(d, datetime.datetime), d + return d + + first = dates[0] + last = dates[-1] + + first = to_datetime(first) + last = to_datetime(last) + + n_years = round((last - first).total_seconds() / (365.25 * 24 * 60 * 60)) + + if n_years < 10: + # leave out 20% of the data + k = int(len(dates) * 0.8) + end = dates[k - 1] + LOG.info(f"Number of years {n_years} < 10, leaving out 20%. {end=}") + return dates[0], end + + delta = 1 + if n_years >= 20: + delta = 3 + LOG.info(f"Number of years {n_years}, leaving out {delta} years.") + end_year = last.year - delta + + end = max(d for d in dates if to_datetime(d).year == end_year) + return dates[0], end + + +def to_datetime(date: str | datetime.datetime) -> np.datetime64: + """Convert a date to numpy datetime64 format. + + Parameters + ---------- + date : str or datetime.datetime + The date to convert. + + Returns + ------- + numpy.datetime64 + The converted date. + """ + if isinstance(date, str): + return np.datetime64(date) + if isinstance(date, datetime.datetime): + return np.datetime64(date, "s") + return date + + +def to_datetimes(dates: list[str | datetime.datetime]) -> list[np.datetime64]: + """Convert a list of dates to numpy datetime64 format. + + Parameters + ---------- + dates : list of str or datetime.datetime + List of dates to convert. + + Returns + ------- + list of numpy.datetime64 + List of converted dates. + """ + return [to_datetime(d) for d in dates] + + +def fix_variance(x: float, name: str, count: NDArray[Any], sums: NDArray[Any], squares: NDArray[Any]) -> float: + """Fix negative variance values due to numerical errors. + + Parameters + ---------- + x : float + The variance value. + name : str + The variable name. + count : numpy.ndarray + The count array. + sums : numpy.ndarray + The sums array. + squares : numpy.ndarray + The squares array. + + Returns + ------- + float + The fixed variance value. + """ + assert count.shape == sums.shape == squares.shape + assert isinstance(x, float) + + mean = sums / count + assert mean.shape == count.shape + + if x >= 0: + return x + + LOG.warning(f"Negative variance for {name=}, variance={x}") + magnitude = np.sqrt((squares / count + mean * mean) / 2) + LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}") + LOG.warning(f"Variable span order of magnitude is {magnitude}.") + LOG.warning(f"Count is {count}.") + + variances = squares / count - mean * mean + assert variances.shape == squares.shape == mean.shape + if np.all(variances >= 0): + LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.") + return 0 + + # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6: + # LOG.warning("Variance is negative but very small.") + # variances = squares / count - mean * mean + # return 0 + + LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).") + return 0 + + +def check_variance( + x: NDArray[Any], + variables_names: list[str], + minimum: NDArray[Any], + maximum: NDArray[Any], + mean: NDArray[Any], + count: NDArray[Any], + sums: NDArray[Any], + squares: NDArray[Any], +) -> None: + """Check for negative variance values and raise an error if found. + + Parameters + ---------- + x : numpy.ndarray + The variance array. + variables_names : list of str + List of variable names. + minimum : numpy.ndarray + The minimum values array. + maximum : numpy.ndarray + The maximum values array. + mean : numpy.ndarray + The mean values array. + count : numpy.ndarray + The count array. + sums : numpy.ndarray + The sums array. + squares : numpy.ndarray + The squares array. + + Raises + ------ + ValueError + If negative variance is found. + """ + if (x >= 0).all(): + return + print(x) + print(variables_names) + for i, (name, y) in enumerate(zip(variables_names, x)): + if y >= 0: + continue + print("---") + print(f"❗ Negative variance for {name=}, variance={y}") + print(f" min={minimum[i]} max={maximum[i]} mean={mean[i]} count={count[i]} sums={sums[i]} squares={squares[i]}") + print(f" -> sums: min={np.min(sums[i])}, max={np.max(sums[i])}, argmin={np.argmin(sums[i])}") + print(f" -> squares: min={np.min(squares[i])}, max={np.max(squares[i])}, argmin={np.argmin(squares[i])}") + print(f" -> count: min={np.min(count[i])}, max={np.max(count[i])}, argmin={np.argmin(count[i])}") + print( + f" squares / count - mean * mean = {squares[i] / count[i]} - {mean[i] * mean[i]} = {squares[i] / count[i] - mean[i] * mean[i]}" + ) + + raise ValueError("Negative variance") + + +def compute_statistics( + array: NDArray[Any], check_variables_names: list[str] | None = None, allow_nans: bool = False +) -> dict[str, np.ndarray]: + """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary. + + Parameters + ---------- + array : numpy.ndarray + The array to compute statistics for. + check_variables_names : list of str, optional + List of variable names to check. Defaults to None. + allow_nans : bool, optional + Whether to allow NaN values. Defaults to False. + + Returns + ------- + dict of str to numpy.ndarray + A dictionary containing the computed statistics. + """ + LOG.info(f"Computing statistics for {array.shape} array") + nvars = array.shape[1] + + LOG.debug(f"Stats {nvars}, {array.shape}, {check_variables_names}") + if check_variables_names: + assert nvars == len(check_variables_names), (nvars, check_variables_names) + stats_shape = (array.shape[0], nvars) + + count = np.zeros(stats_shape, dtype=np.int64) + sums = np.zeros(stats_shape, dtype=np.float64) + squares = np.zeros(stats_shape, dtype=np.float64) + minimum = np.zeros(stats_shape, dtype=np.float64) + maximum = np.zeros(stats_shape, dtype=np.float64) + has_nans = np.zeros(stats_shape, dtype=np.bool_) + + for i, chunk in tqdm.tqdm(enumerate(array), delay=1, total=array.shape[0], desc="Computing statistics"): + values = chunk.reshape((nvars, -1)) + + for j, name in enumerate(check_variables_names): + check_data_values(values[j, :], name=name, allow_nans=allow_nans) + if np.isnan(values[j, :]).all(): + # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") + LOG.warning(f"All NaN values for {name} ({j}) for date {i}") + + # Ignore NaN values + minimum[i] = np.nanmin(values, axis=1) + maximum[i] = np.nanmax(values, axis=1) + sums[i] = np.nansum(values, axis=1) + squares[i] = np.nansum(values * values, axis=1) + count[i] = np.sum(~np.isnan(values), axis=1) + has_nans[i] = np.isnan(values).any() + + LOG.info(f"Statistics computed for {nvars} variables.") + + return { + "minimum": minimum, + "maximum": maximum, + "sums": sums, + "squares": squares, + "count": count, + "has_nans": has_nans, + } + + +class TmpStatistics: + """Temporary statistics storage class.""" + + version = 3 + # Used in parrallel, during data loading, + # to write statistics in pickled npz files. + # can provide statistics for a subset of dates. + + def __init__(self, dirname: str, overwrite: bool = False) -> None: + """Initialize TmpStatistics. + + Parameters + ---------- + dirname : str + Directory name for storing statistics. + overwrite : bool, optional + Whether to overwrite existing files. Defaults to False. + """ + self.dirname = dirname + self.overwrite = overwrite + + def add_provenance(self, **kwargs: dict) -> None: + """Add provenance information. + + Parameters + ---------- + **kwargs : dict + Additional provenance information. + """ + self.create(exist_ok=True) + path = os.path.join(self.dirname, "provenance.json") + if os.path.exists(path): + return + out = dict(provenance=gather_provenance_info(), **kwargs) + with open(path, "w") as f: + json.dump(out, f) + + def create(self, exist_ok: bool) -> None: + """Create the directory for storing statistics. + + Parameters + ---------- + exist_ok : bool + Whether to ignore if the directory already exists. + """ + os.makedirs(self.dirname, exist_ok=exist_ok) + + def delete(self) -> None: + """Delete the directory for storing statistics.""" + try: + shutil.rmtree(self.dirname) + except FileNotFoundError: + pass + + def write(self, key: str, data: any, dates: list[datetime.datetime]) -> None: + """Write statistics data to a file. + + Parameters + ---------- + key : str + The key for the data. + data : any + The data to write. + dates : list of datetime.datetime + List of dates associated with the data. + """ + self.create(exist_ok=True) + h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest() + path = os.path.join(self.dirname, f"{h}.npz") + + if not self.overwrite: + assert not os.path.exists(path), f"{path} already exists" + + tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" + with open(tmp_path, "wb") as f: + pickle.dump((key, dates, data), f) + shutil.move(tmp_path, path) + + LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})") + + def _gather_data(self) -> tuple[str, list[datetime.datetime], dict]: + """Gather data from stored files. + + Yields + ------ + tuple of str, list of datetime.datetime, dict + A tuple containing key, dates, and data. + """ + # use glob to read all pickles + files = glob.glob(self.dirname + "/*.npz") + LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}") + assert len(files) > 0, f"No files found in {self.dirname}" + for f in files: + with open(f, "rb") as f: + yield pickle.load(f) + + def get_aggregated(self, *args: Any, **kwargs: Any) -> Summary: + """Get aggregated statistics. + + Parameters + ---------- + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Summary + The aggregated statistics summary. + """ + aggregator = StatAggregator(self, *args, **kwargs) + return aggregator.aggregate() + + def __str__(self) -> str: + """String representation of TmpStatistics. + + Returns + ------- + str + The string representation. + """ + return f"TmpStatistics({self.dirname})" + + +class StatAggregator: + """Statistics aggregator class.""" + + NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"] + + def __init__( + self, owner: TmpStatistics, dates: list[datetime.datetime], variables_names: list[str], allow_nans: bool + ) -> None: + """Initialize StatAggregator. + + Parameters + ---------- + owner : TmpStatistics + The owner TmpStatistics instance. + dates : list of datetime.datetime + List of dates to aggregate. + variables_names : list of str + List of variable names. + allow_nans : bool + Whether to allow NaN values. + """ + dates = sorted(dates) + dates = to_datetimes(dates) + assert dates, "No dates selected" + self.owner = owner + self.dates = dates + self._number_of_dates = len(dates) + self._set_of_dates = set(dates) + self.variables_names = variables_names + self.allow_nans = allow_nans + + self.shape = (self._number_of_dates, len(self.variables_names)) + LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") + + self.minimum = np.full(self.shape, np.nan, dtype=np.float64) + self.maximum = np.full(self.shape, np.nan, dtype=np.float64) + self.sums = np.full(self.shape, np.nan, dtype=np.float64) + self.squares = np.full(self.shape, np.nan, dtype=np.float64) + self.count = np.full(self.shape, -1, dtype=np.int64) + self.has_nans = np.full(self.shape, False, dtype=np.bool_) + + self._read() + + def _read(self) -> None: + """Read and aggregate statistics data from files.""" + + def check_type(a, b): + if not isinstance(a, set): + a = set(list(a)) + if not isinstance(b, set): + b = set(list(b)) + a = next(iter(a)) if a else None + b = next(iter(b)) if b else None + assert type(a) is type(b), (type(a), type(b)) + + found = set() + offset = 0 + + for _, _dates, stats in self.owner._gather_data(): + assert isinstance(stats, dict), stats + assert stats["minimum"].shape[0] == len(_dates), (stats["minimum"].shape, len(_dates)) + assert stats["minimum"].shape[1] == len(self.variables_names), ( + stats["minimum"].shape, + len(self.variables_names), + ) + for n in self.NAMES: + assert n in stats, (n, list(stats.keys())) + _dates = to_datetimes(_dates) + check_type(_dates, self._set_of_dates) + if found: + check_type(found, self._set_of_dates) + assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics" + + # filter dates + dates = set(_dates) & self._set_of_dates + + if not dates: + # dates have been completely filtered for this chunk + continue + + # filter data + bitmap = np.array([d in self._set_of_dates for d in _dates]) + for k in self.NAMES: + stats[k] = stats[k][bitmap] + + assert stats["minimum"].shape[0] == len(dates), (stats["minimum"].shape, len(dates)) + + # store data in self + found |= set(dates) + for name in self.NAMES: + array = getattr(self, name) + assert stats[name].shape[0] == len(dates), (stats[name].shape, len(dates)) + array[offset : offset + len(dates)] = stats[name] + offset += len(dates) + + for d in self.dates: + assert d in found, f"Statistics for date {d} not precomputed." + assert self._number_of_dates == len(found), "Not all dates found in precomputed statistics" + assert self._number_of_dates == offset, "Not all dates found in precomputed statistics." + LOG.debug(f"Statistics for {len(found)} dates found.") + + def aggregate(self) -> Summary: + """Aggregate the statistics data. + + Returns + ------- + Summary + The aggregated statistics summary. + """ + minimum = np.nanmin(self.minimum, axis=0) + maximum = np.nanmax(self.maximum, axis=0) + + sums = np.nansum(self.sums, axis=0) + squares = np.nansum(self.squares, axis=0) + count = np.nansum(self.count, axis=0) + has_nans = np.any(self.has_nans, axis=0) + assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape + + mean = sums / count + assert mean.shape == minimum.shape + + x = squares / count - mean * mean + assert x.shape == minimum.shape + + for i, name in enumerate(self.variables_names): + # remove negative variance due to numerical errors + x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1]) + + for i, name in enumerate(self.variables_names): + check_variance( + x[i : i + 1], + [name], + minimum[i : i + 1], + maximum[i : i + 1], + mean[i : i + 1], + count[i : i + 1], + sums[i : i + 1], + squares[i : i + 1], + ) + check_data_values(np.array([mean[i]]), name=name, allow_nans=False) + + stdev = np.sqrt(x) + + return Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables_names, + has_nans=has_nans, + ) diff --git a/src/anemoi/datasets/create/statistics/summary.py b/src/anemoi/datasets/create/statistics/summary.py new file mode 100644 index 000000000..8b6c29eb0 --- /dev/null +++ b/src/anemoi/datasets/create/statistics/summary.py @@ -0,0 +1,152 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import json +from collections import defaultdict +from typing import Any + +import numpy as np + +from anemoi.datasets.create.check import StatisticsValueError +from anemoi.datasets.create.check import check_data_values +from anemoi.datasets.create.check import check_stats + + +class Summary(dict): + """This class is used to store the summary statistics of a dataset. It can be saved and loaded from a json file. And does some basic checks on the data.""" + + STATS_NAMES = [ + "minimum", + "maximum", + "mean", + "stdev", + "has_nans", + ] # order matter for __str__. + + def __init__(self, **kwargs: Any) -> None: + """Initialize the Summary object with given keyword arguments. + + Parameters + ---------- + **kwargs : Any + Arbitrary keyword arguments representing summary statistics. + """ + super().__init__(**kwargs) + self.check() + + @property + def size(self) -> int: + """Get the size of the summary, which is the number of variables.""" + return len(self["variables_names"]) + + def check(self) -> None: + """Perform checks on the summary statistics to ensure they are valid. + + Raises + ------ + AssertionError + If any of the checks fail. + StatisticsValueError + If any of the statistical checks fail. + """ + for k, v in self.items(): + if k == "variables_names": + assert len(v) == self.size + continue + assert v.shape == (self.size,) + if k == "count": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.int64, (k, v) + continue + if k == "has_nans": + assert v.dtype == np.bool_, (k, v) + continue + if k == "stdev": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.float64, (k, v) + + for i, name in enumerate(self["variables_names"]): + try: + check_stats(**{k: v[i] for k, v in self.items()}, msg=f"{i} {name}") + check_data_values(self["minimum"][i], name=name) + check_data_values(self["maximum"][i], name=name) + check_data_values(self["mean"][i], name=name) + except StatisticsValueError as e: + e.args += (i, name) + raise + + def __str__(self) -> str: + """Return a string representation of the summary statistics. + + Returns + ------- + str + A formatted string of the summary statistics. + """ + header = ["Variables"] + self.STATS_NAMES + out = [" ".join(header)] + + out += [ + " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) + for i, v in enumerate(self["variables_names"]) + ] + return "\n".join(out) + + def save(self, filename: str, **metadata: Any) -> None: + """Save the summary statistics to a JSON file. + + Parameters + ---------- + filename : str + The name of the file to save the summary statistics. + **metadata : Any + Additional metadata to include in the JSON file. + """ + assert filename.endswith(".json"), filename + dic = {} + for k in self.STATS_NAMES: + dic[k] = list(self[k]) + + out = dict(data=defaultdict(dict)) + for i, name in enumerate(self["variables_names"]): + for k in self.STATS_NAMES: + out["data"][name][k] = dic[k][i] + + out["metadata"] = metadata + + with open(filename, "w") as f: + json.dump(out, f, indent=2) + + def load(self, filename: str) -> "Summary": + """Load the summary statistics from a JSON file. + + Parameters + ---------- + filename : str + The name of the file to load the summary statistics from. + + Returns + ------- + Summary + The loaded Summary object. + """ + assert filename.endswith(".json"), filename + with open(filename) as f: + dic = json.load(f) + + dic_ = {} + for k, v in dic.items(): + if k == "count": + dic_[k] = np.array(v, dtype=np.int64) + continue + if k == "variables": + dic_[k] = v + continue + dic_[k] = np.array(v, dtype=np.float64) + return Summary(dic_) diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py index 23728e6aa..05372d6d7 100644 --- a/src/anemoi/datasets/create/tasks.py +++ b/src/anemoi/datasets/create/tasks.py @@ -46,9 +46,9 @@ def run(self) -> None: return Chain -def task_factory(name: str, trace: str | None = None, **kwargs): +def task_factory(name: str, fields: bool = True, trace: str | None = None, **kwargs): - if True: + if fields: from anemoi.datasets.create.gridded.tasks import TaskCreator creator = TaskCreator() diff --git a/src/anemoi/datasets/create/testing.py b/src/anemoi/datasets/create/testing.py new file mode 100644 index 000000000..5363cd9f7 --- /dev/null +++ b/src/anemoi/datasets/create/testing.py @@ -0,0 +1,4 @@ +class TestingContext: + """A context for testing plugins.""" + + pass diff --git a/src/anemoi/datasets/create/typing.py b/src/anemoi/datasets/create/typing.py new file mode 100644 index 000000000..0eafdb193 --- /dev/null +++ b/src/anemoi/datasets/create/typing.py @@ -0,0 +1,14 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime + +Date = datetime.datetime + +DateList = list[Date] diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py new file mode 100644 index 000000000..00ea89e7b --- /dev/null +++ b/src/anemoi/datasets/create/utils.py @@ -0,0 +1,198 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import datetime +import os +import warnings +from contextlib import contextmanager +from typing import Any + +import numpy as np +from earthkit.data import settings +from numpy.typing import NDArray + + +def cache_context(dirname: str) -> contextmanager: + """Context manager for setting a temporary cache directory. + + Parameters + ---------- + dirname : str + The directory name for the cache. + + Returns + ------- + contextmanager + A context manager that sets the cache directory. + """ + + @contextmanager + def no_cache_context(): + yield + + if dirname is None: + return no_cache_context() + + os.makedirs(dirname, exist_ok=True) + # return settings.temporary("cache-directory", dirname) + return settings.temporary({"cache-policy": "user", "user-cache-directory": dirname}) + + +def to_datetime_list(*args: Any, **kwargs: Any) -> list[datetime.datetime]: + """Convert various date formats to a list of datetime objects. + + Parameters + ---------- + *args : Any + Positional arguments for date conversion. + **kwargs : Any + Keyword arguments for date conversion. + + Returns + ------- + list[datetime.datetime] + A list of datetime objects. + """ + from earthkit.data.utils.dates import to_datetime_list as to_datetime_list_ + + warnings.warn( + "to_datetime_list() is deprecated. Call earthkit.data.utils.dates.to_datetime_list() instead.", + DeprecationWarning, + stacklevel=2, + ) + return to_datetime_list_(*args, **kwargs) + + +def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime: + """Convert various date formats to a single datetime object. + + Parameters + ---------- + *args : Any + Positional arguments for date conversion. + **kwargs : Any + Keyword arguments for date conversion. + + Returns + ------- + datetime.datetime + A datetime object. + """ + from earthkit.data.utils.dates import to_datetime as to_datetime_ + + warnings.warn( + "to_datetime() is deprecated. Call earthkit.data.utils.dates.to_datetime() instead.", + DeprecationWarning, + stacklevel=2, + ) + + return to_datetime_(*args, **kwargs) + + +def make_list_int(value: str | list | tuple | int) -> list[int]: + """Convert a string, list, tuple, or integer to a list of integers. + + Parameters + ---------- + value : str or list or tuple or int + The value to convert. + + Returns + ------- + list[int] + A list of integers. + + Raises + ------ + ValueError + If the value cannot be converted to a list of integers. + """ + # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers. + # Moved to anemoi.utils.humanize + # replace with from anemoi.utils.humanize import make_list_int + # when anemoi-utils is released and pyproject.toml is updated + if isinstance(value, str): + if "/" not in value: + return [value] + bits = value.split("/") + if len(bits) == 3 and bits[1].lower() == "to": + value = list(range(int(bits[0]), int(bits[2]) + 1, 1)) + + elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by": + value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4]))) + + if isinstance(value, list): + return value + if isinstance(value, tuple): + return value + if isinstance(value, int): + return [value] + + raise ValueError(f"Cannot make list from {value}") + + +def normalize_and_check_dates( + dates: list[datetime.datetime], + start: datetime.datetime, + end: datetime.datetime, + frequency: datetime.timedelta, + dtype: str = "datetime64[s]", +) -> NDArray[Any]: + """Normalize and check a list of dates against a specified frequency. + + Parameters + ---------- + dates : list[datetime.datetime] + The list of dates to check. + start : datetime.datetime + The start date. + end : datetime.datetime + The end date. + frequency : datetime.timedelta + The frequency of the dates. + dtype : str, optional + The data type of the dates, by default "datetime64[s]". + + Returns + ------- + NDArray[Any] + An array of normalized dates. + + Raises + ------ + ValueError + If the final date size does not match the data shape. + """ + dates = [d.hdate if hasattr(d, "hdate") else d for d in dates] + + assert isinstance(frequency, datetime.timedelta), frequency + start = np.datetime64(start) + end = np.datetime64(end) + delta = np.timedelta64(frequency) + + res = [] + while start <= end: + res.append(start) + start += delta + dates_ = np.array(res).astype(dtype) + + if len(dates_) != len(dates): + raise ValueError( + f"Final date size {len(dates_)} (from {dates_[0]} to {dates_[-1]}, " + f"{frequency=}) does not match data shape {len(dates)} (from {dates[0]} to " + f"{dates[-1]})." + ) + + for i, (d1, d2) in enumerate(zip(dates, dates_)): + d1 = np.datetime64(d1) + d2 = np.datetime64(d2) + assert d1 == d2, (i, d1, d2) + + return dates_ diff --git a/src/anemoi/datasets/create/writer.py b/src/anemoi/datasets/create/writer.py new file mode 100644 index 000000000..d573c1ca5 --- /dev/null +++ b/src/anemoi/datasets/create/writer.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +class ViewCacheArray: + """A class that provides a caching mechanism for writing to a NumPy-like array. + + The is initialised with a NumPy-like array, a shape and a list to reindex the first + dimension. The array is used to store the final data, while the cache is used to + temporarily store the data before flushing it to the array. + + The `flush` method copies the contents of the cache to the final array. + """ + + def __init__(self, array: NDArray[Any], *, shape: tuple[int, ...], indexes: list[int]): + """Initialize the ViewCacheArray. + + Parameters + ---------- + array : NDArray[Any] + The NumPy-like array to store the final data. + shape : tuple[int, ...] + The shape of the cache array. + indexes : list[int] + List to reindex the first dimension. + """ + assert len(indexes) == shape[0], (len(indexes), shape[0]) + self.array = array + self.dtype = array.dtype + self.cache = np.full(shape, np.nan, dtype=self.dtype) + self.indexes = indexes + + def __setitem__(self, key: tuple[int, ...], value: NDArray[Any]) -> None: + """Set the value in the cache array at the specified key. + + Parameters + ---------- + key : tuple[int, ...] + The index key to set the value. + value : NDArray[Any] + The value to set in the cache array. + """ + self.cache[key] = value + + def flush(self) -> None: + """Copy the contents of the cache to the final array.""" + for i in range(self.cache.shape[0]): + global_i = self.indexes[i] + self.array[global_i] = self.cache[i] diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/zarr.py new file mode 100644 index 000000000..32b493dd3 --- /dev/null +++ b/src/anemoi/datasets/create/zarr.py @@ -0,0 +1,331 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import shutil +from typing import Any + +import numpy as np +import zarr +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +def add_zarr_dataset( + *, + name: str, + dtype: np.dtype = None, + fill_value: np.generic = None, + zarr_root: zarr.Group, + shape: tuple[int, ...] = None, + array: NDArray[Any] = None, + overwrite: bool = True, + dimensions: tuple[str, ...] = None, + **kwargs, +) -> zarr.Array: + """Add a dataset to a Zarr group. + + Parameters + ---------- + name : str + Name of the dataset. + dtype : np.dtype, optional + Data type of the dataset. + fill_value : np.generic, optional + Fill value for the dataset. + zarr_root : zarr.Group + Root Zarr group. + shape : tuple[int, ...], optional + Shape of the dataset. + array : NDArray[Any], optional + Array to initialize the dataset with. + overwrite : bool + Whether to overwrite existing dataset. + dimensions : tuple[str, ...] + Dimensions of the dataset. + **kwargs + Additional arguments for Zarr dataset creation. + + Returns + ------- + zarr.Array + The created Zarr array. + """ + assert dimensions is not None, "Please pass dimensions to add_zarr_dataset." + assert isinstance(dimensions, (tuple, list)) + + if dtype is None: + assert array is not None, (name, shape, array, dtype, zarr_root) + dtype = array.dtype + + if shape is None: + assert array is not None, (name, shape, array, dtype, zarr_root) + shape = array.shape + + if array is not None: + assert array.shape == shape, (array.shape, shape) + a = zarr_root.create_dataset( + name, + shape=shape, + dtype=dtype, + overwrite=overwrite, + **kwargs, + ) + a[...] = array + a.attrs["_ARRAY_DIMENSIONS"] = dimensions + return a + + if "fill_value" not in kwargs: + if str(dtype).startswith("float") or str(dtype).startswith("numpy.float"): + kwargs["fill_value"] = np.nan + elif str(dtype).startswith("datetime64") or str(dtype).startswith("numpy.datetime64"): + kwargs["fill_value"] = np.datetime64("NaT") + # elif str(dtype).startswith("timedelta64") or str(dtype).startswith( + # "numpy.timedelta64" + # ): + # kwargs["fill_value"] = np.timedelta64("NaT") + elif str(dtype).startswith("int") or str(dtype).startswith("numpy.int"): + kwargs["fill_value"] = 0 + elif str(dtype).startswith("bool") or str(dtype).startswith("numpy.bool"): + kwargs["fill_value"] = False + else: + raise ValueError(f"No fill_value for dtype={dtype}") + + a = zarr_root.create_dataset( + name, + shape=shape, + dtype=dtype, + overwrite=overwrite, + **kwargs, + ) + a.attrs["_ARRAY_DIMENSIONS"] = dimensions + return a + + +class ZarrBuiltRegistry: + """A class to manage the creation and access of Zarr datasets.""" + + name_lengths = "lengths" + name_flags = "flags" + lengths = None + flags = None + z = None + + def __init__(self, path: str, synchronizer_path: str | None = None, use_threads: bool = False): + """Initialize the ZarrBuiltRegistry. + + Parameters + ---------- + path : str + Path to the Zarr store. + synchronizer_path : Optional[str], optional + Path to the synchronizer. + use_threads : bool + Whether to use thread-based synchronization. + """ + import zarr + + assert isinstance(path, str), path + self.zarr_path = path + + if use_threads: + self.synchronizer = zarr.ThreadSynchronizer() + self.synchronizer_path = None + else: + if synchronizer_path is None: + synchronizer_path = self.zarr_path + ".sync" + self.synchronizer_path = synchronizer_path + self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) + + def clean(self) -> None: + """Clean up the synchronizer path.""" + if self.synchronizer_path is not None: + try: + shutil.rmtree(self.synchronizer_path) + except FileNotFoundError: + pass + + _build = self.zarr_path + "/_build" + try: + shutil.rmtree(_build) + except FileNotFoundError: + pass + + def _open_write(self) -> zarr.Group: + """Open the Zarr store in write mode.""" + import zarr + + return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer) + + def _open_read(self, sync: bool = True) -> zarr.Group: + """Open the Zarr store in read mode. + + Parameters + ---------- + sync : bool + Whether to use synchronization. + + Returns + ------- + zarr.Group + The opened Zarr group. + """ + import zarr + + if sync: + return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer) + else: + return zarr.open(self.zarr_path, mode="r") + + def new_dataset(self, *args, **kwargs) -> None: + """Create a new dataset in the Zarr store. + + Parameters + ---------- + *args + Positional arguments for dataset creation. + **kwargs + Keyword arguments for dataset creation. + """ + z = self._open_write() + zarr_root = z["_build"] + add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + + def add_to_history(self, action: str, **kwargs) -> None: + """Add an action to the history attribute of the Zarr store. + + Parameters + ---------- + action : str + The action to record. + **kwargs + Additional information about the action. + """ + new = dict( + action=action, + timestamp=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat(), + ) + new.update(kwargs) + + z = self._open_write() + history = z.attrs.get("history", []) + history.append(new) + z.attrs["history"] = history + + def get_lengths(self) -> list[int]: + """Get the lengths dataset. + + Returns + ------- + list[int] + The lengths dataset. + """ + z = self._open_read() + return list(z["_build"][self.name_lengths][:]) + + def get_flags(self, **kwargs) -> list[bool]: + """Get the flags dataset. + + Parameters + ---------- + **kwargs + Additional arguments for reading the dataset. + + Returns + ------- + list[bool] + The flags dataset. + """ + z = self._open_read(**kwargs) + return list(z["_build"][self.name_flags][:]) + + def get_flag(self, i: int) -> bool: + """Get a specific flag. + + Parameters + ---------- + i : int + Index of the flag. + + Returns + ------- + bool + The flag value. + """ + z = self._open_read() + return z["_build"][self.name_flags][i] + + def set_flag(self, i: int, value: bool = True) -> None: + """Set a specific flag. + + Parameters + ---------- + i : int + Index of the flag. + value : bool + Value to set the flag to. + """ + z = self._open_write() + z.attrs["latest_write_timestamp"] = ( + datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() + ) + z["_build"][self.name_flags][i] = value + + def ready(self) -> bool: + """Check if all flags are set. + + Returns + ------- + bool + True if all flags are set, False otherwise. + """ + return all(self.get_flags()) + + def create(self, lengths: list[int], overwrite: bool = False) -> None: + """Create the lengths and flags datasets. + + Parameters + ---------- + lengths : list[int] + Lengths to initialize the dataset with. + overwrite : bool + Whether to overwrite existing datasets. + """ + self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4")) + self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool)) + self.add_to_history("initialised") + + def reset(self, lengths: list[int]) -> None: + """Reset the lengths and flags datasets. + + Parameters + ---------- + lengths : list[int] + Lengths to initialize the dataset with. + """ + return self.create(lengths, overwrite=True) + + def add_provenance(self, name: str) -> None: + """Add provenance information to the Zarr store. + + Parameters + ---------- + name : str + Name of the provenance attribute. + """ + z = self._open_write() + + if name in z.attrs: + return + + from anemoi.utils.provenance import gather_provenance_info + + z.attrs[name] = gather_provenance_info() diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 18a09ecfd..223736971 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -27,13 +27,15 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Args: - x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. + Parameters + ---------- + x : Union[str, List[Any], Tuple[Any, ...]] + A date range string or list/tuple of dates. - Returns - ------- - Iterator[datetime.datetime] - An iterator of datetime objects. + Yields + ------ + datetime.datetime + Individual datetime objects. """ if isinstance(x, (list, tuple)): diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py new file mode 100644 index 000000000..ffec5e351 --- /dev/null +++ b/src/anemoi/datasets/grids.py @@ -0,0 +1,668 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import base64 +import logging +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +LOG = logging.getLogger(__name__) + + +def plot_mask( + path: str, + mask: NDArray[Any], + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], +) -> None: + """Plot and save various visualizations of the mask and coordinates. + + Parameters + ---------- + path : str + The base path for saving the plots. + mask : NDArray[Any] + The mask array. + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + """ + import matplotlib.pyplot as plt + + s = 1 + + global_lons[global_lons >= 180] -= 360 + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=s, marker="o", c="r") + if isinstance(path, str): + plt.savefig(path + "-global.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k") + if isinstance(path, str): + plt.savefig(path + "-cutout.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(lons, lats, s=s) + if isinstance(path, str): + plt.savefig(path + "-lam.png") + # plt.scatter(lons, lats, s=0.01) + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.scatter(lons, lats, s=s) + if isinstance(path, str): + plt.savefig(path + "-both.png") + # plt.scatter(lons, lats, s=0.01) + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.scatter(lons, lats, s=s) + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-both-zoomed.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-global-zoomed.png") + + +# TODO: Use the one from anemoi.utils.grids instead +# from anemoi.utils.grids import ... +def xyz_to_latlon(x: NDArray[Any], y: NDArray[Any], z: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any]]: + """Convert Cartesian coordinates to latitude and longitude. + + Parameters + ---------- + x : NDArray[Any] + X coordinates. + y : NDArray[Any] + Y coordinates. + z : NDArray[Any] + Z coordinates. + + Returns + ------- + Tuple[NDArray[Any], NDArray[Any]] + Latitude and longitude coordinates. + """ + return ( + np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))), + np.rad2deg(np.arctan2(y, x)), + ) + + +# TODO: Use the one from anemoi.utils.grids instead +# from anemoi.utils.grids import ... +def latlon_to_xyz( + lat: NDArray[Any], lon: NDArray[Any], radius: float = 1.0 +) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: + """Convert latitude and longitude to Cartesian coordinates. + + Parameters + ---------- + lat : NDArray[Any] + Latitude coordinates. + lon : NDArray[Any] + Longitude coordinates. + radius : float, optional + Radius of the sphere. Defaults to 1.0. + + Returns + ------- + Tuple[NDArray[Any], NDArray[Any], NDArray[Any]] + X, Y, and Z coordinates. + """ + # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates + # We assume that the Earth is a sphere of radius 1 so N(phi) = 1 + # We assume h = 0 + # + phi = np.deg2rad(lat) + lda = np.deg2rad(lon) + + cos_phi = np.cos(phi) + cos_lda = np.cos(lda) + sin_phi = np.sin(phi) + sin_lda = np.sin(lda) + + x = cos_phi * cos_lda * radius + y = cos_phi * sin_lda * radius + z = sin_phi * radius + + return x, y, z + + +class Triangle3D: + """A class to represent a 3D triangle and perform intersection tests with rays.""" + + def __init__(self, v0: NDArray[Any], v1: NDArray[Any], v2: NDArray[Any]) -> None: + """Initialize the Triangle3D object. + + Parameters + ---------- + v0 : NDArray[Any] + First vertex of the triangle. + v1 : NDArray[Any] + Second vertex of the triangle. + v2 : NDArray[Any] + Third vertex of the triangle. + """ + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + + def intersect(self, ray_origin: NDArray[Any], ray_direction: NDArray[Any]) -> bool: + """Check if a ray intersects with the triangle. + + Parameters + ---------- + ray_origin : NDArray[Any] + Origin of the ray. + ray_direction : NDArray[Any] + Direction of the ray. + + Returns + ------- + bool + True if the ray intersects with the triangle, False otherwise. + """ + # Möller–Trumbore intersection algorithm + # https://en.wikipedia.org/wiki/M%C3%B6ller%E2%80%93Trumbore_intersection_algorithm + + epsilon = 0.0000001 + + h = np.cross(ray_direction, self.v2 - self.v0) + a = np.dot(self.v1 - self.v0, h) + + if -epsilon < a < epsilon: + return False + + f = 1.0 / a + s = ray_origin - self.v0 + u = f * np.dot(s, h) + + if u < 0.0 or u > 1.0: + return False + + q = np.cross(s, self.v1 - self.v0) + v = f * np.dot(ray_direction, q) + + if v < 0.0 or u + v > 1.0: + return False + + t = f * np.dot(self.v2 - self.v0, q) + + if t > epsilon: + return True + + return False + + +def cropping_mask( + lats: NDArray[Any], + lons: NDArray[Any], + north: float, + west: float, + south: float, + east: float, +) -> NDArray[Any]: + """Create a mask for the points within the specified latitude and longitude bounds. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + north : float + Northern boundary. + west : float + Western boundary. + south : float + Southern boundary. + east : float + Eastern boundary. + + Returns + ------- + NDArray[Any] + Mask array. + """ + mask = ( + (lats >= south) + & (lats <= north) + & ( + ((lons >= west) & (lons <= east)) + | ((lons >= west + 360) & (lons <= east + 360)) + | ((lons >= west - 360) & (lons <= east - 360)) + ) + ) + return mask + + +def cutout_mask( + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], + cropping_distance: float = 2.0, + neighbours: int = 5, + min_distance_km: int | float | None = None, + plot: str | None = None, +) -> NDArray[Any]: + """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + cropping_distance : float, optional + Cropping distance. Defaults to 2.0. + neighbours : int, optional + Number of neighbours. Defaults to 5. + min_distance_km : Optional[Union[int, float]], optional + Minimum distance in kilometers. Defaults to None. + plot : Optional[str], optional + Path for saving the plot. Defaults to None. + + Returns + ------- + NDArray[Any] + Mask array. + """ + from scipy.spatial import cKDTree + + # TODO: transform min_distance from lat/lon to xyz + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + # mask = np.array([True] * len(global_lats), dtype=bool) + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + lam_points = np.array(xyx).transpose() + + if isinstance(min_distance_km, (int, float)): + min_distance = min_distance_km / 6371.0 + else: + points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] + distances, _ = cKDTree(points).query(points, k=2) + min_distance = np.min(distances[:, 1]) + + LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") + + # Use a cKDTree to find the nearest points + distances, indices = cKDTree(lam_points).query(global_points, k=neighbours) + + # Centre of the Earth + zero = np.array([0.0, 0.0, 0.0]) + + # After the loop, 'inside_lam' will contain a list point to EXCLUDE + inside_lam = [] + + for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)): + + # We check more than one triangle in case te global point + # is near the edge of triangle, (the lam point and global points are colinear) + + inside = False + for j in range(neighbours): + t = Triangle3D( + lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]] + ) + inside = t.intersect(zero, global_point) + if inside: + break + + close = np.min(distance) <= min_distance + + inside_lam.append(inside or close) + + j = 0 + inside_lam_array = np.array(inside_lam) + for i, m in enumerate(mask): + if not m: + continue + + mask[i] = inside_lam_array[j] + j += 1 + + assert j == len(inside_lam_array) + + # Invert the mask, so we have only the points outside the cutout + mask = ~mask + + if plot: + plot_mask(plot, mask, lats, lons, global_lats, global_lons) + + return mask + + +def thinning_mask( + lats: NDArray[Any], + lons: NDArray[Any], + global_lats: NDArray[Any], + global_lons: NDArray[Any], + cropping_distance: float = 2.0, +) -> NDArray[Any]: + """Return the list of points in [lats, lons] closest to [global_lats, global_lons]. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + global_lats : NDArray[Any] + Global latitude coordinates. + global_lons : NDArray[Any] + Global longitude coordinates. + cropping_distance : float, optional + Cropping distance. Defaults to 2.0. + + Returns + ------- + NDArray[Any] + Array of indices of the closest points. + """ + from scipy.spatial import cKDTree + + assert global_lats.ndim == 1 + assert global_lons.ndim == 1 + assert lats.ndim == 1 + assert lons.ndim == 1 + + assert global_lats.shape == global_lons.shape + assert lats.shape == lons.shape + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + # Reduce the global grid to the area of interest + + mask = cropping_mask( + global_lats, + global_lons, + np.min([90.0, north + cropping_distance]), + west - cropping_distance, + np.max([-90.0, south - cropping_distance]), + east + cropping_distance, + ) + + # return mask + global_lats_masked = global_lats[mask] + global_lons_masked = global_lons[mask] + + global_xyx = latlon_to_xyz(global_lats_masked, global_lons_masked) + global_points = np.array(global_xyx).transpose() + + xyx = latlon_to_xyz(lats, lons) + points = np.array(xyx).transpose() + + # Use a cKDTree to find the nearest points + _, indices = cKDTree(points).query(global_points, k=1) + + return np.array([i for i in indices]) + + +def outline(lats: NDArray[Any], lons: NDArray[Any], neighbours: int = 5) -> list[int]: + """Find the outline of the grid points. + + Parameters + ---------- + lats : NDArray[Any] + Latitude coordinates. + lons : NDArray[Any] + Longitude coordinates. + neighbours : int, optional + Number of neighbours. Defaults to 5. + + Returns + ------- + List[int] + Indices of the outline points. + """ + from scipy.spatial import cKDTree + + xyx = latlon_to_xyz(lats, lons) + grid_points = np.array(xyx).transpose() + + # Use a cKDTree to find the nearest points + _, indices = cKDTree(grid_points).query(grid_points, k=neighbours) + + # Centre of the Earth + zero = np.array([0.0, 0.0, 0.0]) + + outside = [] + + for i, (point, index) in enumerate(zip(grid_points, indices)): + inside = False + for j in range(1, neighbours): + t = Triangle3D( + grid_points[index[j]], + grid_points[index[(j + 1) % neighbours]], + grid_points[index[(j + 2) % neighbours]], + ) + inside = t.intersect(zero, point) + if inside: + break + + if not inside: + outside.append(i) + + return outside + + +def deserialise_mask(encoded: str) -> NDArray[Any]: + """Deserialise a mask from a base64 encoded string. + + Parameters + ---------- + encoded : str + Base64 encoded string. + + Returns + ------- + NDArray[Any] + Deserialised mask array. + """ + import pickle + import zlib + + packed = pickle.loads(zlib.decompress(base64.b64decode(encoded))) + + mask = [] + value = False + for count in packed: + mask.extend([value] * count) + value = not value + return np.array(mask, dtype=bool) + + +def _serialise_mask(mask: NDArray[Any]) -> str: + """Serialise a mask to a base64 encoded string. + + Parameters + ---------- + mask : NDArray[Any] + Mask array. + + Returns + ------- + str + Base64 encoded string. + """ + import pickle + import zlib + + assert len(mask.shape) == 1 + assert len(mask) + + packed = [] + last = mask[0] + count = 1 + + for value in mask[1:]: + if value == last: + count += 1 + else: + packed.append(count) + last = value + count = 1 + + packed.append(count) + + # We always start with an 'off' value + # So if the first value is 'on', we need to add a zero + if mask[0]: + packed.insert(0, 0) + + return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8") + + +def serialise_mask(mask: NDArray[Any]) -> str: + """Serialise a mask and ensure it can be deserialised. + + Parameters + ---------- + mask : NDArray[Any] + Mask array. + + Returns + ------- + str + Base64 encoded string. + """ + result = _serialise_mask(mask) + # Make sure we can deserialise it + assert np.all(mask == deserialise_mask(result)) + return result + + +def nearest_grid_points( + source_latitudes: NDArray[Any], + source_longitudes: NDArray[Any], + target_latitudes: NDArray[Any], + target_longitudes: NDArray[Any], + max_distance: float = None, + k: int = 1, +) -> NDArray[Any]: + """Find the nearest grid points from source to target coordinates. + + Parameters + ---------- + source_latitudes : NDArray[Any] + Source latitude coordinates. + source_longitudes : NDArray[Any] + Source longitude coordinates. + target_latitudes : NDArray[Any] + Target latitude coordinates. + target_longitudes : NDArray[Any] + Target longitude coordinates. + max_distance: float, optional + Maximum distance between nearest point and point to interpolate. Defaults to None. + For example, 1e-3 is 1 km. + k : int, optional + The number of k closest neighbors to consider for interpolation + + Returns + ------- + NDArray[Any] + Indices of the nearest grid points. + """ + # TODO: Use the one from anemoi.utils.grids instead + # from anemoi.utils.grids import ... + from scipy.spatial import KDTree + + source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) + source_points = np.array(source_xyz).transpose() + + target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) + target_points = np.array(target_xyz).transpose() + if max_distance is None: + distances, indices = KDTree(source_points).query(target_points, k=k) + else: + distances, indices = KDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + return distances, indices + + +if __name__ == "__main__": + global_lats, global_lons = np.meshgrid( + np.linspace(90, -90, 90), + np.linspace(-180, 180, 180), + ) + global_lats = global_lats.flatten() + global_lons = global_lons.flatten() + + lats, lons = np.meshgrid( + np.linspace(50, 40, 100), + np.linspace(-10, 15, 100), + ) + lats = lats.flatten() + lons = lons.flatten() + + mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0) + + import matplotlib.pyplot as plt + + fig = plt.figure(figsize=(10, 5)) + plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") + plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") + # plt.scatter(lons, lats, s=0.01) + plt.savefig("cutout.png") diff --git a/src/anemoi/datasets/testing.py b/src/anemoi/datasets/testing.py new file mode 100644 index 000000000..a15c7fd7e --- /dev/null +++ b/src/anemoi/datasets/testing.py @@ -0,0 +1,173 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +# A collection of functions to support pytest testing + +import logging +from typing import Any + +LOG = logging.getLogger(__name__) + + +def assert_field_list( + fs: list[Any], + size: int | None = None, + start: Any | None = None, + end: Any | None = None, + constant: bool = False, + skip: Any | None = None, +) -> None: + """Asserts various properties of a list of fields. + + Parameters + ---------- + fs : List[Any] + List of fields to be checked. + size : Optional[int], optional + Expected size of the list. If None, the list must be non-empty. + start : Optional[Any], optional + Expected start metadata value. If None, no check is performed. + end : Optional[Any], optional + Expected end metadata value. If None, no check is performed. + constant : bool, optional + If True, checks that all fields are constant. + skip : Optional[Any], optional + Placeholder for future use. + """ + import numpy as np + + if size is None: + assert len(fs) > 0, fs + else: + assert len(fs) == size, (len(fs), size) + + first = fs[0] + last = fs[-1] + + if constant: + # TODO: add a check for constant fields + pass + else: + assert start is None or first.metadata("valid_datetime") == start, (first.metadata("valid_datetime"), start) + assert end is None or last.metadata("valid_datetime") == end, (last.metadata("valid_datetime"), end) + print(first.datetime()) + + print(last.metadata()) + + first = first + latitudes, longitudes = first.grid_points() + + assert len(latitudes.shape) == 1, latitudes.shape + assert len(longitudes.shape) == 1, longitudes.shape + + assert len(latitudes) == len(longitudes), (len(latitudes), len(longitudes)) + data = first.to_numpy(flatten=True) + + assert len(data) == len(latitudes), (len(data), len(latitudes)) + + north = np.max(latitudes) + south = np.min(latitudes) + east = np.max(longitudes) + west = np.min(longitudes) + + assert north >= south, (north, south) + assert east >= west, (east, west) + assert north <= 90, north + assert south >= -90, south + assert east <= 360, east + assert west >= -180, west + + +class IndexTester: + """Class to test indexing of datasets.""" + + def __init__(self, ds: Any) -> None: + """Initialise the IndexTester. + + Parameters + ---------- + ds : Any + Dataset. + """ + self.ds = ds + self.np = ds[:] # Numpy array + + assert self.ds.shape == self.np.shape, (self.ds.shape, self.np.shape) + assert (self.ds == self.np).all() + + def __getitem__(self, index: Any) -> None: + """Test indexing. + + Parameters + ---------- + index : Any + Index. + """ + LOG.info("IndexTester: %s", index) + if self.ds[index] is None: + assert False, (self.ds, index) + + if not (self.ds[index] == self.np[index]).all(): + assert (self.ds[index] == self.np[index]).all() + + +def default_test_indexing(ds): + + t = IndexTester(ds) + + t[0:10, :, 0] + t[:, 0:3, 0] + # t[:, :, 0] + t[0:10, 0:3, 0] + t[:, :, :] + + if ds.shape[1] > 2: # Variable dimension + t[:, (1, 2), :] + t[:, (1, 2)] + + t[0] + t[0, :] + t[0, 0, :] + t[0, 0, 0, :] + + if ds.shape[2] > 1: # Ensemble dimension + t[0:10, :, (0, 1)] + + for i in range(3): + t[i] + start = 5 * i + end = len(ds) - 5 * i + step = len(ds) // 10 + + t[start:end:step] + t[start:end] + t[start:] + t[:end] + t[::step] + + +class Trace: + + def __init__(self, ds): + self.ds = ds + self.f = open("trace.txt", "a") + + def __getattr__(self, name: str) -> Any: + + print(name, file=self.f, flush=True) + return getattr(self.ds, name) + + def __len__(self) -> int: + print("__len__", file=self.f, flush=True) + return len(self.ds) + + def __getitem__(self, index: Any) -> Any: + print("__getitem__", file=self.f, flush=True) + return self.ds[index] diff --git a/src/anemoi/datasets/use/__init__.py b/src/anemoi/datasets/use/__init__.py index 9fc775e54..e69de29bb 100644 --- a/src/anemoi/datasets/use/__init__.py +++ b/src/anemoi/datasets/use/__init__.py @@ -1,8 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/use/gridded/__init__.py b/src/anemoi/datasets/use/gridded/__init__.py index dbbfcd9a5..6af38b2f4 100644 --- a/src/anemoi/datasets/use/gridded/__init__.py +++ b/src/anemoi/datasets/use/gridded/__init__.py @@ -95,7 +95,7 @@ def open_dataset(*args: Any, **kwargs: Any) -> "Dataset": ds._check() if trace: - from anemoi.datasets.misc.testing import Trace + from anemoi.datasets.testing import Trace ds = Trace(ds) diff --git a/src/anemoi/datasets/use/gridded/complement.py b/src/anemoi/datasets/use/gridded/complement.py index 1881a74fa..300d79363 100644 --- a/src/anemoi/datasets/use/gridded/complement.py +++ b/src/anemoi/datasets/use/gridded/complement.py @@ -16,7 +16,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.misc.grids import nearest_grid_points +from anemoi.datasets.grids import nearest_grid_points from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import Shape @@ -249,7 +249,13 @@ def __init__(self, target: Any, source: Any, max_distance: float = None, k: int """ super().__init__(target, source) + if isinstance(k, str): + assert False + LOG.warning(f"ComplementNearest: Interpreting k={k} ({type(k)}) as integer") + k = int(k) + self.k = k + self._distances, self._nearest_grid_points = nearest_grid_points( self._source.latitudes, self._source.longitudes, @@ -353,7 +359,7 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset: }[interpolation] if interpolation == "nearest": - k = kwargs.pop("k", "1") + k = kwargs.pop("k", 1) complement = Class(target=target, source=source, k=k)._subset(**kwargs) else: diff --git a/src/anemoi/datasets/use/gridded/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py index 5a4df0052..9969ca69c 100644 --- a/src/anemoi/datasets/use/gridded/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -178,8 +178,6 @@ def __subset(self, **kwargs: Any) -> "Dataset": padding = kwargs.pop("padding", None) if padding: - if padding != "empty": - raise ValueError(f"Only 'empty' padding is supported, got {padding=}") from anemoi.datasets.use.gridded.padded import Padded frequency = kwargs.pop("frequency", self.frequency) @@ -246,7 +244,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from anemoi.datasets.use.gridded import open_dataset + from anemoi.datasets.use import open_dataset from anemoi.datasets.use.gridded.statistics import Statistics statistics = kwargs.pop("statistics") @@ -301,12 +299,6 @@ def __subset(self, **kwargs: Any) -> "Dataset": if skip_missing_dates: return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate() - if "rolling_average" in kwargs: - from anemoi.datasets.use.gridded.rolling_average import RollingAverage - - rolling_average = kwargs.pop("rolling_average") - return RollingAverage(self, rolling_average)._subset(**kwargs).mutate() - if "interpolate_frequency" in kwargs: from anemoi.datasets.use.gridded.interpolate import InterpolateFrequency @@ -1026,7 +1018,7 @@ def origins(self) -> Any: print(p.origins()) def components(self) -> Any: - from anemoi.datasets.use.components import Projection + from anemoi.datasets.use.gridded.components import Projection slices = tuple(slice(0, i, 1) for i in self.shape) return self.project(Projection(slices)) diff --git a/src/anemoi/datasets/use/gridded/ensemble.py b/src/anemoi/datasets/use/gridded/ensemble.py index 0d1aa15b2..1ecff8b97 100644 --- a/src/anemoi/datasets/use/gridded/ensemble.py +++ b/src/anemoi/datasets/use/gridded/ensemble.py @@ -124,6 +124,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """Returns metadata specific to the Number object.""" return {} + def origin_transformation(self, variable, origins): + return { + "name": "number", + "config": {"members": self.members}, + } + class Ensemble(GivenAxis): """A class to represent an ensemble of datasets combined along a given axis.""" diff --git a/src/anemoi/datasets/use/gridded/fill_missing.py b/src/anemoi/datasets/use/gridded/fill_missing.py index 337549cfc..fb0c2f098 100644 --- a/src/anemoi/datasets/use/gridded/fill_missing.py +++ b/src/anemoi/datasets/use/gridded/fill_missing.py @@ -14,7 +14,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.use import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import TupleIndex diff --git a/src/anemoi/datasets/use/gridded/forwards.py b/src/anemoi/datasets/use/gridded/forwards.py index 0ee6f8ac7..3966dd34b 100644 --- a/src/anemoi/datasets/use/gridded/forwards.py +++ b/src/anemoi/datasets/use/gridded/forwards.py @@ -240,6 +240,9 @@ def constant_fields(self) -> list[str]: """Returns the constant fields of the forward dataset.""" return self.forward.constant_fields + def project(self, projection): + return self.forward.project(projection).add_transformation(self) + class Combined(Forwards): """A class to combine multiple datasets into a single dataset.""" diff --git a/src/anemoi/datasets/use/gridded/grids.py b/src/anemoi/datasets/use/gridded/grids.py index 8b399a820..2f49d74f8 100644 --- a/src/anemoi/datasets/use/gridded/grids.py +++ b/src/anemoi/datasets/use/gridded/grids.py @@ -21,15 +21,167 @@ from anemoi.datasets.use.gridded.dataset import Shape from anemoi.datasets.use.gridded.dataset import TupleIndex from anemoi.datasets.use.gridded.debug import Node +from anemoi.datasets.use.gridded.debug import debug_indexing +from anemoi.datasets.use.gridded.forwards import Combined from anemoi.datasets.use.gridded.forwards import GivenAxis from anemoi.datasets.use.gridded.indexing import apply_index_to_slices_changes +from anemoi.datasets.use.gridded.indexing import expand_list_indexing from anemoi.datasets.use.gridded.indexing import index_to_slices +from anemoi.datasets.use.gridded.indexing import length_to_slices +from anemoi.datasets.use.gridded.indexing import update_tuple from anemoi.datasets.use.gridded.misc import _auto_adjust from anemoi.datasets.use.gridded.misc import _open LOG = logging.getLogger(__name__) +class Concat(Combined): + """A class to represent concatenated datasets.""" + + def __len__(self) -> int: + """Returns the total length of the concatenated datasets. + + Returns + ------- + int + Total length of the concatenated datasets. + """ + return sum(len(i) for i in self.datasets) + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, index: TupleIndex) -> NDArray[Any]: + """Retrieves a tuple of data from the concatenated datasets based on the given index. + + Parameters + ---------- + index : TupleIndex + Index specifying the data to retrieve. + + Returns + ------- + NDArray[Any] + Concatenated data array from the specified index. + """ + index, changes = index_to_slices(index, self.shape) + # print(index, changes) + lengths = [d.shape[0] for d in self.datasets] + slices = length_to_slices(index[0], lengths) + # print("slies", slices) + result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None] + result = np.concatenate(result, axis=0) + return apply_index_to_slices_changes(result, changes) + + @debug_indexing + def __getitem__(self, n: FullIndex) -> NDArray[Any]: + """Retrieves data from the concatenated datasets based on the given index. + + Parameters + ---------- + n : FullIndex + Index specifying the data to retrieve. + + Returns + ------- + NDArray[Any] + Data array from the concatenated datasets based on the index. + """ + if isinstance(n, tuple): + return self._get_tuple(n) + + if isinstance(n, slice): + return self._get_slice(n) + + # TODO: optimize + k = 0 + while n >= self.datasets[k]._len: + n -= self.datasets[k]._len + k += 1 + return self.datasets[k][n] + + @debug_indexing + def _get_slice(self, s: slice) -> NDArray[Any]: + """Retrieves a slice of data from the concatenated datasets. + + Parameters + ---------- + s : slice + Slice object specifying the range of data to retrieve. + + Returns + ------- + NDArray[Any] + Concatenated data array from the specified slice. + """ + result = [] + + lengths = [d.shape[0] for d in self.datasets] + slices = length_to_slices(s, lengths) + + result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None] + + return np.concatenate(result) + + def check_compatibility(self, d1: Dataset, d2: Dataset) -> None: + """Check the compatibility of two datasets for concatenation. + + Parameters + ---------- + d1 : Dataset + The first dataset. + d2 : Dataset + The second dataset. + """ + super().check_compatibility(d1, d2) + self.check_same_sub_shapes(d1, d2, drop_axis=0) + + def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None: + """Check if the lengths of two datasets are the same. + + Parameters + ---------- + d1 : Dataset + The first dataset. + d2 : Dataset + The second dataset. + """ + # Turned off because we are concatenating along the first axis + pass + + def check_same_dates(self, d1: Dataset, d2: Dataset) -> None: + """Check if the dates of two datasets are the same. + + Parameters + ---------- + d1 : Dataset + The first dataset. + d2 : Dataset + The second dataset. + """ + # Turned off because we are concatenating along the dates axis + pass + + @property + def dates(self) -> NDArray[np.datetime64]: + """Returns the concatenated dates of all datasets.""" + return np.concatenate([d.dates for d in self.datasets]) + + @property + def shape(self) -> Shape: + """Returns the shape of the concatenated datasets.""" + return (len(self),) + self.datasets[0].shape[1:] + + def tree(self) -> Node: + """Generates a hierarchical tree structure for the concatenated datasets. + + Returns + ------- + Node + A Node object representing the concatenated datasets. + """ + return Node(self, [d.tree() for d in self.datasets]) + + class GridsBase(GivenAxis): """A base class for handling grids in datasets.""" @@ -203,7 +355,7 @@ def _initialize_masks(self) -> None: ValueError If the global mask dimension does not match the global dataset grid points. """ - from anemoi.datasets.misc.grids import cutout_mask + from anemoi.datasets.grids import cutout_mask for i, lam in enumerate(self.lams): assert len(lam.shape) == len( diff --git a/src/anemoi/datasets/use/gridded/interpolate.py b/src/anemoi/datasets/use/gridded/interpolate.py index f3c5155f9..31f412de8 100644 --- a/src/anemoi/datasets/use/gridded/interpolate.py +++ b/src/anemoi/datasets/use/gridded/interpolate.py @@ -227,7 +227,7 @@ def __init__(self, dataset: Dataset, interpolate_variables: list[str], max_dista max_distance : Optional[float], optional The maximum distance for nearest neighbor search, by default None. """ - from anemoi.datasets.misc.grids import nearest_grid_points + from anemoi.datasets.grids import nearest_grid_points super().__init__(dataset) self.vars = interpolate_variables diff --git a/src/anemoi/datasets/use/gridded/join.py b/src/anemoi/datasets/use/gridded/join.py index 7d150f01b..21a190cb4 100644 --- a/src/anemoi/datasets/use/gridded/join.py +++ b/src/anemoi/datasets/use/gridded/join.py @@ -14,6 +14,7 @@ from typing import Any import numpy as np +import rich from numpy.typing import NDArray from anemoi.datasets.use.gridded.dataset import Dataset @@ -175,6 +176,8 @@ def _overlay(self) -> Dataset: from anemoi.datasets.use.gridded.select import Select + rich.print("Overlaying join with", variables, len(indices), [d.shape for d in self.datasets]) + return Select(self, indices, {"overlay": variables}) @cached_property diff --git a/src/anemoi/datasets/use/gridded/masked.py b/src/anemoi/datasets/use/gridded/masked.py index 675ae8dc2..203046fc0 100644 --- a/src/anemoi/datasets/use/gridded/masked.py +++ b/src/anemoi/datasets/use/gridded/masked.py @@ -15,7 +15,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.misc.grids import cropping_mask +from anemoi.datasets.grids import cropping_mask from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import Shape @@ -200,6 +200,12 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: """ return dict(thinning=self.thinning, method=self.method) + def origin_transformation(self, variable, origins): + return { + "name": "thinning", + "config": dict(thinning=self.thinning, method=self.method), + } + class Cropping(Masked): """A class to represent a cropped dataset.""" @@ -214,7 +220,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from ..data import open_dataset + from anemoi.datasets.use import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/use/gridded/merge.py b/src/anemoi/datasets/use/gridded/merge.py index d6a1943e5..2fee1122d 100644 --- a/src/anemoi/datasets/use/gridded/merge.py +++ b/src/anemoi/datasets/use/gridded/merge.py @@ -16,7 +16,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import TupleIndex diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index 58305511d..deab32c83 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -349,18 +349,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " """ from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.stores import Zarr - - if isinstance(a, str) and len(a.split(".")) in [2, 3]: - - metadata_path = os.path.join(a, "metadata.json") - if os.path.exists(metadata_path): - metadata = load_any_dict_format(metadata_path) - if "backend" not in metadata: - raise ValueError(f"Metadata for {a} does not contain 'backend' key") - - from anemoi.datasets.use.tabular.records import open_records_dataset - - return open_records_dataset(a, backend=metadata["backend"]) + from anemoi.datasets.use.gridded.stores import dataset_lookup if isinstance(a, Dataset): return a.mutate() @@ -369,8 +358,6 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " return Zarr(a).mutate() if isinstance(a, str): - from anemoi.datasets.use.gridded.stores import dataset_lookup - path = dataset_lookup(a) if path and path.endswith(".zarr") or path.endswith(".zip"): @@ -386,7 +373,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in load_any_dict_format(metadata_path): raise ValueError(f"Metadata for {path} does not contain 'backend' key") - from anemoi.datasets.use.records import open_records_dataset + from anemoi.datasets.use.gridded.records import open_records_dataset return open_records_dataset(path) @@ -521,7 +508,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": sets.append(_open(a)) if "observations" in kwargs: - from anemoi.datasets.use.tabular.observations import observations_factory + from anemoi.datasets.use.gridded.observations import observations_factory assert not sets, sets @@ -608,13 +595,13 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": assert len(sets) > 0, (args, kwargs) if "set_group" in kwargs: - from anemoi.datasets.use.records import FieldsRecords + from anemoi.datasets.use.gridded.records import FieldsRecords set_group = kwargs.pop("set_group") assert len(sets) == 1, "set_group can only be used with a single dataset" dataset = sets[0] - from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.gridded.dataset import Dataset if isinstance(dataset, Dataset): # Fields dataset return FieldsRecords(dataset, **kwargs, name=set_group).mutate() diff --git a/src/anemoi/datasets/use/gridded/missing.py b/src/anemoi/datasets/use/gridded/missing.py index 6cd97c247..298e0fc52 100644 --- a/src/anemoi/datasets/use/gridded/missing.py +++ b/src/anemoi/datasets/use/gridded/missing.py @@ -16,8 +16,8 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.create.gridded.utils import to_datetime -from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets.create.utils import to_datetime +from anemoi.datasets.use import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import TupleIndex diff --git a/src/anemoi/datasets/use/gridded/observations/__init__.py b/src/anemoi/datasets/use/gridded/observations/__init__.py new file mode 100644 index 000000000..804adddad --- /dev/null +++ b/src/anemoi/datasets/use/gridded/observations/__init__.py @@ -0,0 +1,313 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging +import os +from functools import cached_property +from typing import Any + +import numpy as np +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.use.gridded.dataset import Dataset +from anemoi.datasets.use.gridded.debug import Node + +LOG = logging.getLogger(__name__) + + +def round_datetime(dt, frequency, up=True): + dt = dt.replace(minute=0, second=0, microsecond=0) + hour = dt.hour + if hour % frequency != 0: + dt = dt.replace(hour=(hour // frequency) * frequency) + dt = dt + datetime.timedelta(hours=frequency) + return dt + + +def make_dates(start, end, frequency): + if isinstance(start, np.datetime64): + start = start.astype(datetime.datetime) + if isinstance(end, np.datetime64): + end = end.astype(datetime.datetime) + + dates = [] + current_date = start + while current_date <= end: + dates.append(current_date) + current_date += frequency + + dates = [np.datetime64(d, "s") for d in dates] + dates = np.array(dates, dtype="datetime64[s]") + return dates + + +class ObservationsBase(Dataset): + resolution = None + + @cached_property + def shape(self): + return (len(self.dates), len(self.variables), "dynamic") + + def empty_item(self): + return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32) + + def metadata(self): + return dict(observations_datasets="obs datasets currenty have no metadata") + + def _check(self): + pass + + def __len__(self): + return len(self.dates) + + def tree(self): + return Node( + self, + [], + ) + + def __getitem__(self, i): + if isinstance(i, int): + return self.getitem(i) + + # The following may would work but is likely to change in the future + # if isinstance(i, slice): + # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))] + # if isinstance(i, list): + # return [self.getitem(j) for j in i] + + raise ValueError( + f"Expected int, got {i} of type {type(i)}. Only int is supported to index " + "observations datasets. Please use a second [] to select part of the data [i][a,b,c]" + ) + + @property + def variables(self): + raise NotImplementedError() + + def collect_input_sources(self): + LOG.warning("collect_input_sources method is not implemented") + return [] + + def constant_fields(self): + LOG.warning("constant_fields method is not implemented") + return [] + + @property + def dates(self): + return self._dates + + @property + def dtype(self): + return np.float32 + + @property + def field_shape(self): + return self.shape[1:] + + @property + def frequency(self): + assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}" + return self._frequency + + @property + def latitudes(self): + raise NotImplementedError("latitudes property is not implemented") + + @property + def longitudes(self): + raise NotImplementedError("longitudes property is not implemented") + + @property + def missing(self): + return [] + + def statistics_tendencies(self): + raise NotImplementedError("statistics_tendencies method is not implemented") + + def variables_metadata(self): + raise NotImplementedError("variables_metadata method is not implemented") + + +class ObservationsZarr(ObservationsBase): + def __init__(self, dataset, frequency=None, window=None): + import zarr + + if isinstance(dataset, zarr.hierarchy.Group): + dataset = dataset._store.path + + from anemoi.datasets.use.gridded.stores import dataset_lookup + + dataset = dataset_lookup(dataset) + self.path = dataset + assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" + + if frequency is None: + frequency = self._probe_attributes.get("frequency") + # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}") + if frequency is None: + frequency = "6h" + # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}") + self._frequency = frequency_to_timedelta(frequency) + assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}" + if self.frequency.total_seconds() != 6 * 3600: + LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown") + + frequency_hours = int(self.frequency.total_seconds() // 3600) + assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}" + + if window is None: + window = (-frequency_hours, 0) + if window != (-frequency_hours, 0): + raise ValueError("For now, only window = (- frequency, 0) are supported") + + self.window = window + + start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"] + start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end) + start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours) + + self._dates = make_dates(start + self.frequency, end, self.frequency) + + first_window_begin = start.strftime("%Y%m%d%H%M%S") + first_window_begin = int(first_window_begin) + # last_window_end must be the end of the time window of the last item + last_window_end = int(end.strftime("%Y%m%d%H%M%S")) + + from anemoi.datasets.use.gridded.observations.legacy_obs_dataset import ObsDataset + + args = [self.path, first_window_begin, last_window_end] + kwargs = dict( + len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item + step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items + ) + self.forward = ObsDataset(*args, **kwargs) + + assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}" + assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}" + + if len(self.forward) != len(self.dates): + raise ValueError( + f"Dates are not consistent with the number of items in the dataset. " + f"The dataset contains {len(self.forward)} time windows. " + f"This is not compatible with the " + f"{len(self.dates)} requested dates with frequency={frequency_hours}" + f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} " + ) + + @property + def source(self): + return self.path + + def get_dataset_names(self): + name = os.path.basename(self.path) + if name.endswith(".zarr"): + name = name[:-5] + return [name] + + @cached_property + def _probe_attributes(self): + import zarr + + z = zarr.open(self.path, mode="r") + return dict(z.data.attrs) + + def get_aux(self, i): + data = self.forward[i] + + latitudes = data[:, self.name_to_index["__latitudes"]].numpy() + longitudes = data[:, self.name_to_index["__longitudes"]].numpy() + + reference = self.dates[i] + times = self.forward.get_dates(i) + if str(times.dtype) != "datetime64[s]": + LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ") + times = times.astype("datetime64[s]") + assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}" + timedeltas = times - reference + + assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" + assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" + + assert timedeltas.dtype == "timedelta64[s]", f"Expected timedelta64[s], got {timedeltas.dtype}" + return latitudes, longitudes, timedeltas + + def getitem(self, i): + data = self.forward[i] + + data = data.numpy().astype(np.float32) + assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}" + data = data.T + + if not data.size: + data = self.empty_item() + assert ( + data.shape[0] == self.shape[1] + ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}" + return data + + @cached_property + def variables(self): + colnames = self.forward.colnames + variables = [] + for n in colnames: + if n.startswith("obsvalue_"): + n = n.replace("obsvalue_", "") + if n == "latitude" or n == "lat": + assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}" + variables.append("__latitudes") + continue + if n == "longitude" or n == "lon": + assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}" + variables.append("__longitudes") + continue + assert not n.startswith("__"), f"Invalid name {n} found in {colnames}" + variables.append(n) + return variables + + @property + def name_to_index(self): + return {n: i for i, n in enumerate(self.variables)} + + @property + def statistics(self): + mean = self.forward.properties["means"] + mean = np.array(mean, dtype=np.float32) + + var = self.forward.properties["vars"] + var = np.array(var, dtype=np.float32) + stdev = np.sqrt(var) + + minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32) + maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32) + + assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}" + assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}" + assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}" + assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}" + return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum) + + def tree(self): + return Node( + self, + [], + path=self.path, + frequency=self.frequency, + ) + + def __repr__(self): + return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})" + + +def observations_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> ObservationsBase: + observations = kwargs.pop("observations") + + if not isinstance(observations, dict): + observations = dict(dataset=observations) + dataset = ObservationsZarr(**observations) + return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py b/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py new file mode 100644 index 000000000..85ab51583 --- /dev/null +++ b/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py @@ -0,0 +1,200 @@ +# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import datetime +import logging + +import numpy as np +import pandas as pd +import torch +import zarr +from torch.utils.data import Dataset + +LOG = logging.getLogger(__name__) + + +class ObsDataset(Dataset): + + def __init__( + self, + filename: str, + start: int, + end: int, + len_hrs: int, + step_hrs: int = None, + select: list[str] = None, + drop: list[str] = None, + ) -> None: + + self.filename = filename + self.z = zarr.open(filename, mode="r") + self.data = self.z["data"] + self.dt = self.z["dates"] # datetime only + self.hrly_index = self.z["idx_197001010000_1"] + self.colnames = self.data.attrs["colnames"] + self.selected_colnames = self.colnames + self.selected_cols_idx = np.arange(len(self.colnames)) + self.len_hrs = len_hrs + self.step_hrs = step_hrs if step_hrs else len_hrs + + # Create index for samples + self._setup_sample_index(start, end, self.len_hrs, self.step_hrs) + + self._load_properties() + + if select: + self.select(select) + + if drop: + self.drop(drop) + + def __getitem__( + self, + idx: int, + ) -> torch.tensor: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + data = self.data.oindex[start_row:end_row, self.selected_cols_idx] + + return torch.from_numpy(data) + + def __len__(self) -> int: + + return len(self.indices_start) + + def get_dates( + self, + idx: int, + ) -> np.ndarray: + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + dates = self.dt.oindex[start_row:end_row] + + assert len(dates.shape) == 2, dates.shape + dates = dates[:, 0] + + if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"): + dates = dates.astype("datetime64[s]") + + return dates + + def get_df(self, idx: int) -> pd.DataFrame: + """Convenience function to return data for sample idx packaged in a pandas DataFrame""" + + d = self.__getitem__(idx) + + df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx]) + + start_row = self.indices_start[idx] + end_row = self.indices_end[idx] + + dts = self.dt[start_row:end_row, :] + df["datetime"] = dts + + return df + + def select(self, cols_list: list[str]) -> None: + """Allow user to specify which columns they want to access. + Get functions only returned for these specified columns. + """ + self.selected_colnames = cols_list + self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list]) + + def drop(self, cols_to_drop: list[str]) -> None: + """Allow user to drop specific columns from the dataset. + Get functions no longer return data for these columns after being set. + """ + mask = [name not in cols_to_drop for name in self.selected_colnames] + + self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep] + self.selected_cols_idx = self.selected_cols_idx[mask] + + def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: + """Returns a tuple of datetime objects describing the start and end times of the sample at position idx.""" + + if idx < 0: + idx = len(self) + idx + + time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1) + time_end = min( + self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)), + self.end_dt, + ) + + return (np.datetime64(time_start), np.datetime64(time_end)) + + def first_sample_with_data(self) -> int: + """Returns the position of the first sample which contains data.""" + return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None + + def last_sample_with_data(self) -> int: + """Returns the position of the last sample which contains data.""" + if self.indices_end.max() == 0: + last_sample = None + else: + last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1) + + return last_sample + + def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None: + """Dataset is divided into samples; + - each n_hours long + - sample 0 starts at start (yyyymmddhhmm) + - index array has one entry for each sample; contains the index of the first row + containing data for that sample + """ + + try: + from obsdata.config import config + + assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000" + except ImportError: + pass + base_yyyymmddhhmm = 197001010000 + + assert start > base_yyyymmddhhmm, ( + f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n" + f" Current value: {start}" + ) + + format_str = "%Y%m%d%H%M%S" + base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str) + self.start_dt = datetime.datetime.strptime(str(start), format_str) + self.end_dt = datetime.datetime.strptime(str(end), format_str) + + # Calculate hours since the base date for the requested dataset ranges + diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600) + diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600) + + # Find elements that need to be extracted from the hourly index + # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample + sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs) + sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end) + + # Initialize local index arrays + self.indices_start = np.zeros(sample_starts.shape, dtype=int) + self.indices_end = np.zeros(self.indices_start.shape, dtype=int) + + max_hrly_index = self.hrly_index.shape[0] - 1 + valid_start_mask = sample_starts <= max_hrly_index + valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index) + + # Copy elements from the hrly_index into the local index + self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]] + self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0) + + def _load_properties(self) -> None: + + self.properties = {} + + self.properties["means"] = self.data.attrs["means"] + self.properties["vars"] = self.data.attrs["vars"] + self.properties["data_idxs"] = self.data.attrs["data_idxs"] + self.properties["obs_id"] = self.data.attrs["obs_id"] diff --git a/src/anemoi/datasets/use/gridded/observations/multi.py b/src/anemoi/datasets/use/gridded/observations/multi.py new file mode 100644 index 000000000..a6b6be176 --- /dev/null +++ b/src/anemoi/datasets/use/gridded/observations/multi.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os + +from anemoi.datasets.use import open_dataset + +LOG = logging.getLogger(__name__) + + +class LegacyDatasets: + def __init__(self, paths, start=None, end=None, **kwargs): + self.paths = paths + + if not start or not end: + print( + "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets" + ) + lst = [self._open_dataset(p, **kwargs) for p in paths] + start = min([d.dates[0] for d in lst]) + end = max([d.dates[-1] for d in lst]) + + self._datasets = { + os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty") + for p in paths + } + + first = list(self._datasets.values())[0] + for name, dataset in self._datasets.items(): + if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]: + raise ValueError("Datasets have different start and end times") + if dataset.frequency != first.frequency: + raise ValueError("Datasets have different frequencies") + + self._keys = self._datasets.keys + + self._first = list(self._datasets.values())[0] + + def _open_dataset(self, p, **kwargs): + if p.startswith("observations-"): + return open_dataset(observations=p, **kwargs) + else: + print("❗ Opening non-observations dataset:", p) + return open_dataset(p, **kwargs) + + def items(self): + return self._datasets.items() + + @property + def dates(self): + return self._first.dates + + def __len__(self): + return len(self._first) + + def __getitem__(self, i): + return {k: d[i] for k, d in self._datasets.items()} diff --git a/src/anemoi/datasets/use/gridded/rescale.py b/src/anemoi/datasets/use/gridded/rescale.py index 8426bffbe..4ecc1849d 100644 --- a/src/anemoi/datasets/use/gridded/rescale.py +++ b/src/anemoi/datasets/use/gridded/rescale.py @@ -242,3 +242,10 @@ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict raise NotImplementedError("rescale tendencies statistics", k) return result + + def origin_transformation(self, variable, origins): + config = {} + for variable, (a, b) in self.rescale.items(): + config[variable] = {"scale": a, "offset": b} + + return {"name": "rescale", "config": config} diff --git a/src/anemoi/datasets/use/gridded/select.py b/src/anemoi/datasets/use/gridded/select.py index 9a8fcc385..344ae56a4 100644 --- a/src/anemoi/datasets/use/gridded/select.py +++ b/src/anemoi/datasets/use/gridded/select.py @@ -224,6 +224,17 @@ def forwards_subclass_metadata_specific(self) -> dict[str, Any]: # return dict(indices=self.indices) return dict(reason=self.reason) + def forward_subclass_origin(self, index): + assert ( + isinstance(index, tuple) and len(index) == 4 and all(a > b >= 0 for a, b in zip(self.shape, index)) + ), tuple + + return self.dataset.origin((index[0], self.indices[index[1]], index[2], index[3])) + + def project(self, projection): + projection = projection.from_indices(axis=1, indices=self.indices) + return self.dataset.project(projection) + class Rename(Forwards): """Class to rename variables in a dataset.""" diff --git a/src/anemoi/datasets/use/gridded/statistics.py b/src/anemoi/datasets/use/gridded/statistics.py index 236ce1b7a..d56b03c87 100644 --- a/src/anemoi/datasets/use/gridded/statistics.py +++ b/src/anemoi/datasets/use/gridded/statistics.py @@ -15,7 +15,7 @@ from numpy.typing import NDArray -from anemoi.datasets.use.gridded import open_dataset +from anemoi.datasets import open_dataset from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.debug import Node from anemoi.datasets.use.gridded.forwards import Forwards diff --git a/src/anemoi/datasets/use/gridded/stores.py b/src/anemoi/datasets/use/gridded/stores.py index 3538e1040..c319ed9ae 100644 --- a/src/anemoi/datasets/use/gridded/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -23,7 +23,7 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray -from anemoi.datasets.use.gridded import MissingDateError +from anemoi.datasets import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import Shape @@ -101,7 +101,7 @@ def __getitem__(self, key: str) -> bytes: target = self.url + "/" + key try: - return get_object(target) + return get_object(target).bytes() except FileNotFoundError: raise KeyError(target) @@ -223,7 +223,7 @@ def from_name(cls, name: str) -> "Zarr": """Create a Zarr dataset from a name.""" if name.endswith(".zip") or name.endswith(".zarr"): return Zarr(name) - return Zarr(zarr_lookup(name)) + return Zarr(dataset_lookup(name)) def __len__(self) -> int: """Return the length of the dataset.""" @@ -540,10 +540,6 @@ def label(self) -> str: QUIET = set() -def zarr_lookup(*args, **kwargs) -> Optional[str]: - return dataset_lookup(*args, **kwargs) - - def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: """Look up a zarr dataset by name.""" @@ -581,7 +577,7 @@ def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: tried.append(full) try: - from anemoi.datasets.use.tabular.records import open_records_dataset + from anemoi.datasets.use.gridded.records import open_records_dataset z = open_records_dataset(full) if z is not None: diff --git a/src/anemoi/datasets/use/tabular/observations/__init__.py b/src/anemoi/datasets/use/tabular/observations/__init__.py index 5156be8a6..b231c2c66 100644 --- a/src/anemoi/datasets/use/tabular/observations/__init__.py +++ b/src/anemoi/datasets/use/tabular/observations/__init__.py @@ -141,9 +141,9 @@ def __init__(self, dataset, frequency=None, window=None): if isinstance(dataset, zarr.hierarchy.Group): dataset = dataset._store.path - from anemoi.datasets.use.gridded.stores import zarr_lookup + from anemoi.datasets.use.gridded.stores import dataset_lookup - dataset = zarr_lookup(dataset) + dataset = dataset_lookup(dataset) self.path = dataset assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" @@ -179,7 +179,7 @@ def __init__(self, dataset, frequency=None, window=None): # last_window_end must be the end of the time window of the last item last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - from anemoi.datasets.use.tabular.observations.legacy_obs_dataset import ObsDataset + from anemoi.datasets.use.gridded.tabular.observations.legacy_obs_dataset import ObsDataset args = [self.path, first_window_begin, last_window_end] kwargs = dict( diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index 13b729ef0..c637c6d80 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -20,8 +20,8 @@ from anemoi.datasets.use.gridded.debug import Node -from .records.backends import backend_factory -from .windows import window_from_str +from ..windows import window_from_str +from .backends import backend_factory LOG = logging.getLogger(__name__) @@ -364,7 +364,7 @@ def __init__(self, fields_dataset, name): . """ self.forward = fields_dataset - from anemoi.datasets.use.dataset import Dataset + from anemoi.datasets.use.gridded.dataset import Dataset assert isinstance(fields_dataset, Dataset), f"fields_dataset must be a Dataset, got {type(fields_dataset)}" self._name = name diff --git a/src/anemoi/datasets/use/tabular/records/windows.py b/src/anemoi/datasets/use/tabular/windows.py similarity index 100% rename from src/anemoi/datasets/use/tabular/records/windows.py rename to src/anemoi/datasets/use/tabular/windows.py diff --git a/src/anemoi/datasets/validate.py b/src/anemoi/datasets/validate.py new file mode 100644 index 000000000..945b13833 --- /dev/null +++ b/src/anemoi/datasets/validate.py @@ -0,0 +1,598 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import math +from collections import defaultdict + +import numpy as np + +from anemoi.datasets.testing import default_test_indexing +from anemoi.datasets.use.gridded.dataset import Dataset + +LOG = logging.getLogger(__name__) +# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1 + +TRAINING_METHODS = [ + "__getitem__", + "__len__", + "latitudes", + "longitudes", + "metadata", # Accessed when checkpointing + "missing", + "name_to_index", + "shape", + "statistics", + "supporting_arrays", # Accessed when checkpointing + "variables", +] + +EXTRA_TRAINING_METHODS = [ + "statistics_tendencies", +] + +DEBUGGING_METHODS = [ + "plot", + "to_index", + "tree", + "source", +] + +PUBLIC_METADATA_METHODS = [ + "arguments", + "dtype", + "end_date", + "resolution", + "start_date", + "field_shape", + "frequency", + "dates", + "typed_variables", + "variables_metadata", +] + +PRIVATE_METADATA_METHODS = [ + "computed_constant_fields", + "constant_fields", + "dataset_metadata", + "label", + "metadata_specific", + "provenance", +] + +INTERNAL_METHODS = [ + "mutate", + "swap_with_parent", + "dates_interval_to_indices", +] + +EXPERIMENTAL_METHODS = [ + "get_dataset_names", + "name", + "grids", +] + +OTHER_METHODS = [ + "collect_input_sources", + "collect_supporting_arrays", + "sub_shape", +] + + +METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")} + + +METHODS = set(sum(METHODS_CATEGORIES.values(), [])) + + +KWARGS = { + "__len__": {}, + "__getitem__": {"index": 0}, + "get_dataset_names": {"names": set()}, + "metadata": {}, + "metadata_specific": {}, + "mutate": {}, + "plot": {"date": 0, "variable": 0}, + "provenance": {}, + "source": {"index": 0}, + "statistics_tendencies": {}, + "sub_shape": {}, + "supporting_arrays": {}, + "swap_with_parent": {}, + "to_index": {"date": 0, "variable": 0}, + "tree": {}, +} + + +class Unknown: + emoji = "❓" + + +class Success: + emoji = "✅" + success = True + + def __repr__(self): + return "Success" + + +class Error: + success = False + + def __init__(self, message): + self.message = message + + def __repr__(self): + return str(self.message) or repr(self.message) or "Error" + + +class Failure(Error): + emoji = "💥" + + +class Internal(Error): + emoji = "💣" + + +class Invalid(Error): + emoji = "❌" + + +class Report: + + def __init__(self): + self.report = {} + self.methods = {} + self.warnings = defaultdict(list) + + def method(self, name, method): + self.methods[name] = method + + def success(self, name): + self.report[name] = Success() + + def failure(self, name, message): + self.report[name] = Failure(message) + + def internal(self, name, message): + self.report[name] = Internal(message) + + def invalid(self, name, exception): + self.report[name] = Invalid(exception) + + def warning(self, name, message): + self.warnings[name].append(message) + + def summary(self, detailed=False): + + maxlen = max(len(name) for name in self.report.keys()) + + for name, methods in METHODS_CATEGORIES.items(): + print() + print(f"{name.title().replace('_', ' ')}:") + print("-" * (len(name) + 1)) + print() + + for method in methods: + r = self.report.get(method, Unknown()) + msg = repr(r) + if not msg.endswith("."): + msg += "." + print(f"{r.emoji} {method.ljust(maxlen)}: {msg}") + + for w in self.warnings.get(method, []): + print(" " * (maxlen + 4), "⚠️", w) + + if r.success: + continue + + if not detailed: + continue + + if method not in self.methods: + continue + + proc = self.methods[method] + + doc = proc.__doc__ + if doc: + width = 80 + indent = maxlen + 4 + doc = "\n".join(["=" * width, "", doc, "=" * width]) + indented_doc = "\n".join(" " * indent + line for line in doc.splitlines()) + print() + print(indented_doc) + print() + print() + + print() + + +def _no_validate(report, dataset, name, result): + report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}") + + +def validate_variables(report, dataset, name, result): + """Validate the variables of the dataset.""" + + if not isinstance(result, (list, tuple)): + raise ValueError(f"Result is not a list or tuple {type(result)}") + + if len(result) != dataset.shape[1]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}") + + for value in result: + if not isinstance(value, str): + raise ValueError(f"`{value}` is not a string") + + +def validate_latitudes(report, dataset, name, result): + """Validate the latitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -90) & (result <= 90)): + raise ValueError("Result contains values outside the range [-90, 90]") + + if np.all((result >= -np.pi) & (result <= np.pi)): + report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?") + + +def validate_longitudes(report, dataset, name, result): + """Validate the longitudes of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result) != dataset.shape[3]: + raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + if not np.all((result >= -180) & (result <= 360)): + raise ValueError("Result contains values outside the range [-180, 360]") + + if np.all((result >= -np.pi) & (result <= 2 * np.pi)): + report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?") + + +def validate_statistics(report, dataset, name, result): + """Validate the statistics of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in ["mean", "stdev", "minimum", "maximum"]: + + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], np.ndarray): + raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}") + + if len(result[key].shape) != 1: + raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1") + + if result[key].shape[0] != len(dataset.variables): + raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}") + + if not np.all(np.isfinite(result[key])): + raise ValueError(f"Result[{key}] contains non-finite values") + + if np.isnan(result[key]).any(): + report.invalid(name, ValueError(f"Result[{key}] contains NaN values")) + + +def validate_shape(report, dataset, name, result): + """Validate the shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if len(result) != 4: + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}") + + if result[0] != len(dataset): + raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}") + + if result[1] != len(dataset.variables): + raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}") + + if result[2] != 1: # We ignore ensemble dimension for now + pass + + if result[3] != len(dataset.latitudes): + raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}") + + +def validate_supporting_arrays(report, dataset, name, result): + """Validate the supporting arrays of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + if "latitudes" not in result: + raise ValueError("Result does not contain `latitudes`") + + if "longitudes" not in result: + raise ValueError("Result does not contain `longitudes`") + + if not isinstance(result["latitudes"], np.ndarray): + raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}") + + if not isinstance(result["longitudes"], np.ndarray): + raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}") + + if np.any(result["latitudes"] != dataset.latitudes): + raise ValueError("Result[latitudes] does not match dataset.latitudes") + + if np.any(result["longitudes"] != dataset.longitudes): + raise ValueError("Result[longitudes] does not match dataset.longitudes") + + +def validate_dates(report, dataset, name, result): + """Validate the dates of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if len(result.shape) != 1: + raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1") + + if result.shape[0] != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}") + + if not np.issubdtype(result.dtype, np.datetime64): + raise ValueError(f"Result is not a datetime64 array {result.dtype}") + + if len(result) != len(dataset.dates): + raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}") + + if not np.all(np.isfinite(result)): + raise ValueError("Result contains non-finite values") + + if np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + for d1, d2 in zip(result[:-1], result[1:]): + if d1 >= d2: + raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}") + + frequency = np.diff(result) + if not np.all(frequency == frequency[0]): + raise ValueError("Result contains non-constant frequency") + + +def validate_metadata(report, dataset, name, result): + """Validate the metadata of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + +def validate_missing(report, dataset, name, result): + """Validate the missing values of the dataset.""" + + if not isinstance(result, set): + raise ValueError(f"Result is not a set {type(result)}") + + if not all(isinstance(item, int) for item in result): + raise ValueError("Result contains non-integer values") + + if len(result) > 0: + if min(result) < 0: + raise ValueError("Result contains negative values") + + if max(result) >= len(dataset): + raise ValueError(f"Result contains values greater than {len(dataset)}") + + +def validate_name_to_index(report, dataset, name, result): + """Validate the name to index mapping of the dataset.""" + + if not isinstance(result, dict): + raise ValueError(f"Result is not a dict {type(result)}") + + for key in dataset.variables: + if key not in result: + raise ValueError(f"Result does not contain `{key}`") + + if not isinstance(result[key], int): + raise ValueError(f"Result[{key}] is not an int {type(result[key])}") + + if result[key] < 0 or result[key] >= len(dataset.variables): + raise ValueError(f"Result[{key}] is out of bounds: {result[key]}") + + index_to_name = {v: k for k, v in result.items()} + for i in range(len(dataset.variables)): + if i not in index_to_name: + raise ValueError(f"Result does not contain index `{i}`") + + if not isinstance(index_to_name[i], str): + raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}") + + if index_to_name[i] != dataset.variables[i]: + raise ValueError( + f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}" + ) + + +def validate___getitem__(report, dataset, name, result): + """Validate the __getitem__ method of the dataset.""" + + if not isinstance(result, np.ndarray): + raise ValueError(f"Result is not a np.ndarray {type(result)}") + + if result.shape != dataset.shape[1:]: + raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}") + + +def validate___len__(report, dataset, name, result): + """Validate the __len__ method of the dataset.""" + + if not isinstance(result, int): + raise ValueError(f"Result is not an int {type(result)}") + + if result != dataset.shape[0]: + raise ValueError(f"Result has wrong length: {result} != {len(dataset)}") + + if result != len(dataset.dates): + raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}") + + +def validate_start_date(report, dataset, name, result): + """Validate the start date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[0]: + raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}") + + +def validate_end_date(report, dataset, name, result): + """Validate the end date of the dataset.""" + + if not isinstance(result, np.datetime64): + raise ValueError(f"Result is not a datetime64 {type(result)}") + + if result != dataset.dates[-1]: + raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}") + + +def validate_field_shape(report, dataset, name, result): + """Validate the field shape of the dataset.""" + + if not isinstance(result, tuple): + raise ValueError(f"Result is not a tuple {type(result)}") + + if math.prod(result) != dataset.shape[-1]: + raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}") + + +def validate(report, dataset, name, kwargs=None): + + try: + + validate_fn = globals().get(f"validate_{name}", _no_validate) + + # Check if the method is still in the Dataset class + try: + report.method(name, getattr(Dataset, name)) + except AttributeError: + report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.") + return + + # Check if the method is supported by the dataset instance + try: + result = getattr(dataset, name) + except AttributeError as e: + report.failure(name, e) + return + + # Check if the method is callable + if callable(result): + if kwargs is None: + report.internal( + name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly." + ) + return + else: + if kwargs is not None: + report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.") + return + + if kwargs is not None: + result = result(**kwargs) + + if isinstance(result, np.ndarray) and np.isnan(result).any(): + report.invalid(name, ValueError("Result contains NaN values")) + return + + try: + validate_fn(report, dataset, name, result) + except Exception as e: + report.invalid(name, e) + return + + report.success(name) + + except Exception as e: + report.failure(name, e) + + +def validate_dtype(report, dataset, name, result): + """Validate the dtype of the dataset.""" + + if not isinstance(result, np.dtype): + raise ValueError(f"Result is not a np.dtype {type(result)}") + + +def validate_dataset(dataset, costly_checks=False, detailed=False): + """Validate the dataset.""" + + report = Report() + + if costly_checks: + # This check is expensive as it loads the entire dataset into memory + # so we make it optional + default_test_indexing(dataset) + + for i, x in enumerate(dataset): + y = dataset[i] + assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}" + + for name in METHODS: + validate(report, dataset, name, kwargs=KWARGS.get(name)) + + report.summary(detailed=detailed) + + +if __name__ == "__main__": + methods = METHODS_CATEGORIES.copy() + methods.pop("OTHER_METHODS") + + o = set(OTHER_METHODS) + overlap = False + for m in methods: + if set(methods[m]).intersection(set(OTHER_METHODS)): + print( + f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}" + ) + o = o - set(methods[m]) + overlap = True + + for m in methods: + for n in methods: + if n is not m: + if set(methods[m]).intersection(set(methods[n])): + print( + f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}" + ) + + if overlap: + print(sorted(o)) diff --git a/tests/create/test_observations.py b/tests/create/dont_test_observations.py similarity index 94% rename from tests/create/test_observations.py rename to tests/create/dont_test_observations.py index af0f02fe5..671d522a1 100644 --- a/tests/create/test_observations.py +++ b/tests/create/dont_test_observations.py @@ -14,8 +14,8 @@ from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.records import Interval -from anemoi.datasets.use.records import window_from_str +from anemoi.datasets.use.tabular.windows import Interval +from anemoi.datasets.use.tabular.windows import window_from_str class DummpySource(ObservationsSource): diff --git a/tests/create/test_observations_mars.py b/tests/create/dont_test_observations_mars.py similarity index 95% rename from tests/create/test_observations_mars.py rename to tests/create/dont_test_observations_mars.py index 1ca686b49..91a814490 100644 --- a/tests/create/test_observations_mars.py +++ b/tests/create/dont_test_observations_mars.py @@ -12,12 +12,14 @@ import pandas as pd from earthkit.data import from_source -from odb2df import process_odb from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.records import Interval -from anemoi.datasets.use.records import window_from_str +from anemoi.datasets.use.tabular.windows import Interval +from anemoi.datasets.use.tabular.windows import window_from_str + +# from odb2df import process_odb + log = logging.getLogger(__name__) @@ -115,7 +117,7 @@ def __call__(self, df): "values": ["obsvalue@body"], "drop_na": True, }, - process_func=process_odb, + # process_func=process_odb, ) filter = ColFilter("obsvalue_v10m_0") diff --git a/tests/create/test_observations_mars_bufr.py b/tests/create/dont_test_observations_mars_bufr.py similarity index 95% rename from tests/create/test_observations_mars_bufr.py rename to tests/create/dont_test_observations_mars_bufr.py index b916a58c0..0d8b99fda 100644 --- a/tests/create/test_observations_mars_bufr.py +++ b/tests/create/dont_test_observations_mars_bufr.py @@ -11,13 +11,14 @@ import logging import pandas as pd -from bufr2df import bufr2df + +# from bufr2df import bufr2df from earthkit.data import from_source from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.records import Interval -from anemoi.datasets.use.records import window_from_str +from anemoi.datasets.use.gridded.records import Interval +from anemoi.datasets.use.gridded.records import window_from_str log = logging.getLogger(__name__) @@ -114,7 +115,7 @@ def __call__(self, df): "radarRainfallIntensity": "obsvalue_precip1h_0", }, }, - process_func=bufr2df, + # process_func=bufr2df, ) filter = ColFilter("obsvalue_precip1h_0") diff --git a/tests/create/test_observations_mars_bufr_complex.py b/tests/create/dont_test_observations_mars_bufr_complex.py similarity index 95% rename from tests/create/test_observations_mars_bufr_complex.py rename to tests/create/dont_test_observations_mars_bufr_complex.py index d271a43c0..efb722486 100644 --- a/tests/create/test_observations_mars_bufr_complex.py +++ b/tests/create/dont_test_observations_mars_bufr_complex.py @@ -11,13 +11,14 @@ import logging import pandas as pd -from bufr2df_parallel import bufr2df_parallel + +# from bufr2df_parallel import bufr2df_parallel from earthkit.data import from_source from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.records import Interval -from anemoi.datasets.use.records import window_from_str +from anemoi.datasets.use.gridded.records import Interval +from anemoi.datasets.use.gridded.records import window_from_str log = logging.getLogger(__name__) @@ -133,7 +134,7 @@ def __call__(self, df): "latitudes": "lambda x: np.isfinite(x)", }, }, - process_func=bufr2df_parallel, + # process_func=bufr2df_parallel, ) filter = ColFilter("obsvalue_rawbt_9") diff --git a/tests/create/test_observations_mars_bufr_parallel.py b/tests/create/dont_test_observations_mars_bufr_parallel.py similarity index 94% rename from tests/create/test_observations_mars_bufr_parallel.py rename to tests/create/dont_test_observations_mars_bufr_parallel.py index d3562191d..369ee752b 100644 --- a/tests/create/test_observations_mars_bufr_parallel.py +++ b/tests/create/dont_test_observations_mars_bufr_parallel.py @@ -11,13 +11,14 @@ import logging import pandas as pd -from bufr2df_parallel import bufr2df_parallel + +# from bufr2df_parallel import bufr2df_parallel from earthkit.data import from_source from anemoi.datasets.create.sources.observations import ObservationsFilter from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.records import Interval -from anemoi.datasets.use.records import window_from_str +from anemoi.datasets.use.gridded.records import Interval +from anemoi.datasets.use.gridded.records import window_from_str log = logging.getLogger(__name__) @@ -115,7 +116,7 @@ def __call__(self, df): "radarRainfallIntensity": "obsvalue_precip1h_0", }, }, - process_func=bufr2df_parallel, + # process_func=bufr2df_parallel, ) filter = ColFilter("obsvalue_precip1h_0") diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index 82dffb264..e65e03bec 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -289,17 +289,17 @@ def test_planetary_computer_conus404() -> None: def test_csv(get_test_data: callable) -> None: """Test for CSV source registration.""" from anemoi.datasets.create.sources import create_source - from anemoi.datasets.dates import DatesProvider data = get_test_data("anemoi-datasets/obs/dribu.csv") - source = create_source(context=None, config={"csv": {"path": data}}) - window = DatesProvider.from_config( - { - "start": "2020-01-01T00:00:00", - "end": "2020-01-02:23:59:59", - "window": "(-3h:+3h]", - } - ) - - source.execute(window) + # source = + create_source(context=None, config={"csv": {"path": data}}) + # window = DatesProvider.from_config( + # { + # "start": "2020-01-01T00:00:00", + # "end": "2020-01-02:23:59:59", + # "window": "(-3h:+3h]", + # } + # ) + + # source.execute(window) diff --git a/tests/create/utils/create.py b/tests/create/utils/create.py index 573cb2ffe..344c3f0b4 100644 --- a/tests/create/utils/create.py +++ b/tests/create/utils/create.py @@ -12,7 +12,7 @@ import yaml -from anemoi.datasets.create.gridded import creator_factory +from anemoi.datasets.create.tasks import task_factory class TestingContext: @@ -45,8 +45,6 @@ def create_dataset( The path to the created dataset. """ - from anemoi.datasets.create.tasks import task_factory - if isinstance(config, dict): temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") yaml.dump(config, temp_file) @@ -61,9 +59,9 @@ def create_dataset( task_factory("patch", path=output).run() if delta is not None: - creator_factory("init_additions", path=output, delta=delta).run() - creator_factory("load_additions", path=output, delta=delta).run() - creator_factory("finalise_additions", path=output, delta=delta).run() + task_factory("init_additions", path=output, delta=delta).run() + task_factory("load_additions", path=output, delta=delta).run() + task_factory("finalise_additions", path=output, delta=delta).run() task_factory("cleanup", path=output).run() diff --git a/tests/test_classes.py b/tests/test_classes.py index 1cedcbf43..5a88ff78c 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -28,7 +28,7 @@ def zarr_tests(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("anemoi.datasets.use.gridded.stores.zarr_lookup", _tests_zarrs): + with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", _tests_zarrs): return func(*args, **kwargs) return wrapper @@ -43,9 +43,9 @@ def _test_dataset(ds, variables=None): ds.variables, ) - for p in ds.components(): - print(p) - print(p.origins()) + # for p in ds.components(): + # print(p) + # print(p.origins()) not_ready = pytest.mark.skip(reason="Not ready yet") @@ -66,6 +66,7 @@ def test_class_complement_none(): @skip_if_offline @zarr_tests def test_class_complement_nearest_1(): + ds = open_dataset( complement="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", source="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", diff --git a/tests/test_data.py b/tests/test_data.py index 6df8d2920..4232ff234 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -60,7 +60,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): with patch("zarr.open", zarr_from_str): - with patch("anemoi.datasets.use.gridded.stores.zarr_lookup", lambda name: name): + with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", lambda name: name): return func(*args, **kwargs) return wrapper diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index a20e1668b..1838fef0c 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -42,7 +42,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): with patch("zarr.open", zarr_from_str): - with patch("anemoi.datasets.use.gridded.stores.zarr_lookup", lambda name: name): + with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", lambda name: name): return func(*args, **kwargs) return wrapper diff --git a/tests/test_dates.py b/tests/test_dates.py index 32baf7a51..23ed919c8 100644 --- a/tests/test_dates.py +++ b/tests/test_dates.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from anemoi.datasets.create.gridded.statistics import default_statistics_dates +from anemoi.datasets.create.gridded.stats import default_statistics_dates _ = datetime.datetime diff --git a/tests/test_records.py b/tests/test_records.py index 4b2230660..f389a3cdf 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -154,6 +154,7 @@ def test_open_with_window(): _test(ds, nb_dates=8) +@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") def test_open_bad_window(): subset = dict(end="2018-11-30") with pytest.raises(ValueError, match="No dates left after rewindowing"): diff --git a/tools/build-obs.py b/tools/build-obs.py index db58cb5b6..bc407564a 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,7 +28,7 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") - from anemoi.datasets.use.tabular.records.backends import writer_backend_factory + from anemoi.datasets.use.gridded.tabular.records.backends import writer_backend_factory if not isinstance(backend, dict): backend = {"name": backend} From 54f12c1044496dbababa5e3ef3c0cb26f015f9b6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 15:46:14 +0000 Subject: [PATCH 183/212] update --- src/anemoi/datasets/use/gridded/masked.py | 2 +- src/anemoi/datasets/use/gridded/stores.py | 4 + tests/test_data.py | 48 +- tests/test_data_gridded.py | 548 ---------------------- 4 files changed, 52 insertions(+), 550 deletions(-) delete mode 100644 tests/test_data_gridded.py diff --git a/src/anemoi/datasets/use/gridded/masked.py b/src/anemoi/datasets/use/gridded/masked.py index 203046fc0..656b72fb4 100644 --- a/src/anemoi/datasets/use/gridded/masked.py +++ b/src/anemoi/datasets/use/gridded/masked.py @@ -220,7 +220,7 @@ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, area : Union[Dataset, Tuple[float, float, float, float]] The cropping area. """ - from anemoi.datasets.use import open_dataset + from anemoi.datasets import open_dataset area = area if isinstance(area, (list, tuple)) else open_dataset(area) diff --git a/src/anemoi/datasets/use/gridded/stores.py b/src/anemoi/datasets/use/gridded/stores.py index c319ed9ae..6b7da317d 100644 --- a/src/anemoi/datasets/use/gridded/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -543,6 +543,10 @@ def label(self) -> str: def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: """Look up a zarr dataset by name.""" + parsed = urlparse(name) + if parsed.scheme: + return name + config = load_config()["datasets"] use_search_path_not_found = config.get("use_search_path_not_found", False) diff --git a/tests/test_data.py b/tests/test_data.py index 4232ff234..7290cde17 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -9,6 +9,7 @@ import datetime +import os import tempfile from collections.abc import Callable from functools import cache @@ -60,7 +61,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): with patch("zarr.open", zarr_from_str): - with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", lambda name: name): + with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", lambda name: name + ".zarr"): return func(*args, **kwargs) return wrapper @@ -239,6 +240,8 @@ def zarr_from_str(name: str, mode: str) -> zarr.Group: """ # Format: test-2021-2021-6h-o96-abcd-0 + name, _ = os.path.splitext(name) + args = dict( test="test", start=2021, @@ -386,6 +389,8 @@ def run( regular_shape : bool, optional Whether the dataset has a regular shape, by default True. """ + from anemoi.datasets import open_dataset + if isinstance(expected_variables, str): expected_variables = [v for v in expected_variables] @@ -1433,6 +1438,47 @@ def mock_save_dataset(): assert (saved.dates == np.arange("2021-01-01", "2021-01-03", dtype="datetime64[6h]")).all() +@mockup_open_zarr +def test_trim_edge_simple() -> None: + """Test trimming the edges of a dataset.""" + test = DatasetTester( + "test-2021-2021-15,14-6h-o96-abcd", + trim_edge=(2, 3, 4, 5), + ) + + expected_field_shape = (10, 5) + assert test.ds.field_shape == expected_field_shape, test.ds.field_shape + assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape + + +@mockup_open_zarr +def test_trim_edge_zeros() -> None: + """Test trimming the edges of a dataset when edges are 0""" + for dim in range(2): + trim_edge = [0, 0, 0, 0] + trim_edge[dim] = 1 + test = DatasetTester( + "test-2021-2021-15,14-6h-o96-abcd", + trim_edge=trim_edge, + ) + + expected_field_shape = (14, 14) + assert test.ds.field_shape == expected_field_shape, test.ds.field_shape + assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape + + for dim in range(2, 4): + trim_edge = [0, 0, 0, 0] + trim_edge[dim] = 1 + test = DatasetTester( + "test-2021-2021-15,14-6h-o96-abcd", + trim_edge=trim_edge, + ) + + expected_field_shape = (15, 13) + assert test.ds.field_shape == expected_field_shape, test.ds.field_shape + assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape + + if __name__ == "__main__": for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py deleted file mode 100644 index 1838fef0c..000000000 --- a/tests/test_data_gridded.py +++ /dev/null @@ -1,548 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import datetime -from collections.abc import Callable -from functools import cache -from functools import wraps -from typing import Any -from unittest.mock import patch - -import numpy as np -import zarr -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets import open_dataset - -VALUES = 20 - - -def mockup_open_zarr(func: Callable) -> Callable: - """Decorator to mock the open_zarr function. - - Parameters - ---------- - func : Callable - Function to wrap. - - Returns - ------- - Callable - Wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - with patch("zarr.open", zarr_from_str): - with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", lambda name: name): - return func(*args, **kwargs) - - return wrapper - - -@cache -def _(date: datetime.datetime, var: str, k: int = 0, e: int = 0, values: int = VALUES) -> np.ndarray: - """Create a simple array of values based on the date and variable name, ensemble, grid, and other parameters. - - Parameters - ---------- - date : datetime.datetime - Date. - var : str - Variable name. - k : int, optional - Grid index, by default 0. - e : int, optional - Ensemble index, by default 0. - values : int, optional - Number of values, by default VALUES. - - Returns - ------- - np.ndarray - Array of values. - """ - d = date.year * 10000 + date.month * 100 + date.day - v = ord(var) - ord("a") + 1 - assert 0 <= k <= 9 - assert 0 <= e <= 9 - - return np.array([d * 100 + v + k / 10.0 + w / 100.0 + e / 1000.0 for w in range(values)]) - - -def create_zarr( - vars: str = "abcd", - start: int = 2021, - end: int = 2021, - field_shape: tuple = [4, 5], - frequency: datetime.timedelta = datetime.timedelta(hours=6), - resolution: str = "o96", - k: int = 0, - ensemble: int | None = None, - grids: int | None = None, - missing: bool = False, -) -> zarr.Group: - """Create a Zarr dataset. - - Parameters - ---------- - vars : str, optional - Variable names, by default "abcd". - start : int, optional - Start year, by default 2021. - end : int, optional - End year, by default 2021. - field_shape : tuple, optional - Field shape of dataset, by default [4, 5]. - frequency : datetime.timedelta, optional - Frequency, by default datetime.timedelta(hours=6). - resolution : str, optional - Resolution, by default "o96". - k : int, optional - Grid index, by default 0. - ensemble : Optional[int], optional - Number of ensembles, by default None. - grids : Optional[int], optional - Number of grids, by default None. - missing : bool, optional - Whether to include missing dates, by default False. - - Returns - ------- - zarr.Group - Zarr dataset. - """ - root = zarr.group() - assert isinstance(frequency, datetime.timedelta) - - dates = [] - date = datetime.datetime(start, 1, 1) - while date.year <= end: - dates.append(date) - date += frequency - - dates = np.array(dates, dtype="datetime64") - - ensembles = ensemble if ensemble is not None else 1 - values = grids if grids is not None else VALUES - - data = np.zeros(shape=(len(dates), len(vars), ensembles, values)) - - for i, date in enumerate(dates): - for j, var in enumerate(vars): - for e in range(ensembles): - data[i, j, e] = _(date.astype(object), var, k, e, values) - - root.create_dataset( - "data", - data=data, - dtype=data.dtype, - chunks=data.shape, - compressor=None, - ) - root.create_dataset( - "dates", - data=dates, - compressor=None, - ) - root.create_dataset( - "latitudes", - data=np.array([x + values for x in range(values)]), - compressor=None, - ) - root.create_dataset( - "longitudes", - data=np.array([x + values for x in range(values)]), - compressor=None, - ) - - root.attrs["frequency"] = frequency_to_string(frequency) - root.attrs["resolution"] = resolution - root.attrs["name_to_index"] = {k: i for i, k in enumerate(vars)} - - root.attrs["data_request"] = {"grid": 1, "area": "g", "param_level": {}} - root.attrs["variables_metadata"] = {v: {} for v in vars} - - if missing: - missing_dates = [] - - last = None - for date in [d.astype(object) for d in dates]: - yyyymmdd = date.strftime("%Y%m") - if yyyymmdd != last: - last = yyyymmdd - missing_dates.append(date) - - root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - - root.create_dataset( - "mean", - data=np.mean(data, axis=0), - compressor=None, - ) - root.create_dataset( - "stdev", - data=np.std(data, axis=0), - compressor=None, - ) - root.create_dataset( - "maximum", - data=np.max(data, axis=0), - compressor=None, - ) - root.create_dataset( - "minimum", - data=np.min(data, axis=0), - compressor=None, - ) - - root.attrs["field_shape"] = field_shape - - return root - - -def zarr_from_str(name: str, mode: str) -> zarr.Group: - """Create a Zarr dataset from a string. - - Parameters - ---------- - name : str - Dataset name. - mode : str - Mode. - - Returns - ------- - zarr.Group - Zarr dataset. - """ - # Format: test-2021-2021-6h-o96-abcd-0 - - args = dict( - test="test", - start=2021, - end=2021, - field_shape=[4, 5], - frequency=6, - resolution="o96", - vars="abcd", - k=0, - ensemble=None, - grids=None, - ) - - for name, bit in zip(args, name.split("-")): - args[name] = bit - - args["field_shape"] = [int(i) for i in args["field_shape"].split(",")] - - print(args) - - return create_zarr( - start=int(args["start"]), - end=int(args["end"]), - field_shape=args["field_shape"], - frequency=frequency_to_timedelta(args["frequency"]), - resolution=args["resolution"], - vars=[x for x in args["vars"]], - k=int(args["k"]), - ensemble=int(args["ensemble"]) if args["ensemble"] is not None else None, - grids=int(args["grids"]) if args["grids"] is not None else None, - missing=args["test"] == "missing", - ) - - -class IndexTester: - """Class to test indexing of datasets.""" - - def __init__(self, ds: Any) -> None: - """Initialize the IndexTester. - - Parameters - ---------- - ds : Any - Dataset. - """ - self.ds = ds - self.np = ds[:] # Numpy array - - assert self.ds.shape == self.np.shape - assert (self.ds == self.np).all() - - def __getitem__(self, index: Any) -> None: - """Test indexing. - - Parameters - ---------- - index : Any - Index. - """ - print("INDEX", type(self.ds), index) - if self.ds[index] is None: - assert False, (self.ds, index) - - if not (self.ds[index] == self.np[index]).all(): - # print("DS", self.ds[index]) - # print("NP", self.np[index]) - assert (self.ds[index] == self.np[index]).all() - - -def make_missing(x: Any) -> Any: - """Mark data as missing. - - Parameters - ---------- - x : Any - Data. - - Returns - ------- - Any - Data with missing values. - """ - if isinstance(x, tuple): - return (make_missing(a) for a in x) - if isinstance(x, list): - return [make_missing(a) for a in x] - if isinstance(x, dict): - return {k: make_missing(v) for k, v in x.items()} - if isinstance(x, str) and x.startswith("test-"): - return x.replace("test-", "missing-") - return x - - -class DatasetTester: - """Class to test various dataset operations.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize the DatasetTester. - - Parameters - ---------- - *args : Any - Arguments. - **kwargs : Any - Keyword arguments. - """ - self.ds = open_dataset(*args, **kwargs) - - args, kwargs = make_missing((args, kwargs)) - - print(f"ds={self.ds}") - - def run( - self, - *, - expected_class: type, - expected_length: int, - expected_shape: tuple, - expected_variables: str | list, - expected_name_to_index: str | dict, - date_to_row: Callable, - start_date: datetime.datetime, - time_increment: datetime.timedelta, - statistics_reference_dataset: str | list | None, - statistics_reference_variables: str | list | None, - ) -> None: - """Run the dataset tests. - - Parameters - ---------- - expected_class : Type - Expected class. - expected_length : int - Expected length. - expected_shape : tuple - Expected shape. - expected_variables : Union[str, list] - Expected variables. - expected_name_to_index : Union[str, dict] - Expected name to index mapping. - date_to_row : Callable - Function to generate row data. - start_date : datetime.datetime - Start date. - time_increment : datetime.timedelta - Time increment. - statistics_reference_dataset : Optional[Union[str, list]] - Reference dataset for statistics. - statistics_reference_variables : Optional[Union[str, list]] - Reference variables for statistics. - """ - if isinstance(expected_variables, str): - expected_variables = [v for v in expected_variables] - - if isinstance(expected_name_to_index, str): - expected_name_to_index = {v: i for i, v in enumerate(expected_name_to_index)} - - assert isinstance(self.ds, expected_class) - assert len(self.ds) == expected_length - assert len([row for row in self.ds]) == len(self.ds) - assert self.ds.shape == expected_shape, (self.ds.shape, expected_shape) - assert self.ds.variables == expected_variables - - assert set(self.ds.variables_metadata.keys()) == set(expected_variables) - - assert self.ds.name_to_index == expected_name_to_index - assert self.ds.dates[0] == start_date - assert self.ds.dates[1] - self.ds.dates[0] == time_increment - - dates = [] - date = start_date - - for row in self.ds: - # print(f"{date=} {row.shape}") - expect = date_to_row(date) - assert (row == expect).all() - dates.append(date) - date += time_increment - - assert (self.ds.dates == np.array(dates, dtype="datetime64")).all() - - if statistics_reference_dataset is not None: - self.same_stats( - self.ds, - open_dataset(statistics_reference_dataset), - statistics_reference_variables, - ) - - self.indexing(self.ds) - self.metadata(self.ds) - - self.ds.tree() - - def metadata(self, ds: Any) -> None: - """Test metadata. - - Parameters - ---------- - ds : Any - Dataset. - """ - metadata = ds.metadata() - assert isinstance(metadata, dict) - - def same_stats(self, ds1: Any, ds2: Any, vars1: list, vars2: list | None = None) -> None: - """Compare statistics between two datasets. - - Parameters - ---------- - ds1 : Any - First dataset. - ds2 : Any - Second dataset. - vars1 : list - Variables in the first dataset. - vars2 : Optional[list], optional - Variables in the second dataset, by default None. - """ - if vars2 is None: - vars2 = vars1 - - vars1 = [v for v in vars1] - vars2 = [v for v in vars2] - for v1, v2 in zip(vars1, vars2): - idx1 = ds1.name_to_index[v1] - idx2 = ds2.name_to_index[v2] - assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() - assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() - assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() - assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() - - def indexing(self, ds: Any) -> None: - """Test indexing. - - Parameters - ---------- - ds : Any - Dataset. - """ - t = IndexTester(ds) - - print("INDEXING", ds.shape) - - t[0:10, :, 0] - t[:, 0:3, 0] - # t[:, :, 0] - t[0:10, 0:3, 0] - t[:, :, :] - - if ds.shape[1] > 2: # Variable dimension - t[:, (1, 2), :] - t[:, (1, 2)] - - t[0] - t[0, :] - t[0, 0, :] - t[0, 0, 0, :] - - if ds.shape[2] > 1: # Ensemble dimension - t[0:10, :, (0, 1)] - - for i in range(3): - t[i] - start = 5 * i - end = len(ds) - 5 * i - step = len(ds) // 10 - - t[start:end:step] - t[start:end] - t[start:] - t[:end] - t[::step] - - -@mockup_open_zarr -def test_trim_edge_simple() -> None: - """Test trimming the edges of a dataset.""" - test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", - trim_edge=(2, 3, 4, 5), - ) - - expected_field_shape = (10, 5) - assert test.ds.field_shape == expected_field_shape, test.ds.field_shape - assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape - - -@mockup_open_zarr -def test_trim_edge_zeros() -> None: - """Test trimming the edges of a dataset when edges are 0""" - for dim in range(2): - trim_edge = [0, 0, 0, 0] - trim_edge[dim] = 1 - test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", - trim_edge=trim_edge, - ) - - expected_field_shape = (14, 14) - assert test.ds.field_shape == expected_field_shape, test.ds.field_shape - assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape - - for dim in range(2, 4): - trim_edge = [0, 0, 0, 0] - trim_edge[dim] = 1 - test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", - trim_edge=trim_edge, - ) - - expected_field_shape = (15, 13) - assert test.ds.field_shape == expected_field_shape, test.ds.field_shape - assert test.ds.shape == (365 * 4, 4, 1, np.prod(expected_field_shape)), test.ds.shape - - -if __name__ == "__main__": - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() From 309517e7d9a4ffa220730d7a3799629eebdd9f71 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 15:49:52 +0000 Subject: [PATCH 184/212] update --- src/anemoi/datasets/create/sources/csv.py | 40 +++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 0b293845e..8e5a329f5 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -8,17 +8,25 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.source import ObservationsSource -from anemoi.datasets.create.sources import source_registry +from ..source import Source +from . import source_registry @source_registry.register("csv") -class CSVSource(ObservationsSource): +class CSVSource(Source): """A source that reads data from a CSV file.""" emoji = "📄" # For tracing - def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): + def __init__( + self, + context: any, + path: str, + columns: list = None, + flavour: dict = None, + *args, + **kwargs, + ): """Initialise the CSVSource. Parameters @@ -27,16 +35,36 @@ def __init__(self, context: any, path: str, *args: tuple, **kwargs: dict): The context for the data source. filepath : str The path to the CSV file. + columns : list, optional + The list of columns to read from the CSV file. *args : tuple Additional positional arguments. **kwargs : dict Additional keyword arguments. """ super().__init__(context, *args, **kwargs) + self.path = path + self.columns = columns + + self.flavour = { + "latitude": "latitude", + "longitude": "longitude", + "time": "time", + } + + if flavour is not None: + self.flavour.update(flavour) def execute(self, dates): import pandas as pd - frame = pd.read_csv(self.path) - print(frame) + if self.columns is None: + frame = pd.read_csv(self.path) + else: + frame = pd.read_csv(self.path, usecols=self.columns) + + start, end = dates.window.start_date, dates.window.end_date + mask = (frame[self.flavour["time"]] >= start) & (frame[self.flavour["time"]] <= end) + frame = frame.loc[mask] + return frame From 90aa18e78d8cf387341b6efab081e3f868ef906a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 18:30:03 +0000 Subject: [PATCH 185/212] update --- src/anemoi/datasets/create/sources/csv.py | 7 +- src/anemoi/datasets/dates/__init__.py | 39 ++++++++-- src/anemoi/datasets/use/gridded/dataset.py | 2 +- .../datasets/use/gridded/fill_missing.py | 2 +- src/anemoi/datasets/use/gridded/misc.py | 2 +- src/anemoi/datasets/use/gridded/missing.py | 2 +- .../use/gridded/observations/multi.py | 2 +- tests/create/test_sources.py | 76 ++++++++++++++++--- tests/test_data.py | 27 +++++-- 9 files changed, 126 insertions(+), 33 deletions(-) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index 8e5a329f5..ba374fbff 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -64,7 +64,10 @@ def execute(self, dates): else: frame = pd.read_csv(self.path, usecols=self.columns) - start, end = dates.window.start_date, dates.window.end_date - mask = (frame[self.flavour["time"]] >= start) & (frame[self.flavour["time"]] <= end) + print(sorted(frame.columns)) + + mask = (frame[self.flavour["time"]] >= dates.start_date) & (frame[self.flavour["time"]] <= dates.end_date) + frame = frame.loc[mask] + return frame diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 223736971..0ce767418 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -85,7 +85,7 @@ class DatesProvider: 3 """ - def __init__(self, missing: list[str | datetime.datetime] | None = None) -> None: + def __init__(self, missing: list[str | datetime.datetime] | None = None, window: Any = None) -> None: """Initialize the DatesProvider with optional missing dates. Parameters @@ -95,16 +95,21 @@ def __init__(self, missing: list[str | datetime.datetime] | None = None) -> None """ if not missing: missing = [] + self.missing = list(extend(missing)) + if set(self.missing) - set(self.values): diff = set(self.missing) - set(self.values) warnings.warn(f"Missing dates {len(diff)=} not in list.") + self.window = window + @classmethod - def from_config(cls, **kwargs: Any) -> "DatesProvider": + def from_config(cls, *args, **kwargs: Any) -> "DatesProvider": """Create a DatesProvider instance from configuration. Args: + *args (Any): Positional arguments. **kwargs (Any): Configuration parameters. Returns @@ -112,13 +117,22 @@ def from_config(cls, **kwargs: Any) -> "DatesProvider": DatesProvider An instance of DatesProvider. """ - if kwargs.pop("hindcasts", False): - return HindcastsDates(**kwargs) - if "values" in kwargs: - return ValuesDates(**kwargs) + options = {} + for a in args: + if not isinstance(a, dict): + raise ValueError(f"Unexpected argument type {type(a)}") + options.update(a) + + options.update(kwargs) + + if options.pop("hindcasts", False): + return HindcastsDates(**options) + + if "values" in options: + return ValuesDates(**options) - return StartEndDates(**kwargs) + return StartEndDates(**options) def __iter__(self) -> Iterator[datetime.datetime]: """Iterate over the dates. @@ -260,11 +274,12 @@ def _(x): self.frequency = frequency missing = kwargs.pop("missing", []) + window = kwargs.pop("window", None) self.values = list(DateTimes(start, end, increment=frequency, **kwargs)) self.kwargs = kwargs - super().__init__(missing=missing) + super().__init__(missing=missing, window=window) def as_dict(self) -> dict[str, Any]: """Convert the StartEndDates instance to a dictionary. @@ -287,6 +302,14 @@ def to_python(self) -> str: else: return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) + @property + def start_date(self) -> datetime.datetime: + return self.start + + @property + def end_date(self) -> datetime.datetime: + return self.end + class Hindcast: """Class representing a single hindcast date. diff --git a/src/anemoi/datasets/use/gridded/dataset.py b/src/anemoi/datasets/use/gridded/dataset.py index 9969ca69c..39dd76078 100644 --- a/src/anemoi/datasets/use/gridded/dataset.py +++ b/src/anemoi/datasets/use/gridded/dataset.py @@ -244,7 +244,7 @@ def __subset(self, **kwargs: Any) -> "Dataset": return Rescale(self, rescale)._subset(**kwargs).mutate() if "statistics" in kwargs: - from anemoi.datasets.use import open_dataset + from anemoi.datasets import open_dataset from anemoi.datasets.use.gridded.statistics import Statistics statistics = kwargs.pop("statistics") diff --git a/src/anemoi/datasets/use/gridded/fill_missing.py b/src/anemoi/datasets/use/gridded/fill_missing.py index fb0c2f098..3aca7e639 100644 --- a/src/anemoi/datasets/use/gridded/fill_missing.py +++ b/src/anemoi/datasets/use/gridded/fill_missing.py @@ -14,7 +14,7 @@ import numpy as np from numpy.typing import NDArray -from anemoi.datasets.use import MissingDateError +from anemoi.datasets import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import TupleIndex diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index deab32c83..edd3f3c3e 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -595,7 +595,7 @@ def _open_dataset(*args: Any, **kwargs: Any) -> "Dataset": assert len(sets) > 0, (args, kwargs) if "set_group" in kwargs: - from anemoi.datasets.use.gridded.records import FieldsRecords + from anemoi.datasets.use.tabular.records import FieldsRecords set_group = kwargs.pop("set_group") assert len(sets) == 1, "set_group can only be used with a single dataset" diff --git a/src/anemoi/datasets/use/gridded/missing.py b/src/anemoi/datasets/use/gridded/missing.py index 298e0fc52..ed8754b5f 100644 --- a/src/anemoi/datasets/use/gridded/missing.py +++ b/src/anemoi/datasets/use/gridded/missing.py @@ -16,8 +16,8 @@ import numpy as np from numpy.typing import NDArray +from anemoi.datasets import MissingDateError from anemoi.datasets.create.utils import to_datetime -from anemoi.datasets.use import MissingDateError from anemoi.datasets.use.gridded.dataset import Dataset from anemoi.datasets.use.gridded.dataset import FullIndex from anemoi.datasets.use.gridded.dataset import TupleIndex diff --git a/src/anemoi/datasets/use/gridded/observations/multi.py b/src/anemoi/datasets/use/gridded/observations/multi.py index a6b6be176..5b2ca4967 100644 --- a/src/anemoi/datasets/use/gridded/observations/multi.py +++ b/src/anemoi/datasets/use/gridded/observations/multi.py @@ -10,7 +10,7 @@ import logging import os -from anemoi.datasets.use import open_dataset +from anemoi.datasets import open_dataset LOG = logging.getLogger(__name__) diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index e65e03bec..b9d89169d 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -289,17 +289,71 @@ def test_planetary_computer_conus404() -> None: def test_csv(get_test_data: callable) -> None: """Test for CSV source registration.""" from anemoi.datasets.create.sources import create_source + from anemoi.datasets.dates import DatesProvider data = get_test_data("anemoi-datasets/obs/dribu.csv") - # source = - create_source(context=None, config={"csv": {"path": data}}) - # window = DatesProvider.from_config( - # { - # "start": "2020-01-01T00:00:00", - # "end": "2020-01-02:23:59:59", - # "window": "(-3h:+3h]", - # } - # ) - - # source.execute(window) + source = create_source( + context=None, + config={ + "csv": { + "path": data, + "flavour": { + "time": [ + "typicalDate", + "typicalTime", + ] + }, + } + }, + ) + window = DatesProvider.from_config( + { + "start": "2020-01-01T00:00:00", + "end": "2020-01-02:23:59:59", + "window": "(-3h:+3h]", + } + ) + + source.execute(window) + + +@pytest.mark.skip(reason="ODB source currently not functional") +@skip_if_offline +def test_odb(get_test_data: callable) -> None: + from anemoi.datasets.create.sources import create_source + from anemoi.datasets.dates import DatesProvider + + data = get_test_data("anemoi-datasets/obs/dribu.odb") + + source = create_source(context=None, config={"odb": {"path": data}}) + window = DatesProvider.from_config( + { + "start": "2020-01-01T00:00:00", + "end": "2020-01-02:23:59:59", + "window": "(-3h:+3h]", + } + ) + + source.execute(window) + + +@pytest.mark.skip(reason="BUFR source currently not functional") +@skip_if_offline +def test_bufr(get_test_data: callable) -> None: + + from anemoi.datasets.create.sources import create_source + from anemoi.datasets.dates import DatesProvider + + data = get_test_data("anemoi-datasets/obs/dribu.bufr") + + source = create_source(context=None, config={"bufr": {"path": data}}) + window = DatesProvider.from_config( + { + "start": "2020-01-01T00:00:00", + "end": "2020-01-02:23:59:59", + "window": "(-3h:+3h]", + } + ) + + source.execute(window) diff --git a/tests/test_data.py b/tests/test_data.py index 7290cde17..670b29a88 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import rich import zarr from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta @@ -101,6 +102,7 @@ def create_zarr( vars: str = "abcd", start: int = 2021, end: int = 2021, + field_shape=[2, 5], frequency: datetime.timedelta = datetime.timedelta(hours=6), resolution: str = "o96", k: int = 0, @@ -118,6 +120,8 @@ def create_zarr( Start year, by default 2021. end : int, optional End year, by default 2021. + field_shape : list, optional + Field shape, by default [2, 5]. frequency : datetime.timedelta, optional Frequency, by default datetime.timedelta(hours=6). resolution : str, optional @@ -220,6 +224,10 @@ def create_zarr( compressor=None, ) + root.attrs["field_shape"] = field_shape + assert len(field_shape) == 2 + assert data.shape[-1] == field_shape[0] * field_shape[1] + return root @@ -252,12 +260,14 @@ def zarr_from_str(name: str, mode: str) -> zarr.Group: k=0, ensemble=None, grids=None, + field_shape="2,5", ) for name, bit in zip(args, name.split("-")): - args[name] = bit + if bit: + args[name] = bit - print(args) + rich.print(args) return create_zarr( start=int(args["start"]), @@ -269,6 +279,7 @@ def zarr_from_str(name: str, mode: str) -> zarr.Group: ensemble=int(args["ensemble"]) if args["ensemble"] is not None else None, grids=int(args["grids"]) if args["grids"] is not None else None, missing=args["test"] == "missing", + field_shape=list(map(int, args["field_shape"].split(","))), ) @@ -1309,7 +1320,7 @@ def test_grids() -> None: test = DatasetTester( grids=[ "test-2021-2021-6h-o96-abcd-1-1", # Default is 10 gridpoints - "test-2021-2021-6h-o96-abcd-2-1-25", # 25 gridpoints + "test-2021-2021-6h-o96-abcd-2-1-25-5,5", # 25 gridpoints ] ) test.run( @@ -1344,7 +1355,7 @@ def test_grids() -> None: ) ds1 = open_dataset("test-2021-2021-6h-o96-abcd-1-1") - ds2 = open_dataset("test-2021-2021-6h-o96-abcd-2-1-25") + ds2 = open_dataset("test-2021-2021-6h-o96-abcd-2-1-25-5,5") assert (test.ds.longitudes == np.concatenate([ds1.longitudes, ds2.longitudes])).all() assert (test.ds.latitudes == np.concatenate([ds1.latitudes, ds2.latitudes])).all() @@ -1381,6 +1392,7 @@ def test_cropping() -> None: assert test.ds.shape == (365 * 4, 4, 1, 8) +@pytest.mark.skip("Rolling average not yet supported in that branch") @mockup_open_zarr def test_rolling_average() -> None: initial = DatasetTester("test-2021-2021-6h-o96-abcd") @@ -1417,6 +1429,7 @@ def test_fields_to_records() -> None: assert ds.variables == {key: ["a", "b", "c", "d"]} +@pytest.mark.skip("Saving datasets not yet supported in that branch") def test_save_dataset() -> None: """Test save datasets.""" @@ -1442,7 +1455,7 @@ def mock_save_dataset(): def test_trim_edge_simple() -> None: """Test trimming the edges of a dataset.""" test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", + "test-2021-2021-6h-o96-abcd---210-15,14", trim_edge=(2, 3, 4, 5), ) @@ -1458,7 +1471,7 @@ def test_trim_edge_zeros() -> None: trim_edge = [0, 0, 0, 0] trim_edge[dim] = 1 test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", + "test-2021-2021-6h-o96-abcd---210-15,14", trim_edge=trim_edge, ) @@ -1470,7 +1483,7 @@ def test_trim_edge_zeros() -> None: trim_edge = [0, 0, 0, 0] trim_edge[dim] = 1 test = DatasetTester( - "test-2021-2021-15,14-6h-o96-abcd", + "test-2021-2021-6h-o96-abcd---210-15,14", trim_edge=trim_edge, ) From 4fb65cfea809325d8894028bcc04b92beba66052 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 19:08:01 +0000 Subject: [PATCH 186/212] update --- src/anemoi/datasets/create/sources/csv.py | 52 ++++++++++++++++++++++- tests/create/test_sources.py | 3 +- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index ba374fbff..f399f97c5 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -8,9 +8,13 @@ # nor does it submit to any jurisdiction. +import logging + from ..source import Source from . import source_registry +LOG = logging.getLogger(__name__) + @source_registry.register("csv") class CSVSource(Source): @@ -56,18 +60,64 @@ def __init__( if flavour is not None: self.flavour.update(flavour) + if not isinstance(self.flavour["time"], (list, tuple)): + self.flavour["time"] = self.flavour["time"].split(",") + def execute(self, dates): import pandas as pd + to_drop = [] + if self.columns is None: frame = pd.read_csv(self.path) else: frame = pd.read_csv(self.path, usecols=self.columns) + match len(self.flavour["time"]): + case 1: + self.make_time_column_1(frame, self.flavour["time"][0], to_drop) + case 2: + self.make_time_column_2(frame, self.flavour["time"][0], self.flavour["time"][1], to_drop) + case _: + raise ValueError(f"Invalid number of time columns specified in flavour. ({len(self.flavour['time'])=})") + + self.make_lat_lon_columns(frame, "latitude", to_drop) + self.make_lat_lon_columns(frame, "longitude", to_drop) + + if to_drop: + frame.drop(columns=to_drop, inplace=True) + print(sorted(frame.columns)) - mask = (frame[self.flavour["time"]] >= dates.start_date) & (frame[self.flavour["time"]] <= dates.end_date) + mask = (frame["time"] >= dates.start_date) & (frame["time"] <= dates.end_date) frame = frame.loc[mask] return frame + + def make_lat_lon_columns(self, frame, name, to_drop): + frame[name] = frame[self.flavour[name]].astype(float) + to_drop.append(self.flavour[name]) + + def make_time_column_1(self, frame, time_col, to_drop): + import pandas as pd + + if "time" in frame.columns and time_col != "time": + LOG.warning(f"Column 'time' already exists in data frame. Overwriting with '{time_col}'.") + to_drop.append(time_col) + + frame["time"] = pd.to_datetime(frame[time_col]) + + def make_time_column_2(self, frame, date_col, time_col, to_drop): + import pandas as pd + + if "time" in frame.columns: + LOG.warning(f"Column 'time' already exists in data frame. Overwriting with '{date_col}' and '{time_col}'.") + + # TODO: Read from format from flavour + frame[time_col] = frame[time_col].astype(str).str.zfill(6) + + frame["time"] = pd.to_datetime(frame[date_col].astype(str) + " " + frame[time_col].astype(str)) + + to_drop.append(date_col) + to_drop.append(time_col) diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index b9d89169d..ba55631b3 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -315,7 +315,8 @@ def test_csv(get_test_data: callable) -> None: } ) - source.execute(window) + frame = source.execute(window) + print(frame) @pytest.mark.skip(reason="ODB source currently not functional") From a915fb3337f8f5ede62e511c3aaeaad3e6f19087 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 13 Nov 2025 19:16:17 +0000 Subject: [PATCH 187/212] update --- src/anemoi/datasets/create/sources/csv.py | 3 ++- tests/create/test_sources.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/anemoi/datasets/create/sources/csv.py b/src/anemoi/datasets/create/sources/csv.py index f399f97c5..bd421283f 100644 --- a/src/anemoi/datasets/create/sources/csv.py +++ b/src/anemoi/datasets/create/sources/csv.py @@ -97,7 +97,8 @@ def execute(self, dates): def make_lat_lon_columns(self, frame, name, to_drop): frame[name] = frame[self.flavour[name]].astype(float) - to_drop.append(self.flavour[name]) + if self.flavour[name] != name: + to_drop.append(self.flavour[name]) def make_time_column_1(self, frame, time_col, to_drop): import pandas as pd diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index ba55631b3..143f1b737 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -309,14 +309,22 @@ def test_csv(get_test_data: callable) -> None: ) window = DatesProvider.from_config( { - "start": "2020-01-01T00:00:00", - "end": "2020-01-02:23:59:59", + "start": "2025-01-01T00:00:00", + "end": "2025-12-21T23:59:59", "window": "(-3h:+3h]", } ) frame = source.execute(window) - print(frame) + assert len(frame) == 2526 + + assert "latitude" in frame.columns, frame.columns + assert "longitude" in frame.columns, frame.columns + assert "time" in frame.columns, frame.columns + + assert frame["latitude"].dtype == float or np.issubdtype(frame["latitude"].dtype, np.floating) + assert frame["longitude"].dtype == float or np.issubdtype(frame["longitude"].dtype, np.floating) + assert frame["time"].dtype == "datetime64[ns]" or np.issubdtype(frame["time"].dtype, np.datetime64) @pytest.mark.skip(reason="ODB source currently not functional") From eecf7eed5527787c3e3aaa337b631833daea5c45 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 14 Nov 2025 08:38:43 +0100 Subject: [PATCH 188/212] added missing property, finish merging --- src/anemoi/datasets/use/gridded/misc.py | 2 +- src/anemoi/datasets/use/gridded/stores.py | 2 +- src/anemoi/datasets/use/tabular/records/__init__.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/anemoi/datasets/use/gridded/misc.py b/src/anemoi/datasets/use/gridded/misc.py index edd3f3c3e..91cf60919 100644 --- a/src/anemoi/datasets/use/gridded/misc.py +++ b/src/anemoi/datasets/use/gridded/misc.py @@ -373,7 +373,7 @@ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> " if "backend" not in load_any_dict_format(metadata_path): raise ValueError(f"Metadata for {path} does not contain 'backend' key") - from anemoi.datasets.use.gridded.records import open_records_dataset + from anemoi.datasets.use.tabular.records import open_records_dataset return open_records_dataset(path) diff --git a/src/anemoi/datasets/use/gridded/stores.py b/src/anemoi/datasets/use/gridded/stores.py index 6b7da317d..7593714c7 100644 --- a/src/anemoi/datasets/use/gridded/stores.py +++ b/src/anemoi/datasets/use/gridded/stores.py @@ -581,7 +581,7 @@ def dataset_lookup(name: str, fail: bool = True) -> Optional[str]: tried.append(full) try: - from anemoi.datasets.use.gridded.records import open_records_dataset + from anemoi.datasets.use.tabular.records import open_records_dataset z = open_records_dataset(full) if z is not None: diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index c637c6d80..de511f87c 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -238,6 +238,10 @@ def name_to_index(self): def frequency(self): return self.forward.frequency + @property + def metadata(self): + return self.forward.metadata + @property def _window(self): return self.forward._window From 4ae97b193673e65eb976b75a240384e935c2f817 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 14 Nov 2025 14:30:25 +0000 Subject: [PATCH 189/212] fix --- .../datasets/use/tabular/records/__init__.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index de511f87c..11c2b2565 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -151,18 +151,6 @@ def _subset(self, **kwargs): if window is not None: return Rewindowed(self, window)._subset(**kwargs) - frequency = kwargs.pop("frequency", self.frequency) - if frequency: - frequency = frequency_to_timedelta(frequency) - current = self.frequency.total_seconds() - new = frequency.total_seconds() - if current != new and current % new == 0: - return IncreaseFrequency(self, frequency)._subset(**kwargs) - elif current != new and new % current == 0: - raise NotImplementedError("Decreasing frequency not implemented yet") - # return DecreaseFrequency(self, frequency)._subset(**kwargs) - assert self.frequency == frequency, (self.frequency, frequency) - start = kwargs.pop("start", None) end = kwargs.pop("end", None) if start is not None or end is not None: @@ -176,9 +164,19 @@ def _dates_to_indices(start, end): return [i for i, date in enumerate(self.dates) if start <= date <= end] - return RecordsSubset( - self, _dates_to_indices(start, end), {"start": start, "end": end, "frequency": frequency} - )._subset(**kwargs) + return RecordsSubset(self, _dates_to_indices(start, end), {"start": start, "end": end})._subset(**kwargs) + + frequency = kwargs.pop("frequency", self.frequency) + if frequency: + frequency = frequency_to_timedelta(frequency) + current = self.frequency.total_seconds() + new = frequency.total_seconds() + if current != new and current % new == 0: + return IncreaseFrequency(self, frequency)._subset(**kwargs) + elif current != new and new % current == 0: + raise NotImplementedError("Decreasing frequency not implemented yet") + # return DecreaseFrequency(self, frequency)._subset(**kwargs) + assert self.frequency == frequency, (self.frequency, frequency) select = kwargs.pop("select", None) if select is not None: @@ -287,7 +285,7 @@ def _window(self): def __len__(self): return len(self.dataset) * self._n - @property + @cached_property def dates(self): dates = [] freq = _to_numpy_timedelta(self._frequency) @@ -932,3 +930,7 @@ def name_to_index(self): @property def statistics(self): return self.dataset.statistics[self.name] + + @property + def metadata(self): + return self.dataset.metadata From dda1a2045f09905d0cb3f71af7ebd4b66617f51f Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 14 Nov 2025 14:30:25 +0000 Subject: [PATCH 190/212] fix missing property for observations --- .../datasets/use/tabular/records/__init__.py | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py index de511f87c..11c2b2565 100644 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/__init__.py @@ -151,18 +151,6 @@ def _subset(self, **kwargs): if window is not None: return Rewindowed(self, window)._subset(**kwargs) - frequency = kwargs.pop("frequency", self.frequency) - if frequency: - frequency = frequency_to_timedelta(frequency) - current = self.frequency.total_seconds() - new = frequency.total_seconds() - if current != new and current % new == 0: - return IncreaseFrequency(self, frequency)._subset(**kwargs) - elif current != new and new % current == 0: - raise NotImplementedError("Decreasing frequency not implemented yet") - # return DecreaseFrequency(self, frequency)._subset(**kwargs) - assert self.frequency == frequency, (self.frequency, frequency) - start = kwargs.pop("start", None) end = kwargs.pop("end", None) if start is not None or end is not None: @@ -176,9 +164,19 @@ def _dates_to_indices(start, end): return [i for i, date in enumerate(self.dates) if start <= date <= end] - return RecordsSubset( - self, _dates_to_indices(start, end), {"start": start, "end": end, "frequency": frequency} - )._subset(**kwargs) + return RecordsSubset(self, _dates_to_indices(start, end), {"start": start, "end": end})._subset(**kwargs) + + frequency = kwargs.pop("frequency", self.frequency) + if frequency: + frequency = frequency_to_timedelta(frequency) + current = self.frequency.total_seconds() + new = frequency.total_seconds() + if current != new and current % new == 0: + return IncreaseFrequency(self, frequency)._subset(**kwargs) + elif current != new and new % current == 0: + raise NotImplementedError("Decreasing frequency not implemented yet") + # return DecreaseFrequency(self, frequency)._subset(**kwargs) + assert self.frequency == frequency, (self.frequency, frequency) select = kwargs.pop("select", None) if select is not None: @@ -287,7 +285,7 @@ def _window(self): def __len__(self): return len(self.dataset) * self._n - @property + @cached_property def dates(self): dates = [] freq = _to_numpy_timedelta(self._frequency) @@ -932,3 +930,7 @@ def name_to_index(self): @property def statistics(self): return self.dataset.statistics[self.name] + + @property + def metadata(self): + return self.dataset.metadata From 9ce559615cc42a485f6d926c907ecd7d7d62c273 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 17 Nov 2025 11:43:56 +0100 Subject: [PATCH 191/212] remove unused observation npz2 backend + slight optimisation in speed for npz1 --- .../use/tabular/records/backends/__init__.py | 67 ++----------------- 1 file changed, 6 insertions(+), 61 deletions(-) diff --git a/src/anemoi/datasets/use/tabular/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py index 5d27203ff..5ca924e5b 100644 --- a/src/anemoi/datasets/use/tabular/records/backends/__init__.py +++ b/src/anemoi/datasets/use/tabular/records/backends/__init__.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import io import json import logging import os @@ -51,20 +52,18 @@ class Npz1Backend(Backend): def __init__(self, *args, number_of_files_per_subdirectory=100, **kwargs): super().__init__(*args, **kwargs) self.number_of_files_per_subdirectory = number_of_files_per_subdirectory - self._cache = None + self._cache = LRUCache(maxsize=5) def read(self, i, **kwargs): - if self._cache is None: - self._cache = LRUCache(maxsize=5) if i in self._cache: return self._cache[i] d = str(int(i / self.number_of_files_per_subdirectory)) path = os.path.join(self.path, "data", d, f"{i}.npz") - with open(path, "rb") as f: - data = dict(np.load(f)) - self._cache[i] = data - return data + raw = open(path, "rb").read() + buffer = io.BytesIO(raw) + self._cache[i] = dict(np.load(buffer)) + return self._cache[i] def read_metadata(self): with open(os.path.join(self.path, "metadata.json")) as f: @@ -81,27 +80,6 @@ def read_statistics(self): return dic -class Npz2Backend(Backend): - def read(self, i, **kwargs): - path = os.path.join(self.path, "data_", str(int(i / 10)), f"{i}_.npz") - with open(path, "rb") as f: - return dict(np.load(f)) - - def read_metadata(self): - with open(os.path.join(self.path, "metadata.json")) as f: - return json.load(f) - - def read_statistics(self): - path = os.path.join(self.path, "statistics_.npz") - dic = {} - for k, v in dict(np.load(path)).items(): - key, group = k.split(":") - if group not in dic: - dic[group] = {} - dic[group][key] = v - return dic - - class Nc1Backend(Backend): number_of_files_per_subdirectory = 100 @@ -135,7 +113,6 @@ def read_statistics(self): def backend_factory(name, *args, **kwargs): BACKENDS = dict( npz1=Npz1Backend, - npz2=Npz2Backend, nc1=Nc1Backend, ) cls = BACKENDS[name] @@ -286,43 +263,11 @@ def write_statistics(self, statistics): np.savez(path, **flatten) -class Npz2WriteBackend(WriteBackend): - def write(self, i, data, **kwargs): - self._check_data(data) - path = os.path.join(self.path, "data_", str(int(i / 10))) - os.makedirs(path, exist_ok=True) - out_path = os.path.join(path, f"{i}_.npz") - np.savez(out_path, **data) - - def write_metadata(self, metadata): - from anemoi.datasets.create.gridded.tasks import _json_tidy - - os.makedirs(self.path, exist_ok=True) - with open(os.path.join(self.path, "metadata.json"), "w") as f: - json.dump(metadata, f, indent=2, default=_json_tidy) - - def write_statistics(self, statistics): - flatten = {} - for name, d in statistics.items(): - assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" - assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" - for k, v in d.items(): - assert isinstance( - v, (int, float, np.ndarray) - ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" - flatten[k + ":" + name] = v - - os.makedirs(self.path, exist_ok=True) - path = os.path.join(self.path, "statistics_.npz") - np.savez(path, **flatten) - - def writer_backend_factory(name, **kwargs): # choose the right backend for writing # this is intended to make benchmarking easier WRITE_BACKENDS = dict( npz1=Npz1WriteBackend, - npz2=Npz2WriteBackend, nc1=Nc1WriteBackend, ) return WRITE_BACKENDS[name](**kwargs) From d83468972245e8e8f1799b95084fdd6b12624ba2 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:03:45 +0100 Subject: [PATCH 192/212] update --- docs/building/code/using-python-1.py | 3 --- docs/building/code/using-python-2.py | 8 -------- docs/building/code/using-python-3.py | 12 ------------ docs/building/code/using-python-4.py | 0 docs/building/using-python.rst | 26 -------------------------- 5 files changed, 49 deletions(-) delete mode 100644 docs/building/code/using-python-1.py delete mode 100644 docs/building/code/using-python-2.py delete mode 100644 docs/building/code/using-python-3.py delete mode 100644 docs/building/code/using-python-4.py delete mode 100644 docs/building/using-python.rst diff --git a/docs/building/code/using-python-1.py b/docs/building/code/using-python-1.py deleted file mode 100644 index 196a25f42..000000000 --- a/docs/building/code/using-python-1.py +++ /dev/null @@ -1,3 +0,0 @@ -from anemoi.datasets.recipe import Recipe - -r = Recipe() diff --git a/docs/building/code/using-python-2.py b/docs/building/code/using-python-2.py deleted file mode 100644 index 717129592..000000000 --- a/docs/building/code/using-python-2.py +++ /dev/null @@ -1,8 +0,0 @@ -from anemoi.datasets.recipe import Recipe - -r = Recipe( - description="Example dataset recipe", - name="example-dataset", - licence="CC-BY-4.0", - attribution="my-organisation", -) diff --git a/docs/building/code/using-python-3.py b/docs/building/code/using-python-3.py deleted file mode 100644 index f21dc3947..000000000 --- a/docs/building/code/using-python-3.py +++ /dev/null @@ -1,12 +0,0 @@ -from anemoi.datasets.recipe import Recipe - -r = Recipe() - -r.description = """ -Example dataset recipe using Python, with attributes set one by one -and a multiline description. -""" - -r.name = "example-dataset" -r.licence = "CC-BY-4.0" -r.attribution = "my-organisation" diff --git a/docs/building/code/using-python-4.py b/docs/building/code/using-python-4.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/docs/building/using-python.rst b/docs/building/using-python.rst deleted file mode 100644 index fbf2892cf..000000000 --- a/docs/building/using-python.rst +++ /dev/null @@ -1,26 +0,0 @@ -############################# - Using Python define recipes -############################# - -You can use Python to define recipes for building datasets. This allows -for more complex logic and flexibility compared to using static -configuration files. - -When executed, the Python code will generate a YAML configuration that -can be used by the dataset building tool. - -Here is an example of how to define a dataset recipe using Python. - -First create a ``Recipe`` object, which will hold the configuration: - -.. literalinclude:: code/using-python-1.py - :language: python - -.. literalinclude:: code/using-python-2.py - :language: python - -.. literalinclude:: code/using-python-3.py - :language: python - -.. literalinclude:: code/using-python-4.py - :language: python From 8ba4057d2fb539ca5d910677dbd090e7b6150b16 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:05:02 +0100 Subject: [PATCH 193/212] update --- docs/building/introduction.rst | 11 -- tests/create/dont_test_observations.py | 73 --------- tests/create/dont_test_observations_mars.py | 130 --------------- .../dont_test_observations_mars_bufr.py | 128 --------------- ...ont_test_observations_mars_bufr_complex.py | 148 ------------------ ...nt_test_observations_mars_bufr_parallel.py | 129 --------------- 6 files changed, 619 deletions(-) delete mode 100644 tests/create/dont_test_observations.py delete mode 100644 tests/create/dont_test_observations_mars.py delete mode 100644 tests/create/dont_test_observations_mars_bufr.py delete mode 100644 tests/create/dont_test_observations_mars_bufr_complex.py delete mode 100644 tests/create/dont_test_observations_mars_bufr_parallel.py diff --git a/docs/building/introduction.rst b/docs/building/introduction.rst index 5ace4044e..e85a796c9 100644 --- a/docs/building/introduction.rst +++ b/docs/building/introduction.rst @@ -105,14 +105,3 @@ operations can be combined to build complex datasets. :caption: Naming Conventions naming-conventions - -**************** - Python recipes -**************** - -.. toctree:: - :maxdepth: 1 - :hidden: - :caption: Python recipes - - using-python diff --git a/tests/create/dont_test_observations.py b/tests/create/dont_test_observations.py deleted file mode 100644 index 671d522a1..000000000 --- a/tests/create/dont_test_observations.py +++ /dev/null @@ -1,73 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime - -import numpy as np -import pandas as pd - -from anemoi.datasets.create.sources.observations import ObservationsFilter -from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.tabular.windows import Interval -from anemoi.datasets.use.tabular.windows import window_from_str - - -class DummpySource(ObservationsSource): - def __init__(self, data): - assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" - self.data = data - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - if window.include_start: - mask = self.data["times"] > window.start - else: - mask = self.data["times"] >= window.start - if window.include_end: - mask &= self.data["times"] <= window.end - else: - mask &= self.data["times"] < window.end - - df = self.data[mask] - - return self._check(df) - - -class DummyFilter(ObservationsFilter): - def __call__(self, df): - """Filter the data based on the given window.""" - self._check(df) - # Here we can add any filtering logic if needed - df.loc[:, "a1"] = df["a1"] + 0.42 - return self._check(df) - - -dates = [datetime.datetime(2023, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] - -N = 100 -source = DummpySource( - pd.DataFrame( - { - "times": np.arange(N) * datetime.timedelta(hours=1) + dates[0], - "latitudes": -0.1 * np.arange(N), - "longitudes": -0.2 * np.arange(N), - "a1": np.arange(N) * 1.0, - "a2": np.arange(N) * 2.0, - } - ) -) -filter = DummyFilter() - -for d in dates: - window = window_from_str("(-5h, 1h]").to_interval(d) - d = source(window) - d = filter(d) - print(window) - print(d) diff --git a/tests/create/dont_test_observations_mars.py b/tests/create/dont_test_observations_mars.py deleted file mode 100644 index 91a814490..000000000 --- a/tests/create/dont_test_observations_mars.py +++ /dev/null @@ -1,130 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import pandas as pd -from earthkit.data import from_source - -from anemoi.datasets.create.sources.observations import ObservationsFilter -from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.tabular.windows import Interval -from anemoi.datasets.use.tabular.windows import window_from_str - -# from odb2df import process_odb - - -log = logging.getLogger(__name__) - - -class DummpySource(ObservationsSource): - def __init__(self, data): - assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" - self.data = data - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - if window.include_start: - mask = self.data["times"] > window.start - else: - mask = self.data["times"] >= window.start - if window.include_end: - mask &= self.data["times"] <= window.end - else: - mask &= self.data["times"] < window.end - - df = self.data[mask] - - return self._check(df) - - -class MarsObsSource(ObservationsSource): - def __init__(self, request_dict, pre_process_dict, process_func): - assert isinstance(request_dict, dict), "request_dict must be a dictionary" - self.request_dict = request_dict - self.pre_process_dict = pre_process_dict - self.process_func = process_func - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - request_dict = self.request_dict - request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" - try: - ekd_ds = from_source("mars", request_dict) - except Exception as e: - if "File is empty" in str(e): - log.warning( - f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." - ) - return - else: - raise # Re-raise if it's a different error - - data = self.process_func(ekd_ds, **self.pre_process_dict) - - if window.include_start: - mask = data["times"] > window.start - else: - mask = data["times"] >= window.start - if window.include_end: - mask &= data["times"] <= window.end - else: - mask &= data["times"] < window.end - - df = data[mask] - - return self._check(df) - - -class ColFilter(ObservationsFilter): - def __init__(self, col_name): - self.col_name = col_name - - def __call__(self, df): - """Filter the data based on the given window.""" - self._check(df) - # Here we can add any filtering logic if needed - df.loc[:, self.col_name] = df[self.col_name] + 0.42 - return self._check(df) - - -dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] - -source = MarsObsSource( - request_dict={ - "class": "ea", - "expver": "0001", - "stream": "oper", - "obsgroup": "conv", - "reportype": "16001/16002/16004/16065/16076", - "type": "ofb", - "time": "00/12", - "filter": "'select seqno,reportype,date,time,lat,lon,report_status,report_event1,entryno,varno,statid,stalt,obsvalue,lsm@modsurf,biascorr_fg,final_obs_error,datum_status@body,datum_event1@body,vertco_reference_1,vertco_type where ((varno==39 and abs(fg_depar@body)<20) or (varno in (41,42) and abs(fg_depar@body)<15) or (varno==58 and abs(fg_depar@body)<0.4) or (varno == 110 and entryno == 1 and abs(fg_depar@body)<10000) or (varno == 91)) and time in (000000,030000,060000,090000,120000,150000,180000,210000);'", - }, - pre_process_dict={ - # "target": odb2df.process_odb, - "index": ["seqno@hdr", "lat@hdr", "lon@hdr", "date@hdr", "time@hdr", "stalt@hdr", "lsm@modsurf"], - "pivot": ["varno@body"], - "values": ["obsvalue@body"], - "drop_na": True, - }, - # process_func=process_odb, -) -filter = ColFilter("obsvalue_v10m_0") - -for d in dates: - window = window_from_str("(-5h, 1h]").to_interval(d) - print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) - d = source(window) - d = filter(d) - print(window) - print(d) diff --git a/tests/create/dont_test_observations_mars_bufr.py b/tests/create/dont_test_observations_mars_bufr.py deleted file mode 100644 index 0d8b99fda..000000000 --- a/tests/create/dont_test_observations_mars_bufr.py +++ /dev/null @@ -1,128 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import pandas as pd - -# from bufr2df import bufr2df -from earthkit.data import from_source - -from anemoi.datasets.create.sources.observations import ObservationsFilter -from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.gridded.records import Interval -from anemoi.datasets.use.gridded.records import window_from_str - -log = logging.getLogger(__name__) - - -class DummpySource(ObservationsSource): - def __init__(self, data): - assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" - self.data = data - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - if window.include_start: - mask = self.data["times"] > window.start - else: - mask = self.data["times"] >= window.start - if window.include_end: - mask &= self.data["times"] <= window.end - else: - mask &= self.data["times"] < window.end - - df = self.data[mask] - - return self._check(df) - - -class MarsObsSource(ObservationsSource): - def __init__(self, request_dict, pre_process_dict, process_func): - assert isinstance(request_dict, dict), "request_dict must be a dictionary" - self.request_dict = request_dict - self.pre_process_dict = pre_process_dict - self.process_func = process_func - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - request_dict = self.request_dict - request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" - try: - ekd_ds = from_source("mars", request_dict) - except Exception as e: - if "File is empty" in str(e): - log.warning( - f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." - ) - return - else: - raise # Re-raise if it's a different error - - data = self.process_func(ekd_ds, **self.pre_process_dict) - - if window.include_start: - mask = data["times"] > window.start - else: - mask = data["times"] >= window.start - if window.include_end: - mask &= data["times"] <= window.end - else: - mask &= data["times"] < window.end - - df = data[mask] - - return self._check(df) - - -class ColFilter(ObservationsFilter): - def __init__(self, col_name): - self.col_name = col_name - - def __call__(self, df): - """Filter the data based on the given window.""" - self._check(df) - # Here we can add any filtering logic if needed - df.loc[:, self.col_name] = df[self.col_name] + 0.42 - return self._check(df) - - -dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] - -source = MarsObsSource( - request_dict={ - "class": "od", - "expver": "0001", - "stream": "LWDA", - "type": "ai", - "obstype": "nexrad_rr", - "times": "00/06/12/18", - }, - pre_process_dict={ - # "target": odb2df.process_odb, - "per_report": { - "latitude": "latitudes", - "longitude": "longitudes", - "radarRainfallIntensity": "obsvalue_precip1h_0", - }, - }, - # process_func=bufr2df, -) -filter = ColFilter("obsvalue_precip1h_0") - -for d in dates: - window = window_from_str("(-5h, 1h]").to_interval(d) - print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) - d = source(window) - d = filter(d) - print(window) - print(d) diff --git a/tests/create/dont_test_observations_mars_bufr_complex.py b/tests/create/dont_test_observations_mars_bufr_complex.py deleted file mode 100644 index efb722486..000000000 --- a/tests/create/dont_test_observations_mars_bufr_complex.py +++ /dev/null @@ -1,148 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import pandas as pd - -# from bufr2df_parallel import bufr2df_parallel -from earthkit.data import from_source - -from anemoi.datasets.create.sources.observations import ObservationsFilter -from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.gridded.records import Interval -from anemoi.datasets.use.gridded.records import window_from_str - -log = logging.getLogger(__name__) - - -class DummpySource(ObservationsSource): - def __init__(self, data): - assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" - self.data = data - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - if window.include_start: - mask = self.data["times"] > window.start - else: - mask = self.data["times"] >= window.start - if window.include_end: - mask &= self.data["times"] <= window.end - else: - mask &= self.data["times"] < window.end - - df = self.data[mask] - - return self._check(df) - - -class MarsObsSource(ObservationsSource): - def __init__(self, request_dict, pre_process_dict, process_func): - assert isinstance(request_dict, dict), "request_dict must be a dictionary" - self.request_dict = request_dict - self.pre_process_dict = pre_process_dict - self.process_func = process_func - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - request_dict = self.request_dict - request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" - try: - ekd_ds = from_source("mars", request_dict) - except Exception as e: - if "File is empty" in str(e): - log.warning( - f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." - ) - return - else: - raise # Re-raise if it's a different error - - data = self.process_func(ekd_ds, **self.pre_process_dict) - - if window.include_start: - mask = data["times"] > window.start - else: - mask = data["times"] >= window.start - if window.include_end: - mask &= data["times"] <= window.end - else: - mask &= data["times"] < window.end - - df = data[mask] - - return self._check(df) - - -class ColFilter(ObservationsFilter): - def __init__(self, col_name): - self.col_name = col_name - - def __call__(self, df): - """Filter the data based on the given window.""" - self._check(df) - # Here we can add any filtering logic if needed - df.loc[:, self.col_name] = df[self.col_name] + 0.42 - return self._check(df) - - -dates = [datetime.datetime(2015, 10, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] - -source = MarsObsSource( - request_dict={ - "class": "od", - "expver": "0001", - "stream": "DCDA/LWDA", - "type": "ai", - "obstype": "ssmis", - "times": "00/06/12/18", - }, - pre_process_dict={ - # "target": odb2df.process_odb, - "nproc": 12, - "prefilter_msg_header": {"satelliteID": 286.0}, - "datetime_position_prefix": "#1#", - "per_report": { - "satelliteID": "satelliteID", - "#1#latitude": "latitudes", - "#1#longitude": "longitudes", - # bearingOrAzimuth: azimuth - "fieldOfViewNumber": "fov_num", - "#9#brightnessTemperature": "obsvalue_rawbt_9", - "#10#brightnessTemperature": "obsvalue_rawbt_10", - "#11#brightnessTemperature": "obsvalue_rawbt_11", - "#12#brightnessTemperature": "obsvalue_rawbt_12", - "#13#brightnessTemperature": "obsvalue_rawbt_13", - "#14#brightnessTemperature": "obsvalue_rawbt_14", - "#15#brightnessTemperature": "obsvalue_rawbt_15", - "#16#brightnessTemperature": "obsvalue_rawbt_16", - "#17#brightnessTemperature": "obsvalue_rawbt_17", - "#18#brightnessTemperature": "obsvalue_rawbt_18", - }, - "filters": { - "longitudes": "lambda x: np.isfinite(x)", - "latitudes": "lambda x: np.isfinite(x)", - }, - }, - # process_func=bufr2df_parallel, -) -filter = ColFilter("obsvalue_rawbt_9") - -for d in dates: - window = window_from_str("(-5h, 1h]").to_interval(d) - print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) - d = source(window) - d = filter(d) - print(window) - print(d) - print(d["satelliteID"].unique()) diff --git a/tests/create/dont_test_observations_mars_bufr_parallel.py b/tests/create/dont_test_observations_mars_bufr_parallel.py deleted file mode 100644 index 369ee752b..000000000 --- a/tests/create/dont_test_observations_mars_bufr_parallel.py +++ /dev/null @@ -1,129 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import pandas as pd - -# from bufr2df_parallel import bufr2df_parallel -from earthkit.data import from_source - -from anemoi.datasets.create.sources.observations import ObservationsFilter -from anemoi.datasets.create.sources.observations import ObservationsSource -from anemoi.datasets.use.gridded.records import Interval -from anemoi.datasets.use.gridded.records import window_from_str - -log = logging.getLogger(__name__) - - -class DummpySource(ObservationsSource): - def __init__(self, data): - assert isinstance(data, pd.DataFrame), "Data must be a pandas DataFrame" - self.data = data - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - if window.include_start: - mask = self.data["times"] > window.start - else: - mask = self.data["times"] >= window.start - if window.include_end: - mask &= self.data["times"] <= window.end - else: - mask &= self.data["times"] < window.end - - df = self.data[mask] - - return self._check(df) - - -class MarsObsSource(ObservationsSource): - def __init__(self, request_dict, pre_process_dict, process_func): - assert isinstance(request_dict, dict), "request_dict must be a dictionary" - self.request_dict = request_dict - self.pre_process_dict = pre_process_dict - self.process_func = process_func - - def __call__(self, window): - assert isinstance(window, Interval), "window must be an Interval" - - request_dict = self.request_dict - request_dict["date"] = f"{window.start.strftime('%Y%m%d')}/to/{window.end.strftime('%Y%m%d')}" - try: - ekd_ds = from_source("mars", request_dict) - except Exception as e: - if "File is empty" in str(e): - log.warning( - f"Empty file for period {window.start.strftime('%Y%m%d')} to {window.end.strftime('%Y%m%d')}. Skipping." - ) - return - else: - raise # Re-raise if it's a different error - - data = self.process_func(ekd_ds, **self.pre_process_dict) - - if window.include_start: - mask = data["times"] > window.start - else: - mask = data["times"] >= window.start - if window.include_end: - mask &= data["times"] <= window.end - else: - mask &= data["times"] < window.end - - df = data[mask] - - return self._check(df) - - -class ColFilter(ObservationsFilter): - def __init__(self, col_name): - self.col_name = col_name - - def __call__(self, df): - """Filter the data based on the given window.""" - self._check(df) - # Here we can add any filtering logic if needed - df.loc[:, self.col_name] = df[self.col_name] + 0.42 - return self._check(df) - - -dates = [datetime.datetime(2025, 1, 1, 0, 0) + datetime.timedelta(hours=i * 8) for i in range(3)] - -source = MarsObsSource( - request_dict={ - "class": "od", - "expver": "0001", - "stream": "LWDA", - "type": "ai", - "obstype": "nexrad_rr", - "times": "00/06/12/18", - }, - pre_process_dict={ - # "target": odb2df.process_odb, - "nproc": 12, - "per_report": { - "latitude": "latitudes", - "longitude": "longitudes", - "radarRainfallIntensity": "obsvalue_precip1h_0", - }, - }, - # process_func=bufr2df_parallel, -) -filter = ColFilter("obsvalue_precip1h_0") - -for d in dates: - window = window_from_str("(-5h, 1h]").to_interval(d) - print(window.start.strftime("%Y-%m-%d"), window.end.strftime("%Y-%m-%d")) - d = source(window) - d = filter(d) - print(window) - print(d) From 7de1bde7daee8b3bcb06d7967e99e4abc0d458d3 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:10:49 +0100 Subject: [PATCH 194/212] update --- tools/build-obs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/build-obs.py b/tools/build-obs.py index bc407564a..db58cb5b6 100755 --- a/tools/build-obs.py +++ b/tools/build-obs.py @@ -28,7 +28,7 @@ def build(input, output, backend, overwrite=False): print(f"Dataset has {len(ds)} records, from {ds.start_date} to {ds.end_date}") print(f"Converting dataset to {output} using new backend '{backend}'") - from anemoi.datasets.use.gridded.tabular.records.backends import writer_backend_factory + from anemoi.datasets.use.tabular.records.backends import writer_backend_factory if not isinstance(backend, dict): backend = {"name": backend} From 44daea0564bfb4cc47a98996cf2f04aa20cad6bc Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:38:42 +0100 Subject: [PATCH 195/212] remove python --- src/anemoi/datasets/create/python.py | 578 --------------------------- src/anemoi/datasets/recipe.py | 532 ------------------------ 2 files changed, 1110 deletions(-) delete mode 100644 src/anemoi/datasets/create/python.py delete mode 100644 src/anemoi/datasets/recipe.py diff --git a/src/anemoi/datasets/create/python.py b/src/anemoi/datasets/create/python.py deleted file mode 100644 index 29b8c611d..000000000 --- a/src/anemoi/datasets/create/python.py +++ /dev/null @@ -1,578 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import logging -import re -from collections import defaultdict -from functools import cached_property - -LOG = logging.getLogger(__name__) - -RESERVED_KEYWORDS = ( - "and", - "or", - "not", - "is", - "in", - "if", - "else", - "elif", - "for", - "while", - "return", - "class", - "def", - "with", - "as", - "import", - "from", - "try", - "except", - "finally", - "raise", - "assert", - "break", - "continue", - "pass", -) - - -def _sanitize_name(name): - name = name.replace("-", "_") - if name in RESERVED_KEYWORDS: - name = f"{name}_" - return name - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class PythonCode: - - def __init__(self, top): - self.top = top - self.top.register(self) - self.key = str(id(self)) - - def call(self, name, argument): - return PythonCall(self.top, name, argument) - - def sum(self, actions): - return PythonChain(self.top, "join", "&", actions) - - def pipe(self, actions): - return PythonChain(self.top, "pipe", "|", actions) - - def concat(self, argument): - return PythonConcat(self.top, argument) - - def source_code(self): - return self.top.source_code(self) - - def combine(self, nodes): - return None - - def recipe(self, input, data_sources): - return PythonRecipe(self.top, input, data_sources) - - def prelude(self): - return None - - def sources(self, sources): - return PythonSources(self.top, sources) - - def update_anchor(self): - pass - - -class Variable(PythonCode): - def __init__(self, name, node): - super().__init__(top=node.top) - self.name = name - self.node = node - - def __repr__(self): - return "" - - def replace_node(self, old, new): - pass - - def prelude(self): - return [f"{self.name} = {repr(self.node)}", ""] - - -class InLine(PythonCode): - def __init__(self, node): - super().__init__(top=node.top) - self.node = node - - @cached_property - def name(self): - n = self.top.counter["_anchor"] - self.top.counter["_anchor"] += 1 - return f"_a{n}" - - def __repr__(self): - return f"({self.name} := {repr(self.node)})" - - def replace_node(self, old, new): - pass - - -class PythonRecipe(PythonCode): - def __init__(self, top, input, data_sources): - super().__init__(top) - self.input = input - self.data_sources = data_sources - - def apply_references(self, *path): - self.data_sources.apply_references(*path, "data_sources") - self.input.apply_references(*path, "input") - - def replace_node(self, old, new): - if self.input is old: - self.input = new - return - - if self.data_sources is old: - self.data_sources = new - return - - self.input.replace_node(old, new) - self.data_sources.replace_node(old, new) - - def __repr__(self): - return repr(self.input) - - def prelude(self): - return self.data_sources.prelude() - - -class Argument(PythonCode): - - def __init__(self, top, name): - super().__init__(top=top) - self.name = _sanitize_name(name) - - def __repr__(self): - return self.name - - def replace_node(self, old, new): - pass - - -class Anchor(PythonCode): - - def __init__(self, identifier): - super().__init__(top=identifier.node.top) - self.identifier = identifier - - @property - def name(self): - return self.identifier.name - - def __repr__(self): - # assert False - return repr(self.identifier) - - def replace_node(self, old, new): - pass - - -class Reference(PythonCode): - - def __init__(self, top, path): - super().__init__(top) - self.path = tuple(path) - self.anchor = None - - def update_anchor(self): - - node = self.top.by_reference.get(self.path, None) - if node is None: - LOG.warning(f"Reference {self.path} not found") - for p in sorted(self.top.by_reference): - LOG.warning(f" - {p}") - else: - self.anchor = Anchor(node) - self.top.replace_nodes([(node.node, self.anchor)]) - - def __repr__(self): - if self.anchor is not None: - return self.anchor.name - - return f"'${{{'.'.join(self.path)}}}'" - - def replace_node(self, old, new): - pass - - -class Function(PythonCode): - def __init__(self, name, node, counter): - super().__init__(top=node.top) - self._name = name - self.node = node - self.used = False - self.counter = counter - - def __repr__(self): - return self.name - - def prelude(self): - if self.used: - return None - - self.used = True - - node_prelude = self.node.prelude() - - arguments = self.node.free_arguments() - - return [ - *(node_prelude if node_prelude else []), - f"def {self.name}({','.join(repr(p) for p in arguments)}):", - f" return {self.node}", - ] - - def free_arguments(self): - return self.node.free_arguments() - - @cached_property - def name(self): - n = self.counter[self._name] - self.counter[self._name] += 1 - if n == 0: - return _sanitize_name(self._name) - return _sanitize_name(f"{self._name}_{n}") - - def replace_node(self, old, new): - if self.node is old: - self.node = new - - -class PythonSources(PythonCode): - def __init__(self, top, sources): - super().__init__(top) - self.sources = sources - - def __repr__(self): - return "" - - def prelude(self): - pass - - def replace_node(self, old, new): - for k, v in list(self.sources.items()): - if v is old: - self.sources[k] = new - else: - v.replace_node(old, new) - - def apply_references(self, *path): - for k, v in self.sources.items(): - self.top.by_reference[path + (k,)] = Variable(k, v) - - -class PythonConcat(PythonCode): - def __init__(self, top, argument): - super().__init__(top=top) - self.argument = _un_dotdict(argument) - - def __repr__(self): - return f"r.concat({self.argument})" - - def replace_node(self, old, new): - for k, v in list(self.argument.items()): - if v is old: - self.argument[k] = new - else: - v.replace_node(old, new) - - def apply_references(self, *path): - assert "concat" not in path, path - self.top.by_reference[path + ("concat",)] = InLine(self) - for i, node in enumerate(self.argument.values()): - node.apply_references(*path, "concat", str(i)) - - -class PythonChain(PythonCode): - def __init__(self, top, kind, op, actions): - super().__init__(top=top) - self.op = op - self.kind = kind - self.actions = list(actions) - self.key = op - - def __repr__(self): - return "(" + self.op.join(repr(x) for x in self.actions) + ")" - - def replace_node(self, old, new): - - for i, node in enumerate(self.actions): - - if node is old: - self.actions[i] = new - else: - node.replace_node(old, new) - - def apply_references(self, *path): - self.top.by_reference[path + (self.kind,)] = InLine(self) - for i, node in enumerate(self.actions): - node.apply_references(*path, self.kind, str(i)) - - -class PythonCall(PythonCode): - def __init__(self, top, name, argument): - super().__init__(top=top) - self.name = name - self.argument = _un_dotdict(argument) - self.key = name - - def free_arguments(self): - result = [] - for k, v in self.argument.items(): - if isinstance(v, Argument): - result.append(v) - return result - - def __repr__(self): - name = self.name.replace("-", "_") - config = dict(**self.argument) - - params = [] - - for k, v in config.items(): - k = _sanitize_name(k) - - if not k.isidentifier(): - return f"r.{name}({config})" - - params.append(f"{k}={repr(v)}") - - if params: - params.append("") # For a trailing comma - - params = ",".join(params) - return f"r.{name}({params})" - - def replace_node(self, old, new): - pass - - def combine(self, nodes): - - # Exact similarity - - changes = self._combine0(nodes) - if changes: - return changes - - # On key difference - changes = self._combine1(nodes) - if changes: - return changes - - def _combine0(self, nodes): - - x = defaultdict(list) - for node in nodes: - key = {k2: v2 for k2, v2 in sorted(node.argument.items())} - x[str(key)].append(node) - - for i in sorted(x.values(), key=len, reverse=True): - node = i[0] - if len(i) < 2: - return - - call = PythonCall(self.top, self.name, node.argument) - - func = self.top.function(call) - changes = [] - for node in i: - - new = PythonFunction(top=self.top, func=func) - - changes.append((node, new)) - - return changes - - def _combine1(self, nodes): - - x = defaultdict(list) - for node in nodes: - argument = node.argument - for k, v in argument.items(): - rest = {k2: v2 for k2, v2 in sorted(argument.items()) if k2 != k} - x[str(rest)].append((k, v, node)) - - for i in sorted(x.values(), key=len, reverse=True): - key, value, node = i[0] - if len(i) < 2: - return - - rest = {k: v for k, v in node.argument.items() if k != key} - rest[key] = Argument(self.top, key) - call = PythonCall(self.top, self.name, rest) - - func = self.top.function(call) - changes = [] - for key, value, node in i: - - new = PythonFunction( - top=self.top, - func=func, - **{key: value}, - ) - - changes.append((node, new)) - - return changes - - def apply_references(self, *path): - self.top.by_reference[path + (self.name,)] = InLine(self) - - for k, v in self.argument.items(): - if isinstance(v, str) and (m := re.match(r"^\${(\w+(?:\.\w+)+)}$", v)): - path = m.group(1).split(".") - self.argument[k] = Reference(self.top, path) - - -class PythonFunction(PythonCode): - def __init__(self, top, func, **kwargs): - super().__init__(top=top) - self.func = func - self.kwargs = kwargs - - def __repr__(self): - - params = [] - for a in self.func.free_arguments(): - name = _sanitize_name(a.name) - if a.name in self.kwargs: - v = self.kwargs[a.name] - params.append(f"{name}={repr(v)}") - else: - params.append(f"{name}={name}") - - return f"{self.func}({', '.join(params)})" - - def replace_node(self, old, new): - self.func.replace_node(old, new) - - def prelude(self): - return self.func.prelude() - - def free_arguments(self): - return [a for a in self.func.free_arguments() if a.name not in self.kwargs] - - def apply_references(self, *path): - pass - - -class PythonScript(PythonCode): - - def __init__(self): - self.nodes = [] - self.counter = defaultdict(int) - self.by_reference = {} - super().__init__(top=self) - - def register(self, child): - if child is not self: - self.nodes.append(child) - - def prelude(self, config): - - from anemoi.datasets.recipe import Recipe - - SKIP = ( - "input", - "data_sources", - "common", - "aliases", - ) - - result = [] - - for k, v in config.items(): - - if k in SKIP: - continue - - if not hasattr(Recipe, k): - LOG.warning(f"Unknown key in recipe: {k}") - assert False, f"Unknown key in recipe: {k}" - continue - - result.append(f"r.{k} = {repr(v)}") - - for node in self.nodes: - prelude = node.prelude() - if prelude: - if not isinstance(prelude, (list, tuple)): - prelude = list(prelude) - result.extend(prelude) - return "\n".join(result) - - def source_code(self, first, config): - - which = self.nodes.index(first) - first.apply_references() - for node in self.nodes: - node.update_anchor() - - more = True - while more: - more = False - - by_class = defaultdict(list) - for node in self.nodes: - by_class[(node.__class__, node.key)].append(node) - - for nodes in by_class.values(): - if len(nodes) > 1: - changes = nodes[0].combine(nodes) - if changes: - self.replace_nodes(changes) - more = True - - first = self.nodes[which] - - return "\n\n".join( - [ - "# Generated Python code for Anemoi dataset creation", - "import datetime", - "from anemoi.datasets.recipe import Recipe", - "r = Recipe()", - self.prelude(config), - f"r.input = {repr(first)}", - "r.dump()", - ] - ) - - def function(self, node): - return Function(node.name, node, self.counter) - - def replace_nodes(self, changes): - - for old, new in changes: - assert old in self.nodes, f"Node {old} not found in {self.nodes}" - for i, node in enumerate(self.nodes): - - if node is old: - self.nodes[i] = new - else: - node.replace_node(old, new) diff --git a/src/anemoi/datasets/recipe.py b/src/anemoi/datasets/recipe.py deleted file mode 100644 index 134f1cc27..000000000 --- a/src/anemoi/datasets/recipe.py +++ /dev/null @@ -1,532 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import os -import sys -from collections import defaultdict -from tempfile import TemporaryDirectory - -from anemoi.transform.filters import filter_registry as transform_filter_registry -from anemoi.utils.config import DotDict -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_string -from anemoi.utils.dates import frequency_to_timedelta - -# from anemoi.datasets.create.filters import filter_registry as datasets_filter_registry -from anemoi.datasets.create.sources import source_registry - -LOG = logging.getLogger(__name__) - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class Index: - def __init__(self, index): - self.name = str(index) - - def __repr__(self): - return f"Index({self.name})" - - def same(self, other): - if not isinstance(other, Index): - return False - return self.name == other.name - - -class Step: - - def __or__(self, other): - return Pipe(self, other) - - def __and__(self, other): - return Join(self, other) - - def same(self, other): - return self is other - - -class Chain(Step): - def __init__(self, *args): - if len(args) > 0 and isinstance(args[0], self.__class__): - args = args[0].steps + args[1:] - - self.steps = args - self.index = [Index(i) for i in range(len(self.steps))] - - def as_dict(self, recipe): - if len(self.steps) == 1: - return self.steps[0].as_dict(recipe) - return {self.name: [s.as_dict(recipe) for s in self.steps]} - - def path(self, target, result, *path): - - if target is self: - result.append([*path, self]) - return - - for i, s in enumerate(self.steps): - s.path(target, result, *path, self, self.index[i]) - - def collocated(self, a, b): - return True - - -class Pipe(Chain): - name = "pipe" - - -class Join(Chain): - name = "join" - - -class Concat(Step): - name = "concat" - - def __init__(self, args): - assert isinstance(args, dict), f"Invalid argument {args}" - self.params = args - - def __setitem__(self, key, value): - self.params[key] = value - - def as_dict(self, recipe): - - result = [] - - for k, v in sorted(self.params.items()): - - key = dict(start=as_datetime(k[0]), end=as_datetime(k[1])) - if len(k) == 3: - key["frequency"] = k[2] - - result.append({"dates": key, **v.as_dict(recipe)}) - - return {"concat": result} - - def collocated(self, a, b): - return a[0].same(b[0]) - - def path(self, target, result, *path): - if target is self: - result.append([*path, self]) - return - for i, (k, v) in enumerate(sorted(self.params.items())): - v.path(target, result, *path, self, Index(i)) - - -class Base(Step): - def __init__(self, owner, *args, **kwargs): - self.owner = owner - self.name = owner.name - self.params = {} - for a in args: - assert isinstance(a, dict), f"Invalid argument {a}" - self.params.update(a) - self.params.update(kwargs) - - def as_dict(self, recipe): - - def resolve(params, recipe, name=None): - if isinstance(params, dict): - - def _(k): - if isinstance(k, str) and k.endswith("_"): - return k[:-1] - return k - - return {_(k): resolve(v, recipe, name=_(k)) for k, v in params.items()} - - if isinstance(params, (list, tuple)): - return [resolve(v, recipe) for v in params] - - if isinstance(params, list): - return [resolve(v, recipe) for v in params] - - if isinstance(params, Step): - return recipe.resolve(self, params, name=name) - - return params - - return {self.owner.name: resolve(self.params, recipe)} - - def path(self, target, result, *path): - if self is target: - result.append([*path, self]) - - -class Source(Base): - pass - - -class Filter(Base): - pass - - -class SourceMaker: - def __init__(self, name, factory): - self.name = name - self.factory = factory - - def __call__(self, *args, **kwargs): - return Source(self, *args, **kwargs) - - -class FilterMaker: - def __init__(self, name, factory): - self.name = name - self.factory = factory - - def __call__(self, *args, **kwargs): - if len(args) > 0 and isinstance(args[0], Step): - prev = args[0] - args = args[1:] - return Pipe(prev, Filter(self, *args, **kwargs)) - return Filter(self, *args, **kwargs) - - -class Recipe: - - def __init__(self, name=None, description=None, attribution=None, licence=None): - - self._description = description - self._attribution = attribution - self._licence = licence - self._name = name - self._dates = None - self._statistics = None - self._build = None - self._env = None - self._dataset_status = None - self._output = None - self._platform = None - - self.input = Join() - self.output = DotDict() - self.statistics = DotDict() - self.build = DotDict() - - self._data_sources = {} - self._counter = defaultdict(int) - - sources = source_registry.factories.copy() - filters = transform_filter_registry.factories.copy() - - for key, factory in sources.items(): - if key in filters: - LOG.warning( - f"Source `{key}` is registered in anemoi.datasets source registry and in anemoi.transform filter registry" - ) - del filters[key] - - for key, factory in sources.items(): - key = key.replace("-", "_") - assert not hasattr(self, key) - setattr(self, key, SourceMaker(key, factory)) - - for key, factory in filters.items(): - key = key.replace("-", "_") - assert not hasattr(self, key) - setattr(self, key, FilterMaker(key, factory)) - - self.repeated_dates = SourceMaker("repeated_dates", None) - - def as_dict(self): - result = { - "name": self.name, - "description": self.description, - "attribution": self.attribution, - "licence": self.licence, - "dates": self.dates, - "statistics": self.statistics, - "build": self.build, - } - - if self._data_sources: - result["data_sources"] = self._data_sources - - for k, v in list(result.items()): - if v is None: - del result[k] - - return result - - def concat(self, *args, **kwargs): - return Concat(*args, **kwargs) - - def make_data_source(self, name, target): - - target = target.as_dict(self) - - name = name or "source" - if name in self._data_sources: - if self._data_sources[name] == target: - return f"${{data_sources.{name}}}" - - n = self._counter[name] - self._counter[name] += 1 - - name = f"{name}_{n}" if n > 0 else name - - self._data_sources[name] = target.copy() - return f"${{data_sources.{name}}}" - - def resolve(self, source, target, name=None): - - top = Index("input") # So we have 'input' first in the path - - path_to_source = [] - self.input.path(source, path_to_source, top) - if len(path_to_source) == 0: - raise ValueError(f"Source {source} not found in recipe") - if len(path_to_source) > 1: - raise ValueError(f"Source {source} found in multiple locations {path_to_source}") - path_to_source = path_to_source[0] - - path_to_target = [] - self.input.path(target, path_to_target, top) - if len(path_to_target) > 1: - raise ValueError(f"Target {target} found in multiple locations {path_to_target}") - - if len(path_to_target) == 0: - # Add a `data_sources` entry - return self.make_data_source(name, target) - - path_to_target = path_to_target[0] - - a = [s for s in path_to_target] - b = [s for s in path_to_source] - common_ancestor = None - while a[0] is b[0]: - common_ancestor = a[0] - a = a[1:] - b = b[1:] - - assert common_ancestor is not None, f"Common ancestor not found between {source} and {target}" - - if not common_ancestor.collocated(a, b): - source = ".".join(s.name for s in path_to_source) - target = ".".join(s.name for s in path_to_target) - raise ValueError( - f"Source ${{{source}}} and target ${{{target}}} are not collocated (i.e. they are not branch of a 'concat')" - ) - - target = ".".join(s.name for s in path_to_target) - return f"${{{target}}}" - - @property - def description(self): - return self._description - - @description.setter - def description(self, value): - self._description = value.strip() - - @property - def attribution(self): - return self._attribution - - @attribution.setter - def attribution(self, value): - self._attribution = value.strip() - - @property - def licence(self): - return self._licence - - @licence.setter - def licence(self, value): - self._licence = value.strip() - - @property - def name(self): - return self._name - - @name.setter - def name(self, value): - self._name = value.strip() - - @property - def dates(self): - return self._dates - - def _parse_dates(self, value): - - if isinstance(value, dict): - return value - - start = None - end = None - frequency = 1 - - if isinstance(value, (list, tuple)): - if len(value) in [2, 3]: - start = value[0] - end = value[1] - - if len(value) == 3: - frequency = frequency_to_string(frequency_to_timedelta(value[2])) - if isinstance(frequency, int): - frequency = f"{frequency}h" - - if start is None or end is None: - raise ValueError(f"Invalid dates {value}") - - if isinstance(frequency, int): - frequency = f"{frequency}h" - - return dict( - start=as_datetime(start), - end=as_datetime(end), - frequency=frequency, - ) - - @dates.setter - def dates(self, value): - self._dates = self._parse_dates(value) - - @property - def output(self): - return self._output - - @output.setter - def output(self, value): - self._output = value - - @property - def statistics(self): - return self._statistics - - @statistics.setter - def statistics(self, value): - self._statistics = value - - @property - def build(self): - return self._build - - @build.setter - def build(self, value): - self._build = value - - @property - def env(self): - return self._env - - @env.setter - def env(self, value): - self._env = value - - @property - def dataset_status(self): - return self._dataset_status - - @dataset_status.setter - def dataset_status(self, value): - self._dataset_status = value - - @property - def platform(self): - return self._platform - - @platform.setter - def platform(self, value): - self._platform = value - - def dump(self, file=sys.stdout): - input = self.input.as_dict(self) # First so we get the data_sources - - result = self.as_dict() - - result["input"] = input - - if self.output: - result["output"] = self.output - - if self.statistics: - result["statistics"] = self.statistics - - if self.build: - result["build"] = self.build - - if self.env: - result["env"] = self.env - - if self.dataset_status: - result["dataset_status"] = self.dataset_status - - if self.platform: - result["platform"] = self.platform - - from anemoi.datasets.dumper import yaml_dump - - yaml_dump(_un_dotdict(result), stream=file) - - def test(self, output="recipe.zarr"): - from argparse import ArgumentParser - - from anemoi.datasets.commands.create import command - - parser = ArgumentParser() - parser.add_argument("command", help="Command to run") - - cmd = command() - cmd.add_arguments(parser) - - with TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "recipe.yaml") - with open(path, "w") as file: - self.dump(file) - - args = parser.parse_args(["create", path, output, "--overwrite", "--test"]) - cmd.run(args) - - -if __name__ == "__main__": - - r = Recipe() - r.description = "test" - - r.dates = ("2023-01-01 00:00:00", "2023-12-31 18:00:00", "6h") - - m1 = r.mars(expver="0001", grid=[20, 20]) - m2 = r.mars(expver="0002") - m3 = r.mars(expver="0003") - - r.input = m1 - - r.input += r.forcings(template=m1, param=["cos_latitude", "sin_latitude"]) - - # m0 = r.mars(expver="0000") - # c = r.concat( - # { - # ("190", "2000"): m0, - # ("2001", "2020"): r.mars(expver="0002"), - # ("2021", "2023"): (r.mars(expver="0003") + r.forcings(template=m1, param=["cos_lat", "sin_lat"])), - # }, - # ) - - # c[("2031", "2033")] = r.mars(expver="0005") - - # r.input += c - - r.output.group_by = "day" - r.build.additions = True - r.statistics.end = "80%" - - r.dump() - r.test() From 4928e148526fe1adba0b1d8ec74a4467e632c029 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:42:03 +0100 Subject: [PATCH 196/212] remove python --- src/anemoi/datasets/create/input/__init__.py | 3 -- src/anemoi/datasets/create/input/action.py | 31 ------------------- .../datasets/create/input/data_sources.py | 5 --- 3 files changed, 39 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index f56bbd067..62f94b8cf 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -63,9 +63,6 @@ def select(self, context, argument) -> Any: self.action(context, argument), ) - def python_code(self, code): - return self.action.python_code(code) - def build_input(config: dict, data_sources: dict | list, **kwargs: Any) -> InputBuilder: """Build an InputBuilder instance. diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 831456435..62015a4ac 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -29,10 +29,6 @@ def __init__(self, config, *path): def __call__(self, context, argument): pass - @abstractmethod - def python_code(self, code): - pass - def __repr__(self): return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" @@ -68,11 +64,6 @@ def __call__(self, context, argument): return context.register(results, self.path) - def python_code(self, code): - return code.concat( - {filtering_dates.to_python(): action.python_code(code) for filtering_dates, action in self.choices} - ) - class Join(Action): def __init__(self, config, *path): @@ -93,9 +84,6 @@ def __call__(self, context, argument): return context.register(results, self.path) - def python_code(self, code) -> None: - return code.sum(a.python_code(code) for a in self.actions) - class Pipe(Action): def __init__(self, config, *path): @@ -117,9 +105,6 @@ def __call__(self, context, argument): return context.register(result, self.path) - def python_code(self, code) -> None: - return code.pipe(a.python_code(code) for a in self.actions) - class Function(Action): def __init__(self, config, *path): @@ -135,13 +120,6 @@ def __call__(self, context, argument): return context.register(self.call_object(context, source, argument), self.path) - def python_code(self, code) -> str: - # For now... - if "source" in self.config: - source = action_factory(self.config["source"], *self.path, "source") - self.config["source"] = source.python_code(code) - return code.call(self.name, self.config) - class DatasetSourceMixin: def create_object(self, context, config): @@ -225,9 +203,6 @@ def __init__(self, config, *path): else: self.sources = {i: action_factory(v, *path, str(i)) for i, v in enumerate(config)} - def python_code(self, code): - return code.sources({k: v.python_code(code) for k, v in self.sources.items()}) - def __call__(self, context, argument): for name, source in self.sources.items(): context.register(source(context, argument), self.path + (name,)) @@ -238,12 +213,6 @@ def __init__(self, input, data_sources): self.input = input self.data_sources = data_sources - def python_code(self, code): - return code.recipe( - self.input.python_code(code), - self.data_sources.python_code(code), - ) - def __call__(self, context, argument): # Load data_sources self.data_sources(context, argument) diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 7a706c8ef..31956d602 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -84,11 +84,6 @@ def __repr__(self) -> str: content = "\n".join([str(i) for i in self.sources]) return self._repr(content) - def python_code(self, code) -> str: - for n, s in zip(self.names, self.sources): - code.source(n, s.python_code(code)) - return code - class DataSourcesResult(Result): """Class to represent the result of data sources actions in the dataset creation process.""" From 941c500ae6eb41b732baabaa36ab4a86b6e875f8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 14:45:34 +0100 Subject: [PATCH 197/212] remove python --- src/anemoi/datasets/commands/check.py | 3 ++- src/anemoi/datasets/commands/cleanup.py | 3 ++- src/anemoi/datasets/commands/compare-lam.py | 3 ++- src/anemoi/datasets/commands/compare.py | 3 ++- src/anemoi/datasets/commands/copy.py | 3 ++- src/anemoi/datasets/commands/create.py | 25 ++++++++++++++++--- .../datasets/commands/finalise-additions.py | 3 ++- src/anemoi/datasets/commands/finalise.py | 3 ++- src/anemoi/datasets/commands/grib-index.py | 2 +- .../datasets/commands/init-additions.py | 3 ++- src/anemoi/datasets/commands/init.py | 3 ++- src/anemoi/datasets/commands/inspect.py | 17 ++++++------- .../datasets/commands/load-additions.py | 3 ++- src/anemoi/datasets/commands/load.py | 3 ++- src/anemoi/datasets/commands/patch.py | 3 ++- src/anemoi/datasets/commands/publish.py | 2 +- src/anemoi/datasets/commands/scan.py | 2 +- src/anemoi/datasets/commands/validate.py | 3 ++- 18 files changed, 57 insertions(+), 30 deletions(-) diff --git a/src/anemoi/datasets/commands/check.py b/src/anemoi/datasets/commands/check.py index 4202ed09f..61b29bf23 100644 --- a/src/anemoi/datasets/commands/check.py +++ b/src/anemoi/datasets/commands/check.py @@ -13,9 +13,10 @@ import yaml -from anemoi.datasets.commands import Command from anemoi.datasets.create.check import DatasetName +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/cleanup.py b/src/anemoi/datasets/commands/cleanup.py index 25b5b9ca0..0b3a393bd 100644 --- a/src/anemoi/datasets/commands/cleanup.py +++ b/src/anemoi/datasets/commands/cleanup.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/compare-lam.py b/src/anemoi/datasets/commands/compare-lam.py index 92ea9a6af..74d97bb48 100644 --- a/src/anemoi/datasets/commands/compare-lam.py +++ b/src/anemoi/datasets/commands/compare-lam.py @@ -12,7 +12,8 @@ import os from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command + +from . import Command RADIUS_EARTH_KM = 6371.0 # Earth's radius in kilometers diff --git a/src/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py index bbd121bd1..ffe1ec02e 100644 --- a/src/anemoi/datasets/commands/compare.py +++ b/src/anemoi/datasets/commands/compare.py @@ -15,7 +15,8 @@ import zarr from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command + +from . import Command class Compare(Command): diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 9628bae8e..406c13de7 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -20,7 +20,8 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr -from anemoi.datasets.commands import Command + +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 151b175d9..f0df9762d 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -18,13 +18,30 @@ import tqdm from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command +from . import Command LOG = logging.getLogger(__name__) -def task(what: str, fields: bool, options: dict, *args: Any, **kwargs: Any) -> Any: - """Make sure `import Creator` is done in the sub-processes, and not in the main one.""" +def task(what: str, options: dict, *args: Any, **kwargs: Any) -> Any: + """Make sure `import Creator` is done in the sub-processes, and not in the main one. + + Parameters + ---------- + what : str + The task to be executed. + options : dict + Options for the task. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Any + The result of the task. + """ now = datetime.datetime.now() LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") @@ -32,7 +49,7 @@ def task(what: str, fields: bool, options: dict, *args: Any, **kwargs: Any) -> A options = {k: v for k, v in options.items() if v is not None} - c = task_factory(what.replace("-", "_"), fields, **options) + c = task_factory(what.replace("-", "_"), **options) result = c.run() LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") diff --git a/src/anemoi/datasets/commands/finalise-additions.py b/src/anemoi/datasets/commands/finalise-additions.py index 25380fbba..811760ae9 100644 --- a/src/anemoi/datasets/commands/finalise-additions.py +++ b/src/anemoi/datasets/commands/finalise-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/finalise.py b/src/anemoi/datasets/commands/finalise.py index 5197fb73c..53999ad50 100644 --- a/src/anemoi/datasets/commands/finalise.py +++ b/src/anemoi/datasets/commands/finalise.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/grib-index.py b/src/anemoi/datasets/commands/grib-index.py index b5cc910d2..cfd7a08e8 100644 --- a/src/anemoi/datasets/commands/grib-index.py +++ b/src/anemoi/datasets/commands/grib-index.py @@ -13,7 +13,7 @@ import tqdm -from anemoi.datasets.commands import Command +from . import Command class GribIndexCmd(Command): diff --git a/src/anemoi/datasets/commands/init-additions.py b/src/anemoi/datasets/commands/init-additions.py index c49bbf76c..09788f0e4 100644 --- a/src/anemoi/datasets/commands/init-additions.py +++ b/src/anemoi/datasets/commands/init-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index c5aa515fd..0ca540b86 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 257bee122..384ee7d34 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,9 +27,10 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset -from anemoi.datasets.commands import Command -from anemoi.datasets.use.gridded.stores import dataset_lookup -from anemoi.datasets.use.gridded.stores import open_zarr +from anemoi.datasets.data.stores import open_zarr +from anemoi.datasets.data.stores import zarr_lookup + +from . import Command LOG = logging.getLogger(__name__) @@ -395,13 +396,9 @@ def progress(self) -> None: ) return - if self.build_flags is None: - print("🪫 Dataset not initialised") - return - - build_flags = self.build_flags + build_flags = self.build_flags or np.array([], dtype=bool) - build_lengths = self.build_lengths + build_lengths = self.build_lengths or np.array([], dtype=bool) assert build_flags.size == build_lengths.size latest_write_timestamp = self.zarr.attrs.get("latest_write_timestamp") @@ -813,7 +810,7 @@ def _info(self, path: str) -> Version: Version The version object of the dataset. """ - z = open_zarr(dataset_lookup(path)) + z = open_zarr(zarr_lookup(path)) metadata = dict(z.attrs) version = metadata.get("version", "0.0.0") diff --git a/src/anemoi/datasets/commands/load-additions.py b/src/anemoi/datasets/commands/load-additions.py index 82dec8b0a..a8cd5d5c9 100644 --- a/src/anemoi/datasets/commands/load-additions.py +++ b/src/anemoi/datasets/commands/load-additions.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index 7b1c1f684..3d969f5c3 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/patch.py b/src/anemoi/datasets/commands/patch.py index 1920fc420..dc8356126 100644 --- a/src/anemoi/datasets/commands/patch.py +++ b/src/anemoi/datasets/commands/patch.py @@ -13,9 +13,10 @@ from anemoi.utils.humanize import seconds_to_human -from anemoi.datasets.commands import Command from anemoi.datasets.commands.create import task +from . import Command + LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/publish.py b/src/anemoi/datasets/commands/publish.py index 47282e65b..7f719543e 100644 --- a/src/anemoi/datasets/commands/publish.py +++ b/src/anemoi/datasets/commands/publish.py @@ -10,7 +10,7 @@ import logging from typing import Any -from anemoi.datasets.commands import Command +from . import Command LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/commands/scan.py b/src/anemoi/datasets/commands/scan.py index 37c8d0cfd..8a048125e 100644 --- a/src/anemoi/datasets/commands/scan.py +++ b/src/anemoi/datasets/commands/scan.py @@ -17,7 +17,7 @@ import tqdm import yaml -from anemoi.datasets.commands import Command +from . import Command KEYS = ("class", "type", "stream", "expver", "levtype", "domain") diff --git a/src/anemoi/datasets/commands/validate.py b/src/anemoi/datasets/commands/validate.py index 03691541c..1382814a7 100644 --- a/src/anemoi/datasets/commands/validate.py +++ b/src/anemoi/datasets/commands/validate.py @@ -10,9 +10,10 @@ import logging from typing import Any -from anemoi.datasets.commands import Command from anemoi.datasets.validate import validate_dataset +from . import Command + LOG = logging.getLogger(__name__) DEFAULT_DATASET = "aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8" From 740d3ee46116ef4cd6d5a26eed821c86a5521dbe Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 16:29:12 +0100 Subject: [PATCH 198/212] update --- src/anemoi/datasets/commands/copy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 406c13de7..5020a208d 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -504,7 +504,6 @@ def add_arguments(self, command_parser: Any) -> None: default=100, help="For optimisation purposes, data is transfered by blocks. Default is 100.", ) - command_parser.add_argument("--workdir", help="Working directory for the copy operation.", default=".") command_parser.add_argument("source", help="Source location.") command_parser.add_argument("target", help="Target location.") @@ -534,7 +533,6 @@ def run(self, args: Any) -> None: resume=args.resume, verbosity=args.verbosity, threads=args.transfers, - workdir=args.workdir, ) copier.run() return From 99fa8d32ea10b905b62144ef2caf7a9d7772062a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 16:30:47 +0100 Subject: [PATCH 199/212] update --- src/anemoi/datasets/create/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/datasets/create/tasks.py b/src/anemoi/datasets/create/tasks.py index 05372d6d7..e803058be 100644 --- a/src/anemoi/datasets/create/tasks.py +++ b/src/anemoi/datasets/create/tasks.py @@ -53,7 +53,7 @@ def task_factory(name: str, fields: bool = True, trace: str | None = None, **kwa creator = TaskCreator() else: - from anemoi.datasets.create.observations.tasks import TaskCreator + from anemoi.datasets.create.tabular.tasks import TaskCreator creator = TaskCreator() From 895b9ade7fc49837aef0a57dc8a27047905e38a8 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:07:14 +0100 Subject: [PATCH 200/212] update --- src/anemoi/datasets/commands/create.py | 35 +- src/anemoi/datasets/commands/inspect.py | 6 +- src/anemoi/datasets/create/gridded/result.py | 35 -- src/anemoi/datasets/create/gridded/tasks.py | 24 -- .../create/sources/xarray_support/__init__.py | 56 ++-- .../create/sources/xarray_support/field.py | 6 +- .../sources/xarray_support/fieldlist.py | 12 +- .../create/sources/xarray_support/flavour.py | 38 +-- .../create/sources/xarray_support/metadata.py | 2 +- .../create/sources/xarray_support/time.py | 4 +- .../create/sources/xarray_support/variable.py | 2 +- tests/create/bufr2df.py | 122 ------- tests/create/bufr2df_parallel.py | 314 ------------------ tests/create/odb2df.py | 124 ------- 14 files changed, 82 insertions(+), 698 deletions(-) delete mode 100644 tests/create/bufr2df.py delete mode 100644 tests/create/bufr2df_parallel.py delete mode 100644 tests/create/odb2df.py diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index f0df9762d..b9ba80029 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -115,20 +115,18 @@ def serial_create(self, args: Any) -> None: options.pop("threads") options.pop("processes") - fields = args.path.endswith(".zarr") or args.path.endswith(".zarr/") + task("init", options) + task("load", options) + task("finalise", options) - task("init", fields, options) - task("load", fields, options) - task("finalise", fields, options) + task("init_additions", options) + task("load_additions", options) + task("finalise_additions", options) - task("init_additions", fields, options) - task("load_additions", fields, options) - task("finalise_additions", fields, options) + task("patch", options) - task("patch", fields, options) - - task("cleanup", fields, options) - task("verify", fields, options) + task("cleanup", options) + task("verify", options) def parallel_create(self, args: Any) -> None: """Create the dataset in parallel mode. @@ -149,7 +147,6 @@ def parallel_create(self, args: Any) -> None: threads = options.pop("threads") processes = options.pop("processes") - fields = args.path.endswith(".zarr") or args.path.endswith(".zarr/") use_threads = threads > 0 options["use_threads"] = use_threads @@ -160,7 +157,7 @@ def parallel_create(self, args: Any) -> None: ExecutorClass = ProcessPoolExecutor with ExecutorClass(max_workers=1) as executor: - total = executor.submit(task, "init", fields, options).result() + total = executor.submit(task, "init", options).result() futures = [] @@ -169,7 +166,7 @@ def parallel_create(self, args: Any) -> None: for n in range(total): opt = options.copy() opt["parts"] = f"{n+1}/{total}" - futures.append(executor.submit(task, "load", fields, opt)) + futures.append(executor.submit(task, "load", opt)) for future in tqdm.tqdm( as_completed(futures), desc="Loading", total=len(futures), colour="green", position=parallel + 1 @@ -180,7 +177,7 @@ def parallel_create(self, args: Any) -> None: executor.submit(task, "finalise", options).result() with ExecutorClass(max_workers=1) as executor: - executor.submit(task, "init-additions", fields, options).result() + executor.submit(task, "init-additions", options).result() with ExecutorClass(max_workers=parallel) as executor: for n in range(total): @@ -198,10 +195,10 @@ def parallel_create(self, args: Any) -> None: future.result() with ExecutorClass(max_workers=1) as executor: - executor.submit(task, "finalise-additions", fields, options).result() - executor.submit(task, "patch", fields, options).result() - executor.submit(task, "cleanup", fields, options).result() - executor.submit(task, "verify", fields, options).result() + executor.submit(task, "finalise-additions", options).result() + executor.submit(task, "patch", options).result() + executor.submit(task, "cleanup", options).result() + executor.submit(task, "verify", options).result() command = Create diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 384ee7d34..728714a68 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -27,8 +27,8 @@ from numpy.typing import NDArray from anemoi.datasets import open_dataset -from anemoi.datasets.data.stores import open_zarr -from anemoi.datasets.data.stores import zarr_lookup +from anemoi.datasets.use.gridded.stores import dataset_lookup +from anemoi.datasets.use.gridded.stores import open_zarr from . import Command @@ -810,7 +810,7 @@ def _info(self, path: str) -> Version: Version The version object of the dataset. """ - z = open_zarr(zarr_lookup(path)) + z = open_zarr(dataset_lookup(path)) metadata = dict(z.attrs) version = metadata.get("version", "0.0.0") diff --git a/src/anemoi/datasets/create/gridded/result.py b/src/anemoi/datasets/create/gridded/result.py index d4bcf58ea..787c4b6c3 100644 --- a/src/anemoi/datasets/create/gridded/result.py +++ b/src/anemoi/datasets/create/gridded/result.py @@ -563,41 +563,6 @@ def build_coords(self) -> None: self._cube: Any = cube - name_key = list(self.order_by.keys())[1] - - p = None - origins_per_number = defaultdict(lambda: defaultdict(set)) - - for fs in self.datasource: - o = fs.metadata("anemoi_origin", remapping=self.remapping, patches=self.patches) - name = fs.metadata(name_key, remapping=self.remapping, patches=self.patches) - number = fs.metadata("number", remapping=self.remapping, patches=self.patches) - - assert name not in origins_per_number[number][o], name - origins_per_number[number][o].add(name) - - if p is not o: - LOG.info(f"🔥🔥🔥🔥🔥🔥 Source: {name}, {o}") - p = o - - origins_per_variables = defaultdict(lambda: defaultdict(set)) - for number, origins in origins_per_number.items(): - for origin, names in origins.items(): - for name in names: - origins_per_variables[name][origin].add(number) - - origins = defaultdict(set) - - # Check if all members of a variable have the same origins - for name, origin_number in origins_per_variables.items(): - # For now we do not support variables with members from different origins - assert len(origin_number) == 1, origin_number - origins[list(origin_number.keys())[0]].add(name) - - self._origins = [] - for k, v in origins.items(): - self._origins.append({"origin": k.as_dict(), "variables": sorted(v)}) - self._coords_already_built: bool = True @property diff --git a/src/anemoi/datasets/create/gridded/tasks.py b/src/anemoi/datasets/create/gridded/tasks.py index d4cb1f288..7cb5f01cb 100644 --- a/src/anemoi/datasets/create/gridded/tasks.py +++ b/src/anemoi/datasets/create/gridded/tasks.py @@ -511,30 +511,6 @@ def _tidy(d): raise -def _config_to_python(config: Any) -> Any: - - from anemoi.datasets.create.create.python import PythonScript - - raw_config = config - - config = loader_config(config) - - input = InputBuilder(config.input, data_sources=config.get("data_sources", {})) - - code = PythonScript() - x = input.python_code(code) - code = code.source_code(x, raw_config) - - try: - import black - - return black.format_str(code, mode=black.Mode()) - # except ImportError: - except Exception: - LOG.warning("Black not installed, skipping formatting") - return code - - class TaskCreator: """A class to create and run dataset creation tasks.""" diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index c33ce7bfc..8e3cebc08 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -15,9 +15,11 @@ import xarray as xr from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.create.sources.xarray_support.fieldlist import XarrayFieldList + +from .. import source_registry +from ..legacy import LegacySource +from .fieldlist import XarrayFieldList LOG = logging.getLogger(__name__) @@ -151,26 +153,30 @@ def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern: return MultiFieldList(result) -@legacy_source("xarray") -def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the loading of datasets. - - Parameters - ---------- - context : Any - Context object. - dates : List[str] - List of dates. - url : str - URL pattern for loading datasets. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded datasets. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +@source_registry.register("xarray") +class LegacyXarraySource(LegacySource): + name = "xarray" + + @staticmethod + def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the loading of datasets. + + Parameters + ---------- + context : Any + Context object. + dates : List[str] + List of dates. + url : str + URL pattern for loading datasets. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded datasets. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 09cd6679c..7f6bb4fb3 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -17,9 +17,9 @@ from earthkit.data.core.fieldlist import math from numpy.typing import NDArray -from anemoi.datasets.create.sources.xarray_support.coordinates import extract_single_value -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.metadata import XArrayMetadata +from .coordinates import extract_single_value +from .coordinates import is_scalar +from .metadata import XArrayMetadata LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py index 174cb2716..48f9cf0e1 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py +++ b/src/anemoi/datasets/create/sources/xarray_support/fieldlist.py @@ -16,12 +16,12 @@ import yaml from earthkit.data import FieldList -from anemoi.datasets.create.sources.xarray_support.field import EmptyFieldList -from anemoi.datasets.create.sources.xarray_support.flavour import CoordinateGuesser -from anemoi.datasets.create.sources.xarray_support.patch import patch_dataset -from anemoi.datasets.create.sources.xarray_support.time import Time -from anemoi.datasets.create.sources.xarray_support.variable import FilteredVariable -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from .field import EmptyFieldList +from .flavour import CoordinateGuesser +from .patch import patch_dataset +from .time import Time +from .variable import FilteredVariable +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 74fcdbd03..80f0b6a62 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -17,25 +17,25 @@ import xarray as xr from anemoi.utils.config import DotDict -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import DateCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import EnsembleCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LatitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LevelCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import LongitudeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import PointCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import ScalarCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import StepCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import TimeCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import UnsupportedCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import XCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import YCoordinate -from anemoi.datasets.create.sources.xarray_support.coordinates import is_scalar -from anemoi.datasets.create.sources.xarray_support.grid import Grid -from anemoi.datasets.create.sources.xarray_support.grid import MeshedGrid -from anemoi.datasets.create.sources.xarray_support.grid import MeshProjectionGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredGrid -from anemoi.datasets.create.sources.xarray_support.grid import UnstructuredProjectionGrid +from .coordinates import Coordinate +from .coordinates import DateCoordinate +from .coordinates import EnsembleCoordinate +from .coordinates import LatitudeCoordinate +from .coordinates import LevelCoordinate +from .coordinates import LongitudeCoordinate +from .coordinates import PointCoordinate +from .coordinates import ScalarCoordinate +from .coordinates import StepCoordinate +from .coordinates import TimeCoordinate +from .coordinates import UnsupportedCoordinate +from .coordinates import XCoordinate +from .coordinates import YCoordinate +from .coordinates import is_scalar +from .grid import Grid +from .grid import MeshedGrid +from .grid import MeshProjectionGrid +from .grid import UnstructuredGrid +from .grid import UnstructuredProjectionGrid LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/metadata.py b/src/anemoi/datasets/create/sources/xarray_support/metadata.py index 2230db3ef..23713ae74 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/metadata.py +++ b/src/anemoi/datasets/create/sources/xarray_support/metadata.py @@ -46,7 +46,7 @@ def __init__(self, field: Any) -> None: field : Any The field to extract metadata from. """ - from anemoi.datasets.create.sources.xarray_support.field import XArrayField + from .field import XArrayField assert isinstance(field, XArrayField), type(field) self._field = field diff --git a/src/anemoi/datasets/create/sources/xarray_support/time.py b/src/anemoi/datasets/create/sources/xarray_support/time.py index 7b1f60e58..847b21598 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/time.py +++ b/src/anemoi/datasets/create/sources/xarray_support/time.py @@ -16,8 +16,8 @@ from anemoi.utils.dates import as_datetime -from anemoi.datasets.create.sources.xarray_support.coordinates import Coordinate -from anemoi.datasets.create.sources.xarray_support.variable import Variable +from .coordinates import Coordinate +from .variable import Variable LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/sources/xarray_support/variable.py b/src/anemoi/datasets/create/sources/xarray_support/variable.py index 13d6fa4e2..5d2c1c5b1 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/variable.py +++ b/src/anemoi/datasets/create/sources/xarray_support/variable.py @@ -17,7 +17,7 @@ import numpy as np import xarray as xr -from anemoi.datasets.create.sources.xarray_support.field import XArrayField +from .field import XArrayField LOG = logging.getLogger(__name__) diff --git a/tests/create/bufr2df.py b/tests/create/bufr2df.py deleted file mode 100644 index 892141041..000000000 --- a/tests/create/bufr2df.py +++ /dev/null @@ -1,122 +0,0 @@ -import eccodes -import numpy as np -import pandas as pd -import tqdm -from earthkit.data.readers.bufr.bufr import BUFRReader -from gribapi.errors import KeyValueNotFoundError - - -def filter_values(df: pd.DataFrame, filters: dict) -> pd.DataFrame: - """Filter the DataFrame based on the specified conditions""" - for col, condition in filters.items(): - if isinstance(condition, str): - condition = eval(condition) - if callable(condition): - df = df[df[col].apply(condition)] - elif isinstance(condition, slice): - start, stop = condition.start, condition.stop - query_str = f"({start} <= {col}) & ({col} < {stop})" - df = df.query(query_str) - elif isinstance(condition, (list, set)): - df = df[df[col].isin(condition)] - else: - raise ValueError(f"Invalid condition for column '{col}': {condition}") - return df - - -def bufr_get_array(bid: int, element: str, typ: type, nsubsets: int, missing_val=np.nan) -> np.ndarray: - """Wrapper for codes_get_array to work around the inconsistent handling of arrays in eccodes when data is constant""" - try: - arr = eccodes.codes_get_array(bid, element, typ) - if len(arr) == 1: - arr = np.ones(nsubsets, dtype=typ) * arr - except KeyValueNotFoundError: - arr = np.ones(nsubsets, dtype=typ) * missing_val - return arr - - -def extract_datetimes(bid: int, nreports: int) -> pd.DataFrame: - """Extracts and parses the date/time info from a bufr message - and returns as an array of datetime objects - """ - df = pd.DataFrame( - dict( - years=bufr_get_array(bid, "year", int, nreports), - months=bufr_get_array(bid, "month", int, nreports), - days=bufr_get_array(bid, "day", int, nreports), - hours=bufr_get_array(bid, "hour", int, nreports), - minutes=bufr_get_array(bid, "minute", int, nreports), - seconds=bufr_get_array(bid, "second", int, nreports, missing_val=0), - ) - ) - # Create the datetime series using pandas - datetimes = pd.to_datetime(df) - return datetimes - - -def get_msg(f, i, per_report_dict, per_datum_dict=None, filters=None) -> pd.DataFrame: - bid = eccodes.codes_bufr_new_from_file(f) - eccodes.codes_set(bid, "unpack", 1) - nreports = eccodes.codes_get(bid, "numberOfSubsets") - - data_dict = { - item: bufr_get_array(bid, col, float, nreports).astype(np.float32) for col, item in per_report_dict.items() - } - data_dict["times"] = extract_datetimes(bid, nreports) - - if per_datum_dict: - for col, sub_dict in per_datum_dict.items(): - ndatum = eccodes.codes_get_size(bid, next(iter(per_datum_dict))) // nreports - vals = bufr_get_array(bid, col, float, nreports * ndatum).astype(np.float32) - try: - vals_2d = vals.reshape(ndatum, nreports).T - except ValueError as e: - if "cannot reshape array" in str(e): - import warnings - - warnings.warn( - f"Reshape error in file {f}, message {i}: Cannot reshape array of size {len(vals)} " - f"into shape ({ndatum}, {nreports}). Skipping this message.", - RuntimeWarning, - ) - eccodes.codes_release(bid) - return None - else: - raise # Re-raise if it's a different ValueError - - for col_rename, slice_str in sub_dict.items(): - vals_col = vals_2d[:, eval(slice_str)] - for i in range(vals_col.shape[1]): - data_dict[f"{col_rename}_{i+1}"] = vals_col[:, i] - - df = pd.DataFrame(data_dict) - - if filters: - df = filter_values(df, filters) - - eccodes.codes_release(bid) - return df - - -def bufr2df( - ekd_ds: BUFRReader, - per_report: dict, - per_datum: dict = None, - filter: dict = None, -) -> pd.DataFrame: - """Extracts data from a BUFR file into a pandas DataFrame - -info on what to extract (and how it should be named in the dataframe) are - provided by input dictionaries; one at the per-report level and another for the per-datum - """ - fname = ekd_ds.path - with open(fname, "rb") as f: - nmessages = eccodes.codes_count_in_file(f) - bar = tqdm.tqdm( - iterable=range(nmessages), - desc="Processing bufr messages...", - mininterval=20.0, - ) - df_lst = [get_msg(f, i, per_report, per_datum, filter) for i in bar] - df = pd.concat(df_lst) - df = df.sort_values(by=["times"]).reset_index(drop=True) - return df diff --git a/tests/create/bufr2df_parallel.py b/tests/create/bufr2df_parallel.py deleted file mode 100644 index 04dd20e7a..000000000 --- a/tests/create/bufr2df_parallel.py +++ /dev/null @@ -1,314 +0,0 @@ -import logging -import mmap -import os -from multiprocessing import Pool - -import eccodes -import numpy as np -import pandas as pd -from earthkit.data.readers.bufr.bufr import BUFRReader -from gribapi.errors import KeyValueNotFoundError - - -def filter_values(df: pd.DataFrame, filters: dict) -> pd.DataFrame: - """Filter the DataFrame based on the specified conditions""" - for col, condition in filters.items(): - if isinstance(condition, str): - condition = eval(condition) - if callable(condition): - df = df[df[col].apply(condition)] - elif isinstance(condition, slice): - start, stop = condition.start, condition.stop - query_str = f"({start} <= {col}) & ({col} < {stop})" - df = df.query(query_str) - elif isinstance(condition, (list, set)): - df = df[df[col].isin(condition)] - else: - raise ValueError(f"Invalid condition for column '{col}': {condition}") - return df - - -log = logging.getLogger(__name__) -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(processName)s - %(levelname)s - %(message)s", - force=True, -) - - -def filter_bufr_message(bid: int, filter_config: dict) -> bool: - """Check if BUFR message meets filtering conditions specified in filter_config - Returns True if message should be kept, False if it should be filtered out - """ - namespace = {"inf": float("inf")} - - for key, condition in filter_config.items(): - try: - # Get the value from BUFR - value = eccodes.codes_get(bid, key) - - if isinstance(condition, str) and condition.startswith("lambda"): - # Lambda expression case - filter_condition = eval(condition, namespace) - if not filter_condition(value): - return False - else: - - # Direct value comparison case - if value != condition: - return False - - except eccodes.KeyValueNotFoundError: - logging.warning(f"Key {key} not found in BUFR message") - return False - except Exception as e: - logging.error(f"Error evaluating condition for {key}: {e}") - return False - - return True - - -def bufr_get_array(bid: int, element: str, typ: type, nsubsets: int, missing_val=np.nan) -> np.ndarray: - """Wrapper for codes_get_array to work around the inconsistent handling of arrays in eccodes when data is constant""" - try: - arr = eccodes.codes_get_array(bid, element, typ) - if len(arr) == 1: - arr = np.ones(nsubsets, dtype=typ) * arr - except KeyValueNotFoundError: - arr = np.ones(nsubsets, dtype=typ) * missing_val - return arr - - -def extract_datetimes(bid: int, nreports: int, position_prefix: str = "") -> pd.DataFrame: - """Extracts and parses the date/time info from a bufr message - and returns as an array of datetime objects - """ - df = pd.DataFrame( - dict( - years=bufr_get_array(bid, position_prefix + "year", int, nreports), - months=bufr_get_array(bid, position_prefix + "month", int, nreports), - days=bufr_get_array(bid, position_prefix + "day", int, nreports), - hours=bufr_get_array(bid, position_prefix + "hour", int, nreports), - minutes=bufr_get_array(bid, position_prefix + "minute", int, nreports), - seconds=bufr_get_array(bid, position_prefix + "second", int, nreports, missing_val=0), - ) - ) - # Create the datetime series using pandas - datetimes = pd.to_datetime(df) - return datetimes - - -def get_msg( - bufr_msg, - per_report: dict, - prefilter_msg_header: dict = {}, - prefilter_msg_data: dict = {}, - datetime_position_prefix: str = "", - per_datum: dict = None, - filters: dict = None, -) -> pd.DataFrame: - try: - bid = eccodes.codes_new_from_message(bufr_msg) - nreports = eccodes.codes_get(bid, "numberOfSubsets") - eccodes.codes_set(bid, "skipExtraKeyAttributes", 1) - - # Optionally filter messages based on header section entries - if prefilter_msg_header and not filter_bufr_message(bid, prefilter_msg_header): - eccodes.codes_release(bid) - return pd.DataFrame() - - eccodes.codes_set(bid, "unpack", 1) - - # Optionally filter messages based on data section entries - if prefilter_msg_data and not filter_bufr_message(bid, prefilter_msg_data): - eccodes.codes_release(bid) - return pd.DataFrame() - - data_dict = { - item: bufr_get_array(bid, col, float, nreports).astype(np.float32) for col, item in per_report.items() - } - - data_dict["times"] = extract_datetimes(bid, nreports, datetime_position_prefix) - - if per_datum: - for col, sub_dict in per_datum.items(): - ndatum = eccodes.codes_get_size(bid, next(iter(per_datum))) // nreports - vals = bufr_get_array(bid, col, float, nreports * ndatum).astype(np.float32) - try: - vals_2d = vals.reshape(ndatum, nreports).T - except ValueError as e: - if "cannot reshape array" in str(e): - import warnings - - warnings.warn( - f"Reshape error in bufr message {bufr_msg}: Cannot reshape array of size {len(vals)} " - f"into shape ({ndatum}, {nreports}). Skipping this message.", - RuntimeWarning, - ) - eccodes.codes_release(bid) - return None - else: - raise # Re-raise if it's a different ValueError - - for col_rename, slice_str in sub_dict.items(): - vals_col = vals_2d[:, eval(slice_str)] - for k in range(vals_col.shape[1]): - data_dict[f"{col_rename}_{k+1}"] = vals_col[:, k] - - df = pd.DataFrame(data_dict) - - if filters: - df = filter_values(df, filters) - - eccodes.codes_release(bid) - return df - except Exception as e: - import warnings - - warnings.warn( - f"Unexpected error in message: {str(e)}. Skipping this message.", - RuntimeWarning, - ) - if "bid" in locals(): - eccodes.codes_release(bid) - return None - - -class BufrData(object): - def __init__(self, BufrFileName): - self._filename = BufrFileName - self._fobj = open(self._filename, "rb") - self._fileno = self._fobj.fileno() - self._nmsg = eccodes.codes_count_in_file(self._fobj) - self._dataBlock = self.get_datablock() - self._lstOffsets = self.get_list_offsets() - - @property - def dataBlock(self): - return self._dataBlock - - @property - def nmsg(self): - return self._nmsg - - @property - def lstOffsets(self): - return self._lstOffsets - - def get_datablock(self): - with mmap.mmap(self._fileno, length=0, access=mmap.ACCESS_READ) as mobj: - data = mobj.read() - return data - - def get_list_offsets(self): - lstOffsets = [] - for _ in range(0, self._nmsg): - bid = eccodes.codes_bufr_new_from_file(self._fobj) - offset = eccodes.codes_get_message_offset(bid) - size = eccodes.codes_get_message_size(bid) - lstOffsets.append((offset, size)) - eccodes.codes_release(bid) - return lstOffsets - - def __del__(self): - self._fobj.close() - - -def read_block( - sublist, - dataBlock, - per_report: dict, - prefilter_msg_header: dict = None, - prefilter_msg_data: dict = None, - datetime_position_prefix: str = "", - per_datum: dict = None, - filters: dict = None, -): - log.info(f"PID : {os.getpid()} in read block sublist has {len(sublist)} elements") - try: - df_lst = [ - get_msg( - dataBlock[offset : offset + ch_size], - per_report, - prefilter_msg_header, - prefilter_msg_data, - datetime_position_prefix, - per_datum, - filters, - ) - for offset, ch_size in sublist - ] - return pd.concat(df_lst) - except Exception as e: - log.error(f"Error in read_block: {str(e)}") - raise - - -def split_list(alist, nparts): - nelem = len(alist) - chunkSize = nelem // (nparts) - sublists = [] - for i in range(0, nelem, chunkSize): - slist = alist[i : i + chunkSize] - sublists.append(slist) - return sublists - - -def bufr2df_parallel( - ekd_ds: BUFRReader, - per_report: dict, - nproc: int = 1, - prefilter_msg_header: dict = None, - prefilter_msg_data: dict = None, - datetime_position_prefix: str = "", - per_datum: dict = None, - filters: dict = None, -) -> pd.DataFrame: - fname = ekd_ds.path - mbfo = BufrData(fname) - fullDataBlock = mbfo.dataBlock - log.info(f"number of messages {mbfo.nmsg}") - sublists = split_list(mbfo.lstOffsets, nproc) - - nSubLists = len(sublists) - - pool = Pool(processes=nproc) - try: - results = [ - pool.apply_async( - read_block, - args=( - sublists[i], - fullDataBlock, - per_report, - prefilter_msg_header, - prefilter_msg_data, - datetime_position_prefix, - per_datum, - filters, - ), - ) - for i in range(0, nSubLists) - ] - all_lst = [] - for r in results: - try: - df = r.get() - all_lst.append(df) - except Exception as e: - log.error(f"Error getting result from worker process: {str(e)}") - continue - if not all_lst: - raise ValueError("No valid results were returned from any worker process") - finally: - pool.close() # Stop accepting new tasks - pool.join() # Wait for workers to finish with timeout - pool.terminate() # Force terminate if still running - - df = pd.concat(all_lst) - if len(df) > 0: - df = df.sort_values(by=["times"]).reset_index(drop=True) - - log.info(f"Number of rows in the dataframe {len(df)}") - - return df diff --git a/tests/create/odb2df.py b/tests/create/odb2df.py deleted file mode 100644 index 9ad31f1df..000000000 --- a/tests/create/odb2df.py +++ /dev/null @@ -1,124 +0,0 @@ -import json -import logging -from typing import Dict -from typing import List -from typing import Optional -from typing import Union - -import pandas as pd -from earthkit.data.readers.odb import ODBReader - -logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") - - -def load_varno_dict(path: Optional[str] = None) -> Dict: - """Load varno mapping, return empty dict if not found.""" - try: - with open(path or "varno.json") as f: - return json.load(f) - except (ValueError, Exception): - return {"data": []} - - -def get_varno_name(varno: Union[int, str], varno_dict: Dict) -> str: - """Get varno name or return original if not found.""" - try: - v = int(varno) - for entry in varno_dict.get("data", []): - if v in entry: - return str(entry[0]) - except (ValueError, Exception): - pass - return str(varno) - - -def rename_cols(cols: List, extra_obs: List[str] = None, varno_path: str = None) -> List[str]: - """Rename columns: base_name_varno_level""" - varno_dict = load_varno_dict(varno_path) - extra_obs = extra_obs or [] - - result = [] - for col in cols: - if isinstance(col, tuple): - parts = col + ("", "") - name, varno = parts[:2] - level = parts[2] if len(parts) > 2 else "" - else: - name, varno, level = col, "", "" - - base = name.split("@")[0] - if base in extra_obs: - base = f"obsvalue_{base}" - - if varno: - varno_name = get_varno_name(varno, varno_dict) - level_str = str(int(level)) if level and not isinstance(level, (list, tuple)) else "0" - result.append(f"{base}_{varno_name}_{level_str}") - else: - result.append(base) - - return result - - -def process_odb( - reader: ODBReader, - index: List[str], - pivot: List[str], - values: List[str], - sort: List[str] = None, - extra_obs: List[str] = None, - drop_na: bool = False, - datetime_cols: tuple = ("date@hdr", "time@hdr"), - varno_path: str = None, -) -> pd.DataFrame: - """Process ODB data: convert to pandas, pivot, rename columns.""" - - try: - df = reader.to_pandas() - except (ValueError, Exception) as e: - logging.error(f"ODB conversion failed: {e}") - return pd.DataFrame() - - if df.empty: - return df - - # Remove duplicates and pivot - df = df.drop_duplicates(subset=index + pivot, keep="first") - df = df.pivot(index=index, columns=pivot, values=values) - - # Sort and reset - if sort and all(c in df.index.names for c in sort): - df = df.sort_values(by=sort, kind="stable") - df = df.reset_index() - - # Reorganize columns - meta = df[index] - obs = df.drop(columns=index, level=0).sort_index(axis=1) - df = pd.concat([meta, obs], axis=1) - - if drop_na: - df = df.dropna() - - # Create datetime if both columns exist - date_col, time_col = datetime_cols - if date_col in df.columns and time_col in df.columns: - try: - df["times"] = pd.to_datetime( - df[date_col].astype(int).astype(str) + df[time_col].astype(int).astype(str).str.zfill(6), - format="%Y%m%d%H%M%S", - ) - df = df.drop(columns=[date_col, time_col], level=0) - except (ValueError, Exception): - logging.warning("Could not create datetime column") - - # Rename columns - df.columns = rename_cols(df.columns.tolist(), extra_obs, varno_path) - - # Rename lat/lon columns to match expected format - df = df.rename(columns={"lat": "latitudes", "lon": "longitudes"}).sort_values(by="times") - - return df - - -# Example usage: -# df = process_odb(reader, ["seqno@hdr", "lat@hdr", "lon@hdr"], ["varno@body"], ["obsvalue@body"]) From 0cfabd5d7d4eb03db9ad91491e41a91f588caf0e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:23:48 +0100 Subject: [PATCH 201/212] update --- .../datasets/create/gridded/__init__.py | 8 + .../datasets/create/sources/accumulations.py | 171 +++++------ .../datasets/create/sources/accumulations2.py | 62 ++-- .../datasets/create/sources/anemoi_dataset.py | 88 +++--- .../datasets/create/sources/constants.py | 77 ++--- .../datasets/create/sources/eccc_fstd.py | 4 +- src/anemoi/datasets/create/sources/empty.py | 48 +-- src/anemoi/datasets/create/sources/fdb.py | 5 +- .../datasets/create/sources/forcings.py | 57 ++-- src/anemoi/datasets/create/sources/generic.py | 3 +- src/anemoi/datasets/create/sources/grib.py | 171 ++++++----- .../datasets/create/sources/grib_index.py | 88 +++--- .../datasets/create/sources/hindcasts.py | 112 +++---- src/anemoi/datasets/create/sources/legacy.py | 75 +---- src/anemoi/datasets/create/sources/mars.py | 239 +++++++-------- src/anemoi/datasets/create/sources/netcdf.py | 58 ++-- src/anemoi/datasets/create/sources/opendap.py | 58 ++-- .../create/sources/planetary_computer.py | 4 +- .../datasets/create/sources/recentre.py | 86 +++--- .../datasets/create/sources/repeated_dates.py | 285 +----------------- src/anemoi/datasets/create/sources/source.py | 68 ----- src/anemoi/datasets/create/sources/xarray.py | 9 +- .../create/sources/xarray_kerchunk.py | 4 +- .../datasets/create/sources/xarray_zarr.py | 58 ++-- src/anemoi/datasets/create/sources/zenodo.py | 86 +++--- 25 files changed, 761 insertions(+), 1163 deletions(-) delete mode 100644 src/anemoi/datasets/create/sources/source.py diff --git a/src/anemoi/datasets/create/gridded/__init__.py b/src/anemoi/datasets/create/gridded/__init__.py index e69de29bb..c167afa25 100644 --- a/src/anemoi/datasets/create/gridded/__init__.py +++ b/src/anemoi/datasets/create/gridded/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/create/sources/accumulations.py b/src/anemoi/datasets/create/sources/accumulations.py index 40b8749f6..ce4ff6266 100644 --- a/src/anemoi/datasets/create/sources/accumulations.py +++ b/src/anemoi/datasets/create/sources/accumulations.py @@ -20,11 +20,13 @@ from earthkit.data.readers.grib.output import new_grib_output from numpy.typing import NDArray -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.create.sources import source_registry + +from .legacy import LegacySource +from .mars import mars LOG = logging.getLogger(__name__) +MISSING_VALUE = 1e-38 def _member(field: Any) -> int: @@ -167,6 +169,7 @@ def write(self, template: Any) -> None: # are used to store the end step edition = template.metadata("edition") + assert np.all(self.values != MISSING_VALUE) if edition == 1 and self.endStep > 254: self.out.write( @@ -175,6 +178,7 @@ def write(self, template: Any) -> None: stepType="instant", step=self.endStep, check_nans=True, + missing_value=MISSING_VALUE, ) else: self.out.write( @@ -184,6 +188,7 @@ def write(self, template: Any) -> None: startStep=self.startStep, endStep=self.endStep, check_nans=True, + missing_value=MISSING_VALUE, ) self.values = None self.done = True @@ -204,9 +209,6 @@ def add(self, field: Any, values: NDArray[Any]) -> None: if step not in self.steps: return - if not np.all(values >= 0): - warnings.warn(f"Negative values for {field}: {np.nanmin(values)} {np.nanmax(values)}") - assert not self.done, (self.key, step) assert step not in self.seen, (self.key, step) @@ -965,97 +967,76 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@legacy_source(__file__) -def accumulations( - context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any -) -> Any: - """Computes accumulations based on the provided context, dates, and request parameters. +@source_registry.register("accumulations") +class AccumulationsSource(LegacySource): - Parameters - ---------- - context : Any - Context for the computation. - dates : List[datetime.datetime] - List of dates. - use_cdsapi_dataset : Optional[str], optional - CDSAPI dataset to use. Defaults to None. - **request : Any - Additional request parameters. + @staticmethod + def _execute( + context: Any, dates: list[datetime.datetime], use_cdsapi_dataset: str | None = None, **request: Any + ) -> Any: + """Computes accumulations based on the provided context, dates, and request parameters. - Returns - ------- - Any - The computed accumulations. - """ - - if ( - request.get("class") == "ea" - and request.get("stream", "oper") == "oper" - and request.get("accumulation_period") == 24 - ): - from anemoi.datasets.create.sources.accumulations2 import accumulations as accumulations2 - - LOG.warning( - "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" - ) - return accumulations2(context, dates, **request) - - _to_list(request["param"]) - class_ = request.get("class", "od") - stream = request.get("stream", "oper") - - user_accumulation_period = request.pop("accumulation_period", 6) - accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) - user_date = request.pop("date", None) - - # If `data_accumulation_period` is not set, this means that the accumulations are from the start - # of the forecast. - - KWARGS = { - ("od", "oper"): dict(patch=_scda), - ("od", "elda"): dict(base_times=(6, 18)), - ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), - ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), - ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), - ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), - ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), - } - - kwargs = KWARGS.get((class_, stream), {}) - - context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") - - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - accumulations_reset_frequency=accumulations_reset_frequency, - use_cdsapi_dataset=use_cdsapi_dataset, - user_date=user_date, - **kwargs, - ) - - -execute = accumulations - -if __name__ == "__main__": - import yaml + Parameters + ---------- + context : Any + Context for the computation. + dates : List[datetime.datetime] + List of dates. + use_cdsapi_dataset : Optional[str], optional + CDSAPI dataset to use. Defaults to None. + **request : Any + Additional request parameters. - config = yaml.safe_load( + Returns + ------- + Any + The computed accumulations. """ - class: ea - expver: '0001' - grid: 20./20. - levtype: sfc -# number: [0, 1] -# stream: enda - param: [cp, tp] -# accumulation_period: 6h - """ - ) - dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - for f in accumulations(None, dates, **config): - print(f, f.to_numpy().mean()) + if ( + request.get("class") == "ea" + and request.get("stream", "oper") == "oper" + and request.get("accumulation_period") == 24 + ): + from .accumulations2 import Accumulations2Source + + LOG.warning( + "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24" + ) + return Accumulations2Source._execute(context, dates, **request) + + _to_list(request["param"]) + class_ = request.get("class", "od") + stream = request.get("stream", "oper") + + user_accumulation_period = request.pop("accumulation_period", 6) + accumulations_reset_frequency = request.pop("accumulations_reset_frequency", None) + user_date = request.pop("date", None) + + # If `data_accumulation_period` is not set, this means that the accumulations are from the start + # of the forecast. + + KWARGS = { + ("od", "oper"): dict(patch=_scda), + ("od", "elda"): dict(base_times=(6, 18)), + ("od", "enfo"): dict(base_times=(0, 6, 12, 18)), + ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), + ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), + ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)), + ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), + } + + kwargs = KWARGS.get((class_, stream), {}) + + context.trace("🌧️", f"accumulations {request} {user_accumulation_period} {kwargs}") + + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + accumulations_reset_frequency=accumulations_reset_frequency, + use_cdsapi_dataset=use_cdsapi_dataset, + user_date=user_date, + **kwargs, + ) diff --git a/src/anemoi/datasets/create/sources/accumulations2.py b/src/anemoi/datasets/create/sources/accumulations2.py index 3c34d392e..c6bf98843 100644 --- a/src/anemoi/datasets/create/sources/accumulations2.py +++ b/src/anemoi/datasets/create/sources/accumulations2.py @@ -18,9 +18,10 @@ from earthkit.data.core.temporary import temp_file from earthkit.data.readers.grib.output import new_grib_output -from anemoi.datasets.create.sources.legacy import legacy_source +from anemoi.datasets.create.sources import source_registry from anemoi.datasets.create.sources.mars import mars -from anemoi.datasets.create.utils import to_datetime_list + +from .legacy import LegacySource LOG = logging.getLogger(__name__) @@ -598,49 +599,20 @@ def _scda(request: dict[str, Any]) -> dict[str, Any]: return request -@legacy_source(__file__) -def accumulations(context, dates, **request): - _to_list(request["param"]) - user_accumulation_period = request.pop("accumulation_period", 6) - user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) - - context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") - - return _compute_accumulations( - context, - dates, - request, - user_accumulation_period=user_accumulation_period, - ) - +@source_registry.register("accumulations2") +class Accumulations2Source(LegacySource): -execute = accumulations + @staticmethod + def _execute(context, dates, **request): + _to_list(request["param"]) + user_accumulation_period = request.pop("accumulation_period", 6) + user_accumulation_period = datetime.timedelta(hours=user_accumulation_period) -if __name__ == "__main__": - import yaml + context.trace("🌧️", f"accumulations {request} {user_accumulation_period}") - config = yaml.safe_load( - """ - class: ea - expver: '0001' - grid: 20./20. - levtype: sfc -# number: [0, 1] -# stream: enda - param: [cp, tp] -# accumulation_period: 6h - accumulation_period: 2 - """ - ) - dates = yaml.safe_load("[2022-12-31 00:00, 2022-12-31 06:00]") - # dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - - class Context: - use_grib_paramid = True - - def trace(self, *args): - print(*args) - - for f in accumulations(Context, dates, **config): - print(f, f.to_numpy().mean()) + return _compute_accumulations( + context, + dates, + request, + user_accumulation_period=user_accumulation_period, + ) diff --git a/src/anemoi/datasets/create/sources/anemoi_dataset.py b/src/anemoi/datasets/create/sources/anemoi_dataset.py index a05e7df51..743605bb9 100644 --- a/src/anemoi/datasets/create/sources/anemoi_dataset.py +++ b/src/anemoi/datasets/create/sources/anemoi_dataset.py @@ -9,65 +9,69 @@ import numpy as np -from anemoi.datasets.create.sources.legacy import legacy_source +from . import source_registry +from .legacy import LegacySource -@legacy_source(__file__) -def execute(context, dates, params=None, **kwargs): - import earthkit.data as ekd +@source_registry.register("anemoi_dataset") +class AnemoiDatasetSource(LegacySource): - from anemoi.datasets import open_dataset + @staticmethod + def _execute(context, dates, params=None, **kwargs): + import earthkit.data as ekd - ds = open_dataset(**kwargs) - # dates_to_index = {date: i for i, date in enumerate(ds.dates)} + from anemoi.datasets import open_dataset - indices = [] - for date in dates: - idx = np.where(ds.dates == date)[0] - if len(idx) == 0: - continue - indices.append((int(idx[0]), date)) + ds = open_dataset(**kwargs) + # dates_to_index = {date: i for i, date in enumerate(ds.dates)} - vars = ds.variables - if params is None: - params = vars + indices = [] + for date in dates: + idx = np.where(ds.dates == date)[0] + if len(idx) == 0: + continue + indices.append((int(idx[0]), date)) - if not isinstance(params, (list, tuple, set)): - params = [params] + vars = ds.variables + if params is None: + params = vars - params = set(params) - results = [] + if not isinstance(params, (list, tuple, set)): + params = [params] - ensemble = ds.shape[2] > 1 - latitudes = ds.latitudes - longitudes = ds.longitudes + params = set(params) + results = [] - for idx, date in indices: + ensemble = ds.shape[2] > 1 + latitudes = ds.latitudes + longitudes = ds.longitudes - metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) + for idx, date in indices: - for j, y in enumerate(ds[idx]): + metadata = dict(valid_datetime=date, latitudes=latitudes, longitudes=longitudes) - param = vars[j] - if param not in params: - continue + for j, y in enumerate(ds[idx]): + + param = vars[j] + if param not in params: + continue - # metadata['name'] = param - # metadata['param_level'] = param - metadata["param"] = param + # metadata['name'] = param + # metadata['param_level'] = param + metadata["param"] = param - for k, e in enumerate(y): - if ensemble: - metadata["number"] = k + 1 + for k, e in enumerate(y): + if ensemble: + metadata["number"] = k + 1 - metadata["values"] = e + metadata["values"] = e - results.append(metadata.copy()) + results.append(metadata.copy()) - print(results[0].keys()) + print(results[0].keys()) - # "list-of-dicts" does support resolution - results = ekd.from_source("list-of-dicts", results) + # "list-of-dicts" does support resolution + results = ekd.from_source("list-of-dicts", results) - # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) - return results + # return new_fieldlist_from_list([new_field_from_latitudes_longitudes(x, latitudes, longitudes) for x in results]) + return results diff --git a/src/anemoi/datasets/create/sources/constants.py b/src/anemoi/datasets/create/sources/constants.py index accde7936..a805c4b16 100644 --- a/src/anemoi/datasets/create/sources/constants.py +++ b/src/anemoi/datasets/create/sources/constants.py @@ -11,41 +11,42 @@ from earthkit.data import from_source -from anemoi.datasets.create.sources.legacy import legacy_source - - -@legacy_source(__file__) -def constants(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: - """Deprecated function to retrieve constants data. - - Parameters - ---------- - context : Any - The context object for tracing. - dates : list of str - List of dates for which data is required. - template : dict of str to Any - Template dictionary for the data source. - param : str - Parameter to retrieve. - - Returns - ------- - Any - Data retrieved from the source. - """ - from warnings import warn - - warn( - "The source `constants` is deprecated, use `forcings` instead.", - DeprecationWarning, - stacklevel=2, - ) - context.trace("✅", f"from_source(constants, {template}, {param}") - if len(template) == 0: - raise ValueError("Forcings template is empty.") - - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) - - -execute: Any = constants +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("constants") +class ConstantsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], template: dict[str, Any], param: str) -> Any: + """Deprecated function to retrieve constants data. + + Parameters + ---------- + context : Any + The context object for tracing. + dates : list of str + List of dates for which data is required. + template : dict of str to Any + Template dictionary for the data source. + param : str + Parameter to retrieve. + + Returns + ------- + Any + Data retrieved from the source. + """ + from warnings import warn + + warn( + "The source `constants` is deprecated, use `forcings` instead.", + DeprecationWarning, + stacklevel=2, + ) + context.trace("✅", f"from_source(constants, {template}, {param}") + if len(template) == 0: + raise ValueError("Forcings template is empty.") + + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) diff --git a/src/anemoi/datasets/create/sources/eccc_fstd.py b/src/anemoi/datasets/create/sources/eccc_fstd.py index fdd79af8d..41734e9b6 100644 --- a/src/anemoi/datasets/create/sources/eccc_fstd.py +++ b/src/anemoi/datasets/create/sources/eccc_fstd.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("eccc_fstd") diff --git a/src/anemoi/datasets/create/sources/empty.py b/src/anemoi/datasets/create/sources/empty.py index f948810f5..fa8bc8d84 100644 --- a/src/anemoi/datasets/create/sources/empty.py +++ b/src/anemoi/datasets/create/sources/empty.py @@ -12,25 +12,29 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source - - -@legacy_source(__file__) -def execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: - """Executes the loading of an empty data source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - Loaded empty data source. - """ - return ekd.from_source("empty") +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("empty") +class EmptySource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], **kwargs: Any) -> ekd.FieldList: + """Executes the loading of an empty data source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + Loaded empty data source. + """ + return ekd.from_source("empty") diff --git a/src/anemoi/datasets/create/sources/fdb.py b/src/anemoi/datasets/create/sources/fdb.py index 81cdb7e13..bb33f7d50 100644 --- a/src/anemoi/datasets/create/sources/fdb.py +++ b/src/anemoi/datasets/create/sources/fdb.py @@ -16,10 +16,11 @@ from anemoi.transform.flavour import RuleBasedFlavour from anemoi.transform.grids import grid_registry -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry from anemoi.datasets.create.typing import DateList +from ..source import Source +from . import source_registry + @source_registry.register("fdb") class FdbSource(Source): diff --git a/src/anemoi/datasets/create/sources/forcings.py b/src/anemoi/datasets/create/sources/forcings.py index 88eca92e4..6070772fc 100644 --- a/src/anemoi/datasets/create/sources/forcings.py +++ b/src/anemoi/datasets/create/sources/forcings.py @@ -11,31 +11,32 @@ from earthkit.data import from_source -from anemoi.datasets.create.sources.legacy import legacy_source - - -@legacy_source(__file__) -def forcings(context: Any, dates: list[str], template: str, param: str) -> Any: - """Loads forcing data from a specified source. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - template : FieldList - Template for the data source. - param : str - Parameter for the data source. - - Returns - ------- - object - Loaded forcing data. - """ - context.trace("✅", f"from_source(forcings, {template}, {param}") - return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) - - -execute = forcings +from . import source_registry +from .legacy import LegacySource + + +@source_registry.register("forcings") +class ForcingsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], template: str, param: str) -> Any: + """Loads forcing data from a specified source. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + template : FieldList + Template for the data source. + param : str + Parameter for the data source. + + Returns + ------- + object + Loaded forcing data. + """ + context.trace("✅", f"from_source(forcings, {template}, {param}") + return from_source("forcings", source_or_dataset=template, date=list(dates), param=param) diff --git a/src/anemoi/datasets/create/sources/generic.py b/src/anemoi/datasets/create/sources/generic.py index a6675449a..0c6b23853 100644 --- a/src/anemoi/datasets/create/sources/generic.py +++ b/src/anemoi/datasets/create/sources/generic.py @@ -12,7 +12,8 @@ from earthkit.data import from_source -from . import source_registry +from anemoi.datasets.create.sources import source_registry + from .legacy import LegacySource diff --git a/src/anemoi/datasets/create/sources/grib.py b/src/anemoi/datasets/create/sources/grib.py index 550709f98..d709efc5e 100644 --- a/src/anemoi/datasets/create/sources/grib.py +++ b/src/anemoi/datasets/create/sources/grib.py @@ -20,7 +20,8 @@ from earthkit.data import from_source from earthkit.data.utils.patterns import Pattern -from anemoi.datasets.create.sources.legacy import legacy_source +from . import source_registry +from .legacy import LegacySource LOG = logging.getLogger(__name__) @@ -47,6 +48,14 @@ def check(ds: Any, paths: list[str], **kwargs: Any) -> None: if isinstance(v, (tuple, list)): count *= len(v) + # in the case of static data (e.g repeated dates) dates might be empty + if len(ds) != count and kwargs.get("dates", []) == []: + LOG.warning( + f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})" + f" Received empty dates - assuming this is static data." + ) + return + if len(ds) != count: raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, paths={paths})") @@ -73,81 +82,85 @@ def _expand(paths: list[str]) -> Any: yield path -@legacy_source(__file__) -def execute( - context: Any, - dates: list[Any], - path: str | list[str], - flavour: str | dict[str, Any] | None = None, - grid_definition: dict[str, Any] | None = None, - *args: Any, - **kwargs: Any, -) -> ekd.FieldList: - """Executes the function to load data from GRIB files. - - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : list of Any - List of dates. - path : str or list of str - Path or list of paths to the GRIB files. - flavour : str or dict of str to Any, optional - Flavour information, by default None. - grid_definition : dict of str to Any, optional - Grid definition configuration to create a Grid object, by default None. - *args : Any - Additional positional arguments. - **kwargs : Any - Additional keyword arguments. - - Returns - ------- - Any - The loaded dataset. - """ - given_paths = path if isinstance(path, list) else [path] - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - if grid_definition is not None: - grid = grid_registry.from_config(grid_definition) - else: - grid = None - - ds = from_source("empty") - dates = [d.isoformat() for d in dates] - - for path in given_paths: - paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) - - for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): - if name in kwargs: - raise ValueError(f"MARS interpolation parameter '{name}' not supported") - - for path in _expand(paths): - context.trace("📁", "PATH", path) - s = from_source("file", path) - if flavour is not None: - s = flavour.map(s) - s = s.sel(valid_datetime=dates, **kwargs) - ds = ds + s - - if kwargs and not context.partial_ok: - check(ds, given_paths, valid_datetime=dates, **kwargs) - - if grid is not None: - - lat, lon = grid.latlon() - - assert len(lat) == len(lon), (len(lat), len(lon)) - for f in ds: - assert len(f.to_numpy(flatten=True)) == len(lat), (len(f.to_numpy(flatten=True)), len(lat)) - - ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) - - if len(ds) == 0: - LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") - - return ds +@source_registry.register("grib") +class GribSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[Any], + path: str | list[str], + flavour: str | dict[str, Any] | None = None, + grid_definition: dict[str, Any] | None = None, + *args: Any, + **kwargs: Any, + ) -> ekd.FieldList: + """Executes the function to load data from GRIB files. + + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : list of Any + List of dates. + path : str or list of str + Path or list of paths to the GRIB files. + flavour : str or dict of str to Any, optional + Flavour information, by default None. + grid_definition : dict of str to Any, optional + Grid definition configuration to create a Grid object, by default None. + *args : Any + Additional positional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Any + The loaded dataset. + """ + given_paths = path if isinstance(path, list) else [path] + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + if grid_definition is not None: + grid = grid_registry.from_config(grid_definition) + else: + grid = None + + ds = from_source("empty") + dates = [d.isoformat() for d in dates] + + for path in given_paths: + + # do not substitute if not needed + if "{" not in path: + paths = [path] + else: + paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs) + + for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"): + if name in kwargs: + raise ValueError(f"MARS interpolation parameter '{name}' not supported") + + for path in _expand(paths): + context.trace("📁", "PATH", path) + s = from_source("file", path) + if flavour is not None: + s = flavour.map(s) + sel_kwargs = kwargs.copy() + if dates != []: + sel_kwargs["valid_datetime"] = dates + s = s.sel(**sel_kwargs) + ds = ds + s + + if kwargs and not context.partial_ok: + check(ds, given_paths, valid_datetime=dates, **kwargs) + + if grid is not None: + ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds]) + + if len(ds) == 0: + LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})") + + return ds diff --git a/src/anemoi/datasets/create/sources/grib_index.py b/src/anemoi/datasets/create/sources/grib_index.py index 160ff3f3a..0d86732f6 100644 --- a/src/anemoi/datasets/create/sources/grib_index.py +++ b/src/anemoi/datasets/create/sources/grib_index.py @@ -19,7 +19,8 @@ from cachetools import LRUCache from earthkit.data.indexing.fieldlist import FieldArray -from anemoi.datasets.create.sources.legacy import legacy_source +from . import source_registry +from .legacy import LegacySource LOG = logging.getLogger(__name__) @@ -569,44 +570,47 @@ def retrieve(self, dates: list[Any], **kwargs: Any) -> Iterator[Any]: yield data -@legacy_source(__file__) -def execute( - context: Any, - dates: list[Any], - indexdb: str, - flavour: str | None = None, - **kwargs: Any, -) -> FieldArray: - """Execute the GRIB data retrieval process. - - Parameters - ---------- - context : Any - The execution context. - dates : List[Any] - List of dates to retrieve data for. - indexdb : str - Path to the GRIB index database. - flavour : Optional[str], optional - Flavour configuration for mapping fields, by default None. - **kwargs : Any - Additional filtering criteria. - - Returns - ------- - FieldArray - An array of retrieved GRIB fields. - """ - index = GribIndex(indexdb) - result = [] - - if flavour is not None: - flavour = RuleBasedFlavour(flavour) - - for grib in index.retrieve(dates, **kwargs): - field = ekd.from_source("memory", grib)[0] - if flavour: - field = flavour.apply(field) - result.append(field) - - return FieldArray(result) +@source_registry.register("grib_index") +class GribIndexSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[Any], + indexdb: str, + flavour: str | None = None, + **kwargs: Any, + ) -> FieldArray: + """Execute the GRIB data retrieval process. + + Parameters + ---------- + context : Any + The execution context. + dates : List[Any] + List of dates to retrieve data for. + indexdb : str + Path to the GRIB index database. + flavour : Optional[str], optional + Flavour configuration for mapping fields, by default None. + **kwargs : Any + Additional filtering criteria. + + Returns + ------- + FieldArray + An array of retrieved GRIB fields. + """ + index = GribIndex(indexdb) + result = [] + + if flavour is not None: + flavour = RuleBasedFlavour(flavour) + + for grib in index.retrieve(dates, **kwargs): + field = ekd.from_source("memory", grib)[0] + if flavour: + field = flavour.apply(field) + result.append(field) + + return FieldArray(result) diff --git a/src/anemoi/datasets/create/sources/hindcasts.py b/src/anemoi/datasets/create/sources/hindcasts.py index d796a74af..ad1df38a5 100644 --- a/src/anemoi/datasets/create/sources/hindcasts.py +++ b/src/anemoi/datasets/create/sources/hindcasts.py @@ -12,9 +12,11 @@ from earthkit.data.core.fieldlist import MultiFieldList -from anemoi.datasets.create.sources.legacy import legacy_source from anemoi.datasets.create.sources.mars import mars +from . import source_registry +from .legacy import LegacySource + LOGGER = logging.getLogger(__name__) @@ -36,57 +38,57 @@ def _to_list(x: list | tuple | Any) -> list[Any]: return [x] -@legacy_source(__file__) -def hindcasts(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: - """Generates hindcast requests based on the provided dates and request parameters. - - Parameters - ---------- - context : Any - The context containing the dates provider and trace method. - dates : List[Any] - A list of dates for which to generate hindcast requests. - request : Dict[str, Any] - Additional request parameters. - - Returns - ------- - MultiFieldList - A MultiFieldList containing the hindcast data. - """ - from anemoi.datasets.dates import HindcastsDates - - provider = context.dates_provider - assert isinstance(provider, HindcastsDates) - - context.trace("H️", f"hindcasts {len(dates)=}") - - request["param"] = _to_list(request["param"]) - request["step"] = _to_list(request.get("step", 0)) - request["step"] = [int(_) for _ in request["step"]] - - context.trace("H️", f"hindcast {request}") - - requests = [] - for d in dates: - r = request.copy() - hindcast = provider.mapping[d] - r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") - r["date"] = hindcast.refdate.strftime("%Y-%m-%d") - r["time"] = hindcast.refdate.strftime("%H") - r["step"] = hindcast.step - requests.append(r) - - if len(requests) == 0: - return MultiFieldList([]) - - return mars( - context, - dates, - *requests, - date_key="hdate", - request_already_using_valid_datetime=True, - ) - - -execute = hindcasts +@source_registry.register("hindcasts") +class HindcastsSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[Any], **request: dict[str, Any]) -> MultiFieldList: + """Generates hindcast requests based on the provided dates and request parameters. + + Parameters + ---------- + context : Any + The context containing the dates provider and trace method. + dates : List[Any] + A list of dates for which to generate hindcast requests. + request : Dict[str, Any] + Additional request parameters. + + Returns + ------- + MultiFieldList + A MultiFieldList containing the hindcast data. + """ + from anemoi.datasets.dates import HindcastsDates + + provider = context.dates_provider + assert isinstance(provider, HindcastsDates) + + context.trace("H️", f"hindcasts {len(dates)=}") + + request["param"] = _to_list(request["param"]) + request["step"] = _to_list(request.get("step", 0)) + request["step"] = [int(_) for _ in request["step"]] + + context.trace("H️", f"hindcast {request}") + + requests = [] + for d in dates: + r = request.copy() + hindcast = provider.mapping[d] + r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") + r["date"] = hindcast.refdate.strftime("%Y-%m-%d") + r["time"] = hindcast.refdate.strftime("%H") + r["step"] = hindcast.step + requests.append(r) + + if len(requests) == 0: + return MultiFieldList([]) + + return mars( + context, + dates, + *requests, + date_key="hdate", + request_already_using_valid_datetime=True, + ) diff --git a/src/anemoi/datasets/create/sources/legacy.py b/src/anemoi/datasets/create/sources/legacy.py index 0de230d29..f9a0288a0 100644 --- a/src/anemoi/datasets/create/sources/legacy.py +++ b/src/anemoi/datasets/create/sources/legacy.py @@ -8,14 +8,13 @@ # nor does it submit to any jurisdiction. -import inspect import logging -import os -from collections.abc import Callable +from abc import abstractmethod from typing import Any -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources import source_registry +from anemoi.datasets.create.input.context import Context + +from ..source import Source LOG = logging.getLogger(__name__) @@ -25,7 +24,7 @@ class LegacySource(Source): Parameters ---------- - context : Any + context : Context The context in which the source is created. *args : tuple Positional arguments. @@ -33,65 +32,15 @@ class LegacySource(Source): Keyword arguments. """ - def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None: + def __init__(self, context: Context, *args: Any, **kwargs: Any) -> None: super().__init__(context, *args, **kwargs) self.args = args self.kwargs = kwargs + @staticmethod + @abstractmethod + def _execute(context, *args, **kwargs): + pass -class legacy_source: - """A decorator class for legacy sources. - - Parameters - ---------- - name : str - The name of the legacy source. - """ - - def __init__(self, name: str) -> None: - name, _ = os.path.splitext(os.path.basename(name)) - self.name = name - - def __call__(self, execute: Callable) -> Callable: - """Call method to wrap the execute function. - - Parameters - ---------- - execute : function - The execute function to be wrapped. - - Returns - ------- - function - The wrapped execute function. - """ - this = self - name = f"Legacy{self.name.title()}Source" - source = ".".join([execute.__module__, execute.__name__]) - - def execute_wrapper(self, dates) -> Any: - """Wrapper method to call the execute function.""" - - # args, kwargs = resolve(context, (self.args, self.kwargs)) - args, kwargs = self.args, self.kwargs - - try: - return execute(self.context, dates, *args, **kwargs) - except TypeError: - LOG.error(f"Error executing source {this.name} from {source}") - LOG.error(f"Function signature is: {inspect.signature(execute)}") - LOG.error(f"Arguments are: {args=}, {kwargs=}") - raise - - klass = type( - name, - (LegacySource,), - { - "execute": execute_wrapper, - "_source": source, - }, - ) - - source_registry.register(self.name)(klass) - - return execute + def execute(self, dates: Any) -> Any: + return self._execute(self.context, dates, *self.args, **self.kwargs) diff --git a/src/anemoi/datasets/create/sources/mars.py b/src/anemoi/datasets/create/sources/mars.py index d59f6034d..25e223cb4 100644 --- a/src/anemoi/datasets/create/sources/mars.py +++ b/src/anemoi/datasets/create/sources/mars.py @@ -16,8 +16,9 @@ from earthkit.data import from_source from earthkit.data.utils.availability import Availability -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.utils import to_datetime_list +from anemoi.datasets.create.sources import source_registry + +from .legacy import LegacySource DEBUG = False @@ -357,135 +358,111 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]: ] -@legacy_source(__file__) -def mars( - context: Any, - dates: list[datetime.datetime], - *requests: dict[str, Any], - request_already_using_valid_datetime: bool = False, - date_key: str = "date", - use_cdsapi_dataset: str | None = None, - **kwargs: Any, -) -> Any: - """Executes MARS requests based on the given context, dates, and other parameters. - - Parameters - ---------- - context : Any - The context for the requests. - dates : List[datetime.datetime] - The list of dates to be used in the requests. - requests : Dict[str, Any] - The input requests to be executed. - request_already_using_valid_datetime : bool, optional - Flag indicating if the requests already use valid datetime. - date_key : str, optional - The key for the date in the requests. - use_cdsapi_dataset : Optional[str], optional - The dataset to be used with CDS API. - kwargs : Any - Additional keyword arguments for the requests. - - Returns - ------- - Any - The resulting dataset. - """ - - if not requests: - requests = [kwargs] - - for r in requests: - param = r.get("param", []) - if not isinstance(param, (list, tuple)): - param = [param] - # check for "Norway bug" where yaml transforms 'no' into False, etc. - for p in param: - if p is False: - raise ValueError( - "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" - ) - if p is None: - raise ValueError( - "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" - ) - if p is True: - raise ValueError( - "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" - ) - - if len(dates) == 0: # When using `repeated_dates` - assert len(requests) == 1, requests - assert "date" in requests[0], requests[0] - if isinstance(requests[0]["date"], datetime.date): - requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") - else: - requests = factorise_requests( - dates, - *requests, - request_already_using_valid_datetime=request_already_using_valid_datetime, - date_key=date_key, - ) - - requests = list(requests) - - ds = from_source("empty") - context.trace("✅", f"{[str(d) for d in dates]}") - context.trace("✅", f"Will run {len(requests)} requests") - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - context.trace("✅", f"mars {r}") - - for r in requests: - r = {k: v for k, v in r.items() if v != ("-",)} - - if context.use_grib_paramid and "param" in r: - r = use_grib_paramid(r) - - for k, v in r.items(): - if k not in MARS_KEYS: - raise ValueError( - f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" - ) - try: - if use_cdsapi_dataset: - ds = ds + from_source("cds", use_cdsapi_dataset, r) - else: - ds = ds + from_source("mars", **r) - except Exception as e: - if "File is empty:" not in str(e): - raise - return ds - - -execute = mars - - -if __name__ == "__main__": - import yaml - - config = yaml.safe_load( +@source_registry.register("mars") +class MarsSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: list[datetime.datetime], + *requests: dict[str, Any], + request_already_using_valid_datetime: bool = False, + date_key: str = "date", + use_cdsapi_dataset: str | None = None, + **kwargs: Any, + ) -> Any: + """Executes MARS requests based on the given context, dates, and other parameters. + + Parameters + ---------- + context : Any + The context for the requests. + dates : List[datetime.datetime] + The list of dates to be used in the requests. + requests : Dict[str, Any] + The input requests to be executed. + request_already_using_valid_datetime : bool, optional + Flag indicating if the requests already use valid datetime. + date_key : str, optional + The key for the date in the requests. + use_cdsapi_dataset : Optional[str], optional + The dataset to be used with CDS API. + kwargs : Any + Additional keyword arguments for the requests. + + Returns + ------- + Any + The resulting dataset. """ - - class: ea - expver: '0001' - grid: 20.0/20.0 - levtype: sfc - param: [2t] - # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z] - number: [0, 1] - - # - class: ea - # expver: '0001' - # grid: 20.0/20.0 - # levtype: pl - # param: [q] - # levelist: [1000, 850] - """ - ) - dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) + if not requests: + requests = [kwargs] + + for r in requests: + param = r.get("param", []) + if not isinstance(param, (list, tuple)): + param = [param] + # check for "Norway bug" where yaml transforms 'no' into False, etc. + for p in param: + if p is False: + raise ValueError( + "'param' cannot be 'False'. If you wrote 'param: no' or 'param: off' in yaml, you may want to use quotes?" + ) + if p is None: + raise ValueError( + "'param' cannot be 'None'. If you wrote 'param: no' in yaml, you may want to use quotes?" + ) + if p is True: + raise ValueError( + "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?" + ) + + if len(dates) == 0: # When using `repeated_dates` + assert len(requests) == 1, requests + assert "date" in requests[0], requests[0] + if isinstance(requests[0]["date"], datetime.date): + requests[0]["date"] = requests[0]["date"].strftime("%Y%m%d") + else: + requests = factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=request_already_using_valid_datetime, + date_key=date_key, + ) - DEBUG = True - for f in mars(None, dates, *config): - print(f, f.to_numpy().mean()) + requests = list(requests) + + ds = from_source("empty") + context.trace("✅", f"{[str(d) for d in dates]}") + context.trace("✅", f"Will run {len(requests)} requests") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + context.trace("✅", f"mars {r}") + + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + + if context.use_grib_paramid and "param" in r: + r = use_grib_paramid(r) + + for k, v in r.items(): + if k not in MARS_KEYS: + raise ValueError( + f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?" + ) + try: + if use_cdsapi_dataset: + ds = ds + from_source("cds", use_cdsapi_dataset, r) + else: + ds = ds + from_source("mars", **r) + except Exception as e: + if "File is empty:" not in str(e): + raise + return ds + + +# TODO: make clearer the interface between sources that use mars. +# Currently some sources use mars as a function rather than through the registry, +# e.g. accumulations, accumulations2, hindcasts, recentre, tendencies +mars = MarsSource._execute diff --git a/src/anemoi/datasets/create/sources/netcdf.py b/src/anemoi/datasets/create/sources/netcdf.py index 606a8dd53..e6f4271a7 100644 --- a/src/anemoi/datasets/create/sources/netcdf.py +++ b/src/anemoi/datasets/create/sources/netcdf.py @@ -12,30 +12,34 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many - - -@legacy_source(__file__) -def execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the loading of multiple NetCDF files. - - Parameters - ---------- - context : object - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - path : str - Path to the directory containing the NetCDF files. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - object - The loaded data. - """ - return load_many("📁", context, dates, path, *args, **kwargs) +from . import source_registry +from .legacy import LegacySource +from .xarray import load_many + + +@source_registry.register("netcdf") +class NetCDFSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], path: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the loading of multiple NetCDF files. + + Parameters + ---------- + context : object + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + path : str + Path to the directory containing the NetCDF files. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + object + The loaded data. + """ + return load_many("📁", context, dates, path, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/opendap.py b/src/anemoi/datasets/create/sources/opendap.py index 34e3fe94d..86cd3e6d2 100644 --- a/src/anemoi/datasets/create/sources/opendap.py +++ b/src/anemoi/datasets/create/sources/opendap.py @@ -12,30 +12,34 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many - - -@legacy_source(__file__) -def execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process from an OpenDAP source. - - Parameters - ---------- - context : dict - The context in which the function is executed. - dates : list - List of dates for which data is to be loaded. - url : str - The URL of the OpenDAP source. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - xarray.Dataset - The loaded dataset. - """ - return load_many("🌐", context, dates, url, *args, **kwargs) +from . import source_registry +from .legacy import LegacySource +from .xarray import load_many + + +@source_registry.register("opendap") +class OpenDAPSource(LegacySource): + + @staticmethod + def _execute(context: dict[str, Any], dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process from an OpenDAP source. + + Parameters + ---------- + context : dict + The context in which the function is executed. + dates : list + List of dates for which data is to be loaded. + url : str + The URL of the OpenDAP source. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + xarray.Dataset + The loaded dataset. + """ + return load_many("🌐", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py index 07e8f0203..b710bcbbe 100644 --- a/src/anemoi/datasets/create/sources/planetary_computer.py +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("planetary_computer") diff --git a/src/anemoi/datasets/create/sources/recentre.py b/src/anemoi/datasets/create/sources/recentre.py index d0959f664..2d6c70b1d 100644 --- a/src/anemoi/datasets/create/sources/recentre.py +++ b/src/anemoi/datasets/create/sources/recentre.py @@ -11,8 +11,10 @@ from typing import Any from anemoi.datasets.compute.recentre import recentre as _recentre -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.mars import mars + +from . import source_registry +from .legacy import LegacySource +from .mars import mars def to_list(x: list | tuple | str) -> list: @@ -104,43 +106,43 @@ def load_if_needed(context: Any, dates: Any, dict_or_dataset: dict | Any) -> Any return dict_or_dataset -@legacy_source(__file__) -def recentre( - context: Any, - dates: Any, - members: dict | Any, - centre: dict | Any, - alpha: float = 1.0, - remapping: dict = {}, - patches: dict = {}, -) -> Any: - """Recentres the members dataset using the centre dataset. - - Parameters - ---------- - context : Any - The context for recentering. - dates : Any - The dates for recentering. - members : Union[dict, Any] - The members dataset or request dictionary. - centre : Union[dict, Any] - The centre dataset or request dictionary. - alpha : float, optional - The alpha value for recentering. Defaults to 1.0. - remapping : dict, optional - The remapping dictionary. Defaults to {}. - patches : dict, optional - The patches dictionary. Defaults to {}. - - Returns - ------- - Any - The recentred dataset. - """ - members = load_if_needed(context, dates, members) - centre = load_if_needed(context, dates, centre) - return _recentre(members=members, centre=centre, alpha=alpha) - - -execute = recentre +@source_registry.register("recentre") +class RecentreSource(LegacySource): + + @staticmethod + def _execute( + context: Any, + dates: Any, + members: dict | Any, + centre: dict | Any, + alpha: float = 1.0, + remapping: dict = {}, + patches: dict = {}, + ) -> Any: + """Recentres the members dataset using the centre dataset. + + Parameters + ---------- + context : Any + The context for recentering. + dates : Any + The dates for recentering. + members : Union[dict, Any] + The members dataset or request dictionary. + centre : Union[dict, Any] + The centre dataset or request dictionary. + alpha : float, optional + The alpha value for recentering. Defaults to 1.0. + remapping : dict, optional + The remapping dictionary. Defaults to {}. + patches : dict, optional + The patches dictionary. Defaults to {}. + + Returns + ------- + Any + The recentred dataset. + """ + members = load_if_needed(context, dates, members) + centre = load_if_needed(context, dates, centre) + return _recentre(members=members, centre=centre, alpha=alpha) diff --git a/src/anemoi/datasets/create/sources/repeated_dates.py b/src/anemoi/datasets/create/sources/repeated_dates.py index 77a06c76c..f1f86eb78 100644 --- a/src/anemoi/datasets/create/sources/repeated_dates.py +++ b/src/anemoi/datasets/create/sources/repeated_dates.py @@ -9,305 +9,34 @@ import logging -from collections import defaultdict -from collections.abc import Generator from typing import Any -import numpy as np from anemoi.transform.fields import new_field_with_valid_datetime from anemoi.transform.fields import new_fieldlist_from_list -from anemoi.utils.dates import as_datetime -from anemoi.utils.dates import frequency_to_timedelta +from anemoi.datasets.create.input.repeated_dates import DateMapper from anemoi.datasets.create.source import Source from anemoi.datasets.create.sources import source_registry -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - LOG = logging.getLogger(__name__) -class Action: - pass - - -class Result: - pass - - -class DateMapper: - """A factory class to create DateMapper instances based on the given mode.""" - - @staticmethod - def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": - """Create a DateMapper instance based on the given mode. - - Parameters - ---------- - mode : str - The mode to use for the DateMapper. - source : Any - The data source. - config : dict - Configuration parameters. - - Returns - ------- - DateMapper - An instance of DateMapper. - """ - MODES: dict = dict( - closest=DateMapperClosest, - climatology=DateMapperClimatology, - constant=DateMapperConstant, - ) - - if mode not in MODES: - raise ValueError(f"Invalid mode for DateMapper: {mode}") - - return MODES[mode](source, **config) - - -class DateMapperClosest(DateMapper): - """A DateMapper implementation that maps dates to the closest available dates.""" - - def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: - """Initialize DateMapperClosest. - - Parameters - ---------- - source : Any - The data source. - frequency : str - Frequency of the dates. - maximum : str - Maximum time delta. - skip_all_nans : bool - Whether to skip all NaN values. - """ - self.source: Any = source - self.maximum: Any = frequency_to_timedelta(maximum) - self.frequency: Any = frequency_to_timedelta(frequency) - self.skip_all_nans: bool = skip_all_nans - self.tried: set[Any] = set() - self.found: set[Any] = set() - - def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: - """Transform the group of dates to the closest available dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Generator[Tuple[Any, Any], None, None] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - asked_dates = list(group_of_dates) - if not asked_dates: - return [] - - to_try = set() - for date in asked_dates: - start = date - while start >= date - self.maximum: - to_try.add(start) - start -= self.frequency - - end = date - while end <= date + self.maximum: - to_try.add(end) - end += self.frequency - - to_try = sorted(to_try - self.tried) - info = {k: "no-data" for k in to_try} - - if not to_try: - LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") - # return [] - - if to_try: - result = self.source.select( - GroupOfDates( - sorted(to_try), - group_of_dates.provider, - partial_ok=True, - ) - ) - - cnt = 0 - for f in result.datasource: - cnt += 1 - # We could keep the fields in a dictionary, but we don't want to keep the fields in memory - date = as_datetime(f.metadata("valid_datetime")) - - if self.skip_all_nans: - if np.isnan(f.to_numpy()).all(): - LOG.warning(f"Skipping {date} because all values are NaN") - info[date] = "all-nans" - continue - - info[date] = "ok" - self.found.add(date) - - if cnt == 0: - raise ValueError(f"No data found for {group_of_dates} in {self.source}") - - self.tried.update(to_try) - - if not self.found: - for k, v in info.items(): - LOG.warning(f"{k}: {v}") - - raise ValueError(f"No matching data found for {asked_dates} in {self.source}") - - new_dates = defaultdict(list) - - for date in asked_dates: - best = None - for found_date in sorted(self.found): - delta = abs(date - found_date) - # With < we prefer the first date - # With <= we prefer the last date - if best is None or delta <= best[0]: - best = delta, found_date - new_dates[best[1]].append(date) - - for date, dates in new_dates.items(): - yield ( - GroupOfDates([date], group_of_dates.provider), - GroupOfDates(dates, group_of_dates.provider), - ) - - -class DateMapperClimatology(DateMapper): - """A DateMapper implementation that maps dates to specified climatology dates.""" - - def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: - """Initialize DateMapperClimatology. - - Parameters - ---------- - source : Any - The data source. - year : int - The year to map to. - day : int - The day to map to. - hour : Optional[int] - The hour to map to. - """ - self.year: int = year - self.day: int = day - self.hour: int | None = hour - - def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: - """Transform the group of dates to the specified climatology dates. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Generator[Tuple[Any, Any], None, None] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - dates = list(group_of_dates) - if not dates: - return [] - - new_dates = defaultdict(list) - for date in dates: - new_date = date.replace(year=self.year, day=self.day) - if self.hour is not None: - new_date = new_date.replace(hour=self.hour, minute=0, second=0) - new_dates[new_date].append(date) - - for date, dates in new_dates.items(): - yield ( - GroupOfDates([date], group_of_dates.provider), - GroupOfDates(dates, group_of_dates.provider), - ) - - -class DateMapperConstant(DateMapper): - """A DateMapper implementation that maps dates to a constant date.""" - - def __init__(self, source: Any, date: Any | None = None) -> None: - """Initialize DateMapperConstant. - - Parameters - ---------- - source : Any - The data source. - date : Optional[Any] - The constant date to map to. - """ - self.source: Any = source - self.date: Any | None = date - - def transform(self, group_of_dates: Any) -> tuple[Any, Any]: - """Transform the group of dates to a constant date. - - Parameters - ---------- - group_of_dates : Any - The group of dates to transform. - - Returns - ------- - Tuple[Any, Any] - Transformed dates. - """ - from anemoi.datasets.dates.groups import GroupOfDates - - if self.date is None: - return [ - ( - GroupOfDates([], group_of_dates.provider), - group_of_dates, - ) - ] - - return [ - ( - GroupOfDates([self.date], group_of_dates.provider), - group_of_dates, - ) - ] - - @source_registry.register("repeated_dates") class RepeatedDatesSource(Source): - def __init__(self, owner, source: Any, mode: str, **kwargs) -> None: - self.owner = owner + def __init__(self, context, source: Any, mode: str, **kwargs) -> None: + # assert False, (context, source, mode, kwargs) + super().__init__(context, **kwargs) self.mapper = DateMapper.from_mode(mode, source, kwargs) self.source = source - def execute(self, context, group_of_dates): - source = context.create_source(self.source, *self.owner.path, "source") + def execute(self, group_of_dates): + source = self.context.create_source(self.source, "data_sources", str(id(self))) result = [] for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}") - source_results = source(context, one_date_group) + source_results = source(self.context, one_date_group) for field in source_results: for date in many_dates_group: result.append(new_field_with_valid_datetime(field, date)) diff --git a/src/anemoi/datasets/create/sources/source.py b/src/anemoi/datasets/create/sources/source.py deleted file mode 100644 index 1bac545d8..000000000 --- a/src/anemoi/datasets/create/sources/source.py +++ /dev/null @@ -1,68 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from datetime import datetime -from typing import Any - -from earthkit.data import from_source - -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.utils import to_datetime_list - - -@legacy_source(__file__) -def source(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any: - """Generates a source based on the provided context, dates, and additional keyword arguments. - - Parameters - ---------- - context : Optional[Any] - The context in which the source is generated. - dates : List[datetime] - A list of datetime objects representing the dates. - **kwargs : Any - Additional keyword arguments for the source generation. - - Returns - ------- - Any - The generated source. - """ - name = kwargs.pop("name") - context.trace("✅", f"from_source({name}, {dates}, {kwargs}") - if kwargs["date"] == "$from_dates": - kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates}) - if kwargs["time"] == "$from_dates": - kwargs["time"] = list({d.strftime("%H%M") for d in dates}) - return from_source(name, **kwargs) - - -execute = source - -if __name__ == "__main__": - import yaml - - config: dict[str, Any] = yaml.safe_load( - """ - name: mars - class: ea - expver: '0001' - grid: 20.0/20.0 - levtype: sfc - param: [2t] - number: [0, 1] - date: $from_dates - time: $from_dates - """ - ) - dates: list[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]") - dates = to_datetime_list(dates) - - for f in source(None, dates, **config): - print(f, f.to_numpy().mean()) diff --git a/src/anemoi/datasets/create/sources/xarray.py b/src/anemoi/datasets/create/sources/xarray.py index 5e3cc4c10..d63b708d6 100644 --- a/src/anemoi/datasets/create/sources/xarray.py +++ b/src/anemoi/datasets/create/sources/xarray.py @@ -11,12 +11,13 @@ import earthkit.data as ekd -from anemoi.datasets.create.source import Source -from anemoi.datasets.create.sources.xarray_support import XarrayFieldList -from anemoi.datasets.create.sources.xarray_support import load_many -from anemoi.datasets.create.sources.xarray_support import load_one from anemoi.datasets.create.typing import DateList +from ..source import Source +from .xarray_support import XarrayFieldList +from .xarray_support import load_many +from .xarray_support import load_one + __all__ = ["load_many", "load_one", "XarrayFieldList"] diff --git a/src/anemoi/datasets/create/sources/xarray_kerchunk.py b/src/anemoi/datasets/create/sources/xarray_kerchunk.py index 632a7cae2..056d756ca 100644 --- a/src/anemoi/datasets/create/sources/xarray_kerchunk.py +++ b/src/anemoi/datasets/create/sources/xarray_kerchunk.py @@ -8,8 +8,8 @@ # nor does it submit to any jurisdiction. -from anemoi.datasets.create.sources import source_registry -from anemoi.datasets.create.sources.xarray import XarraySourceBase +from . import source_registry +from .xarray import XarraySourceBase @source_registry.register("xarray_kerchunk") diff --git a/src/anemoi/datasets/create/sources/xarray_zarr.py b/src/anemoi/datasets/create/sources/xarray_zarr.py index 2f96ab207..2e89981bd 100644 --- a/src/anemoi/datasets/create/sources/xarray_zarr.py +++ b/src/anemoi/datasets/create/sources/xarray_zarr.py @@ -11,30 +11,34 @@ import earthkit.data as ekd -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.xarray import load_many - - -@legacy_source(__file__) -def execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Execute the data loading process. - - Parameters - ---------- - context : Any - The context in which the execution occurs. - dates : List[str] - List of dates for which data is to be loaded. - url : str - The URL from which data is to be loaded. - *args : tuple - Additional positional arguments. - **kwargs : dict - Additional keyword arguments. - - Returns - ------- - ekd.FieldList - The loaded data. - """ - return load_many("🇿", context, dates, url, *args, **kwargs) +from . import source_registry +from .legacy import LegacySource +from .xarray import load_many + + +@source_registry.register("xarray_zarr") +class XarrayZarrSource(LegacySource): + + @staticmethod + def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Execute the data loading process. + + Parameters + ---------- + context : Any + The context in which the execution occurs. + dates : List[str] + List of dates for which data is to be loaded. + url : str + The URL from which data is to be loaded. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + ekd.FieldList + The loaded data. + """ + return load_many("🇿", context, dates, url, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/zenodo.py b/src/anemoi/datasets/create/sources/zenodo.py index e23b8fa47..9f4d68f97 100644 --- a/src/anemoi/datasets/create/sources/zenodo.py +++ b/src/anemoi/datasets/create/sources/zenodo.py @@ -14,54 +14,58 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.sources.url import download_and_cache -from anemoi.datasets.create.sources.legacy import legacy_source -from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.create.sources.xarray import load_one +from . import source_registry +from .legacy import LegacySource +from .patterns import iterate_patterns +from .xarray import load_one -@legacy_source(__file__) -def execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: - """Executes the download and processing of files from Zenodo. +@source_registry.register("zenodo") +class ZenodoSource(LegacySource): - Parameters - ---------- - context : Any - The context in which the function is executed. - dates : Any - The dates for which the data is required. - record_id : str - The Zenodo record ID. - file_key : str - The key to identify the file. - *args : Any - Additional arguments. - **kwargs : Any - Additional keyword arguments. + @staticmethod + def _execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList: + """Executes the download and processing of files from Zenodo. - Returns - ------- - MultiFieldList - A list of fields loaded from the downloaded files. - """ - import requests + Parameters + ---------- + context : Any + The context in which the function is executed. + dates : Any + The dates for which the data is required. + record_id : str + The Zenodo record ID. + file_key : str + The key to identify the file. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. - result: list[Any] = [] + Returns + ------- + MultiFieldList + A list of fields loaded from the downloaded files. + """ + import requests - URLPATTERN = "https://zenodo.org/api/records/{record_id}" - url = URLPATTERN.format(record_id=record_id) - r = requests.get(url) - r.raise_for_status() - record: dict[str, Any] = r.json() + result: list[Any] = [] - urls: dict[str, str] = {} - for file in record["files"]: - urls[file["key"]] = file["links"]["self"] + URLPATTERN = "https://zenodo.org/api/records/{record_id}" + url = URLPATTERN.format(record_id=record_id) + r = requests.get(url) + r.raise_for_status() + record: dict[str, Any] = r.json() - for url, dates in iterate_patterns(file_key, dates, **kwargs): - if url not in urls: - continue + urls: dict[str, str] = {} + for file in record["files"]: + urls[file["key"]] = file["links"]["self"] - path = download_and_cache(urls[url]) - result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) + for url, dates in iterate_patterns(file_key, dates, **kwargs): + if url not in urls: + continue - return MultiFieldList(result) + path = download_and_cache(urls[url]) + result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs)) + + return MultiFieldList(result) From 242a421da8f1d31c6a498f6d57aa9d5e81e89f67 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:28:06 +0100 Subject: [PATCH 202/212] tidy up --- src/anemoi/datasets/.gitignore | 1 - src/anemoi/datasets/__main__.py | 4 +- src/anemoi/datasets/grids.py | 185 +------------------------- src/anemoi/datasets/traits/gridded.py | 2 - src/anemoi/datasets/traits/tabular.py | 2 - 5 files changed, 6 insertions(+), 188 deletions(-) delete mode 100644 src/anemoi/datasets/.gitignore delete mode 100644 src/anemoi/datasets/traits/gridded.py delete mode 100644 src/anemoi/datasets/traits/tabular.py diff --git a/src/anemoi/datasets/.gitignore b/src/anemoi/datasets/.gitignore deleted file mode 100644 index 0aba28e9b..000000000 --- a/src/anemoi/datasets/.gitignore +++ /dev/null @@ -1 +0,0 @@ -!build/ diff --git a/src/anemoi/datasets/__main__.py b/src/anemoi/datasets/__main__.py index f47c46050..62b7d7c73 100644 --- a/src/anemoi/datasets/__main__.py +++ b/src/anemoi/datasets/__main__.py @@ -12,8 +12,8 @@ from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser -from anemoi.datasets import __version__ -from anemoi.datasets.commands import COMMANDS +from . import __version__ +from .commands import COMMANDS # For read-the-docs diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index ffec5e351..075f73495 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -8,11 +8,11 @@ # nor does it submit to any jurisdiction. -import base64 import logging from typing import Any import numpy as np +from anemoi.utils.grids import latlon_to_xyz from numpy.typing import NDArray LOG = logging.getLogger(__name__) @@ -88,71 +88,6 @@ def plot_mask( plt.savefig(path + "-global-zoomed.png") -# TODO: Use the one from anemoi.utils.grids instead -# from anemoi.utils.grids import ... -def xyz_to_latlon(x: NDArray[Any], y: NDArray[Any], z: NDArray[Any]) -> tuple[NDArray[Any], NDArray[Any]]: - """Convert Cartesian coordinates to latitude and longitude. - - Parameters - ---------- - x : NDArray[Any] - X coordinates. - y : NDArray[Any] - Y coordinates. - z : NDArray[Any] - Z coordinates. - - Returns - ------- - Tuple[NDArray[Any], NDArray[Any]] - Latitude and longitude coordinates. - """ - return ( - np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))), - np.rad2deg(np.arctan2(y, x)), - ) - - -# TODO: Use the one from anemoi.utils.grids instead -# from anemoi.utils.grids import ... -def latlon_to_xyz( - lat: NDArray[Any], lon: NDArray[Any], radius: float = 1.0 -) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: - """Convert latitude and longitude to Cartesian coordinates. - - Parameters - ---------- - lat : NDArray[Any] - Latitude coordinates. - lon : NDArray[Any] - Longitude coordinates. - radius : float, optional - Radius of the sphere. Defaults to 1.0. - - Returns - ------- - Tuple[NDArray[Any], NDArray[Any], NDArray[Any]] - X, Y, and Z coordinates. - """ - # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates - # We assume that the Earth is a sphere of radius 1 so N(phi) = 1 - # We assume h = 0 - # - phi = np.deg2rad(lat) - lda = np.deg2rad(lon) - - cos_phi = np.cos(phi) - cos_lda = np.cos(lda) - sin_phi = np.sin(phi) - sin_lda = np.sin(lda) - - x = cos_phi * cos_lda * radius - y = cos_phi * sin_lda * radius - z = sin_phi * radius - - return x, y, z - - class Triangle3D: """A class to represent a 3D triangle and perform intersection tests with rays.""" @@ -509,92 +444,6 @@ def outline(lats: NDArray[Any], lons: NDArray[Any], neighbours: int = 5) -> list return outside -def deserialise_mask(encoded: str) -> NDArray[Any]: - """Deserialise a mask from a base64 encoded string. - - Parameters - ---------- - encoded : str - Base64 encoded string. - - Returns - ------- - NDArray[Any] - Deserialised mask array. - """ - import pickle - import zlib - - packed = pickle.loads(zlib.decompress(base64.b64decode(encoded))) - - mask = [] - value = False - for count in packed: - mask.extend([value] * count) - value = not value - return np.array(mask, dtype=bool) - - -def _serialise_mask(mask: NDArray[Any]) -> str: - """Serialise a mask to a base64 encoded string. - - Parameters - ---------- - mask : NDArray[Any] - Mask array. - - Returns - ------- - str - Base64 encoded string. - """ - import pickle - import zlib - - assert len(mask.shape) == 1 - assert len(mask) - - packed = [] - last = mask[0] - count = 1 - - for value in mask[1:]: - if value == last: - count += 1 - else: - packed.append(count) - last = value - count = 1 - - packed.append(count) - - # We always start with an 'off' value - # So if the first value is 'on', we need to add a zero - if mask[0]: - packed.insert(0, 0) - - return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8") - - -def serialise_mask(mask: NDArray[Any]) -> str: - """Serialise a mask and ensure it can be deserialised. - - Parameters - ---------- - mask : NDArray[Any] - Mask array. - - Returns - ------- - str - Base64 encoded string. - """ - result = _serialise_mask(mask) - # Make sure we can deserialise it - assert np.all(mask == deserialise_mask(result)) - return result - - def nearest_grid_points( source_latitudes: NDArray[Any], source_longitudes: NDArray[Any], @@ -628,7 +477,7 @@ def nearest_grid_points( """ # TODO: Use the one from anemoi.utils.grids instead # from anemoi.utils.grids import ... - from scipy.spatial import KDTree + from scipy.spatial import cKDTree source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) source_points = np.array(source_xyz).transpose() @@ -636,33 +485,7 @@ def nearest_grid_points( target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) target_points = np.array(target_xyz).transpose() if max_distance is None: - distances, indices = KDTree(source_points).query(target_points, k=k) + distances, indices = cKDTree(source_points).query(target_points, k=k) else: - distances, indices = KDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) + distances, indices = cKDTree(source_points).query(target_points, k=k, distance_upper_bound=max_distance) return distances, indices - - -if __name__ == "__main__": - global_lats, global_lons = np.meshgrid( - np.linspace(90, -90, 90), - np.linspace(-180, 180, 180), - ) - global_lats = global_lats.flatten() - global_lons = global_lons.flatten() - - lats, lons = np.meshgrid( - np.linspace(50, 40, 100), - np.linspace(-10, 15, 100), - ) - lats = lats.flatten() - lons = lons.flatten() - - mask = cutout_mask(lats, lons, global_lats, global_lons, cropping_distance=5.0) - - import matplotlib.pyplot as plt - - fig = plt.figure(figsize=(10, 5)) - plt.scatter(global_lons, global_lats, s=0.01, marker="o", c="r") - plt.scatter(global_lons[mask], global_lats[mask], s=0.1, c="k") - # plt.scatter(lons, lats, s=0.01) - plt.savefig("cutout.png") diff --git a/src/anemoi/datasets/traits/gridded.py b/src/anemoi/datasets/traits/gridded.py deleted file mode 100644 index 62f8c87d3..000000000 --- a/src/anemoi/datasets/traits/gridded.py +++ /dev/null @@ -1,2 +0,0 @@ -class Gridded: - pass diff --git a/src/anemoi/datasets/traits/tabular.py b/src/anemoi/datasets/traits/tabular.py deleted file mode 100644 index 4ad7058a7..000000000 --- a/src/anemoi/datasets/traits/tabular.py +++ /dev/null @@ -1,2 +0,0 @@ -class Tabular: - pass From 7f593fb27d8e28724b391f27f858896541324ae4 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:35:01 +0100 Subject: [PATCH 203/212] tidy up --- src/anemoi/datasets/create/input/action.py | 133 ++++++++++------- src/anemoi/datasets/create/input/origin.py | 159 --------------------- 2 files changed, 85 insertions(+), 207 deletions(-) delete mode 100644 src/anemoi/datasets/create/input/origin.py diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py index 62015a4ac..7808ae717 100644 --- a/src/anemoi/datasets/create/input/action.py +++ b/src/anemoi/datasets/create/input/action.py @@ -8,15 +8,20 @@ # nor does it submit to any jurisdiction. import logging -from abc import ABC -from abc import abstractmethod from anemoi.datasets.dates import DatesProvider LOG = logging.getLogger(__name__) -class Action(ABC): +class Action: + """An "Action" represents a single operation described in the yaml configuration, e.g. a source, a filter, + pipe, join, etc. + + See :ref:`operations` for more details. + + """ + def __init__(self, config, *path): self.config = config self.path = path @@ -25,15 +30,32 @@ def __init__(self, config, *path): "data_sources", ), f"{self.__class__.__name__}: path must start with 'input' or 'data_sources': {path}" - @abstractmethod - def __call__(self, context, argument): - pass - def __repr__(self): - return f"{self.__class__.__name__}({'.'.join(str(x) for x in self.path)}, {self.config})" +class Concat(Action): + """The Concat contruct is used to concat different actions that are responsible + for delivery fields for different dates. + See :ref:`building-concat` for more details. + + .. block-code:: yaml + + input: + concat: + - dates: + start: 2023-01-01 + end: 2023-01-31 + frequency: 1d + action: # some action + ... + + - dates: + start: 2023-02-01 + end: 2023-02-28 + frequency: 1d + action: # some action + + """ -class Concat(Action): def __init__(self, config, *path): super().__init__(config, *path, "concat") @@ -43,7 +65,6 @@ def __init__(self, config, *path): for i, item in enumerate(config): - assert "dates" in item, f"Value must contain the key 'dates' {item}" dates = item["dates"] filtering_dates = DatesProvider.from_config(**dates) action = action_factory({k: v for k, v in item.items() if k != "dates"}, *self.path, str(i)) @@ -66,10 +87,26 @@ def __call__(self, context, argument): class Join(Action): + """Implement the join operation to combine results from multiple actions. + + See :ref:`building-join` for more details. + + .. block-code:: yaml + + input: + join: + - grib: + ... + + - netcdf: # some other action + ... + + """ + def __init__(self, config, *path): super().__init__(config, *path, "join") - assert isinstance(config, list), f"Value must be a list {config}" + assert isinstance(config, list), f"Value of Join Action must be a list, got: {config}" self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] @@ -86,8 +123,25 @@ def __call__(self, context, argument): class Pipe(Action): + """Implement the pipe operation to chain results from a + source through multiple filters. + + See :ref:`building-pipe` for more details. + + .. block-code:: yaml + + input: + pipe: + - grib: + ... + + - rename: + ... + + """ + def __init__(self, config, *path): - assert isinstance(config, list), f"Value must be a list {config}" + assert isinstance(config, list), f"Value of Pipe Action must be a list, got {config}" super().__init__(config, *path, "pipe") self.actions = [action_factory(item, *self.path, str(i)) for i, item in enumerate(config)] @@ -107,6 +161,8 @@ def __call__(self, context, argument): class Function(Action): + """Base class for sources and filters.""" + def __init__(self, config, *path): super().__init__(config, *path, self.name) @@ -122,54 +178,43 @@ def __call__(self, context, argument): class DatasetSourceMixin: + """Mixin class for sources defined in anemoi-datasets""" + def create_object(self, context, config): from anemoi.datasets.create.sources import create_source as create_datasets_source return create_datasets_source(context, config) def call_object(self, context, source, argument): - result = source.execute(context.source_argument(argument)) - return context.origin(result, self, argument) - - def origin(self): - from anemoi.datasets.create.input.origin import Source - - return Source(self.path[-1], self.config) + return source.execute(context.source_argument(argument)) class TransformSourceMixin: + """Mixin class for sources defined in anemoi-transform""" + def create_object(self, context, config): from anemoi.transform.sources import create_source as create_transform_source return create_transform_source(context, config) - def combine_origins(self, current, previous): - assert previous is None, f"Cannot combine origins, previous already exists: {previous}" - return current - - def origin(self): - from anemoi.datasets.create.input.origin import Source - - return Source(self.path[-1], self.config) - class TransformFilterMixin: + """Mixin class for filters defined in anemoi-transform""" + def create_object(self, context, config): from anemoi.transform.filters import create_filter as create_transform_filter return create_transform_filter(context, config) def call_object(self, context, filter, argument): - result = filter.forward(context.filter_argument(argument)) - return context.origin(result, self, argument) + return filter.forward(context.filter_argument(argument)) - def origin(self): - from anemoi.datasets.create.input.origin import Filter - return Filter(self.path[-1], self.config) +class FilterFunction(Function): + """Action to call a filter on the argument (e.g. rename, regrid, etc.).""" - def combine_origins(self, current, previous): - return {"_apply": current, **(previous or {})} + def __call__(self, context, argument): + return self.call(context, argument, context.filter_argument) def _make_name(name, what): @@ -195,6 +240,8 @@ def new_filter(name, mixin): class DataSources(Action): + """Action to call a source (e.g. mars, netcdf, grib, etc.).""" + def __init__(self, config, *path): super().__init__(config, *path) assert isinstance(config, (dict, list)), f"Invalid config type: {type(config)}" @@ -209,6 +256,8 @@ def __call__(self, context, argument): class Recipe(Action): + """Action that represent a recipe (i.e. a sequence of data_sources and input).""" + def __init__(self, input, data_sources): self.input = input self.data_sources = data_sources @@ -227,7 +276,6 @@ def __call__(self, context, argument): } LEN_KLASS = len(KLASS) -TYPES = {} def make(key, config, *path): @@ -244,28 +292,17 @@ def make(key, config, *path): for name in dataset_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, DatasetSourceMixin) - TYPES[name.replace("_", "-")] = "source" for name in transform_source_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_source(name, TransformSourceMixin) - TYPES[name.replace("_", "-")] = "source" # Register filters for name in transform_filter_registry.registered: if name not in KLASS: KLASS[name.replace("_", "-")] = new_filter(name, TransformFilterMixin) - TYPES[name.replace("_", "-")] = "filter" - - key = key.replace("_", "-") - - if key not in KLASS: - LOG.error(f"Unknown action '{key}' in {'.'.join(x for x in path)}") - for available in sorted(KLASS): - LOG.error(f" Available: {available} (type={TYPES.get(available, 'built-in')})") - raise ValueError(f"Unknown action '{key}' in {'.'.join(x for x in path)}") - return KLASS[key](config, *path) + return KLASS[key.replace("_", "-")](config, *path) def action_factory(data, *path): diff --git a/src/anemoi/datasets/create/input/origin.py b/src/anemoi/datasets/create/input/origin.py deleted file mode 100644 index 9f5173afc..000000000 --- a/src/anemoi/datasets/create/input/origin.py +++ /dev/null @@ -1,159 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -from abc import ABC - -LOG = logging.getLogger(__name__) - - -class Origin(ABC): - - def __init__(self, when="dataset-create"): - self.when = when - - def __eq__(self, other): - if not isinstance(other, Origin): - return False - return self is other - - def __hash__(self): - return id(self) - - -def _un_dotdict(x): - if isinstance(x, dict): - return {k: _un_dotdict(v) for k, v in x.items()} - - if isinstance(x, (list, tuple, set)): - return [_un_dotdict(a) for a in x] - - return x - - -class Pipe(Origin): - def __init__(self, s1, s2, when="dataset-create"): - super().__init__(when) - self.steps = [s1, s2] - - assert s1 is not None, (s1, s2) - assert s2 is not None, (s1, s2) - - if isinstance(s1, Pipe): - assert not isinstance(s2, Pipe), (s1, s2) - self.steps = s1.steps + [s2] - - def combine(self, previous, action, action_arguments): - assert False, (self, previous) - - def as_dict(self): - return { - "type": "pipe", - "steps": [s.as_dict() for s in self.steps], - "when": self.when, - } - - def __repr__(self): - return " | ".join(repr(s) for s in self.steps) - - -class Join(Origin): - def __init__(self, origins, when="dataset-create"): - assert isinstance(origins, (list, tuple, set)), origins - super().__init__(when) - self.steps = list(origins) - - assert all(o is not None for o in origins), origins - - def combine(self, previous, action, action_arguments): - assert False, (self, previous) - - def as_dict(self): - return { - "type": "join", - "steps": [s.as_dict() for s in self.steps], - "when": self.when, - } - - def __repr__(self): - return " & ".join(repr(s) for s in self.steps) - - -class Source(Origin): - def __init__(self, name, config, when="dataset-create"): - super().__init__(when) - assert isinstance(config, dict), f"Config must be a dictionary {config}" - self.name = name - self.config = _un_dotdict(config) - - def combine(self, previous, action, action_arguments): - assert previous is None, f"Cannot combine origins, previous already exists: {previous}" - return self - - def as_dict(self): - return { - "type": "source", - "name": self.name, - "config": self.config, - "when": self.when, - } - - def __repr__(self): - return f"{self.name}({id(self)})" - - -class Filter(Origin): - def __init__(self, name, config, when="dataset-create"): - super().__init__(when) - assert isinstance(config, dict), f"Config must be a dictionary {config}" - self.name = name - self.config = _un_dotdict(config) - self._cache = {} - - def combine(self, previous, action, action_arguments): - - if previous is None: - # This can happen if the filter does not tag its output with an origin - # (e.g. a user plugin). In that case we try to get the origin from the action arguments - key = (id(action), id(action_arguments)) - if key not in self._cache: - - LOG.warning(f"No previous origin to combine with: {self}. Action: {action}") - LOG.warning(f"Connecting to action arguments {action_arguments}") - origins = set() - for k in action_arguments: - o = k.metadata("anemoi_origin", default=None) - if o is None: - raise ValueError( - f"Cannot combine origins, previous is None and action_arguments {action_arguments} has no origin" - ) - origins.add(o) - if len(origins) == 1: - self._cache[key] = origins.pop() - else: - self._cache[key] = Join(origins) - previous = self._cache[key] - - if previous in self._cache: - # We use a cache to avoid recomputing the same combination - return self._cache[previous] - - self._cache[previous] = Pipe(previous, self) - return self._cache[previous] - - def as_dict(self): - return { - "type": "filter", - "name": self.name, - "config": self.config, - "when": self.when, - } - - def __repr__(self): - return f"{self.name}({id(self)})" From fd7a9e778696a6a0b1cb749cc12a6ef6f0b43567 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:36:54 +0100 Subject: [PATCH 204/212] tidy up --- src/anemoi/datasets/create/input/__init__.py | 13 +++++++++---- src/anemoi/datasets/create/input/context.py | 9 +++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py index 62f94b8cf..e5746c610 100644 --- a/src/anemoi/datasets/create/input/__init__.py +++ b/src/anemoi/datasets/create/input/__init__.py @@ -9,13 +9,17 @@ from copy import deepcopy from functools import cached_property +from typing import TYPE_CHECKING from typing import Any +if TYPE_CHECKING: + from anemoi.datasets.create.input.action import Recipe + class InputBuilder: """Builder class for creating input data from configuration and data sources.""" - def __init__(self, config: dict, data_sources: dict | list) -> None: + def __init__(self, config: dict, data_sources: dict | list, **kwargs: Any) -> None: """Initialize the InputBuilder. Parameters @@ -27,14 +31,15 @@ def __init__(self, config: dict, data_sources: dict | list) -> None: **kwargs : Any Additional keyword arguments. """ + self.kwargs = kwargs self.config = deepcopy(config) self.data_sources = deepcopy(dict(data_sources=data_sources)) @cached_property - def action(self) -> Any: + def action(self) -> "Recipe": """Returns the action object based on the configuration.""" - from anemoi.datasets.create.input.action import Recipe - from anemoi.datasets.create.input.action import action_factory + from .action import Recipe + from .action import action_factory sources = action_factory(self.data_sources, "data_sources") input = action_factory(self.config, "input") diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py index 28c797dd5..89df7a727 100644 --- a/src/anemoi/datasets/create/input/context.py +++ b/src/anemoi/datasets/create/input/context.py @@ -18,9 +18,10 @@ class Context(ABC): """Context for building input data.""" - def __init__(self) -> None: + def __init__(self, /, argument: Any) -> None: self.results = {} self.cache = {} + self.argument = argument def trace(self, emoji, *message) -> None: @@ -33,7 +34,7 @@ def register(self, data: Any, path: list[str]) -> Any: assert path[0] in ("input", "data_sources"), path - LOG.info(f"Registering data at path: {'.'.join(str(x) for x in path)}") + LOG.info(f"Registering data at path: {path}") self.results[tuple(path)] = data return data @@ -46,9 +47,9 @@ def resolve(self, config): if path in self.results: config[key] = self.results[path] else: - print(f"Path not found {path}") + LOG.warning(f"Path not found {path}") for p in sorted(self.results): - print(f" Available paths: {p}") + LOG.info(f" Available paths: {p}") raise KeyError(f"Path {path} not found in results: {self.results.keys()}") return config From 29625a7b35d20b4988c09fdff8c6cdc7208a87c9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:38:40 +0100 Subject: [PATCH 205/212] tidy up --- src/anemoi/datasets/create/input/context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py index 89df7a727..a3077fc18 100644 --- a/src/anemoi/datasets/create/input/context.py +++ b/src/anemoi/datasets/create/input/context.py @@ -18,10 +18,9 @@ class Context(ABC): """Context for building input data.""" - def __init__(self, /, argument: Any) -> None: + def __init__(self, /) -> None: self.results = {} self.cache = {} - self.argument = argument def trace(self, emoji, *message) -> None: From 8ce9ddbb25bd0e9ed290404a74cafd52d9c582f3 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:41:07 +0100 Subject: [PATCH 206/212] tidy up --- .../datasets/create/input/data_sources.py | 10 +- .../datasets/create/input/repeated_dates.py | 270 ++++++++++++++++++ 2 files changed, 275 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/datasets/create/input/repeated_dates.py diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py index 31956d602..31bf3d8cc 100644 --- a/src/anemoi/datasets/create/input/data_sources.py +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -13,11 +13,11 @@ from earthkit.data import FieldList -from anemoi.datasets.create.input.action import Action -from anemoi.datasets.create.input.action import action_factory -from anemoi.datasets.create.input.misc import _tidy -from anemoi.datasets.create.input.result.field import Result -from anemoi.datasets.dates.groups import GroupOfDates +from ...dates.groups import GroupOfDates +from .action import Action +from .action import action_factory +from .misc import _tidy +from .result.field import Result LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py new file mode 100644 index 000000000..8262b2d13 --- /dev/null +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -0,0 +1,270 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from collections import defaultdict +from collections.abc import Generator +from typing import Any + +import numpy as np +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +LOG = logging.getLogger(__name__) + + +class DateMapper: + """A factory class to create DateMapper instances based on the given mode.""" + + @staticmethod + def from_mode(mode: str, source: Any, config: dict[str, Any]) -> "DateMapper": + """Create a DateMapper instance based on the given mode. + + Parameters + ---------- + mode : str + The mode to use for the DateMapper. + source : Any + The data source. + config : dict + Configuration parameters. + + Returns + ------- + DateMapper + An instance of DateMapper. + """ + MODES: dict = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + """A DateMapper implementation that maps dates to the closest available dates.""" + + def __init__(self, source: Any, frequency: str = "1h", maximum: str = "30d", skip_all_nans: bool = False) -> None: + """Initialize DateMapperClosest. + + Parameters + ---------- + source : Any + The data source. + frequency : str + Frequency of the dates. + maximum : str + Maximum time delta. + skip_all_nans : bool + Whether to skip all NaN values. + """ + self.source: Any = source + self.maximum: Any = frequency_to_timedelta(maximum) + self.frequency: Any = frequency_to_timedelta(frequency) + self.skip_all_nans: bool = skip_all_nans + self.tried: set[Any] = set() + self.found: set[Any] = set() + + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: + """Transform the group of dates to the closest available dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + info = {k: "no-data" for k in to_try} + + if not to_try: + LOG.warning(f"No new dates to try for {group_of_dates} in {self.source}") + # return [] + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + cnt = 0 + for f in result.datasource: + cnt += 1 + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + info[date] = "all-nans" + continue + + info[date] = "ok" + self.found.add(date) + + if cnt == 0: + raise ValueError(f"No data found for {group_of_dates} in {self.source}") + + self.tried.update(to_try) + + if not self.found: + for k, v in info.items(): + LOG.warning(f"{k}: {v}") + + raise ValueError(f"No matching data found for {asked_dates} in {self.source}") + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + """A DateMapper implementation that maps dates to specified climatology dates.""" + + def __init__(self, source: Any, year: int, day: int, hour: int | None = None) -> None: + """Initialize DateMapperClimatology. + + Parameters + ---------- + source : Any + The data source. + year : int + The year to map to. + day : int + The day to map to. + hour : Optional[int] + The hour to map to. + """ + self.year: int = year + self.day: int = day + self.hour: int | None = hour + + def transform(self, group_of_dates: Any) -> Generator[tuple[Any, Any], None, None]: + """Transform the group of dates to the specified climatology dates. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Generator[Tuple[Any, Any], None, None] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + if self.hour is not None: + new_date = new_date.replace(hour=self.hour, minute=0, second=0) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + """A DateMapper implementation that maps dates to a constant date.""" + + def __init__(self, source: Any, date: Any | None = None) -> None: + """Initialize DateMapperConstant. + + Parameters + ---------- + source : Any + The data source. + date : Optional[Any] + The constant date to map to. + """ + self.source: Any = source + self.date: Any | None = date + + def transform(self, group_of_dates: Any) -> tuple[Any, Any]: + """Transform the group of dates to a constant date. + + Parameters + ---------- + group_of_dates : Any + The group of dates to transform. + + Returns + ------- + Tuple[Any, Any] + Transformed dates. + """ + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] From 82b4fe04af84fad00ba081ac35067b4b1a50791d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:46:01 +0100 Subject: [PATCH 207/212] tidy up --- src/anemoi/datasets/dates/__init__.py | 80 +++++-------- src/anemoi/datasets/dates/groups.py | 155 ++++++++------------------ 2 files changed, 71 insertions(+), 164 deletions(-) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index 0ce767418..04b5177d1 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -27,15 +27,13 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]: """Extend a date range or list of dates into individual datetime objects. - Parameters - ---------- - x : Union[str, List[Any], Tuple[Any, ...]] - A date range string or list/tuple of dates. - - Yields - ------ - datetime.datetime - Individual datetime objects. + Args: + x (Union[str, List[Any], Tuple[Any, ...]]): A date range string or list/tuple of dates. + + Returns + ------- + Iterator[datetime.datetime] + An iterator of datetime objects. """ if isinstance(x, (list, tuple)): @@ -176,12 +174,9 @@ def summary(self) -> str: class ValuesDates(DatesProvider): """Class for handling a list of date values. - Parameters - ---------- - values : List[Union[str, datetime.datetime]] - List of date values. - **kwargs : Any - Additional arguments. + Args: + values (List[Union[str, datetime.datetime]]): List of date values. + **kwargs (Any): Additional arguments. """ def __init__(self, values: list[str | datetime.datetime], **kwargs: Any) -> None: @@ -218,16 +213,11 @@ def as_dict(self) -> dict[str, Any]: class StartEndDates(DatesProvider): """Class for generating dates between a start and end date with a specified frequency. - Parameters - ---------- - start : Union[str, datetime.datetime] - Start date. - end : Union[str, datetime.datetime] - End date. - frequency : Union[int, str] - Frequency of dates. - **kwargs : Any - Additional arguments. + Args: + start (Union[str, datetime.datetime]): Start date. + end (Union[str, datetime.datetime]): End date. + frequency (Union[int, str]): Frequency of dates. + **kwargs (Any): Additional arguments. """ def __repr__(self) -> str: @@ -295,13 +285,6 @@ def as_dict(self) -> dict[str, Any]: "frequency": frequency_to_string(self.frequency), }.update(self.kwargs) - def to_python(self) -> str: - """Convert the StartEndDates instance to a tuple of ISO-formatted date strings.""" - if self.frequency == datetime.timedelta(hours=1): - return (self.start.isoformat(), self.end.isoformat()) - else: - return (self.start.isoformat(), self.end.isoformat(), frequency_to_string(self.frequency)) - @property def start_date(self) -> datetime.datetime: return self.start @@ -314,16 +297,11 @@ def end_date(self) -> datetime.datetime: class Hindcast: """Class representing a single hindcast date. - Parameters - ---------- - date : datetime.datetime - The date of the hindcast. - refdate : datetime.datetime - The reference date. - hdate : datetime.datetime - The hindcast date. - step : int - The step value. + Args: + date (datetime.datetime): The date of the hindcast. + refdate (datetime.datetime): The reference date. + hdate (datetime.datetime): The hindcast date. + step (int): The step value. """ def __init__( @@ -346,18 +324,12 @@ def __init__( class HindcastsDates(DatesProvider): """Class for generating hindcast dates over a range of years. - Parameters - ---------- - start : Union[str, List[str]] - Start date(s). - end : Union[str, List[str]] - End date(s). - steps : List[int] - List of step values. - years : int - Number of years. - **kwargs : Any - Additional arguments. + Args: + start (Union[str, List[str]]): Start date(s). + end (Union[str, List[str]]): End date(s). + steps (List[int]): List of step values. + years (int): Number of years. + **kwargs (Any): Additional arguments. """ def __init__( diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index 547e99892..a72fdaa74 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -24,15 +24,11 @@ def _shorten(dates: list[datetime.datetime] | tuple[datetime.datetime, ...]) -> str | list[str]: """Shorten the list of dates for display. - Parameters - ---------- - dates : Union[List[datetime.datetime], Tuple[datetime.datetime, ...]] - The list of dates. - - Returns - ------- - Union[str, List[str]] - The shortened list of dates. + Args: + dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. + + Returns: + Union[str, List[str]]: The shortened list of dates. """ if isinstance(dates, (list, tuple)): dates = [d.isoformat() for d in dates] @@ -45,17 +41,6 @@ class GroupOfDates: """A class to represent a group of dates.""" def __init__(self, dates: list[datetime.datetime], provider: DatesProvider, partial_ok: bool = False) -> None: - """Initialise a GroupOfDates instance. - - Parameters - ---------- - dates : List[datetime.datetime] - List of dates. - provider : DatesProvider - The dates provider. - partial_ok : bool, optional - Whether partial groups are allowed (default is False). - """ assert isinstance(provider, DatesProvider), type(provider) assert isinstance(dates, list) @@ -66,45 +51,35 @@ def __init__(self, dates: list[datetime.datetime], provider: DatesProvider, part def __len__(self) -> int: """Return the number of dates in the group. - Returns - ------- - int - The number of dates. + Returns: + int: The number of dates. """ return len(self.dates) def __iter__(self) -> Iterator[datetime.datetime]: """Return an iterator over the dates in the group. - Returns - ------- - Iterator[datetime.datetime] - The iterator over the dates. + Returns: + Iterator[datetime.datetime]: The iterator over the dates. """ return iter(self.dates) def __repr__(self) -> str: """Return a string representation of the group of dates. - Returns - ------- - str - The string representation. + Returns: + str: The string representation. """ return f"GroupOfDates(dates={_shorten(self.dates)})" def __eq__(self, other: object) -> bool: """Check if two groups of dates are equal. - Parameters - ---------- - other : object - The other group of dates. + Args: + other (object): The other group of dates. - Returns - ------- - bool - True if the groups are equal, False otherwise. + Returns: + bool: True if the groups are equal, False otherwise. """ return isinstance(other, GroupOfDates) and self.dates == other.dates @@ -112,8 +87,7 @@ def __eq__(self, other: object) -> bool: class Groups: """A collection of groups of dates. - Examples - -------- + Examples: >>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] @@ -145,8 +119,7 @@ def __init__(self, **kwargs: Any) -> None: Parameters ---------- - **kwargs : Any - Arbitrary keyword arguments. Expected keys include: + **kwargs : Any : Arbitrary keyword arguments. Expected keys include: - group_by: Configuration for the Grouper. - Other keys for DatesProvider configuration. """ @@ -164,10 +137,8 @@ def provider(self) -> DatesProvider: def __iter__(self) -> Iterator[GroupOfDates]: """Return an iterator over the groups of dates. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ for go in self._grouper(self._dates): dates = self._filter(go.dates) @@ -178,10 +149,8 @@ def __iter__(self) -> Iterator[GroupOfDates]: def __len__(self) -> int: """Return the number of groups of dates. - Returns - ------- - int - The number of groups. + Returns: + int: The number of groups. """ return self._len @@ -199,30 +168,24 @@ def _len(self) -> int: def __repr__(self) -> str: """Return a string representation of the groups of dates. - Returns - ------- - str - The string representation. + Returns: + str: The string representation. """ return f"{self.__class__.__name__}(dates={len(self)},{_shorten(self._dates)})" def describe(self) -> str: """Return a summary description of the dates. - Returns - ------- - str - The summary description. + Returns: + str: The summary description. """ return self._dates.summary def one_date(self) -> GroupOfDates: """Return a group containing only one date. - Returns - ------- - GroupOfDates - The group containing only one date. + Returns: + GroupOfDates: The group containing only one date. """ go = next(iter(self)) return GroupOfDates([go.dates[0]], go.provider) @@ -237,24 +200,22 @@ def __init__(self, missing: list[datetime.datetime]) -> None: def __call__(self, dates: list[datetime.datetime]) -> list[datetime.datetime]: """Filter out missing dates from the list of dates. - Parameters - ---------- - dates : List[datetime.datetime] - The list of dates. + Args: + dates (List[datetime.datetime]): The list of dates. - Returns - ------- - List[datetime.datetime] - The filtered list of dates. + Returns: + List[datetime.datetime]: The filtered list of dates. """ return [d for d in dates if d not in self.missing] class Grouper(ABC): + """Abstract base class for grouping dates.""" @classmethod def from_config(cls, group_by: Any) -> "Grouper": """Create a grouper based on the configuration.""" + if isinstance(group_by, int) and group_by > 0: return GrouperByFixedSize(group_by) @@ -314,15 +275,11 @@ class GrouperOneGroup(Grouper): def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group all dates into a single group. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ assert isinstance(dates, DatesProvider), type(dates) @@ -333,27 +290,16 @@ class GrouperByKey(Grouper): """Group dates by a key.""" def __init__(self, key: Callable[[datetime.datetime], Any]) -> None: - """Initialise GrouperByKey with a key function. - - Parameters - ---------- - key : Callable[[datetime.datetime], Any] - Function to extract grouping key from a datetime. - """ self.key = key def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates based on the provided key. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ for _, g in itertools.groupby(sorted(dates, key=self.key), key=self.key): yield GroupOfDates(list(g), dates) @@ -363,27 +309,16 @@ class GrouperByFixedSize(Grouper): """Group dates by a fixed size.""" def __init__(self, size: int) -> None: - """Initialise GrouperByFixedSize with batch size. - - Parameters - ---------- - size : int - Number of dates per group. - """ self.size = size def __call__(self, dates: DatesProvider) -> Iterator[GroupOfDates]: """Group dates into fixed-size batches. - Parameters - ---------- - dates : DatesProvider - The dates provider. + Args: + dates (DatesProvider): The dates provider. - Returns - ------- - Iterator[GroupOfDates] - The iterator over the groups of dates. + Returns: + Iterator[GroupOfDates]: The iterator over the groups of dates. """ batch = [] From de8d99111813bc1446bf45c5eaff31d2feb354fa Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 17:47:51 +0100 Subject: [PATCH 208/212] tidy up --- tests/create/varno.json | 1 - 1 file changed, 1 deletion(-) delete mode 100644 tests/create/varno.json diff --git a/tests/create/varno.json b/tests/create/varno.json deleted file mode 100644 index 7b54097c7..000000000 --- a/tests/create/varno.json +++ /dev/null @@ -1 +0,0 @@ -{"fields": ["name", "code", "description"], "data": [["u", 3, "upper air u component"], ["v", 4, "upper air v component"], ["z", 1, "geopotential"], ["dz", 57, "thickness"], ["rh", 29, "upper air rel. humidity"], ["pwc", 9, "precipitable water content"], ["rh2m", 58, "2m rel. humidity"], ["t", 2, "upper air temperature (K)"], ["td", 59, "upper air dew point (K)"], ["t2m", 39, "2m temperature (K)"], ["td2m", 40, "2m dew point (K)"], ["ts", 11, "surface temperature (K)"], ["ptend", 30, "pressure tendency"], ["w", 60, "past weather (w)"], ["ww", 61, "present weather (ww)"], ["vv", 62, "visibility"], ["ch", 63, "type of high clouds (ch)"], ["cm", 64, "type of middle clouds (cm)"], ["cl", 65, "type of low clouds (cl)"], ["nh", 66, "cloud base height (nh) (meter)"], ["nn", 67, "low cloud amount (n)"], ["hshs", 68, "additional cloud group height (hh)"], ["c", 69, "additional cloud group type (c)"], ["ns", 70, "additional cloud group amount (ns)"], ["sdepth", 71, "snow depth"], ["e", 72, "state of ground (e)"], ["tgtg", 73, "ground temperature (tgtg)"], ["spsp1", 74, "special phenomena (spsp)#1"], ["spsp2", 75, "special phenomena (spsp)#2"], ["rs", 76, "ice code type (rs)"], ["eses", 77, "ice thickness (eses)"], ["is", 78, "ice (is)"], ["trtr", 79, "original time period of rain obs. (trtr)"], ["rr", 80, "6hr rain (liquid part)"], ["jj", 81, "max. temperature (jj)"], ["vs", 82, "ship speed (vs)"], ["ds", 83, "ship direction (ds)"], ["hwhw", 84, "wave height"], ["pwpw", 85, "wave period"], ["dwdw", 86, "wave direction"], ["gclg", 87, "general cloud group"], ["rhlc", 88, "rel. humidity from low clouds"], ["rhmc", 89, "rel. humidity from middle clouds"], ["rhhc", 90, "rel. humidity from high clouds"], ["n", 91, "total amount of clouds"], ["sfall", 92, "6hr snowfall (solid part of rain)"], ["ps", 110, "surface pressure"], ["dd", 111, "wind direction"], ["ff", 112, "wind force"], ["rawbt", 119, "brightness temperature (K)"], ["rawra", 120, "raw radiance"], ["satcl", 121, "cloud amount from satellite"], ["scatss", 122, "sigma 0"], ["du", 5, "wind shear (du)"], ["dv", 6, "wind shear (dv)"], ["u10m", 41, "10m u component (m/s)"], ["v10m", 42, "10m v component (m/s)"], ["rhlay", 19, "layer rel. humidity"], ["cllqw", 123, "cloud liquid water"], ["scatdd", 124, "ambiguous v component"], ["scatff", 125, "ambiguous u component"], ["q", 7, "specific humidity (q)"], ["scatwd", 126, "ambiguous wind direction"], ["scatws", 127, "ambiguous wind speed"], ["vsp", 8, "vertical speed"], ["vt", 56, "virtual temperature"], ["o3lay", 206, "layer ozone"], ["height", 156, "height"], ["1dvar", 215, "1d-var model level (pseudo)-variable"], ["w2", 160, "past weather 2 (used in synoptic maps)"], ["cpt", 130, "characteristic of pressure tendency (used in synoptic maps)"], ["tsts", 12, "sea water temperature (used in synoptic maps)"], ["refl", 192, "radar reflectivity"], ["apdss", 128, "atmospheric path delay in satellite signal"], ["bend_angle", 162, "radio occultation bending angle"], ["los", 187, "horizontal line-of-sight wind component"], ["aerod", 174, "aerosol optical depth at 0.55 microns"], ["limb_radiance", 163, "Limb Radiances"], ["chem3", 183, "chem3: co"], ["chem2", 182, "chem2: so2"], ["chem1", 181, "chem1: no2/nox"], ["cod", 175, "cloud optical depth"], ["rao", 176, "Ratio of fine mode to total aerosol optical depth at 0.55 microns"], ["od", 177, "optical depth"], ["rfltnc", 178, "Aerosol reflectance multi-channel"], ["nsoilm", 179, "normalized soil moisture (0-100%)"], ["soilm", 180, "soil moisture"], ["flgt_phase", 201, "phase of aircraft flight"], ["height_assignment_method", 211, "Height assignment method"], ["dopp", 195, "radar doppler wind"], ["ghg1", 186, "ghg1: carbon dioxide"], ["ghg2", 188, "ghg2: methane"], ["ghg3", 189, "ghg3: nitrous oxide"], ["bt_real", 190, "brightness temperature real part"], ["bt_imaginary", 191, "brightness temperature imaginary part"], ["prc", 202, "radar rain rate"], ["lnprc", 203, "log(radar rain rate mm/h + epsilon)"], ["libksc", 222, "lidar backscattering"], ["ralt_swh", 220, "significant wave height (m)"], ["ralt_sws", 221, "surface wind speed (m/s)"], ["rawbt_clear", 193, "brightness temperature for clear (K)"], ["rawbt_cloudy", 194, "brightness temperature for cloudy (K)"], ["binary_snow_cover", 223, "binary snow cover (0: no snow / 1: presence of snow)"], ["salinity", 224, "ocean salinity (PSU)"], ["potential_temp", 225, "potential temperature (Kelvin)"], ["humidity_mixing_ratio", 226, "humidity mixing ratio (kg/kg)"], ["airframe_icing", 227, "airframe icing"], ["turbulence_index", 228, "turbulence index"], ["pstation", 107, "Station pressure (Pa)"], ["pmsl", 108, "Mean sea-level pressure (Pa)"], ["pstandard", 109, "Standard level pressure (Pa)"], ["vert_vv", 218, "Vertical visibility (m)"], ["max_wind_shear1", 219, "Wind shear above and below 1st maximum wind in sonde profile (s-1)"], ["tot_zen_delay", 229, "Total zenith delay (GPS)"], ["tot_zen_delay_err", 230, "Total zenith delay error (GPS)"], ["cloud_top_temp", 231, "Cloud top temperature (K)"], ["rawsca", 233, "Scaled radiance"], ["cloud_top_press", 235, "Cloud top pressure (Pa)"], ["mean_freq", 241, "GPSRO mean frequency"], ["u_amb", 242, "Ambiguous u-wind component (m/s)"], ["v_amb", 243, "Ambiguous v-wind component (m/s)"], ["lwp", 244, "Liquid water path"], ["tcwv", 245, "Total column water vapour"], ["cloud_frac_clear", 247, "Cloud clear fraction"], ["rawbt_hirs", 248, "Raw brightness temperature specific to HIRS (K)"], ["rawbt_amsu", 249, "Raw brightness temperature specific to AMSU (K)"], ["rawbt_hirs20", 250, "Raw brightness temperature specific to HIRS (K)"], ["sea_ice", 253, "Sea ice fraction"], ["cloud_frac_covered", 257, "Cloud covered fraction"], ["level_mixing_ratio", 258, "humidity_mixing_ratio]"], ["radial_velocity", 259, "Radial velocity from doppler radar"], ["cloud_ice_water", 260, "Cloud ice water"], ["wind_gust", 261, "Maximum wind gust (m/s)"], ["mass_density", 262, "Mass density"], ["atmosphere_number", 263, "SFERICS number of atmospheres"], ["lightning", 265, "Lightning strike observation (ATDNET)"], ["level_cloud", 266, "Cloud fraction (multi-level)"], ["rawbt_amsr_89ghz", 267, "Raw brightness temperature specific to AMSR 89GHz channels (K)"], ["max_wind_shear2", 268, "Wind shear above and below 2nd maximum wind in sonde profile"], ["lower_layer_p", 269, "Pressure at bottom of layer SBUV (Pa)"], ["upper_layer_p", 270, "Pressure at top of later SBUV (Pa)"], ["cloud_cover", 271, "Total cloud cover"], ["depth", 272, "Depth (m)"], ["ssh", 273, "Sea surface height (m)"], ["rawbt_mwts", 274, "Raw brightness temperature specific to MWTS (K)"], ["rawbt_mwhs", 275, "Raw brightness temperature specific to MWHS (K)"], ["tot_lightning_flash_dens", 196, "total (cloud-to-ground plus intra-cloud) lightning flash density (fl/km2/day)"], ["cg_lightning_flash_dens", 197, "cloud-to-ground lightning flash density ( fl/km2/day)"], ["lidar_aerosol_extinction", 236, "lidar aerosol extinction (1/m)"], ["lidar_cloud_backscatter", 237, "lidar cloud backscatter"], ["lidar_cloud_extinction", 238, "lidar cloud extinction"], ["cloud_radar_reflectivity", 239, "cloud radar reflectivity"], ["lidar_aerosol_attenuated_backscatter", 280, "lidar aerosol attenuated backscatter (1/m*sr)"], ["q2m", 281, "specific humidity at 2m (kg/kg)"], ["chem6", 284, "volcanic SO2"], ["sla", 287, "sea level anomaly"], ["ice_freeboard", 286, "Height of sea ice above open water"], ["snow_freeboard", 285, "Height of snow and sea ice above open water"], ["visible_spectral_reflectance", 240, "Visible Spectral Reflectance"], ["od10", 288, "optical depth at 10 microns"], ["chem4", 184, "chem4: hcho"], ["chem5", 185, "chem5: go3"], ["frac_snow_cover", 282, "fractional snow cover"], ["cloud_doppler_velocity", 251, "vertical radar doppler velocity"], ["lidar_rayleigh_backscatter", 252, "lidar Rayleigh backscatter"], ["sigma0_sm", 283, "backscatter coefficient normalized at 40 degree (db)"], ["t2m_min", 37, "minimum 2m temperature (K)"], ["t2m_max", 38, "maximum 2m temperature (K)"], ["ssrd", 25, "downward surface solar radiation (J/m2)"]]} From cb5991e72985cdc04df9bec766be446a6941cdee Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 18:09:58 +0100 Subject: [PATCH 209/212] update --- .../use/gridded/observations/__init__.py | 313 ------------------ .../observations/legacy_obs_dataset.py | 200 ----------- .../use/gridded/observations/multi.py | 64 ---- 3 files changed, 577 deletions(-) delete mode 100644 src/anemoi/datasets/use/gridded/observations/__init__.py delete mode 100644 src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py delete mode 100644 src/anemoi/datasets/use/gridded/observations/multi.py diff --git a/src/anemoi/datasets/use/gridded/observations/__init__.py b/src/anemoi/datasets/use/gridded/observations/__init__.py deleted file mode 100644 index 804adddad..000000000 --- a/src/anemoi/datasets/use/gridded/observations/__init__.py +++ /dev/null @@ -1,313 +0,0 @@ -# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import os -from functools import cached_property -from typing import Any - -import numpy as np -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets.use.gridded.dataset import Dataset -from anemoi.datasets.use.gridded.debug import Node - -LOG = logging.getLogger(__name__) - - -def round_datetime(dt, frequency, up=True): - dt = dt.replace(minute=0, second=0, microsecond=0) - hour = dt.hour - if hour % frequency != 0: - dt = dt.replace(hour=(hour // frequency) * frequency) - dt = dt + datetime.timedelta(hours=frequency) - return dt - - -def make_dates(start, end, frequency): - if isinstance(start, np.datetime64): - start = start.astype(datetime.datetime) - if isinstance(end, np.datetime64): - end = end.astype(datetime.datetime) - - dates = [] - current_date = start - while current_date <= end: - dates.append(current_date) - current_date += frequency - - dates = [np.datetime64(d, "s") for d in dates] - dates = np.array(dates, dtype="datetime64[s]") - return dates - - -class ObservationsBase(Dataset): - resolution = None - - @cached_property - def shape(self): - return (len(self.dates), len(self.variables), "dynamic") - - def empty_item(self): - return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32) - - def metadata(self): - return dict(observations_datasets="obs datasets currenty have no metadata") - - def _check(self): - pass - - def __len__(self): - return len(self.dates) - - def tree(self): - return Node( - self, - [], - ) - - def __getitem__(self, i): - if isinstance(i, int): - return self.getitem(i) - - # The following may would work but is likely to change in the future - # if isinstance(i, slice): - # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))] - # if isinstance(i, list): - # return [self.getitem(j) for j in i] - - raise ValueError( - f"Expected int, got {i} of type {type(i)}. Only int is supported to index " - "observations datasets. Please use a second [] to select part of the data [i][a,b,c]" - ) - - @property - def variables(self): - raise NotImplementedError() - - def collect_input_sources(self): - LOG.warning("collect_input_sources method is not implemented") - return [] - - def constant_fields(self): - LOG.warning("constant_fields method is not implemented") - return [] - - @property - def dates(self): - return self._dates - - @property - def dtype(self): - return np.float32 - - @property - def field_shape(self): - return self.shape[1:] - - @property - def frequency(self): - assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}" - return self._frequency - - @property - def latitudes(self): - raise NotImplementedError("latitudes property is not implemented") - - @property - def longitudes(self): - raise NotImplementedError("longitudes property is not implemented") - - @property - def missing(self): - return [] - - def statistics_tendencies(self): - raise NotImplementedError("statistics_tendencies method is not implemented") - - def variables_metadata(self): - raise NotImplementedError("variables_metadata method is not implemented") - - -class ObservationsZarr(ObservationsBase): - def __init__(self, dataset, frequency=None, window=None): - import zarr - - if isinstance(dataset, zarr.hierarchy.Group): - dataset = dataset._store.path - - from anemoi.datasets.use.gridded.stores import dataset_lookup - - dataset = dataset_lookup(dataset) - self.path = dataset - assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" - - if frequency is None: - frequency = self._probe_attributes.get("frequency") - # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}") - if frequency is None: - frequency = "6h" - # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}") - self._frequency = frequency_to_timedelta(frequency) - assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}" - if self.frequency.total_seconds() != 6 * 3600: - LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown") - - frequency_hours = int(self.frequency.total_seconds() // 3600) - assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}" - - if window is None: - window = (-frequency_hours, 0) - if window != (-frequency_hours, 0): - raise ValueError("For now, only window = (- frequency, 0) are supported") - - self.window = window - - start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"] - start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end) - start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours) - - self._dates = make_dates(start + self.frequency, end, self.frequency) - - first_window_begin = start.strftime("%Y%m%d%H%M%S") - first_window_begin = int(first_window_begin) - # last_window_end must be the end of the time window of the last item - last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - - from anemoi.datasets.use.gridded.observations.legacy_obs_dataset import ObsDataset - - args = [self.path, first_window_begin, last_window_end] - kwargs = dict( - len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item - step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items - ) - self.forward = ObsDataset(*args, **kwargs) - - assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}" - assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}" - - if len(self.forward) != len(self.dates): - raise ValueError( - f"Dates are not consistent with the number of items in the dataset. " - f"The dataset contains {len(self.forward)} time windows. " - f"This is not compatible with the " - f"{len(self.dates)} requested dates with frequency={frequency_hours}" - f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} " - ) - - @property - def source(self): - return self.path - - def get_dataset_names(self): - name = os.path.basename(self.path) - if name.endswith(".zarr"): - name = name[:-5] - return [name] - - @cached_property - def _probe_attributes(self): - import zarr - - z = zarr.open(self.path, mode="r") - return dict(z.data.attrs) - - def get_aux(self, i): - data = self.forward[i] - - latitudes = data[:, self.name_to_index["__latitudes"]].numpy() - longitudes = data[:, self.name_to_index["__longitudes"]].numpy() - - reference = self.dates[i] - times = self.forward.get_dates(i) - if str(times.dtype) != "datetime64[s]": - LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ") - times = times.astype("datetime64[s]") - assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}" - timedeltas = times - reference - - assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" - assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" - - assert timedeltas.dtype == "timedelta64[s]", f"Expected timedelta64[s], got {timedeltas.dtype}" - return latitudes, longitudes, timedeltas - - def getitem(self, i): - data = self.forward[i] - - data = data.numpy().astype(np.float32) - assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}" - data = data.T - - if not data.size: - data = self.empty_item() - assert ( - data.shape[0] == self.shape[1] - ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}" - return data - - @cached_property - def variables(self): - colnames = self.forward.colnames - variables = [] - for n in colnames: - if n.startswith("obsvalue_"): - n = n.replace("obsvalue_", "") - if n == "latitude" or n == "lat": - assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}" - variables.append("__latitudes") - continue - if n == "longitude" or n == "lon": - assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}" - variables.append("__longitudes") - continue - assert not n.startswith("__"), f"Invalid name {n} found in {colnames}" - variables.append(n) - return variables - - @property - def name_to_index(self): - return {n: i for i, n in enumerate(self.variables)} - - @property - def statistics(self): - mean = self.forward.properties["means"] - mean = np.array(mean, dtype=np.float32) - - var = self.forward.properties["vars"] - var = np.array(var, dtype=np.float32) - stdev = np.sqrt(var) - - minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32) - maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32) - - assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}" - assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}" - assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}" - assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}" - return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum) - - def tree(self): - return Node( - self, - [], - path=self.path, - frequency=self.frequency, - ) - - def __repr__(self): - return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})" - - -def observations_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> ObservationsBase: - observations = kwargs.pop("observations") - - if not isinstance(observations, dict): - observations = dict(dataset=observations) - dataset = ObservationsZarr(**observations) - return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py b/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py deleted file mode 100644 index 85ab51583..000000000 --- a/src/anemoi/datasets/use/gridded/observations/legacy_obs_dataset.py +++ /dev/null @@ -1,200 +0,0 @@ -# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import numpy as np -import pandas as pd -import torch -import zarr -from torch.utils.data import Dataset - -LOG = logging.getLogger(__name__) - - -class ObsDataset(Dataset): - - def __init__( - self, - filename: str, - start: int, - end: int, - len_hrs: int, - step_hrs: int = None, - select: list[str] = None, - drop: list[str] = None, - ) -> None: - - self.filename = filename - self.z = zarr.open(filename, mode="r") - self.data = self.z["data"] - self.dt = self.z["dates"] # datetime only - self.hrly_index = self.z["idx_197001010000_1"] - self.colnames = self.data.attrs["colnames"] - self.selected_colnames = self.colnames - self.selected_cols_idx = np.arange(len(self.colnames)) - self.len_hrs = len_hrs - self.step_hrs = step_hrs if step_hrs else len_hrs - - # Create index for samples - self._setup_sample_index(start, end, self.len_hrs, self.step_hrs) - - self._load_properties() - - if select: - self.select(select) - - if drop: - self.drop(drop) - - def __getitem__( - self, - idx: int, - ) -> torch.tensor: - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - - data = self.data.oindex[start_row:end_row, self.selected_cols_idx] - - return torch.from_numpy(data) - - def __len__(self) -> int: - - return len(self.indices_start) - - def get_dates( - self, - idx: int, - ) -> np.ndarray: - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - dates = self.dt.oindex[start_row:end_row] - - assert len(dates.shape) == 2, dates.shape - dates = dates[:, 0] - - if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"): - dates = dates.astype("datetime64[s]") - - return dates - - def get_df(self, idx: int) -> pd.DataFrame: - """Convenience function to return data for sample idx packaged in a pandas DataFrame""" - - d = self.__getitem__(idx) - - df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx]) - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - - dts = self.dt[start_row:end_row, :] - df["datetime"] = dts - - return df - - def select(self, cols_list: list[str]) -> None: - """Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ - self.selected_colnames = cols_list - self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list]) - - def drop(self, cols_to_drop: list[str]) -> None: - """Allow user to drop specific columns from the dataset. - Get functions no longer return data for these columns after being set. - """ - mask = [name not in cols_to_drop for name in self.selected_colnames] - - self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep] - self.selected_cols_idx = self.selected_cols_idx[mask] - - def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: - """Returns a tuple of datetime objects describing the start and end times of the sample at position idx.""" - - if idx < 0: - idx = len(self) + idx - - time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1) - time_end = min( - self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)), - self.end_dt, - ) - - return (np.datetime64(time_start), np.datetime64(time_end)) - - def first_sample_with_data(self) -> int: - """Returns the position of the first sample which contains data.""" - return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None - - def last_sample_with_data(self) -> int: - """Returns the position of the last sample which contains data.""" - if self.indices_end.max() == 0: - last_sample = None - else: - last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1) - - return last_sample - - def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None: - """Dataset is divided into samples; - - each n_hours long - - sample 0 starts at start (yyyymmddhhmm) - - index array has one entry for each sample; contains the index of the first row - containing data for that sample - """ - - try: - from obsdata.config import config - - assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000" - except ImportError: - pass - base_yyyymmddhhmm = 197001010000 - - assert start > base_yyyymmddhhmm, ( - f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n" - f" Current value: {start}" - ) - - format_str = "%Y%m%d%H%M%S" - base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str) - self.start_dt = datetime.datetime.strptime(str(start), format_str) - self.end_dt = datetime.datetime.strptime(str(end), format_str) - - # Calculate hours since the base date for the requested dataset ranges - diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600) - diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600) - - # Find elements that need to be extracted from the hourly index - # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample - sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs) - sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end) - - # Initialize local index arrays - self.indices_start = np.zeros(sample_starts.shape, dtype=int) - self.indices_end = np.zeros(self.indices_start.shape, dtype=int) - - max_hrly_index = self.hrly_index.shape[0] - 1 - valid_start_mask = sample_starts <= max_hrly_index - valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index) - - # Copy elements from the hrly_index into the local index - self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]] - self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0) - - def _load_properties(self) -> None: - - self.properties = {} - - self.properties["means"] = self.data.attrs["means"] - self.properties["vars"] = self.data.attrs["vars"] - self.properties["data_idxs"] = self.data.attrs["data_idxs"] - self.properties["obs_id"] = self.data.attrs["obs_id"] diff --git a/src/anemoi/datasets/use/gridded/observations/multi.py b/src/anemoi/datasets/use/gridded/observations/multi.py deleted file mode 100644 index 5b2ca4967..000000000 --- a/src/anemoi/datasets/use/gridded/observations/multi.py +++ /dev/null @@ -1,64 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import os - -from anemoi.datasets import open_dataset - -LOG = logging.getLogger(__name__) - - -class LegacyDatasets: - def __init__(self, paths, start=None, end=None, **kwargs): - self.paths = paths - - if not start or not end: - print( - "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets" - ) - lst = [self._open_dataset(p, **kwargs) for p in paths] - start = min([d.dates[0] for d in lst]) - end = max([d.dates[-1] for d in lst]) - - self._datasets = { - os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty") - for p in paths - } - - first = list(self._datasets.values())[0] - for name, dataset in self._datasets.items(): - if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]: - raise ValueError("Datasets have different start and end times") - if dataset.frequency != first.frequency: - raise ValueError("Datasets have different frequencies") - - self._keys = self._datasets.keys - - self._first = list(self._datasets.values())[0] - - def _open_dataset(self, p, **kwargs): - if p.startswith("observations-"): - return open_dataset(observations=p, **kwargs) - else: - print("❗ Opening non-observations dataset:", p) - return open_dataset(p, **kwargs) - - def items(self): - return self._datasets.items() - - @property - def dates(self): - return self._first.dates - - def __len__(self): - return len(self._first) - - def __getitem__(self, i): - return {k: d[i] for k, d in self._datasets.items()} From 543645342e124d14d5a3bbd6c9962552cc659091 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 18:18:45 +0100 Subject: [PATCH 210/212] update --- src/anemoi/datasets/dates/groups.py | 8 +- src/anemoi/datasets/use/tabular/__init__.py | 8 - .../use/tabular/observations/__init__.py | 313 ------ .../observations/legacy_obs_dataset.py | 200 ---- .../use/tabular/observations/multi.py | 64 -- .../datasets/use/tabular/records/__init__.py | 936 ------------------ .../use/tabular/records/backends/__init__.py | 273 ----- src/anemoi/datasets/use/tabular/windows.py | 252 ----- 8 files changed, 4 insertions(+), 2050 deletions(-) delete mode 100644 src/anemoi/datasets/use/tabular/__init__.py delete mode 100644 src/anemoi/datasets/use/tabular/observations/__init__.py delete mode 100644 src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py delete mode 100644 src/anemoi/datasets/use/tabular/observations/multi.py delete mode 100644 src/anemoi/datasets/use/tabular/records/__init__.py delete mode 100644 src/anemoi/datasets/use/tabular/records/backends/__init__.py delete mode 100644 src/anemoi/datasets/use/tabular/windows.py diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index a72fdaa74..ceea33981 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -24,11 +24,11 @@ def _shorten(dates: list[datetime.datetime] | tuple[datetime.datetime, ...]) -> str | list[str]: """Shorten the list of dates for display. - Args: - dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. + backen Args: + dates (Union[List[datetime.datetime], Tuple[datetime.datetime, ...]]): The list of dates. - Returns: - Union[str, List[str]]: The shortened list of dates. + Returns: + Union[str, List[str]]: The shortened list of dates. """ if isinstance(dates, (list, tuple)): dates = [d.isoformat() for d in dates] diff --git a/src/anemoi/datasets/use/tabular/__init__.py b/src/anemoi/datasets/use/tabular/__init__.py deleted file mode 100644 index 9fc775e54..000000000 --- a/src/anemoi/datasets/use/tabular/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. diff --git a/src/anemoi/datasets/use/tabular/observations/__init__.py b/src/anemoi/datasets/use/tabular/observations/__init__.py deleted file mode 100644 index b231c2c66..000000000 --- a/src/anemoi/datasets/use/tabular/observations/__init__.py +++ /dev/null @@ -1,313 +0,0 @@ -# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import os -from functools import cached_property -from typing import Any - -import numpy as np -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets.use.gridded.dataset import Dataset -from anemoi.datasets.use.gridded.debug import Node - -LOG = logging.getLogger(__name__) - - -def round_datetime(dt, frequency, up=True): - dt = dt.replace(minute=0, second=0, microsecond=0) - hour = dt.hour - if hour % frequency != 0: - dt = dt.replace(hour=(hour // frequency) * frequency) - dt = dt + datetime.timedelta(hours=frequency) - return dt - - -def make_dates(start, end, frequency): - if isinstance(start, np.datetime64): - start = start.astype(datetime.datetime) - if isinstance(end, np.datetime64): - end = end.astype(datetime.datetime) - - dates = [] - current_date = start - while current_date <= end: - dates.append(current_date) - current_date += frequency - - dates = [np.datetime64(d, "s") for d in dates] - dates = np.array(dates, dtype="datetime64[s]") - return dates - - -class ObservationsBase(Dataset): - resolution = None - - @cached_property - def shape(self): - return (len(self.dates), len(self.variables), "dynamic") - - def empty_item(self): - return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32) - - def metadata(self): - return dict(observations_datasets="obs datasets currenty have no metadata") - - def _check(self): - pass - - def __len__(self): - return len(self.dates) - - def tree(self): - return Node( - self, - [], - ) - - def __getitem__(self, i): - if isinstance(i, int): - return self.getitem(i) - - # The following may would work but is likely to change in the future - # if isinstance(i, slice): - # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))] - # if isinstance(i, list): - # return [self.getitem(j) for j in i] - - raise ValueError( - f"Expected int, got {i} of type {type(i)}. Only int is supported to index " - "observations datasets. Please use a second [] to select part of the data [i][a,b,c]" - ) - - @property - def variables(self): - raise NotImplementedError() - - def collect_input_sources(self): - LOG.warning("collect_input_sources method is not implemented") - return [] - - def constant_fields(self): - LOG.warning("constant_fields method is not implemented") - return [] - - @property - def dates(self): - return self._dates - - @property - def dtype(self): - return np.float32 - - @property - def field_shape(self): - return self.shape[1:] - - @property - def frequency(self): - assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}" - return self._frequency - - @property - def latitudes(self): - raise NotImplementedError("latitudes property is not implemented") - - @property - def longitudes(self): - raise NotImplementedError("longitudes property is not implemented") - - @property - def missing(self): - return [] - - def statistics_tendencies(self): - raise NotImplementedError("statistics_tendencies method is not implemented") - - def variables_metadata(self): - raise NotImplementedError("variables_metadata method is not implemented") - - -class ObservationsZarr(ObservationsBase): - def __init__(self, dataset, frequency=None, window=None): - import zarr - - if isinstance(dataset, zarr.hierarchy.Group): - dataset = dataset._store.path - - from anemoi.datasets.use.gridded.stores import dataset_lookup - - dataset = dataset_lookup(dataset) - self.path = dataset - assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}" - - if frequency is None: - frequency = self._probe_attributes.get("frequency") - # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}") - if frequency is None: - frequency = "6h" - # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}") - self._frequency = frequency_to_timedelta(frequency) - assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}" - if self.frequency.total_seconds() != 6 * 3600: - LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown") - - frequency_hours = int(self.frequency.total_seconds() // 3600) - assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}" - - if window is None: - window = (-frequency_hours, 0) - if window != (-frequency_hours, 0): - raise ValueError("For now, only window = (- frequency, 0) are supported") - - self.window = window - - start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"] - start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end) - start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours) - - self._dates = make_dates(start + self.frequency, end, self.frequency) - - first_window_begin = start.strftime("%Y%m%d%H%M%S") - first_window_begin = int(first_window_begin) - # last_window_end must be the end of the time window of the last item - last_window_end = int(end.strftime("%Y%m%d%H%M%S")) - - from anemoi.datasets.use.gridded.tabular.observations.legacy_obs_dataset import ObsDataset - - args = [self.path, first_window_begin, last_window_end] - kwargs = dict( - len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item - step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items - ) - self.forward = ObsDataset(*args, **kwargs) - - assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}" - assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}" - - if len(self.forward) != len(self.dates): - raise ValueError( - f"Dates are not consistent with the number of items in the dataset. " - f"The dataset contains {len(self.forward)} time windows. " - f"This is not compatible with the " - f"{len(self.dates)} requested dates with frequency={frequency_hours}" - f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} " - ) - - @property - def source(self): - return self.path - - def get_dataset_names(self): - name = os.path.basename(self.path) - if name.endswith(".zarr"): - name = name[:-5] - return [name] - - @cached_property - def _probe_attributes(self): - import zarr - - z = zarr.open(self.path, mode="r") - return dict(z.data.attrs) - - def get_aux(self, i): - data = self.forward[i] - - latitudes = data[:, self.name_to_index["__latitudes"]].numpy() - longitudes = data[:, self.name_to_index["__longitudes"]].numpy() - - reference = self.dates[i] - times = self.forward.get_dates(i) - if str(times.dtype) != "datetime64[s]": - LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ") - times = times.astype("datetime64[s]") - assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}" - timedeltas = times - reference - - assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}" - assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}" - - assert timedeltas.dtype == "timedelta64[s]", f"Expected timedelta64[s], got {timedeltas.dtype}" - return latitudes, longitudes, timedeltas - - def getitem(self, i): - data = self.forward[i] - - data = data.numpy().astype(np.float32) - assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}" - data = data.T - - if not data.size: - data = self.empty_item() - assert ( - data.shape[0] == self.shape[1] - ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}" - return data - - @cached_property - def variables(self): - colnames = self.forward.colnames - variables = [] - for n in colnames: - if n.startswith("obsvalue_"): - n = n.replace("obsvalue_", "") - if n == "latitude" or n == "lat": - assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}" - variables.append("__latitudes") - continue - if n == "longitude" or n == "lon": - assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}" - variables.append("__longitudes") - continue - assert not n.startswith("__"), f"Invalid name {n} found in {colnames}" - variables.append(n) - return variables - - @property - def name_to_index(self): - return {n: i for i, n in enumerate(self.variables)} - - @property - def statistics(self): - mean = self.forward.properties["means"] - mean = np.array(mean, dtype=np.float32) - - var = self.forward.properties["vars"] - var = np.array(var, dtype=np.float32) - stdev = np.sqrt(var) - - minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32) - maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32) - - assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}" - assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}" - assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}" - assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}" - return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum) - - def tree(self): - return Node( - self, - [], - path=self.path, - frequency=self.frequency, - ) - - def __repr__(self): - return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})" - - -def observations_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> ObservationsBase: - observations = kwargs.pop("observations") - - if not isinstance(observations, dict): - observations = dict(dataset=observations) - dataset = ObservationsZarr(**observations) - return dataset._subset(**kwargs) diff --git a/src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py b/src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py deleted file mode 100644 index 85ab51583..000000000 --- a/src/anemoi/datasets/use/tabular/observations/legacy_obs_dataset.py +++ /dev/null @@ -1,200 +0,0 @@ -# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging - -import numpy as np -import pandas as pd -import torch -import zarr -from torch.utils.data import Dataset - -LOG = logging.getLogger(__name__) - - -class ObsDataset(Dataset): - - def __init__( - self, - filename: str, - start: int, - end: int, - len_hrs: int, - step_hrs: int = None, - select: list[str] = None, - drop: list[str] = None, - ) -> None: - - self.filename = filename - self.z = zarr.open(filename, mode="r") - self.data = self.z["data"] - self.dt = self.z["dates"] # datetime only - self.hrly_index = self.z["idx_197001010000_1"] - self.colnames = self.data.attrs["colnames"] - self.selected_colnames = self.colnames - self.selected_cols_idx = np.arange(len(self.colnames)) - self.len_hrs = len_hrs - self.step_hrs = step_hrs if step_hrs else len_hrs - - # Create index for samples - self._setup_sample_index(start, end, self.len_hrs, self.step_hrs) - - self._load_properties() - - if select: - self.select(select) - - if drop: - self.drop(drop) - - def __getitem__( - self, - idx: int, - ) -> torch.tensor: - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - - data = self.data.oindex[start_row:end_row, self.selected_cols_idx] - - return torch.from_numpy(data) - - def __len__(self) -> int: - - return len(self.indices_start) - - def get_dates( - self, - idx: int, - ) -> np.ndarray: - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - dates = self.dt.oindex[start_row:end_row] - - assert len(dates.shape) == 2, dates.shape - dates = dates[:, 0] - - if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"): - dates = dates.astype("datetime64[s]") - - return dates - - def get_df(self, idx: int) -> pd.DataFrame: - """Convenience function to return data for sample idx packaged in a pandas DataFrame""" - - d = self.__getitem__(idx) - - df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx]) - - start_row = self.indices_start[idx] - end_row = self.indices_end[idx] - - dts = self.dt[start_row:end_row, :] - df["datetime"] = dts - - return df - - def select(self, cols_list: list[str]) -> None: - """Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ - self.selected_colnames = cols_list - self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list]) - - def drop(self, cols_to_drop: list[str]) -> None: - """Allow user to drop specific columns from the dataset. - Get functions no longer return data for these columns after being set. - """ - mask = [name not in cols_to_drop for name in self.selected_colnames] - - self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep] - self.selected_cols_idx = self.selected_cols_idx[mask] - - def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]: - """Returns a tuple of datetime objects describing the start and end times of the sample at position idx.""" - - if idx < 0: - idx = len(self) + idx - - time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1) - time_end = min( - self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)), - self.end_dt, - ) - - return (np.datetime64(time_start), np.datetime64(time_end)) - - def first_sample_with_data(self) -> int: - """Returns the position of the first sample which contains data.""" - return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None - - def last_sample_with_data(self) -> int: - """Returns the position of the last sample which contains data.""" - if self.indices_end.max() == 0: - last_sample = None - else: - last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1) - - return last_sample - - def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None: - """Dataset is divided into samples; - - each n_hours long - - sample 0 starts at start (yyyymmddhhmm) - - index array has one entry for each sample; contains the index of the first row - containing data for that sample - """ - - try: - from obsdata.config import config - - assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000" - except ImportError: - pass - base_yyyymmddhhmm = 197001010000 - - assert start > base_yyyymmddhhmm, ( - f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n" - f" Current value: {start}" - ) - - format_str = "%Y%m%d%H%M%S" - base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str) - self.start_dt = datetime.datetime.strptime(str(start), format_str) - self.end_dt = datetime.datetime.strptime(str(end), format_str) - - # Calculate hours since the base date for the requested dataset ranges - diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600) - diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600) - - # Find elements that need to be extracted from the hourly index - # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample - sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs) - sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end) - - # Initialize local index arrays - self.indices_start = np.zeros(sample_starts.shape, dtype=int) - self.indices_end = np.zeros(self.indices_start.shape, dtype=int) - - max_hrly_index = self.hrly_index.shape[0] - 1 - valid_start_mask = sample_starts <= max_hrly_index - valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index) - - # Copy elements from the hrly_index into the local index - self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]] - self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0) - - def _load_properties(self) -> None: - - self.properties = {} - - self.properties["means"] = self.data.attrs["means"] - self.properties["vars"] = self.data.attrs["vars"] - self.properties["data_idxs"] = self.data.attrs["data_idxs"] - self.properties["obs_id"] = self.data.attrs["obs_id"] diff --git a/src/anemoi/datasets/use/tabular/observations/multi.py b/src/anemoi/datasets/use/tabular/observations/multi.py deleted file mode 100644 index 31fc4e1dd..000000000 --- a/src/anemoi/datasets/use/tabular/observations/multi.py +++ /dev/null @@ -1,64 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import logging -import os - -from anemoi.datasets.use.gridded import open_dataset - -LOG = logging.getLogger(__name__) - - -class LegacyDatasets: - def __init__(self, paths, start=None, end=None, **kwargs): - self.paths = paths - - if not start or not end: - print( - "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets" - ) - lst = [self._open_dataset(p, **kwargs) for p in paths] - start = min([d.dates[0] for d in lst]) - end = max([d.dates[-1] for d in lst]) - - self._datasets = { - os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty") - for p in paths - } - - first = list(self._datasets.values())[0] - for name, dataset in self._datasets.items(): - if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]: - raise ValueError("Datasets have different start and end times") - if dataset.frequency != first.frequency: - raise ValueError("Datasets have different frequencies") - - self._keys = self._datasets.keys - - self._first = list(self._datasets.values())[0] - - def _open_dataset(self, p, **kwargs): - if p.startswith("observations-"): - return open_dataset(observations=p, **kwargs) - else: - print("❗ Opening non-observations dataset:", p) - return open_dataset(p, **kwargs) - - def items(self): - return self._datasets.items() - - @property - def dates(self): - return self._first.dates - - def __len__(self): - return len(self._first) - - def __getitem__(self, i): - return {k: d[i] for k, d in self._datasets.items()} diff --git a/src/anemoi/datasets/use/tabular/records/__init__.py b/src/anemoi/datasets/use/tabular/records/__init__.py deleted file mode 100644 index 11c2b2565..000000000 --- a/src/anemoi/datasets/use/tabular/records/__init__.py +++ /dev/null @@ -1,936 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime -import logging -import os -from collections import defaultdict -from collections.abc import Mapping -from functools import cached_property - -import numpy as np -from anemoi.utils.config import load_any_dict_format -from anemoi.utils.dates import frequency_to_timedelta - -from anemoi.datasets.use.gridded.debug import Node - -from ..windows import window_from_str -from .backends import backend_factory - -LOG = logging.getLogger(__name__) - -if os.environ.get("ANEMOI_DATASET_COUNTER", "0") == "1": - - def counter(func): - def wrapper(*args, **kwargs): - count = 0 - for i in range(len(args[0])): - count += 1 - yield func(*args, **kwargs) - print(f"Counter: {count} calls to {func.__name__}") - - return wrapper - -else: - - def counter(func): - return func - - -def _to_numpy_timedelta(td): - if isinstance(td, np.timedelta64): - assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" - return td - return np.timedelta64(int(td.total_seconds()), "s") - - -def open_records_dataset(dataset, **kwargs): - metadata_path = os.path.join(dataset, "metadata.json") - if not os.path.exists(metadata_path): - return None - metadata = load_any_dict_format(metadata_path) - kwargs["backend"] = kwargs.get("backend", metadata["backend"]) - return RecordsDataset(dataset, **kwargs) - - -def merge_data(list_of_dicts): - merged = defaultdict(list) - for d in list_of_dicts: - for key, value in d.items(): - merged[key].append(value) - return {k: np.hstack(v) for k, v in merged.items()} - - -def _to_numpy_date(d): - if isinstance(d, np.datetime64): - assert d.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {d.dtype}" - return d - assert isinstance(d, datetime.datetime), f"date must be a datetime.datetime, got {type(d)}" - return _to_numpy_dates([d])[0] - - -def _to_numpy_dates(d): - return np.array(d, dtype="datetime64[s]") - - -class BaseRecordsDataset: - """This is the base class for all datasets based on records. - Records datasets are datasets that can be indexed by time (int) or by group (str). - A record dataset is designed for observations, where multiple array of difference shapes need to be stored for each date. - They have the same concept or start_date, end_date, frequency as fields datasets, but each date correspond to a window. - All windows have the same size (the window span can be different from the dataset frequency) - - variables in a record datasets are identified by a group and a name. - """ - - # Depending on the context, a variable is identified by "group.name", - # or using a dict with keys as groups and values as list of names. - # most of the code should be agnostic and transform one format to the other when needed. - - def __getitem__(self, i: int | str): - if isinstance(i, str): - return self._getgroup(i) - - if isinstance(i, (int, np.integer)): - return self._getrecord(i) - - raise ValueError(f"Invalid index {i}, must be int or str") - - @cached_property - def window(self): - """Returns a string representation of the relative window of the dataset, such as '(-3h, 3h]'.""" - return str(self._window) - - def _getgroup(self, group: str): - """Returns a Tabular object for the group. As a partial function when argument group is given but i is not.""" - return Tabular(self, group) - - def _getrecord(self, i: int): - """Returns a Record object for the time step i. As a partial function when argument i is given but group is not.""" - return Record(self, i) - - def _load_data(self, i: int) -> dict: - """Load the data for a specific time step or window (i). - It is expected to return a dict containing keys of the form: - - - "data:group1" : numpy array - - "latitudes:group1" : numpy array - - "longitudes:group1" : numpy array - - "metadata:group1" : - - ... - - "data:group2" : numpy array - - "latitudes:group2" : numpy array - - ... - """ - raise NotImplementedError("Must be implemented in subclass") - - @property - def start_date(self): - return self.dates[0] - - @property - def end_date(self): - if len(self.dates) == 0: - return None - if len(self.dates) == 1: - return self.dates[0] - return self.dates[-1] - - @property - def groups(self): - raise NotImplementedError("Must be implemented in subclass") - - def _subset(self, **kwargs): - window = kwargs.pop("window", None) - if window is not None: - return Rewindowed(self, window)._subset(**kwargs) - - start = kwargs.pop("start", None) - end = kwargs.pop("end", None) - if start is not None or end is not None: - - def _dates_to_indices(start, end): - from anemoi.datasets.use.gridded.misc import as_first_date - from anemoi.datasets.use.gridded.misc import as_last_date - - start = self.dates[0] if start is None else as_first_date(start, self.dates) - end = self.dates[-1] if end is None else as_last_date(end, self.dates) - - return [i for i, date in enumerate(self.dates) if start <= date <= end] - - return RecordsSubset(self, _dates_to_indices(start, end), {"start": start, "end": end})._subset(**kwargs) - - frequency = kwargs.pop("frequency", self.frequency) - if frequency: - frequency = frequency_to_timedelta(frequency) - current = self.frequency.total_seconds() - new = frequency.total_seconds() - if current != new and current % new == 0: - return IncreaseFrequency(self, frequency)._subset(**kwargs) - elif current != new and new % current == 0: - raise NotImplementedError("Decreasing frequency not implemented yet") - # return DecreaseFrequency(self, frequency)._subset(**kwargs) - assert self.frequency == frequency, (self.frequency, frequency) - - select = kwargs.pop("select", None) - if select is not None: - return Select(self, select)._subset(**kwargs) - - set_group = kwargs.pop("set_group", None) - if set_group is not None: - return SetGroup(self, set_group)._subset(**kwargs) - - rename = kwargs.pop("rename", None) - if rename is not None: - return Rename(self, rename)._subset(**kwargs) - - for k in kwargs: - if k in ["backend"]: - continue - raise ValueError(f"Invalid kwargs {kwargs}, must be 'start', 'end', 'frequency' or 'select'") - - return self - - def mutate(self): - return self - - def _check(self): - pass - - @property - def name_to_index(self): - raise NotImplementedError("Must be implemented in subclass") - - -class RecordsForward(BaseRecordsDataset): - def __init__(self, dataset): - self.forward = dataset - - @property - def statistics(self): - return self.forward.statistics - - @property - def variables(self): - return self.forward.variables - - @property - def groups(self): - return self.forward.groups - - @property - def dates(self): - return np.array(self.forward.dates, dtype="datetime64[s]") - - @property - def name_to_index(self): - return self.forward.name_to_index - - @property - def frequency(self): - return self.forward.frequency - - @property - def metadata(self): - return self.forward.metadata - - @property - def _window(self): - return self.forward._window - - @property - def shapes(self): - return self.forward.shapes - - def __len__(self): - return len(self.dates) - - def tree(self): - return Node(self, [self.forward.tree()], **self.reason) - - -class IncreaseFrequency(RecordsForward): - # change the frequency of a records dataset by splitting the windows to fit the new frequency - # the new frequency must be a divisor of the original frequency (e.g. 6h -> 3h, but not 3h -> 6h) (and not 6h -> 5h) - # and the window length should match the frequency - def __init__(self, dataset, frequency): - super().__init__(dataset) - self.dataset = dataset - self._frequency = frequency_to_timedelta(frequency) - self.reason = {"frequency": frequency} - - self._n = self.dataset.frequency / self._frequency - if int(self._n) != self._n: - raise ValueError(f"Cannot split frequency {self.dataset.frequency} to {frequency}, not a multiple") - self._n = int(self._n) - - if self.dataset._window.end - self.dataset._window.start != self.dataset.frequency: - raise ValueError( - f"Cannot split frequency {self.dataset.frequency} to {frequency}, window {self.dataset._window} does not match frequency" - ) - - @cached_property - def _window(self): - previous = self.dataset._window - if isinstance(previous, int): - previous = window_from_str(previous) - return previous / self._n - - def __len__(self): - return len(self.dataset) * self._n - - @cached_property - def dates(self): - dates = [] - freq = _to_numpy_timedelta(self._frequency) - for date in self.dataset.dates: - dates += [date + i * freq for i in range(self._n)] - return np.array(dates, dtype="datetime64[s]") - - @property - def frequency(self): - return self._frequency - - def metadata(self): - return self.dataset.metadata - - def _load_data(self, i): - j = i // self._n - k = i % self._n - # k = 0 -> shift of (self._n - 1) * self.frequency - # k = ... - # k = self._n - 1 -> shift of 0 (0 * self.frequency) - # so we need to shift by (self._n - 1 - k) * self.frequency - assert k < self._n, (k, self._n) - assert k >= 0 - - s = self._window.start - e = self._window.end - - ref_timedelta = -self.dataset.frequency + (k + 1) * self.frequency - start_delta = ref_timedelta + s - end_delta = ref_timedelta + e - # print( - # f" {i}={j}*{self._n}+{k} ({self.dates[i]}) -> ref_timedelta={ref_timedelta.total_seconds()/3600}, [start, end] = [{start_delta.total_seconds()/3600}, {end_delta.total_seconds()/3600}]" - # ) - - start_delta = _to_numpy_timedelta(start_delta) - end_delta = _to_numpy_timedelta(end_delta) - ref_timedelta = _to_numpy_timedelta(ref_timedelta) - - too_much_data = self.dataset._load_data(j) - - out = {} - for group in self.groups: - timedeltas = too_much_data[f"timedeltas:{group}"] - if timedeltas.dtype != "timedelta64[s]": - raise ValueError(f"Wrong type for {group}") - - if self._window.include_start: - mask = timedeltas >= start_delta - else: - mask = timedeltas > start_delta - if self._window.include_end: - mask &= timedeltas <= end_delta - else: - mask &= timedeltas < end_delta - - out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] - out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] - out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] - out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] - ref_timedelta - out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] - - return out - - def tree(self): - return Node(self, [self.dataset.tree()], **self.reason) - - -class FieldsRecords(RecordsForward): - """A wrapper around a FieldsDataset to provide a consistent interface for records datasets.""" - - def __init__(self, fields_dataset, name): - """wrapper around a fields dataset to provide a consistent interface for records datasets. - A FieldsRecords appears as a RecordsDataset with a single group. - This allows merging fields datasets with other records datasets. - Parameters: - fields_dataset: must be a regular fields dataset - name: the name of the group - . - """ - self.forward = fields_dataset - from anemoi.datasets.use.gridded.dataset import Dataset - - assert isinstance(fields_dataset, Dataset), f"fields_dataset must be a Dataset, got {type(fields_dataset)}" - self._name = name - self._groups = [name] - self.reason = {"name": name} - - @property - def metadata(self): - return self.forward.metadata - - def _nest_in_dict(self, obj): - """Helper to nest the object in a dict with the name as key.""" - return {self._name: obj} - - def _load_data(self, i): - data = self.forward[i] - out = {} - out[f"data:{self._name}"] = data - out[f"latitudes:{self._name}"] = self.forward.latitudes - out[f"longitudes:{self._name}"] = self.forward.longitudes - out[f"timedeltas:{self._name}"] = np.zeros(data.shape[-1], dtype="timedelta64[s]") # + _to_numpy_date( - # self.forward.dates[i] - # ) - out[f"metadata:{self._name}"] = self.forward.metadata() - return out - - @property - def groups(self): - return self._groups - - @property - def statistics(self): - return self._nest_in_dict(self.forward.statistics) - - @property - def variables(self): - return self._nest_in_dict(self.forward.variables) - - @property - def dates(self): - return np.array(self.forward.dates, dtype="datetime64[s]") - - @property - def longitudes(self): - return self._nest_in_dict(self.forward.longitudes) - - @property - def latitudes(self): - return self._nest_in_dict(self.forward.latitudes) - - @property - def name_to_index(self): - return self._nest_in_dict(self.forward.name_to_index) - - @property - def frequency(self): - return self.forward.frequency - - @property - def _window(self): - return self.forward._window - - @property - def shapes(self): - return self._nest_in_dict(self.forward.shape) - - def __len__(self): - return len(self.forward.dates) - - -class BaseRename(RecordsForward): - """Renames variables in a records dataset.""" - - def __init__(self, dataset, rename): - self.forward = dataset - assert isinstance(rename, dict) - for k, v in rename.items(): - assert isinstance(k, str), k - assert isinstance(v, str), v - self.rename = rename - self.reason = {"rename": rename} - - @property - def statistics(self): - return {self.rename.get(k, k): v for k, v in self.forward.statistics.items()} - - @property - def variables(self): - return {self.rename.get(k, k): v for k, v in self.forward.variables.items()} - - @property - def name_to_index(self): - return {self.rename.get(k, k): v for k, v in self.forward.name_to_index.items()} - - @property - def groups(self): - return [self.rename.get(k, k) for k in self.forward.groups] - - -class Rename(BaseRename): - pass - - -class SetGroup(BaseRename): - def __init__(self, dataset, set_group): - if len(dataset.groups) != 1: - raise ValueError(f"{self.__class__.__name__} can only be used with datasets containing a single group.") - - super().__init__(dataset, {dataset.groups[0]: set_group}) - - def _load_data(self, i): - return self.dataset._load_data(i) - - -def match_variable(lst, group, name): - # lst must be a list of strings with dots (if there is no dot, it is automatically added at the end) - # - a dict with keys as group and values as list of strings - - if name == "__latitudes" or name == "__longitudes": - # This should disappear in the future, when we stop saving a duplicate of lat/lon in the data - return False - - lst = [k if "." in k else f"{k}.*" for k in lst] - - key = f"{group}.{name}" - if key in lst: - return True - if f"{group}.*" in lst: - return True - if f"*.{name}" in lst: - return True - if "*" in lst: - return True - return False - - -class Rewindowed(RecordsForward): - # change the window of a records dataset - # similar to changing the frequency of a dataset - - def __init__(self, dataset, window): - super().__init__(dataset) - self.dataset = dataset - - # in this class anything with 1 refers to the original window/dataset - # and anything with 2 refers to the new window/dataset - - self._window1 = self.forward._window - self._window2 = window_from_str(window) - self.reason = {"window": self.window} - - self._dates1 = _to_numpy_dates(self.forward.dates) - dates = self._dates1 - self.dates_offset = 0 - while len(dates) > 0 and not self._window1.starts_before(self._dates1, dates, self._window2): - LOG.warning(f"Removing first date {dates[0]} because it is to early") - self.dates_offset += 1 - dates = dates[1:] - while len(dates) > 0 and not self._window1.ends_after(self._dates1, dates, self._window2): - LOG.warning(f"Removing last date {dates[-1]} because it is to late") - dates = dates[:-1] - - if len(dates) == 0: - raise ValueError( - f"No dates left after rewindowing {self._window1} -> {self._window2} (frequency={self.frequency}), check your window" - ) - self._dates = dates - - before_span1 = self._window1.start / self.frequency - before_span2 = self._window2.start / self.frequency - delta_before_span = before_span2 - before_span1 - if delta_before_span == int(delta_before_span): - if not self._window1.include_start and self._window2.include_start: - # if the start of the window is not included, we need to read one more index - delta_before_span -= 1 - delta_before_span = int(delta_before_span) - self.delta_before_span = delta_before_span - - after_span1 = self._window1.end / self.frequency - after_span2 = self._window2.end / self.frequency - delta_after_span = after_span2 - after_span1 - if delta_after_span == int(delta_after_span): - if not self._window1.include_end and self._window2.include_end: - # if the end of the window is not included, we need to read one more index - delta_after_span += 1 - delta_after_span = int(delta_after_span) - self.delta_after_span = delta_after_span - - @property - def window(self): - return self._window2 - - @property - def dates(self): - return np.array(self._dates, dtype="datetime64[s]") - - def __len__(self): - return len(self.dates) - - @property - def frequency(self): - return self.forward.frequency - - def _load_data(self, i): - print(f"Rewindowing data for i={i} (date={self.dates[i]}) : {self._window1} -> {self._window2}") - - first_j = i + self.delta_before_span - last_j = i + self.delta_after_span - - first_j = first_j + self.dates_offset - last_j = last_j + self.dates_offset - print(f"Requested ds({i}) : need to read {list(range(first_j, last_j + 1))} indices") - - # _load_data could support a list of indices, but for now we merge the data ourselves - # we merge the windows that we need, and then remove unnecessary data - too_much_data = merge_data(self.forward._load_data(j) for j in range(first_j, last_j + 1)) - - out = {} - for group in self.groups: - timedeltas = too_much_data[f"timedeltas:{group}"] - if timedeltas.dtype != "timedelta64[s]": - raise ValueError(f"Wrong type for {group}") - mask = self._window.compute_mask(timedeltas) - - out[f"data:{group}"] = too_much_data[f"data:{group}"][..., mask] - out[f"latitudes:{group}"] = too_much_data[f"latitudes:{group}"][..., mask] - out[f"longitudes:{group}"] = too_much_data[f"longitudes:{group}"][..., mask] - out[f"timedeltas:{group}"] = too_much_data[f"timedeltas:{group}"][..., mask] - out[f"metadata:{group}"] = too_much_data[f"metadata:{group}"] - - return out - - -class Select(RecordsForward): - # Select a subset of variables from a records dataset - # select can be a list of strings with dots (or a dict with keys as groups and values as list of strings) - # - # the selection is a filter, not a reordering, which is different from fields datasets and should be documented/fixed - # - # Drop should be implemented - - def __init__(self, dataset, select): - super().__init__(dataset) - - self.dataset = dataset - - if isinstance(select, dict): - # if a dict is provided, make it a list of strings with '.' - sel = [] - for group, d in select.items(): - for name in d: - sel.append(f"{group}.{name}") - select = sel - - self._select = select - - self.reason = {"select": select} - self._build_indices_and_name_to_index() - - @property - def metadata(self): - return dict(select=self._select, forward=self.dataset.metadata) - - def _build_indices_and_name_to_index(self): - indices = {} - name_to_index = {} - variables = {} - - # this should be revisited to take into account the order requested by the user - # see what is done in the fields datasets - for group, names in self.dataset.variables.items(): - ind = np.zeros(len(names), dtype=bool) - count = 0 - for j, name in enumerate(names): - if self.match_variable(group, name): - assert j == names.index(name), f"Invalid index {j} for {name} in {group}" - ind[j] = True - indices[group] = ind - if group not in name_to_index: - name_to_index[group] = {} - assert group not in variables, (group, j, name, variables, name_to_index) - variables[group] = [] - name_to_index[group][name] = count - variables[group].append(name) - count += 1 - assert np.sum(ind) == count, f"Mismatch in {group}: {names}, {ind}" - if not variables: - raise ValueError( - f"No variables matched in {self._select} for dataset {self.dataset}. Available groups: {self.dataset.groups} Available variables: {self.dataset.variables} " - ) - self._indices = indices - self._name_to_index = name_to_index - self._variables = variables - - def match_variable(self, *args, **kwargs): - return match_variable(self._select, *args, **kwargs) - - @property - def groups(self): - return list(self._indices.keys()) - - def _load_data(self, i): - forward = self.dataset._load_data(i) - data = {} - for k, v in self._indices.items(): - data[f"latitudes:{k}"] = forward[f"latitudes:{k}"] - data[f"longitudes:{k}"] = forward[f"longitudes:{k}"] - data[f"timedeltas:{k}"] = forward[f"timedeltas:{k}"] - data[f"metadata:{k}"] = forward[f"metadata:{k}"] - for k, v in self._indices.items(): - data[f"data:{k}"] = forward[f"data:{k}"][v] # notice the [v] here - return data - - @property - def name_to_index(self): - return self._name_to_index - - @property - def variables(self): - return self._variables - - @property - def statistics(self): - dic = {} - for group, v in self._indices.items(): - stats = self.dataset.statistics[group] - dic[group] = {key: stats[key][v] for key in stats.keys()} - assert "mean" in dic[group], f"Missing mean in {dic[group]}" - return dic - - -class RecordsSubset(RecordsForward): - """Subset of a records dataset based on a list of integer indices.""" - - def __init__(self, dataset, indices, reason): - super().__init__(dataset) - self.dataset = dataset - self.reason = reason - self._indices = indices - - @cached_property - def dates(self): - dates = self.dataset.dates - return np.array([dates[i] for i in self._indices], dtype="datetime64[s]") - - def _load_data(self, i): - return self.dataset._load_data(self._indices[i]) - - def __len__(self): - return len(self._indices) - - -class RecordsDataset(BaseRecordsDataset): - """This is the base class for all datasets based on records stored on disk.""" - - def __init__(self, path, backend=None, **kwargs): - if kwargs: - print("Warning: ignoring additional kwargs", kwargs) - self.path = path - self.backend = backend_factory(**backend, path=path) - self._groups = list(self.metadata["sources"].keys()) - for k in self.groups: - assert k == self.normalise_key(k), k - - @property - def groups(self): - return self._groups - - @classmethod - def normalise_key(cls, k): - return "".join([x.lower() if x.isalnum() else "_" for x in k]) - - @property - def frequency(self): - frequency = self.metadata["frequency"] - frequency = frequency_to_timedelta(frequency) - return frequency - - @property - def name_to_index(self): - return self.metadata["name_to_index"] - - @property - def variables(self): - return self.metadata["variables"] - - @cached_property - def _window(self): - window = self.metadata["window"] - return window_from_str(window) - - @cached_property - def metadata(self): - return self.backend.read_metadata() - - @property - def shapes(self): - return self.metadata["shapes"] - - def items(self, *args, **kwargs): - return {k: Tabular(self, k) for k in self.groups}.items(*args, **kwargs) - - @cached_property - def statistics(self): - return self.backend.read_statistics() - - def __len__(self): - return len(self.dates) - - @property - def start_date(self): - date = self.metadata["start_date"] - return datetime.datetime.fromisoformat(date) - - @property - def end_date(self): - date = self.metadata["end_date"] - return datetime.datetime.fromisoformat(date) - - @cached_property - def dates(self): - result = [] - delta = self.frequency - d = self.start_date - while d <= self.end_date: - result.append(d) - d += delta - return np.array(result, dtype="datetime64[s]") - - @counter - def _load_data(self, i): - data = self.backend.read(i) - self.backend._check_data(data) - return data - - def check(self, i=None): - if i is not None: - dict_of_sets = defaultdict(set) - for key in self._load_data(i).keys(): - kind, group = key.split(":") - dict_of_sets[group].add(kind) - for group, s in dict_of_sets.items(): - assert s == {"latitudes", "longitudes", "timedeltas", "metadata", "data"}, f"Invalid keys {s}" - - def tree(self): - return Node(self, [], path=self.path) - - -class Record(Mapping): - """A record corresponds to a single time step in a record dataset.""" - - def __init__(self, dataset: RecordsDataset, n: int): - """A record corresponds to a single time step in a record dataset. - n : int, the index of the time step in the dataset. - dataset : RecordsDataset, the dataset this record belongs to. - """ - self.dataset = dataset - self.n = n - - def __repr__(self): - d = {group: "" for group in self.dataset.groups} - return str(d) - - def items(self): - return self._payload.items() - - def __iter__(self): - return iter(self.groups) - - def __len__(self): - return len(self.groups) - - def __contains__(self, group): - return group in self.groups - - @property - def name_to_index(self): - return self.dataset.name_to_index - - @cached_property - def _payload(self) -> dict: - payload = self.dataset._load_data(self.n) - for k in payload.keys(): - assert len(k.split(":")) == 2, f"Invalid key {k}" - return payload - - @cached_property - def groups(self) -> list[str]: - return self.dataset.groups - - def __getitem__(self, group): - k = f"data:{group}" - if k not in self._payload: - raise KeyError(f"Group {group} not found in record {self.n}. Available groups are {self.groups}") - return self._payload[k] - - def _get_aux(self, name): - try: - return {k: self._payload[name + ":" + k] for k in self.groups} - except KeyError as e: - e.add_note(f"Available keys are {self._payload.keys()}") - raise - - @property - def latitudes(self): - return self._get_aux("latitudes") - - @property - def longitudes(self): - return self._get_aux("longitudes") - - @property - def timedeltas(self): - return self._get_aux("timedeltas") - - @property - def statistics(self): - return self.dataset.statistics - - def as_dict(self) -> dict: - """Returns the record as a dictionary with group names as keys. - - Returns - ------- - dict - Dictionary mapping group names to their data. - """ - return {group: self[group] for group in self.groups} - - -class Tabular: - """A RecordsDataset for a single group, similar to a fields dataset, but allowing different shapes for each date.""" - - def __init__(self, dataset, name): - self.dataset = dataset - self.name = name - - @property - def group(self): - return self.name - - def __getitem__(self, i): - return self.__get(i, "data") - - def __get(self, i, k): - payload = self.dataset._load_data(i) - try: - return payload[k + ":" + self.name] - except KeyError: - print(f"KeyError to retrieve {self.name} available groups are", payload.keys()) - raise - - @property - def variables(self): - return self.dataset.variables[self.name] - - @property - def name_to_index(self): - return self.dataset.name_to_index[self.name] - - @property - def statistics(self): - return self.dataset.statistics[self.name] - - @property - def metadata(self): - return self.dataset.metadata diff --git a/src/anemoi/datasets/use/tabular/records/backends/__init__.py b/src/anemoi/datasets/use/tabular/records/backends/__init__.py deleted file mode 100644 index 5ca924e5b..000000000 --- a/src/anemoi/datasets/use/tabular/records/backends/__init__.py +++ /dev/null @@ -1,273 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import io -import json -import logging -import os - -import numpy as np -from cachetools import LRUCache - -LOG = logging.getLogger(__name__) - - -def normalise_key(k): - return "".join([x.lower() if x.isalnum() else "_" for x in k]) - - -class Backend: - def __init__(self, path, **kwargs): - self.path = path - self.kwargs = kwargs - - def read(self, i, **kwargs): - """Read the i-th record and return a dictionary of numpy arrays.""" - raise NotImplementedError("Must be implemented in subclass") - - def read_metadata(self): - """Read the metadata of a record dataset. The metadata does not depend on the record index.""" - raise NotImplementedError("Must be implemented in subclass") - - def read_statistics(self): - """Read the statistics of a record dataset. The statistics does not depend on the record index.""" - raise NotImplementedError("Must be implemented in subclass") - - def _check_data(self, data): - # Check that all keys are normalised - for k in list(data.keys()): - k = k.split(":")[-1] - if k != normalise_key(k): - raise ValueError(f"{k} must be alphanumerical and '-' only.") - - -class Npz1Backend(Backend): - - def __init__(self, *args, number_of_files_per_subdirectory=100, **kwargs): - super().__init__(*args, **kwargs) - self.number_of_files_per_subdirectory = number_of_files_per_subdirectory - self._cache = LRUCache(maxsize=5) - - def read(self, i, **kwargs): - if i in self._cache: - return self._cache[i] - - d = str(int(i / self.number_of_files_per_subdirectory)) - path = os.path.join(self.path, "data", d, f"{i}.npz") - raw = open(path, "rb").read() - buffer = io.BytesIO(raw) - self._cache[i] = dict(np.load(buffer)) - return self._cache[i] - - def read_metadata(self): - with open(os.path.join(self.path, "metadata.json")) as f: - return json.load(f) - - def read_statistics(self): - path = os.path.join(self.path, "statistics.npz") - dic = {} - for k, v in dict(np.load(path)).items(): - key, group = k.split(":") - if group not in dic: - dic[group] = {} - dic[group][key] = v - return dic - - -class Nc1Backend(Backend): - number_of_files_per_subdirectory = 100 - - def read(self, i, **kwargs): - d = str(int(i / self.number_of_files_per_subdirectory)) - path = os.path.join(self.path, "data", d, f"{i}.nc") - import xarray as xr - - ds = xr.open_dataset(path) - return {var: ds[var].values for var in ds.data_vars} - - def read_metadata(self): - with open(os.path.join(self.path, "metadata.json"), "r") as f: - return json.load(f) - - def read_statistics(self): - path = os.path.join(self.path, "statistics.nc") - import xarray as xr - - ds = xr.open_dataset(path) - flatten = {var: ds[var].values for var in ds.data_vars} - dic = {} - for k, v in flatten.items(): - key, group = k.split(":") - if group not in dic: - dic[group] = {} - dic[group][key] = v - return dic - - -def backend_factory(name, *args, **kwargs): - BACKENDS = dict( - npz1=Npz1Backend, - nc1=Nc1Backend, - ) - cls = BACKENDS[name] - return cls(*args, **kwargs) - - -class WriteBackend(Backend): - # Write backend base class, not used for reading - # provides implementation to write data - def __init__(self, *, target, **kwargs): - super().__init__(target, **kwargs) - - def write(self, i, data, **kwargs): - # expects data to be a dict of numpy arrays - raise NotImplementedError("Must be implemented in subclass") - - def write_metadata(self, metadata): - # expects metadata to be a dict - raise NotImplementedError("Must be implemented in subclass") - - def write_statistics(self, statistics): - # expects statistics to be a dict of dicts with the right keys: - # {group: {mean:..., std:..., min:..., max:...}} - raise NotImplementedError("Must be implemented in subclass") - - def _check_data(self, data): - for k in list(data.keys()): - k = k.split(":")[-1] - if k != normalise_key(k): - raise ValueError(f"{k} must be alphanumerical and '_' only.") - - def _dataframes_to_record(self, i, data, variables, **kwargs): - # Convert data from pandas DataFrames to a record format - # will be used for writing, building obs datasets - - assert isinstance(data, (dict)), type(data) - if not data: - LOG.warning(f"Empty data for index {i}.") - return data - first = data[list(data.keys())[0]] - import pandas as pd - - if isinstance(first, pd.DataFrame): - data = {name: self._dataframe_to_dict(name, df, **kwargs) for name, df in data.items()} - else: - assert False - - return data - - def _dataframe_to_dict(self, name, df, **kwargs): - # will be used for writing, building obs datasets - - d = {} - d["timedeltas:" + name] = df["timedeltas"] - d["latitudes:" + name] = df["latitudes"] - d["longitudes:" + name] = df["longitudes"] - d["data:" + name] = df["data"] - d["metadata:" + name] = df["metadata"] - return d - - -class Npz1WriteBackend(WriteBackend): - - def write(self, i, data, number_of_files_per_subdirectory=100, **kwargs): - self.number_of_files_per_subdirectory = number_of_files_per_subdirectory - self._check_data(data) - d = str(int(i / self.number_of_files_per_subdirectory)) - dir_path = os.path.join(self.path, "data", d) - - out_path = os.path.join(dir_path, f"{i}.npz") - tmp_path = os.path.join(dir_path, f"{i}.tmp.npz") - - os.makedirs(os.path.dirname(tmp_path), exist_ok=True) - np.savez(tmp_path, **data) - os.rename(tmp_path, out_path) - - def write_metadata(self, metadata): - from anemoi.datasets.create.gridded.tasks import _json_tidy - - os.makedirs(self.path, exist_ok=True) - - path = os.path.join(self.path, "metadata.json") - tmp_path = path + ".tmp" - with open(tmp_path, "w") as f: - json.dump(metadata, f, indent=2, default=_json_tidy) - os.rename(tmp_path, path) - - def write_statistics(self, statistics): - os.makedirs(self.path, exist_ok=True) - flatten = {} - for name, d in statistics.items(): - assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" - assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" - for k, v in d.items(): - assert isinstance( - v, (int, float, np.ndarray) - ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" - flatten[k + ":" + name] = v - - path = os.path.join(self.path, "statistics.npz") - np.savez(path, **flatten) - - -class Nc1WriteBackend(WriteBackend): - number_of_files_per_subdirectory = 100 - - def write(self, i, data, **kwargs): - self._check_data(data) - d = str(int(i / self.number_of_files_per_subdirectory)) - path = os.path.join(self.path, "data", d) - os.makedirs(path, exist_ok=True) - out_path = os.path.join(path, f"{i}.nc") - - import xarray as xr - - ds = xr.Dataset( - {key: ([f"dim_{key}" + str(i) for i in range(value.ndim)], value) for key, value in data.items()} - ) - ds.to_netcdf(out_path) - - def write_metadata(self, metadata): - from anemoi.datasets.create.gridded.tasks import _json_tidy - - os.makedirs(self.path, exist_ok=True) - with open(os.path.join(self.path, "metadata.json"), "w") as f: - json.dump(metadata, f, indent=2, default=_json_tidy) - - def write_statistics(self, statistics): - os.makedirs(self.path, exist_ok=True) - flatten = {} - for name, d in statistics.items(): - assert isinstance(d, dict), f"Statistics for {name} must be a dict, got {type(d)}" - assert "mean" in d, f"Statistics for {name} must contain 'mean' key but got {d.keys()}" - for k, v in d.items(): - assert isinstance( - v, (int, float, np.ndarray) - ), f"Statistics value for {k} in {name} must be int, float or ndarray, got {type(v)}" - flatten[k + ":" + name] = v - - path = os.path.join(self.path, "statistics.nc") - - import xarray as xr - - ds = xr.Dataset( - {key: ([f"dim_{key}" + str(i) for i in range(value.ndim)], value) for key, value in flatten.items()} - ) - ds.to_netcdf(path) - np.savez(path, **flatten) - - -def writer_backend_factory(name, **kwargs): - # choose the right backend for writing - # this is intended to make benchmarking easier - WRITE_BACKENDS = dict( - npz1=Npz1WriteBackend, - nc1=Nc1WriteBackend, - ) - return WRITE_BACKENDS[name](**kwargs) diff --git a/src/anemoi/datasets/use/tabular/windows.py b/src/anemoi/datasets/use/tabular/windows.py deleted file mode 100644 index 5f02e3c82..000000000 --- a/src/anemoi/datasets/use/tabular/windows.py +++ /dev/null @@ -1,252 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -import datetime - -import numpy as np -from anemoi.utils.dates import frequency_to_string - - -def _to_numpy_timedelta(td): - if isinstance(td, np.timedelta64): - assert td.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {td.dtype}" - return td - return np.timedelta64(int(td.total_seconds()), "s") - - -def window_from_str(txt): - """Parses a window string of the form '(-6h, 0h]' and returns a WindowsSpec object.""" - if txt.startswith("["): - include_start = True - elif txt.startswith("("): - include_start = False - else: - raise ValueError(f"Invalid window {txt}, must start with '(' or '['") - txt = txt[1:] - - if txt.endswith("]"): - include_end = True - elif txt.endswith(")"): - include_end = False - else: - raise ValueError(f"Invalid window {txt}, must end with ')' or ']'") - txt = txt[:-1] - - txt = txt.strip() - if ";" in txt: - txt = txt.replace(";", ",") - lst = txt.split(",") - if len(lst) != 2: - raise ValueError( - f"Invalid window {txt}, must be of the form '(start, end)' or '[start, end]' or '[start, end)' or '(start, end]'" - ) - start, end = lst - start = start.strip() - end = end.strip() - - def _to_timedelta(t): - # This part should go into utils - from anemoi.utils.dates import as_timedelta - - if t.startswith(" ") or t.endswith(" "): - t = t.strip() - if t.startswith("-"): - return -as_timedelta(t[1:]) - if t.startswith("+"): - return as_timedelta(t[1:]) - # end of : This part should go into utils - return as_timedelta(t) - - start = _to_timedelta(start) - end = _to_timedelta(end) - return WindowsSpec( - start=start, - end=end, - include_start=include_start, - include_end=include_end, - ) - - -class Interval: - # not used but expected to be useful when building datasets. And used in tests - def __init__(self, start, end, include_start=True, include_end=True): - assert isinstance(start, datetime.datetime), f"start must be a datetime.datetime, got {type(start)}" - assert isinstance(end, datetime.datetime), f"end must be a datetime.datetime, got {type(end)}" - assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" - assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" - if start >= end: - raise ValueError(f"start {start} must be less than end {end}") - self.start = start - self.end = end - self.include_start = include_start - self.include_end = include_end - - def __repr__(self): - return f"{'[' if self.include_start else '('}{self.start.isoformat()},{self.end.isoformat()}{']' if self.include_end else ')'}" - - def intersection(self, other): - assert isinstance(other, Interval), f"`other` must be a Interval, got {type(other)}" - - if self._start_np > other._end_np or other._start_np > self._end_np: - return None # no intersection - - if self._start_np < other._start_np: - start = other._start_np - include_start = other.include_start - elif self._start_np > other._start_np: - start = self._start_np - include_start = self.include_start - else: # equal - start = self._start_np - include_start = self.include_start and other.include_start - - if self._end_np < other._end_np: - end = self._end_np - include_end = self.include_end - elif self._end_np > other._end_np: - end = other._end_np - include_end = other.include_end - else: # equal - end = self._end_np - include_end = self.include_end and other.include_end - - return Interval(start=start, end=end, include_start=include_start, include_end=include_end) - - def union(self, other): - assert isinstance(other, Interval), f"`other` must be a Interval, got {type(other)}" - - if self._start_np < other._start_np: - start = self._start_np - include_start = self.include_start - elif self._start_np > other._start_np: - start = other._start_np - include_start = other.include_start - else: # equal - start = self._start_np - include_start = self.include_start or other.include_start - - if self._end_np > other._end_np: - end = self._end_np - include_end = self.include_end - elif self._end_np < other._end_np: - end = other._end_np - include_end = other.include_end - else: # equal - end = self._end_np - include_end = self.include_end or other.include_end - - return Interval(start=start, end=end, include_start=include_start, include_end=include_end) - - -class WindowsSpec: - # A window specified by relative timedeltas, such as (-6h, 0h] - # - # the term "WindowSpec" is used here to avoid confusion between - # - a relative window, such as (-6h, 0h] which this class represents (WindowsSpec) - # - an actual time interval, such as [2023-01-01 00:00, 2023-01-01 06:00] which is an (Interval) - # - # but is is more confusing, it should be renamed as Window. - - def __init__(self, *, start, end, include_start=False, include_end=True): - assert isinstance(start, (str, datetime.timedelta)), f"start must be a str or timedelta, got {type(start)}" - assert isinstance(end, (str, datetime.timedelta)), f"end must be a str or timedelta, got {type(end)}" - assert isinstance(include_start, bool), f"include_start must be a bool, got {type(include_start)}" - assert isinstance(include_end, bool), f"include_end must be a bool, got {type(include_end)}" - assert include_start in (True, False), f"Invalid include_start {include_start}" # None is not allowed - assert include_end in (True, False), f"Invalid include_end {include_end}" # None is not allowed - - if start >= end: - raise ValueError(f"start {start} must be less than end {end}") - - self.start = start - self.end = end - self.include_start = include_start - self.include_end = include_end - - self._start_np = _to_numpy_timedelta(start) - self._end_np = _to_numpy_timedelta(end) - - def to_interval(self, date): - """Convert the window to an absolute window based on a date.""" - # not used but expected to be useful when building datasets. And used in tests - assert isinstance(date, datetime.datetime), f"date must be a datetime.datetime, got {type(date)}" - start = date + self.start - end = date + self.end - return Interval(start=start, end=end, include_start=self.include_start, include_end=self.include_end) - - def __repr__(self): - first = "[" if self.include_start else "(" - last = "]" if self.include_end else ")" - - def _frequency_to_string(t): - if t < datetime.timedelta(0): - return f"-{frequency_to_string(-t)}" - elif t == datetime.timedelta(0): - return "0" - return frequency_to_string(t) - - return f"{first}{_frequency_to_string(self.start)},{_frequency_to_string(self.end)}{last}" - - def compute_mask(self, timedeltas): - """Returns a boolean numpy array of the same shape as timedeltas.""" - - assert timedeltas.dtype == "timedelta64[s]", f"expecting np.timedelta64[s], got {timedeltas.dtype}" - if self.include_start: - lower_mask = timedeltas >= self._start_np - else: - lower_mask = timedeltas > self._start_np - - if self.include_end: - upper_mask = timedeltas <= self._end_np - else: - upper_mask = timedeltas < self._end_np - - return lower_mask & upper_mask - - def starts_before(self, my_dates, other_dates, other_window): - # apply this window to my_dates[0] and the other_window to other_dates[0] - # return True if this window starts before the other window - - assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" - assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" - assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" - - my_start = my_dates[0] + self._start_np - other_start = other_dates[0] + other_window._start_np - - if my_start == other_start: - return (not other_window.include_start) or self.include_start - return my_start <= other_start - - def ends_after(self, my_dates, other_dates, other_window): - # same as starts_before - assert my_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {my_dates.dtype}" - assert other_dates.dtype == "datetime64[s]", f"expecting np.datetime64[s], got {other_dates.dtype}" - assert isinstance(other_window, WindowsSpec), f"other_window must be a WindowsSpec, got {type(other_window)}" - - my_end = my_dates[-1] + self._end_np - other_end = other_dates[-1] + other_window._end_np - - if my_end == other_end: - print(".", (not other_window.include_end) or self.include_end) - return (not other_window.include_end) or self.include_end - print(my_end >= other_end) - return my_end >= other_end - - def __truediv__(self, n: int): - """Divide the window into a smaller windows, shrinked by a factor n.""" - assert isinstance(n, int), f"n must be an int, got {type(n)}" - assert n > 0, f"n must be positive, got {n}" - - return WindowsSpec( - start=self.start / n, - end=self.end / n, - include_start=self.include_start, - include_end=self.include_end, - ) From 88930b94f8b5696e603b5b176dc9748ec58fc038 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 17 Nov 2025 18:21:08 +0100 Subject: [PATCH 211/212] update --- tests/test_classes.py | 538 ------------------------------------------ tests/test_records.py | 209 ---------------- 2 files changed, 747 deletions(-) delete mode 100644 tests/test_classes.py delete mode 100644 tests/test_records.py diff --git a/tests/test_classes.py b/tests/test_classes.py deleted file mode 100644 index 5a88ff78c..000000000 --- a/tests/test_classes.py +++ /dev/null @@ -1,538 +0,0 @@ -# (C) Copyright 2025 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import os -from collections.abc import Callable -from functools import wraps -from unittest.mock import patch - -import pytest -from anemoi.utils.testing import TEST_DATA_URL -from anemoi.utils.testing import skip_if_offline - -from anemoi.datasets import open_dataset - - -def _tests_zarrs(name: str) -> str: - return os.path.join(TEST_DATA_URL, "anemoi-datasets", f"{name}.zarr") - - -def zarr_tests(func: Callable) -> Callable: - - @wraps(func) - def wrapper(*args, **kwargs): - with patch("anemoi.datasets.use.gridded.stores.dataset_lookup", _tests_zarrs): - return func(*args, **kwargs) - - return wrapper - - -def _test_dataset(ds, variables=None): - - if variables is not None: - assert ds.variables == variables, ( - set(ds.variables) - set(variables), - set(variables) - set(ds.variables), - ds.variables, - ) - - # for p in ds.components(): - # print(p) - # print(p.origins()) - - -not_ready = pytest.mark.skip(reason="Not ready yet") - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_complement_none(): - pass - # ds = open_dataset( - # source="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - # complement="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - # # adjust="all", - # ) - - -@skip_if_offline -@zarr_tests -def test_class_complement_nearest_1(): - - ds = open_dataset( - complement="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - source="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - interpolation="nearest", - ) - _test_dataset( - ds, - variables=[ - "2t", - "cos_latitude", - "cp", - "insolation", - "lsm", - "msl", - "orog", - "sf", - "t_500", - "t_850", - "tp", - "z", - "z_500", - "z_850", - ], - ) - - -@skip_if_offline -@zarr_tests -def test_class_complement_nearest_2(): - ds = open_dataset( - source="cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - complement="aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - interpolation="nearest", - ) - _test_dataset( - ds, - variables=[ - "2t", - "cos_latitude", - "cp", - "insolation", - "lsm", - "msl", - "orog", - "sf", - "t_500", - "t_850", - "tp", - "z", - "z_500", - "z_850", - ], - ) - - -@skip_if_offline -@zarr_tests -def test_class_concat(): - ds = open_dataset( - [ - "aifs-ea-an-oper-0001-mars-20p0-2016-2016-6h-v1", - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - ] - ) - _test_dataset( - ds, - variables=[ - "2t", - "cos_latitude", - "cp", - "insolation", - "lsm", - "msl", - "t_500", - "t_850", - "tp", - "z", - "z_500", - "z_850", - ], - ) - - -@skip_if_offline -@zarr_tests -def test_class_number(): - ds = open_dataset( - "aifs-ea-an-enda-0001-mars-20p0-2017-2017-6h-v1", - members=[0, 2], - ) - _test_dataset( - ds, - variables=[ - "2t", - "cos_latitude", - "cp", - "insolation", - "lsm", - "msl", - "t_500", - "t_850", - "tp", - "z", - "z_500", - "z_850", - ], - ) - - -@skip_if_offline -@zarr_tests -def test_class_ensemble(): - ds = open_dataset( - ensemble=[ - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - "aifs-ea-em-enda-0001-mars-20p0-2017-2017-6h-v1", - ] - ) - _test_dataset( - ds, - variables=[ - "2t", - "cos_latitude", - "cp", - "insolation", - "lsm", - "msl", - "t_500", - "t_850", - "tp", - "z", - "z_500", - "z_850", - ], - ) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_dates_fill(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_dates_closest(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_dates_interpolate(): - pass - - -@skip_if_offline -@zarr_tests -def test_class_grids(): - ds = open_dataset( - grids=[ - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - ], - adjust="all", - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_cutout() -> None: - ds = open_dataset( - cutout=[ - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - ], - adjust="all", - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_date_error(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_interpolate_frequency(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_interpolate_nearest(): - pass - - -@skip_if_offline -@zarr_tests -def test_class_join_1(): - ds = open_dataset( - [ - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-sfc", - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-pl", - ], - ) - _test_dataset(ds, ["2t", "lsm", "msl", "z", "t_500", "t_850", "z_500", "z_850"]) - - -@skip_if_offline -@zarr_tests -def test_class_join_2(): - ds = open_dataset( - [ - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-pl", - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1-sfc", - ], - ) - _test_dataset(ds, ["t_500", "t_850", "z_500", "z_850", "2t", "lsm", "msl", "z"]) - - -@skip_if_offline -@zarr_tests -def test_class_thinning(): - ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - thinning=4, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_cropping(): - ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - area=[80, -10, 30, 40], - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_trim_edge(): - ds = open_dataset( - "cerra-rr-an-oper-0001-mars-5p0-2017-2017-6h-v1", - trim_edge=(1, 2, 3, 4), - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_merge(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_dates(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_skip_missing_dates(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_missing_dataset(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_padded(): - pass - - -@skip_if_offline -@zarr_tests -def test_class_rescale_1(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - rescale={"2t": (1.0, -273.15)}, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_rescale_2(): - try: - import cfunits # noqa: F401 - except FileNotFoundError: - # cfunits requires the library udunits2 to be installed - raise pytest.skip("udunits2 library not installed") - - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - rescale={"2t": ("K", "degC")}, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_rescale_3(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - rescale={ - "2t": {"scale": 1.0, "offset": -273.15}, - }, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_select_select_1(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - select=["msl", "2t"], - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_select_select_2(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - select={"msl", "2t"}, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_select_drop(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - drop=["2t", "msl"], - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_rename() -> None: - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - rename={"2t": "temperature", "msl": "pressure"}, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_rename_with_overlap() -> None: - ds = open_dataset( - [ - { - "dataset": "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - "select": ["cp", "tp"], - "end": 2023, - "frequency": "6h", - }, - { - "dataset": "aifs-od-an-oper-0001-mars-o96-2016-2023-6h-v1-precipitations", - "end": 2023, - "frequency": "6h", - "rename": {"tp_0h_12h": "tp"}, - "select": ["tp_0h_12h"], - }, - ], - end=2022, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_statistics(): - pass - - -@skip_if_offline -@zarr_tests -def test_class_zarr(): - ds = open_dataset("aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1") - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_zarr_with_missing_dates(): - ds = open_dataset("rodeo-opera-files-o96-2013-2023-6h-v5") - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -def test_class_subset(): - ds = open_dataset( - "aifs-ea-an-oper-0001-mars-20p0-2017-2017-6h-v1", - frequency="12h", - start=2017, - end=2018, - ) - _test_dataset(ds) - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_chain(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_zipbase(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_zip(): - pass - - -@skip_if_offline -@zarr_tests -@not_ready -def test_class_xy(): - pass - - -if __name__ == "__main__": - test_class_complement_nearest_1() - test_class_complement_nearest_2() - exit(0) - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() diff --git a/tests/test_records.py b/tests/test_records.py deleted file mode 100644 index f389a3cdf..000000000 --- a/tests/test_records.py +++ /dev/null @@ -1,209 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -import os - -import numpy as np -import pytest - -from anemoi.datasets.use.gridded import open_dataset -from anemoi.datasets.use.tabular.records import Record -from anemoi.datasets.use.tabular.records import Tabular - -TEST_DATASET = "../../data/vz/observations-testing-2018-2018-6h-v0.vz" - - -def check_numpy(x, y): - assert x.shape == y.shape, f"Expected {x.shape} == {y.shape}" - assert type(x) == type(y), f"Expected {type(x)} == {type(y)}" # noqa: E721 - assert np.all(np.isnan(x) == np.isnan(y)) and np.all( - np.nan_to_num(x) == np.nan_to_num(y) - ), f"Expected {x} == {y} (ignoring NaNs)" - - -def _test(ds, nb_dates=None): - print(f"💬 Testing {type(ds)} with {len(ds)} dates") - print(ds.tree()) - grp = "metop_a" - index_i = 0 - - if nb_dates is not None: - assert len(ds) == nb_dates, f"Expected {nb_dates} dates, got {len(ds)}" - - ################################# - # Order does not matter too much [i] and [grp] are exchangeable - - elt = ds[index_i] - assert isinstance(elt, Record), (type(ds), type(elt)) - assert ds[index_i].dataset == ds, (type(ds[index_i].dataset), type(ds)) - - group = ds[grp] - assert isinstance(group, Tabular), type(group) - - x = ds[grp][index_i] - y = ds[index_i][grp] - check_numpy(x, y) - - ############################################### - # lat and lon and timedelta are not the same for all elements - # but they have the same size - - lat = ds[index_i].latitudes[grp] - assert isinstance(lat, np.ndarray), type(lat) - - # Not implemented yet - # lat = ds[grp].latitudes[index_i] - # assert isinstance(lat, np.ndarray), type(lat) - - # Not implemented yet : do not need ? - # lat = ds.latitudes[grp][index_i] - # assert isinstance(lat, np.ndarray), type(lat) - - # Not implemented yet : do not need ? - # lat = ds.latitudes[index_i][grp] - # assert isinstance(lat, np.ndarray), type(lat) - - lon = ds[index_i].longitudes[grp] - assert isinstance(lon, np.ndarray), type(lon) - assert len(lat) == len(lon), f"Expected same size for lat and lon {len(lat)} == {len(lon)}" - - timedeltas = ds[index_i].timedeltas[grp] - assert isinstance(timedeltas, np.ndarray), type(timedeltas) - assert len(timedeltas) == len(lat), f"Expected same size for lat and timedeltas {len(lat)} == {len(timedeltas)}" - - ############################################# - # name_to_index is must be the same for all elements - # name_to_index is a dict of dict (key is the group name) - - name_to_index = ds.name_to_index - assert isinstance(name_to_index, dict), type(name_to_index) - assert len(name_to_index) > 0, "name_to_index is empty" - assert all(isinstance(k, str) for k in name_to_index.keys()), name_to_index - assert all(isinstance(v, dict) for v in name_to_index.values()), name_to_index - - _name_to_index = ds[index_i].name_to_index - assert list(name_to_index.keys()) == list(_name_to_index.keys()), ( - list(name_to_index.keys()), - list(_name_to_index.keys()), - ) - assert name_to_index == _name_to_index, "name_to_index is not the same for all elements" - - ############################################### - # statistics is not the same for all elements - # statistics is a dict of dict (first key is the group name) - - statistics = ds.statistics - assert isinstance(statistics, dict), type(statistics) - assert len(statistics) > index_i, "statistics is empty" - assert all(isinstance(k, str) for k in statistics.keys()), statistics - assert all(isinstance(v, dict) for v in statistics.values()), statistics - assert grp in statistics, f"statistics does not contain {grp}" - - statistics_ = ds[grp].statistics - assert isinstance(statistics_, dict), type(statistics_) - assert "mean" in statistics_, "statistics does not contain mean" - - # ! here, the meaning could be ambigous, this is the statistics of the whole dataset. - # Do not document this, and maybe remove it. - _statistics = ds[index_i].statistics - assert isinstance(_statistics, dict), type(_statistics) - assert grp in _statistics, f"statistics does not contain {grp}" - assert list(_statistics.keys()) == ds.groups, (_statistics.keys(), ds.groups) - for group_name, stats in _statistics.items(): - assert "mean" in stats, f"statistics does not contain mean for {group_name}" - for key, v in stats.items(): - assert np.all(statistics[group_name][key] == v), (key, statistics[group_name][key], v) - - assert statistics[grp].keys() == statistics_.keys(), (statistics[grp].keys(), statistics_.keys()) - for key, v in statistics[grp].items(): - assert np.all(statistics[grp][key] == v), (key, statistics[grp][key], v) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -def test_open(): - ds = open_dataset(TEST_DATASET) - _test(ds) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -def test_open_with_subset_dates(): - ds = open_dataset( - TEST_DATASET, - end="2018-11-30", - select=[ - "metop_a.*", - "amsr2_h180.rawbt_4", - "amsr2_h180.rawbt_3", - ], - ) - _test(ds, nb_dates=8) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -def test_open_with_window(): - dates = dict(end="2018-11-30") - ds = open_dataset(TEST_DATASET, window="(-6h, 0h]", **dates) - _test(ds, nb_dates=8) - - ds = open_dataset(TEST_DATASET, window="(-1h, 0)", **dates) - _test(ds, nb_dates=8) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -def test_open_bad_window(): - subset = dict(end="2018-11-30") - with pytest.raises(ValueError, match="No dates left after rewindowing"): - open_dataset(TEST_DATASET, window="(-48h, +48h)", **subset) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -@pytest.mark.parametrize( - "window, missing_dates", - [ - ("(-12h, 0)", -1), # first window is incomplete - ("[-12h, 0)", -2), # first two windows are incomplete - ("(-3h, +3h)", -1), # last date is incomplete - ("[-6h, 0h)", -1), # incomplete due to rounding - ("(-6h, 0h)", 0), - ("(-1h, 0h]", 0), - ("(-1h, 0)", 0), - ("(-6h, +6h)", -1), - ("(-6h, +5h)", -1), - ("(-12h, +12h)", -3), - ("(-1h, +15h]", -3), - ], -) -def test_open_with_window_parametrized(window, missing_dates): - subset = dict(end="2018-11-30") - - ds = open_dataset(TEST_DATASET, **subset) - assert len(ds) == 8 - nb_dates = len(ds) + missing_dates - - ds = open_dataset(TEST_DATASET, window=window, **subset) - _test(ds, nb_dates=nb_dates) - - -@pytest.mark.skipif(not os.path.exists(TEST_DATASET), reason="File not found") -def test_open_with_subset_select(): - ds = open_dataset( - TEST_DATASET, - select=[ - "amsr2_h180.rawbt_4", - "amsr2_h180.rawbt_3", - "metop_a.*", - ], - ) - _test(ds) - - -if __name__ == "__main__": - - test_open() - test_open_with_subset_select() - test_open_with_subset_dates() From f3f0c343ad5943b7ba813bccef02ca6d45d40a39 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 18 Nov 2025 07:36:22 +0100 Subject: [PATCH 212/212] update --- tests/test_data.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_data.py b/tests/test_data.py index 670b29a88..34da035a6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1417,18 +1417,6 @@ def test_invalid_trim_edge() -> None: ) -@mockup_open_zarr -def test_fields_to_records() -> None: - """Test joining datasets (case 2).""" - - key = "grp" - ds = open_dataset(dataset="test-2021-2021-6h-o96-abcd-1", set_group=key) - # unwrapped = open_dataset(dataset="test-2021-2021-6h-o96-abcd-2") - - assert ds.groups == [key] - assert ds.variables == {key: ["a", "b", "c", "d"]} - - @pytest.mark.skip("Saving datasets not yet supported in that branch") def test_save_dataset() -> None: """Test save datasets."""