Skip to content

Commit ebe5697

Browse files
authored
feat: configurable registry for SQL conversion (#115)
Enables configuring extension registry in SQL conversion
1 parent fcfbb60 commit ebe5697

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/substrait/sql/sql_to_substrait.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,13 @@ def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionReg
333333
raise Exception(f"Unknown op {op}")
334334

335335

336-
def convert(query: str, dialect: str, schema_resolver: SchemaResolver):
336+
def convert(
337+
query: str,
338+
dialect: str,
339+
schema_resolver: SchemaResolver,
340+
registry: ExtensionRegistry = None,
341+
):
337342
ast = parse_sql(sql=query, dialect=dialect)[0]
338-
registry = ExtensionRegistry(load_default_extensions=True)
343+
if not registry:
344+
registry = ExtensionRegistry(load_default_extensions=True)
339345
return translate(ast, schema_resolver=schema_resolver, registry=registry)

tests/sql/test_sql_to_substrait.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from substrait.extension_registry import ExtensionRegistry
12
from substrait.sql.sql_to_substrait import convert
23
import pyarrow
34
from google.protobuf import json_format
@@ -30,6 +31,8 @@
3031
]
3132
)
3233

34+
registry = ExtensionRegistry(load_default_extensions=True)
35+
3336

3437
def sort_arrow(table: pyarrow.Table):
3538
import pyarrow.compute as pc
@@ -52,7 +55,7 @@ def df_schema_resolver(name: str):
5255
pa_schema = ctx.sql(f"SELECT * FROM {name} LIMIT 0").schema()
5356
return pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema
5457

55-
plan = convert(query, "generic", df_schema_resolver)
58+
plan = convert(query, "generic", df_schema_resolver, registry)
5659

5760
sql_arrow = ctx.sql(query).to_arrow_table()
5861

@@ -86,7 +89,7 @@ def duckdb_schema_resolver(name: str):
8689
conn.register("stores", data)
8790
conn.register("sales", sales_data)
8891

89-
plan = convert(query, "duckdb", duckdb_schema_resolver)
92+
plan = convert(query, "duckdb", duckdb_schema_resolver, registry)
9093

9194
conn.install_extension("substrait", repository="community")
9295
conn.load_extension("substrait")

0 commit comments

Comments
 (0)