-
Notifications
You must be signed in to change notification settings - Fork 978
Align decimal dtypes in predicate before conditional join #20060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d64eca5
42e5e43
2d0f114
d9e6a96
7b35f1c
7b768fe
f83d127
a753f4f
1cb19ad
67fcd7e
6647326
d559636
dd2286b
51ea86f
8ff793a
d152f5c
dcdb068
99b5ba2
e811eec
e5452fb
d8b24c9
b9eda7d
4b9df48
66b334c
623ab7c
176654a
9f83cb4
3752341
555c8ee
6136d0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -333,6 +333,32 @@ def _(node: pl_ir.GroupBy, translator: Translator, schema: Schema) -> ir.IR: | |
return rewrite_groupby(node, schema, keys, original_aggs, inp) | ||
|
||
|
||
_DECIMAL_TYPES = {plc.TypeId.DECIMAL32, plc.TypeId.DECIMAL64, plc.TypeId.DECIMAL128} | ||
|
||
|
||
def _align_decimal_scales( | ||
left: expr.Expr, right: expr.Expr | ||
) -> tuple[expr.Expr, expr.Expr]: | ||
left_type, right_type = left.dtype, right.dtype | ||
|
||
if plc.traits.is_fixed_point(left_type.plc_type) and plc.traits.is_fixed_point( | ||
right_type.plc_type | ||
): | ||
target = DataType.common_decimal_dtype(left_type, right_type) | ||
|
||
if ( | ||
left_type.id() != target.id() or left_type.scale() != target.scale() | ||
): # pragma: no cover; no test yet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the reviewer: We need this cast for Q11, to work. But I haven't been able to reproduce the failure outside of Q11 (in an actual test). I'm planning on leaving this for now and following up with a test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xref #20213 |
||
left = expr.Cast(target, left) | ||
|
||
if ( | ||
right_type.id() != target.id() or right_type.scale() != target.scale() | ||
): # pragma: no cover; no test yet | ||
right = expr.Cast(target, right) | ||
|
||
return left, right | ||
|
||
|
||
@_translate_ir.register | ||
def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR: | ||
# Join key dtypes are dependent on the schema of the left and | ||
|
@@ -388,22 +414,24 @@ def _(node: pl_ir.Join, translator: Translator, schema: Schema) -> ir.IR: | |
expr.BinOp( | ||
dtype, | ||
expr.BinOp._MAPPING[op], | ||
insert_colrefs( | ||
left.value, | ||
table_ref=plc.expressions.TableReference.LEFT, | ||
name_to_index={ | ||
name: i for i, name in enumerate(inp_left.schema) | ||
}, | ||
), | ||
insert_colrefs( | ||
right.value, | ||
table_ref=plc.expressions.TableReference.RIGHT, | ||
name_to_index={ | ||
name: i for i, name in enumerate(inp_right.schema) | ||
}, | ||
*_align_decimal_scales( | ||
insert_colrefs( | ||
left_ne.value, | ||
table_ref=plc.expressions.TableReference.LEFT, | ||
name_to_index={ | ||
name: i for i, name in enumerate(inp_left.schema) | ||
}, | ||
), | ||
insert_colrefs( | ||
right_ne.value, | ||
table_ref=plc.expressions.TableReference.RIGHT, | ||
name_to_index={ | ||
name: i for i, name in enumerate(inp_right.schema) | ||
}, | ||
), | ||
), | ||
) | ||
for op, left, right in zip(ops, left_on, right_on, strict=True) | ||
for op, left_ne, right_ne in zip(ops, left_on, right_on, strict=True) | ||
), | ||
) | ||
|
||
|
@@ -889,13 +917,21 @@ def _( | |
def _( | ||
node: pl_expr.Agg, translator: Translator, dtype: DataType, schema: Schema | ||
) -> expr.Expr: | ||
value = expr.Agg( | ||
dtype, | ||
node.name, | ||
node.options, | ||
*(translator.translate_expr(n=n, schema=schema) for n in node.arguments), | ||
) | ||
if value.name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32: | ||
agg_name = node.name | ||
args = [translator.translate_expr(n=arg, schema=schema) for arg in node.arguments] | ||
|
||
if agg_name not in ("count", "n_unique", "mean", "median", "quantile"): | ||
args = [ | ||
expr.Cast(dtype, arg) | ||
if plc.traits.is_fixed_point(arg.dtype.plc_type) | ||
and arg.dtype.plc_type != dtype.plc_type | ||
else arg | ||
for arg in args | ||
] | ||
|
||
value = expr.Agg(dtype, agg_name, node.options, *args) | ||
|
||
if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32: | ||
return expr.Cast(value.dtype, value) | ||
return value | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.