Skip to content

Commit aea80ab

Browse files
committed
Add ensure_path decorator
1 parent 69be265 commit aea80ab

File tree

6 files changed

+82
-15
lines changed

6 files changed

+82
-15
lines changed

src/anemoi/inference/decorators.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
# nor does it submit to any jurisdiction.
99

1010

11+
import logging
12+
from pathlib import Path
1113
from typing import Any
1214
from typing import TypeVar
1315

1416
from anemoi.inference.context import Context
1517

18+
LOG = logging.getLogger("anemoi.inference")
19+
1620
MARKER = object()
1721
F = TypeVar("F", bound=type)
1822

@@ -64,3 +68,67 @@ def __init__(wrapped_cls, context: Context, main: object = MARKER, *args: Any, *
6468
super().__init__(context, *args, **kwargs)
6569

6670
return type(cls.__name__, (WrappedClass,), {})
71+
72+
73+
class ensure_path:
74+
"""Decorator to ensure a path argument is a Path object and optionally exists.
75+
76+
If `is_dir` is True, the path is treated as a directory, if not for files, the parent directory is treated as a directory.
77+
If `must_exist` is True, the directory must exist.
78+
If `create` is True, the directory will be created if it doesn't exist.
79+
80+
For example:
81+
```
82+
@ensure_path("dir", create=True)
83+
class GribOutput
84+
def __init__(context, dir=None, archive_requests=None):
85+
...
86+
"""
87+
88+
def __init__(self, arg: str, is_dir: bool = False, create: bool = True, must_exist: bool = False):
89+
self.arg = arg
90+
self.is_dir = is_dir
91+
self.create = create
92+
self.must_exist = must_exist
93+
94+
def __call__(self, cls: F) -> F:
95+
"""Decorate the object to ensure the path argument is a Path object."""
96+
97+
class WrappedClass(cls):
98+
def __init__(wrapped_cls, context: Context, *args: Any, **kwargs: Any) -> None:
99+
if self.arg not in kwargs:
100+
LOG.debug(f"Argument '{self.arg}' not found in kwargs, cannot ensure path.")
101+
super().__init__(context, *args, **kwargs)
102+
return
103+
104+
path = kwargs[self.arg] = Path(kwargs[self.arg])
105+
if not self.is_dir:
106+
path = path.parent
107+
108+
if self.must_exist:
109+
if not path.exists():
110+
raise FileNotFoundError(f"Path '{path}' does not exist.")
111+
if self.create:
112+
path.mkdir(parents=True, exist_ok=True)
113+
114+
super().__init__(context, *args, **kwargs)
115+
116+
return type(cls.__name__, (WrappedClass,), {})
117+
118+
119+
class ensure_dir(ensure_path):
120+
"""Decorator to ensure a directory path argument is a Path object and optionally exists.
121+
122+
If `must_exist` is True, the directory must exist.
123+
If `create` is True, the directory will be created if it doesn't exist.
124+
125+
For example:
126+
```
127+
@ensure_dir("dir", create=True)
128+
class PlotOutput
129+
def __init__(context, dir=None, ...):
130+
...
131+
"""
132+
133+
def __init__(self, arg: str, create: bool = True, must_exist: bool = False):
134+
super().__init__(arg, is_dir=True, create=create, must_exist=must_exist)

src/anemoi/inference/outputs/gribfile.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from anemoi.inference.types import FloatArray
2424
from anemoi.inference.types import ProcessorConfig
2525

26+
from ..decorators import ensure_path
2627
from ..decorators import main_argument
2728
from ..grib.encoding import GribWriter
2829
from ..grib.encoding import check_encoding
@@ -312,6 +313,7 @@ def _patch(r: DataRequest) -> DataRequest:
312313

313314
@output_registry.register("grib")
314315
@main_argument("path")
316+
@ensure_path("path")
315317
class GribFileOutput(GribIoOutput):
316318
"""Handles grib files."""
317319

@@ -367,10 +369,6 @@ def __init__(
367369
split_output : bool, optional
368370
Whether to split the output, by default True.
369371
"""
370-
path = Path(path)
371-
if not path.parent.exists():
372-
path.parent.mkdir(parents=True, exist_ok=True)
373-
374372
super().__init__(
375373
context,
376374
out=path,

src/anemoi/inference/outputs/netcdf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from anemoi.inference.types import ProcessorConfig
1919
from anemoi.inference.types import State
2020

21+
from ..decorators import ensure_path
2122
from ..decorators import main_argument
2223
from ..output import Output
2324
from . import output_registry
@@ -31,6 +32,7 @@
3132

3233
@output_registry.register("netcdf")
3334
@main_argument("path")
35+
@ensure_path("path")
3436
class NetCDFOutput(Output):
3537
"""NetCDF output class."""
3638

@@ -76,7 +78,7 @@ def __init__(
7678

7779
from netCDF4 import Dataset
7880

79-
self.path = Path(path)
81+
self.path = path
8082
self.ncfile: Dataset | None = None
8183
self.float_size = float_size
8284
self.missing_value = missing_value
@@ -103,7 +105,6 @@ def open(self, state: State) -> None:
103105
if self.ncfile is not None:
104106
return
105107

106-
self.path.parent.mkdir(parents=True, exist_ok=True)
107108
# If the file exists, we may get a 'Permission denied' error
108109
if os.path.exists(self.path):
109110
os.remove(self.path)

src/anemoi/inference/outputs/plot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from anemoi.utils.grib import units
1515

1616
from anemoi.inference.context import Context
17+
from anemoi.inference.decorators import ensure_dir
1718
from anemoi.inference.decorators import main_argument
1819
from anemoi.inference.types import FloatArray
1920
from anemoi.inference.types import ProcessorConfig
@@ -44,6 +45,7 @@ def fix(lons: FloatArray) -> FloatArray:
4445

4546
@output_registry.register("plot")
4647
@main_argument("dir")
48+
@ensure_dir("dir")
4749
class PlotOutput(Output):
4850
"""Use `earthkit-plots` to plot the outputs."""
4951

@@ -105,7 +107,7 @@ def __init__(
105107
write_initial_state=write_initial_state,
106108
)
107109

108-
self.dir = Path(dir)
110+
self.dir = dir
109111
self.format = format
110112
self.variables = variables
111113
self.template = template
@@ -166,7 +168,6 @@ def write_step(self, state: State) -> None:
166168
"variables": "_".join(self.variables or []),
167169
},
168170
)
169-
self.dir.mkdir(parents=True, exist_ok=True)
170171
fname = self.dir / fname
171172

172173
fig.save(fname)

src/anemoi/inference/outputs/printer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from anemoi.inference.context import Context
2222
from anemoi.inference.types import State
2323

24+
from ..decorators import ensure_path
2425
from ..decorators import main_argument
2526
from ..output import Output
2627
from . import output_registry
@@ -105,6 +106,7 @@ def print_state(
105106

106107
@output_registry.register("printer")
107108
@main_argument("max_lines")
109+
@ensure_path("path")
108110
class PrinterOutput(Output):
109111
"""Printer output class."""
110112

@@ -142,9 +144,6 @@ def __init__(
142144
self.f = None
143145

144146
if path is not None:
145-
path = Path(path)
146-
path.parent.mkdir(parents=True, exist_ok=True)
147-
148147
self.f = open(path, "w")
149148
self.print = partial(print, file=self.f)
150149

src/anemoi/inference/outputs/raw.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from anemoi.inference.types import State
1717
from anemoi.inference.utils.templating import render_template
1818

19+
from ..decorators import ensure_dir
1920
from ..decorators import main_argument
2021
from ..output import Output
2122
from . import output_registry
@@ -25,6 +26,7 @@
2526

2627
@output_registry.register("raw")
2728
@main_argument("path")
29+
@ensure_dir("dir")
2830
class RawOutput(Output):
2931
"""Raw output class."""
3032

@@ -53,7 +55,7 @@ def __init__(
5355
The date format string, by default "%Y%m%d%H%M%S".
5456
"""
5557
super().__init__(context, variables=variables, **kwargs)
56-
self.dir = Path(dir)
58+
self.dir = dir
5759
self.template = template
5860
self.strftime = strftime
5961

@@ -65,7 +67,7 @@ def __repr__(self) -> str:
6567
str
6668
String representation of the RawOutput object.
6769
"""
68-
return f"RawOutput({self.path})"
70+
return f"RawOutput({self.dir})"
6971

7072
def write_step(self, state: State) -> None:
7173
"""Write the state to a compressed .npz file.
@@ -78,8 +80,6 @@ def write_step(self, state: State) -> None:
7880
date = state["date"]
7981
basetime = date - state["step"]
8082

81-
self.dir.mkdir(parents=True, exist_ok=True)
82-
8383
format_info = {
8484
"date": date.strftime(self.strftime),
8585
"step": state["step"],

0 commit comments

Comments
 (0)