-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
What happened?
When saving a DataTree
with cross-group dependencies using to_zarr()
, Dask computations are executed multiple times unnecessarily, leading to significant performance overhead compared to saving equivalent data as a single Dataset
.
A workaround is to use DataTree.persist().to_zarr()
, but this has the disadvantage of computing and loading all data into memory before writing. Do you know if there is a better memory-efficient workaround that avoids redundant computations without the RAM overhead?
What did you expect to happen?
DataTree.to_zarr()
should optimize the computation graph globally and execute each Dask operation only once, similar to Dataset.to_zarr()
or DataTree.compute()
.
Minimal Complete Verifiable Example
#!/usr/bin/env python3
"""
Minimal DataTree.to_zarr() bug reproduction: DataTree.to_zarr() performs
redundant computations with cross-group dependencies
"""
import tempfile
import dask.array as da
import xarray as xr
eval_count = 0
def expensive_func(x):
global eval_count
eval_count += 1
return x + 1
def test_bug():
global eval_count
# Create data with cross-group dependency
base = da.random.random((100, 50), chunks=(50, 25))
var_a = da.map_blocks(expensive_func, base, dtype=float)
var_b = da.map_blocks(expensive_func, var_a, dtype=float) # B depends on A
base_da = xr.DataArray(base, dims=["x", "y"])
var_a_da = xr.DataArray(var_a, dims=["x", "y"])
var_b_da = xr.DataArray(var_b, dims=["x", "y"])
# Test 1: Single Dataset (optimal)
eval_count = 0
ds = xr.Dataset({"base": base_da, "var_a": var_a_da, "var_b": var_b_da})
with tempfile.TemporaryDirectory() as tmpdir:
ds.to_zarr(f"{tmpdir}/single.zarr")
single_evals = eval_count
# Test 2: DataTree with cross-group dependency (inefficient)
eval_count = 0
group1 = xr.DataTree(xr.Dataset({"base": base_da, "var_a": var_a_da}), name="group1")
group2 = xr.DataTree(xr.Dataset({"var_b": var_b_da}), name="group2") # depends on var_a
dt = xr.DataTree(dataset=xr.Dataset(), children={"group1": group1, "group2": group2}, name="root")
with tempfile.TemporaryDirectory() as tmpdir:
dt.to_zarr(f"{tmpdir}/datatree.zarr")
datatree_evals = eval_count
# Test 3: DataTree with workaround
eval_count = 0
with tempfile.TemporaryDirectory() as tmpdir:
dt.persist().to_zarr(f"{tmpdir}/persist.zarr")
persist_evals = eval_count
print(f"Single Dataset: {single_evals} evaluations")
print(
f"DataTree (BUG): {datatree_evals} evaluations (+{((datatree_evals-single_evals)/single_evals)*100:.0f}%)"
)
print(f"DataTree + persist: {persist_evals} evaluations")
print(f"\nWORKAROUND: Use dt.persist().to_zarr() but load all data in memory")
if __name__ == "__main__":
test_bug()
MVCE confirmation
- Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- Complete example — the example is self-contained, including all data and the text of any traceback.
- Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- New issue — a search of GitHub Issues suggests this is not a duplicate.
- Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
Single Dataset: 8 evaluations
DataTree (BUG): 12 evaluations (+50%)
DataTree + persist: 8 evaluations
Anything else we need to know?
No response
Environment
xarray: 2024.11.0
pandas: 2.3.1
numpy: 2.1.3
scipy: 1.16.1
netCDF4: 1.7.2
pydap: None
h5netcdf: 1.6.4
h5py: 3.13.0
zarr: 2.18.4
cftime: 1.6.4.post1
nc_time_axis: None
iris: None
bottleneck: None
dask: 2024.5.2
distributed: 2024.5.2
matplotlib: 3.10.1
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.5.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: None
setuptools: 80.9.0
pip: 25.2
conda: None
pytest: 7.4.3
mypy: 1.10.0
IPython: 9.4.0
sphinx: 6.2.1