Skip to content

DataTree.to_zarr() performs redundant computations with cross-group dependencies #10637

@jdemaria

Description

@jdemaria

What happened?

When saving a DataTree with cross-group dependencies using to_zarr(), Dask computations are executed multiple times unnecessarily, leading to significant performance overhead compared to saving equivalent data as a single Dataset.

A workaround is to use DataTree.persist().to_zarr(), but this has the disadvantage of computing and loading all data into memory before writing. Do you know if there is a better memory-efficient workaround that avoids redundant computations without the RAM overhead?

What did you expect to happen?

DataTree.to_zarr() should optimize the computation graph globally and execute each Dask operation only once, similar to Dataset.to_zarr() or DataTree.compute().

Minimal Complete Verifiable Example

#!/usr/bin/env python3
"""
Minimal DataTree.to_zarr() bug reproduction: DataTree.to_zarr() performs
redundant computations with cross-group dependencies
"""

import tempfile

import dask.array as da
import xarray as xr

eval_count = 0


def expensive_func(x):
    global eval_count
    eval_count += 1
    return x + 1


def test_bug():
    global eval_count

    # Create data with cross-group dependency
    base = da.random.random((100, 50), chunks=(50, 25))
    var_a = da.map_blocks(expensive_func, base, dtype=float)
    var_b = da.map_blocks(expensive_func, var_a, dtype=float)  # B depends on A

    base_da = xr.DataArray(base, dims=["x", "y"])
    var_a_da = xr.DataArray(var_a, dims=["x", "y"])
    var_b_da = xr.DataArray(var_b, dims=["x", "y"])

    # Test 1: Single Dataset (optimal)
    eval_count = 0
    ds = xr.Dataset({"base": base_da, "var_a": var_a_da, "var_b": var_b_da})
    with tempfile.TemporaryDirectory() as tmpdir:
        ds.to_zarr(f"{tmpdir}/single.zarr")
    single_evals = eval_count

    # Test 2: DataTree with cross-group dependency (inefficient)
    eval_count = 0
    group1 = xr.DataTree(xr.Dataset({"base": base_da, "var_a": var_a_da}), name="group1")
    group2 = xr.DataTree(xr.Dataset({"var_b": var_b_da}), name="group2")  # depends on var_a

    dt = xr.DataTree(dataset=xr.Dataset(), children={"group1": group1, "group2": group2}, name="root")

    with tempfile.TemporaryDirectory() as tmpdir:
        dt.to_zarr(f"{tmpdir}/datatree.zarr")
    datatree_evals = eval_count

    # Test 3: DataTree with workaround
    eval_count = 0
    with tempfile.TemporaryDirectory() as tmpdir:
        dt.persist().to_zarr(f"{tmpdir}/persist.zarr")
    persist_evals = eval_count

    print(f"Single Dataset:     {single_evals} evaluations")
    print(
        f"DataTree (BUG):     {datatree_evals} evaluations (+{((datatree_evals-single_evals)/single_evals)*100:.0f}%)"
    )
    print(f"DataTree + persist: {persist_evals} evaluations")
    print(f"\nWORKAROUND: Use dt.persist().to_zarr() but load all data in memory")


if __name__ == "__main__":
    test_bug()

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

Single Dataset:     8 evaluations
DataTree (BUG):     12 evaluations (+50%)
DataTree + persist: 8 evaluations

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] python-bits: 64 OS: Linux OS-release: 5.14.0-570.30.1.el9_6.x86_64 machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: fr_FR.UTF-8 LOCALE: ('fr_FR', 'UTF-8') libhdf5: 1.14.2 libnetcdf: 4.9.4-development

xarray: 2024.11.0
pandas: 2.3.1
numpy: 2.1.3
scipy: 1.16.1
netCDF4: 1.7.2
pydap: None
h5netcdf: 1.6.4
h5py: 3.13.0
zarr: 2.18.4
cftime: 1.6.4.post1
nc_time_axis: None
iris: None
bottleneck: None
dask: 2024.5.2
distributed: 2024.5.2
matplotlib: 3.10.1
cartopy: None
seaborn: None
numbagg: None
fsspec: 2024.5.0
cupy: None
pint: None
sparse: None
flox: None
numpy_groupies: None
setuptools: 80.9.0
pip: 25.2
conda: None
pytest: 7.4.3
mypy: 1.10.0
IPython: 9.4.0
sphinx: 6.2.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugtopic-DataTreeRelated to the implementation of a DataTree classtopic-dasktopic-zarrRelated to zarr storage library

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions