Skip to content

Commit ceeab4f

Browse files
committed
feat(datafusion-proto): allow TableSource to be serialized
Currently, only instances of `TableProvider` are considered by `LogicalExtensionCodec`, and are automatically wrapped in a `DefaultTableSource` when deserializing. That doesn't work with custom extensions. Fixes #16749
1 parent 8d7b11b commit ceeab4f

File tree

6 files changed

+93
-69
lines changed

6 files changed

+93
-69
lines changed

datafusion/catalog/src/default_table_source.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ impl DefaultTableSource {
4545
pub fn new(table_provider: Arc<dyn TableProvider>) -> Self {
4646
Self { table_provider }
4747
}
48+
49+
/// Attempt to downcast a TableSource to DefaultTableSource and access the
50+
/// TableProvider. This will only work with a TableSource created by DataFusion.
51+
pub fn unwrap_provider<T: TableProvider + 'static>(
52+
source: &Arc<dyn TableSource>,
53+
) -> datafusion_common::Result<&T> {
54+
if let Some(source) = source
55+
.as_ref()
56+
.as_any()
57+
.downcast_ref::<DefaultTableSource>()
58+
{
59+
if let Some(provider) =
60+
source.table_provider.as_ref().as_any().downcast_ref::<T>()
61+
{
62+
return Ok(provider);
63+
}
64+
}
65+
66+
internal_err!("TableSource was not expected type")
67+
}
4868
}
4969

5070
impl TableSource for DefaultTableSource {
@@ -87,7 +107,7 @@ impl TableSource for DefaultTableSource {
87107
}
88108
}
89109

90-
/// Wrap TableProvider in TableSource
110+
/// Wrap a TableProvider as a TableSource.
91111
pub fn provider_as_source(
92112
table_provider: Arc<dyn TableProvider>,
93113
) -> Arc<dyn TableSource> {

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ pub mod provider;
3030
mod view_test;
3131

3232
// backwards compatibility
33-
pub use self::default_table_source::{
34-
provider_as_source, source_as_provider, DefaultTableSource,
35-
};
33+
pub use self::default_table_source::{provider_as_source, DefaultTableSource};
3634
pub use self::memory::MemTable;
3735
pub use self::view::ViewTable;
3836
pub use crate::catalog::TableProvider;

datafusion/core/src/physical_planner.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::sync::Arc;
2424
use crate::datasource::file_format::file_type_to_format;
2525
use crate::datasource::listing::ListingTableUrl;
2626
use crate::datasource::physical_plan::FileSinkConfig;
27-
use crate::datasource::{source_as_provider, DefaultTableSource};
27+
use crate::datasource::DefaultTableSource;
2828
use crate::error::{DataFusionError, Result};
2929
use crate::execution::context::{ExecutionProps, SessionState};
3030
use crate::logical_expr::utils::generate_sort_key;
@@ -60,6 +60,7 @@ use crate::schema_equivalence::schema_satisfied_by;
6060
use arrow::array::{builder::StringBuilder, RecordBatch};
6161
use arrow::compute::SortOptions;
6262
use arrow::datatypes::{Schema, SchemaRef};
63+
use datafusion_catalog::default_table_source::source_as_provider;
6364
use datafusion_common::display::ToStringifiedPlan;
6465
use datafusion_common::tree_node::{
6566
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,

datafusion/proto/src/logical_plan/file_formats.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,20 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
180180
not_impl_err!("Method not implemented")
181181
}
182182

183-
fn try_decode_table_provider(
183+
fn try_decode_table_source(
184184
&self,
185185
_buf: &[u8],
186186
_table_ref: &TableReference,
187187
_schema: arrow::datatypes::SchemaRef,
188188
_ctx: &SessionContext,
189-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
189+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
190190
not_impl_err!("Method not implemented")
191191
}
192192

193-
fn try_encode_table_provider(
193+
fn try_encode_table_source(
194194
&self,
195195
_table_ref: &TableReference,
196-
_node: Arc<dyn datafusion::datasource::TableProvider>,
196+
_node: Arc<dyn datafusion_expr::TableSource>,
197197
_buf: &mut Vec<u8>,
198198
) -> datafusion_common::Result<()> {
199199
not_impl_err!("Method not implemented")
@@ -287,20 +287,20 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
287287
not_impl_err!("Method not implemented")
288288
}
289289

290-
fn try_decode_table_provider(
290+
fn try_decode_table_source(
291291
&self,
292292
_buf: &[u8],
293293
_table_ref: &TableReference,
294294
_schema: arrow::datatypes::SchemaRef,
295295
_ctx: &SessionContext,
296-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
296+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
297297
not_impl_err!("Method not implemented")
298298
}
299299

300-
fn try_encode_table_provider(
300+
fn try_encode_table_source(
301301
&self,
302302
_table_ref: &TableReference,
303-
_node: Arc<dyn datafusion::datasource::TableProvider>,
303+
_node: Arc<dyn datafusion_expr::TableSource>,
304304
_buf: &mut Vec<u8>,
305305
) -> datafusion_common::Result<()> {
306306
not_impl_err!("Method not implemented")
@@ -603,20 +603,20 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec {
603603
not_impl_err!("Method not implemented")
604604
}
605605

606-
fn try_decode_table_provider(
606+
fn try_decode_table_source(
607607
&self,
608608
_buf: &[u8],
609609
_table_ref: &TableReference,
610610
_schema: arrow::datatypes::SchemaRef,
611611
_ctx: &SessionContext,
612-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
612+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
613613
not_impl_err!("Method not implemented")
614614
}
615615

616-
fn try_encode_table_provider(
616+
fn try_encode_table_source(
617617
&self,
618618
_table_ref: &TableReference,
619-
_node: Arc<dyn datafusion::datasource::TableProvider>,
619+
_node: Arc<dyn datafusion_expr::TableSource>,
620620
_buf: &mut Vec<u8>,
621621
) -> datafusion_common::Result<()> {
622622
not_impl_err!("Method not implemented")
@@ -689,20 +689,20 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
689689
not_impl_err!("Method not implemented")
690690
}
691691

692-
fn try_decode_table_provider(
692+
fn try_decode_table_source(
693693
&self,
694694
_buf: &[u8],
695695
_table_ref: &TableReference,
696696
_schema: arrow::datatypes::SchemaRef,
697697
_ctx: &SessionContext,
698-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
698+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
699699
not_impl_err!("Method not implemented")
700700
}
701701

702-
fn try_encode_table_provider(
702+
fn try_encode_table_source(
703703
&self,
704704
_table_ref: &TableReference,
705-
_node: Arc<dyn datafusion::datasource::TableProvider>,
705+
_node: Arc<dyn datafusion_expr::TableSource>,
706706
_buf: &mut Vec<u8>,
707707
) -> datafusion_common::Result<()> {
708708
not_impl_err!("Method not implemented")
@@ -747,20 +747,20 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
747747
not_impl_err!("Method not implemented")
748748
}
749749

750-
fn try_decode_table_provider(
750+
fn try_decode_table_source(
751751
&self,
752752
_buf: &[u8],
753753
_table_ref: &TableReference,
754754
_schema: arrow::datatypes::SchemaRef,
755755
_cts: &SessionContext,
756-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
756+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
757757
not_impl_err!("Method not implemented")
758758
}
759759

760-
fn try_encode_table_provider(
760+
fn try_encode_table_source(
761761
&self,
762762
_table_ref: &TableReference,
763-
_node: Arc<dyn datafusion::datasource::TableProvider>,
763+
_node: Arc<dyn datafusion_expr::TableSource>,
764764
_buf: &mut Vec<u8>,
765765
) -> datafusion_common::Result<()> {
766766
not_impl_err!("Method not implemented")

datafusion/proto/src/logical_plan/mod.rs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ use datafusion::datasource::file_format::parquet::ParquetFormat;
4242
use datafusion::datasource::file_format::{
4343
file_type_to_format, format_as_file_type, FileFormatFactory,
4444
};
45+
use datafusion::datasource::DefaultTableSource;
4546
use datafusion::{
47+
datasource::provider_as_source,
4648
datasource::{
4749
file_format::{
4850
csv::CsvFormat, json::JsonFormat as OtherNdJsonFormat, FileFormat,
4951
},
5052
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
5153
view::ViewTable,
52-
TableProvider,
5354
},
54-
datasource::{provider_as_source, source_as_provider},
5555
prelude::SessionContext,
5656
};
5757
use datafusion_common::file_options::file_type::FileType;
@@ -117,18 +117,18 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
117117

118118
fn try_encode(&self, node: &Extension, buf: &mut Vec<u8>) -> Result<()>;
119119

120-
fn try_decode_table_provider(
120+
fn try_decode_table_source(
121121
&self,
122122
buf: &[u8],
123123
table_ref: &TableReference,
124124
schema: SchemaRef,
125125
ctx: &SessionContext,
126-
) -> Result<Arc<dyn TableProvider>>;
126+
) -> Result<Arc<dyn TableSource>>;
127127

128-
fn try_encode_table_provider(
128+
fn try_encode_table_source(
129129
&self,
130130
table_ref: &TableReference,
131-
node: Arc<dyn TableProvider>,
131+
node: Arc<dyn TableSource>,
132132
buf: &mut Vec<u8>,
133133
) -> Result<()>;
134134

@@ -192,20 +192,20 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {
192192
not_impl_err!("LogicalExtensionCodec is not provided")
193193
}
194194

195-
fn try_decode_table_provider(
195+
fn try_decode_table_source(
196196
&self,
197197
_buf: &[u8],
198198
_table_ref: &TableReference,
199199
_schema: SchemaRef,
200200
_ctx: &SessionContext,
201-
) -> Result<Arc<dyn TableProvider>> {
201+
) -> Result<Arc<dyn TableSource>> {
202202
not_impl_err!("LogicalExtensionCodec is not provided")
203203
}
204204

205-
fn try_encode_table_provider(
205+
fn try_encode_table_source(
206206
&self,
207207
_table_ref: &TableReference,
208-
_node: Arc<dyn TableProvider>,
208+
_node: Arc<dyn TableSource>,
209209
_buf: &mut Vec<u8>,
210210
) -> Result<()> {
211211
not_impl_err!("LogicalExtensionCodec is not provided")
@@ -439,7 +439,7 @@ impl AsLogicalPlan for LogicalPlanNode {
439439
}
440440
#[cfg_attr(not(feature = "avro"), allow(unused_variables))]
441441
FileFormatType::Avro(..) => {
442-
#[cfg(feature = "avro")]
442+
#[cfg(feature = "avro")]
443443
{
444444
Arc::new(AvroFormat)
445445
}
@@ -519,18 +519,15 @@ impl AsLogicalPlan for LogicalPlanNode {
519519
let table_name =
520520
from_table_reference(scan.table_name.as_ref(), "CustomScan")?;
521521

522-
let provider = extension_codec.try_decode_table_provider(
522+
let source = extension_codec.try_decode_table_source(
523523
&scan.custom_table_data,
524524
&table_name,
525525
schema,
526526
ctx,
527527
)?;
528528

529529
LogicalPlanBuilder::scan_with_filters(
530-
table_name,
531-
provider_as_source(provider),
532-
projection,
533-
filters,
530+
table_name, source, projection, filters,
534531
)?
535532
.build()
536533
}
@@ -1001,9 +998,7 @@ impl AsLogicalPlan for LogicalPlanNode {
1001998
projection,
1002999
..
10031000
}) => {
1004-
let provider = source_as_provider(source)?;
1005-
let schema = provider.schema();
1006-
let source = provider.as_any();
1001+
let schema = source.schema();
10071002

10081003
let projection = match projection {
10091004
None => None,
@@ -1021,7 +1016,9 @@ impl AsLogicalPlan for LogicalPlanNode {
10211016
let filters: Vec<protobuf::LogicalExprNode> =
10221017
serialize_exprs(filters, extension_codec)?;
10231018

1024-
if let Some(listing_table) = source.downcast_ref::<ListingTable>() {
1019+
if let Ok(listing_table) =
1020+
DefaultTableSource::unwrap_provider::<ListingTable>(source)
1021+
{
10251022
let any = listing_table.options().format.as_any();
10261023
let file_format_type = {
10271024
let mut maybe_some_type = None;
@@ -1130,7 +1127,9 @@ impl AsLogicalPlan for LogicalPlanNode {
11301127
},
11311128
)),
11321129
})
1133-
} else if let Some(view_table) = source.downcast_ref::<ViewTable>() {
1130+
} else if let Ok(view_table) =
1131+
DefaultTableSource::unwrap_provider::<ViewTable>(source)
1132+
{
11341133
let schema: protobuf::Schema = schema.as_ref().try_into()?;
11351134
Ok(LogicalPlanNode {
11361135
logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new(
@@ -1151,7 +1150,8 @@ impl AsLogicalPlan for LogicalPlanNode {
11511150
},
11521151
))),
11531152
})
1154-
} else if let Some(cte_work_table) = source.downcast_ref::<CteWorkTable>()
1153+
} else if let Ok(cte_work_table) =
1154+
DefaultTableSource::unwrap_provider::<CteWorkTable>(source)
11551155
{
11561156
let name = cte_work_table.name().to_string();
11571157
let schema = cte_work_table.schema();
@@ -1169,7 +1169,11 @@ impl AsLogicalPlan for LogicalPlanNode {
11691169
let schema: protobuf::Schema = schema.as_ref().try_into()?;
11701170
let mut bytes = vec![];
11711171
extension_codec
1172-
.try_encode_table_provider(table_name, provider, &mut bytes)
1172+
.try_encode_table_source(
1173+
table_name,
1174+
Arc::clone(source),
1175+
&mut bytes,
1176+
)
11731177
.map_err(|e| context!("Error serializing custom table", e))?;
11741178
let scan = CustomScan(CustomTableScanNode {
11751179
table_name: Some(table_name.clone().into()),

0 commit comments

Comments
 (0)