Skip to content

Commit eabf3b7

Browse files
authored
Add example of custom file schema casting rules (#16803)
1 parent 50e6114 commit eabf3b7

File tree

4 files changed

+221
-18
lines changed

4 files changed

+221
-18
lines changed

datafusion-examples/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ path = "examples/external_dependency/dataframe-to-s3.rs"
5252
name = "query_aws_s3"
5353
path = "examples/external_dependency/query-aws-s3.rs"
5454

55+
[[example]]
56+
name = "custom_file_casts"
57+
path = "examples/custom_file_casts.rs"
58+
5559
[dev-dependencies]
5660
arrow = { workspace = true }
5761
# arrow_schema is required for record_batch! macro :sad:
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{record_batch, RecordBatch};
21+
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};
22+
23+
use datafusion::assert_batches_eq;
24+
use datafusion::common::not_impl_err;
25+
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
26+
use datafusion::common::{Result, ScalarValue};
27+
use datafusion::datasource::listing::{
28+
ListingTable, ListingTableConfig, ListingTableUrl,
29+
};
30+
use datafusion::execution::context::SessionContext;
31+
use datafusion::execution::object_store::ObjectStoreUrl;
32+
use datafusion::parquet::arrow::ArrowWriter;
33+
use datafusion::physical_expr::expressions::CastExpr;
34+
use datafusion::physical_expr::schema_rewriter::{
35+
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
36+
};
37+
use datafusion::physical_expr::PhysicalExpr;
38+
use datafusion::prelude::SessionConfig;
39+
use object_store::memory::InMemory;
40+
use object_store::path::Path;
41+
use object_store::{ObjectStore, PutPayload};
42+
43+
// Example showing how to implement custom casting rules to adapt file schemas.
44+
// This example enforces that casts must be stricly widening: if the file type is Int64 and the table type is Int32, it will error
45+
// before even reading the data.
46+
// Without this custom cast rule DataFusion would happily do the narrowing cast, potentially erroring only if it found a row with data it could not cast.
47+
48+
#[tokio::main]
49+
async fn main() -> Result<()> {
50+
println!("=== Creating example data ===");
51+
52+
// Create a logical / table schema with an Int32 column
53+
let logical_schema =
54+
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
55+
56+
// Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing)
57+
let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
58+
let path = Path::from("good.parquet");
59+
let batch = record_batch!(("id", Int16, [1, 2, 3]))?;
60+
write_data(&store, &path, &batch).await?;
61+
let path = Path::from("bad.parquet");
62+
let batch = record_batch!(("id", Int64, [1, 2, 3]))?;
63+
write_data(&store, &path, &batch).await?;
64+
65+
// Set up query execution
66+
let mut cfg = SessionConfig::new();
67+
// Turn on filter pushdown so that the PhysicalExprAdapter is used
68+
cfg.options_mut().execution.parquet.pushdown_filters = true;
69+
let ctx = SessionContext::new_with_config(cfg);
70+
ctx.runtime_env()
71+
.register_object_store(ObjectStoreUrl::parse("memory://")?.as_ref(), store);
72+
73+
// Register our good and bad files via ListingTable
74+
let listing_table_config =
75+
ListingTableConfig::new(ListingTableUrl::parse("memory:///good.parquet")?)
76+
.infer_options(&ctx.state())
77+
.await?
78+
.with_schema(Arc::clone(&logical_schema))
79+
.with_expr_adapter_factory(Arc::new(
80+
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
81+
DefaultPhysicalExprAdapterFactory,
82+
)),
83+
));
84+
let table = ListingTable::try_new(listing_table_config).unwrap();
85+
ctx.register_table("good_table", Arc::new(table))?;
86+
let listing_table_config =
87+
ListingTableConfig::new(ListingTableUrl::parse("memory:///bad.parquet")?)
88+
.infer_options(&ctx.state())
89+
.await?
90+
.with_schema(Arc::clone(&logical_schema))
91+
.with_expr_adapter_factory(Arc::new(
92+
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
93+
DefaultPhysicalExprAdapterFactory,
94+
)),
95+
));
96+
let table = ListingTable::try_new(listing_table_config).unwrap();
97+
ctx.register_table("bad_table", Arc::new(table))?;
98+
99+
println!("\n=== File with narrower schema is cast ===");
100+
let query = "SELECT id FROM good_table WHERE id > 1";
101+
println!("Query: {query}");
102+
let batches = ctx.sql(query).await?.collect().await?;
103+
#[rustfmt::skip]
104+
let expected = [
105+
"+----+",
106+
"| id |",
107+
"+----+",
108+
"| 2 |",
109+
"| 3 |",
110+
"+----+",
111+
];
112+
arrow::util::pretty::print_batches(&batches)?;
113+
assert_batches_eq!(expected, &batches);
114+
115+
println!("\n=== File with wider schema errors ===");
116+
let query = "SELECT id FROM bad_table WHERE id > 1";
117+
println!("Query: {query}");
118+
match ctx.sql(query).await?.collect().await {
119+
Ok(_) => panic!("Expected error for narrowing cast, but query succeeded"),
120+
Err(e) => {
121+
println!("Caught expected error: {e}");
122+
}
123+
}
124+
Ok(())
125+
}
126+
127+
async fn write_data(
128+
store: &dyn ObjectStore,
129+
path: &Path,
130+
batch: &RecordBatch,
131+
) -> Result<()> {
132+
let mut buf = vec![];
133+
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?;
134+
writer.write(batch)?;
135+
writer.close()?;
136+
137+
let payload = PutPayload::from_bytes(buf.into());
138+
store.put(path, payload).await?;
139+
Ok(())
140+
}
141+
142+
/// Factory for creating DefaultValuePhysicalExprAdapter instances
143+
#[derive(Debug)]
144+
struct CustomCastPhysicalExprAdapterFactory {
145+
inner: Arc<dyn PhysicalExprAdapterFactory>,
146+
}
147+
148+
impl CustomCastPhysicalExprAdapterFactory {
149+
fn new(inner: Arc<dyn PhysicalExprAdapterFactory>) -> Self {
150+
Self { inner }
151+
}
152+
}
153+
154+
impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory {
155+
fn create(
156+
&self,
157+
logical_file_schema: SchemaRef,
158+
physical_file_schema: SchemaRef,
159+
) -> Arc<dyn PhysicalExprAdapter> {
160+
let inner = self
161+
.inner
162+
.create(logical_file_schema, Arc::clone(&physical_file_schema));
163+
Arc::new(CustomCastsPhysicalExprAdapter {
164+
physical_file_schema,
165+
inner,
166+
})
167+
}
168+
}
169+
170+
/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata
171+
/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation
172+
#[derive(Debug, Clone)]
173+
struct CustomCastsPhysicalExprAdapter {
174+
physical_file_schema: SchemaRef,
175+
inner: Arc<dyn PhysicalExprAdapter>,
176+
}
177+
178+
impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter {
179+
fn rewrite(&self, mut expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
180+
// First delegate to the inner adapter to handle missing columns and discover any necessary casts
181+
expr = self.inner.rewrite(expr)?;
182+
// Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression
183+
// For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138).
184+
expr.transform(|expr| {
185+
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
186+
let input_data_type = cast.expr().data_type(&self.physical_file_schema)?;
187+
let output_data_type = cast.data_type(&self.physical_file_schema)?;
188+
if !cast.is_bigger_cast(&input_data_type) {
189+
return not_impl_err!("Unsupported CAST from {input_data_type:?} to {output_data_type:?}")
190+
}
191+
}
192+
Ok(Transformed::no(expr))
193+
}).data()
194+
}
195+
196+
fn with_partition_values(
197+
&self,
198+
partition_values: Vec<(FieldRef, ScalarValue)>,
199+
) -> Arc<dyn PhysicalExprAdapter> {
200+
Arc::new(Self {
201+
inner: self.inner.with_partition_values(partition_values),
202+
..self.clone()
203+
})
204+
}
205+
}

datafusion/core/src/datasource/listing/table.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub struct ListingTableConfig {
101101
/// Optional [`SchemaAdapterFactory`] for creating schema adapters
102102
schema_adapter_factory: Option<Arc<dyn SchemaAdapterFactory>>,
103103
/// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters
104-
physical_expr_adapter_factory: Option<Arc<dyn PhysicalExprAdapterFactory>>,
104+
expr_adapter_factory: Option<Arc<dyn PhysicalExprAdapterFactory>>,
105105
}
106106

107107
impl ListingTableConfig {
@@ -284,7 +284,7 @@ impl ListingTableConfig {
284284
options: Some(listing_options),
285285
schema_source: self.schema_source,
286286
schema_adapter_factory: self.schema_adapter_factory,
287-
physical_expr_adapter_factory: self.physical_expr_adapter_factory,
287+
expr_adapter_factory: self.expr_adapter_factory,
288288
})
289289
}
290290

@@ -304,7 +304,7 @@ impl ListingTableConfig {
304304
options: _,
305305
schema_source,
306306
schema_adapter_factory,
307-
physical_expr_adapter_factory,
307+
expr_adapter_factory: physical_expr_adapter_factory,
308308
} = self;
309309

310310
let (schema, new_schema_source) = match file_schema {
@@ -327,7 +327,7 @@ impl ListingTableConfig {
327327
options: Some(options),
328328
schema_source: new_schema_source,
329329
schema_adapter_factory,
330-
physical_expr_adapter_factory,
330+
expr_adapter_factory: physical_expr_adapter_factory,
331331
})
332332
}
333333
None => internal_err!("No `ListingOptions` set for inferring schema"),
@@ -370,7 +370,7 @@ impl ListingTableConfig {
370370
options: Some(options),
371371
schema_source: self.schema_source,
372372
schema_adapter_factory: self.schema_adapter_factory,
373-
physical_expr_adapter_factory: self.physical_expr_adapter_factory,
373+
expr_adapter_factory: self.expr_adapter_factory,
374374
})
375375
}
376376
None => config_err!("No `ListingOptions` set for inferring schema"),
@@ -433,12 +433,12 @@ impl ListingTableConfig {
433433
/// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used.
434434
///
435435
/// See <https://github.com/apache/datafusion/issues/16800> for details on this transition.
436-
pub fn with_physical_expr_adapter_factory(
436+
pub fn with_expr_adapter_factory(
437437
self,
438-
physical_expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
438+
expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
439439
) -> Self {
440440
Self {
441-
physical_expr_adapter_factory: Some(physical_expr_adapter_factory),
441+
expr_adapter_factory: Some(expr_adapter_factory),
442442
..self
443443
}
444444
}
@@ -981,7 +981,7 @@ impl ListingTable {
981981
constraints: Constraints::default(),
982982
column_defaults: HashMap::new(),
983983
schema_adapter_factory: config.schema_adapter_factory,
984-
expr_adapter_factory: config.physical_expr_adapter_factory,
984+
expr_adapter_factory: config.expr_adapter_factory,
985985
};
986986

987987
Ok(table)

datafusion/core/tests/parquet/schema_adapter.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
264264
.unwrap()
265265
.with_schema(table_schema.clone())
266266
.with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory))
267-
.with_physical_expr_adapter_factory(Arc::new(
268-
DefaultPhysicalExprAdapterFactory,
269-
));
267+
.with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory));
270268

271269
let table = ListingTable::try_new(listing_table_config).unwrap();
272270
ctx.register_table("t", Arc::new(table)).unwrap();
@@ -324,9 +322,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
324322
.await
325323
.unwrap()
326324
.with_schema(table_schema.clone())
327-
.with_physical_expr_adapter_factory(Arc::new(
328-
CustomPhysicalExprAdapterFactory,
329-
));
325+
.with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory));
330326
let table = ListingTable::try_new(listing_table_config).unwrap();
331327
ctx.deregister_table("t").unwrap();
332328
ctx.register_table("t", Arc::new(table)).unwrap();
@@ -354,9 +350,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() {
354350
.unwrap()
355351
.with_schema(table_schema.clone())
356352
.with_schema_adapter_factory(Arc::new(CustomSchemaAdapterFactory))
357-
.with_physical_expr_adapter_factory(Arc::new(
358-
CustomPhysicalExprAdapterFactory,
359-
));
353+
.with_expr_adapter_factory(Arc::new(CustomPhysicalExprAdapterFactory));
360354
let table = ListingTable::try_new(listing_table_config).unwrap();
361355
ctx.deregister_table("t").unwrap();
362356
ctx.register_table("t", Arc::new(table)).unwrap();

0 commit comments

Comments
 (0)