Skip to content

Commit 2b34d81

Browse files
committed
Avoid scatter operation in expr_or_expr=
1 parent 2a9a495 commit 2b34d81

File tree

1 file changed

+157
-66
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+157
-66
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 157 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::array::*;
2323
use arrow::compute::kernels::zip::zip;
2424
use arrow::compute::{
2525
is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
26+
SlicesIterator,
2627
};
2728
use arrow::datatypes::{DataType, Schema, UInt32Type};
2829
use arrow::error::ArrowError;
@@ -246,10 +247,12 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
246247
}
247248

248249
/// Creates a [FilterPredicate] from a boolean array.
249-
fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
250+
fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
250251
let mut filter_builder = FilterBuilder::new(predicate);
251-
// Always optimize the filter since we use them multiple times.
252-
filter_builder = filter_builder.optimize();
252+
if optimize {
253+
// Always optimize the filter since we use them multiple times.
254+
filter_builder = filter_builder.optimize();
255+
}
253256
filter_builder.build()
254257
}
255258

@@ -290,6 +293,84 @@ fn filter_array(
290293
filter.filter(array)
291294
}
292295

296+
fn merge(
297+
mask: &BooleanArray,
298+
truthy: ColumnarValue,
299+
falsy: ColumnarValue,
300+
) -> std::result::Result<ArrayRef, ArrowError> {
301+
let (truthy, truthy_is_scalar) = match truthy {
302+
ColumnarValue::Array(a) => (a, false),
303+
ColumnarValue::Scalar(s) => (s.to_array()?, true),
304+
};
305+
let (falsy, falsy_is_scalar) = match falsy {
306+
ColumnarValue::Array(a) => (a, false),
307+
ColumnarValue::Scalar(s) => (s.to_array()?, true),
308+
};
309+
310+
if truthy_is_scalar && falsy_is_scalar {
311+
return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
312+
}
313+
314+
let falsy = falsy.to_data();
315+
let truthy = truthy.to_data();
316+
317+
let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
318+
319+
// the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
320+
// fill with falsy values
321+
322+
// keep track of how much is filled
323+
let mut filled = 0;
324+
let mut falsy_offset = 0;
325+
let mut truthy_offset = 0;
326+
327+
SlicesIterator::new(mask).for_each(|(start, end)| {
328+
// the gap needs to be filled with falsy values
329+
if start > filled {
330+
if falsy_is_scalar {
331+
for _ in filled..start {
332+
// Copy the first item from the 'falsy' array into the output buffer.
333+
mutable.extend(1, 0, 1);
334+
}
335+
} else {
336+
let falsy_length = start - filled;
337+
let falsy_end = falsy_offset + falsy_length;
338+
mutable.extend(1, falsy_offset, falsy_end);
339+
falsy_offset = falsy_end;
340+
}
341+
}
342+
// fill with truthy values
343+
if truthy_is_scalar {
344+
for _ in start..end {
345+
// Copy the first item from the 'truthy' array into the output buffer.
346+
mutable.extend(0, 0, 1);
347+
}
348+
} else {
349+
let truthy_length = end - start;
350+
let truthy_end = truthy_offset + truthy_length;
351+
mutable.extend(0, truthy_offset, truthy_end);
352+
truthy_offset = truthy_end;
353+
}
354+
filled = end;
355+
});
356+
// the remaining part is falsy
357+
if filled < mask.len() {
358+
if falsy_is_scalar {
359+
for _ in filled..mask.len() {
360+
// Copy the first item from the 'falsy' array into the output buffer.
361+
mutable.extend(1, 0, 1);
362+
}
363+
} else {
364+
let falsy_length = mask.len() - filled;
365+
let falsy_end = falsy_offset + falsy_length;
366+
mutable.extend(1, falsy_offset, falsy_end);
367+
}
368+
}
369+
370+
let data = mutable.freeze();
371+
Ok(make_array(data))
372+
}
373+
293374
/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from
294375
/// those values.
295376
///
@@ -342,7 +423,7 @@ fn filter_array(
342423
/// └───────────┘ └─────────┘ └─────────┘
343424
/// values indices result
344425
/// ```
345-
fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
426+
fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
346427
#[cfg(debug_assertions)]
347428
for ix in indices {
348429
if let Some(index) = ix.index() {
@@ -647,7 +728,7 @@ impl ResultBuilder {
647728
}
648729
Partial { arrays, indices } => {
649730
// Merge partial results into a single array.
650-
Ok(ColumnarValue::Array(merge(&arrays, &indices)?))
731+
Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
651732
}
652733
Complete(v) => {
653734
// If we have a complete result, we can just return it.
@@ -723,6 +804,26 @@ impl CaseExpr {
723804
}
724805

725806
impl CaseBody {
807+
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
808+
// since all then results have the same data type, we can choose any one as the
809+
// return data type except for the null.
810+
let mut data_type = DataType::Null;
811+
for i in 0..self.when_then_expr.len() {
812+
data_type = self.when_then_expr[i].1.data_type(input_schema)?;
813+
if !data_type.equals_datatype(&DataType::Null) {
814+
break;
815+
}
816+
}
817+
// if all then results are null, we use data type of else expr instead if possible.
818+
if data_type.equals_datatype(&DataType::Null) {
819+
if let Some(e) = &self.else_expr {
820+
data_type = e.data_type(input_schema)?;
821+
}
822+
}
823+
824+
Ok(data_type)
825+
}
826+
726827
/// See [CaseExpr::case_when_with_expr].
727828
fn case_when_with_expr(
728829
&self,
@@ -767,7 +868,7 @@ impl CaseBody {
767868
result_builder.add_branch_result(&remainder_rows, nulls_value)?;
768869
} else {
769870
// Filter out the null rows and evaluate the else expression for those
770-
let nulls_filter = create_filter(&not(&base_not_nulls)?);
871+
let nulls_filter = create_filter(&not(&base_not_nulls)?, true);
771872
let nulls_batch =
772873
filter_record_batch(&remainder_batch, &nulls_filter)?;
773874
let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
@@ -782,7 +883,7 @@ impl CaseBody {
782883
}
783884

784885
// Remove the null rows from the remainder batch
785-
let not_null_filter = create_filter(&base_not_nulls);
886+
let not_null_filter = create_filter(&base_not_nulls, true);
786887
remainder_batch =
787888
Cow::Owned(filter_record_batch(&remainder_batch, &not_null_filter)?);
788889
remainder_rows = filter_array(&remainder_rows, &not_null_filter)?;
@@ -802,8 +903,7 @@ impl CaseBody {
802903
compare_with_eq(&a, &base_values, base_value_is_nested)
803904
}
804905
ColumnarValue::Scalar(s) => {
805-
let scalar = Scalar::new(s.to_array()?);
806-
compare_with_eq(&scalar, &base_values, base_value_is_nested)
906+
compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
807907
}
808908
}?;
809909

@@ -829,7 +929,7 @@ impl CaseBody {
829929
// for the current branch
830930
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
831931
// this unconditionally.
832-
let then_filter = create_filter(&when_value);
932+
let then_filter = create_filter(&when_value, true);
833933
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
834934
let then_rows = filter_array(&remainder_rows, &then_filter)?;
835935

@@ -852,7 +952,7 @@ impl CaseBody {
852952
not(&prep_null_mask_filter(&when_value))
853953
}
854954
}?;
855-
let next_filter = create_filter(&next_selection);
955+
let next_filter = create_filter(&next_selection, true);
856956
remainder_batch =
857957
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
858958
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
@@ -918,7 +1018,7 @@ impl CaseBody {
9181018
// for the current branch
9191019
// Still no need to call `prep_null_mask_filter` since `create_filter` will already do
9201020
// this unconditionally.
921-
let then_filter = create_filter(when_value);
1021+
let then_filter = create_filter(when_value, true);
9221022
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
9231023
let then_rows = filter_array(&remainder_rows, &then_filter)?;
9241024

@@ -941,7 +1041,7 @@ impl CaseBody {
9411041
not(&prep_null_mask_filter(when_value))
9421042
}
9431043
}?;
944-
let next_filter = create_filter(&next_selection);
1044+
let next_filter = create_filter(&next_selection, true);
9451045
remainder_batch =
9461046
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
9471047
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
@@ -964,24 +1064,38 @@ impl CaseBody {
9641064
&self,
9651065
batch: &RecordBatch,
9661066
when_value: &BooleanArray,
967-
return_type: &DataType,
9681067
) -> Result<ColumnarValue> {
969-
let then_value = self.when_then_expr[0]
970-
.1
971-
.evaluate_selection(batch, when_value)?
972-
.into_array(batch.num_rows())?;
1068+
let when_value = match when_value.null_count() {
1069+
0 => Cow::Borrowed(when_value),
1070+
_ => {
1071+
// `prep_null_mask_filter` is required to ensure null is treated as false
1072+
Cow::Owned(prep_null_mask_filter(when_value))
1073+
}
1074+
};
1075+
1076+
let optimize_filter = batch.num_columns() > 1;
1077+
1078+
let when_filter = create_filter(&when_value, optimize_filter);
1079+
let then_batch = filter_record_batch(batch, &when_filter)?;
1080+
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
1081+
1082+
let else_selection = not(&when_value)?;
1083+
let else_filter = create_filter(&else_selection, optimize_filter);
1084+
let else_batch = filter_record_batch(batch, &else_filter)?;
9731085

974-
// evaluate else expression on the values not covered by when_value
975-
let remainder = not(when_value)?;
976-
let e = self.else_expr.as_ref().unwrap();
9771086
// keep `else_expr`'s data type and return type consistent
978-
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
1087+
let e = self.else_expr.as_ref().unwrap();
1088+
let return_type = self.data_type(&batch.schema())?;
1089+
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
9791090
.unwrap_or_else(|_| Arc::clone(e));
980-
let else_ = expr
981-
.evaluate_selection(batch, &remainder)?
982-
.into_array(batch.num_rows())?;
9831091

984-
Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
1092+
let else_value = else_expr.evaluate(&else_batch)?;
1093+
1094+
Ok(ColumnarValue::Array(merge(
1095+
&when_value,
1096+
then_value,
1097+
else_value,
1098+
)?))
9851099
}
9861100
}
9871101

@@ -1113,41 +1227,34 @@ impl CaseExpr {
11131227
batch: &RecordBatch,
11141228
projected: &ProjectedCaseBody,
11151229
) -> Result<ColumnarValue> {
1116-
let return_type = self.data_type(&batch.schema())?;
1117-
11181230
// evaluate when condition on batch
11191231
let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
1120-
let when_value = when_value.into_array(batch.num_rows())?;
1232+
// `num_rows == 1` is intentional to avoid expanding scalars.
1233+
// If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks
1234+
// below will avoid incorrectly using the scalar as a merge/zip mask.
1235+
let when_value = when_value.into_array(1)?;
11211236
let when_value = as_boolean_array(&when_value).map_err(|e| {
11221237
DataFusionError::Context(
11231238
"WHEN expression did not return a BooleanArray".to_string(),
11241239
Box::new(e),
11251240
)
11261241
})?;
11271242

1128-
// For the true and false/null selection vectors, bypass `evaluate_selection` and merging
1129-
// results. This avoids materializing the array for the other branch which we will discard
1130-
// entirely anyway.
11311243
let true_count = when_value.true_count();
1132-
if true_count == batch.num_rows() {
1133-
return self.body.when_then_expr[0].1.evaluate(batch);
1244+
if true_count == when_value.len() {
1245+
// All input rows are true, just call the 'then' expression
1246+
self.body.when_then_expr[0].1.evaluate(batch)
11341247
} else if true_count == 0 {
1135-
return self.body.else_expr.as_ref().unwrap().evaluate(batch);
1136-
}
1137-
1138-
// Treat 'NULL' as false value
1139-
let when_value = match when_value.null_count() {
1140-
0 => Cow::Borrowed(when_value),
1141-
_ => Cow::Owned(prep_null_mask_filter(when_value)),
1142-
};
1143-
1144-
if projected.projection.len() < batch.num_columns() {
1248+
// All input rows are false/null, just call the 'else' expression
1249+
self.body.else_expr.as_ref().unwrap().evaluate(batch)
1250+
} else if projected.projection.len() < batch.num_columns() {
1251+
// The case expressions do not use all the columns of the input batch.
1252+
// Project first to reduce time spent filtering.
11451253
let projected_batch = batch.project(&projected.projection)?;
1146-
projected
1147-
.body
1148-
.expr_or_expr(&projected_batch, &when_value, &return_type)
1254+
projected.body.expr_or_expr(&projected_batch, when_value)
11491255
} else {
1150-
self.body.expr_or_expr(batch, &when_value, &return_type)
1256+
// All columns are used in the case expressions, so there is no need to project.
1257+
self.body.expr_or_expr(batch, when_value)
11511258
}
11521259
}
11531260
}
@@ -1159,23 +1266,7 @@ impl PhysicalExpr for CaseExpr {
11591266
}
11601267

11611268
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
1162-
// since all then results have the same data type, we can choose any one as the
1163-
// return data type except for the null.
1164-
let mut data_type = DataType::Null;
1165-
for i in 0..self.body.when_then_expr.len() {
1166-
data_type = self.body.when_then_expr[i].1.data_type(input_schema)?;
1167-
if !data_type.equals_datatype(&DataType::Null) {
1168-
break;
1169-
}
1170-
}
1171-
// if all then results are null, we use data type of else expr instead if possible.
1172-
if data_type.equals_datatype(&DataType::Null) {
1173-
if let Some(e) = &self.body.else_expr {
1174-
data_type = e.data_type(input_schema)?;
1175-
}
1176-
}
1177-
1178-
Ok(data_type)
1269+
self.body.data_type(input_schema)
11791270
}
11801271

11811272
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
@@ -2154,7 +2245,7 @@ mod tests {
21542245
PartialResultIndex::try_new(2).unwrap(),
21552246
];
21562247

2157-
let merged = merge(&[a1, a2, a3], &indices).unwrap();
2248+
let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
21582249
let merged = merged.as_string::<i32>();
21592250

21602251
assert_eq!(merged.len(), indices.len());

0 commit comments

Comments
 (0)