diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 979ada2bc6bb..e44e77831264 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -5160,7 +5160,7 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> { assert_snapshot!( actual, @r" - Projection: Int32(3) AS $1 [$1:Null;N] + Projection: Int32(3) AS $1 [$1:Int32] EmptyRelation: rows=1 [] " ); diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 2eb3ba36dd90..4b6d3241d64b 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -18,7 +18,6 @@ use std::collections::HashMap; use super::*; -use datafusion::assert_batches_eq; use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue}; use insta::assert_snapshot; @@ -343,26 +342,53 @@ async fn test_query_parameters_with_metadata() -> Result<()> { ])) .unwrap(); - // df_with_params_replaced.schema() is not correct here - // https://github.com/apache/datafusion/issues/18102 - let batches = df_with_params_replaced.clone().collect().await.unwrap(); - let schema = batches[0].schema(); - + let schema = df_with_params_replaced.schema(); assert_eq!(schema.field(0).data_type(), &DataType::UInt32); assert_eq!(schema.field(0).metadata(), &metadata1); assert_eq!(schema.field(1).data_type(), &DataType::Utf8); assert_eq!(schema.field(1).metadata(), &metadata2); - assert_batches_eq!( - [ - "+----+-----+", - "| $1 | $2 |", - "+----+-----+", - "| 1 | two |", - "+----+-----+", - ], - &batches - ); + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-----+ + | $1 | $2 | + +----+-----+ + | 1 | two | + +----+-----+ + "); + + Ok(()) +} + +/// Test for https://github.com/apache/datafusion/issues/18102 +#[tokio::test] +async fn test_query_parameters_in_values_list_relation() -> Result<()> { + let ctx = SessionContext::new(); + + let df = ctx + .sql("SELECT a, b FROM (VALUES ($1, $2)) AS t(a, b)") + .await + .unwrap(); + + let df_with_params_replaced = df + .with_param_values(ParamValues::List(vec![ + ScalarAndMetadata::new(ScalarValue::UInt32(Some(1)), None), + ScalarAndMetadata::new(ScalarValue::Utf8(Some("two".to_string())), None), + ])) + .unwrap(); + + let schema = df_with_params_replaced.schema(); + assert_eq!(schema.field(0).data_type(), &DataType::UInt32); + assert_eq!(schema.field(1).data_type(), &DataType::Utf8); + + let batches = df_with_params_replaced.collect().await.unwrap(); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +---+-----+ + | a | b | + +---+-----+ + | 1 | two | + +---+-----+ + "); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9541f35e3062..1c12e0645eae 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -633,9 +633,9 @@ impl LogicalPlan { }) => Projection::try_new(expr, input).map(LogicalPlan::Projection), LogicalPlan::Dml(_) => Ok(self), LogicalPlan::Copy(_) => Ok(self), - LogicalPlan::Values(Values { schema, values }) => { - // todo it isn't clear why the schema is not recomputed here - Ok(LogicalPlan::Values(Values { schema, values })) + LogicalPlan::Values(Values { values, schema: _ }) => { + // TODO: docs why we compute this + LogicalPlanBuilder::values(values)?.build() } LogicalPlan::Filter(Filter { predicate, input }) => { Filter::try_new(predicate, input).map(LogicalPlan::Filter) @@ -1451,7 +1451,7 @@ impl LogicalPlan { self.transform_up_with_subqueries(|plan| { let schema = Arc::clone(plan.schema()); let name_preserver = NamePreserver::new(&plan); - plan.map_expressions(|e| { + let transformed_plan = plan.map_expressions(|e| { let (e, has_placeholder) = e.infer_placeholder_types(&schema)?; if !has_placeholder { // Performance optimization: @@ -1473,7 +1473,16 @@ impl LogicalPlan { // Preserve name to avoid breaking column references to this expression Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) } - }) + }); + + // TODO: docs, explain why we do this + // TODO: lazily compute upon transformed + // TODO: `recompute_schema` doesn't work for children ? + if let Ok(transformed_plan) = transformed_plan { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + transformed_plan + } }) .map(|res| res.data) } @@ -4246,6 +4255,7 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; + use crate::select_expr::SelectExpr; use crate::test::function_stub::{count, count_udaf}; use crate::{ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, @@ -4824,6 +4834,71 @@ mod tests { .expect_err("prepared field metadata mismatch unexpectedly succeeded"); } + #[test] + fn test_replace_placeholder_values_relation_valid_schema() { + // SELECT a, b FROM (VALUES ($1, $2)) AS t(a, b); + let plan = + LogicalPlanBuilder::values(vec![vec![placeholder("$1"), placeholder("$2")]]) + .unwrap() + .project(vec![col("column1").alias("a"), col("column2").alias("b")]) + .unwrap() + .alias("t") + .unwrap() + .project(vec![col("a"), col("b")]) + .unwrap() + .build() + .unwrap(); + + // original + assert_snapshot!(plan.display_indent_schema(), @r#" + Projection: t.a, t.b [a:Null;N, b:Null;N] + SubqueryAlias: t [a:Null;N, b:Null;N] + Projection: column1 AS a, column2 AS b [a:Null;N, b:Null;N] + Values: ($1, $2) [column1:Null;N, column2:Null;N] + "#); + + let plan = plan + .with_param_values(vec![ScalarValue::from(1i32), ScalarValue::from("s")]) + .unwrap(); + + // replaced + assert_snapshot!(plan.display_indent_schema(), @r#" + Projection: t.a, t.b [a:Int32;N, b:Utf8;N] + SubqueryAlias: t [a:Int32;N, b:Utf8;N] + Projection: column1 AS a, column2 AS b [a:Int32;N, b:Utf8;N] + Values: (Int32(1) AS $1, Utf8("s") AS $2) [column1:Int32;N, column2:Utf8;N] + "#); + } + + #[test] + fn test_replace_placeholder_empty_relation_valid_schema() { + // SELECT $1, $2; + let plan = LogicalPlanBuilder::empty(false) + .project(vec![ + SelectExpr::from(placeholder("$1")), + SelectExpr::from(placeholder("$2")), + ]) + .unwrap() + .build() + .unwrap(); + + // original + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: $1, $2 [$1:Null;N, $2:Null;N] + EmptyRelation: rows=0 [] + "); + + let plan = plan + .with_param_values(vec![ScalarValue::from(1i32), ScalarValue::from("s")]) + .unwrap(); + + // replaced + assert_snapshot!(plan.display_indent_schema(), @r#" + Projection: Int32(1) AS $1, Utf8("s") AS $2 [$1:Int32, $2:Utf8] + EmptyRelation: rows=0 [] + "#); + } + #[test] fn test_nullable_schema_after_grouping_set() { let schema = Schema::new(vec![