diff --git a/pyproject.toml b/pyproject.toml index 5e5fd00328b..d4901f4d78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ accel = [ "numba>=0.59", "flox>=0.9", "opt_einsum", + "numpy<2.3", # numba has not updated yet: https://github.com/numba/numba/issues/10105 ] complete = ["xarray[accel,etc,io,parallel,viz]"] io = [ @@ -324,6 +325,8 @@ known-first-party = ["xarray"] [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. ban-relative-imports = "all" +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pandas.api.types.is_extension_array_dtype".msg = "Use xarray.core.utils.is_allowed_extension_array{_dtype} instead. Only use the banend API if the incoming data has already been sanitized by xarray" [tool.pytest.ini_options] addopts = [ diff --git a/xarray/computation/ops.py b/xarray/computation/ops.py index 61834a85acf..1514f1694ca 100644 --- a/xarray/computation/ops.py +++ b/xarray/computation/ops.py @@ -8,12 +8,15 @@ from __future__ import annotations import operator -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np from xarray.core import dtypes, duck_array_ops +if TYPE_CHECKING: + pass + try: import bottleneck as bn @@ -158,8 +161,8 @@ def fillna(data, other, join="left", dataset_join="left"): ) -# Unsure why we get a mypy error here -def where_method(self, cond, other=dtypes.NA): # type: ignore[has-type] +# TODO: type this properly +def where_method(self, cond, other=dtypes.NA): # type: ignore[unused-ignore,has-type] """Return elements from `self` or `other` depending on `cond`. Parameters diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f79df3da7c2..5355d81c6c0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -26,7 +26,6 @@ import numpy as np import pandas as pd -from pandas.api.types import is_extension_array_dtype from xarray.coding.calendar_ops import convert_calendar, interp_calendar from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings @@ -91,6 +90,7 @@ either_dict_or_kwargs, emit_user_level_warning, infix_dims, + is_allowed_extension_array, is_dict_like, is_duck_array, is_duck_dask_array, @@ -6780,7 +6780,7 @@ def reduce( elif ( # Some reduction functions (e.g. std, var) need to run on variables # that don't have the reduce dims: PR5393 - not is_extension_array_dtype(var.dtype) + not pd.api.types.is_extension_array_dtype(var.dtype) # noqa: TID251 and ( not reduce_dims or not numeric_only @@ -7105,12 +7105,12 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): non_extension_array_columns = [ k for k in columns_in_order - if not is_extension_array_dtype(self.variables[k].data) + if not pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251 ] extension_array_columns = [ k for k in columns_in_order - if is_extension_array_dtype(self.variables[k].data) + if pd.api.types.is_extension_array_dtype(self.variables[k].data) # noqa: TID251 ] extension_array_columns_different_index = [ k @@ -7302,7 +7302,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: arrays = [] extension_arrays = [] for k, v in dataframe.items(): - if not is_extension_array_dtype(v) or isinstance( + if not is_allowed_extension_array(v) or isinstance( v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES ): arrays.append((k, np.asarray(v))) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..0a7b1722877 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -4,7 +4,7 @@ from typing import Any import numpy as np -from pandas.api.types import is_extension_array_dtype +import pandas as pd from xarray.compat import array_api_compat, npcompat from xarray.compat.npcompat import HAS_STRING_DTYPE @@ -213,7 +213,7 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: if isinstance(dtype, np.dtype): return npcompat.isdtype(dtype, kind) - elif is_extension_array_dtype(dtype): + elif pd.api.types.is_extension_array_dtype(dtype): # noqa: TID251 # we never want to match pandas extension array dtypes return False else: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0c7d40113d6..b8a4011a72e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -23,7 +23,6 @@ take, unravel_index, # noqa: F401 ) -from pandas.api.types import is_extension_array_dtype from xarray.compat import dask_array_compat, dask_array_ops from xarray.compat.array_api_compat import get_array_namespace @@ -184,7 +183,7 @@ def isnull(data): dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool return full_like(data, dtype=dtype, fill_value=False) # at this point, array should have dtype=object - elif isinstance(data, np.ndarray) or is_extension_array_dtype(data): + elif isinstance(data, np.ndarray) or pd.api.types.is_extension_array_dtype(data): # noqa: TID251 return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -266,10 +265,12 @@ def asarray(data, xp=np, dtype=None): def as_shared_dtype(scalars_or_arrays, xp=None): """Cast arrays to a shared dtype using xarray's type promotion rules.""" - if any(is_extension_array_dtype(x) for x in scalars_or_arrays): - extension_array_types = [ - x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) - ] + extension_array_types = [ + x.dtype + for x in scalars_or_arrays + if pd.api.types.is_extension_array_dtype(x) # noqa: TID251 + ] + if len(extension_array_types) >= 1: non_nans = [x for x in scalars_or_arrays if not isna(x)] if len(extension_array_types) == len(non_nans) and all( isinstance(x, type(extension_array_types[0])) for x in extension_array_types diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index d85f7e66b55..9262982d4cb 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -8,10 +8,9 @@ import numpy as np import pandas as pd from packaging.version import Version -from pandas.api.types import is_extension_array_dtype from xarray.core.types import DTypeLikeSave, T_ExtensionArray -from xarray.core.utils import NDArrayMixin +from xarray.core.utils import NDArrayMixin, is_allowed_extension_array HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} @@ -100,10 +99,11 @@ def __post_init__(self): raise TypeError(f"{self.array} is not an pandas ExtensionArray.") # This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because # we do support extension arrays from datetime, for example, that need - # duck array support internally via this class. - if isinstance(self.array, pd.arrays.NumpyExtensionArray): + # duck array support internally via this class. These can appear from `DatetimeIndex` + # wrapped by `PandasIndex` internally, for example. + if not is_allowed_extension_array(self.array): raise TypeError( - "`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally." + f"{self.array.dtype!r} should be converted to a numpy array in `xarray` internally." ) def __array_function__(self, func, types, args, kwargs): @@ -126,7 +126,7 @@ def replace_duck_with_extension_array(args) -> list: if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: raise KeyError("Function not registered for pandas extension arrays.") res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): + if is_allowed_extension_array(res): return PandasExtensionArray(res) return res @@ -135,7 +135,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] - if is_extension_array_dtype(item): + if is_allowed_extension_array(item): return PandasExtensionArray(item) if np.isscalar(item) or isinstance(key, int): return PandasExtensionArray(type(self.array)._from_sequence([item])) # type: ignore[call-arg,attr-defined,unused-ignore] diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index c233c6911e4..d22fc37aa4f 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -23,6 +23,7 @@ Frozen, emit_user_level_warning, get_valid_numpy_dtype, + is_allowed_extension_array_dtype, is_dict_like, is_scalar, ) @@ -666,9 +667,8 @@ def __init__( self.index = index self.dim = dim - if coord_dtype is None: - if pd.api.types.is_extension_array_dtype(index.dtype): + if is_allowed_extension_array_dtype(index.dtype): cast(pd.api.extensions.ExtensionDtype, index.dtype) coord_dtype = index.dtype else: diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e4458fb88f..c98175578f8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -24,6 +24,8 @@ NDArrayMixin, either_dict_or_kwargs, get_valid_numpy_dtype, + is_allowed_extension_array, + is_allowed_extension_array_dtype, is_duck_array, is_duck_dask_array, is_scalar, @@ -1763,12 +1765,12 @@ def __init__( self.array = safe_cast_to_index(array) if dtype is None: - if pd.api.types.is_extension_array_dtype(array.dtype): + if is_allowed_extension_array(array): cast(pd.api.extensions.ExtensionDtype, array.dtype) self._dtype = array.dtype else: self._dtype = get_valid_numpy_dtype(array) - elif pd.api.types.is_extension_array_dtype(dtype): + elif is_allowed_extension_array_dtype(dtype): self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: self._dtype = np.dtype(cast(DTypeLike, dtype)) @@ -1816,10 +1818,7 @@ def get_duck_array(self) -> np.ndarray | PandasExtensionArray: # We return an PandasExtensionArray wrapper type that satisfies # duck array protocols. # `NumpyExtensionArray` is excluded - if pd.api.types.is_extension_array_dtype(self.array) and not isinstance( - self.array.array, - pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] - ): + if is_allowed_extension_array(self.array): from xarray.core.extension_array import PandasExtensionArray return PandasExtensionArray(self.array.array) @@ -1916,7 +1915,7 @@ def copy(self, deep: bool = True) -> Self: @property def nbytes(self) -> int: - if pd.api.types.is_extension_array_dtype(self.dtype): + if is_allowed_extension_array(self.array): return self.array.nbytes dtype = self._get_numpy_dtype() diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 562706a1ac0..386f1e346de 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -104,6 +104,20 @@ T = TypeVar("T") +def is_allowed_extension_array_dtype(dtype: Any): + return pd.api.types.is_extension_array_dtype(dtype) and not isinstance( # noqa: TID251 + dtype, pd.StringDtype + ) + + +def is_allowed_extension_array(array: Any) -> bool: + return ( + hasattr(array, "dtype") + and is_allowed_extension_array_dtype(array.dtype) + and not isinstance(array, pd.arrays.NumpyExtensionArray) # type: ignore[attr-defined] + ) + + def alias_message(old_name: str, new_name: str) -> str: return f"{old_name} has been deprecated. Use {new_name} instead." diff --git a/xarray/core/variable.py b/xarray/core/variable.py index bcc2ca4e460..325dde57e1d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -40,6 +40,7 @@ emit_user_level_warning, ensure_us_time_resolution, infix_dims, + is_allowed_extension_array, is_dict_like, is_duck_array, is_duck_dask_array, @@ -198,7 +199,9 @@ def _maybe_wrap_data(data): return PandasIndexingAdapter(data) if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES): return data.to_numpy() - if isinstance(data, pd.api.extensions.ExtensionArray): + if isinstance( + data, pd.api.extensions.ExtensionArray + ) and is_allowed_extension_array(data): return PandasExtensionArray(data) return data @@ -261,7 +264,8 @@ def convert_non_numpy_type(data): if isinstance(data, pd.Series | pd.DataFrame): if ( isinstance(data, pd.Series) - and pd.api.types.is_extension_array_dtype(data) + and is_allowed_extension_array(data.array) + # Some datetime types are not allowed as well as backing Variable types and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES) ): pandas_data = data.array diff --git a/xarray/tests/test_pandas_to_xarray.py b/xarray/tests/test_pandas_to_xarray.py index 111866541eb..8346f5abe21 100644 --- a/xarray/tests/test_pandas_to_xarray.py +++ b/xarray/tests/test_pandas_to_xarray.py @@ -37,6 +37,7 @@ import pandas as pd import pandas._testing as tm import pytest +from packaging.version import Version from pandas import ( Categorical, CategoricalIndex, @@ -171,7 +172,9 @@ def test_to_xarray_with_multiindex(self, df): result = result.to_dataframe() expected = df.copy() - expected["f"] = expected["f"].astype(object) + expected["f"] = expected["f"].astype( + object if Version(pd.__version__) < Version("3.0.0dev0") else str + ) expected.columns.name = None tm.assert_frame_equal(result, expected) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 2f67e97522c..e2f4a3154f3 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1594,7 +1594,7 @@ def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error - assert pd.api.types.is_extension_array_dtype(v.dtype) + assert isinstance(v.dtype, pd.CategoricalDtype) def test_squeeze(self): v = Variable(["x", "y"], [[1]])