Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d64eca5
Align decimal dtypes to schema after parquet IO scan
Matt711 Sep 15, 2025
42e5e43
clean up
Matt711 Sep 15, 2025
2d0f114
Merge branch 'branch-25.10' into bug/polars/parquet-schema-mismatch
Matt711 Sep 16, 2025
d9e6a96
Merge branch 'branch-25.10' into bug/polars/parquet-schema-mismatch
Matt711 Sep 17, 2025
7b35f1c
Merge branch 'branch-25.10' into bug/polars/parquet-schema-mismatch
Matt711 Sep 19, 2025
7b768fe
casts for decimal pdsh support
Matt711 Sep 22, 2025
f83d127
clean up
Matt711 Sep 22, 2025
a753f4f
clean up
Matt711 Sep 22, 2025
1cb19ad
clean up
Matt711 Sep 22, 2025
67fcd7e
clean up
Matt711 Sep 22, 2025
6647326
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Sep 22, 2025
d559636
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Sep 22, 2025
dd2286b
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Sep 25, 2025
51ea86f
align types in predicate
Matt711 Sep 29, 2025
8ff793a
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Sep 29, 2025
d152f5c
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Sep 30, 2025
dcdb068
insert cast ahead of time
Matt711 Sep 30, 2025
99b5ba2
add helper methods to datatype
Matt711 Sep 30, 2025
e811eec
clean up
Matt711 Sep 30, 2025
e5452fb
clean up
Matt711 Sep 30, 2025
d8b24c9
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Oct 1, 2025
b9eda7d
address review
Matt711 Oct 2, 2025
4b9df48
add datatype test
Matt711 Oct 2, 2025
66b334c
pass more tests
Matt711 Oct 2, 2025
623ab7c
merge conflict
Matt711 Oct 2, 2025
176654a
merge conflict
Matt711 Oct 7, 2025
9f83cb4
code coverage
Matt711 Oct 8, 2025
3752341
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Oct 8, 2025
555c8ee
address review
Matt711 Oct 8, 2025
6136d0d
Merge branch 'branch-25.12' into fea/polars/pdsh-decimals
Matt711 Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/cudf_polars/cudf_polars/containers/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def children(self) -> list[DataType]:
return [DataType(self.polars_type.inner)] # type: ignore[attr-defined]
return []

def scale(self) -> int:
"""The scale of this DataType."""
return self.plc_type.scale()

@staticmethod
def common_decimal_dtype(left: DataType, right: DataType) -> DataType:
"""Return a common decimal DataType for the two inputs."""
if not (
plc.traits.is_fixed_point(left.plc_type)
and plc.traits.is_fixed_point(right.plc_type)
):
raise ValueError("Both inputs required to be decimal types.")
return DataType(pl.Decimal(38, abs(min(left.scale(), right.scale()))))

def __eq__(self, other: object) -> bool:
"""Equality of DataTypes."""
if not isinstance(other, DataType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102
def _reduce(
self, column: Column, *, request: plc.aggregation.Aggregation
) -> Column:
if (
self.name in {"mean", "median"}
and plc.traits.is_fixed_point(column.dtype.plc_type)
and self.dtype.plc_type.id() in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}
):
column = column.astype(self.dtype)
return Column(
plc.Column.from_scalar(
plc.reduce.reduce(column.obj, request, self.dtype.plc_type),
1,
plc.reduce.reduce(column.obj, request, self.dtype.plc_type), 1
),
name=column.name,
dtype=self.dtype,
Expand Down
141 changes: 131 additions & 10 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import polars as pl

import pylibcudf as plc
from pylibcudf import expressions as plc_expr

import cudf_polars.dsl.expr as expr
from cudf_polars.containers import Column, DataFrame, DataType
from cudf_polars.containers.dataframe import NamedColumn
from cudf_polars.dsl.expressions import rolling, unary
from cudf_polars.dsl.expressions.base import ExecutionContext
from cudf_polars.dsl.nodebase import Node
Expand Down Expand Up @@ -81,6 +83,23 @@
]


_BINOPS = {
plc.binaryop.BinaryOperator.EQUAL,
plc.binaryop.BinaryOperator.NOT_EQUAL,
plc.binaryop.BinaryOperator.LESS,
plc.binaryop.BinaryOperator.LESS_EQUAL,
plc.binaryop.BinaryOperator.GREATER,
plc.binaryop.BinaryOperator.GREATER_EQUAL,
# TODO: Handle other binary operations as needed
}


_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}


_FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}


class IR(Node["IR"]):
"""Abstract plan node, representing an unevaluated dataframe."""

Expand Down Expand Up @@ -214,20 +233,16 @@ def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | Non

def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
# TODO: Alternatively set the schema of the parquet reader to decimal128
plc_decimals_ids = {
plc.TypeId.DECIMAL32,
plc.TypeId.DECIMAL64,
plc.TypeId.DECIMAL128,
}
cast_list = []

for name, col in df.column_map.items():
src = col.obj.type()
dst = schema[name].plc_type

if (
src.id() in plc_decimals_ids
and dst.id() in plc_decimals_ids
and ((src.id() != dst.id()) or (src.scale != dst.scale))
plc.traits.is_fixed_point(src)
and plc.traits.is_fixed_point(dst)
and ((src.id() != dst.id()) or (src.scale() != dst.scale()))
):
cast_list.append(
Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name])
Expand Down Expand Up @@ -1586,6 +1601,108 @@ def do_evaluate(
return DataFrame(broadcasted).slice(zlice)


def _strip_predicate_casts(node: expr.Expr) -> expr.Expr:
if isinstance(node, expr.Cast):
(child,) = node.children
child = _strip_predicate_casts(child)

src = child.dtype
dst = node.dtype

if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point(
dst.plc_type
):
return child

if not node.children:
return node
return node.reconstruct([_strip_predicate_casts(child) for child in node.children])


def _add_cast(
target: DataType,
side: expr.ColRef,
left_casts: dict[str, DataType],
right_casts: dict[str, DataType],
) -> None:
(col,) = side.children
assert isinstance(col, expr.Col)
casts = (
left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts
)
casts[col.name] = target


def _align_decimal_binop_types(
left_expr: expr.ColRef,
right_expr: expr.ColRef,
left_casts: dict[str, DataType],
right_casts: dict[str, DataType],
) -> None:
left_type, right_type = left_expr.dtype, right_expr.dtype

if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
right_type.plc_type
):
target = DataType.common_decimal_dtype(left_type, right_type)

if left_type.id() != target.id() or left_type.scale() != target.scale():
_add_cast(target, left_expr, left_casts, right_casts)

if right_type.id() != target.id() or right_type.scale() != target.scale():
_add_cast(target, right_expr, left_casts, right_casts)

elif (
plc.traits.is_fixed_point(left_type.plc_type)
and plc.traits.is_floating_point(right_type.plc_type)
) or (
plc.traits.is_fixed_point(right_type.plc_type)
and plc.traits.is_floating_point(left_type.plc_type)
):
is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type)
decimal_expr, float_expr = (
(left_expr, right_expr) if is_decimal_left else (right_expr, left_expr)
)
_add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts)


def _collect_decimal_binop_casts(
predicate: expr.Expr,
) -> tuple[dict[str, DataType], dict[str, DataType]]:
left_casts: dict[str, DataType] = {}
right_casts: dict[str, DataType] = {}

def _walk(node: expr.Expr) -> None:
if isinstance(node, expr.BinOp) and node.op in _BINOPS:
left_expr, right_expr = node.children
if isinstance(left_expr, expr.ColRef) and isinstance(
right_expr, expr.ColRef
):
_align_decimal_binop_types(
left_expr, right_expr, left_casts, right_casts
)
for child in node.children:
_walk(child)

_walk(predicate)
return left_casts, right_casts


def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame:
if not casts:
return df

columns = []
for col in df.columns:
target = casts.get(col.name)
if target is None:
columns.append(Column(col.obj, dtype=col.dtype, name=col.name))
else:
casted = col.astype(target)
columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name))
return DataFrame(columns)


class ConditionalJoin(IR):
"""A conditional inner join of two dataframes on a predicate."""

Expand Down Expand Up @@ -1633,6 +1750,7 @@ def __init__(
self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
) -> None:
self.schema = schema
predicate = _strip_predicate_casts(predicate)
self.predicate = predicate
# options[0] is a tuple[str, Operator, ...]
# The Operator class can't be pickled, but we don't use it anyway so
Expand Down Expand Up @@ -1665,11 +1783,14 @@ def do_evaluate(
right: DataFrame,
) -> DataFrame:
"""Evaluate and return a dataframe."""
left_casts, right_casts = _collect_decimal_binop_casts(
predicate_wrapper.predicate
)
_, _, zlice, suffix, _, _ = options

lg, rg = plc.join.conditional_inner_join(
left.table,
right.table,
_apply_casts(left, left_casts).table,
_apply_casts(right, right_casts).table,
predicate_wrapper.ast,
)
left = DataFrame.from_table(
Expand Down
78 changes: 57 additions & 21 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,32 @@ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR:
return rewrite_groupby(node, schema, keys, original_aggs, inp)


_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128}


def _align_decimal_scales(
left: expr.Expr, right: expr.Expr
) -> tuple[expr.Expr, expr.Expr]:
left_type, right_type = left.dtype, right.dtype

if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point(
right_type.plc_type
):
target = DataType.common_decimal_dtype(left_type, right_type)

if (
left_type.id() != target.id() or left_type.scale() != target.scale()
): # pragma: no cover; no test yet
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reviewer: We need this cast for Q11, to work. But I haven't been able to reproduce the failure outside of Q11 (in an actual test). I'm planning on leaving this for now and following up with a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xref #20213

left = expr.Cast(target, left)

if (
right_type.id() != target.id() or right_type.scale() != target.scale()
): # pragma: no cover; no test yet
right = expr.Cast(target, right)

return left, right


@_translate_ir.register
def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
# Join key dtypes are dependent on the schema of the left and
Expand Down Expand Up @@ -388,22 +414,24 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR:
expr.BinOp(
dtype,
expr.BinOp._MAPPING[op],
insert_colrefs(
left.value,
table_ref=plc.expressions.TableReference.LEFT,
name_to_index={
name: i for i, name in enumerate(inp_left.schema)
},
),
insert_colrefs(
right.value,
table_ref=plc.expressions.TableReference.RIGHT,
name_to_index={
name: i for i, name in enumerate(inp_right.schema)
},
*_align_decimal_scales(
insert_colrefs(
left_ne.value,
table_ref=plc.expressions.TableReference.LEFT,
name_to_index={
name: i for i, name in enumerate(inp_left.schema)
},
),
insert_colrefs(
right_ne.value,
table_ref=plc.expressions.TableReference.RIGHT,
name_to_index={
name: i for i, name in enumerate(inp_right.schema)
},
),
),
)
for op, left, right in zip(ops, left_on, right_on, strict=True)
for op, left_ne, right_ne in zip(ops, left_on, right_on, strict=True)
),
)

Expand Down Expand Up @@ -889,13 +917,21 @@ def _(
def _(
node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema
) -> expr.Expr:
value = expr.Agg(
dtype,
node.name,
node.options,
*(translator.translate_expr(n=n, schema=schema) for n in node.arguments),
)
if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
agg_name = node.name
args = [translator.translate_expr(n=arg, schema=schema) for arg in node.arguments]

if agg_name not in ("count", "n_unique", "mean", "median", "quantile"):
args = [
expr.Cast(dtype, arg)
if plc.traits.is_fixed_point(arg.dtype.plc_type)
and arg.dtype.plc_type != dtype.plc_type
else arg
for arg in args
]

value = expr.Agg(dtype, agg_name, node.options, *args)

if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
return expr.Cast(value.dtype, value)
return value

Expand Down
8 changes: 8 additions & 0 deletions python/cudf_polars/tests/containers/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ def test_repr():
)
def test_children(dtype, expected):
assert DataType(dtype).children == expected


def test_common_decimal_type_raises():
with pytest.raises(ValueError, match="Both inputs required to be decimal types."):
DataType.common_decimal_dtype(
DataType(pl.Float64()),
DataType(pl.Float64()),
)
26 changes: 26 additions & 0 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from decimal import Decimal

import pytest

import polars as pl
Expand Down Expand Up @@ -68,6 +70,19 @@ def df(dtype, with_nulls, is_sorted):
return df


@pytest.fixture
def decimal_df() -> pl.LazyFrame:
return pl.LazyFrame(
{
"a": pl.Series(
"a",
[Decimal("0.10"), Decimal("1.10"), Decimal("100.10")],
dtype=pl.Decimal(precision=9, scale=2),
),
}
)


def test_agg(df, agg):
expr = getattr(pl.col("a"), agg)()
q = df.select(expr)
Expand Down Expand Up @@ -162,3 +177,14 @@ def test_implode_agg_unsupported():
)
q = df.select(pl.col("b").implode())
assert_ir_translation_raises(q, NotImplementedError)


def test_decimal_aggs(decimal_df: pl.LazyFrame) -> None:
q = decimal_df.with_columns(
sum=pl.col("a").sum(),
min=pl.col("a").min(),
max=pl.col("a").max(),
mean=pl.col("a").mean(),
median=pl.col("a").median(),
)
assert_gpu_result_equal(q)
Loading