Skip to content

BUG: Fix concat dtype preservation through concat #61893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
22 changes: 22 additions & 0 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def concat_compat(
-------
a single array, preserving the combined dtypes
"""

if len(to_concat) and lib.dtypes_all_equal([obj.dtype for obj in to_concat]):
# fastpath!
obj = to_concat[0]
Expand All @@ -92,6 +93,27 @@ def concat_compat(
to_concat_eas,
axis=axis, # type: ignore[call-arg]
)
# Special handling for categorical arrays solves #51362
if (
len(to_concat)
and all(isinstance(arr.dtype, CategoricalDtype) for arr in to_concat)
and axis == 0
):
# Filter out empty arrays before union, similar to non_empties logic
non_empty_categoricals = [x for x in to_concat if _is_nonempty(x, axis)]

if len(non_empty_categoricals) == 0:
# All arrays are empty, return the first one (they're all categorical)
return to_concat[0]
elif len(non_empty_categoricals) == 1:
# Only one non-empty array, return it directly
return non_empty_categoricals[0]
else:
# Multiple non-empty arrays, use union_categoricals
return union_categoricals(
non_empty_categoricals, sort_categories=True
) # Performance cost, but necessary to keep tests passing.
# see pandas/tests/reshape/concat/test_append_common.py:498

# If all arrays are empty, there's nothing to convert, just short-cut to
# the concatenation, #3121.
Expand Down
29 changes: 25 additions & 4 deletions pandas/tests/dtypes/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pandas.core.dtypes.concat as _concat

import pandas as pd
from pandas import Series
from pandas import (
DataFrame,
Series,
)
import pandas._testing as tm


Expand All @@ -14,12 +17,12 @@ def test_concat_mismatched_categoricals_with_empty():

result = _concat.concat_compat([ser1._values, ser2._values])
expected = pd.concat([ser1, ser2])._values
tm.assert_numpy_array_equal(result, expected)
tm.assert_categorical_equal(result, expected)


def test_concat_single_dataframe_tz_aware():
# https://github.com/pandas-dev/pandas/issues/25257
df = pd.DataFrame(
df = DataFrame(
{"timestamp": [pd.Timestamp("2020-04-08 09:00:00.709949+0000", tz="UTC")]}
)
expected = df.copy()
Expand Down Expand Up @@ -53,7 +56,7 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string):
ser2 = Series(dtype=float)

result = pd.concat([ser1, ser2], axis=1)
expected = pd.DataFrame(
expected = DataFrame(
data=[
(0.0, None),
],
Expand All @@ -64,3 +67,21 @@ def test_concat_series_between_empty_and_tzaware_series(using_infer_string):
dtype=float,
)
tm.assert_frame_equal(result, expected)


def test_concat_categorical_dataframes():
df = DataFrame({"a": [0, 1]}, dtype="category")
df2 = DataFrame({"a": [2, 3]}, dtype="category")

result = pd.concat([df, df2], axis=0)

assert result["a"].dtype.name == "category"


def test_concat_categorical_series():
ser = Series([0, 1], dtype="category")
ser2 = Series([2, 3], dtype="category")

result = pd.concat([ser, ser2], axis=0)

assert result.dtype.name == "category"
Loading