diff --git a/python/cudf_polars/cudf_polars/containers/datatype.py b/python/cudf_polars/cudf_polars/containers/datatype.py index f1960b552c1..8290efccbf9 100644 --- a/python/cudf_polars/cudf_polars/containers/datatype.py +++ b/python/cudf_polars/cudf_polars/containers/datatype.py @@ -124,6 +124,20 @@ def children(self) -> list[DataType]: return [DataType(self.polars_type.inner)] # type: ignore[attr-defined] return [] + def scale(self) -> int: + """The scale of this DataType.""" + return self.plc_type.scale() + + @staticmethod + def common_decimal_dtype(left: DataType, right: DataType) -> DataType: + """Return a common decimal DataType for the two inputs.""" + if not ( + plc.traits.is_fixed_point(left.plc_type) + and plc.traits.is_fixed_point(right.plc_type) + ): + raise ValueError("Both inputs required to be decimal types.") + return DataType(pl.Decimal(38, abs(min(left.scale(), right.scale())))) + def __eq__(self, other: object) -> bool: """Equality of DataTypes.""" if not isinstance(other, DataType): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py index 04bb5121e52..b7c08656c48 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py @@ -138,10 +138,15 @@ def agg_request(self) -> plc.aggregation.Aggregation: # noqa: D102 def _reduce( self, column: Column, *, request: plc.aggregation.Aggregation ) -> Column: + if ( + self.name in {"mean", "median"} + and plc.traits.is_fixed_point(column.dtype.plc_type) + and self.dtype.plc_type.id() in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64} + ): + column = column.astype(self.dtype) return Column( plc.Column.from_scalar( - plc.reduce.reduce(column.obj, request, self.dtype.plc_type), - 1, + plc.reduce.reduce(column.obj, request, self.dtype.plc_type), 1 ), name=column.name, dtype=self.dtype, diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index c440a8f8577..1f9eea70bfe 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -26,9 +26,11 @@ import polars as pl import pylibcudf as plc +from pylibcudf import expressions as plc_expr import cudf_polars.dsl.expr as expr from cudf_polars.containers import Column, DataFrame, DataType +from cudf_polars.containers.dataframe import NamedColumn from cudf_polars.dsl.expressions import rolling, unary from cudf_polars.dsl.expressions.base import ExecutionContext from cudf_polars.dsl.nodebase import Node @@ -81,6 +83,23 @@ ] +_BINOPS = { + plc.binaryop.BinaryOperator.EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL, + # TODO: Handle other binary operations as needed +} + + +_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128} + + +_FLOAT_TYPES = {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64} + + class IR(Node["IR"]): """Abstract plan node, representing an unevaluated dataframe.""" @@ -214,20 +233,16 @@ def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | Non def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame: # TODO: Alternatively set the schema of the parquet reader to decimal128 - plc_decimals_ids = { - plc.TypeId.DECIMAL32, - plc.TypeId.DECIMAL64, - plc.TypeId.DECIMAL128, - } cast_list = [] for name, col in df.column_map.items(): src = col.obj.type() dst = schema[name].plc_type + if ( - src.id() in plc_decimals_ids - and dst.id() in plc_decimals_ids - and ((src.id() != dst.id()) or (src.scale != dst.scale)) + plc.traits.is_fixed_point(src) + and plc.traits.is_fixed_point(dst) + and ((src.id() != dst.id()) or (src.scale() != dst.scale())) ): cast_list.append( Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name]) @@ -1586,6 +1601,108 @@ def do_evaluate( return DataFrame(broadcasted).slice(zlice) +def _strip_predicate_casts(node: expr.Expr) -> expr.Expr: + if isinstance(node, expr.Cast): + (child,) = node.children + child = _strip_predicate_casts(child) + + src = child.dtype + dst = node.dtype + + if plc.traits.is_fixed_point(src.plc_type) or plc.traits.is_fixed_point( + dst.plc_type + ): + return child + + if not node.children: + return node + return node.reconstruct([_strip_predicate_casts(child) for child in node.children]) + + +def _add_cast( + target: DataType, + side: expr.ColRef, + left_casts: dict[str, DataType], + right_casts: dict[str, DataType], +) -> None: + (col,) = side.children + assert isinstance(col, expr.Col) + casts = ( + left_casts if side.table_ref == plc_expr.TableReference.LEFT else right_casts + ) + casts[col.name] = target + + +def _align_decimal_binop_types( + left_expr: expr.ColRef, + right_expr: expr.ColRef, + left_casts: dict[str, DataType], + right_casts: dict[str, DataType], +) -> None: + left_type, right_type = left_expr.dtype, right_expr.dtype + + if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point( + right_type.plc_type + ): + target = DataType.common_decimal_dtype(left_type, right_type) + + if left_type.id() != target.id() or left_type.scale() != target.scale(): + _add_cast(target, left_expr, left_casts, right_casts) + + if right_type.id() != target.id() or right_type.scale() != target.scale(): + _add_cast(target, right_expr, left_casts, right_casts) + + elif ( + plc.traits.is_fixed_point(left_type.plc_type) + and plc.traits.is_floating_point(right_type.plc_type) + ) or ( + plc.traits.is_fixed_point(right_type.plc_type) + and plc.traits.is_floating_point(left_type.plc_type) + ): + is_decimal_left = plc.traits.is_fixed_point(left_type.plc_type) + decimal_expr, float_expr = ( + (left_expr, right_expr) if is_decimal_left else (right_expr, left_expr) + ) + _add_cast(decimal_expr.dtype, float_expr, left_casts, right_casts) + + +def _collect_decimal_binop_casts( + predicate: expr.Expr, +) -> tuple[dict[str, DataType], dict[str, DataType]]: + left_casts: dict[str, DataType] = {} + right_casts: dict[str, DataType] = {} + + def _walk(node: expr.Expr) -> None: + if isinstance(node, expr.BinOp) and node.op in _BINOPS: + left_expr, right_expr = node.children + if isinstance(left_expr, expr.ColRef) and isinstance( + right_expr, expr.ColRef + ): + _align_decimal_binop_types( + left_expr, right_expr, left_casts, right_casts + ) + for child in node.children: + _walk(child) + + _walk(predicate) + return left_casts, right_casts + + +def _apply_casts(df: DataFrame, casts: dict[str, DataType]) -> DataFrame: + if not casts: + return df + + columns = [] + for col in df.columns: + target = casts.get(col.name) + if target is None: + columns.append(Column(col.obj, dtype=col.dtype, name=col.name)) + else: + casted = col.astype(target) + columns.append(Column(casted.obj, dtype=casted.dtype, name=col.name)) + return DataFrame(columns) + + class ConditionalJoin(IR): """A conditional inner join of two dataframes on a predicate.""" @@ -1633,6 +1750,7 @@ def __init__( self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR ) -> None: self.schema = schema + predicate = _strip_predicate_casts(predicate) self.predicate = predicate # options[0] is a tuple[str, Operator, ...] # The Operator class can't be pickled, but we don't use it anyway so @@ -1665,11 +1783,14 @@ def do_evaluate( right: DataFrame, ) -> DataFrame: """Evaluate and return a dataframe.""" + left_casts, right_casts = _collect_decimal_binop_casts( + predicate_wrapper.predicate + ) _, _, zlice, suffix, _, _ = options lg, rg = plc.join.conditional_inner_join( - left.table, - right.table, + _apply_casts(left, left_casts).table, + _apply_casts(right, right_casts).table, predicate_wrapper.ast, ) left = DataFrame.from_table( diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index f533e958ae8..710f576c66d 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -333,6 +333,32 @@ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR: return rewrite_groupby(node, schema, keys, original_aggs, inp) +_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128} + + +def _align_decimal_scales( + left: expr.Expr, right: expr.Expr +) -> tuple[expr.Expr, expr.Expr]: + left_type, right_type = left.dtype, right.dtype + + if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point( + right_type.plc_type + ): + target = DataType.common_decimal_dtype(left_type, right_type) + + if ( + left_type.id() != target.id() or left_type.scale() != target.scale() + ): # pragma: no cover; no test yet + left = expr.Cast(target, left) + + if ( + right_type.id() != target.id() or right_type.scale() != target.scale() + ): # pragma: no cover; no test yet + right = expr.Cast(target, right) + + return left, right + + @_translate_ir.register def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR: # Join key dtypes are dependent on the schema of the left and @@ -388,22 +414,24 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR: expr.BinOp( dtype, expr.BinOp._MAPPING[op], - insert_colrefs( - left.value, - table_ref=plc.expressions.TableReference.LEFT, - name_to_index={ - name: i for i, name in enumerate(inp_left.schema) - }, - ), - insert_colrefs( - right.value, - table_ref=plc.expressions.TableReference.RIGHT, - name_to_index={ - name: i for i, name in enumerate(inp_right.schema) - }, + *_align_decimal_scales( + insert_colrefs( + left_ne.value, + table_ref=plc.expressions.TableReference.LEFT, + name_to_index={ + name: i for i, name in enumerate(inp_left.schema) + }, + ), + insert_colrefs( + right_ne.value, + table_ref=plc.expressions.TableReference.RIGHT, + name_to_index={ + name: i for i, name in enumerate(inp_right.schema) + }, + ), ), ) - for op, left, right in zip(ops, left_on, right_on, strict=True) + for op, left_ne, right_ne in zip(ops, left_on, right_on, strict=True) ), ) @@ -889,13 +917,21 @@ def _( def _( node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema ) -> expr.Expr: - value = expr.Agg( - dtype, - node.name, - node.options, - *(translator.translate_expr(n=n, schema=schema) for n in node.arguments), - ) - if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32: + agg_name = node.name + args = [translator.translate_expr(n=arg, schema=schema) for arg in node.arguments] + + if agg_name not in ("count", "n_unique", "mean", "median", "quantile"): + args = [ + expr.Cast(dtype, arg) + if plc.traits.is_fixed_point(arg.dtype.plc_type) + and arg.dtype.plc_type != dtype.plc_type + else arg + for arg in args + ] + + value = expr.Agg(dtype, agg_name, node.options, *args) + + if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32: return expr.Cast(value.dtype, value) return value diff --git a/python/cudf_polars/tests/containers/test_datatype.py b/python/cudf_polars/tests/containers/test_datatype.py index ff6120edf3b..3eb50aeefbb 100644 --- a/python/cudf_polars/tests/containers/test_datatype.py +++ b/python/cudf_polars/tests/containers/test_datatype.py @@ -46,3 +46,11 @@ def test_repr(): ) def test_children(dtype, expected): assert DataType(dtype).children == expected + + +def test_common_decimal_type_raises(): + with pytest.raises(ValueError, match="Both inputs required to be decimal types."): + DataType.common_decimal_dtype( + DataType(pl.Float64()), + DataType(pl.Float64()), + ) diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py index 9bcd4bd5c4b..2792f6f9d73 100644 --- a/python/cudf_polars/tests/expressions/test_agg.py +++ b/python/cudf_polars/tests/expressions/test_agg.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from decimal import Decimal + import pytest import polars as pl @@ -68,6 +70,19 @@ def df(dtype, with_nulls, is_sorted): return df +@pytest.fixture +def decimal_df() -> pl.LazyFrame: + return pl.LazyFrame( + { + "a": pl.Series( + "a", + [Decimal("0.10"), Decimal("1.10"), Decimal("100.10")], + dtype=pl.Decimal(precision=9, scale=2), + ), + } + ) + + def test_agg(df, agg): expr = getattr(pl.col("a"), agg)() q = df.select(expr) @@ -162,3 +177,14 @@ def test_implode_agg_unsupported(): ) q = df.select(pl.col("b").implode()) assert_ir_translation_raises(q, NotImplementedError) + + +def test_decimal_aggs(decimal_df: pl.LazyFrame) -> None: + q = decimal_df.with_columns( + sum=pl.col("a").sum(), + min=pl.col("a").min(), + max=pl.col("a").max(), + mean=pl.col("a").mean(), + median=pl.col("a").median(), + ) + assert_gpu_result_equal(q) diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 7cc69a91aae..078abae9dab 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from decimal import Decimal + import pytest import polars as pl @@ -249,3 +251,64 @@ def test_join_maintain_order_with_slice(left, right, maintain_order, how, zlice) if POLARS_VERSION_LT_130 else {"optimizations": pl.QueryOptFlags(slice_pushdown=False)}, ) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("foo") > pl.col("bar"), + pl.col("foo") >= pl.col("bar"), + pl.col("foo") < pl.col("bar"), + pl.col("foo") <= pl.col("bar"), + pl.col("foo") == pl.col("bar"), + pytest.param( + pl.col("foo") != pl.col("bar"), + marks=pytest.mark.xfail(reason="nested loop join"), + ), + ], +) +@pytest.mark.parametrize( + "left_dtype,right_dtype", + [ + (pl.Decimal(15, 2), pl.Decimal(15, 2)), + (pl.Decimal(15, 4), pl.Decimal(15, 2)), + (pl.Decimal(15, 2), pl.Decimal(15, 4)), + (pl.Decimal(15, 2), pl.Float32), + (pl.Decimal(15, 2), pl.Float64), + ], +) +def test_cross_join_filter_with_decimals(request, expr, left_dtype, right_dtype): + request.applymarker( + pytest.mark.xfail( + POLARS_VERSION_LT_132 + and isinstance(left_dtype, pl.Decimal) + and isinstance(right_dtype, pl.Decimal) + and "==" in repr(expr), + reason="Hash Inner Join between i128 and i128", + ) + ) + left = pl.LazyFrame( + { + "foo": [Decimal("1.00"), Decimal("2.50"), Decimal("3.00")], + "foo1": [10, 20, 30], + }, + schema={"foo": left_dtype, "foo1": pl.Int64}, + ) + + if isinstance(right_dtype, pl.Decimal): + right = pl.LazyFrame( + { + "bar": [Decimal("2").scaleb(-right_dtype.scale)], + "foo1": ["x"], + }, + schema={"bar": right_dtype, "foo1": pl.String}, + ) + else: + right = pl.LazyFrame( + {"bar": [2.0], "foo1": ["x"]}, + schema={"bar": right_dtype, "foo1": pl.String}, + ) + + q = left.join(right, how="cross").filter(expr) + + assert_gpu_result_equal(q, check_row_order=False)