Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/dataframe_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from substrait.builders.plan import read_named_table
from substrait.builders.type import i64, boolean, struct, named_struct
from substrait.extension_registry import ExtensionRegistry
import substrait.dataframe as sdf

registry = ExtensionRegistry(load_default_extensions=True)

ns = named_struct(
names=["id", "is_applicable"],
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
)

table = read_named_table("example_table", ns)

frame = sdf.DataFrame(read_named_table("example_table", ns))
frame = frame.select(sdf.col("id"))
print(frame.to_substrait(registry))
36 changes: 36 additions & 0 deletions examples/narwhals_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Install duckdb and pyarrow before running this example
# /// script
# dependencies = [
# "narwhals==2.9.0",
# "substrait[extensions] @ file:///${PROJECT_ROOT}/"
# ]
# ///

from substrait.builders.plan import read_named_table
from substrait.builders.type import i64, boolean, struct, named_struct
from substrait.extension_registry import ExtensionRegistry

from narwhals.typing import FrameT
import narwhals as nw
import substrait.dataframe as sdf


registry = ExtensionRegistry(load_default_extensions=True)

ns = named_struct(
names=["id", "is_applicable"],
struct=struct(types=[i64(nullable=False), boolean()], nullable=False),
)

table = read_named_table("example_table", ns)


lazy_frame: FrameT = nw.from_native(
sdf.DataFrame(read_named_table("example_table", ns))
)

lazy_frame = lazy_frame.select(nw.col("id").abs(), new_id=nw.col("id"))

df: sdf.DataFrame = lazy_frame.to_native()

print(df.to_substrait(registry))
16 changes: 16 additions & 0 deletions src/substrait/dataframe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import substrait.dataframe
from substrait.builders.extended_expression import column

from substrait.dataframe.dataframe import DataFrame
from substrait.dataframe.expression import Expression

__all__ = [DataFrame, Expression]


def col(name: str) -> Expression:
"""Column selection."""
return Expression(column(name))

# TODO handle str_as_lit argument
def parse_into_expr(expr, str_as_lit: bool):
return expr._to_compliant_expr(substrait.dataframe)
36 changes: 36 additions & 0 deletions src/substrait/dataframe/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Union, Iterable
import substrait.dataframe
from substrait.builders.plan import project
from substrait.dataframe.expression import Expression


class DataFrame:
def __init__(self, plan):
self.plan = plan
self._native_frame = self

def to_substrait(self, registry):
return self.plan(registry)

def __narwhals_lazyframe__(self) -> "DataFrame":
"""Return object implementing CompliantDataFrame protocol."""
return self

def __narwhals_namespace__(self):
"""
Return the namespace object that contains functions like col, lit, etc.
This is how Narwhals knows which backend's functions to use.
"""
return substrait.dataframe

def select(
self, *exprs: Union[Expression, Iterable[Expression]], **named_exprs: Expression
) -> "DataFrame":
expressions = [e.expr for e in exprs] + [
expr.alias(alias).expr for alias, expr in named_exprs.items()
]
return DataFrame(project(self.plan, expressions=expressions))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make sure to update this to use the new select builder once your work in #125 is merged in.


# TODO handle version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's needs to exist to handle this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no clue at this point, I need to dig deeper to understand if we can disregard it or not. that's what I meant by "handle version".

def _with_version(self, version):
return self
36 changes: 36 additions & 0 deletions src/substrait/dataframe/expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from substrait.builders.extended_expression import (
UnboundExtendedExpression,
ExtendedExpressionOrUnbound,
resolve_expression,
scalar_function
)
import substrait.gen.proto.type_pb2 as stp
import substrait.gen.proto.extended_expression_pb2 as stee
from substrait.extension_registry import ExtensionRegistry


def _alias(
expr: ExtendedExpressionOrUnbound,
alias: str = None,
):
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expression = resolve_expression(expr, base_schema, registry)
bound_expression.referred_expr[0].output_names[0] = alias
return bound_expression

return resolve


class Expression:
def __init__(self, expr: UnboundExtendedExpression):
self.expr = expr

def alias(self, alias: str):
self.expr = _alias(self.expr, alias)
return self

def abs(self):
self.expr = scalar_function("functions_arithmetic.yaml", "abs", expressions=[self.expr])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this PR will get merged before your work in #107. We'll need to make sure to update this.

return self
54 changes: 54 additions & 0 deletions tests/dataframe/test_df_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import substrait.gen.proto.type_pb2 as stt
import substrait.gen.proto.plan_pb2 as stp
import substrait.gen.proto.algebra_pb2 as stalg
from substrait.builders.type import boolean, i64
from substrait.builders.plan import read_named_table
from substrait.extension_registry import ExtensionRegistry
import substrait.dataframe as sdf


registry = ExtensionRegistry(load_default_extensions=False)

struct = stt.Type.Struct(
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
)

named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)


def test_project():
df = sdf.DataFrame(read_named_table("table", named_struct))

actual = df.select(id=sdf.col("id")).to_substrait(registry)

expected = stp.Plan(
relations=[
stp.PlanRel(
root=stalg.RelRoot(
input=stalg.Rel(
project=stalg.ProjectRel(
common=stalg.RelCommon(
emit=stalg.RelCommon.Emit(output_mapping=[2])
),
input=df.to_substrait(None).relations[-1].root.input,
expressions=[
stalg.Expression(
selection=stalg.Expression.FieldReference(
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=0
)
),
root_reference=stalg.Expression.FieldReference.RootReference(),
)
)
],
)
),
names=["id"],
)
)
]
)

assert actual == expected