@@ -23,6 +23,7 @@ use arrow::array::*;
2323use arrow:: compute:: kernels:: zip:: zip;
2424use arrow:: compute:: {
2525 is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder , FilterPredicate ,
26+ SlicesIterator ,
2627} ;
2728use arrow:: datatypes:: { DataType , Schema , UInt32Type } ;
2829use arrow:: error:: ArrowError ;
@@ -246,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
246247}
247248
248249/// Creates a [FilterPredicate] from a boolean array.
249- fn create_filter ( predicate : & BooleanArray ) -> FilterPredicate {
250+ fn create_filter ( predicate : & BooleanArray , optimize : bool ) -> FilterPredicate {
250251 let mut filter_builder = FilterBuilder :: new ( predicate) ;
251- // Always optimize the filter since we use them multiple times.
252- filter_builder = filter_builder. optimize ( ) ;
252+ if optimize {
253+ // Always optimize the filter since we use them multiple times.
254+ filter_builder = filter_builder. optimize ( ) ;
255+ }
253256 filter_builder. build ( )
254257}
255258
@@ -290,6 +293,84 @@ fn filter_array(
290293 filter. filter ( array)
291294}
292295
296+ fn merge (
297+ mask : & BooleanArray ,
298+ truthy : ColumnarValue ,
299+ falsy : ColumnarValue ,
300+ ) -> std:: result:: Result < ArrayRef , ArrowError > {
301+ let ( truthy, truthy_is_scalar) = match truthy {
302+ ColumnarValue :: Array ( a) => ( a, false ) ,
303+ ColumnarValue :: Scalar ( s) => ( s. to_array ( ) ?, true ) ,
304+ } ;
305+ let ( falsy, falsy_is_scalar) = match falsy {
306+ ColumnarValue :: Array ( a) => ( a, false ) ,
307+ ColumnarValue :: Scalar ( s) => ( s. to_array ( ) ?, true ) ,
308+ } ;
309+
310+ if truthy_is_scalar && falsy_is_scalar {
311+ return zip ( mask, & Scalar :: new ( truthy) , & Scalar :: new ( falsy) ) ;
312+ }
313+
314+ let falsy = falsy. to_data ( ) ;
315+ let truthy = truthy. to_data ( ) ;
316+
317+ let mut mutable = MutableArrayData :: new ( vec ! [ & truthy, & falsy] , false , truthy. len ( ) ) ;
318+
319+ // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
320+ // fill with falsy values
321+
322+ // keep track of how much is filled
323+ let mut filled = 0 ;
324+ let mut falsy_offset = 0 ;
325+ let mut truthy_offset = 0 ;
326+
327+ SlicesIterator :: new ( mask) . for_each ( |( start, end) | {
328+ // the gap needs to be filled with falsy values
329+ if start > filled {
330+ if falsy_is_scalar {
331+ for _ in filled..start {
332+ // Copy the first item from the 'falsy' array into the output buffer.
333+ mutable. extend ( 1 , 0 , 1 ) ;
334+ }
335+ } else {
336+ let falsy_length = start - filled;
337+ let falsy_end = falsy_offset + falsy_length;
338+ mutable. extend ( 1 , falsy_offset, falsy_end) ;
339+ falsy_offset = falsy_end;
340+ }
341+ }
342+ // fill with truthy values
343+ if truthy_is_scalar {
344+ for _ in start..end {
345+ // Copy the first item from the 'truthy' array into the output buffer.
346+ mutable. extend ( 0 , 0 , 1 ) ;
347+ }
348+ } else {
349+ let truthy_length = end - start;
350+ let truthy_end = truthy_offset + truthy_length;
351+ mutable. extend ( 0 , truthy_offset, truthy_end) ;
352+ truthy_offset = truthy_end;
353+ }
354+ filled = end;
355+ } ) ;
356+ // the remaining part is falsy
357+ if filled < mask. len ( ) {
358+ if falsy_is_scalar {
359+ for _ in filled..mask. len ( ) {
360+ // Copy the first item from the 'falsy' array into the output buffer.
361+ mutable. extend ( 1 , 0 , 1 ) ;
362+ }
363+ } else {
364+ let falsy_length = mask. len ( ) - filled;
365+ let falsy_end = falsy_offset + falsy_length;
366+ mutable. extend ( 1 , falsy_offset, falsy_end) ;
367+ }
368+ }
369+
370+ let data = mutable. freeze ( ) ;
371+ Ok ( make_array ( data) )
372+ }
373+
293374/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
294375/// those values.
295376///
@@ -342,7 +423,7 @@ fn filter_array(
342423/// └───────────┘ └─────────┘ └─────────┘
343424/// values indices result
344425/// ```
345- fn merge ( values : & [ ArrayData ] , indices : & [ PartialResultIndex ] ) -> Result < ArrayRef > {
426+ fn merge_n ( values : & [ ArrayData ] , indices : & [ PartialResultIndex ] ) -> Result < ArrayRef > {
346427 #[ cfg( debug_assertions) ]
347428 for ix in indices {
348429 if let Some ( index) = ix. index ( ) {
@@ -647,7 +728,7 @@ impl ResultBuilder {
647728 }
648729 Partial { arrays, indices } => {
649730 // Merge partial results into a single array.
650- Ok ( ColumnarValue :: Array ( merge ( & arrays, & indices) ?) )
731+ Ok ( ColumnarValue :: Array ( merge_n ( & arrays, & indices) ?) )
651732 }
652733 Complete ( v) => {
653734 // If we have a complete result, we can just return it.
@@ -723,6 +804,26 @@ impl CaseExpr {
723804}
724805
725806impl CaseBody {
807+ fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
808+ // since all then results have the same data type, we can choose any one as the
809+ // return data type except for the null.
810+ let mut data_type = DataType :: Null ;
811+ for i in 0 ..self . when_then_expr . len ( ) {
812+ data_type = self . when_then_expr [ i] . 1 . data_type ( input_schema) ?;
813+ if !data_type. equals_datatype ( & DataType :: Null ) {
814+ break ;
815+ }
816+ }
817+ // if all then results are null, we use data type of else expr instead if possible.
818+ if data_type. equals_datatype ( & DataType :: Null ) {
819+ if let Some ( e) = & self . else_expr {
820+ data_type = e. data_type ( input_schema) ?;
821+ }
822+ }
823+
824+ Ok ( data_type)
825+ }
826+
726827 /// See [CaseExpr::case_when_with_expr].
727828 fn case_when_with_expr (
728829 & self ,
@@ -767,7 +868,7 @@ impl CaseBody {
767868 result_builder. add_branch_result ( & remainder_rows, nulls_value) ?;
768869 } else {
769870 // Filter out the null rows and evaluate the else expression for those
770- let nulls_filter = create_filter ( & not ( & base_not_nulls) ?) ;
871+ let nulls_filter = create_filter ( & not ( & base_not_nulls) ?, true ) ;
771872 let nulls_batch =
772873 filter_record_batch ( & remainder_batch, & nulls_filter) ?;
773874 let nulls_rows = filter_array ( & remainder_rows, & nulls_filter) ?;
@@ -782,7 +883,7 @@ impl CaseBody {
782883 }
783884
784885 // Remove the null rows from the remainder batch
785- let not_null_filter = create_filter ( & base_not_nulls) ;
886+ let not_null_filter = create_filter ( & base_not_nulls, true ) ;
786887 remainder_batch =
787888 Cow :: Owned ( filter_record_batch ( & remainder_batch, & not_null_filter) ?) ;
788889 remainder_rows = filter_array ( & remainder_rows, & not_null_filter) ?;
@@ -802,8 +903,7 @@ impl CaseBody {
802903 compare_with_eq ( & a, & base_values, base_value_is_nested)
803904 }
804905 ColumnarValue :: Scalar ( s) => {
805- let scalar = Scalar :: new ( s. to_array ( ) ?) ;
806- compare_with_eq ( & scalar, & base_values, base_value_is_nested)
906+ compare_with_eq ( & s. to_scalar ( ) ?, & base_values, base_value_is_nested)
807907 }
808908 } ?;
809909
@@ -829,7 +929,7 @@ impl CaseBody {
829929 // for the current branch
830930 // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
831931 // this unconditionally.
832- let then_filter = create_filter ( & when_value) ;
932+ let then_filter = create_filter ( & when_value, true ) ;
833933 let then_batch = filter_record_batch ( & remainder_batch, & then_filter) ?;
834934 let then_rows = filter_array ( & remainder_rows, & then_filter) ?;
835935
@@ -852,7 +952,7 @@ impl CaseBody {
852952 not ( & prep_null_mask_filter ( & when_value) )
853953 }
854954 } ?;
855- let next_filter = create_filter ( & next_selection) ;
955+ let next_filter = create_filter ( & next_selection, true ) ;
856956 remainder_batch =
857957 Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_filter) ?) ;
858958 remainder_rows = filter_array ( & remainder_rows, & next_filter) ?;
@@ -918,7 +1018,7 @@ impl CaseBody {
9181018 // for the current branch
9191019 // Still no need to call `prep_null_mask_filter` since `create_filter` will already do
9201020 // this unconditionally.
921- let then_filter = create_filter ( when_value) ;
1021+ let then_filter = create_filter ( when_value, true ) ;
9221022 let then_batch = filter_record_batch ( & remainder_batch, & then_filter) ?;
9231023 let then_rows = filter_array ( & remainder_rows, & then_filter) ?;
9241024
@@ -941,7 +1041,7 @@ impl CaseBody {
9411041 not ( & prep_null_mask_filter ( when_value) )
9421042 }
9431043 } ?;
944- let next_filter = create_filter ( & next_selection) ;
1044+ let next_filter = create_filter ( & next_selection, true ) ;
9451045 remainder_batch =
9461046 Cow :: Owned ( filter_record_batch ( & remainder_batch, & next_filter) ?) ;
9471047 remainder_rows = filter_array ( & remainder_rows, & next_filter) ?;
@@ -964,24 +1064,38 @@ impl CaseBody {
9641064 & self ,
9651065 batch : & RecordBatch ,
9661066 when_value : & BooleanArray ,
967- return_type : & DataType ,
9681067 ) -> Result < ColumnarValue > {
969- let then_value = self . when_then_expr [ 0 ]
970- . 1
971- . evaluate_selection ( batch, when_value) ?
972- . into_array ( batch. num_rows ( ) ) ?;
1068+ let when_value = match when_value. null_count ( ) {
1069+ 0 => Cow :: Borrowed ( when_value) ,
1070+ _ => {
1071+ // `prep_null_mask_filter` is required to ensure null is treated as false
1072+ Cow :: Owned ( prep_null_mask_filter ( when_value) )
1073+ }
1074+ } ;
1075+
1076+ let optimize_filter = batch. num_columns ( ) > 1 ;
1077+
1078+ let when_filter = create_filter ( & when_value, optimize_filter) ;
1079+ let then_batch = filter_record_batch ( batch, & when_filter) ?;
1080+ let then_value = self . when_then_expr [ 0 ] . 1 . evaluate ( & then_batch) ?;
1081+
1082+ let else_selection = not ( & when_value) ?;
1083+ let else_filter = create_filter ( & else_selection, optimize_filter) ;
1084+ let else_batch = filter_record_batch ( batch, & else_filter) ?;
9731085
974- // evaluate else expression on the values not covered by when_value
975- let remainder = not ( when_value) ?;
976- let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
9771086 // keep `else_expr`'s data type and return type consistent
978- let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
1087+ let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
1088+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
1089+ let else_expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
9791090 . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
980- let else_ = expr
981- . evaluate_selection ( batch, & remainder) ?
982- . into_array ( batch. num_rows ( ) ) ?;
9831091
984- Ok ( ColumnarValue :: Array ( zip ( & remainder, & else_, & then_value) ?) )
1092+ let else_value = else_expr. evaluate ( & else_batch) ?;
1093+
1094+ Ok ( ColumnarValue :: Array ( merge (
1095+ & when_value,
1096+ then_value,
1097+ else_value,
1098+ ) ?) )
9851099 }
9861100}
9871101
@@ -1113,41 +1227,34 @@ impl CaseExpr {
11131227 batch : & RecordBatch ,
11141228 projected : & ProjectedCaseBody ,
11151229 ) -> Result < ColumnarValue > {
1116- let return_type = self . data_type ( & batch. schema ( ) ) ?;
1117-
11181230 // evaluate when condition on batch
11191231 let when_value = self . body . when_then_expr [ 0 ] . 0 . evaluate ( batch) ?;
1120- let when_value = when_value. into_array ( batch. num_rows ( ) ) ?;
1232+ // `num_rows == 1` is intentional to avoid expanding scalars.
1233+ // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1234+ // below will avoid incorrectly using the scalar as a merge/zip mask.
1235+ let when_value = when_value. into_array ( 1 ) ?;
11211236 let when_value = as_boolean_array ( & when_value) . map_err ( |e| {
11221237 DataFusionError :: Context (
11231238 "WHEN expression did not return a BooleanArray" . to_string ( ) ,
11241239 Box :: new ( e) ,
11251240 )
11261241 } ) ?;
11271242
1128- // For the true and false/null selection vectors, bypass `evaluate_selection` and merging
1129- // results. This avoids materializing the array for the other branch which we will discard
1130- // entirely anyway.
11311243 let true_count = when_value. true_count ( ) ;
1132- if true_count == batch. num_rows ( ) {
1133- return self . body . when_then_expr [ 0 ] . 1 . evaluate ( batch) ;
1244+ if true_count == when_value. len ( ) {
1245+ // All input rows are true, just call the 'then' expression
1246+ self . body . when_then_expr [ 0 ] . 1 . evaluate ( batch)
11341247 } else if true_count == 0 {
1135- return self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch) ;
1136- }
1137-
1138- // Treat 'NULL' as false value
1139- let when_value = match when_value. null_count ( ) {
1140- 0 => Cow :: Borrowed ( when_value) ,
1141- _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
1142- } ;
1143-
1144- if projected. projection . len ( ) < batch. num_columns ( ) {
1248+ // All input rows are false/null, just call the 'else' expression
1249+ self . body . else_expr . as_ref ( ) . unwrap ( ) . evaluate ( batch)
1250+ } else if projected. projection . len ( ) < batch. num_columns ( ) {
1251+ // The case expressions do not use all the columns of the input batch.
1252+ // Project first to reduce time spent filtering.
11451253 let projected_batch = batch. project ( & projected. projection ) ?;
1146- projected
1147- . body
1148- . expr_or_expr ( & projected_batch, & when_value, & return_type)
1254+ projected. body . expr_or_expr ( & projected_batch, when_value)
11491255 } else {
1150- self . body . expr_or_expr ( batch, & when_value, & return_type)
1256+ // All columns are used in the case expressions, so there is no need to project.
1257+ self . body . expr_or_expr ( batch, when_value)
11511258 }
11521259 }
11531260}
@@ -1159,23 +1266,7 @@ impl PhysicalExpr for CaseExpr {
11591266 }
11601267
11611268 fn data_type ( & self , input_schema : & Schema ) -> Result < DataType > {
1162- // since all then results have the same data type, we can choose any one as the
1163- // return data type except for the null.
1164- let mut data_type = DataType :: Null ;
1165- for i in 0 ..self . body . when_then_expr . len ( ) {
1166- data_type = self . body . when_then_expr [ i] . 1 . data_type ( input_schema) ?;
1167- if !data_type. equals_datatype ( & DataType :: Null ) {
1168- break ;
1169- }
1170- }
1171- // if all then results are null, we use data type of else expr instead if possible.
1172- if data_type. equals_datatype ( & DataType :: Null ) {
1173- if let Some ( e) = & self . body . else_expr {
1174- data_type = e. data_type ( input_schema) ?;
1175- }
1176- }
1177-
1178- Ok ( data_type)
1269+ self . body . data_type ( input_schema)
11791270 }
11801271
11811272 fn nullable ( & self , input_schema : & Schema ) -> Result < bool > {
@@ -2154,7 +2245,7 @@ mod tests {
21542245 PartialResultIndex :: try_new( 2 ) . unwrap( ) ,
21552246 ] ;
21562247
2157- let merged = merge ( & [ a1, a2, a3] , & indices) . unwrap ( ) ;
2248+ let merged = merge_n ( & [ a1, a2, a3] , & indices) . unwrap ( ) ;
21582249 let merged = merged. as_string :: < i32 > ( ) ;
21592250
21602251 assert_eq ! ( merged. len( ) , indices. len( ) ) ;
0 commit comments