Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
50a74b3
Add Expression Support & with_column API
goutamvenkat-anyscale Jul 3, 2025
9a5086f
Rename to use_columns, use list[expr]
goutamvenkat-anyscale Jul 3, 2025
9804716
Use project operator & update doc
goutamvenkat-anyscale Jul 7, 2025
16cccff
Fix linting issue
goutamvenkat-anyscale Jul 7, 2025
a309598
Doc linter
goutamvenkat-anyscale Jul 7, 2025
176b444
doctest
goutamvenkat-anyscale Jul 7, 2025
c17570f
Address comment
goutamvenkat-anyscale Jul 7, 2025
eea0d55
Address comments
goutamvenkat-anyscale Jul 9, 2025
ed251a1
Linter & remove dataclass for operations
goutamvenkat-anyscale Jul 9, 2025
e36a87b
Address comments
goutamvenkat-anyscale Jul 9, 2025
86bc9fb
revert old change
goutamvenkat-anyscale Jul 9, 2025
3053b95
Remove unnecessary arg
goutamvenkat-anyscale Jul 9, 2025
b09ae8f
Merge branch 'master' into goutam/expressions
goutamvenkat-anyscale Jul 10, 2025
fb3c6a1
doctest + pytest skip if version is not met
goutamvenkat-anyscale Jul 10, 2025
bd7bc77
Remove circular dep
goutamvenkat-anyscale Jul 10, 2025
6d443f8
Address comments
goutamvenkat-anyscale Jul 10, 2025
fc034ec
remove change in block builder
goutamvenkat-anyscale Jul 10, 2025
1a3941f
Remove block builder change
goutamvenkat-anyscale Jul 10, 2025
3f30cbb
Make pre-commit happy
goutamvenkat-anyscale Jul 10, 2025
9b8de87
Address comment on Expr AST comparison
goutamvenkat-anyscale Jul 10, 2025
c13f679
Add expressions test to bazel build
goutamvenkat-anyscale Jul 10, 2025
8d61562
Remove match expression
goutamvenkat-anyscale Jul 10, 2025
d8890fd
Comments
goutamvenkat-anyscale Jul 10, 2025
49e3ccb
Merge branch 'master' into goutam/expressions
goutamvenkat-anyscale Jul 11, 2025
164cbd3
Address comments
goutamvenkat-anyscale Jul 11, 2025
b64beef
Add comments back
goutamvenkat-anyscale Jul 11, 2025
a3f3050
Make expression classes dev api
goutamvenkat-anyscale Jul 11, 2025
821b73e
Add stability to DeveloperAPIs
goutamvenkat-anyscale Jul 11, 2025
f5b08eb
Add .rst files
goutamvenkat-anyscale Jul 11, 2025
2741b12
idk rst
goutamvenkat-anyscale Jul 12, 2025
f4f620c
Merge branch 'master' into goutam/expressions
goutamvenkat-anyscale Jul 12, 2025
c7f0424
Remove code snippet
goutamvenkat-anyscale Jul 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne
from ray.data.block import UserDefinedFunction
from ray.data.expressions import Expr
from ray.data.preprocessor import Preprocessor

if TYPE_CHECKING:
Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(
input_op: LogicalOperator,
cols: Optional[List[str]] = None,
cols_rename: Optional[Dict[str, str]] = None,
exprs: Optional[Dict[str, "Expr"]] = None,
compute: Optional[ComputeStrategy] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
):
Expand All @@ -275,6 +277,7 @@ def __init__(
self._batch_size = None
self._cols = cols
self._cols_rename = cols_rename
self._exprs = exprs
self._batch_format = "pyarrow"
self._zero_copy_batch = True

Expand All @@ -286,6 +289,10 @@ def cols(self) -> Optional[List[str]]:
def cols_rename(self) -> Optional[Dict[str, str]]:
return self._cols_rename

@property
def exprs(self) -> Optional[Dict[str, "Expr"]]:
return self._exprs

def can_modify_num_rows(self) -> bool:
return False

Expand Down
26 changes: 22 additions & 4 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from ray.data.context import DataContext
from ray.data.exceptions import UserCodeException
from ray.data.expressions import eval_expr
from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,18 +95,35 @@ def plan_project_op(

columns = op.cols
columns_rename = op.cols_rename
exprs = op.exprs

def fn(block: Block) -> Block:
try:
if not BlockAccessor.for_block(block).num_rows():
return block
tbl = BlockAccessor.for_block(block).to_arrow()

# 1. evaluate / add expressions
if exprs:
for name, ex in exprs.items():
arr = eval_expr(ex, tbl)
if name in tbl.column_names:
tbl = tbl.set_column(
tbl.schema.get_field_index(name), name, arr
)
else:
tbl = tbl.append_column(name, arr)

# 2. (optional) column projection
if columns:
block = BlockAccessor.for_block(block).select(columns)
tbl = tbl.select(columns)

# 3. (optional) rename
if columns_rename:
block = block.rename_columns(
[columns_rename.get(col, col) for col in block.schema.names]
tbl = tbl.rename_columns(
[columns_rename.get(col, col) for col in tbl.schema.names]
)
return block
return tbl
except Exception as e:
_handle_debugger_exception(e, block)

Expand Down
66 changes: 65 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
from ray.data._internal.execution.interfaces import Executor, NodeIdStr
from ray.data.grouped_data import GroupedData

from ray.data.expressions import AliasExpr, Expr

logger = logging.getLogger(__name__)

Expand All @@ -154,6 +155,7 @@
IOC_API_GROUP = "I/O and Conversion"
IM_API_GROUP = "Inspecting Metadata"
E_API_GROUP = "Execution"
EXPRESSION_API_GROUP = "Expressions"


@PublicAPI
Expand Down Expand Up @@ -776,6 +778,68 @@ def _map_batches_without_batch_size_validation(
logical_plan = LogicalPlan(map_batches_op, self.context)
return Dataset(plan, logical_plan)

@PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha")
def with_columns(
self,
exprs: List[Expr],
batch_format: Optional[str] = "pyarrow",
**ray_remote_args,
) -> "Dataset":
"""
Add new columns to the dataset.

Examples:

>>> import ray
>>> from ray.data.expressions import col
>>> ds = ray.data.range(100)
>>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema()
Column Type
------ ----
id int64
new_id int64
new_id_2 int64

Args:
exprs: The expressions to evaluate to produce the new column values.
batch_format: If ``"numpy"``, batches are
``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are
``pandas.DataFrame``. If ``"pyarrow"``, batches are
``pyarrow.Table``. If ``"numpy"``, batches are
``Dict[str, numpy.ndarray]``.
**ray_remote_args: Additional resource requirements to request from
Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
:func:`ray.remote` for details.

Returns:
A new dataset with the added columns evaluated via expressions.
"""
if not exprs:
raise ValueError("at least one expression is required")

# Build mapping {new_col_name: expression}
projections: Dict[str, Expr] = {}
for e in exprs:
if not isinstance(e, Expr):
raise TypeError(f"Expected Expr, got {type(e)}")
if isinstance(e, AliasExpr):
projections[e.name] = e.expr
else:
raise ValueError("Each expression must be `.alias(<output>)`-ed.")

from ray.data._internal.logical.operators.map_operator import Project

plan = self._plan.copy()
project_op = Project(
self._logical_plan.dag,
cols=None,
cols_rename=None,
exprs=projections, # << pass expressions
ray_remote_args=ray_remote_args,
)
logical_plan = LogicalPlan(project_op, self.context)
return Dataset(plan, logical_plan)

@PublicAPI(api_group=BT_API_GROUP)
def add_column(
self,
Expand Down Expand Up @@ -4461,7 +4525,7 @@ def write_clickhouse(
* order_by:
Sets the `ORDER BY` clause in the `CREATE TABLE` statement, iff not provided.
When overwriting an existing table, its previous `ORDER BY` (if any) is reused.
Otherwise, a best column is selected automatically (favoring a timestamp column,
Otherwise, a "best" column is selected automatically (favoring a timestamp column,
then a non-string column, and lastly the first column).

* partition_by:
Expand Down
181 changes: 181 additions & 0 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import operator
from dataclasses import dataclass
from typing import Any, Callable, Dict

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc

from ray.util.annotations import DeveloperAPI


# ──────────────────────────────────────
# Basic expression node definitions
# ──────────────────────────────────────
@DeveloperAPI
class Expr: # Base class – all expression nodes inherit from this
# Binary/boolean operator overloads
def _bin(self, other: Any, op: str) -> "Expr":
other = other if isinstance(other, Expr) else LiteralExpr(other)
return BinaryExpr(op, self, other)

# arithmetic
def __add__(self, other):
return self._bin(other, "add")

def __sub__(self, other):
return self._bin(other, "sub")

def __mul__(self, other):
return self._bin(other, "mul")

def __truediv__(self, other):
return self._bin(other, "div")

# comparison
def __gt__(self, other):
return self._bin(other, "gt")

def __lt__(self, other):
return self._bin(other, "lt")

def __ge__(self, other):
return self._bin(other, "ge")

def __le__(self, other):
return self._bin(other, "le")

def __eq__(self, other):
return self._bin(other, "eq")

# boolean
def __and__(self, other):
return self._bin(other, "and")

def __or__(self, other):
return self._bin(other, "or")

# Rename the output column
def alias(self, name: str) -> "AliasExpr":
return AliasExpr(self, name)


@DeveloperAPI
@dataclass(frozen=True, eq=False)
class ColumnExpr(Expr):
name: str


@DeveloperAPI
@dataclass(frozen=True, eq=False)
class LiteralExpr(Expr):
value: Any


@DeveloperAPI
@dataclass(frozen=True, eq=False)
class BinaryExpr(Expr):
op: str
left: Expr
right: Expr


@DeveloperAPI
@dataclass(frozen=True, eq=False)
class AliasExpr(Expr):
expr: Expr
name: str


# ──────────────────────────────────────
# User helpers
# ──────────────────────────────────────


@DeveloperAPI
def col(name: str) -> ColumnExpr:
"""Reference an existing column."""
return ColumnExpr(name)


@DeveloperAPI
def lit(value: Any) -> LiteralExpr:
"""Create a scalar literal expression (e.g. lit(1))."""
return LiteralExpr(value)


# ──────────────────────────────────────
# Local evaluator (pandas batches)
# ──────────────────────────────────────
# This is used by Dataset.with_columns – kept here so it can be re-used by
# future optimised executors.
_PANDAS_OPS: Dict[str, Callable[[Any, Any], Any]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these pandas, numpy, and pyarrow all need to have keys for "add", "sub", etc.. maybe better to define a base class for each key, and then implement __call__ for each subclass (pandas, numpy, pyarrow)?

"add": operator.add,
"sub": operator.sub,
"mul": operator.mul,
"div": operator.truediv,
"gt": operator.gt,
"lt": operator.lt,
"ge": operator.ge,
"le": operator.le,
"eq": operator.eq,
"and": operator.and_,
"or": operator.or_,
}

_NUMPY_OPS: Dict[str, Callable[[Any, Any], Any]] = {
"add": np.add,
"sub": np.subtract,
"mul": np.multiply,
"div": np.divide,
"gt": np.greater,
"lt": np.less,
"ge": np.greater_equal,
"le": np.less_equal,
"eq": np.equal,
"and": np.logical_and,
"or": np.logical_or,
}

_ARROW_OPS: Dict[str, Callable[[Any, Any], Any]] = {
"add": pc.add,
"sub": pc.subtract,
"mul": pc.multiply,
"div": pc.divide,
"gt": pc.greater,
"lt": pc.less,
"ge": pc.greater_equal,
"le": pc.less_equal,
"eq": pc.equal,
"and": pc.and_,
"or": pc.or_,
}


def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any:
"""Generic recursive expression evaluator."""
if isinstance(expr, ColumnExpr):
return batch[expr.name]
if isinstance(expr, LiteralExpr):
return expr.value
if isinstance(expr, BinaryExpr):
return ops[expr.op](
_eval_expr_recursive(expr.left, batch, ops),
_eval_expr_recursive(expr.right, batch, ops),
)
raise TypeError(f"Unsupported expression node: {type(expr).__name__}")


@DeveloperAPI
def eval_expr(expr: Expr, batch) -> Any:
"""Recursively evaluate *expr* against a batch of the appropriate type."""
if isinstance(batch, pd.DataFrame):
return _eval_expr_recursive(expr, batch, _PANDAS_OPS)
elif isinstance(batch, (np.ndarray, dict)):
return _eval_expr_recursive(expr, batch, _NUMPY_OPS)
elif isinstance(batch, pa.Table):
return _eval_expr_recursive(expr, batch, _ARROW_OPS)
raise TypeError(f"Unsupported batch type: {type(batch).__name__}")
Loading