diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8182e4fd47d4..b6ae571e6024 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1193,6 +1193,127 @@ impl ScalarValue { }) } + /// Creates a non-null placeholder value of the given type. + /// + /// Supports all Arrow data types. The `val` parameter allows generating + /// distinct placeholder values + pub fn try_new_placeholder(data_type: &DataType, val: i64) -> Result { + let placeholder = match data_type { + DataType::Int8 => ScalarValue::Int8(Some(val as i8)), + DataType::Int16 => ScalarValue::Int16(Some(val as i16)), + DataType::Int32 => ScalarValue::Int32(Some(val as i32)), + DataType::Int64 => ScalarValue::Int64(Some(val)), + DataType::UInt8 => ScalarValue::UInt8(Some(val as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(val as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(val as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(val as u64)), + DataType::Float32 => ScalarValue::Float32(Some(val as f32)), + DataType::Float64 => ScalarValue::Float64(Some(val as f64)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(val as f32))), + DataType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(Some(val as i128), *precision, *scale) + } + DataType::Decimal256(precision, scale) => ScalarValue::Decimal256( + Some(i256::from_i128(val as i128)), + *precision, + *scale, + ), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(Some(val as i32), *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(Some(val), *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(Some(val.to_string())), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(Some(val.to_string())), + DataType::Utf8View => ScalarValue::Utf8View(Some(val.to_string())), + DataType::Binary => ScalarValue::Binary(Some(val.to_string().into_bytes())), + DataType::LargeBinary => { + ScalarValue::LargeBinary(Some(val.to_string().into_bytes())) + } + DataType::BinaryView => { + ScalarValue::BinaryView(Some(val.to_string().into_bytes())) + } + DataType::FixedSizeBinary(size) => { + let mut bytes = val.to_string().into_bytes(); + bytes.resize(*size as usize, 0); + ScalarValue::FixedSizeBinary(*size, Some(bytes)) + } + DataType::Boolean => ScalarValue::Boolean(Some(val % 2 == 1)), + DataType::Date32 => ScalarValue::Date32(Some(val as i32)), + DataType::Date64 => ScalarValue::Date64(Some(val)), + DataType::Time32(_) => ScalarValue::Time32Second(Some(val as i32)), + DataType::Time64(_) => ScalarValue::Time64Nanosecond(Some(val)), + DataType::Timestamp(unit, tz) => match unit { + TimeUnit::Second => ScalarValue::TimestampSecond(Some(val), tz.clone()), + TimeUnit::Millisecond => { + ScalarValue::TimestampMillisecond(Some(val), tz.clone()) + } + TimeUnit::Microsecond => { + ScalarValue::TimestampMicrosecond(Some(val), tz.clone()) + } + TimeUnit::Nanosecond => { + ScalarValue::TimestampNanosecond(Some(val), tz.clone()) + } + }, + DataType::Interval(_) => ScalarValue::IntervalYearMonth(Some(val as i32)), + DataType::Duration(_) => ScalarValue::DurationNanosecond(Some(val)), + DataType::Null => ScalarValue::Null, + DataType::List(field) => { + ScalarValue::List(ScalarValue::new_list_nullable(&[], field.data_type())) + } + DataType::LargeList(field) => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], field.data_type()), + ), + DataType::FixedSizeList(field, size) => { + let empty_arr = new_empty_array(field.data_type()); + let values = Arc::new( + SingleRowListArrayBuilder::new(empty_arr) + .with_nullable(field.is_nullable()) + .build_fixed_size_list_array(*size as usize), + ); + ScalarValue::FixedSizeList(values) + } + DataType::ListView(field) => { + // ListView is not supported as a ScalarValue variant yet - use List as workaround + ScalarValue::List(ScalarValue::new_list_nullable(&[], field.data_type())) + } + DataType::LargeListView(field) => { + // LargeListView is not supported as a ScalarValue variant yet - use LargeList as workaround + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + field.data_type(), + )) + } + DataType::Struct(fields) => ScalarValue::Struct( + new_null_array(&DataType::Struct(fields.clone()), 1) + .as_struct() + .to_owned() + .into(), + ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } + DataType::Map(field, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(Arc::clone(field), *sorted), 1) + .as_map() + .to_owned() + .into(), + ), + DataType::Dictionary(key_type, value_type) => { + // Dictionaries: create placeholder of value type, encoding happens at execution time + let value = ScalarValue::try_new_placeholder(value_type, val)?; + ScalarValue::Dictionary(key_type.clone(), Box::new(value)) + } + DataType::RunEndEncoded(_run_ends_type, _values_type) => { + return _internal_err!( + "Run-end encoded columns are not yet supported for placeholder creation" + ); + } + }; + Ok(placeholder) + } + /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { ScalarValue::from(val.into()) diff --git a/datafusion/core/src/execution/lateral_table_function.rs b/datafusion/core/src/execution/lateral_table_function.rs new file mode 100644 index 000000000000..9a6f46cbdc44 --- /dev/null +++ b/datafusion/core/src/execution/lateral_table_function.rs @@ -0,0 +1,450 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for LATERAL table functions + +use std::any::Any; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::ArrayRef; +use arrow::compute::concat_batches; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_catalog::TableFunction; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, +}; +use futures::future::BoxFuture; +use futures::stream::Stream; +use futures::{ready, FutureExt, StreamExt}; + +use crate::execution::session_state::SessionState; + +/// Execution plan for LATERAL table functions +/// +/// This operator evaluates a table function for each row from the input, +/// allowing the table function to reference columns from the outer query. +#[derive(Debug, Clone)] +pub struct LateralTableFunctionExec { + /// Input execution plan (the "outer" table) + input: Arc, + /// Name of the table function to call + function_name: String, + /// The table function instance + table_function: Arc, + /// Physical expressions for table function arguments + args: Vec>, + /// Complete output schema (input columns + table function columns). + /// Used when creating the final output batches that combine both. + schema: SchemaRef, + /// Table function output schema only (excludes input columns). + /// Used when concatenating batches from the table function before + /// combining them with input columns. + table_function_schema: SchemaRef, + /// Session state for accessing catalog and scanning + session_state: Arc, + /// Cached plan properties (partitioning, ordering, equivalences, etc.). + /// Computed once during construction to avoid expensive recalculation + /// during query optimization. + cache: PlanProperties, +} + +impl LateralTableFunctionExec { + pub fn new( + input: Arc, + function_name: String, + table_function: Arc, + args: Vec>, + schema: SchemaRef, + table_function_schema: SchemaRef, + session_state: Arc, + ) -> Self { + let cache = Self::compute_properties(&input, Arc::clone(&schema)); + Self { + input, + function_name, + table_function, + args, + schema, + table_function_schema, + session_state, + cache, + } + } + + fn compute_properties( + input: &Arc, + schema: SchemaRef, + ) -> PlanProperties { + let eq_properties = EquivalenceProperties::new(schema); + PlanProperties::new( + eq_properties, + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + ) + } + + pub fn function_name(&self) -> &str { + &self.function_name + } + + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for LateralTableFunctionExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "LateralTableFunctionExec: function={}(..)", + self.function_name + ) + } + DisplayFormatType::TreeRender => { + write!(f, "LateralTableFunction({})", self.function_name) + } + } + } +} + +impl ExecutionPlan for LateralTableFunctionExec { + fn name(&self) -> &str { + "LateralTableFunctionExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!( + "LateralTableFunctionExec expects exactly one child, got {}", + children.len() + ); + } + + Ok(Arc::new(LateralTableFunctionExec::new( + Arc::clone(&children[0]), + self.function_name.clone(), + Arc::clone(&self.table_function), + self.args.clone(), + Arc::clone(&self.schema), + Arc::clone(&self.table_function_schema), + Arc::clone(&self.session_state), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input_stream = self.input.execute(partition, Arc::clone(&context))?; + + Ok(Box::pin(LateralTableFunctionStream { + input: input_stream, + table_function: Arc::clone(&self.table_function), + args: self.args.clone(), + table_function_schema: Arc::clone(&self.table_function_schema), + schema: Arc::clone(&self.schema), + session_state: Arc::clone(&self.session_state), + context, + state: ProcessingState::ReadingInput, + })) + } +} + +enum ProcessingState { + ReadingInput, + ProcessingRow { + input_batch: RecordBatch, + row_idx: usize, + output_batches: Vec, + }, + ScanningTableFunction { + input_batch: RecordBatch, + row_idx: usize, + output_batches: Vec, + scan_future: BoxFuture<'static, Result>>, + }, + ReadingTableStream { + input_batch: RecordBatch, + row_idx: usize, + output_batches: Vec, + table_stream: SendableRecordBatchStream, + table_batches: Vec, + }, +} + +struct LateralTableFunctionStream { + input: SendableRecordBatchStream, + table_function: Arc, + args: Vec>, + table_function_schema: SchemaRef, + schema: SchemaRef, + session_state: Arc, + context: Arc, + state: ProcessingState, +} + +impl Stream for LateralTableFunctionStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let state = std::mem::replace( + &mut self.state, + ProcessingState::ReadingInput, + ); + + match state { + ProcessingState::ReadingInput => { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Some(Ok(RecordBatch::new_empty( + Arc::clone(&self.schema), + )))); + } + self.state = ProcessingState::ProcessingRow { + input_batch: batch, + row_idx: 0, + output_batches: Vec::new(), + }; + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + + ProcessingState::ProcessingRow { + input_batch, + row_idx, + output_batches, + } => { + if row_idx >= input_batch.num_rows() { + let result = self.combine_output_batches(&output_batches)?; + return Poll::Ready(Some(Ok(result))); + } + + let arg_values = match self.evaluate_args(&input_batch, row_idx) { + Ok(args) => args, + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + }; + + let table_provider = + match self.table_function.function().call(&arg_values) { + Ok(provider) => provider, + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + }; + + let session_state = Arc::clone(&self.session_state); + let scan_future = async move { + table_provider + .scan(session_state.as_ref(), None, &[], None) + .await + } + .boxed(); + + self.state = ProcessingState::ScanningTableFunction { + input_batch, + row_idx, + output_batches, + scan_future, + }; + } + + ProcessingState::ScanningTableFunction { + input_batch, + row_idx, + output_batches, + mut scan_future, + } => { + let table_exec = ready!(scan_future.poll_unpin(cx)); + let table_exec = match table_exec { + Ok(exec) => exec, + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + }; + + let table_stream = match table_exec.execute(0, Arc::clone(&self.context)) + { + Ok(stream) => stream, + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + }; + + self.state = ProcessingState::ReadingTableStream { + input_batch, + row_idx, + output_batches, + table_stream, + table_batches: Vec::new(), + }; + } + + ProcessingState::ReadingTableStream { + input_batch, + row_idx, + mut output_batches, + mut table_stream, + mut table_batches, + } => { + match ready!(table_stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + table_batches.push(batch); + self.state = ProcessingState::ReadingTableStream { + input_batch, + row_idx, + output_batches, + table_stream, + table_batches, + }; + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + None => { + if !table_batches.is_empty() { + match self.combine_row_with_table_output( + &input_batch, + row_idx, + &table_batches, + ) { + Ok(output_batch) => { + output_batches.push(output_batch); + } + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + } + } + + self.state = ProcessingState::ProcessingRow { + input_batch, + row_idx: row_idx + 1, + output_batches, + }; + } + } + } + } + } + } +} + +impl LateralTableFunctionStream { + fn evaluate_args( + &self, + input_batch: &RecordBatch, + row_idx: usize, + ) -> Result> { + self.args + .iter() + .map(|arg_expr| { + let columnar_value = arg_expr.evaluate(input_batch)?; + match columnar_value { + ColumnarValue::Scalar(scalar) => Ok(Expr::Literal(scalar, None)), + ColumnarValue::Array(array) => { + let scalar = ScalarValue::try_from_array(&array, row_idx)?; + Ok(Expr::Literal(scalar, None)) + } + } + }) + .collect::>>() + } + + fn combine_row_with_table_output( + &self, + input_batch: &RecordBatch, + row_idx: usize, + table_batches: &[RecordBatch], + ) -> Result { + let combined_table_batch = if table_batches.len() == 1 { + table_batches[0].clone() + } else { + concat_batches(&self.table_function_schema, table_batches)? + }; + + let input_row_arrays: Vec = input_batch + .columns() + .iter() + .map(|col| { + let scalar = ScalarValue::try_from_array(col, row_idx)?; + scalar.to_array_of_size(combined_table_batch.num_rows()) + }) + .collect::>>()?; + + let mut output_columns = input_row_arrays; + output_columns.extend(combined_table_batch.columns().iter().cloned()); + + RecordBatch::try_new(Arc::clone(&self.schema), output_columns) + .map_err(Into::into) + } + + fn combine_output_batches( + &self, + output_batches: &[RecordBatch], + ) -> Result { + if output_batches.is_empty() { + Ok(RecordBatch::new_empty(Arc::clone(&self.schema))) + } else if output_batches.len() == 1 { + Ok(output_batches[0].clone()) + } else { + concat_batches(&self.schema, output_batches).map_err(Into::into) + } + } +} + +impl RecordBatchStream for LateralTableFunctionStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index 2e3e09685bcc..700d27cda6a0 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -18,6 +18,7 @@ //! Shared state for query planning and execution. pub mod context; +pub mod lateral_table_function; pub mod session_state; pub use session_state::{SessionState, SessionStateBuilder}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c28e56790e66..ac10f09e9d11 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -27,6 +27,7 @@ use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::{source_as_provider, DefaultTableSource}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; +use crate::execution::lateral_table_function::LateralTableFunctionExec; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, @@ -65,8 +66,8 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::TableReference; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - ScalarValue, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, + plan_err, DFSchema, ScalarValue, }; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; @@ -1360,6 +1361,43 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Analyze must be root of the plan" ) } + LogicalPlan::LateralTableFunction(lateral_tf) => { + let input = children.one()?; + + let physical_args = lateral_tf + .args + .iter() + .map(|expr| { + self.create_physical_expr( + expr, + lateral_tf.input.schema(), + session_state, + ) + }) + .collect::>>()?; + + let table_function = Arc::clone( + session_state + .table_functions() + .get(&lateral_tf.function_name) + .ok_or_else(|| { + plan_datafusion_err!( + "Table function '{}' not found", + lateral_tf.function_name + ) + })?, + ); + + Arc::new(LateralTableFunctionExec::new( + input, + lateral_tf.function_name.clone(), + table_function, + physical_args, + Arc::clone(lateral_tf.schema.inner()), + Arc::clone(lateral_tf.table_function_schema.inner()), + Arc::new(session_state.clone()), + )) + } }; Ok(exec_node) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index ea08c223e8f4..0efd73199523 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -22,9 +22,9 @@ use std::fmt; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, - Unnest, Values, Window, + Filter, Join, LateralTableFunction, Limit, LogicalPlan, Partitioning, Projection, + RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, + TableProviderFilterPushDown, TableScan, Unnest, Values, Window, }; use crate::dml::CopyTo; @@ -319,6 +319,17 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Is Distinct": is_distinct, }) } + LogicalPlan::LateralTableFunction(LateralTableFunction { + ref function_name, + ref args, + .. + }) => { + json!({ + "Node Type": "LateralTableFunction", + "Function": function_name, + "Args": expr_vec_fmt!(args) + }) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values = values .iter() diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 7de2fd117487..9568ac4073e4 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -40,9 +40,10 @@ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, - Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + Join, JoinConstraint, JoinType, LateralTableFunction, Limit, LogicalPlan, + Partitioning, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, + StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, + Unnest, Values, Window, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b8200ab8a48c..aebfd8f62ab2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -291,6 +291,10 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// Table function invocation with arguments that reference outer query columns. + /// Used to implement LATERAL table functions where arguments cannot be evaluated + /// at planning time because they depend on runtime row values. + LateralTableFunction(LateralTableFunction), } impl Default for LogicalPlan { @@ -355,6 +359,9 @@ impl LogicalPlan { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } + LogicalPlan::LateralTableFunction(LateralTableFunction { + schema, .. + }) => schema, } } @@ -477,6 +484,9 @@ impl LogicalPlan { recursive_term, .. }) => vec![static_term, recursive_term], + LogicalPlan::LateralTableFunction(LateralTableFunction { input, .. }) => { + vec![input] + } LogicalPlan::Statement(stmt) => stmt.inputs(), // plans without inputs LogicalPlan::TableScan { .. } @@ -585,6 +595,9 @@ impl LogicalPlan { .map_or(Ok(None), |v| v.map(Some)) } LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::LateralTableFunction(lateral) => Ok(Some(Expr::Column( + Column::from(lateral.schema.qualified_field(0)), + ))), LogicalPlan::EmptyRelation(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) @@ -737,6 +750,7 @@ impl LogicalPlan { Ok(LogicalPlan::Distinct(distinct)) } LogicalPlan::RecursiveQuery(_) => Ok(self), + LogicalPlan::LateralTableFunction(_) => Ok(self), LogicalPlan::Analyze(_) => Ok(self), LogicalPlan::Explain(_) => Ok(self), LogicalPlan::TableScan(_) => Ok(self), @@ -1133,6 +1147,22 @@ impl LogicalPlan { self.assert_no_inputs(inputs)?; Ok(self.clone()) } + LogicalPlan::LateralTableFunction(LateralTableFunction { + function_name, + args: _, + schema, + table_function_schema, + .. + }) => { + let input = self.only_input(inputs)?; + Ok(LogicalPlan::LateralTableFunction(LateralTableFunction { + input: Arc::new(input), + function_name: function_name.clone(), + args: expr, + schema: Arc::clone(schema), + table_function_schema: Arc::clone(table_function_schema), + })) + } LogicalPlan::Unnest(Unnest { exec_columns: columns, options, @@ -1372,6 +1402,7 @@ impl LogicalPlan { ) => input.max_rows(), LogicalPlan::Values(v) => Some(v.values.len()), LogicalPlan::Unnest(_) => None, + LogicalPlan::LateralTableFunction(_) => None, LogicalPlan::Ddl(_) | LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) @@ -1732,6 +1763,16 @@ impl LogicalPlan { }) => { write!(f, "RecursiveQuery: is_distinct={is_distinct}") } + LogicalPlan::LateralTableFunction(LateralTableFunction { + ref function_name, + ref args, + .. + }) => { + write!(f, "LateralTableFunction: {}({})", + function_name, + expr_vec_fmt!(args) + ) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values .iter() @@ -2103,6 +2144,44 @@ pub struct RecursiveQuery { pub is_distinct: bool, } +/// Table function call with arguments that reference outer query columns. +/// Used in LATERAL joins where table function arguments depend on values +/// from the outer query that are only available at execution time. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct LateralTableFunction { + /// Input from outer query + pub input: Arc, + /// Name of the table function + pub function_name: String, + /// Arguments to the table function (may contain OuterReferenceColumn) + pub args: Vec, + /// Complete output schema (input columns + table function columns) + pub schema: DFSchemaRef, + /// Table function output schema only (excludes input columns) + pub table_function_schema: DFSchemaRef, +} + +impl Debug for LateralTableFunction { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("LateralTableFunction") + .field("function_name", &self.function_name) + .field("args", &self.args) + .field("schema", &self.schema) + .finish() + } +} + +impl PartialOrd for LateralTableFunction { + fn partial_cmp(&self, other: &Self) -> Option { + // Compare by function name and args, skip schema comparison + match self.function_name.partial_cmp(&other.function_name) { + Some(Ordering::Equal) => {} + other => return other, + } + self.args.partial_cmp(&other.args) + } +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 47088370a1d9..3cafa1339ffb 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -40,9 +40,9 @@ use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, - UserDefinedLogicalNode, Values, Window, + LateralTableFunction, Limit, LogicalPlan, Partitioning, Prepare, Projection, + RecursiveQuery, Repartition, Sort, Statement, Subquery, SubqueryAlias, TableScan, + Union, Unnest, UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; @@ -337,6 +337,21 @@ impl TreeNode for LogicalPlan { }) }, ), + LogicalPlan::LateralTableFunction(LateralTableFunction { + input, + function_name, + args, + schema, + table_function_schema, + }) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::LateralTableFunction(LateralTableFunction { + input, + function_name, + args, + schema, + table_function_schema, + }) + }), LogicalPlan::Statement(stmt) => match stmt { Statement::Prepare(p) => p .input @@ -455,6 +470,9 @@ impl LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, .. }) => { (skip, fetch).apply_ref_elements(f) } + LogicalPlan::LateralTableFunction(LateralTableFunction { args, .. }) => { + args.apply_elements(f) + } LogicalPlan::Statement(stmt) => match stmt { Statement::Execute(Execute { parameters, .. }) => { parameters.apply_elements(f) @@ -631,6 +649,21 @@ impl LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) }) } + LogicalPlan::LateralTableFunction(LateralTableFunction { + input, + function_name, + args, + schema, + table_function_schema, + }) => args.map_elements(f)?.update_data(|args| { + LogicalPlan::LateralTableFunction(LateralTableFunction { + input, + function_name, + args, + schema, + table_function_schema, + }) + }), LogicalPlan::Statement(stmt) => match stmt { Statement::Execute(e) => { e.parameters.map_elements(f)?.update_data(|parameters| { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ec1f8f991a8e..86b58f90380f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -586,6 +586,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) + | LogicalPlan::LateralTableFunction(_) | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5db71417bc8f..be9e3d8e6215 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -322,6 +322,7 @@ fn optimize_projections( | LogicalPlan::Analyze(_) | LogicalPlan::Subquery(_) | LogicalPlan::Statement(_) + | LogicalPlan::LateralTableFunction(_) | LogicalPlan::Distinct(Distinct::All(_)) => { // These plans require all their fields, and their children should // be treated as final plans -- otherwise, we may have schema a diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index fd9e07914b07..c58d22143bae 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1788,6 +1788,9 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::LateralTableFunction(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for LateralTableFunction", + )), } } } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 754ded1514a6..3048c6c278e4 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -16,12 +16,17 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, plan_datafusion_err, Column, Result}; -use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, ScalarValue, +}; +use datafusion_expr::logical_plan::LateralTableFunction; +use datafusion_expr::{Expr, ExprSchemable, JoinType, LogicalPlan, LogicalPlanBuilder}; use sqlparser::ast::{ - Join, JoinConstraint, JoinOperator, ObjectName, TableFactor, TableWithJoins, + FunctionArg, FunctionArgExpr, Join, JoinConstraint, JoinOperator, ObjectName, + TableFactor, TableWithJoins, }; use std::collections::HashSet; +use std::sync::Arc; impl SqlToRel<'_, S> { pub(crate) fn plan_table_with_joins( @@ -49,7 +54,22 @@ impl SqlToRel<'_, S> { join: Join, planner_context: &mut PlannerContext, ) -> Result { - let right = if is_lateral_join(&join)? { + let is_lateral = is_lateral_join(&join)?; + + // For LATERAL joins, check if it's a table function before planning + if is_lateral { + if let Some(lateral_tf) = self.try_create_lateral_table_function( + &left, + &join.relation, + planner_context, + )? { + // LateralTableFunction already represents the complete join result + return Ok(lateral_tf); + } + } + + // Normal join planning + let right = if is_lateral { self.create_relation_subquery(join.relation, planner_context)? } else { self.create_relation(join.relation, planner_context)? @@ -177,6 +197,122 @@ impl SqlToRel<'_, S> { .build(), } } + + /// Try to create a LateralTableFunction node if the relation is a table function + /// with outer references. Returns None if this is not a lateral table function case. + fn try_create_lateral_table_function( + &self, + left: &LogicalPlan, + relation: &TableFactor, + planner_context: &mut PlannerContext, + ) -> Result> { + // Check if this is a table function call (either Table or Function variant) + let (tbl_func_name, func_args_vec, alias) = match relation { + TableFactor::Table { + name, + args: Some(func_args), + alias, + .. + } => { + let name_str = name + .0 + .first() + .and_then(|ident| ident.as_ident()) + .ok_or_else(|| plan_datafusion_err!("Invalid table function name"))? + .to_string(); + (name_str, func_args.args.clone(), alias.as_ref()) + } + TableFactor::Function { + name, args, alias, .. + } => { + let name_str = name + .0 + .first() + .and_then(|ident| ident.as_ident()) + .ok_or_else(|| plan_datafusion_err!("Invalid function name"))? + .to_string(); + (name_str, args.clone(), alias.as_ref()) + } + _ => return Ok(None), + }; + + // Parse arguments to expressions + // Use the outer from schema so that column references are properly recognized + let schema_for_args = planner_context + .outer_from_schema() + .unwrap_or_else(|| Arc::new(DFSchema::empty())); + + let args = func_args_vec + .iter() + .map(|arg| { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg { + self.sql_expr_to_logical_expr( + expr.clone(), + &schema_for_args, + planner_context, + ) + } else { + plan_err!("Unsupported function argument type") + } + }) + .collect::>>()?; + + // LATERAL functions evaluate row-by-row when referencing outer columns + let has_column_refs = args.iter().any(|expr| { + matches!(expr, Expr::Column(_)) || expr.contains_outer() + }); + + if has_column_refs { + // For table functions with outer references, we need to get the schema + // but can't actually call the function yet (outer refs not resolved). + // We'll replace outer references with placeholder literals to get the schema. + let placeholder_args: Vec = args + .iter() + .enumerate() + .map(|(idx, arg)| { + if matches!(arg, Expr::Column(_)) || arg.contains_outer() { + let data_type = arg.get_type(&schema_for_args)?; + + // Use incrementing values (1, 2, 3...) to ensure valid ranges for functions + // like generate_series(start, end) where start < end + let val = (idx + 1) as i64; + + let placeholder = + ScalarValue::try_new_placeholder(&data_type, val)?; + + Ok(Expr::Literal(placeholder, None)) + } else { + Ok(arg.clone()) + } + }) + .collect::>>()?; + + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, placeholder_args)?; + let tf_schema = provider.schema(); + + let qualifier = alias + .map(|a| self.ident_normalizer.normalize(a.name.clone())) + .unwrap_or_else(|| format!("{tbl_func_name}()")); + + let tf_df_schema = + DFSchema::try_from_qualified_schema(qualifier.as_str(), &tf_schema)?; + + let combined_schema = left.schema().join(&tf_df_schema)?; + + let lateral_tf = LateralTableFunction { + input: Arc::new(left.clone()), + function_name: tbl_func_name.clone(), + args, + schema: Arc::new(combined_schema), + table_function_schema: Arc::new(tf_df_schema), + }; + return Ok(Some(LogicalPlan::LateralTableFunction(lateral_tf))); + } + + Ok(None) + } } /// Return `true` iff the given [`TableFactor`] is lateral. diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9dfa078701d3..44522f748e26 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -61,17 +61,28 @@ impl SqlToRel<'_, S> { } }) .collect::>(); - let provider = self - .context_provider - .get_table_function_source(&tbl_func_name, args)?; - let plan = LogicalPlanBuilder::scan( - TableReference::Bare { - table: format!("{tbl_func_name}()").into(), - }, - provider, - None, - )? - .build()?; + + let has_outer_references = + args.iter().any(|expr| expr.contains_outer()); + + let plan = if has_outer_references { + return not_impl_err!( + "Table function arguments cannot reference columns without LATERAL keyword. \ + Use: FROM other_table, LATERAL table_func(other_table.column)" + ); + } else { + let provider = self + .context_provider + .get_table_function_source(&tbl_func_name, args)?; + LogicalPlanBuilder::scan( + TableReference::Bare { + table: format!("{tbl_func_name}()").into(), + }, + provider, + None, + )? + .build()? + }; (plan, alias) } else { // Normalize name and alias diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index b6c65614995a..3baae56e92a4 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,6 +124,7 @@ impl Unparser<'_> { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::LateralTableFunction(_) | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), } } diff --git a/datafusion/sqllogictest/test_files/lateral_table_functions.slt b/datafusion/sqllogictest/test_files/lateral_table_functions.slt new file mode 100644 index 000000000000..4c656b5a293a --- /dev/null +++ b/datafusion/sqllogictest/test_files/lateral_table_functions.slt @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for LATERAL table functions +# LATERAL allows table-valued functions to reference columns from outer queries + +#### +# Setup test data +#### + +statement ok +CREATE TABLE orders(order_id BIGINT, quantity BIGINT, customer_id BIGINT) AS VALUES +(1, 3, 100), +(2, 2, 101), +(3, 5, 100), +(4, 1, 102); + +#### +# Basic LATERAL tests with generate_series +#### + +# Basic LATERAL with generate_series +query III rowsort +SELECT order_id, quantity, s.value +FROM orders o +CROSS JOIN LATERAL generate_series(1, o.quantity) AS s(value); +---- +1 3 1 +1 3 2 +1 3 3 +2 2 1 +2 2 2 +3 5 1 +3 5 2 +3 5 3 +3 5 4 +3 5 5 +4 1 1 + +# LATERAL with WHERE clause on outer table +query III rowsort +SELECT order_id, quantity, s.value +FROM orders o +CROSS JOIN LATERAL generate_series(1, o.quantity) AS s(value) +WHERE o.customer_id = 100; +---- +1 3 1 +1 3 2 +1 3 3 +3 5 1 +3 5 2 +3 5 3 +3 5 4 +3 5 5 + +# LATERAL with range function (exclusive end) +query III rowsort +SELECT order_id, quantity, s.value +FROM orders o +CROSS JOIN LATERAL range(0, o.quantity) AS s(value); +---- +1 3 0 +1 3 1 +1 3 2 +2 2 0 +2 2 1 +3 5 0 +3 5 1 +3 5 2 +3 5 3 +3 5 4 +4 1 0 + +# LATERAL with aggregation +query II +SELECT customer_id, SUM(s.value) as total +FROM orders o +CROSS JOIN LATERAL generate_series(1, o.quantity) AS s(value) +GROUP BY customer_id +ORDER BY customer_id; +---- +100 21 +101 3 +102 1 + +# LATERAL with multiple outer references +statement ok +CREATE TABLE products(product_id BIGINT, min_qty BIGINT, max_qty BIGINT) AS VALUES +(1, 1, 3), +(2, 2, 4); + +query III rowsort +SELECT product_id, min_qty, s.value +FROM products p +CROSS JOIN LATERAL generate_series(p.min_qty, p.max_qty) AS s(value); +---- +1 1 1 +1 1 2 +1 1 3 +2 2 2 +2 2 3 +2 2 4 + +# LATERAL with projected columns only +query I rowsort +SELECT s.value +FROM orders o +CROSS JOIN LATERAL generate_series(1, o.quantity) AS s(value) +WHERE order_id = 2; +---- +1 +2 + +# Multiple LATERAL joins - sequential application +query IIII rowsort +SELECT o1.order_id, o1.quantity, s1.value as series1, s2.value as series2 +FROM orders o1 +CROSS JOIN LATERAL generate_series(1, o1.quantity) AS s1(value) +CROSS JOIN LATERAL generate_series(1, s1.value) AS s2(value) +WHERE o1.order_id = 2; +---- +2 2 1 1 +2 2 2 1 +2 2 2 2 + +# LATERAL with ORDER BY +query III +SELECT order_id, quantity, s.value +FROM orders o +CROSS JOIN LATERAL generate_series(1, o.quantity) AS s(value) +ORDER BY order_id, s.value +LIMIT 5; +---- +1 3 1 +1 3 2 +1 3 3 +2 2 1 +2 2 2 + +#### +# Cleanup +#### + +statement ok +DROP TABLE orders; + +statement ok +DROP TABLE products; diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..c71b090e1574 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,8 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::LateralTableFunction(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } } }