diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b0190dadf3c3..d4efee6c1f0c 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -45,7 +45,7 @@ bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples base64 = "0.22.1" -datafusion = { workspace = true, default-features = true, features = ["parquet_encryption"] } +datafusion = { workspace = true, default-features = true, features = ["parquet_encryption", "sql"] } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs index a3ae5396eda1..7a9704688a15 100644 --- a/datafusion-examples/examples/relation_planner/table_sample.rs +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -83,13 +83,12 @@ use std::{ any::Any, fmt::{self, Debug, Formatter}, hash::{Hash, Hasher}, - ops::{Add, Div, Mul, Sub}, pin::Pin, - str::FromStr, sync::Arc, task::{Context, Poll}, }; +use arrow::datatypes::{Float64Type, Int64Type}; use arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, compute, @@ -102,6 +101,7 @@ use futures::{ use rand::{rngs::StdRng, Rng, SeedableRng}; use tonic::async_trait; +use datafusion::optimizer::simplify_expressions::simplify_sql_literal::parse_sql_literal; use datafusion::{ execution::{ context::QueryPlanner, RecordBatchStream, SendableRecordBatchStream, @@ -415,7 +415,7 @@ impl RelationPlanner for TableSamplePlanner { match quantity.unit { // TABLESAMPLE (N ROWS) - exact row limit Some(TableSampleUnit::Rows) => { - let rows = parse_quantity::(&quantity.value)?; + let rows: i64 = parse_sql_literal::(&quantity.value, context)?; if rows < 0 { return plan_err!("row count must be non-negative, got {}", rows); } @@ -427,7 +427,8 @@ impl RelationPlanner for TableSamplePlanner { // TABLESAMPLE (N PERCENT) - percentage sampling Some(TableSampleUnit::Percent) => { - let percent = parse_quantity::(&quantity.value)?; + let percent: f64 = + parse_sql_literal::(&quantity.value, context)?; let fraction = percent / 100.0; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) @@ -435,7 +436,7 @@ impl RelationPlanner for TableSamplePlanner { // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 None => { - let value = parse_quantity::(&quantity.value)?; + let value = parse_sql_literal::(&quantity.value, context)?; if value < 0.0 { return plan_err!("sample value must be non-negative, got {}", value); } @@ -454,40 +455,6 @@ impl RelationPlanner for TableSamplePlanner { } } -/// Parse a SQL expression as a numeric value (supports basic arithmetic). -fn parse_quantity(expr: &ast::Expr) -> Result -where - T: FromStr + Add + Sub + Mul + Div, -{ - eval_numeric_expr(expr) - .ok_or_else(|| plan_datafusion_err!("invalid numeric expression: {:?}", expr)) -} - -/// Recursively evaluate numeric SQL expressions. -fn eval_numeric_expr(expr: &ast::Expr) -> Option -where - T: FromStr + Add + Sub + Mul + Div, -{ - match expr { - ast::Expr::Value(v) => match &v.value { - ast::Value::Number(n, _) => n.to_string().parse().ok(), - _ => None, - }, - ast::Expr::BinaryOp { left, op, right } => { - let l = eval_numeric_expr::(left)?; - let r = eval_numeric_expr::(right)?; - match op { - ast::BinaryOperator::Plus => Some(l + r), - ast::BinaryOperator::Minus => Some(l - r), - ast::BinaryOperator::Multiply => Some(l * r), - ast::BinaryOperator::Divide => Some(l / r), - _ => None, - } - } - _ => None, - } -} - /// Custom logical plan node representing a TABLESAMPLE operation. /// /// Stores sampling parameters (bounds, seed) and wraps the input plan. diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index d2ecd34886de..ced6c21277de 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -98,6 +98,7 @@ serde = [ ] sql = [ "datafusion-common/sql", + "datafusion-optimizer/sql", "datafusion-functions-nested?/sql", "datafusion-sql", "sqlparser", diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 0fb08684cd14..221c4a2da229 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -41,7 +41,9 @@ workspace = true name = "datafusion_optimizer" [features] +default = ["sql"] recursive_protection = ["dep:recursive"] +sql = ["datafusion-expr/sql"] [dependencies] arrow = { workspace = true } diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index e238fca32689..a7410c476da5 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -23,6 +23,8 @@ mod inlist_simplifier; mod regex; pub mod simplify_exprs; mod simplify_predicates; +#[cfg(feature = "sql")] +pub mod simplify_sql_literal; mod unwrap_cast; mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_sql_literal.rs b/datafusion/optimizer/src/simplify_expressions/simplify_sql_literal.rs new file mode 100644 index 000000000000..34c980607396 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_sql_literal.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parses and simplifies a SQL expression to a literal of a given type. +//! +//! This module provides functionality to parse and simplify static SQL expressions +//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required +//! in a planning (not an execution) phase, they need to be reduced to literals of a given type. + +use crate::simplify_expressions::ExprSimplifier; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::{ + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err, + plan_err, +}; +use datafusion_expr::Expr; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::planner::RelationPlannerContext; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::sqlparser::ast; +use std::sync::Arc; + +/// Parse and simplifies a SQL expression to a numeric literal, +/// corresponding to an arrow primitive type `T` (for example, Float64Type). +/// +/// This function simplifies and coerces the expression, then extracts the underlying +/// native type using `TryFrom`. +/// +/// # Arguments +/// * `expr` - A logical AST expression +/// * `schema` - Schema reference for expression planning +/// * `context` - `RelationPlannerContext` context +/// +/// # Returns +/// A `Result` containing a literal type +/// +/// # Example +/// ```ignore +/// let value: f64 = parse_sql_literal::(&expr, &schema, &mut relPlannerContext)?; +/// ``` +pub fn parse_sql_literal( + expr: &ast::Expr, + context: &mut dyn RelationPlannerContext, +) -> Result +where + T: ArrowPrimitiveType, + ::Native: TryFrom, +{ + // Empty schema is sufficient because it parses only literal expressions + let schema = DFSchemaRef::new(DFSchema::empty()); + + match context.sql_to_expr(expr.clone(), &schema) { + Ok(logical_expr) => { + log::debug!("Parsing expr {:?} to type {}", logical_expr, T::DATA_TYPE); + + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + ); + + // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5) + let simplified_expr: Expr = simplifier + .simplify(logical_expr.clone()) + .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?; + let coerced_expr: Expr = + simplifier.coerce(simplified_expr, schema.as_ref())?; + log::debug!("Coerced expression: {:?}", &coerced_expr); + + match coerced_expr { + Expr::Literal(scalar_value, _) => { + // It is a literal - proceed to the underlying value + // Cast to the target type if needed + let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?; + + // Extract the native type + T::Native::try_from(casted_scalar).map_err(|err| { + plan_datafusion_err!( + "Cannot extract {} from scalar value: {err}", + std::any::type_name::() + ) + }) + } + actual => { + plan_err!( + "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression" + ) + } + } + } + Err(err) => { + plan_err!("Cannot construct logical expression from {expr:?}: {err}") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Float64Type, Int64Type}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{TableReference, not_impl_err}; + use datafusion_expr::planner::ContextProvider; + use datafusion_expr::sqlparser::parser::Parser; + use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_sql::planner::{PlannerContext, SqlToRel}; + use datafusion_sql::relation::SqlToRelRelationContext; + use datafusion_sql::sqlparser::dialect::GenericDialect; + use std::sync::Arc; + + // Simple mock context provider for testing + struct MockContextProvider { + options: ConfigOptions, + } + + impl ContextProvider for MockContextProvider { + fn get_table_source(&self, _: TableReference) -> Result> { + not_impl_err!("mock") + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + vec![] + } + + fn udaf_names(&self) -> Vec { + vec![] + } + + fn udwf_names(&self) -> Vec { + vec![] + } + } + + #[test] + fn test_parse_sql_float_literal() { + let test_cases = vec![ + ("0.0", 0.0), + ("1.0", 1.0), + ("0", 0.0), + ("1", 1.0), + ("0.5", 0.5), + ("100.0", 100.0), + ("0.001", 0.001), + ("999.999", 999.999), + ("1.0 + 2.0", 3.0), + ("10.0 * 0.5", 5.0), + ("100.0 / 4.0", 25.0), + ("(80.0 + 2.0*10.0) / 4.0", 25.0), + ("50.0 - 10.0", 40.0), + ("1e2", 100.0), + ("1.5e1", 15.0), + ("2.5e-1", 0.25), + ]; + + let context = MockContextProvider { + options: ConfigOptions::default(), + }; + let sql_to_rel = SqlToRel::new(&context); + let mut planner_context = PlannerContext::new(); + let mut sql_context = + SqlToRelRelationContext::new(&sql_to_rel, &mut planner_context); + let dialect = GenericDialect {}; + + for (sql_expr, expected) in test_cases { + let ast_expr = Parser::new(&dialect) + .try_with_sql(sql_expr) + .unwrap() + .parse_expr() + .unwrap(); + + let result: Result = + parse_sql_literal::(&ast_expr, &mut sql_context); + + match result { + Ok(value) => { + assert!( + (value - expected).abs() < 1e-10, + "For expression '{sql_expr}': expected {expected}, got {value}", + ); + } + Err(e) => panic!("Failed to parse expression '{sql_expr}': {e}"), + } + } + } + + #[test] + fn test_parse_sql_integer_literal() { + let context = MockContextProvider { + options: ConfigOptions::default(), + }; + let sql_to_rel = SqlToRel::new(&context); + let mut planner_context = PlannerContext::new(); + let mut sql_context = + SqlToRelRelationContext::new(&sql_to_rel, &mut planner_context); + let dialect = GenericDialect {}; + + // Integer + let ast_expr = Parser::new(&dialect) + .try_with_sql("2 + 4") + .unwrap() + .parse_expr() + .unwrap(); + + let result: Result = + parse_sql_literal::(&ast_expr, &mut sql_context); + + match result { + Ok(value) => { + assert_eq!(6, value); + } + Err(e) => panic!("Failed to parse expression: {e}"), + } + } +} diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 7fef670933f9..3eaba01e97eb 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -47,7 +47,7 @@ mod expr; pub mod parser; pub mod planner; mod query; -mod relation; +pub mod relation; pub mod resolve; mod select; mod set_expr; diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 3115d8dfffbd..5f96779d1a1f 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -33,11 +33,23 @@ use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; mod join; -struct SqlToRelRelationContext<'a, 'b, S: ContextProvider> { +pub struct SqlToRelRelationContext<'a, 'b, S: ContextProvider> { planner: &'a SqlToRel<'b, S>, planner_context: &'a mut PlannerContext, } +impl<'a, 'b, S: ContextProvider> SqlToRelRelationContext<'a, 'b, S> { + pub fn new( + planner: &'a SqlToRel<'b, S>, + planner_context: &'a mut PlannerContext, + ) -> Self { + Self { + planner, + planner_context, + } + } +} + // Implement RelationPlannerContext impl<'a, 'b, S: ContextProvider> RelationPlannerContext for SqlToRelRelationContext<'a, 'b, S> @@ -117,11 +129,7 @@ impl SqlToRel<'_, S> { let mut current_relation = relation; for planner in planners.iter() { - let mut context = SqlToRelRelationContext { - planner: self, - planner_context, - }; - + let mut context = SqlToRelRelationContext::new(self, planner_context); match planner.plan_relation(current_relation, &mut context)? { RelationPlanning::Planned(planned) => { return Ok(RelationPlanning::Planned(planned));