Skip to content

ENH: EA._cast_pointwise_result #62105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,73 @@ def _from_sequence_of_strings(
)
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)

def _cast_pointwise_result(self, values) -> ArrayLike:
if len(values) == 0:
# Retain our dtype
return self[:0].copy()

try:
arr = pa.array(values, from_pandas=True)
except (ValueError, TypeError):
# e.g. test_by_column_values_with_same_starting_value with nested
# values, one entry of which is an ArrowStringArray
# or test_agg_lambda_complex128_dtype_conversion for complex values
return super()._cast_pointwise_result(values)

if pa.types.is_duration(arr.type):
# workaround for https://github.com/apache/arrow/issues/40620
result = ArrowExtensionArray._from_sequence(values)
if pa.types.is_duration(self._pa_array.type):
result = result.astype(self.dtype) # type: ignore[assignment]
elif pa.types.is_timestamp(self._pa_array.type):
# Try to retain original unit
new_dtype = ArrowDtype(pa.duration(self._pa_array.type.unit))
try:
result = result.astype(new_dtype) # type: ignore[assignment]
except ValueError:
pass
elif pa.types.is_date64(self._pa_array.type):
# Try to match unit we get on non-pointwise op
dtype = ArrowDtype(pa.duration("ms"))
result = result.astype(dtype) # type: ignore[assignment]
elif pa.types.is_date(self._pa_array.type):
# Try to match unit we get on non-pointwise op
dtype = ArrowDtype(pa.duration("s"))
result = result.astype(dtype) # type: ignore[assignment]
return result

elif pa.types.is_date(arr.type) and pa.types.is_date(self._pa_array.type):
arr = arr.cast(self._pa_array.type)
elif pa.types.is_time(arr.type) and pa.types.is_time(self._pa_array.type):
arr = arr.cast(self._pa_array.type)
elif pa.types.is_decimal(arr.type) and pa.types.is_decimal(self._pa_array.type):
arr = arr.cast(self._pa_array.type)
elif pa.types.is_integer(arr.type) and pa.types.is_integer(self._pa_array.type):
try:
arr = arr.cast(self._pa_array.type)
except pa.lib.ArrowInvalid:
# e.g. test_combine_add if we can't cast
pass
elif pa.types.is_floating(arr.type) and pa.types.is_floating(
self._pa_array.type
):
try:
arr = arr.cast(self._pa_array.type)
except pa.lib.ArrowInvalid:
# e.g. test_combine_add if we can't cast
pass

if isinstance(self.dtype, StringDtype):
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
# ArrowStringArrayNumpySemantics
return type(self)(arr).astype(self.dtype)
if self.dtype.na_value is np.nan:
# ArrowEA has different semantics, so we return numpy-based
# result instead
return super()._cast_pointwise_result(values)
return ArrowExtensionArray(arr)
return type(self)(arr)

@classmethod
def _box_pa(
cls, value, pa_type: pa.DataType | None = None
Expand Down
49 changes: 9 additions & 40 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
cast,
overload,
)
import warnings

import numpy as np

Expand All @@ -35,13 +34,11 @@
Substitution,
cache_readonly,
)
from pandas.util._exceptions import find_stack_level
from pandas.util._validators import (
validate_bool_kwarg,
validate_insert_loc,
)

from pandas.core.dtypes.cast import maybe_cast_pointwise_result
from pandas.core.dtypes.common import (
is_list_like,
is_scalar,
Expand Down Expand Up @@ -89,7 +86,6 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
FillnaOptions,
InterpolateOptions,
NumpySorter,
Expand Down Expand Up @@ -311,38 +307,6 @@ def _from_sequence(
"""
raise AbstractMethodError(cls)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
"""
Strict analogue to _from_sequence, allowing only sequences of scalars
that should be specifically inferred to the given dtype.

Parameters
----------
scalars : sequence
dtype : ExtensionDtype

Raises
------
TypeError or ValueError

Notes
-----
This is called in a try/except block when casting the result of a
pointwise operation.
"""
try:
return cls._from_sequence(scalars, dtype=dtype, copy=False)
except (ValueError, TypeError):
raise
except Exception:
warnings.warn(
"_from_scalars should only raise ValueError or TypeError. "
"Consider overriding _from_scalars where appropriate.",
stacklevel=find_stack_level(),
)
raise

@classmethod
def _from_sequence_of_strings(
cls, strings, *, dtype: ExtensionDtype, copy: bool = False
Expand Down Expand Up @@ -371,9 +335,6 @@ def _from_sequence_of_strings(
from a sequence of scalars.
api.extensions.ExtensionArray._from_factorized : Reconstruct an ExtensionArray
after factorization.
api.extensions.ExtensionArray._from_scalars : Strict analogue to _from_sequence,
allowing only sequences of scalars that should be specifically inferred to
the given dtype.

Examples
--------
Expand Down Expand Up @@ -416,6 +377,14 @@ def _from_factorized(cls, values, original):
"""
raise AbstractMethodError(cls)

def _cast_pointwise_result(self, values) -> ArrayLike:
"""
Cast the result of a pointwise operation (e.g. Series.map) to an
array, preserve dtype_backend if possible.
"""
values = np.asarray(values, dtype=object)
return lib.maybe_convert_objects(values, convert_non_numeric=True)

# ------------------------------------------------------------------------
# Must be a Sequence
# ------------------------------------------------------------------------
Expand Down Expand Up @@ -2842,7 +2811,7 @@ def _maybe_convert(arr):
# https://github.com/pandas-dev/pandas/issues/22850
# We catch all regular exceptions here, and fall back
# to an ndarray.
res = maybe_cast_pointwise_result(arr, self.dtype, same_dtype=False)
res = self._cast_pointwise_result(arr)
if not isinstance(res, type(self)):
# exception raised in _from_sequence; ensure we have ndarray
res = np.asarray(arr)
Expand Down
21 changes: 6 additions & 15 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
NpDtype,
Ordered,
Shape,
Expand Down Expand Up @@ -529,20 +528,12 @@ def _from_sequence(
) -> Self:
return cls(scalars, dtype=dtype, copy=copy)

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if dtype is None:
# The _from_scalars strictness doesn't make much sense in this case.
raise NotImplementedError

res = cls._from_sequence(scalars, dtype=dtype)

# if there are any non-category elements in scalars, these will be
# converted to NAs in res.
mask = isna(scalars)
if not (mask == res.isna()).all():
# Some non-category element in scalars got converted to NA in res.
raise ValueError
def _cast_pointwise_result(self, values) -> ArrayLike:
res = super()._cast_pointwise_result(values)
cat = type(self)._from_sequence(res, dtype=self.dtype)
if (cat.isna() == isna(res)).all():
# i.e. the conversion was non-lossy
return cat
return res

@overload
Expand Down
9 changes: 0 additions & 9 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
from pandas._typing import (
ArrayLike,
DateTimeErrorChoices,
DtypeObj,
IntervalClosedType,
TimeAmbiguous,
TimeNonexistent,
Expand Down Expand Up @@ -293,14 +292,6 @@ def _scalar_type(self) -> type[Timestamp]:
_dtype: np.dtype[np.datetime64] | DatetimeTZDtype
_freq: BaseOffset | None = None

@classmethod
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
# TODO: require any NAs be valid-for-DTA
# TODO: if dtype is passed, check for tzawareness compat?
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

@classmethod
def _validate_dtype(cls, values, dtype):
# used in TimeLikeOps.__init__
Expand Down
14 changes: 14 additions & 0 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pandas.util._decorators import doc

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
from pandas.core.dtypes.common import (
is_bool,
is_integer_dtype,
Expand Down Expand Up @@ -147,6 +148,19 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self:
values, mask = cls._coerce_to_array(scalars, dtype=dtype, copy=copy)
return cls(values, mask)

def _cast_pointwise_result(self, values) -> ArrayLike:
values = np.asarray(values, dtype=object)
result = lib.maybe_convert_objects(values, convert_to_nullable_dtype=True)
lkind = self.dtype.kind
rkind = result.dtype.kind
if (lkind in "iu" and rkind in "iu") or (lkind == rkind == "f"):
result = cast(BaseMaskedArray, result)
new_data = maybe_downcast_to_dtype(
result._data, dtype=self.dtype.numpy_dtype
)
result = type(result)(new_data, result._mask)
return result

@classmethod
@doc(ExtensionArray._empty)
def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self:
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from pandas.compat.numpy import function as nv

from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.cast import (
construct_1d_object_array_from_listlike,
maybe_downcast_to_dtype,
)
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.dtypes import NumpyEADtype
from pandas.core.dtypes.missing import isna
Expand All @@ -34,6 +37,7 @@
from collections.abc import Callable

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
FillnaOptions,
Expand Down Expand Up @@ -145,6 +149,24 @@ def _from_sequence(
result = result.copy()
return cls(result)

def _cast_pointwise_result(self, values) -> ArrayLike:
result = super()._cast_pointwise_result(values)
lkind = self.dtype.kind
rkind = result.dtype.kind
if (
(lkind in "iu" and rkind in "iu")
or (lkind == "f" and rkind == "f")
or (lkind == rkind == "c")
):
result = maybe_downcast_to_dtype(result, self.dtype.numpy_dtype)
elif rkind == "M":
# Ensure potential subsequent .astype(object) doesn't incorrectly
# convert Timestamps to ints
from pandas import array as pd_array

result = pd_array(result, copy=False)
return result

# ------------------------------------------------------------------------
# Data

Expand Down
17 changes: 17 additions & 0 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,23 @@ def _from_sequence(
def _from_factorized(cls, values, original) -> Self:
return cls(values, dtype=original.dtype)

def _cast_pointwise_result(self, values):
result = super()._cast_pointwise_result(values)
if result.dtype.kind == self.dtype.kind:
try:
# e.g. test_groupby_agg_extension
res = type(self)._from_sequence(result, dtype=self.dtype)
if ((res == result) | (isna(result) & res.isna())).all():
# This does not hold for e.g.
# test_arith_frame_with_scalar[0-__truediv__]
return res
return type(self)._from_sequence(result)
except (ValueError, TypeError):
return type(self)._from_sequence(result)
else:
# e.g. test_combine_le avoid casting bools to Sparse[float64, nan]
return type(self)._from_sequence(result)

# ------------------------------------------------------------------------
# Data
# ------------------------------------------------------------------------
Expand Down
14 changes: 7 additions & 7 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,6 @@ def tolist(self) -> list:
return [x.tolist() for x in self]
return list(self.to_numpy())

@classmethod
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
if lib.infer_dtype(scalars, skipna=True) not in ["string", "empty"]:
# TODO: require any NAs be valid-for-string
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

def _formatter(self, boxed: bool = False):
formatter = partial(
printing.pprint_thing,
Expand Down Expand Up @@ -732,6 +725,13 @@ def _from_sequence_of_strings(
) -> Self:
return cls._from_sequence(strings, dtype=dtype, copy=copy)

def _cast_pointwise_result(self, values) -> ArrayLike:
result = super()._cast_pointwise_result(values)
if isinstance(result.dtype, StringDtype):
# Ensure we retain our same na_value/storage
result = result.astype(self.dtype) # type: ignore[call-overload]
return result

@classmethod
def _empty(cls, shape, dtype) -> StringArray:
values = np.empty(shape, dtype=object)
Expand Down
Loading
Loading