Skip to content

Support compute=False from DataTree.to_netcdf #10625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 18, 2025
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ v2025.08.1 (unreleased)
New Features
~~~~~~~~~~~~

- ``compute=False`` is now supported by :py:meth:`DataTree.to_netcdf` and
:py:meth:`DataTree.to_zarr`.
By `Stephan Hoyer <https://github.com/shoyer>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -25,6 +28,10 @@ Deprecations
Bug fixes
~~~~~~~~~

- :py:meth:`DataTree.to_netcdf` and :py:meth:`DataTree.to_zarr` now avoid
redundant computation of Dask arrays with cross-group dependencies
(:issue:`10637`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
204 changes: 136 additions & 68 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
from xarray.backends import plugins
from xarray.backends.common import (
AbstractDataStore,
AbstractWritableDataStore,
ArrayWriter,
BytesIOProxy,
T_PathFileOrDataStore,
_find_absolute_paths,
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.locks import get_dask_scheduler
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
from xarray.core import dtypes, indexing
from xarray.core.coordinates import Coordinates
Expand Down Expand Up @@ -307,12 +308,18 @@ def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None:
_protect_dataset_variables_inplace(node.dataset, cache)


def _finalize_store(write, store):
def _finalize_store(writes, store):
"""Finalize this store by explicitly syncing and closing"""
del write # ensure writing is done first
del writes # ensure writing is done first
store.close()


def delayed_close_after_writes(writes, store):
import dask

return dask.delayed(_finalize_store)(writes, store)


def _multi_file_closer(closers):
for closer in closers:
closer()
Expand Down Expand Up @@ -1855,6 +1862,39 @@ def open_mfdataset(
}


def get_writable_netcdf_store(
target,
engine: T_NetcdfEngine,
*,
format: T_NetcdfTypes | None,
mode: NetcdfWriteModes,
autoclose: bool,
invalid_netcdf: bool,
auto_complex: bool | None,
) -> AbstractWritableDataStore:
"""Create a store for writing to a netCDF file."""
try:
store_open = WRITEABLE_STORES[engine]
except KeyError as err:
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err

if format is not None:
format = format.upper() # type: ignore[assignment]

kwargs = dict(autoclose=True) if autoclose else {}
if invalid_netcdf:
if engine == "h5netcdf":
kwargs["invalid_netcdf"] = invalid_netcdf
else:
raise ValueError(
f"unrecognized option 'invalid_netcdf' for engine {engine}"
)
if auto_complex is not None:
kwargs["auto_complex"] = auto_complex

return store_open(target, mode=mode, format=format, **kwargs)


# multifile=True returns writer and datastore
@overload
def to_netcdf(
Expand Down Expand Up @@ -2040,16 +2080,8 @@ def to_netcdf(
# sanitize unlimited_dims
unlimited_dims = _sanitize_unlimited_dims(dataset, unlimited_dims)

try:
store_open = WRITEABLE_STORES[engine]
except KeyError as err:
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err

if format is not None:
format = format.upper() # type: ignore[assignment]

# handle scheduler specific logic
scheduler = _get_scheduler()
scheduler = get_dask_scheduler()
have_chunks = any(v.chunks is not None for v in dataset.variables.values())

autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
Expand All @@ -2064,18 +2096,17 @@ def to_netcdf(
else:
target = path_or_file # type: ignore[assignment]

kwargs = dict(autoclose=True) if autoclose else {}
if invalid_netcdf:
if engine == "h5netcdf":
kwargs["invalid_netcdf"] = invalid_netcdf
else:
raise ValueError(
f"unrecognized option 'invalid_netcdf' for engine {engine}"
)
if auto_complex is not None:
kwargs["auto_complex"] = auto_complex

store = store_open(target, mode, format, group, **kwargs)
store = get_writable_netcdf_store(
target,
engine,
mode=mode,
format=format,
autoclose=autoclose,
invalid_netcdf=invalid_netcdf,
auto_complex=auto_complex,
)
if group is not None:
store = store.get_child_store(group)

writer = ArrayWriter()

Expand All @@ -2096,17 +2127,18 @@ def to_netcdf(
writes = writer.sync(compute=compute)

finally:
if not multifile and compute: # type: ignore[redundant-expr]
store.close()
if not multifile:
if compute:
store.close()
else:
store.sync()

if path_or_file is None:
assert isinstance(target, BytesIOProxy) # created in this function
return target.getvalue_or_getbuffer()

if not compute:
import dask

return dask.delayed(_finalize_store)(writes, store)
return delayed_close_after_writes(writes, store)

return None

Expand Down Expand Up @@ -2262,20 +2294,71 @@ def save_mfdataset(
try:
writes = [w.sync(compute=compute) for w in writers]
finally:
if compute:
for store in stores:
for store in stores:
if compute:
store.close()
else:
store.sync()

if not compute:
import dask

return dask.delayed(
list(
starmap(dask.delayed(_finalize_store), zip(writes, stores, strict=True))
)
list(starmap(delayed_close_after_writes, zip(writes, stores, strict=True)))
)


def get_writable_zarr_store(
store: ZarrStoreLike | None = None,
*,
chunk_store: MutableMapping | str | os.PathLike | None = None,
mode: ZarrWriteModes | None = None,
synchronizer=None,
group: str | None = None,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
safe_chunks: bool = True,
align_chunks: bool = False,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
zarr_format: int | None = None,
write_empty_chunks: bool | None = None,
) -> backends.ZarrStore:
"""Create a store for writing to Zarr."""
from xarray.backends.zarr import _choose_default_mode, _get_mappers

kwargs, mapper, chunk_mapper = _get_mappers(
storage_options=storage_options, store=store, chunk_store=chunk_store
)
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)

if mode == "r+":
already_consolidated = consolidated
consolidate_on_close = False
else:
already_consolidated = False
consolidate_on_close = consolidated or consolidated is None

return backends.ZarrStore.open_group(
store=mapper,
mode=mode,
synchronizer=synchronizer,
group=group,
consolidated=already_consolidated,
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_mapper,
append_dim=append_dim,
write_region=region,
safe_chunks=safe_chunks,
align_chunks=align_chunks,
zarr_version=zarr_version,
zarr_format=zarr_format,
write_empty=write_empty_chunks,
**kwargs,
)


# compute=True returns ZarrStore
@overload
def to_zarr(
Expand Down Expand Up @@ -2350,7 +2433,6 @@ def to_zarr(

See `Dataset.to_zarr` for full API docs.
"""
from xarray.backends.zarr import _choose_default_mode, _get_mappers

# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)
Expand All @@ -2365,53 +2447,39 @@ def to_zarr(
if encoding is None:
encoding = {}

kwargs, mapper, chunk_mapper = _get_mappers(
storage_options=storage_options, store=store, chunk_store=chunk_store
)
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)

if mode == "r+":
already_consolidated = consolidated
consolidate_on_close = False
else:
already_consolidated = False
consolidate_on_close = consolidated or consolidated is None

zstore = backends.ZarrStore.open_group(
store=mapper,
zstore = get_writable_zarr_store(
store,
chunk_store=chunk_store,
mode=mode,
synchronizer=synchronizer,
group=group,
consolidated=already_consolidated,
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_mapper,
consolidated=consolidated,
append_dim=append_dim,
write_region=region,
region=region,
safe_chunks=safe_chunks,
align_chunks=align_chunks,
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
write_empty=write_empty_chunks,
**kwargs,
write_empty_chunks=write_empty_chunks,
)

dataset = zstore._validate_and_autodetect_region(
dataset,
)
dataset = zstore._validate_and_autodetect_region(dataset)
zstore._validate_encoding(encoding)

writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
dump_to_store(dataset, zstore, writer, encoding=encoding)
writes = writer.sync(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

if compute:
_finalize_store(writes, zstore)
else:
import dask
# TODO: figure out how to properly handle unlimited_dims
try:
dump_to_store(dataset, zstore, writer, encoding=encoding)
writes = writer.sync(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)
finally:
if compute:
zstore.close()

return dask.delayed(_finalize_store)(writes, zstore)
if not compute:
return delayed_close_after_writes(writes, zstore)

return zstore
9 changes: 9 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Any,
ClassVar,
Generic,
Self,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -326,6 +327,10 @@ async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None):
class AbstractDataStore:
__slots__ = ()

def get_child_store(self, group: str) -> Self: # pragma: no cover
"""Get a store corresponding to the indicated child group."""
raise NotImplementedError()

def get_dimensions(self): # pragma: no cover
raise NotImplementedError()

Expand Down Expand Up @@ -606,6 +611,10 @@ def set_dimensions(self, variables, unlimited_dims=None):
is_unlimited = dim in unlimited_dims
self.set_dimension(dim, length, is_unlimited)

def sync(self):
"""Write all buffered data to disk."""
raise NotImplementedError()


def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
Expand Down
13 changes: 12 additions & 1 deletion xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
import os
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Self

import numpy as np

Expand Down Expand Up @@ -150,6 +150,17 @@ def __init__(
self.lock = ensure_lock(lock)
self.autoclose = autoclose

def get_child_store(self, group: str) -> Self:
if self._group is not None:
group = os.path.join(self._group, group)
return type(self)(
self._manager,
group=group,
mode=self._mode,
lock=self.lock,
autoclose=self.autoclose,
)

@classmethod
def open(
cls,
Expand Down
Loading
Loading