Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4135,6 +4135,33 @@ def isin(self, values, level=None) -> npt.NDArray[np.bool_]:
# base class "Index" defined the type as "Callable[[Index, Any, bool], Any]")
rename = Index.set_names # type: ignore[assignment]

def difference(self, other, sort=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems really unlikely to be the right place to handle this. if you step through the existing implementation, where is the first step that goes wrong?

"""
Return a new MultiIndex with elements in self that are not in other.
Fixed to work with pyarrow-backed Timestamps.
"""
if isinstance(other, type(self)):
# Convert pyarrow-backed Timestamps to pandas Timestamps for comparison
self_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
for level in self.levels]
other_arrays = [level.to_pandas() if hasattr(level, "to_pandas") else level
for level in other.levels]
self_conv = pd.MultiIndex.from_arrays(self_arrays, names=self.names)
other_conv = pd.MultiIndex.from_arrays(other_arrays, names=other.names)
result = self_conv.difference(other_conv, sort=sort)
# Preserve pyarrow dtypes if present
for i, level in enumerate(self.levels):
if getattr(level, "dtype", None) == "timestamp[ns][pyarrow]":
result = pd.MultiIndex.from_arrays(
[pd.Series(arr, dtype="timestamp[ns][pyarrow]") if i==idx else arr
for idx, arr in enumerate(result.levels)],
names=result.names
)
return result
else:
return super(type(self), self).difference(other, sort=sort)


# ---------------------------------------------------------------
# Arithmetic/Numeric Methods - Disabled

Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/indexes/multi/test_timestamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd
import pytest

pytest.importorskip("pyarrow")

def test_difference_with_pyarrow_timestamp():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would go in tests/indexes/multi/test_setops.py

dates = pd.Series(
["2024-01-01", "2024-01-02"], dtype="timestamp[ns][pyarrow]"
)
ids = [1, 2]

mi = pd.MultiIndex.from_arrays([ids, dates], names=["id", "date"])
to_remove = mi[:1]

result = mi.difference(to_remove)

expected_dates = pd.Series(
["2024-01-02"], dtype="timestamp[ns][pyarrow]"
)
expected_ids = [2]
expected = pd.MultiIndex.from_arrays(
[expected_ids, expected_dates], names=["id", "date"]
)

pd.testing.assert_index_equal(result, expected)
Loading