diff --git a/Cargo.lock b/Cargo.lock index 8ffdb8c6403c..1012587c0c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4235,12 +4235,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -4434,12 +4433,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" version = "4.2.1" @@ -6675,9 +6668,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", diff --git a/Cargo.toml b/Cargo.toml index 601d11f12dd8..fe6667b7a8b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,7 +99,6 @@ arrow-flight = { version = "55.2.0", features = [ ] } arrow-ipc = { version = "55.2.0", default-features = false, features = [ "lz4", - "zstd", ] } arrow-ord = { version = "55.2.0", default-features = false } arrow-schema = { version = "55.2.0", default-features = false } @@ -198,6 +197,7 @@ rpath = false strip = false # Retain debug info for flamegraphs [profile.ci] +debug = false inherits = "dev" incremental = false diff --git a/ci.test b/ci.test new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index f7316ddc1bec..e0fab7ee9f31 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -199,6 +199,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, }, })) }; diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c4455e271c84..1a6a66923e55 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", + "arrow-ipc/zstd", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab123dcceada..df24c19f7841 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1646,6 +1646,7 @@ pub fn create_window_expr_with_name( order_by, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); let physical_args = @@ -1674,6 +1675,7 @@ pub fn create_window_expr_with_name( window_frame, physical_schema, ignore_nulls, + *distinct, ) } other => plan_err!("Invalid window expression '{other:?}'"), diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 316d3ba5a926..23e3281cf386 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -288,6 +288,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { Arc::new(window_frame), &extended_schema, false, + false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -660,6 +661,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec1, false, @@ -678,6 +680,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec2, search_mode.clone(), diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index fd847763124a..2dce87de00ed 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -23,7 +23,7 @@ use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, parquet_exec_with_stats, repartition_exec, schema, sort_exec, sort_exec_with_preserve_partitioning, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, trim_plan_display, union_exec, }; use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; @@ -39,10 +39,12 @@ use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::{assert_contains, ScalarValue}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr::{JoinType, Operator}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{AggregateUDF, JoinType, Operator}; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -51,6 +53,7 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; +use datafusion_physical_optimizer::sanity_checker::check_plan_sanity; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -66,7 +69,7 @@ use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + displayable, get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, }; @@ -162,8 +165,8 @@ impl ExecutionPlan for SortRequiredExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -237,7 +240,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc, alias_pairs: Vec<(String, String)>, ) -> Arc { @@ -251,6 +254,15 @@ fn projection_exec_with_alias( fn aggregate_exec_with_alias( input: Arc, alias_pairs: Vec<(String, String)>, +) -> Arc { + aggregate_exec_with_aggr_expr_and_alias(input, vec![], alias_pairs) +} + +#[expect(clippy::type_complexity)] +fn aggregate_exec_with_aggr_expr_and_alias( + input: Arc, + aggr_expr: Vec<(Arc, Vec>)>, + alias_pairs: Vec<(String, String)>, ) -> Arc { let schema = schema(); let mut group_by_expr: Vec<(Arc, String)> = vec![]; @@ -271,18 +283,31 @@ fn aggregate_exec_with_alias( .collect::>(); let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + let aggr_expr = aggr_expr + .into_iter() + .map(|(udaf, exprs)| { + AggregateExprBuilder::new(udaf.clone(), exprs) + .alias(udaf.name()) + .schema(Arc::clone(&schema)) + .build() + .map(Arc::new) + .unwrap() + }) + .collect::>(); + let filter_exprs = std::iter::repeat_n(None, aggr_expr.len()).collect::>(); + Arc::new( AggregateExec::try_new( AggregateMode::FinalPartitioned, final_grouping, - vec![], - vec![], + aggr_expr.clone(), + filter_exprs.clone(), Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, - vec![], - vec![], + aggr_expr, + filter_exprs, input, schema.clone(), ) @@ -439,6 +464,12 @@ impl TestConfig { self } + /// Set batch size. + fn with_batch_size(mut self, batch_size: usize) -> Self { + self.config.execution.batch_size = batch_size; + self + } + /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, /// and return the result plan (for potentional subsequent runs). @@ -2027,6 +2058,285 @@ fn repartition_ignores_union() -> Result<()> { Ok(()) } +fn aggregate_over_union(input: Vec>) -> Arc { + let union = union_exec(input); + let plan = + aggregate_exec_with_alias(union, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +// Aggregate over a union, +// with current testing setup. +// +// It will repartiton twice for an aggregate over a union. +// * repartitions before the partial aggregate. +// * repartitions before the final aggregation. +#[test] +fn repartitions_twice_for_aggregate_after_union() -> Result<()> { + let plan = aggregate_over_union(vec![parquet_exec(); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // Updated plan (post optimization) will have added RepartitionExecs (btwn union and aggregation). + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default(); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +// Aggregate over a union, +// but make the test setup more realistic. +// +// It will repartiton once for an aggregate over a union. +// * repartitions btwn partial & final aggregations. +#[test] +fn repartitions_once_for_aggregate_after_union() -> Result<()> { + // use parquet exec with stats + let plan: Arc = + aggregate_over_union(vec![parquet_exec_with_stats(10000); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // This removes the forced round-robin repartitioning, + // by no longer hard-coding batch_size=1. + // + // Updated plan (post optimization) will have added only 1 RepartitionExec. + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_once_for_aggregate_after_union`], but adds a sort btwn +/// the union and the aggregate. This changes the outcome: +/// +/// * we no longer get a distribution error. +/// * but we still get repartitioning? +#[test] +fn repartitions_for_aggregate_after_sorted_union() -> Result<()> { + let plan = aggregate_over_sorted_union(vec![parquet_exec_with_stats(10000); 2]); + + // With the sort, there is no distribution error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It does not repartition on the first run + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // But does repartition on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_sorted_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union_projection( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let union_projection = projection_exec_with_alias( + union, + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union_projection); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "ProjectionExec: expr=[a@0 as a, b@1 as value]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_for_aggregate_after_sorted_union`], but adds a projection +/// as well between the union and aggregate. This change the outcome: +/// +/// * we no longer get repartitioning, and instead get coalescing. +#[test] +fn coalesces_for_aggregate_after_sorted_union_projection() -> Result<()> { + let plan = + aggregate_over_sorted_union_projection(vec![parquet_exec_with_stats(10000); 2]); + + // Same as `repartitions_for_aggregate_after_sorted_union`. No error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It no longer does a repartition on the first run. + // Instead adds a SPM. + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // Then it removes the SPM, and inserts a coalesace on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + #[test] fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e31a30cc0883..ef29d51e5d37 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,15 +17,16 @@ use std::sync::Arc; +use crate::physical_optimizer::enforce_distribution::projection_exec_with_alias; use crate::physical_optimizer::test_utils::{ aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + parquet_exec_with_stats, projection_exec, repartition_exec, schema, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -47,6 +48,9 @@ use datafusion_physical_expr_common::sort_expr::{ }; use datafusion_physical_expr::{Distribution, Partitioning}; use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -2292,6 +2296,93 @@ async fn test_commutativity() -> Result<()> { Ok(()) } +fn single_partition_aggregate( + input: Arc, + alias_pairs: Vec<(String, String)>, +) -> Arc { + let schema = schema(); + let group_by = alias_pairs + .iter() + .map(|(column, alias)| (col(column, &input.schema()).unwrap(), alias.to_string())) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(group_by); + + Arc::new( + AggregateExec::try_new( + AggregateMode::SinglePartitioned, + group_by, + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +#[tokio::test] +async fn test_preserve_needed_coalesce() -> Result<()> { + // Input to EnforceSorting, from our test case. + let plan = projection_exec_with_alias( + union_exec(vec![parquet_exec_with_stats(10000); 2]), + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let plan = Arc::new(CoalescePartitionsExec::new(plan)); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let plan: Arc = + single_partition_aggregate(plan, vec![("a".to_string(), "a1".to_string())]); + let plan = sort_exec(sort_key, plan); + + // Starting plan: as in our test case. + assert_eq!( + get_plan_string(&plan), + vec![ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + let checker = SanityCheckPlan::new().optimize(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // EnforceSorting will remove the coalesce, and add an SPM further up (above the aggregate). + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(plan, &Default::default())?; + assert_eq!( + get_plan_string(&optimized), + vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + // Plan is valid. + let checker = SanityCheckPlan::new(); + let checker = checker.optimize(optimized, &Default::default()); + assert!(checker.is_ok()); + + Ok(()) +} + #[tokio::test] async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; @@ -3675,6 +3766,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.window_frame, input_schema.as_ref(), false, + false, )?; let window_exec = if window_expr.uses_bounded_memory() { Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index 6233f5d09c56..ce6eb13c86c4 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -20,7 +20,8 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, }; use arrow::compute::SortOptions; @@ -28,8 +29,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, Result}; -use datafusion_physical_expr::expressions::col; +use datafusion_common::{JoinType, Result, ScalarValue}; +use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; @@ -665,3 +666,77 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { assert_sanity_check(&smj, false); Ok(()) } + +/// A particular edge case. +/// +/// See . +#[tokio::test] +async fn test_union_with_sorts_and_constants() -> Result<()> { + let schema_in = create_test_schema2(); + + let proj_exprs_1 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + let proj_exprs_2 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + + let source_1 = memory_exec(&schema_in); + let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap(); + let schema_sources = source_1.schema(); + let ordering_sources: LexOrdering = + [sort_expr("a", &schema_sources).nulls_last()].into(); + let source_1 = sort_exec(ordering_sources.clone(), source_1); + + let source_2 = memory_exec(&schema_in); + let source_2 = projection_exec(proj_exprs_2, source_2).unwrap(); + let source_2 = sort_exec(ordering_sources.clone(), source_2); + + let plan = union_exec(vec![source_1, source_2]); + + let schema_out = plan.schema(); + let ordering_out: LexOrdering = [ + sort_expr("const_1", &schema_out).nulls_last(), + sort_expr("const_2", &schema_out).nulls_last(), + sort_expr("a", &schema_out).nulls_last(), + ] + .into(); + + let plan = sort_preserving_merge_exec(ordering_out, plan); + + let plan_str = displayable(plan.as_ref()).indent(true).to_string(); + let plan_str = plan_str.trim(); + assert_snapshot!( + plan_str, + @r" + SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST] + UnionExec + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + " + ); + + assert_sanity_check(&plan, true); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 7fb0f795f294..5e2d61e68f8d 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -265,6 +265,7 @@ pub fn bounded_window_exec_with_partition( Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, + false, ) .unwrap(); @@ -509,6 +510,13 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} + // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0749ff0e98b7..efe8a639087a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1131,6 +1131,8 @@ pub struct WindowFunctionParams { pub window_frame: WindowFrame, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, + /// Distinct flag + pub distinct: bool, } impl WindowFunction { @@ -1145,6 +1147,7 @@ impl WindowFunction { order_by: Vec::default(), window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, } } @@ -2291,6 +2294,7 @@ impl NormalizeEq for Expr { partition_by: self_partition_by, order_by: self_order_by, null_treatment: self_null_treatment, + distinct: self_distinct, }, } = left.as_ref(); let WindowFunction { @@ -2302,6 +2306,7 @@ impl NormalizeEq for Expr { partition_by: other_partition_by, order_by: other_order_by, null_treatment: other_null_treatment, + distinct: other_distinct, }, } = other.as_ref(); @@ -2325,6 +2330,7 @@ impl NormalizeEq for Expr { && a.nulls_first == b.nulls_first && a.expr.normalize_eq(&b.expr) }) + && self_distinct == other_distinct } ( Expr::Exists(Exists { @@ -2558,11 +2564,13 @@ impl HashNode for Expr { order_by: _, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); + distinct.hash(state); } Expr::InList(InList { expr: _expr, @@ -2865,15 +2873,27 @@ impl Display for SchemaDisplay<'_> { order_by, window_frame, null_treatment, + distinct, } = params; + // Write function name and open parenthesis + write!(f, "{fun}(")?; + + // If DISTINCT, emit the keyword + if *distinct { + write!(f, "DISTINCT ")?; + } + + // Write the comma‑separated argument list write!( f, - "{}({})", - fun, + "{}", schema_name_from_exprs_comma_separated_without_space(args)? )?; + // **Close the argument parenthesis** + write!(f, ")")?; + if let Some(null_treatment) = null_treatment { write!(f, " {null_treatment}")?; } @@ -3260,9 +3280,10 @@ impl Display for Expr { order_by, window_frame, null_treatment, + distinct, } = params; - fmt_function(f, &fun.to_string(), false, args, true)?; + fmt_function(f, &fun.to_string(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, "{nt}")?; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c0351a9dcaca..fab86fe7663d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -945,6 +945,7 @@ impl ExprFuncBuilder { window_frame: window_frame .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, + distinct, }, }) } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 067c7a94279f..b04fe32d376e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -308,6 +308,7 @@ pub struct RawWindowExpr { pub order_by: Vec, pub window_frame: WindowFrame, pub null_treatment: Option, + pub distinct: bool, } /// Result of planning a raw expr with [`ExprPlanner`] diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f953aec5a1e3..b6f583ca4c74 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -242,10 +242,22 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { + if distinct { + return Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build() + .unwrap(); + } + Expr::from(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b6c8eb627c77..a0d2b6a96a48 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -459,7 +459,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { // exclude the first function argument(= column) in ordered set aggregate function, // because it is duplicated with the WITHIN GROUP clause in schema name. - let args = if self.is_ordered_set_aggregate() { + let args = if self.is_ordered_set_aggregate() && !order_by.is_empty() { &args[1..] } else { &args[..] @@ -554,14 +554,25 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut schema_name = String::new(); - schema_name.write_fmt(format_args!( - "{}({})", - self.name(), - schema_name_from_exprs(args)? - ))?; + + // Inject DISTINCT into the schema name when requested + if *distinct { + schema_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } else { + schema_name.write_fmt(format_args!( + "{}({})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } if let Some(null_treatment) = null_treatment { schema_name.write_fmt(format_args!(" {null_treatment}"))?; @@ -579,7 +590,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { " ORDER BY [{}]", schema_name_from_sorts(order_by)? ))?; - }; + } schema_name.write_fmt(format_args!(" {window_frame}"))?; @@ -648,15 +659,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut display_name = String::new(); - display_name.write_fmt(format_args!( - "{}({})", - self.name(), - expr_vec_fmt!(args) - ))?; + if *distinct { + display_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + expr_vec_fmt!(args) + ))?; + } else { + display_name.write_fmt(format_args!( + "{}({})", + self.name(), + expr_vec_fmt!(args) + ))?; + } if let Some(null_treatment) = null_treatment { display_name.write_fmt(format_args!(" {null_treatment}"))?; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 55c8c847ad0a..fce300e79bea 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -77,15 +77,38 @@ pub fn approx_percentile_cont( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | +-----------------------------------------------------------------------+ | 65.0 | +-----------------------------------------------------------------------+ -```"#, +``` +An alternate syntax is also supported: +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` +"#, standard_argument(name = "expression",), argument( name = "percentile", @@ -313,7 +336,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } if arg_types.len() == 3 && !arg_types[2].is_integer() { return plan_err!( - "approx_percentile_cont requires integer max_size input types" + "approx_percentile_cont requires integer centroids input types" ); } Ok(arg_types[0].clone()) @@ -360,6 +383,11 @@ impl ApproxPercentileAccumulator { } } + // public for approx_percentile_cont_with_weight + pub(crate) fn max_size(&self) -> usize { + self.digest.max_size() + } + // public for approx_percentile_cont_with_weight pub fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ab847e838869..f70d751a8cb9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -25,32 +25,53 @@ use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, -}; -use datafusion_functions_aggregate_common::tdigest::{ - Centroid, TDigest, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, }; +use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -make_udaf_expr_and_func!( +create_func!( ApproxPercentileContWithWeight, - approx_percentile_cont_with_weight, - expression weight percentile, - "Computes the approximate percentile continuous with weight of a set of numbers", approx_percentile_cont_with_weight_udaf ); +/// Computes the approximate percentile continuous with weight of a set of numbers +pub fn approx_percentile_cont_with_weight( + order_by: Sort, + weight: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let expr = order_by.expr.clone(); + + let args = if let Some(centroids) = centroids { + vec![expr, weight, percentile, centroids] + } else { + vec![expr, weight, percentile] + }; + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_with_weight_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql > SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; +---------------------------------------------------------------------------------------------+ @@ -58,6 +79,22 @@ make_udaf_expr_and_func!( +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -67,6 +104,10 @@ make_udaf_expr_and_func!( argument( name = "percentile", description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] pub struct ApproxPercentileContWithWeight { @@ -91,21 +132,26 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with weight and float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + ])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + int.clone(), + ])); + } + } Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::one_of(variants, Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -138,6 +184,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { if arg_types[2] != DataType::Float64 { return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); } + if arg_types.len() == 4 && !arg_types[3].is_integer() { + return plan_err!( + "approx_percentile_cont_with_weight requires integer centroids input types" + ); + } Ok(arg_types[0].clone()) } @@ -148,17 +199,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.exprs.len() != 3 { + if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 { return plan_err!( - "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]" ); } let sub_args = AccumulatorArgs { - exprs: &[ - Arc::clone(&acc_args.exprs[0]), - Arc::clone(&acc_args.exprs[2]), - ], + exprs: if acc_args.exprs.len() == 4 { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + Arc::clone(&acc_args.exprs[3]), // centroids + ] + } else { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + ] + }, ..acc_args }; let approx_percentile_cont_accumulator = @@ -244,7 +303,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { let mut digests: Vec = vec![]; for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, + self.approx_percentile_cont_accumulator.max_size(), Centroid::new(*mean, *weight), )) } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 09904bbad6ec..7a7c2879aa79 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -31,7 +31,7 @@ use arrow::{ }; use datafusion_common::{ downcast_value, internal_err, not_impl_err, stats::Precision, - utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue, + utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, }; use datafusion_expr::{ expr::WindowFunction, @@ -59,6 +59,7 @@ use std::{ ops::BitAnd, sync::Arc, }; + make_udaf_expr_and_func!( Count, count, @@ -406,6 +407,98 @@ impl AggregateUDFImpl for Count { // the same as new values are seen. SetMonotonicity::Increasing } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + if args.is_distinct { + let acc = + SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?; + Ok(Box::new(acc)) + } else { + let acc = CountAccumulator::new(); + Ok(Box::new(acc)) + } + } +} + +// DistinctCountAccumulator does not support retract_batch and sliding window +// this is a specialized accumulator for distinct count that supports retract_batch +// and sliding window. +#[derive(Debug)] +pub struct SlidingDistinctCountAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl SlidingDistinctCountAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + counts: HashMap::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctCountAccumulator { + fn state(&mut self) -> Result> { + let keys = self.counts.keys().cloned().collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + keys.as_slice(), + &self.data_type, + ))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); + } + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let list_arr = states[0].as_list::(); + for inner in list_arr.iter().flatten() { + for j in 0..inner.len() { + let v = ScalarValue::try_from_array(&*inner, j)?; + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.counts.len() as i64))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } } #[derive(Debug)] @@ -878,4 +971,72 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn sliding_distinct_count_accumulator_basic() -> Result<()> { + // Basic update_batch + evaluate functionality + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // Create an Int32Array: [1, 2, 2, 3, null] + let values: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + ])); + acc.update_batch(&[values])?; + // Expect distinct values {1,2,3} → count = 3 + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_retract() -> Result<()> { + // Test that retract_batch properly decrements counts + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?; + // Initial batch: ["a", "b", "a"] + let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])) + as ArrayRef; + acc.update_batch(&[arr1])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"} + + // Retract batch: ["a", null, "b"] + let arr2 = + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef; + acc.retract_batch(&[arr2])?; + // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"} + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_merge_states() -> Result<()> { + // Test merging multiple accumulator states with merge_batch + let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // acc1 sees [1, 2] + acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?; + // acc2 sees [2, 3] + acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?; + // Extract their states as Vec + let state_sv1 = acc1.state()?; + let state_sv2 = acc2.state()?; + // Convert ScalarValue states into Vec, propagating errors + // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray + let state_arr1: Vec = state_sv1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + let state_arr2: Vec = state_sv2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + // Merge both states into a fresh accumulator + let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + merged.merge_batch(&state_arr1)?; + merged.merge_batch(&state_arr2)?; + // Expect distinct {1,2,3} → count = 3 + assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } } diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 091737bb9c15..5e3a6bc6336c 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -41,6 +41,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; let origin_expr = Expr::from(WindowFunction { @@ -51,6 +52,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, }); @@ -68,6 +70,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let raw_expr = RawWindowExpr { @@ -77,6 +80,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }; // TODO: remove the next line after `Expr::Wildcard` is removed @@ -93,18 +97,23 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; - let new_expr = Expr::from(WindowFunction::new( + let mut new_expr_before_build = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?; + .null_treatment(null_treatment); + if distinct { + new_expr_before_build = new_expr_before_build.distinct(); + } + + let new_expr = new_expr_before_build.build()?; let new_expr = saved_name.restore(new_expr); return Ok(PlannerResult::Planned(new_expr)); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a98b0fdcc3d3..e6fc006cb2ff 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -549,6 +549,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let window_frame = @@ -565,14 +566,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { _ => args, }; - Ok(Transformed::yes( - Expr::from(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, - )) + if distinct { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build()?, + )) + } else { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) + } } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 6d18d34ca4de..fed3b78de801 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -255,10 +255,11 @@ impl EquivalenceProperties { pub fn constants(&self) -> Vec { self.eq_group .iter() - .filter_map(|c| { - c.constant.as_ref().and_then(|across| { - c.canonical_expr() - .map(|expr| ConstExpr::new(Arc::clone(expr), across.clone())) + .flat_map(|c| { + c.iter().filter_map(|expr| { + c.constant + .as_ref() + .map(|across| ConstExpr::new(Arc::clone(expr), across.clone())) }) }) .collect() diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 4f44b9b0c9d4..8ec2464068ef 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -67,16 +67,43 @@ fn calculate_union_binary( }) .collect::>(); + // TEMP HACK WORKAROUND + // Revert code from https://github.com/apache/datafusion/pull/12562 + // Context: https://github.com/apache/datafusion/issues/13748 + // Context: https://github.com/influxdata/influxdb_iox/issues/13038 + // Next, calculate valid orderings for the union by searching for prefixes // in both sides. - let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(&lhs, &rhs)?; - orderings.add_satisfied_orderings(&rhs, &lhs)?; - let orderings = orderings.build(); + let mut orderings = vec![]; + for ordering in lhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !rhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + for ordering in rhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !lhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } let mut eq_properties = EquivalenceProperties::new(lhs.schema); eq_properties.add_constants(constants)?; eq_properties.add_orderings(orderings); + Ok(eq_properties) } @@ -122,6 +149,7 @@ struct UnionEquivalentOrderingBuilder { orderings: Vec, } +#[expect(unused)] impl UnionEquivalentOrderingBuilder { fn new() -> Self { Self { orderings: vec![] } @@ -504,6 +532,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -579,6 +608,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); @@ -607,6 +637,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -658,6 +689,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -921,4 +953,63 @@ mod tests { .collect::>(), )) } + + #[test] + fn test_constants_share_values() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("const_1", DataType::Utf8, false), + Field::new("const_2", DataType::Utf8, false), + ])); + + let col_const_1 = col("const_1", &schema)?; + let col_const_2 = col("const_2", &schema)?; + + let literal_foo = ScalarValue::Utf8(Some("foo".to_owned())); + let literal_bar = ScalarValue::Utf8(Some("bar".to_owned())); + + let const_expr_1_foo = ConstExpr::new( + Arc::clone(&col_const_1), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_foo = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_bar = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_bar.clone())), + ); + + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + + // | Input | Const_1 | Const_2 | + // | ----- | ------- | ------- | + // | 1 | foo | foo | + // | 2 | foo | bar | + input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?; + input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?; + + // Calculate union properties + let union_props = calculate_union(vec![input1, input2], schema)?; + + // This should result in: + // const_1 = Uniform("foo") + // const_2 = Heterogeneous + assert_eq!(union_props.constants().len(), 2); + let union_const_1 = &union_props.constants()[0]; + assert!(union_const_1.expr.eq(&col_const_1)); + assert_eq!( + union_const_1.across_partitions, + AcrossPartitions::Uniform(Some(literal_foo)), + ); + let union_const_2 = &union_props.constants()[1]; + assert!(union_const_2.expr.eq(&col_const_2)); + assert_eq!( + union_const_2.across_partitions, + AcrossPartitions::Heterogeneous, + ); + + Ok(()) + } } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 8a71b28486a2..dae0edcfb171 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -48,8 +48,8 @@ use crate::enforce_sorting::sort_pushdown::{ }; use crate::output_requirements::OutputRequirementExec; use crate::utils::{ - add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, - is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, + add_sort_above, add_sort_above_with_check, is_aggregation, is_coalesce_partitions, + is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; use crate::PhysicalOptimizerRule; @@ -678,7 +678,7 @@ fn remove_bottleneck_in_subplan( ) -> Result { let plan = &requirements.plan; let children = &mut requirements.children; - if is_coalesce_partitions(&children[0].plan) { + if is_coalesce_partitions(&children[0].plan) && !is_aggregation(plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_child_node = children[0].children.swap_remove(0); while new_child_node.plan.output_partitioning() == plan.output_partitioning() diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index acc70d39f057..3cc5319f9e10 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -32,6 +32,8 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -135,6 +137,14 @@ pub fn check_plan_sanity( plan.required_input_ordering(), plan.required_input_distribution(), ) { + // TEMP HACK WORKAROUND https://github.com/apache/datafusion/issues/11492 + if child.as_any().downcast_ref::().is_some() { + continue; + } + if child.as_any().downcast_ref::().is_some() { + continue; + } + let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { let sort_req = sort_req.into_single(); diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 3655e555a744..d3207d4880a7 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use datafusion_common::Result; use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -113,3 +114,8 @@ pub fn is_repartition(plan: &Arc) -> bool { pub fn is_limit(plan: &Arc) -> bool { plan.as_any().is::() || plan.as_any().is::() } + +/// Checks whether the given operator is a [`AggregateExec`]. +pub fn is_aggregation(plan: &Arc) -> bool { + plan.as_any().is::() +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9a6832283486..3708ec4900a0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2503,6 +2503,10 @@ fn compare_join_arrays( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2571,6 +2575,10 @@ fn is_join_arrays_equal( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2600,7 +2608,8 @@ mod tests { use arrow::array::{ builder::{BooleanBuilder, UInt64Builder}, - BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array, + BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Int32Array, RecordBatch, UInt64Array, }; use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; @@ -2694,6 +2703,56 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + fn build_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Binary, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(BinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_fixed_size_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::FixedSizeBinary(3), false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(FixedSizeBinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + /// returns a table with 3 columns of i32 in memory pub fn build_table_i32_nullable( a: (&str, &Vec>), @@ -3932,6 +3991,100 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_binary() -> Result<()> { + let left = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_fixed_size_binary() -> Result<()> { + let left = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + #[tokio::test] async fn join_left_sort_order() -> Result<()> { let left = build_table( diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index d3335c0e7fe1..4c991544f877 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1377,6 +1377,7 @@ mod tests { Arc::new(window_frame), &input.schema(), false, + false, )?], input, input_order_mode, diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5583abfd72a2..085b17cab9bc 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,21 +103,38 @@ pub fn create_window_expr( window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, + distinct: bool, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) - .alias(name) - .with_ignore_nulls(ignore_nulls) - .build() - .map(Arc::new)?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) + if distinct { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .distinct() + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } else { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } } WindowFunctionDefinition::WindowUDF(fun) => Arc::new(StandardWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, @@ -805,6 +822,7 @@ mod tests { Arc::new(WindowFrame::new(None)), schema.as_ref(), false, + false, )?], blocking_exec, false, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43afaa0fbe65..f59e97df0d46 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -316,6 +316,7 @@ pub fn serialize_expr( ref window_frame, // TODO: support null treatment in proto null_treatment: _, + distinct: _, }, } = window_fun.as_ref(); let mut buf = Vec::new(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c60470b2218..2ed6ec037fc8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -179,6 +179,7 @@ pub fn parse_physical_window_expr( Arc::new(window_frame), &extended_schema, false, + false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6c51d553fe16..b56fdc0fede6 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -981,7 +981,18 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), - approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + None, + ), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + Some(lit(50)), + ), grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e63ca75d019d..fd0e7dc6e3b9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -352,6 +352,7 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct: function_args.distinct, }; for planner in self.context_provider.get_expr_planners().iter() { @@ -368,8 +369,19 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct, } = window_expr; + if distinct { + return Expr::from(expr::WindowFunction::new(func_def, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build(); + } + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) @@ -380,10 +392,6 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - if fm.is_ordered_set_aggregate() && within_group.is_empty() { - return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); - } - if null_treatment.is_some() && !fm.supports_null_handling_clause() { return plan_err!( "[IGNORE | RESPECT] NULLS are not permitted for {}", @@ -403,7 +411,8 @@ impl SqlToRel<'_, S> { None, )?; - // add target column expression in within group clause to function arguments + // Add the WITHIN GROUP ordering expressions to the front of the argument list + // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) if !within_group.is_empty() { args = within_group .iter() diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 4ddd5ccccbbd..4c0dc316615c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,9 @@ use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, Expr as AstExpr, Function, Ident, Interval, - ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, ValueWithSpan, + self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, + Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, + ValueWithSpan, }; use std::sync::Arc; use std::vec; @@ -198,6 +199,7 @@ impl Unparser<'_> { partition_by, order_by, window_frame, + distinct, .. }, } = window_fun.as_ref(); @@ -256,7 +258,8 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, + duplicate_treatment: distinct + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -339,7 +342,7 @@ impl Unparser<'_> { }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: distinct - .then_some(ast::DuplicateTreatment::Distinct), + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -2051,6 +2054,7 @@ mod tests { order_by: vec![], window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, @@ -2076,6 +2080,7 @@ mod tests { ), ), null_treatment: None, + distinct: false, }, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 753820b6b619..ab31a87b9e35 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1287,7 +1287,7 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 ## Column `c12` is omitted due to a large relative error (~10%) due to the small ## float values. -#csv_query_approx_percentile_cont (c2) +# csv_query_approx_percentile_cont (c2) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- @@ -1303,6 +1303,23 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS D ---- true + +# csv_query_approx_percentile_cont (c2, alternate syntax, should be the same as above) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + # csv_query_approx_percentile_cont (c3) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 @@ -1743,6 +1760,40 @@ c 122 d 124 e 115 + +# csv_query_approx_percentile_cont_with_weight (should be the same as above) +query TI +SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + +# using approx_percentile_cont on 2 columns with same signature +query TII +SELECT c1, approx_percentile_cont(c2, 0.95) AS c2, approx_percentile_cont(c3, 0.95) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 5 73 +b 5 68 +c 5 122 +d 5 124 +e 5 115 + +# error is unique to this UDAF +query TRR +SELECT c1, avg(c2) AS c2, avg(c3) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 2.857142857143 -18.333333333333 +b 3.263157894737 -5.842105263158 +c 2.666666666667 -1.333333333333 +d 2.444444444444 25.444444444444 +e 3 40.333333333333 + + + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1762,6 +1813,17 @@ c 122 d 124 e 115 +# csv_query_approx_percentile_cont_with_weight alternate syntax +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + query TI SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1790,6 +1852,16 @@ c 123 d 124 e 115 +# approx_percentile_cont_with_weight with centroids +query TI +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + # csv_query_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index c17fe8dfc7e6..ed463333217a 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -833,9 +833,61 @@ t2 as ( 11 14 12 15 -# return sql params back to default values statement ok -set datafusion.optimizer.prefer_hash_join = true; +set datafusion.execution.batch_size = 8192; + +###### +## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys +###### statement ok -set datafusion.execution.batch_size = 8192; +create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5); + +statement ok +create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50); + +# Binary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# LargeBinary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# BinaryView join keys +query ?I?I +with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1), + t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# FixedSizeBinary join keys +query ?I?I +with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1), + t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +statement ok +drop table t1; + +statement ok +drop table t2; + +# return sql params back to default values +statement ok +set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 82de11302857..bed9121eec3f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5650,3 +5650,82 @@ WINDOW 3 7 4 11 5 16 + + +# window with distinct operation +statement ok +CREATE TABLE table_test_distinct_count ( + k VARCHAR, + v Int, + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table_test_distinct_count (k, v, time) VALUES + ('a', 1, '1970-01-01T00:01:00.00Z'), + ('a', 1, '1970-01-01T00:02:00.00Z'), + ('a', 1, '1970-01-01T00:03:00.00Z'), + ('a', 2, '1970-01-01T00:03:00.00Z'), + ('a', 1, '1970-01-01T00:04:00.00Z'), + ('b', 3, '1970-01-01T00:01:00.00Z'), + ('b', 3, '1970-01-01T00:02:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'); + +query TPII +SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:04:00Z 4 2 +b 1970-01-01T00:01:00Z 1 1 +b 1970-01-01T00:02:00Z 2 1 +b 1970-01-01T00:03:00Z 4 2 +b 1970-01-01T00:03:00Z 4 2 + + +query TT +EXPLAIN SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ODER BY k, time; +---- +logical_plan +01)Projection: oder.k, oder.time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +02)--WindowAggr: windowExpr=[[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +03)----SubqueryAlias: oder +04)------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)ProjectionExec: expr=[k@0 as k, time@2 as time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +02)--BoundedWindowAggExec: wdw=[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +03)----SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=1 +05)--------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +06)----------DataSourceExec: partitions=2, partition_sizes=[5, 4] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 80b643a547ee..27f0de84b7a0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -112,6 +112,7 @@ pub async fn from_window_function( order_by, window_frame, null_treatment: None, + distinct: false, }, })) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs index 17e71f2d7c14..94a39e930f1c 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -42,6 +42,7 @@ pub fn from_window_function( order_by, window_frame, null_treatment: _, + distinct: _, }, } = window_fn; // function reference diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 03ab86eeb813..abf0286fa85b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Functions -| Syntax | Description | -| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | -| avg(expr) | Сalculates the average value for `expr`. | -| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | -| approx_median(expr) | Calculates an approximation of the median for `expr`. | -| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. | -| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. | -| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | -| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | -| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | -| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | -| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | -| count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | -| cube(exprs) | Creates a grouping set for all combination of `exprs` | -| grouping_set(exprs) | Create a grouping set. | -| max(expr) | Finds the maximum value of `expr`. | -| median(expr) | Сalculates the median of `expr`. | -| min(expr) | Finds the minimum value of `expr`. | -| rollup(exprs) | Creates a grouping set for rollup sets. | -| sum(expr) | Сalculates the sum of `expr`. | +| Syntax | Description | +| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| avg(expr) | Сalculates the average value for `expr`. | +| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | +| approx_median(expr) | Calculates an approximation of the median for `expr`. | +| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | +| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | +| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | +| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | +| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | +| count(expr) | Returns the number of rows for `expr`. | +| count_distinct | Creates an expression to represent the count(distinct) aggregate function | +| cube(exprs) | Creates a grouping set for all combination of `exprs` | +| grouping_set(exprs) | Create a grouping set. | +| max(expr) | Finds the maximum value of `expr`. | +| median(expr) | Сalculates the median of `expr`. | +| min(expr) | Finds the minimum value of `expr`. | +| rollup(exprs) | Creates a grouping set for rollup sets. | +| sum(expr) | Сalculates the sum of `expr`. | ## Aggregate Function Builder diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf3..4f2f0abe55c9 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,6 +846,12 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) #### Example ```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -854,12 +860,30 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +-----------------------------------------------------------------------+ ``` +An alternate syntax is also supported: + +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` + ### `approx_percentile_cont_with_weight` Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -867,6 +891,7 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex - **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. #### Example @@ -877,4 +902,21 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` + +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```