Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ path = "examples/external_dependency/dataframe-to-s3.rs"
name = "query_aws_s3"
path = "examples/external_dependency/query-aws-s3.rs"

[[example]]
name = "custom_file_casts"
path = "examples/custom_file_casts.rs"

[dev-dependencies]
arrow = { workspace = true }
# arrow_schema is required for record_batch! macro :sad:
Expand Down
205 changes: 205 additions & 0 deletions datafusion-examples/examples/custom_file_casts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::{record_batch, RecordBatch};
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};

use datafusion::assert_batches_eq;
use datafusion::common::not_impl_err;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::{Result, ScalarValue};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::physical_expr::expressions::CastExpr;
use datafusion::physical_expr::schema_rewriter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::prelude::SessionConfig;
use object_store::memory::InMemory;
use object_store::path::Path;
use object_store::{ObjectStore, PutPayload};

// Example showing how to implement custom casting rules to adapt file schemas.
// This example enforces that casts must be stricly widening: if the file type is Int64 and the table type is Int32, it will error
// before even reading the data.
// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast.

#[tokio::main]
async fn main() -> Result<()> {
println!("=== Creating example data ===");

// Create a logical / table schema with an Int32 column
let logical_schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

// Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing)
let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
let path = Path::from("good.parquet");
let batch = record_batch!(("id", Int16, [1, 2, 3]))?;
write_data(&store, &path, &batch).await?;
let path = Path::from("bad.parquet");
let batch = record_batch!(("id", Int64, [1, 2, 3]))?;
write_data(&store, &path, &batch).await?;

// Set up query execution
let mut cfg = SessionConfig::new();
// Turn on filter pushdown so that the PhysicalExprAdapter is used
cfg.options_mut().execution.parquet.pushdown_filters = true;
let ctx = SessionContext::new_with_config(cfg);
ctx.runtime_env()
.register_object_store(ObjectStoreUrl::parse("memory://")?.as_ref(), store);

// Register our good and bad files via ListingTable
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///good.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(Arc::clone(&logical_schema))
.with_expr_adapter_factory(Arc::new(
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
DefaultPhysicalExprAdapterFactory,
)),
));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("good_table", Arc::new(table))?;
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///bad.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(Arc::clone(&logical_schema))
.with_expr_adapter_factory(Arc::new(
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
DefaultPhysicalExprAdapterFactory,
)),
));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("bad_table", Arc::new(table))?;

println!("\n=== File with narrower schema is cast ===");
let query = "SELECT id FROM good_table WHERE id > 1";
println!("Query: {query}");
let batches = ctx.sql(query).await?.collect().await?;
#[rustfmt::skip]
let expected = [
"+----+",
"| id |",
"+----+",
"| 2 |",
"| 3 |",
"+----+",
];
arrow::util::pretty::print_batches(&batches)?;
assert_batches_eq!(expected, &batches);

println!("\n=== File with wider schema errors ===");
let query = "SELECT id FROM bad_table WHERE id > 1";
println!("Query: {query}");
match ctx.sql(query).await?.collect().await {
Ok(_) => panic!("Expected error for narrowing cast, but query succeeded"),
Err(e) => {
println!("Caught expected error: {e}");
}
}
Ok(())
}

async fn write_data(
store: &dyn ObjectStore,
path: &Path,
batch: &RecordBatch,
) -> Result<()> {
let mut buf = vec![];
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?;
writer.write(batch)?;
writer.close()?;

let payload = PutPayload::from_bytes(buf.into());
store.put(path, payload).await?;
Ok(())
}

/// Factory for creating DefaultValuePhysicalExprAdapter instances
#[derive(Debug)]
struct CustomCastPhysicalExprAdapterFactory {
inner: Arc<dyn PhysicalExprAdapterFactory>,
}

impl CustomCastPhysicalExprAdapterFactory {
fn new(inner: Arc<dyn PhysicalExprAdapterFactory>) -> Self {
Self { inner }
}
}

impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory {
fn create(
&self,
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
) -> Arc<dyn PhysicalExprAdapter> {
let inner = self
.inner
.create(logical_file_schema, Arc::clone(&physical_file_schema));
Arc::new(CustomCastsPhysicalExprAdapter {
physical_file_schema,
inner,
})
}
}

/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata
/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation
#[derive(Debug, Clone)]
struct CustomCastsPhysicalExprAdapter {
physical_file_schema: SchemaRef,
inner: Arc<dyn PhysicalExprAdapter>,
}

impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter {
fn rewrite(&self, mut expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First delegate to the inner adapter to handle missing columns and discover any necessary casts
expr = self.inner.rewrite(expr)?;
// Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression
// For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138).
expr.transform(|expr| {
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
let input_data_type = cast.expr().data_type(&self.physical_file_schema)?;
let output_data_type = cast.data_type(&self.physical_file_schema)?;
if !cast.is_bigger_cast(&input_data_type) {
return not_impl_err!("Unsupported CAST from {input_data_type:?} to {output_data_type:?}")
}
}
Ok(Transformed::no(expr))
}).data()
}

fn with_partition_values(
&self,
partition_values: Vec<(FieldRef, ScalarValue)>,
) -> Arc<dyn PhysicalExprAdapter> {
Arc::new(Self {
inner: self.inner.with_partition_values(partition_values),
..self.clone()
})
}
}
18 changes: 9 additions & 9 deletions datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub struct ListingTableConfig {
/// Optional [`SchemaAdapterFactory`] for creating schema adapters
schema_adapter_factory: Option<Arc<dyn SchemaAdapterFactory>>,
/// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters
physical_expr_adapter_factory: Option<Arc<dyn PhysicalExprAdapterFactory>>,
expr_adapter_factory: Option<Arc<dyn PhysicalExprAdapterFactory>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rename might cause api_change label along with Upgrade Guide changes, I would stick to just example in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API was added yesterday in #16791

}

impl ListingTableConfig {
Expand Down Expand Up @@ -284,7 +284,7 @@ impl ListingTableConfig {
options: Some(listing_options),
schema_source: self.schema_source,
schema_adapter_factory: self.schema_adapter_factory,
physical_expr_adapter_factory: self.physical_expr_adapter_factory,
expr_adapter_factory: self.expr_adapter_factory,
})
}

Expand All @@ -304,7 +304,7 @@ impl ListingTableConfig {
options: _,
schema_source,
schema_adapter_factory,
physical_expr_adapter_factory,
expr_adapter_factory: physical_expr_adapter_factory,
} = self;

let (schema, new_schema_source) = match file_schema {
Expand All @@ -327,7 +327,7 @@ impl ListingTableConfig {
options: Some(options),
schema_source: new_schema_source,
schema_adapter_factory,
physical_expr_adapter_factory,
expr_adapter_factory: physical_expr_adapter_factory,
})
}
None => internal_err!("No `ListingOptions` set for inferring schema"),
Expand Down Expand Up @@ -370,7 +370,7 @@ impl ListingTableConfig {
options: Some(options),
schema_source: self.schema_source,
schema_adapter_factory: self.schema_adapter_factory,
physical_expr_adapter_factory: self.physical_expr_adapter_factory,
expr_adapter_factory: self.expr_adapter_factory,
})
}
None => config_err!("No `ListingOptions` set for inferring schema"),
Expand Down Expand Up @@ -433,12 +433,12 @@ impl ListingTableConfig {
/// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used.
///
/// See <https://github.com/apache/datafusion/issues/16800> for details on this transition.
pub fn with_physical_expr_adapter_factory(
pub fn with_expr_adapter_factory(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In writing the example I noticed that this name is very long and probably too verbose.
This tones it down and brings it in line with the field name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this explanation -- I have found trying to write out examples of using APIs almost always leads to discoveries of ways to improve them for the better. Thank you!

self,
physical_expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
) -> Self {
Self {
physical_expr_adapter_factory: Some(physical_expr_adapter_factory),
expr_adapter_factory: Some(expr_adapter_factory),
..self
}
}
Expand Down Expand Up @@ -981,7 +981,7 @@ impl ListingTable {
constraints: Constraints::default(),
column_defaults: HashMap::new(),
schema_adapter_factory: config.schema_adapter_factory,
expr_adapter_factory: config.physical_expr_adapter_factory,
expr_adapter_factory: config.expr_adapter_factory,
};

Ok(table)
Expand Down
12 changes: 3 additions & 9 deletions datafusion/core/tests/parquet/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
.unwrap()
.with_schema(table_schema.clone())
.with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory))
.with_physical_expr_adapter_factory(Arc::new(
DefaultPhysicalExprAdapterFactory,
));
.with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory));

let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
Expand Down Expand Up @@ -324,9 +322,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
.await
.unwrap()
.with_schema(table_schema.clone())
.with_physical_expr_adapter_factory(Arc::new(
CustomPhysicalExprAdapterFactory,
));
.with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.deregister_table("t").unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
Expand Down Expand Up @@ -354,9 +350,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
.unwrap()
.with_schema(table_schema.clone())
.with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory))
.with_physical_expr_adapter_factory(Arc::new(
CustomPhysicalExprAdapterFactory,
));
.with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.deregister_table("t").unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
Expand Down