Skip to content

Commit f8f579f

Browse files
committed
added preser_root param
1 parent 2fa5d5d commit f8f579f

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,6 @@ doc/source/savefig/
141141
# Pyodide/WASM related files #
142142
##############################
143143
/.pyodide-xbuildenv-*
144+
145+
local.py
146+
.venv/

doc/source/reference/extensions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ objects.
5858
api.extensions.ExtensionArray.isin
5959
api.extensions.ExtensionArray.isna
6060
api.extensions.ExtensionArray.ravel
61+
api.extensions.ExtensionArray.map
6162
api.extensions.ExtensionArray.repeat
6263
api.extensions.ExtensionArray.searchsorted
6364
api.extensions.ExtensionArray.shift

pandas/core/arrays/arrow/array.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,12 +1592,16 @@ def to_numpy(
15921592
result[~mask] = data[~mask]._pa_array.to_numpy()
15931593
return result
15941594

1595-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
1595+
def map(self, mapper,
1596+
na_action: Literal["ignore"] | None = None,
1597+
preserve_dtype: bool = False):
15961598
if is_numeric_dtype(self.dtype):
15971599
result = map_array(self.to_numpy(), mapper, na_action=na_action)
1598-
return self._cast_pointwise_result(result)
1600+
if preserve_dtype:
1601+
result = self._cast_pointwise_result(result)
1602+
return result
15991603
else:
1600-
return super().map(mapper, na_action)
1604+
return super().map(mapper, na_action, preserve_dtype=preserve_dtype)
16011605

16021606
@doc(ExtensionArray.duplicated)
16031607
def duplicated(

pandas/core/arrays/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,7 +2510,9 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
25102510

25112511
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
25122512

2513-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
2513+
def map(self, mapper,
2514+
na_action: Literal["ignore"] | None = None,
2515+
preserve_dtype: bool = False):
25142516
"""
25152517
Map values using an input mapping or function.
25162518
@@ -2522,6 +2524,12 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
25222524
If 'ignore', propagate NA values, without passing them to the
25232525
mapping correspondence. If 'ignore' is not supported, a
25242526
``NotImplementedError`` should be raised.
2527+
preserve_dtype : bool, default False
2528+
If True, attempt to cast the elementwise result back to the
2529+
original ExtensionArray type (and dtype) when possible. This is
2530+
primarily intended for identity or dtype-preserving mappings.
2531+
If False, the result of the mapping is returned as produced by
2532+
the underlying implementation (typically a NumPy ndarray).
25252533
25262534
Returns
25272535
-------
@@ -2531,7 +2539,9 @@ def map(self, mapper, na_action: Literal["ignore"] | None = None):
25312539
a MultiIndex will be returned.
25322540
"""
25332541
results = map_array(self, mapper, na_action=na_action)
2534-
return self._cast_pointwise_result(results)
2542+
if preserve_dtype:
2543+
results = self._cast_pointwise_result(results)
2544+
return results
25352545

25362546
# ------------------------------------------------------------------------
25372547
# GroupBy Methods

pandas/core/arrays/masked.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,9 +1394,14 @@ def max(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
13941394
)
13951395
return self._wrap_reduction_result("max", result, skipna=skipna, axis=axis)
13961396

1397-
def map(self, mapper, na_action: Literal["ignore"] | None = None):
1397+
def map(self, mapper,
1398+
na_action: Literal["ignore"] | None = None,
1399+
preserve_dtype: bool = False):
1400+
"""See ExtensionArray.map."""
13981401
result = map_array(self.to_numpy(), mapper, na_action=na_action)
1399-
return self._cast_pointwise_result(result)
1402+
if preserve_dtype:
1403+
result = self._cast_pointwise_result(result)
1404+
return result
14001405

14011406
@overload
14021407
def any(

0 commit comments

Comments
 (0)