Skip to content

Commit cd9231b

Browse files
committed
fix save_zarr code and test
1 parent e16e255 commit cd9231b

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/anemoi/datasets/data/misc.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ def initialize_zarr_store(root: Any, big_dataset: "Dataset") -> None:
664664
# Create or append to "dates" dataset.
665665
if "dates" not in root:
666666
full_length = len(big_dataset.dates)
667-
root.create_dataset("dates", data=np.array([], dtype="datetime64[s]"), chunks=(full_length,))
667+
root.create_dataset("dates", data=np.array([], dtype="datetime64[s]"), chunks=(full_length,), shape=(0,))
668668

669669
if "data" not in root:
670670
dims = (1, len(big_dataset.variables), ensembles, big_dataset.shape[-1])
@@ -681,12 +681,15 @@ def initialize_zarr_store(root: Any, big_dataset: "Dataset") -> None:
681681
k,
682682
data=v,
683683
compressor=None,
684+
shape=v.shape,
684685
)
685686

686687
# Create spatial coordinate datasets if missing.
687688
if "latitudes" not in root or "longitudes" not in root:
688-
root.create_dataset("latitudes", data=big_dataset.latitudes, compressor=None)
689-
root.create_dataset("longitudes", data=big_dataset.longitudes, compressor=None)
689+
root.create_dataset("latitudes", data=big_dataset.latitudes, compressor=None, shape=big_dataset.latitudes.shape)
690+
root.create_dataset(
691+
"longitudes", data=big_dataset.longitudes, compressor=None, shape=big_dataset.longitudes.shape
692+
)
690693
for k, v in big_dataset.metadata().items():
691694
if k not in root.attrs:
692695
root.attrs[k] = v

tests/test_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
VALUES = 10
4444

45+
true_zarr_open = zarr.open
46+
4547

4648
def mockup_open_zarr(func: Callable) -> Callable:
4749
"""Decorator to mock the open_zarr function.
@@ -237,6 +239,8 @@ def zarr_from_str(name: str, mode: str) -> zarr.Group:
237239
Zarr dataset.
238240
"""
239241
# Format: test-2021-2021-6h-o96-abcd-0
242+
if "/" in name:
243+
return true_zarr_open(name)
240244

241245
args = dict(
242246
test="test",

0 commit comments

Comments
 (0)