diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 6485ba375f5..f28bef7e930 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -5,17 +5,21 @@ import numpy as np +from xarray.core.common import _contains_cftime_datetimes from xarray.core.indexing import ImplicitToExplicitIndexingAdapter from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray from xarray.namedarray.utils import is_duck_dask_array, module_available if TYPE_CHECKING: + from xarray.core.variable import Variable from xarray.namedarray._typing import ( T_Chunks, + _DType, _DType_co, _NormalizedChunks, duckarray, ) + from xarray.namedarray.parallelcompat import _Chunks try: from dask.array import Array as DaskArray @@ -264,3 +268,63 @@ def shuffle( if chunks != "auto": raise NotImplementedError("Only chunks='auto' is supported at present.") return dask.array.shuffle(x, indexer, axis, chunks="auto") + + def rechunk( # type: ignore[override] + self, + data: T_ChunkedArray, + chunks: _NormalizedChunks | tuple[int, ...] | _Chunks, + **kwargs: Any, + ) -> Any: + """ + Changes the chunking pattern of the given array. + + Called when the .chunk method is called on an xarray object that is already chunked. + + Parameters + ---------- + data : dask array + Array to be rechunked. + chunks : int, tuple, dict or str, optional + The new block dimensions to create. -1 indicates the full size of the + corresponding dimension. Default is "auto" which automatically + determines chunk sizes. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.Array.rechunk + cubed.Array.rechunk + """ + + if _contains_cftime_datetimes(data): + from dask import config as dask_config + from dask.array.core import normalize_chunks + from dask.utils import parse_bytes + + from xarray.namedarray.utils import fake_target_chunksize + + target_chunksize = parse_bytes(dask_config.get("array.chunk-size")) + limit, var_dtype = fake_target_chunksize( # type: ignore[var-annotated] + data, target_chunksize=target_chunksize + ) + + chunks = normalize_chunks( + chunks, + shape=data.shape, # type: ignore[attr-defined] + dtype=var_dtype, + limit=limit, + ) # type: ignore[no-untyped-call] + + return data.rechunk(chunks, **kwargs) + + def get_auto_chunk_size(self, var: Variable) -> tuple[int, _DType]: + from dask import config as dask_config + from dask.utils import parse_bytes + + from xarray.namedarray.utils import fake_target_chunksize + + target_chunksize = parse_bytes(dask_config.get("array.chunk-size")) + return fake_target_chunksize(var, target_chunksize=target_chunksize) diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index c1fe5999ecb..e2fb5e5682f 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -746,3 +746,32 @@ def store( cubed.store """ raise NotImplementedError() + + def get_auto_chunk_size( + self, + var, + ) -> tuple[int, _DType]: + """ + Get the default chunk size for a variable. + + This is used to determine the chunk size when opening a dataset with + ``chunks="auto"`` or when rechunking an array with ``chunks="auto"``. + + Parameters + ---------- + var : xarray.Variable + The variable for which to get the chunk size. + target_chunksize : int, optional + The target chunk size in bytes. If not provided, a default value is used. + + Returns + ------- + chunk_size : int + The chunk size in bytes. + dtype : np.dtype + The data type of the variable. + """ + + raise NotImplementedError( + "get_auto_chunk_size must be implemented by the chunk manager." + ) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 96060730345..f40b77371d5 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib +import sys import warnings from collections.abc import Hashable, Iterable, Iterator, Mapping from functools import lru_cache @@ -23,7 +24,9 @@ DaskArray = NDArray # type: ignore[assignment, misc] DaskCollection: Any = NDArray # type: ignore[no-redef] - from xarray.namedarray._typing import _Dim, duckarray + from xarray.core.variable import Variable + from xarray.namedarray._typing import DuckArray, _Dim, _DType, duckarray + from xarray.namedarray.parallelcompat import T_ChunkedArray K = TypeVar("K") @@ -195,6 +198,37 @@ def either_dict_or_kwargs( return pos_kwargs +def fake_target_chunksize( + data: DuckArray[Any] | T_ChunkedArray | Variable, + target_chunksize: int, +) -> tuple[int, _DType]: + """ + Naughty trick - let's get the ratio of our cftime_nbytes, and then compute + the ratio of that size to a np.float64. Then we can just adjust our target_chunksize + and use the default dask chunking algorithm to get a reasonable chunk size. + + ? I don't think T_chunkedArray or Variable should be necessary, but the calls + ? to this in daskmanager.py requires it to be that. I still need to wrap my head + ? around the typing here a bit more. + """ + import numpy as np + + from xarray.core.formatting import first_n_items + + output_dtype: _DType = np.dtype(np.float64) # type: ignore[assignment] + + if data.dtype == object: + nbytes_approx: int = sys.getsizeof(first_n_items(data, 1)) # type: ignore[no-untyped-call] + else: + nbytes_approx = data[0].itemsize + + f64_nbytes = output_dtype.itemsize # Should be 8 bytes + + target_chunksize = int(target_chunksize * (f64_nbytes / nbytes_approx)) + + return target_chunksize, output_dtype + + class ReprObject: """Object that prints as the given value, for use with sentinel values.""" diff --git a/xarray/structure/chunks.py b/xarray/structure/chunks.py index 281cfe278f1..150635f54d5 100644 --- a/xarray/structure/chunks.py +++ b/xarray/structure/chunks.py @@ -11,7 +11,8 @@ from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload from xarray.core import utils -from xarray.core.utils import emit_user_level_warning +from xarray.core.common import _contains_cftime_datetimes +from xarray.core.utils import emit_user_level_warning, is_dict_like from xarray.core.variable import IndexVariable, Variable from xarray.namedarray.parallelcompat import ( ChunkManagerEntrypoint, @@ -83,8 +84,23 @@ def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint): for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape, strict=True) ) + # Chunks can be either dict-like or tuple-like (according to type annotations) + # at this point, so check for # this before we manually construct our chunk + # spec- if we've set chunks to auto + _chunks = list(chunks.values()) if is_dict_like(chunks) else chunks + auto_chunks = all(_chunk == "auto" for _chunk in _chunks) + + if _contains_cftime_datetimes(var) and auto_chunks: + limit, var_dtype = chunkmanager.get_auto_chunk_size(var) + else: + limit, var_dtype = None, var.dtype + chunk_shape = chunkmanager.normalize_chunks( - chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape + chunk_shape, + shape=shape, + dtype=var_dtype, + limit=limit, + previous_chunks=preferred_chunk_shape, ) # Warn where requested chunks break preferred chunks, provided that the variable diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 93329b2297d..fafafd190f3 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5427,6 +5427,35 @@ def test_open_multi_dataset(self) -> None: ) as actual: assert_identical(expected, actual) + def test_open_dataset_cftime_autochunk(self) -> None: + """Create a dataset with cftime datetime objects and + ensure that auto-chunking works correctly.""" + import cftime + + from xarray.core.common import _contains_cftime_datetimes + + original = xr.Dataset( + { + "foo": ("time", [0.0]), + "time_bnds": ( + ("time", "bnds"), + [ + [ + cftime.Datetime360Day(2005, 12, 1, 0, 0, 0, 0), + cftime.Datetime360Day(2005, 12, 2, 0, 0, 0, 0), + ] + ], + ), + }, + {"time": [cftime.Datetime360Day(2005, 12, 1, 12, 0, 0, 0)]}, + ) + with create_tmp_file() as tmp: + original.to_netcdf(tmp) + with open_dataset(tmp, chunks="auto") as actual: + assert isinstance(actual.time_bnds.variable.data, da.Array) + assert _contains_cftime_datetimes(actual.time) + assert_identical(original, actual) + # Flaky test. Very open to contributions on fixing this @pytest.mark.flaky def test_dask_roundtrip(self) -> None: diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 9024f2ae677..68a93dfc9e2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1060,6 +1060,29 @@ def make_da(): return da +def make_da_cftime(): + yrs = np.arange(2000, 2120) + cftime_dates = xr.date_range( + start=f"{yrs[0]}-01-01", + end=f"{yrs[-1]}-12-31", + freq="1YE", + use_cftime=True, + ) + yr_array = np.tile(cftime_dates.values, (10, 1)) + da = xr.DataArray( + yr_array, + dims=["x", "t"], + coords={"x": np.arange(10), "t": cftime_dates}, + name="a", + ).chunk({"x": 4, "t": 5}) + da.x.attrs["long_name"] = "x" + da.attrs["test"] = "test" + da.coords["c2"] = 0.5 + da.coords["ndcoord"] = da.x * 2 + + return da + + def make_ds(): map_ds = xr.Dataset() map_ds["a"] = make_da() @@ -1141,6 +1164,14 @@ def test_auto_chunk_da(obj): assert actual.chunks == expected.chunks +@pytest.mark.parametrize("obj", [make_da_cftime()]) +def test_auto_chunk_da_cftime(obj): + actual = obj.chunk("auto").data + expected = obj.data.rechunk({0: 10, 1: 120}) + np.testing.assert_array_equal(actual, expected) + assert actual.chunks == expected.chunks + + def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 537cd824767..d35545d0603 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -18,6 +18,7 @@ _ShapeType_co, ) from xarray.namedarray.core import NamedArray, from_array +from xarray.namedarray.utils import fake_target_chunksize if TYPE_CHECKING: from types import ModuleType @@ -26,6 +27,7 @@ from xarray.namedarray._typing import ( Default, + DuckArray, _AttrsLike, _Dim, _DimsLike, @@ -37,6 +39,13 @@ duckarray, ) +try: + import cftime + + cftime_available = True +except ModuleNotFoundError: + cftime_available = False + class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): def __init__(self, array: duckarray[Any, _DType_co]) -> None: @@ -609,3 +618,46 @@ def test_repr() -> None: # Basic comparison: assert r == " Size: 8B\narray([0], dtype=uint64)" + + +@pytest.mark.parametrize( + "input_array, expected_chunksize_faked", + [ + (np.arange(100).reshape(10, 10), 1024), + (np.arange(100).reshape(10, 10).astype(np.float32), 2048), + ( + pytest.param( + np.array( + [ + cftime.Datetime360Day(2000, month, day, 0, 0, 0, 0) + for month in range(1, 11) + for day in range(1, 11) + ], + dtype=object, + ).reshape(10, 10), + 73, + marks=pytest.mark.xfail( + not cftime_available, + reason="cftime not available, cannot test object dtype with cftime dates", + ), + ) + ), + ], +) +def test_fake_target_chunksize( + input_array: DuckArray[Any], expected_chunksize_faked: int +) -> None: + """ + Check that `fake_target_chunksize` returns the expected chunksize and dtype. + - It pretends to dask we are chunking an array with an 8-byte dtype, ie. a float64. + As such, it will *double* the amount of memory a 4-byte dtype (like float32) would try to use, + fooling it into actually using the correct amount of memory. For object dtypes, which are + generally larger, it will reduce the effective dask configuration chunksize, reducing the size of + the arrays per chunk such that we get the same amount of memory used. + """ + target_chunksize = 1024 + + faked_chunksize, dtype = fake_target_chunksize(input_array, target_chunksize) # type: ignore[var-annotated] + + assert faked_chunksize == expected_chunksize_faked + assert dtype == np.float64