Skip to content

Commit 1d153bb

Browse files
authored
TST(string dtype): Resolve xfails in test_from_dummies (#60694)
1 parent bc6ad14 commit 1d153bb

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

pandas/core/reshape/encoding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,9 @@ def from_dummies(
390390
The default category is the implied category when a value has none of the
391391
listed categories specified with a one, i.e. if all dummies in a row are
392392
zero. Can be a single value for all variables or a dict directly mapping
393-
the default categories to a prefix of a variable.
393+
the default categories to a prefix of a variable. The default category
394+
will be coerced to the dtype of ``data.columns`` if such coercion is
395+
lossless, and will raise otherwise.
394396
395397
Returns
396398
-------

pandas/tests/reshape/test_from_dummies.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,7 @@ def test_no_prefix_string_cats_default_category(
333333
):
334334
dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]})
335335
result = from_dummies(dummies, default_category=default_category)
336-
expected = DataFrame(expected)
337-
if using_infer_string:
338-
expected[""] = expected[""].astype("str")
336+
expected = DataFrame(expected, dtype=dummies.columns.dtype)
339337
tm.assert_frame_equal(result, expected)
340338

341339

@@ -449,3 +447,31 @@ def test_maintain_original_index():
449447
result = from_dummies(df)
450448
expected = DataFrame({"": list("abca")}, index=list("abcd"))
451449
tm.assert_frame_equal(result, expected)
450+
451+
452+
def test_int_columns_with_float_default():
453+
# https://github.com/pandas-dev/pandas/pull/60694
454+
df = DataFrame(
455+
{
456+
3: [1, 0, 0],
457+
4: [0, 1, 0],
458+
},
459+
)
460+
with pytest.raises(ValueError, match="Trying to coerce float values to integers"):
461+
from_dummies(df, default_category=0.5)
462+
463+
464+
def test_object_dtype_preserved():
465+
# https://github.com/pandas-dev/pandas/pull/60694
466+
# When the input has object dtype, the result should as
467+
# well even when infer_string is True.
468+
df = DataFrame(
469+
{
470+
"x": [1, 0, 0],
471+
"y": [0, 1, 0],
472+
},
473+
)
474+
df.columns = df.columns.astype("object")
475+
result = from_dummies(df, default_category="z")
476+
expected = DataFrame({"": ["x", "y", "z"]}, dtype="object")
477+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)