Skip to content

Add chunks='auto' support for cftime datasets #10527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
49 changes: 49 additions & 0 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np

from xarray.core.common import _contains_cftime_datetimes
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray
from xarray.namedarray.utils import is_duck_dask_array, module_available
Expand All @@ -16,6 +17,7 @@
_NormalizedChunks,
duckarray,
)
from xarray.namedarray.parallelcompat import _Chunks

try:
from dask.array import Array as DaskArray
Expand Down Expand Up @@ -264,3 +266,50 @@ def shuffle(
if chunks != "auto":
raise NotImplementedError("Only chunks='auto' is supported at present.")
return dask.array.shuffle(x, indexer, axis, chunks="auto")

def rechunk( # type: ignore[override]
self,
data: T_ChunkedArray,
chunks: _NormalizedChunks | tuple[int, ...] | _Chunks,
**kwargs: Any,
) -> Any:
"""
Changes the chunking pattern of the given array.

Called when the .chunk method is called on an xarray object that is already chunked.

Parameters
----------
data : dask array
Array to be rechunked.
chunks : int, tuple, dict or str, optional
The new block dimensions to create. -1 indicates the full size of the
corresponding dimension. Default is "auto" which automatically
determines chunk sizes.

Returns
-------
chunked array

See Also
--------
dask.array.Array.rechunk
cubed.Array.rechunk
"""

if _contains_cftime_datetimes(data):
# Preprocess chunks if they're cftime

from dask import config as dask_config
from dask.utils import parse_bytes

from xarray.namedarray.utils import build_chunkspec

target_chunksize = parse_bytes(dask_config.get("array.chunk-size"))

chunks = build_chunkspec(
data,
target_chunksize=target_chunksize,
)

return data.rechunk(chunks, **kwargs)
27 changes: 27 additions & 0 deletions xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
import sys
import warnings
from collections.abc import Hashable, Iterable, Iterator, Mapping
from functools import lru_cache
Expand All @@ -16,6 +17,8 @@

from numpy.typing import NDArray

from xarray.namedarray.parallelcompat import T_ChunkedArray

try:
from dask.array.core import Array as DaskArray
from dask.typing import DaskCollection
Expand Down Expand Up @@ -195,6 +198,30 @@ def either_dict_or_kwargs(
return pos_kwargs


def build_chunkspec(
data: T_ChunkedArray,
target_chunksize: int,
) -> tuple[int, ...]:
"""
Try to make chunks roughly cubic. This needs to be a bit smarter, it
really ought to account for xr.structure.chunks._getchunk and try to
use the default encoding to set the chunk size.
"""
from xarray.core.formatting import first_n_items

cftime_nbytes_approx: int = sys.getsizeof(first_n_items(data, 1)) # type: ignore[no-untyped-call]
elements_per_chunk = target_chunksize // cftime_nbytes_approx
ndim = data.ndim # type:ignore[attr-defined]
shape = data.shape # type:ignore[attr-defined]
if ndim > 0:
chunk_size_per_dim = int(elements_per_chunk ** (1.0 / ndim))
chunks = tuple(min(chunk_size_per_dim, dim_size) for dim_size in shape)
else:
chunks = ()

return chunks


class ReprObject:
"""Object that prints as the given value, for use with sentinel values."""

Expand Down
31 changes: 31 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,29 @@ def make_da():
return da


def make_da_cftime():
yrs = np.arange(2000, 2120)
cftime_dates = xr.date_range(
start=f"{yrs[0]}-01-01",
end=f"{yrs[-1]}-12-31",
freq="1YE",
use_cftime=True,
)
yr_array = np.tile(cftime_dates.values, (10, 1))
da = xr.DataArray(
yr_array,
dims=["x", "t"],
coords={"x": np.arange(10), "t": cftime_dates},
name="a",
).chunk({"x": 4, "t": 5})
da.x.attrs["long_name"] = "x"
da.attrs["test"] = "test"
da.coords["c2"] = 0.5
da.coords["ndcoord"] = da.x * 2

return da


def make_ds():
map_ds = xr.Dataset()
map_ds["a"] = make_da()
Expand Down Expand Up @@ -1141,6 +1164,14 @@ def test_auto_chunk_da(obj):
assert actual.chunks == expected.chunks


@pytest.mark.parametrize("obj", [make_da_cftime()])
def test_auto_chunk_da_cftime(obj):
actual = obj.chunk("auto").data
expected = obj.data.rechunk({0: 10, 1: 120})
np.testing.assert_array_equal(actual, expected)
assert actual.chunks == expected.chunks


def test_map_blocks_error(map_da, map_ds):
def bad_func(darray):
return (darray * darray.x + 5 * darray.y)[:1, :1]
Expand Down
Loading