Skip to content

Commit f02ef93

Browse files
committed
feat(datafusion-proto): allow TableSource to be serialized
Currently, only instances of `TableProvider` are considered by `LogicalExtensionCodec`, and are automatically wrapped in a `DefaultTableSource` when deserializing. That doesn't work with custom extensions. Fixes #16749
1 parent f38f52f commit f02ef93

File tree

6 files changed

+93
-69
lines changed

6 files changed

+93
-69
lines changed

datafusion/catalog/src/default_table_source.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,26 @@ impl DefaultTableSource {
4545
pub fn new(table_provider: Arc<dyn TableProvider>) -> Self {
4646
Self { table_provider }
4747
}
48+
49+
/// Attempt to downcast a TableSource to DefaultTableSource and access the
50+
/// TableProvider. This will only work with a TableSource created by DataFusion.
51+
pub fn unwrap_provider<T: TableProvider + 'static>(
52+
source: &Arc<dyn TableSource>,
53+
) -> datafusion_common::Result<&T> {
54+
if let Some(source) = source
55+
.as_ref()
56+
.as_any()
57+
.downcast_ref::<DefaultTableSource>()
58+
{
59+
if let Some(provider) =
60+
source.table_provider.as_ref().as_any().downcast_ref::<T>()
61+
{
62+
return Ok(provider);
63+
}
64+
}
65+
66+
internal_err!("TableSource was not expected type")
67+
}
4868
}
4969

5070
impl TableSource for DefaultTableSource {
@@ -87,7 +107,7 @@ impl TableSource for DefaultTableSource {
87107
}
88108
}
89109

90-
/// Wrap TableProvider in TableSource
110+
/// Wrap a TableProvider as a TableSource.
91111
pub fn provider_as_source(
92112
table_provider: Arc<dyn TableProvider>,
93113
) -> Arc<dyn TableSource> {

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ pub mod provider;
3030
mod view_test;
3131

3232
// backwards compatibility
33-
pub use self::default_table_source::{
34-
provider_as_source, source_as_provider, DefaultTableSource,
35-
};
33+
pub use self::default_table_source::{provider_as_source, DefaultTableSource};
3634
pub use self::memory::MemTable;
3735
pub use self::view::ViewTable;
3836
pub use crate::catalog::TableProvider;

datafusion/core/src/physical_planner.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use std::sync::Arc;
2424
use crate::datasource::file_format::file_type_to_format;
2525
use crate::datasource::listing::ListingTableUrl;
2626
use crate::datasource::physical_plan::FileSinkConfig;
27-
use crate::datasource::{source_as_provider, DefaultTableSource};
27+
use crate::datasource::DefaultTableSource;
2828
use crate::error::{DataFusionError, Result};
2929
use crate::execution::context::{ExecutionProps, SessionState};
3030
use crate::logical_expr::utils::generate_sort_key;
@@ -60,6 +60,7 @@ use crate::schema_equivalence::schema_satisfied_by;
6060
use arrow::array::{builder::StringBuilder, RecordBatch};
6161
use arrow::compute::SortOptions;
6262
use arrow::datatypes::{Schema, SchemaRef};
63+
use datafusion_catalog::default_table_source::source_as_provider;
6364
use datafusion_common::display::ToStringifiedPlan;
6465
use datafusion_common::tree_node::{
6566
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,

datafusion/proto/src/logical_plan/file_formats.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,20 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec {
180180
not_impl_err!("Method not implemented")
181181
}
182182

183-
fn try_decode_table_provider(
183+
fn try_decode_table_source(
184184
&self,
185185
_buf: &[u8],
186186
_table_ref: &TableReference,
187187
_schema: arrow::datatypes::SchemaRef,
188188
_ctx: &SessionContext,
189-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
189+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
190190
not_impl_err!("Method not implemented")
191191
}
192192

193-
fn try_encode_table_provider(
193+
fn try_encode_table_source(
194194
&self,
195195
_table_ref: &TableReference,
196-
_node: Arc<dyn datafusion::datasource::TableProvider>,
196+
_node: Arc<dyn datafusion_expr::TableSource>,
197197
_buf: &mut Vec<u8>,
198198
) -> datafusion_common::Result<()> {
199199
not_impl_err!("Method not implemented")
@@ -287,20 +287,20 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec {
287287
not_impl_err!("Method not implemented")
288288
}
289289

290-
fn try_decode_table_provider(
290+
fn try_decode_table_source(
291291
&self,
292292
_buf: &[u8],
293293
_table_ref: &TableReference,
294294
_schema: arrow::datatypes::SchemaRef,
295295
_ctx: &SessionContext,
296-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
296+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
297297
not_impl_err!("Method not implemented")
298298
}
299299

300-
fn try_encode_table_provider(
300+
fn try_encode_table_source(
301301
&self,
302302
_table_ref: &TableReference,
303-
_node: Arc<dyn datafusion::datasource::TableProvider>,
303+
_node: Arc<dyn datafusion_expr::TableSource>,
304304
_buf: &mut Vec<u8>,
305305
) -> datafusion_common::Result<()> {
306306
not_impl_err!("Method not implemented")
@@ -603,20 +603,20 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec {
603603
not_impl_err!("Method not implemented")
604604
}
605605

606-
fn try_decode_table_provider(
606+
fn try_decode_table_source(
607607
&self,
608608
_buf: &[u8],
609609
_table_ref: &TableReference,
610610
_schema: arrow::datatypes::SchemaRef,
611611
_ctx: &SessionContext,
612-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
612+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
613613
not_impl_err!("Method not implemented")
614614
}
615615

616-
fn try_encode_table_provider(
616+
fn try_encode_table_source(
617617
&self,
618618
_table_ref: &TableReference,
619-
_node: Arc<dyn datafusion::datasource::TableProvider>,
619+
_node: Arc<dyn datafusion_expr::TableSource>,
620620
_buf: &mut Vec<u8>,
621621
) -> datafusion_common::Result<()> {
622622
not_impl_err!("Method not implemented")
@@ -689,20 +689,20 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec {
689689
not_impl_err!("Method not implemented")
690690
}
691691

692-
fn try_decode_table_provider(
692+
fn try_decode_table_source(
693693
&self,
694694
_buf: &[u8],
695695
_table_ref: &TableReference,
696696
_schema: arrow::datatypes::SchemaRef,
697697
_ctx: &SessionContext,
698-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
698+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
699699
not_impl_err!("Method not implemented")
700700
}
701701

702-
fn try_encode_table_provider(
702+
fn try_encode_table_source(
703703
&self,
704704
_table_ref: &TableReference,
705-
_node: Arc<dyn datafusion::datasource::TableProvider>,
705+
_node: Arc<dyn datafusion_expr::TableSource>,
706706
_buf: &mut Vec<u8>,
707707
) -> datafusion_common::Result<()> {
708708
not_impl_err!("Method not implemented")
@@ -747,20 +747,20 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec {
747747
not_impl_err!("Method not implemented")
748748
}
749749

750-
fn try_decode_table_provider(
750+
fn try_decode_table_source(
751751
&self,
752752
_buf: &[u8],
753753
_table_ref: &TableReference,
754754
_schema: arrow::datatypes::SchemaRef,
755755
_cts: &SessionContext,
756-
) -> datafusion_common::Result<Arc<dyn datafusion::datasource::TableProvider>> {
756+
) -> datafusion_common::Result<Arc<dyn datafusion_expr::TableSource>> {
757757
not_impl_err!("Method not implemented")
758758
}
759759

760-
fn try_encode_table_provider(
760+
fn try_encode_table_source(
761761
&self,
762762
_table_ref: &TableReference,
763-
_node: Arc<dyn datafusion::datasource::TableProvider>,
763+
_node: Arc<dyn datafusion_expr::TableSource>,
764764
_buf: &mut Vec<u8>,
765765
) -> datafusion_common::Result<()> {
766766
not_impl_err!("Method not implemented")

datafusion/proto/src/logical_plan/mod.rs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ use datafusion::datasource::file_format::parquet::ParquetFormat;
4242
use datafusion::datasource::file_format::{
4343
file_type_to_format, format_as_file_type, FileFormatFactory,
4444
};
45+
use datafusion::datasource::DefaultTableSource;
4546
use datafusion::{
47+
datasource::provider_as_source,
4648
datasource::{
4749
file_format::{
4850
csv::CsvFormat, json::JsonFormat as OtherNdJsonFormat, FileFormat,
4951
},
5052
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
5153
view::ViewTable,
52-
TableProvider,
5354
},
54-
datasource::{provider_as_source, source_as_provider},
5555
prelude::SessionContext,
5656
};
5757
use datafusion_common::file_options::file_type::FileType;
@@ -118,18 +118,18 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync {
118118

119119
fn try_encode(&self, node: &Extension, buf: &mut Vec<u8>) -> Result<()>;
120120

121-
fn try_decode_table_provider(
121+
fn try_decode_table_source(
122122
&self,
123123
buf: &[u8],
124124
table_ref: &TableReference,
125125
schema: SchemaRef,
126126
ctx: &SessionContext,
127-
) -> Result<Arc<dyn TableProvider>>;
127+
) -> Result<Arc<dyn TableSource>>;
128128

129-
fn try_encode_table_provider(
129+
fn try_encode_table_source(
130130
&self,
131131
table_ref: &TableReference,
132-
node: Arc<dyn TableProvider>,
132+
node: Arc<dyn TableSource>,
133133
buf: &mut Vec<u8>,
134134
) -> Result<()>;
135135

@@ -193,20 +193,20 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec {
193193
not_impl_err!("LogicalExtensionCodec is not provided")
194194
}
195195

196-
fn try_decode_table_provider(
196+
fn try_decode_table_source(
197197
&self,
198198
_buf: &[u8],
199199
_table_ref: &TableReference,
200200
_schema: SchemaRef,
201201
_ctx: &SessionContext,
202-
) -> Result<Arc<dyn TableProvider>> {
202+
) -> Result<Arc<dyn TableSource>> {
203203
not_impl_err!("LogicalExtensionCodec is not provided")
204204
}
205205

206-
fn try_encode_table_provider(
206+
fn try_encode_table_source(
207207
&self,
208208
_table_ref: &TableReference,
209-
_node: Arc<dyn TableProvider>,
209+
_node: Arc<dyn TableSource>,
210210
_buf: &mut Vec<u8>,
211211
) -> Result<()> {
212212
not_impl_err!("LogicalExtensionCodec is not provided")
@@ -440,7 +440,7 @@ impl AsLogicalPlan for LogicalPlanNode {
440440
}
441441
#[cfg_attr(not(feature = "avro"), allow(unused_variables))]
442442
FileFormatType::Avro(..) => {
443-
#[cfg(feature = "avro")]
443+
#[cfg(feature = "avro")]
444444
{
445445
Arc::new(AvroFormat)
446446
}
@@ -520,18 +520,15 @@ impl AsLogicalPlan for LogicalPlanNode {
520520
let table_name =
521521
from_table_reference(scan.table_name.as_ref(), "CustomScan")?;
522522

523-
let provider = extension_codec.try_decode_table_provider(
523+
let source = extension_codec.try_decode_table_source(
524524
&scan.custom_table_data,
525525
&table_name,
526526
schema,
527527
ctx,
528528
)?;
529529

530530
LogicalPlanBuilder::scan_with_filters(
531-
table_name,
532-
provider_as_source(provider),
533-
projection,
534-
filters,
531+
table_name, source, projection, filters,
535532
)?
536533
.build()
537534
}
@@ -1029,9 +1026,7 @@ impl AsLogicalPlan for LogicalPlanNode {
10291026
projection,
10301027
..
10311028
}) => {
1032-
let provider = source_as_provider(source)?;
1033-
let schema = provider.schema();
1034-
let source = provider.as_any();
1029+
let schema = source.schema();
10351030

10361031
let projection = match projection {
10371032
None => None,
@@ -1049,7 +1044,9 @@ impl AsLogicalPlan for LogicalPlanNode {
10491044
let filters: Vec<protobuf::LogicalExprNode> =
10501045
serialize_exprs(filters, extension_codec)?;
10511046

1052-
if let Some(listing_table) = source.downcast_ref::<ListingTable>() {
1047+
if let Ok(listing_table) =
1048+
DefaultTableSource::unwrap_provider::<ListingTable>(source)
1049+
{
10531050
let any = listing_table.options().format.as_any();
10541051
let file_format_type = {
10551052
let mut maybe_some_type = None;
@@ -1158,7 +1155,9 @@ impl AsLogicalPlan for LogicalPlanNode {
11581155
},
11591156
)),
11601157
})
1161-
} else if let Some(view_table) = source.downcast_ref::<ViewTable>() {
1158+
} else if let Ok(view_table) =
1159+
DefaultTableSource::unwrap_provider::<ViewTable>(source)
1160+
{
11621161
let schema: protobuf::Schema = schema.as_ref().try_into()?;
11631162
Ok(LogicalPlanNode {
11641163
logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new(
@@ -1179,7 +1178,8 @@ impl AsLogicalPlan for LogicalPlanNode {
11791178
},
11801179
))),
11811180
})
1182-
} else if let Some(cte_work_table) = source.downcast_ref::<CteWorkTable>()
1181+
} else if let Ok(cte_work_table) =
1182+
DefaultTableSource::unwrap_provider::<CteWorkTable>(source)
11831183
{
11841184
let name = cte_work_table.name().to_string();
11851185
let schema = cte_work_table.schema();
@@ -1197,7 +1197,11 @@ impl AsLogicalPlan for LogicalPlanNode {
11971197
let schema: protobuf::Schema = schema.as_ref().try_into()?;
11981198
let mut bytes = vec![];
11991199
extension_codec
1200-
.try_encode_table_provider(table_name, provider, &mut bytes)
1200+
.try_encode_table_source(
1201+
table_name,
1202+
Arc::clone(source),
1203+
&mut bytes,
1204+
)
12011205
.map_err(|e| context!("Error serializing custom table", e))?;
12021206
let scan = CustomScan(CustomTableScanNode {
12031207
table_name: Some(table_name.clone().into()),

0 commit comments

Comments
 (0)