diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 7084bc440e86..56ef54a1d450 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -18,7 +18,10 @@ use arrow::array::BooleanArray; use arrow::array::{make_comparator, ArrayRef, Datum}; use arrow::buffer::NullBuffer; -use arrow::compute::SortOptions; +use arrow::compute::kernels::cmp::{ + distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct, +}; +use arrow::compute::{ilike, like, nilike, nlike, SortOptions}; use arrow::error::ArrowError; use datafusion_common::DataFusionError; use datafusion_common::{arrow_datafusion_err, internal_err}; @@ -53,22 +56,49 @@ pub fn apply( } } -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs` pub fn apply_cmp( + op: Operator, lhs: &ColumnarValue, rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, ) -> Result { - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) + if lhs.data_type().is_nested() { + apply_cmp_for_nested(op, lhs, rhs) + } else { + let f = match op { + Operator::Eq => eq, + Operator::NotEq => neq, + Operator::Lt => lt, + Operator::LtEq => lt_eq, + Operator::Gt => gt, + Operator::GtEq => gt_eq, + Operator::IsDistinctFrom => distinct, + Operator::IsNotDistinctFrom => not_distinct, + + Operator::LikeMatch => like, + Operator::ILikeMatch => ilike, + Operator::NotLikeMatch => nlike, + Operator::NotILikeMatch => nilike, + + _ => { + return internal_err!("Invalid compare operator: {}", op); + } + }; + + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) + } } -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like +/// Applies a binary [`Datum`] comparison operator `op` to `lhs` and `rhs` for nested type like /// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type pub fn apply_cmp_for_nested( op: Operator, lhs: &ColumnarValue, rhs: &ColumnarValue, ) -> Result { + let left_data_type = lhs.data_type(); + let right_data_type = rhs.data_type(); + if matches!( op, Operator::Eq @@ -79,12 +109,18 @@ pub fn apply_cmp_for_nested( | Operator::GtEq | Operator::IsDistinctFrom | Operator::IsNotDistinctFrom - ) { + ) && left_data_type.equals_datatype(&right_data_type) + { apply(lhs, rhs, |l, r| { Ok(Arc::new(compare_op_for_nested(op, l, r)?)) }) } else { - internal_err!("invalid operator for nested") + internal_err!( + "invalid operator or data type mismatch for nested data, op {} left {}, right {}", + op, + left_data_type, + right_data_type + ) } } @@ -97,7 +133,7 @@ pub fn compare_with_eq( if is_nested { compare_op_for_nested(Operator::Eq, lhs, rhs) } else { - arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e)) + eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e)) } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index ce3d4ced4e3a..b09d57f02d58 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -24,11 +24,8 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, or_kleene}; -use arrow::compute::kernels::cmp::*; use arrow::compute::kernels::concat_elements::concat_elements_utf8; -use arrow::compute::{ - cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator, -}; +use arrow::compute::{cast, filter_record_batch, SlicesIterator}; use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; @@ -42,7 +39,7 @@ use datafusion_expr::statistics::{ new_generic_from_binary_op, Distribution, }; use datafusion_expr::{ColumnarValue, Operator}; -use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; +use datafusion_physical_expr_common::datum::{apply, apply_cmp}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, @@ -251,13 +248,6 @@ impl PhysicalExpr for BinaryExpr { let schema = batch.schema(); let input_schema = schema.as_ref(); - if left_data_type.is_nested() { - if !left_data_type.equals_datatype(&right_data_type) { - return internal_err!("Cannot evaluate binary expression because of type mismatch: left {}, right {} ", left_data_type, right_data_type); - } - return apply_cmp_for_nested(self.op, &lhs, &rhs); - } - match self.op { Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), @@ -267,18 +257,21 @@ impl PhysicalExpr for BinaryExpr { Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), - Operator::Eq => return apply_cmp(&lhs, &rhs, eq), - Operator::NotEq => return apply_cmp(&lhs, &rhs, neq), - Operator::Lt => return apply_cmp(&lhs, &rhs, lt), - Operator::Gt => return apply_cmp(&lhs, &rhs, gt), - Operator::LtEq => return apply_cmp(&lhs, &rhs, lt_eq), - Operator::GtEq => return apply_cmp(&lhs, &rhs, gt_eq), - Operator::IsDistinctFrom => return apply_cmp(&lhs, &rhs, distinct), - Operator::IsNotDistinctFrom => return apply_cmp(&lhs, &rhs, not_distinct), - Operator::LikeMatch => return apply_cmp(&lhs, &rhs, like), - Operator::ILikeMatch => return apply_cmp(&lhs, &rhs, ilike), - Operator::NotLikeMatch => return apply_cmp(&lhs, &rhs, nlike), - Operator::NotILikeMatch => return apply_cmp(&lhs, &rhs, nilike), + + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch => { + return apply_cmp(self.op, &lhs, &rhs); + } _ => {} } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e86c778d5161..1c9ae530f500 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -19,7 +19,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::apply_cmp; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -118,14 +118,13 @@ impl PhysicalExpr for LikeExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - use arrow::compute::*; let lhs = self.expr.evaluate(batch)?; let rhs = self.pattern.evaluate(batch)?; match (self.negated, self.case_insensitive) { - (false, false) => apply_cmp(&lhs, &rhs, like), - (false, true) => apply_cmp(&lhs, &rhs, ilike), - (true, false) => apply_cmp(&lhs, &rhs, nlike), - (true, true) => apply_cmp(&lhs, &rhs, nilike), + (false, false) => apply_cmp(Operator::LikeMatch, &lhs, &rhs), + (false, true) => apply_cmp(Operator::ILikeMatch, &lhs, &rhs), + (true, false) => apply_cmp(Operator::NotLikeMatch, &lhs, &rhs), + (true, true) => apply_cmp(Operator::NotILikeMatch, &lhs, &rhs), } }