Skip to content

Commit 0fc3cc7

Browse files
authored
feat: add select builder (#125)
The project builder in plan.py now builds a plan that outputs all inputs field followed by all expressions. The select builders captures the old behavior of project and outputs only the expressions. BREAKING CHANGE: the project builder in plan.py has been renamed to select
1 parent 02a65f4 commit 0fc3cc7

File tree

9 files changed

+350
-249
lines changed

9 files changed

+350
-249
lines changed

examples/builder_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from substrait.builders.plan import (
22
read_named_table,
3-
project,
3+
select,
44
filter,
55
sort,
66
fetch,
@@ -34,7 +34,7 @@ def basic_example():
3434
expressions=[column("id"), literal(100, i64(nullable=False))],
3535
),
3636
)
37-
table = project(table, expressions=[column("id")])
37+
table = select(table, expressions=[column("id")])
3838

3939
print(table(registry))
4040
pretty_print_plan(table(registry), use_colors=True)
@@ -177,7 +177,7 @@ def advanced_example():
177177
expressions=[column("id"), literal(100, i64(nullable=False))],
178178
),
179179
)
180-
table = project(table, expressions=[column("id")])
180+
table = select(table, expressions=[column("id")])
181181

182182
print("Simple filtered table:")
183183
pretty_print_plan(table(registry), use_colors=True)
@@ -212,7 +212,7 @@ def advanced_example():
212212
)
213213

214214
# Project with calculated fields
215-
enriched_users = project(
215+
enriched_users = select(
216216
adult_users,
217217
expressions=[
218218
column("user_id"),
@@ -322,7 +322,7 @@ def expression_only_example():
322322
struct=struct(types=[fp64(nullable=False)], nullable=False),
323323
)
324324
dummy_table = read_named_table("dummy", dummy_schema)
325-
dummy_plan = project(dummy_table, expressions=[complex_expr])
325+
dummy_plan = select(dummy_table, expressions=[complex_expr])
326326
pretty_print_plan(dummy_plan(registry), use_colors=True)
327327

328328
print("\n" + "=" * 50 + "\n")

examples/duckdb_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
import duckdb
12-
from substrait.builders.plan import read_named_table, project, filter
12+
from substrait.builders.plan import read_named_table, select, filter
1313
from substrait.builders.extended_expression import column, scalar_function, literal
1414
from substrait.builders.type import i32
1515
from substrait.extension_registry import ExtensionRegistry
@@ -46,7 +46,7 @@ def read_duckdb_named_table(name: str, conn):
4646
expressions=[column("c_nationkey"), literal(3, i32())],
4747
),
4848
)
49-
table = project(
49+
table = select(
5050
table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")]
5151
)
5252
sql = "CALL from_substrait(?)"

examples/pyarrow_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pyarrow.compute as pc
1010
import pyarrow.substrait as pa_substrait
1111
import substrait
12-
from substrait.builders.plan import project, read_named_table
12+
from substrait.builders.plan import select, read_named_table
1313

1414
arrow_schema = pa.schema([pa.field("x", pa.int32()), pa.field("y", pa.int32())])
1515

@@ -24,5 +24,5 @@
2424
pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr))
2525

2626
table = read_named_table("example", substrait_schema)
27-
table = project(table, expressions=[pysubstrait_expr])(None)
27+
table = select(table, expressions=[pysubstrait_expr])(None)
2828
print(table)

src/substrait/builders/plan.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,69 @@ def project(
9595
expressions: Iterable[ExtendedExpressionOrUnbound],
9696
extension: Optional[AdvancedExtension] = None,
9797
) -> UnboundPlan:
98+
"""
99+
Builds an UnboundPlan with ProjectRel as the root node. Expressions are appended to the parent relation fields to produce an output.
100+
Semantically similar to a withColumn transformation.
101+
102+
:param plan: Parent plan
103+
:type plan: PlanOrUnbound
104+
:param expressions: Expressions to project
105+
:type expressions: Iterable[ExtendedExpressionOrUnbound]
106+
:param extension: Optional user-defined extension
107+
:type extension: Optional[AdvancedExtension]
108+
:return: UnboundPlan with ProjectRel as the root node
109+
:rtype: UnboundPlan
110+
"""
111+
112+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
113+
_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
114+
ns = infer_plan_schema(_plan)
115+
bound_expressions: Iterable[stee.ExtendedExpression] = [
116+
resolve_expression(e, ns, registry) for e in expressions
117+
]
118+
119+
names = list(_plan.relations[-1].root.names) + [
120+
e.output_names[0] for ee in bound_expressions for e in ee.referred_expr
121+
]
122+
123+
rel = stalg.Rel(
124+
project=stalg.ProjectRel(
125+
input=_plan.relations[-1].root.input,
126+
expressions=[
127+
e.expression for ee in bound_expressions for e in ee.referred_expr
128+
],
129+
advanced_extension=extension,
130+
)
131+
)
132+
133+
return stp.Plan(
134+
version=default_version,
135+
relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))],
136+
**_merge_extensions(_plan, *bound_expressions),
137+
)
138+
139+
return resolve
140+
141+
142+
def select(
143+
plan: PlanOrUnbound,
144+
expressions: Iterable[ExtendedExpressionOrUnbound],
145+
extension: Optional[AdvancedExtension] = None,
146+
) -> UnboundPlan:
147+
"""
148+
Builds an UnboundPlan with ProjectRel as the root node. Expressions make up the fields of an output relation.
149+
Semantically similar to a select transformation.
150+
151+
:param plan: Parent plan
152+
:type plan: PlanOrUnbound
153+
:param expressions: Expressions to project
154+
:type expressions: Iterable[ExtendedExpressionOrUnbound]
155+
:param extension: Optional user-defined extension
156+
:type extension: Optional[AdvancedExtension]
157+
:return: UnboundPlan with ProjectRel as the root node
158+
:rtype: UnboundPlan
159+
"""
160+
98161
def resolve(registry: ExtensionRegistry) -> stp.Plan:
99162
_plan = plan if isinstance(plan, stp.Plan) else plan(registry)
100163
ns = infer_plan_schema(_plan)

src/substrait/extension_registry.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,14 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:
339339
output_type = evaluate(self.impl.return_, parameters)
340340

341341
if self.nullability == se.NullabilityHandling.MIRROR:
342-
sig_contains_nullable = any([
343-
p.__getattribute__(p.WhichOneof("kind")).nullability
344-
== Type.NULLABILITY_NULLABLE
345-
for p in signature
346-
if isinstance(p, Type)
347-
])
342+
sig_contains_nullable = any(
343+
[
344+
p.__getattribute__(p.WhichOneof("kind")).nullability
345+
== Type.NULLABILITY_NULLABLE
346+
for p in signature
347+
if isinstance(p, Type)
348+
]
349+
)
348350
output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = (
349351
Type.NULLABILITY_NULLABLE
350352
if sig_contains_nullable

src/substrait/sql/sql_to_substrait.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from substrait.builders.plan import (
1313
read_named_table,
14-
project,
14+
select,
1515
filter,
1616
sort,
1717
fetch,
@@ -309,7 +309,7 @@ def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionReg
309309
if having_predicate:
310310
relation = filter(relation, having_predicate)(registry)
311311

312-
return project(relation, expressions=projection)(registry)
312+
return select(relation, expressions=projection)(registry)
313313
elif op == "Table":
314314
name = ast["name"][0]["Identifier"]["value"]
315315
return read_named_table(name, schema_resolver(name))

tests/builders/plan/test_project.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import substrait.gen.proto.plan_pb2 as stp
33
import substrait.gen.proto.algebra_pb2 as stalg
44
from substrait.builders.type import boolean, i64
5-
from substrait.builders.plan import read_named_table, project, default_version
5+
from substrait.builders.plan import read_named_table, select, project, default_version
66
from substrait.builders.extended_expression import column
77
from substrait.extension_registry import ExtensionRegistry
88

@@ -22,6 +22,41 @@ def test_project():
2222

2323
expected = stp.Plan(
2424
version=default_version,
25+
relations=[
26+
stp.PlanRel(
27+
root=stalg.RelRoot(
28+
input=stalg.Rel(
29+
project=stalg.ProjectRel(
30+
input=table(None).relations[-1].root.input,
31+
expressions=[
32+
stalg.Expression(
33+
selection=stalg.Expression.FieldReference(
34+
direct_reference=stalg.Expression.ReferenceSegment(
35+
struct_field=stalg.Expression.ReferenceSegment.StructField(
36+
field=0
37+
)
38+
),
39+
root_reference=stalg.Expression.FieldReference.RootReference(),
40+
)
41+
)
42+
],
43+
)
44+
),
45+
names=["id", "is_applicable", "id"],
46+
)
47+
)
48+
],
49+
)
50+
51+
assert actual == expected
52+
53+
54+
def test_select():
55+
table = read_named_table("table", named_struct)
56+
57+
actual = select(table, [column("id")])(registry)
58+
59+
expected = stp.Plan(
2560
relations=[
2661
stp.PlanRel(
2762
root=stalg.RelRoot(
@@ -49,6 +84,7 @@ def test_project():
4984
)
5085
)
5186
],
87+
version=default_version,
5288
)
5389

5490
assert actual == expected

tests/test_uri_urn_migration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from substrait.builders.plan import (
2525
read_named_table,
2626
aggregate,
27-
project,
27+
select,
2828
filter,
2929
default_version,
3030
)
@@ -149,7 +149,7 @@ def test_project_outputs_both_uri_and_urn():
149149
alias=["add"],
150150
)
151151

152-
actual = project(table, [add_expr])(registry)
152+
actual = select(table, [add_expr])(registry)
153153

154154
ns = infer_plan_schema(table(None))
155155

0 commit comments

Comments
 (0)