From 8d5cb791539a13475382bd9874f56a43d6b39837 Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Sun, 7 Sep 2025 02:05:17 +1000 Subject: [PATCH] feat: Support binary data types for `SortMergeJoin` `on` clause (#17431) * feat: Support binary data types for `SortMergeJoin` `on` clause * Add sql level tests for merge join on binary keys --------- Co-authored-by: Andrew Lamb --- .../src/joins/sort_merge_join.rs | 155 +++++++++++++++++- .../test_files/sort_merge_join.slt | 58 ++++++- 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9a6832283486..3708ec4900a0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2503,6 +2503,10 @@ fn compare_join_arrays( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2571,6 +2575,10 @@ fn is_join_arrays_equal( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2600,7 +2608,8 @@ mod tests { use arrow::array::{ builder::{BooleanBuilder, UInt64Builder}, - BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array, + BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Int32Array, RecordBatch, UInt64Array, }; use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; @@ -2694,6 +2703,56 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + fn build_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Binary, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(BinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_fixed_size_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::FixedSizeBinary(3), false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(FixedSizeBinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + /// returns a table with 3 columns of i32 in memory pub fn build_table_i32_nullable( a: (&str, &Vec>), @@ -3932,6 +3991,100 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_binary() -> Result<()> { + let left = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_fixed_size_binary() -> Result<()> { + let left = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + #[tokio::test] async fn join_left_sort_order() -> Result<()> { let left = build_table( diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index c17fe8dfc7e6..ed463333217a 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -833,9 +833,61 @@ t2 as ( 11 14 12 15 -# return sql params back to default values statement ok -set datafusion.optimizer.prefer_hash_join = true; +set datafusion.execution.batch_size = 8192; + +###### +## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys +###### statement ok -set datafusion.execution.batch_size = 8192; +create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5); + +statement ok +create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50); + +# Binary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# LargeBinary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# BinaryView join keys +query ?I?I +with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1), + t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# FixedSizeBinary join keys +query ?I?I +with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1), + t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +statement ok +drop table t1; + +statement ok +drop table t2; + +# return sql params back to default values +statement ok +set datafusion.optimizer.prefer_hash_join = true;