Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ New Features
By `Benoit Bovy <https://github.com/benbovy>`_.
- Support reading to `GPU memory with Zarr <https://zarr.readthedocs.io/en/stable/user-guide/gpu.html>`_ (:pull:`10078`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray`, enabling
support for CF boundaries coordinate (e.g., ``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10116`).
By `Benoit Bovy <https://github.com/benbovy>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
38 changes: 24 additions & 14 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
return self.to_dataset().identical(other.to_dataset())

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
# redirect to DatasetCoordinates._update_coords
self._data.coords._update_coords(coords, indexes)
Expand Down Expand Up @@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
return self._data._copy_listed(names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
variables = self._data._variables.copy()
variables.update(coords)
Expand Down Expand Up @@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
return self._data.dataset._copy_listed(self._names)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
from xarray.core.datatree import check_alignment

Expand Down Expand Up @@ -964,22 +964,32 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)

def _update_coords(
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
) -> None:
coords_plus_data = coords.copy()
coords_plus_data[_THIS_ARRAY] = self._data.variable
dims = calculate_dimensions(coords_plus_data)
if not set(dims) <= set(self.dims):
raise ValueError(
"cannot add coordinates with new dimensions to a DataArray"
)
self._data._coords = coords

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = dict(self._data.xindexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes
if set(dims) > set(self.dims):
for k, v in coords.items():
if any(d not in self.dims for d in v.dims):
# allow any coordinate associated with an index that shares at least
# one of dataarray's dimensions
temp_indexes = Indexes(
indexes, {k: v for k, v in coords.items() if k in indexes}
)
if k in indexes:
index_dims = temp_indexes.get_all_dims(k)
if any(d in self.dims for d in index_dims):
continue
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {self.dims}"
)

self._data._coords = coords
self._data._indexes = indexes

def _drop_coords(self, coord_names):
# should drop indexed coordinates only
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,30 @@
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)


def _check_coords_dims(shape, coords, dim):
def _check_coords_dims(
shape: tuple[int, ...], coords: Coordinates, dim: tuple[Hashable, ...]
):
sizes = dict(zip(dim, shape, strict=True))
extra_index_dims: set[str] = set()

for k, v in coords.items():
if any(d not in dim for d in v.dims):
# allow any coordinate associated with an index that shares at least
# one of dataarray's dimensions
indexes = coords.xindexes
if k in indexes:
index_dims = indexes.get_all_dims(k)
if any(d in dim for d in index_dims):
extra_index_dims.update(d for d in v.dims if d not in dim)
continue
raise ValueError(
f"coordinate {k} has dimensions {v.dims}, but these "
"are not a subset of the DataArray "
f"dimensions {dim}"
)

for d, s in v.sizes.items():
if s != sizes[d]:
if d not in extra_index_dims and s != sizes[d]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if d not in extra_index_dims and s != sizes[d]:
if d not in extra_index_dims or s != sizes[d]:

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be and.

extra_index_dims corresponds to all non-array dimensions of index coordinates to include in the dataarray (size[d] would return a KeyError).

I renamed it and added comment in 695fb86.

raise ValueError(
f"conflicting sizes for dimension {d!r}: "
f"length {sizes[d]} on the data but length {s} on "
Expand Down Expand Up @@ -212,8 +224,6 @@ def _infer_coords_and_dims(
var.dims = (dim,)
new_coords[dim] = var.to_index_variable()

_check_coords_dims(shape, new_coords, dims_tuple)

return new_coords, dims_tuple


Expand Down Expand Up @@ -487,6 +497,7 @@ def __init__(

if not isinstance(coords, Coordinates):
coords = create_coords_with_default_indexes(coords)
_check_coords_dims(data.shape, coords, dims)
indexes = dict(coords.xindexes)
coords = {k: v.copy() for k, v in coords.variables.items()}

Expand Down
14 changes: 12 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,10 +1210,20 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
needed_dims = set(variable.dims)

coords: dict[Hashable, Variable] = {}
temp_indexes = self.xindexes
# preserve ordering
for k in self._variables:
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]
if k in self._coord_names:
if (
k not in coords
and k in temp_indexes
and set(temp_indexes.get_all_dims(k)) & needed_dims
):
# add all coordinates of each index that shares at least one dimension
# with the dimensions of the extracted variable
coords.update(temp_indexes.get_all_coords(k))
Comment on lines +1217 to +1224
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coud this use a separate loop after the existing loop instead? e.g.,

for k in self._indexes:
    if k in coords:
        coords.update(self.xindexes.get_all_coords(k))

Or if we allow indexes without a coordinate of the same name:

for k in self._indexes:
    if set(self.xindexes.get_all_dims(k)) & needed_dims:
        coords.update(self.xindexes.get_all_coords(k))

Ideally, I would like the logic here to be just as simple as the words describing how it works, so a comment is not necessary!

elif set(self._variables[k].dims) <= needed_dims:
coords[k] = self._variables[k]

indexes = filter_indexes_from_coords(self._indexes, set(coords))

Expand Down
11 changes: 10 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"""

_index_type: type[Index] | type[pd.Index]
_index_dims: dict[Hashable, Mapping[Hashable, int]]
_indexes: dict[Any, T_PandasOrXarrayIndex]
_variables: dict[Any, Variable]

Expand All @@ -1576,6 +1577,7 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
"__id_coord_names",
"__id_index",
"_dims",
"_index_dims",
"_index_type",
"_indexes",
"_variables",
Expand Down Expand Up @@ -1619,6 +1621,7 @@ def __init__(
)

self._index_type = index_type
self._index_dims = {}
self._indexes = dict(**indexes)
self._variables = dict(**variables)

Expand Down Expand Up @@ -1737,7 +1740,13 @@ def get_all_dims(
"""
from xarray.core.variable import calculate_dimensions

return calculate_dimensions(self.get_all_coords(key, errors=errors))
if key in self._index_dims:
return self._index_dims[key]
else:
dims = calculate_dimensions(self.get_all_coords(key, errors=errors))
if dims:
self._index_dims[key] = dims
return dims

def group_by_index(
self,
Expand Down
49 changes: 49 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,30 @@ class CustomIndex(Index): ...
# test coordinate variables copied
assert da.coords["x"] is not coords.variables["x"]

def test_constructor_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

actual = DataArray([1.0, 2.0], coords=coords, dims="x")

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_equals_and_identical(self) -> None:
orig = DataArray(np.arange(5.0), {"a": 42}, dims="x")

Expand Down Expand Up @@ -1634,6 +1658,31 @@ def test_assign_coords_no_default_index(self) -> None:
assert_identical(actual.coords, coords, check_default_indexes=False)
assert "y" not in actual.xindexes

def test_assign_coords_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

da = DataArray([1.0, 2.0], dims="x")
actual = da.assign_coords(coords)

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_coords_alignment(self) -> None:
lhs = DataArray([1, 2, 3], [("x", [0, 1, 2])])
rhs = DataArray([2, 3, 4], [("x", [1, 2, 3])])
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,6 +4206,31 @@ def test_getitem_multiple_dtype(self) -> None:
dataset = Dataset({key: ("dim0", range(1)) for key in keys})
assert_identical(dataset, dataset[keys])

def test_getitem_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
# This test only requires that the coordinates to assign have an
# index, whatever its type.
pass

idx = AnyIndex()
coords = Coordinates(
coords={
"x": ("x", [1, 2]),
"x_bounds": (("x", "x_bnds"), [(0.5, 1.5), (1.5, 2.5)]),
},
indexes={"x": idx, "x_bounds": idx},
)

ds = Dataset({"foo": (("x"), [1.0, 2.0])}, coords=coords)
actual = ds["foo"]

# cannot use `assert_identical()` test utility function here yet
# (indexes invariant check is still based on IndexVariable, which
# doesn't work with AnyIndex coordinate variables here)
assert actual.coords.to_dataset().equals(coords.to_dataset())
assert list(actual.coords.xindexes) == list(coords.xindexes)
assert "x_bnds" not in actual.dims

def test_virtual_variables_default_coords(self) -> None:
dataset = Dataset({"foo": ("x", range(10))})
expected1 = DataArray(range(10), dims="x", name="x")
Expand Down
Loading