Skip to content

Commit f70ded5

Browse files
authored
fix: return ALL constants in EquivalenceProperties::constants (#17404)
* test: regression test for #17372 * test: add more direct regression for #17372 * fix: return ALL constants in `EquivalenceProperties::constants`
1 parent 3b3a5fe commit f70ded5

File tree

3 files changed

+142
-7
lines changed

3 files changed

+142
-7
lines changed

datafusion/core/tests/physical_optimizer/sanity_checker.rs

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ use std::sync::Arc;
2020

2121
use crate::physical_optimizer::test_utils::{
2222
bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec,
23-
repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec,
23+
projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options,
24+
sort_merge_join_exec, sort_preserving_merge_exec, union_exec,
2425
};
2526

2627
use arrow::compute::SortOptions;
2728
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2829
use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable};
2930
use datafusion::prelude::{CsvReadOptions, SessionContext};
3031
use datafusion_common::config::ConfigOptions;
31-
use datafusion_common::{JoinType, Result};
32-
use datafusion_physical_expr::expressions::col;
32+
use datafusion_common::{JoinType, Result, ScalarValue};
33+
use datafusion_physical_expr::expressions::{col, Literal};
3334
use datafusion_physical_expr::Partitioning;
3435
use datafusion_physical_expr_common::sort_expr::LexOrdering;
3536
use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan;
@@ -665,3 +666,77 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> {
665666
assert_sanity_check(&smj, false);
666667
Ok(())
667668
}
669+
670+
/// A particular edge case.
671+
///
672+
/// See <https://github.com/apache/datafusion/issues/17372>.
673+
#[tokio::test]
674+
async fn test_union_with_sorts_and_constants() -> Result<()> {
675+
let schema_in = create_test_schema2();
676+
677+
let proj_exprs_1 = vec![
678+
(
679+
Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _,
680+
"const_1".to_owned(),
681+
),
682+
(
683+
Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _,
684+
"const_2".to_owned(),
685+
),
686+
(col("a", &schema_in).unwrap(), "a".to_owned()),
687+
];
688+
let proj_exprs_2 = vec![
689+
(
690+
Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _,
691+
"const_1".to_owned(),
692+
),
693+
(
694+
Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _,
695+
"const_2".to_owned(),
696+
),
697+
(col("a", &schema_in).unwrap(), "a".to_owned()),
698+
];
699+
700+
let source_1 = memory_exec(&schema_in);
701+
let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap();
702+
let schema_sources = source_1.schema();
703+
let ordering_sources: LexOrdering =
704+
[sort_expr("a", &schema_sources).nulls_last()].into();
705+
let source_1 = sort_exec(ordering_sources.clone(), source_1);
706+
707+
let source_2 = memory_exec(&schema_in);
708+
let source_2 = projection_exec(proj_exprs_2, source_2).unwrap();
709+
let source_2 = sort_exec(ordering_sources.clone(), source_2);
710+
711+
let plan = union_exec(vec![source_1, source_2]);
712+
713+
let schema_out = plan.schema();
714+
let ordering_out: LexOrdering = [
715+
sort_expr("const_1", &schema_out).nulls_last(),
716+
sort_expr("const_2", &schema_out).nulls_last(),
717+
sort_expr("a", &schema_out).nulls_last(),
718+
]
719+
.into();
720+
721+
let plan = sort_preserving_merge_exec(ordering_out, plan);
722+
723+
let plan_str = displayable(plan.as_ref()).indent(true).to_string();
724+
let plan_str = plan_str.trim();
725+
assert_snapshot!(
726+
plan_str,
727+
@r"
728+
SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST]
729+
UnionExec
730+
SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false]
731+
ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a]
732+
DataSourceExec: partitions=1, partition_sizes=[0]
733+
SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false]
734+
ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a]
735+
DataSourceExec: partitions=1, partition_sizes=[0]
736+
"
737+
);
738+
739+
assert_sanity_check(&plan, true);
740+
741+
Ok(())
742+
}

datafusion/physical-expr/src/equivalence/properties/mod.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,11 @@ impl EquivalenceProperties {
255255
pub fn constants(&self) -> Vec<ConstExpr> {
256256
self.eq_group
257257
.iter()
258-
.filter_map(|c| {
259-
c.constant.as_ref().and_then(|across| {
260-
c.canonical_expr()
261-
.map(|expr| ConstExpr::new(Arc::clone(expr), across.clone()))
258+
.flat_map(|c| {
259+
c.iter().filter_map(|expr| {
260+
c.constant
261+
.as_ref()
262+
.map(|across| ConstExpr::new(Arc::clone(expr), across.clone()))
262263
})
263264
})
264265
.collect()

datafusion/physical-expr/src/equivalence/properties/union.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,4 +921,63 @@ mod tests {
921921
.collect::<Vec<_>>(),
922922
))
923923
}
924+
925+
#[test]
926+
fn test_constants_share_values() -> Result<()> {
927+
let schema = Arc::new(Schema::new(vec![
928+
Field::new("const_1", DataType::Utf8, false),
929+
Field::new("const_2", DataType::Utf8, false),
930+
]));
931+
932+
let col_const_1 = col("const_1", &schema)?;
933+
let col_const_2 = col("const_2", &schema)?;
934+
935+
let literal_foo = ScalarValue::Utf8(Some("foo".to_owned()));
936+
let literal_bar = ScalarValue::Utf8(Some("bar".to_owned()));
937+
938+
let const_expr_1_foo = ConstExpr::new(
939+
Arc::clone(&col_const_1),
940+
AcrossPartitions::Uniform(Some(literal_foo.clone())),
941+
);
942+
let const_expr_2_foo = ConstExpr::new(
943+
Arc::clone(&col_const_2),
944+
AcrossPartitions::Uniform(Some(literal_foo.clone())),
945+
);
946+
let const_expr_2_bar = ConstExpr::new(
947+
Arc::clone(&col_const_2),
948+
AcrossPartitions::Uniform(Some(literal_bar.clone())),
949+
);
950+
951+
let mut input1 = EquivalenceProperties::new(Arc::clone(&schema));
952+
let mut input2 = EquivalenceProperties::new(Arc::clone(&schema));
953+
954+
// | Input | Const_1 | Const_2 |
955+
// | ----- | ------- | ------- |
956+
// | 1 | foo | foo |
957+
// | 2 | foo | bar |
958+
input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?;
959+
input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?;
960+
961+
// Calculate union properties
962+
let union_props = calculate_union(vec![input1, input2], schema)?;
963+
964+
// This should result in:
965+
// const_1 = Uniform("foo")
966+
// const_2 = Heterogeneous
967+
assert_eq!(union_props.constants().len(), 2);
968+
let union_const_1 = &union_props.constants()[0];
969+
assert!(union_const_1.expr.eq(&col_const_1));
970+
assert_eq!(
971+
union_const_1.across_partitions,
972+
AcrossPartitions::Uniform(Some(literal_foo)),
973+
);
974+
let union_const_2 = &union_props.constants()[1];
975+
assert!(union_const_2.expr.eq(&col_const_2));
976+
assert_eq!(
977+
union_const_2.across_partitions,
978+
AcrossPartitions::Heterogeneous,
979+
);
980+
981+
Ok(())
982+
}
924983
}

0 commit comments

Comments
 (0)