From 50a74b3c73ad62c139497f15ff2ccc692d113d93 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 3 Jul 2025 11:45:01 -0700 Subject: [PATCH 01/29] Add Expression Support & with_column API Signed-off-by: Goutam V --- python/ray/data/dataset.py | 83 ++++++++++++++++++ python/ray/data/expressions.py | 137 ++++++++++++++++++++++++++++++ python/ray/data/tests/test_map.py | 46 ++++++++++ 3 files changed, 266 insertions(+) create mode 100644 python/ray/data/expressions.py diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index efd293a907cd..29287a39743d 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -134,6 +134,7 @@ from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data.grouped_data import GroupedData +from ray.data.expressions import AliasExpr, Expr, eval_expr logger = logging.getLogger(__name__) @@ -154,6 +155,7 @@ IOC_API_GROUP = "I/O and Conversion" IM_API_GROUP = "Inspecting Metadata" E_API_GROUP = "Execution" +EXPRESSION_API_GROUP = "Expressions" @PublicAPI @@ -776,6 +778,87 @@ def _map_batches_without_batch_size_validation( logical_plan = LogicalPlan(map_batches_op, self.context) return Dataset(plan, logical_plan) + @PublicAPI(api_group=EXPRESSION_API_GROUP) + def with_column( + self, + *exprs: Expr, + batch_format: Optional[str] = "pandas", + compute: Optional[str] = None, + concurrency: Optional[int] = None, + **ray_remote_args, + ) -> "Dataset": + """ + Add a new column to the dataset. + + Examples: + + >>> import ray + >>> ds = ray.data.range(100) + >>> ds.with_column((col("id") * 2).alias("new_id")).schema() + Column Type + ------ ---- + id int64 + new_id int64 + Args: + exprs: The expressions to evaluate to produce the new column values. + batch_format: If ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. + + Returns: + A new dataset with the added column. + """ + if not exprs: + raise ValueError("at least one expression is required") + + accepted_batch_formats = ["pandas", "pyarrow", "numpy"] + if batch_format not in accepted_batch_formats: + raise ValueError( + f"batch_format argument must be on of {accepted_batch_formats}, " + f"got: {batch_format}" + ) + + projections = {} + for expr in exprs: + if not isinstance(expr, Expr): + raise TypeError(f"Expected Expr, got: {type(expr)}") + if isinstance(expr, AliasExpr): + projections[expr.name] = expr.expr + else: + raise ValueError("Each expression must be `.alias()`-ed.") + + def _project(batch): + if isinstance(batch, dict): + for name, ex in projections.items(): + batch[name] = eval_expr(ex, batch) + return batch + + import pandas as pd + + if isinstance(batch, pd.DataFrame): + for name, ex in projections.items(): + batch[name] = eval_expr(ex, batch) + return batch + + import pyarrow as pa + + if isinstance(batch, pa.Table): + tbl = batch + for name, ex in projections.items(): + arr = eval_expr(ex, batch) + tbl = tbl.append_column(name, arr) + return tbl + + return self.map_batches( + _project, + batch_format=batch_format, + compute=compute, + concurrency=concurrency, + zero_copy_batch=False, + **ray_remote_args, + ) + @PublicAPI(api_group=BT_API_GROUP) def add_column( self, diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py new file mode 100644 index 000000000000..41d2b112358e --- /dev/null +++ b/python/ray/data/expressions.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import operator +from dataclasses import dataclass +from typing import Any, Dict, Callable +import numpy as np +import pyarrow.compute as pc +import pandas as pd +import pyarrow as pa + +# ────────────────────────────────────── +# Basic expression node definitions +# ────────────────────────────────────── +class Expr: # Base class – all expression nodes inherit from this + # Binary/boolean operator overloads + def _bin(self, other: Any, op: str) -> "Expr": + other = other if isinstance(other, Expr) else LiteralExpr(other) + return BinaryExpr(op, self, other) + + # arithmetic + def __add__(self, other): return self._bin(other, "add") + def __sub__(self, other): return self._bin(other, "sub") + def __mul__(self, other): return self._bin(other, "mul") + def __truediv__(self, other): return self._bin(other, "div") + # comparison + def __gt__(self, other): return self._bin(other, "gt") + def __lt__(self, other): return self._bin(other, "lt") + def __ge__(self, other): return self._bin(other, "ge") + def __le__(self, other): return self._bin(other, "le") + def __eq__(self, other): return self._bin(other, "eq") + # boolean + def __and__(self, other): return self._bin(other, "and") + def __or__(self, other): return self._bin(other, "or") + + # Rename the output column + def alias(self, name: str) -> "AliasExpr": + return AliasExpr(self, name) + +@dataclass(frozen=True, eq=False) +class ColumnExpr(Expr): + name: str + +@dataclass(frozen=True, eq=False) +class LiteralExpr(Expr): + value: Any + +@dataclass(frozen=True, eq=False) +class BinaryExpr(Expr): + op: str + left: Expr + right: Expr + +@dataclass(frozen=True, eq=False) +class AliasExpr(Expr): + expr: Expr + name: str + +# ────────────────────────────────────── +# User helpers +# ────────────────────────────────────── +def col(name: str) -> ColumnExpr: + """Reference an existing column.""" + return ColumnExpr(name) + +def lit(value: Any) -> LiteralExpr: + """Create a scalar literal expression (e.g. lit(1)).""" + return LiteralExpr(value) + +# ────────────────────────────────────── +# Local evaluator (pandas batches) +# ────────────────────────────────────── +# This is used by Dataset.with_columns – kept here so it can be re-used by +# future optimised executors. +_PANDAS_OPS: Dict[str, Callable[[Any, Any], Any]] = { + "add": operator.add, + "sub": operator.sub, + "mul": operator.mul, + "div": operator.truediv, + "gt": operator.gt, + "lt": operator.lt, + "ge": operator.ge, + "le": operator.le, + "eq": operator.eq, + "and": operator.and_, + "or": operator.or_, +} + +_NUMPY_OPS: Dict[str, Callable[[Any, Any], Any]] = { + "add": np.add, + "sub": np.subtract, + "mul": np.multiply, + "div": np.divide, + "gt": np.greater, + "lt": np.less, + "ge": np.greater_equal, + "le": np.less_equal, + "eq": np.equal, + "and": np.logical_and, + "or": np.logical_or, +} + +_ARROW_OPS: Dict[str, Callable[[Any, Any], Any]] = { + "add": pc.add, + "sub": pc.subtract, + "mul": pc.multiply, + "div": pc.divide, + "gt": pc.greater, + "lt": pc.less, + "ge": pc.greater_equal, + "le": pc.less_equal, + "eq": pc.equal, + "and": pc.and_, + "or": pc.or_, +} + +def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any: + """Generic recursive expression evaluator.""" + if isinstance(expr, ColumnExpr): + return batch[expr.name] + if isinstance(expr, LiteralExpr): + return expr.value + if isinstance(expr, BinaryExpr): + return ops[expr.op]( + _eval_expr_recursive(expr.left, batch, ops), + _eval_expr_recursive(expr.right, batch, ops) + ) + raise TypeError(f"Unsupported expression node: {type(expr).__name__}") + +def eval_expr(expr: Expr, batch) -> Any: + """Recursively evaluate *expr* against a batch of the appropriate type.""" + if isinstance(batch, pd.DataFrame): + return _eval_expr_recursive(expr, batch, _PANDAS_OPS) + elif isinstance(batch, (np.ndarray, dict)): + return _eval_expr_recursive(expr, batch, _NUMPY_OPS) + elif isinstance(batch, pa.Table): + return _eval_expr_recursive(expr, batch, _ARROW_OPS) + raise TypeError(f"Unsupported batch type: {type(batch).__name__}") diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index ccb62fe353c5..dd068c4400ed 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -28,6 +28,7 @@ from ray.data._internal.execution.operators.actor_pool_map_operator import _MapWorker from ray.data.context import DataContext from ray.data.exceptions import UserCodeException +from ray.data.expressions import col, lit from ray.data.tests.conftest import * # noqa from ray.data.tests.test_util import ConcurrencyCounter # noqa from ray.data.tests.util import column_udf, column_udf_class, extract_values @@ -1982,6 +1983,51 @@ def func(x, y): assert r.startswith("OneHotEncoder"), r +@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) +@pytest.mark.parametrize( + "expr, expected_value", + [ + # Arithmetic operations + ((col("id") + 1).alias("result"), 1), # 0 + 1 = 1 + ((col("id") + 5).alias("result"), 5), # 0 + 5 = 5 + ((col("id") - 1).alias("result"), -1), # 0 - 1 = -1 + ((col("id") * 2).alias("result"), 0), # 0 * 2 = 0 + ((col("id") * 3).alias("result"), 0), # 0 * 3 = 0 + ((col("id") / 2).alias("result"), 0.0), # 0 / 2 = 0.0 + # More complex arithmetic + (((col("id") + 1) * 2).alias("result"), 2), # (0 + 1) * 2 = 2 + (((col("id") * 2) + 3).alias("result"), 3), # 0 * 2 + 3 = 3 + # Comparison operations + ((col("id") > 0).alias("result"), False), # 0 > 0 = False + ((col("id") >= 0).alias("result"), True), # 0 >= 0 = True + ((col("id") < 1).alias("result"), True), # 0 < 1 = True + ((col("id") <= 0).alias("result"), True), # 0 <= 0 = True + ((col("id") == 0).alias("result"), True), # 0 == 0 = True + # Operations with literals + ((col("id") + lit(10)).alias("result"), 10), # 0 + 10 = 10 + ((col("id") * lit(5)).alias("result"), 0), # 0 * 5 = 0 + ((lit(2) + col("id")).alias("result"), 2), # 2 + 0 = 2 + ((lit(10) / (col("id") + 1)).alias("result"), 10.0), # 10 / (0 + 1) = 10.0 + ], +) +def test_with_column(ray_start_regular_shared, batch_format, expr, expected_value): + """Verify that `with_column` works for pandas, numpy, and pyarrow batch formats with various operations.""" + ds = ray.data.range(5).with_column(expr, batch_format=batch_format) + result = ds.take(1)[0] + assert result["id"] == 0 + assert result["result"] == expected_value + + +def test_with_column_nonexistent_column(ray_start_regular_shared): + """Verify that referencing a non-existent column with col() raises an exception.""" + # Create a dataset with known column "id" + ds = ray.data.range(5) + + # Try to reference a non-existent column - this should raise an exception + with pytest.raises(UserCodeException): + ds.with_column((col("nonexistent_column") + 1).alias("result")).materialize() + + if __name__ == "__main__": import sys From 9a5086fb96f32a4977ef312bc8f43ac254e4061d Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 3 Jul 2025 11:56:47 -0700 Subject: [PATCH 02/29] Rename to use_columns, use list[expr] Signed-off-by: Goutam V --- python/ray/data/dataset.py | 10 +++-- python/ray/data/expressions.py | 72 ++++++++++++++++++++++--------- python/ray/data/tests/test_map.py | 29 +++++++++++-- 3 files changed, 83 insertions(+), 28 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 29287a39743d..00b4379fb184 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -779,26 +779,28 @@ def _map_batches_without_batch_size_validation( return Dataset(plan, logical_plan) @PublicAPI(api_group=EXPRESSION_API_GROUP) - def with_column( + def with_columns( self, - *exprs: Expr, + exprs: List[Expr], batch_format: Optional[str] = "pandas", compute: Optional[str] = None, concurrency: Optional[int] = None, **ray_remote_args, ) -> "Dataset": """ - Add a new column to the dataset. + Add new columns to the dataset. Examples: >>> import ray >>> ds = ray.data.range(100) - >>> ds.with_column((col("id") * 2).alias("new_id")).schema() + >>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema() Column Type ------ ---- id int64 new_id int64 + new_id_2 int64 + Args: exprs: The expressions to evaluate to produce the new column values. batch_format: If ``"numpy"``, batches are diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 41d2b112358e..21c3ec06e8e1 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -2,11 +2,13 @@ import operator from dataclasses import dataclass -from typing import Any, Dict, Callable +from typing import Any, Callable, Dict + import numpy as np -import pyarrow.compute as pc import pandas as pd import pyarrow as pa +import pyarrow.compute as pc + # ────────────────────────────────────── # Basic expression node definitions @@ -18,43 +20,69 @@ def _bin(self, other: Any, op: str) -> "Expr": return BinaryExpr(op, self, other) # arithmetic - def __add__(self, other): return self._bin(other, "add") - def __sub__(self, other): return self._bin(other, "sub") - def __mul__(self, other): return self._bin(other, "mul") - def __truediv__(self, other): return self._bin(other, "div") + def __add__(self, other): + return self._bin(other, "add") + + def __sub__(self, other): + return self._bin(other, "sub") + + def __mul__(self, other): + return self._bin(other, "mul") + + def __truediv__(self, other): + return self._bin(other, "div") + # comparison - def __gt__(self, other): return self._bin(other, "gt") - def __lt__(self, other): return self._bin(other, "lt") - def __ge__(self, other): return self._bin(other, "ge") - def __le__(self, other): return self._bin(other, "le") - def __eq__(self, other): return self._bin(other, "eq") + def __gt__(self, other): + return self._bin(other, "gt") + + def __lt__(self, other): + return self._bin(other, "lt") + + def __ge__(self, other): + return self._bin(other, "ge") + + def __le__(self, other): + return self._bin(other, "le") + + def __eq__(self, other): + return self._bin(other, "eq") + # boolean - def __and__(self, other): return self._bin(other, "and") - def __or__(self, other): return self._bin(other, "or") + def __and__(self, other): + return self._bin(other, "and") + + def __or__(self, other): + return self._bin(other, "or") # Rename the output column def alias(self, name: str) -> "AliasExpr": return AliasExpr(self, name) + @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): name: str + @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): value: Any + @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): op: str left: Expr right: Expr + @dataclass(frozen=True, eq=False) class AliasExpr(Expr): expr: Expr name: str + # ────────────────────────────────────── # User helpers # ────────────────────────────────────── @@ -62,10 +90,12 @@ def col(name: str) -> ColumnExpr: """Reference an existing column.""" return ColumnExpr(name) + def lit(value: Any) -> LiteralExpr: """Create a scalar literal expression (e.g. lit(1)).""" return LiteralExpr(value) + # ────────────────────────────────────── # Local evaluator (pandas batches) # ────────────────────────────────────── @@ -76,13 +106,13 @@ def lit(value: Any) -> LiteralExpr: "sub": operator.sub, "mul": operator.mul, "div": operator.truediv, - "gt": operator.gt, - "lt": operator.lt, - "ge": operator.ge, - "le": operator.le, - "eq": operator.eq, + "gt": operator.gt, + "lt": operator.lt, + "ge": operator.ge, + "le": operator.le, + "eq": operator.eq, "and": operator.and_, - "or": operator.or_, + "or": operator.or_, } _NUMPY_OPS: Dict[str, Callable[[Any, Any], Any]] = { @@ -113,6 +143,7 @@ def lit(value: Any) -> LiteralExpr: "or": pc.or_, } + def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any: """Generic recursive expression evaluator.""" if isinstance(expr, ColumnExpr): @@ -122,10 +153,11 @@ def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any: if isinstance(expr, BinaryExpr): return ops[expr.op]( _eval_expr_recursive(expr.left, batch, ops), - _eval_expr_recursive(expr.right, batch, ops) + _eval_expr_recursive(expr.right, batch, ops), ) raise TypeError(f"Unsupported expression node: {type(expr).__name__}") + def eval_expr(expr: Expr, batch) -> Any: """Recursively evaluate *expr* against a batch of the appropriate type.""" if isinstance(batch, pd.DataFrame): diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index dd068c4400ed..61c8689b2436 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -2010,22 +2010,43 @@ def func(x, y): ((lit(10) / (col("id") + 1)).alias("result"), 10.0), # 10 / (0 + 1) = 10.0 ], ) -def test_with_column(ray_start_regular_shared, batch_format, expr, expected_value): +def test_with_columns(ray_start_regular_shared, batch_format, expr, expected_value): """Verify that `with_column` works for pandas, numpy, and pyarrow batch formats with various operations.""" - ds = ray.data.range(5).with_column(expr, batch_format=batch_format) + ds = ray.data.range(5).with_columns([expr], batch_format=batch_format) result = ds.take(1)[0] assert result["id"] == 0 assert result["result"] == expected_value -def test_with_column_nonexistent_column(ray_start_regular_shared): +def test_with_columns_nonexistent_column(ray_start_regular_shared): """Verify that referencing a non-existent column with col() raises an exception.""" # Create a dataset with known column "id" ds = ray.data.range(5) # Try to reference a non-existent column - this should raise an exception with pytest.raises(UserCodeException): - ds.with_column((col("nonexistent_column") + 1).alias("result")).materialize() + ds.with_columns([(col("nonexistent_column") + 1).alias("result")]).materialize() + + +@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) +def test_with_columns_multiple_expressions(ray_start_regular_shared, batch_format): + """Verify that `with_column` correctly handles multiple expressions at once.""" + ds = ray.data.range(5) + + expr1 = (col("id") + 1).alias("plus_one") + expr2 = (col("id") * 2).alias("times_two") + expr3 = (lit(10) - col("id")).alias("ten_minus_id") + + ds = ds.with_columns([expr1, expr2, expr3], batch_format=batch_format) + + first_row = ds.take(1)[0] + assert first_row["id"] == 0 + assert first_row["plus_one"] == 1 + assert first_row["times_two"] == 0 + assert first_row["ten_minus_id"] == 10 + + # Ensure all new columns exist in the schema. + assert set(ds.schema().names) == {"id", "plus_one", "times_two", "ten_minus_id"} if __name__ == "__main__": From 9804716d61b16ef6c092d5a076e004381cbe201d Mon Sep 17 00:00:00 2001 From: Goutam V Date: Mon, 7 Jul 2025 00:03:51 -0700 Subject: [PATCH 03/29] Use project operator & update doc Signed-off-by: Goutam V --- .../logical/operators/map_operator.py | 7 ++ .../data/_internal/planner/plan_udf_map_op.py | 26 +++++-- python/ray/data/dataset.py | 69 +++++++------------ 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index ca5683aaf544..ddd6bd157fc2 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -7,6 +7,7 @@ from ray.data._internal.logical.interfaces import LogicalOperator from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne from ray.data.block import UserDefinedFunction +from ray.data.expressions import Expr from ray.data.preprocessor import Preprocessor if TYPE_CHECKING: @@ -263,6 +264,7 @@ def __init__( input_op: LogicalOperator, cols: Optional[List[str]] = None, cols_rename: Optional[Dict[str, str]] = None, + exprs: Optional[Dict[str, "Expr"]] = None, compute: Optional[ComputeStrategy] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): @@ -275,6 +277,7 @@ def __init__( self._batch_size = None self._cols = cols self._cols_rename = cols_rename + self._exprs = exprs self._batch_format = "pyarrow" self._zero_copy_batch = True @@ -286,6 +289,10 @@ def cols(self) -> Optional[List[str]]: def cols_rename(self) -> Optional[Dict[str, str]]: return self._cols_rename + @property + def exprs(self) -> Optional[Dict[str, "Expr"]]: + return self._exprs + def can_modify_num_rows(self) -> bool: return False diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 3561282bf2b6..c73bc813e928 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -49,6 +49,7 @@ ) from ray.data.context import DataContext from ray.data.exceptions import UserCodeException +from ray.data.expressions import eval_expr from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled logger = logging.getLogger(__name__) @@ -94,18 +95,35 @@ def plan_project_op( columns = op.cols columns_rename = op.cols_rename + exprs = op.exprs def fn(block: Block) -> Block: try: if not BlockAccessor.for_block(block).num_rows(): return block + tbl = BlockAccessor.for_block(block).to_arrow() + + # 1. evaluate / add expressions + if exprs: + for name, ex in exprs.items(): + arr = eval_expr(ex, tbl) + if name in tbl.column_names: + tbl = tbl.set_column( + tbl.schema.get_field_index(name), name, arr + ) + else: + tbl = tbl.append_column(name, arr) + + # 2. (optional) column projection if columns: - block = BlockAccessor.for_block(block).select(columns) + tbl = tbl.select(columns) + + # 3. (optional) rename if columns_rename: - block = block.rename_columns( - [columns_rename.get(col, col) for col in block.schema.names] + tbl = tbl.rename_columns( + [columns_rename.get(col, col) for col in tbl.schema.names] ) - return block + return tbl except Exception as e: _handle_debugger_exception(e, block) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 00b4379fb184..d0dee6fd9c90 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -134,7 +134,7 @@ from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data.grouped_data import GroupedData -from ray.data.expressions import AliasExpr, Expr, eval_expr +from ray.data.expressions import AliasExpr, Expr logger = logging.getLogger(__name__) @@ -793,6 +793,7 @@ def with_columns( Examples: >>> import ray + >>> import ray.data.expressions as col >>> ds = ray.data.range(100) >>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema() Column Type @@ -809,57 +810,37 @@ def with_columns( ``pyarrow.Table``. Returns: - A new dataset with the added column. + A new dataset with the added columns evaluated via expressions. """ if not exprs: raise ValueError("at least one expression is required") - accepted_batch_formats = ["pandas", "pyarrow", "numpy"] - if batch_format not in accepted_batch_formats: - raise ValueError( - f"batch_format argument must be on of {accepted_batch_formats}, " - f"got: {batch_format}" - ) - - projections = {} - for expr in exprs: - if not isinstance(expr, Expr): - raise TypeError(f"Expected Expr, got: {type(expr)}") - if isinstance(expr, AliasExpr): - projections[expr.name] = expr.expr + # Build mapping {new_col_name: expression} + projections: Dict[str, Expr] = {} + for e in exprs: + if not isinstance(e, Expr): + raise TypeError(f"Expected Expr, got {type(e)}") + if isinstance(e, AliasExpr): + projections[e.name] = e.expr else: - raise ValueError("Each expression must be `.alias()`-ed.") + raise ValueError("Each expression must be `.alias()`-ed.") - def _project(batch): - if isinstance(batch, dict): - for name, ex in projections.items(): - batch[name] = eval_expr(ex, batch) - return batch - - import pandas as pd - - if isinstance(batch, pd.DataFrame): - for name, ex in projections.items(): - batch[name] = eval_expr(ex, batch) - return batch + from ray.data._internal.compute import TaskPoolStrategy + from ray.data._internal.logical.operators.map_operator import Project - import pyarrow as pa + compute_strategy = TaskPoolStrategy(size=concurrency) - if isinstance(batch, pa.Table): - tbl = batch - for name, ex in projections.items(): - arr = eval_expr(ex, batch) - tbl = tbl.append_column(name, arr) - return tbl - - return self.map_batches( - _project, - batch_format=batch_format, - compute=compute, - concurrency=concurrency, - zero_copy_batch=False, - **ray_remote_args, + plan = self._plan.copy() + project_op = Project( + self._logical_plan.dag, + cols=None, + cols_rename=None, + exprs=projections, # << pass expressions + compute=compute_strategy, + ray_remote_args=ray_remote_args, ) + logical_plan = LogicalPlan(project_op, self.context) + return Dataset(plan, logical_plan) @PublicAPI(api_group=BT_API_GROUP) def add_column( @@ -4546,7 +4527,7 @@ def write_clickhouse( * order_by: Sets the `ORDER BY` clause in the `CREATE TABLE` statement, iff not provided. When overwriting an existing table, its previous `ORDER BY` (if any) is reused. - Otherwise, a “best” column is selected automatically (favoring a timestamp column, + Otherwise, a "best" column is selected automatically (favoring a timestamp column, then a non-string column, and lastly the first column). * partition_by: From 16cccffad2c8a0b6ce81a04f2e0cb53ffa549b9e Mon Sep 17 00:00:00 2001 From: Goutam V Date: Mon, 7 Jul 2025 00:53:31 -0700 Subject: [PATCH 04/29] Fix linting issue Signed-off-by: Goutam V --- python/ray/data/dataset.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index d0dee6fd9c90..14c2d29f3601 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -804,10 +804,13 @@ def with_columns( Args: exprs: The expressions to evaluate to produce the new column values. - batch_format: If ``"numpy"``, batches are - ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are - ``pandas.DataFrame``. If ``"pyarrow"``, batches are - ``pyarrow.Table``. + batch_format: This argument is deprecated and ignored. The operation + is performed using PyArrow format internally for efficiency. + compute: This argument is deprecated. Use ``concurrency`` argument. + concurrency: The maximum number of Ray workers to use concurrently. + **ray_remote_args: Additional resource requirements to request from + Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See + :func:`ray.remote` for details. Returns: A new dataset with the added columns evaluated via expressions. From a309598d43009148b5c9e05dc808679a574bac12 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Mon, 7 Jul 2025 09:18:39 -0700 Subject: [PATCH 05/29] Doc linter Signed-off-by: Goutam V --- python/ray/data/dataset.py | 4 ++-- python/ray/data/expressions.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 14c2d29f3601..c3d0283fd5ef 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -778,7 +778,7 @@ def _map_batches_without_batch_size_validation( logical_plan = LogicalPlan(map_batches_op, self.context) return Dataset(plan, logical_plan) - @PublicAPI(api_group=EXPRESSION_API_GROUP) + @PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha") def with_columns( self, exprs: List[Expr], @@ -793,7 +793,7 @@ def with_columns( Examples: >>> import ray - >>> import ray.data.expressions as col + >>> from ray.data.expressions import col >>> ds = ray.data.range(100) >>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema() Column Type diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 21c3ec06e8e1..8d8fd7114494 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -9,10 +9,13 @@ import pyarrow as pa import pyarrow.compute as pc +from ray.util.annotations import DeveloperAPI + # ────────────────────────────────────── # Basic expression node definitions # ────────────────────────────────────── +@DeveloperAPI class Expr: # Base class – all expression nodes inherit from this # Binary/boolean operator overloads def _bin(self, other: Any, op: str) -> "Expr": @@ -60,16 +63,19 @@ def alias(self, name: str) -> "AliasExpr": return AliasExpr(self, name) +@DeveloperAPI @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): name: str +@DeveloperAPI @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): value: Any +@DeveloperAPI @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): op: str @@ -77,6 +83,7 @@ class BinaryExpr(Expr): right: Expr +@DeveloperAPI @dataclass(frozen=True, eq=False) class AliasExpr(Expr): expr: Expr @@ -86,11 +93,15 @@ class AliasExpr(Expr): # ────────────────────────────────────── # User helpers # ────────────────────────────────────── + + +@DeveloperAPI def col(name: str) -> ColumnExpr: """Reference an existing column.""" return ColumnExpr(name) +@DeveloperAPI def lit(value: Any) -> LiteralExpr: """Create a scalar literal expression (e.g. lit(1)).""" return LiteralExpr(value) @@ -158,6 +169,7 @@ def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any: raise TypeError(f"Unsupported expression node: {type(expr).__name__}") +@DeveloperAPI def eval_expr(expr: Expr, batch) -> Any: """Recursively evaluate *expr* against a batch of the appropriate type.""" if isinstance(batch, pd.DataFrame): From 176b444b5275692bed7586b4aeddc45911ae79df Mon Sep 17 00:00:00 2001 From: Goutam V Date: Mon, 7 Jul 2025 11:11:17 -0700 Subject: [PATCH 06/29] doctest Signed-off-by: Goutam V --- python/ray/data/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index c3d0283fd5ef..8ac48f608431 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -796,10 +796,10 @@ def with_columns( >>> from ray.data.expressions import col >>> ds = ray.data.range(100) >>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema() - Column Type - ------ ---- - id int64 - new_id int64 + Column Type + ------ ---- + id int64 + new_id int64 new_id_2 int64 Args: From c17570ff459ba8d5843f8fe2f4f40e8f736316cf Mon Sep 17 00:00:00 2001 From: Goutam V Date: Mon, 7 Jul 2025 13:21:38 -0700 Subject: [PATCH 07/29] Address comment Signed-off-by: Goutam V --- python/ray/data/dataset.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 8ac48f608431..73e873c2bbc2 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -782,9 +782,7 @@ def _map_batches_without_batch_size_validation( def with_columns( self, exprs: List[Expr], - batch_format: Optional[str] = "pandas", - compute: Optional[str] = None, - concurrency: Optional[int] = None, + batch_format: Optional[str] = "pyarrow", **ray_remote_args, ) -> "Dataset": """ @@ -804,10 +802,11 @@ def with_columns( Args: exprs: The expressions to evaluate to produce the new column values. - batch_format: This argument is deprecated and ignored. The operation - is performed using PyArrow format internally for efficiency. - compute: This argument is deprecated. Use ``concurrency`` argument. - concurrency: The maximum number of Ray workers to use concurrently. + batch_format: If ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are + ``pandas.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. If ``"numpy"``, batches are + ``Dict[str, numpy.ndarray]``. **ray_remote_args: Additional resource requirements to request from Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. @@ -828,18 +827,14 @@ def with_columns( else: raise ValueError("Each expression must be `.alias()`-ed.") - from ray.data._internal.compute import TaskPoolStrategy from ray.data._internal.logical.operators.map_operator import Project - compute_strategy = TaskPoolStrategy(size=concurrency) - plan = self._plan.copy() project_op = Project( self._logical_plan.dag, cols=None, cols_rename=None, exprs=projections, # << pass expressions - compute=compute_strategy, ray_remote_args=ray_remote_args, ) logical_plan = LogicalPlan(project_op, self.context) From eea0d558733974c322bd0ab072934b0f6ad38b93 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Wed, 9 Jul 2025 13:39:08 -0700 Subject: [PATCH 08/29] Address comments Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 94 ++++++ python/ray/data/dataset.py | 6 - python/ray/data/expressions.py | 405 ++++++++++++++++------- python/ray/data/tests/test_map.py | 14 +- 4 files changed, 386 insertions(+), 133 deletions(-) create mode 100644 python/ray/data/_expression_evaluator.py diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py new file mode 100644 index 000000000000..490385b00e19 --- /dev/null +++ b/python/ray/data/_expression_evaluator.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import operator +from typing import Any, Callable, Dict, TYPE_CHECKING + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc + +# Use TYPE_CHECKING imports to avoid circular imports +if TYPE_CHECKING: + from ray.data.expressions import BinaryExpr, ColumnExpr, Expr, LiteralExpr, Operation # noqa: F401 + +def _get_operation_maps(): + """Get operation maps, importing Operation enum at runtime to avoid circular imports.""" + from ray.data.expressions import Operation + + pandas_ops = { + Operation.ADD: operator.add, + Operation.SUB: operator.sub, + Operation.MUL: operator.mul, + Operation.DIV: operator.truediv, + Operation.GT: operator.gt, + Operation.LT: operator.lt, + Operation.GE: operator.ge, + Operation.LE: operator.le, + Operation.EQ: operator.eq, + Operation.AND: operator.and_, + Operation.OR: operator.or_, + } + + numpy_ops = { + Operation.ADD: np.add, + Operation.SUB: np.subtract, + Operation.MUL: np.multiply, + Operation.DIV: np.divide, + Operation.GT: np.greater, + Operation.LT: np.less, + Operation.GE: np.greater_equal, + Operation.LE: np.less_equal, + Operation.EQ: np.equal, + Operation.AND: np.logical_and, + Operation.OR: np.logical_or, + } + + arrow_ops = { + Operation.ADD: pc.add, + Operation.SUB: pc.subtract, + Operation.MUL: pc.multiply, + Operation.DIV: pc.divide, + Operation.GT: pc.greater, + Operation.LT: pc.less, + Operation.GE: pc.greater_equal, + Operation.LE: pc.less_equal, + Operation.EQ: pc.equal, + Operation.AND: pc.and_, + Operation.OR: pc.or_, + } + + return pandas_ops, numpy_ops, arrow_ops + + +def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) -> Any: + """Generic recursive expression evaluator.""" + # Import classes at runtime to avoid circular imports + from ray.data.expressions import BinaryExpr, ColumnExpr, LiteralExpr + + # TODO: Separate unresolved expressions (arbitrary AST with unresolved refs) + # and resolved expressions (bound to a schema) for better error handling + + if isinstance(expr, ColumnExpr): + return batch[expr.name] + if isinstance(expr, LiteralExpr): + return expr.value + if isinstance(expr, BinaryExpr): + return ops[expr.op]( + _eval_expr_recursive(expr.left, batch, ops), + _eval_expr_recursive(expr.right, batch, ops), + ) + raise TypeError(f"Unsupported expression node: {type(expr).__name__}") + + +def eval_expr(expr: "Expr", batch) -> Any: + """Recursively evaluate *expr* against a batch of the appropriate type.""" + pandas_ops, numpy_ops, arrow_ops = _get_operation_maps() + + if isinstance(batch, pd.DataFrame): + return _eval_expr_recursive(expr, batch, pandas_ops) + elif isinstance(batch, (np.ndarray, dict)): + return _eval_expr_recursive(expr, batch, numpy_ops) + elif isinstance(batch, pa.Table): + return _eval_expr_recursive(expr, batch, arrow_ops) + raise TypeError(f"Unsupported batch type: {type(batch).__name__}") diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 73e873c2bbc2..47dae059329f 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -782,7 +782,6 @@ def _map_batches_without_batch_size_validation( def with_columns( self, exprs: List[Expr], - batch_format: Optional[str] = "pyarrow", **ray_remote_args, ) -> "Dataset": """ @@ -802,11 +801,6 @@ def with_columns( Args: exprs: The expressions to evaluate to produce the new column values. - batch_format: If ``"numpy"``, batches are - ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are - ``pandas.DataFrame``. If ``"pyarrow"``, batches are - ``pyarrow.Table``. If ``"numpy"``, batches are - ``Dict[str, numpy.ndarray]``. **ray_remote_args: Additional resource requirements to request from Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 8d8fd7114494..4220000f6b6c 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -1,181 +1,348 @@ from __future__ import annotations -import operator from dataclasses import dataclass -from typing import Any, Callable, Dict - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc +from enum import Enum +from typing import Any +from ray.data._expression_evaluator import eval_expr from ray.util.annotations import DeveloperAPI -# ────────────────────────────────────── -# Basic expression node definitions -# ────────────────────────────────────── @DeveloperAPI -class Expr: # Base class – all expression nodes inherit from this - # Binary/boolean operator overloads - def _bin(self, other: Any, op: str) -> "Expr": - other = other if isinstance(other, Expr) else LiteralExpr(other) +@dataclass(frozen=True, eq=False) +class Operation(Enum): + """Enumeration of supported operations in expressions. + + This enum defines all the binary operations that can be performed + between expressions, including arithmetic, comparison, and boolean operations. + + Attributes: + ADD: Addition operation (+) + SUB: Subtraction operation (-) + MUL: Multiplication operation (*) + DIV: Division operation (/) + GT: Greater than comparison (>) + LT: Less than comparison (<) + GE: Greater than or equal comparison (>=) + LE: Less than or equal comparison (<=) + EQ: Equality comparison (==) + AND: Logical AND operation (&) + OR: Logical OR operation (|) + """ + + ADD = "add" + SUB = "sub" + MUL = "mul" + DIV = "div" + GT = "gt" + LT = "lt" + GE = "ge" + LE = "le" + EQ = "eq" + AND = "and" + OR = "or" + + +@DeveloperAPI +@dataclass(frozen=True) +class Expr: + """Base class for all expression nodes. + + This is the abstract base class that all expression types inherit from. + It provides operator overloads for building complex expressions using + standard Python operators. + + Expressions form a tree structure where each node represents an operation + or value. The tree can be evaluated against data batches to compute results. + + Example: + >>> from ray.data.expressions import col, lit + >>> # Create an expression tree: (col("x") + 5) * col("y") + >>> expr = (col("x") + lit(5)) * col("y") + >>> # This creates a BinaryExpr with operation=MUL + >>> # left=BinaryExpr(op=ADD, left=ColumnExpr("x"), right=LiteralExpr(5)) + >>> # right=ColumnExpr("y") + + Note: + This class should not be instantiated directly. Use the concrete + subclasses like ColumnExpr, LiteralExpr, etc. + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # Override dataclass __eq__ with expression __eq__ + cls.__eq__ = cls._expr_eq + + def _expr_eq(self, other: Any) -> "Expr": + """Expression equality operator.""" + return self._bin(other, Operation.EQ) + + def _bin(self, other: Any, op: Operation) -> "Expr": + """Create a binary expression with the given operation. + + Args: + other: The right operand expression or literal value + op: The operation to perform + + Returns: + A new BinaryExpr representing the operation + + Note: + If other is not an Expr, it will be automatically converted to a LiteralExpr. + """ + if not isinstance(other, Expr): + other = LiteralExpr(other) return BinaryExpr(op, self, other) # arithmetic - def __add__(self, other): - return self._bin(other, "add") + def __add__(self, other: Any) -> "Expr": + """Addition operator (+).""" + return self._bin(other, Operation.ADD) + + def __radd__(self, other: Any) -> "Expr": + """Reverse addition operator (for literal + expr).""" + return LiteralExpr(other)._bin(self, Operation.ADD) + + def __sub__(self, other: Any) -> "Expr": + """Subtraction operator (-).""" + return self._bin(other, Operation.SUB) + + def __rsub__(self, other: Any) -> "Expr": + """Reverse subtraction operator (for literal - expr).""" + return LiteralExpr(other)._bin(self, Operation.SUB) + + def __mul__(self, other: Any) -> "Expr": + """Multiplication operator (*).""" + return self._bin(other, Operation.MUL) - def __sub__(self, other): - return self._bin(other, "sub") + def __rmul__(self, other: Any) -> "Expr": + """Reverse multiplication operator (for literal * expr).""" + return LiteralExpr(other)._bin(self, Operation.MUL) - def __mul__(self, other): - return self._bin(other, "mul") + def __truediv__(self, other: Any) -> "Expr": + """Division operator (/).""" + return self._bin(other, Operation.DIV) - def __truediv__(self, other): - return self._bin(other, "div") + def __rtruediv__(self, other: Any) -> "Expr": + """Reverse division operator (for literal / expr).""" + return LiteralExpr(other)._bin(self, Operation.DIV) # comparison - def __gt__(self, other): - return self._bin(other, "gt") + def __gt__(self, other: Any) -> "Expr": + """Greater than operator (>).""" + return self._bin(other, Operation.GT) - def __lt__(self, other): - return self._bin(other, "lt") + def __lt__(self, other: Any) -> "Expr": + """Less than operator (<).""" + return self._bin(other, Operation.LT) - def __ge__(self, other): - return self._bin(other, "ge") + def __ge__(self, other: Any) -> "Expr": + """Greater than or equal operator (>=).""" + return self._bin(other, Operation.GE) - def __le__(self, other): - return self._bin(other, "le") + def __le__(self, other: Any) -> "Expr": + """Less than or equal operator (<=).""" + return self._bin(other, Operation.LE) - def __eq__(self, other): - return self._bin(other, "eq") + def __eq__(self, other: Any) -> "Expr": + """Equality operator (==).""" + return self._bin(other, Operation.EQ) # boolean - def __and__(self, other): - return self._bin(other, "and") + def __and__(self, other: Any) -> "Expr": + """Logical AND operator (&).""" + return self._bin(other, Operation.AND) - def __or__(self, other): - return self._bin(other, "or") + def __or__(self, other: Any) -> "Expr": + """Logical OR operator (|).""" + return self._bin(other, Operation.OR) - # Rename the output column def alias(self, name: str) -> "AliasExpr": + """Give this expression a new name. + + Args: + name: The new name for the expression result + + Returns: + An AliasExpr that evaluates this expression and assigns the result + to the given name + + Example: + >>> from ray.data.expressions import col + >>> # Create a column named "sum" from adding x and y + >>> expr = (col("x") + col("y")).alias("sum") + """ return AliasExpr(self, name) @DeveloperAPI -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class ColumnExpr(Expr): + """Expression that references a column by name. + + This expression type represents a reference to an existing column + in the dataset. When evaluated, it returns the values from the + specified column. + + Args: + name: The name of the column to reference + + Example: + >>> from ray.data.expressions import col + >>> # Reference the "age" column + >>> age_expr = col("age") # Creates ColumnExpr(name="age") + """ + name: str @DeveloperAPI -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class LiteralExpr(Expr): + """Expression that represents a constant scalar value. + + This expression type represents a literal value that will be broadcast + to all rows when evaluated. The value can be any Python object. + + Args: + value: The constant value to represent + + Example: + >>> from ray.data.expressions import lit + >>> # Create a literal value + >>> five = lit(5) # Creates LiteralExpr(value=5) + >>> name = lit("John") # Creates LiteralExpr(value="John") + """ + value: Any @DeveloperAPI -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class BinaryExpr(Expr): - op: str + """Expression that represents a binary operation between two expressions. + + This expression type represents an operation with two operands (left and right). + The operation is specified by the `op` field, which must be one of the + supported operations from the Operation enum. + + Args: + op: The operation to perform (from Operation enum) + left: The left operand expression + right: The right operand expression + + Example: + >>> from ray.data.expressions import col, lit, Operation + >>> # Manually create a binary expression (usually done via operators) + >>> expr = BinaryExpr(Operation.ADD, col("x"), lit(5)) + >>> # This is equivalent to: col("x") + lit(5) + """ + + op: Operation left: Expr right: Expr @DeveloperAPI -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class AliasExpr(Expr): - expr: Expr - name: str + """Expression that assigns a name to another expression's result. + This expression type wraps another expression and assigns a specific + name to its result. This is useful for creating new columns with + meaningful names. -# ────────────────────────────────────── -# User helpers -# ────────────────────────────────────── + Args: + expr: The expression to evaluate + name: The name to assign to the result + + Example: + >>> from ray.data.expressions import col + >>> # Create an alias for a computed column + >>> expr = AliasExpr(col("x") + col("y"), "sum") + >>> # This is equivalent to: (col("x") + col("y")).alias("sum") + """ + + expr: Expr + name: str @DeveloperAPI def col(name: str) -> ColumnExpr: - """Reference an existing column.""" + """Reference an existing column by name. + + This is the primary way to reference columns in expressions. + The returned expression will extract values from the specified + column when evaluated. + + Args: + name: The name of the column to reference + + Returns: + A ColumnExpr that references the specified column + + Example: + >>> from ray.data.expressions import col + >>> # Reference columns in an expression + >>> expr = col("price") * col("quantity") + >>> + >>> # Use with Dataset.with_columns() + >>> import ray + >>> ds = ray.data.from_items([{"price": 10, "quantity": 2}]) + >>> ds = ds.with_columns(total=col("price") * col("quantity")) + """ return ColumnExpr(name) @DeveloperAPI def lit(value: Any) -> LiteralExpr: - """Create a scalar literal expression (e.g. lit(1)).""" + """Create a literal expression from a constant value. + + This creates an expression that represents a constant scalar value. + The value will be broadcast to all rows when the expression is evaluated. + + Args: + value: The constant value to represent. Can be any Python object + (int, float, str, bool, etc.) + + Returns: + A LiteralExpr containing the specified value + + Example: + >>> from ray.data.expressions import col, lit + >>> # Create literals of different types + >>> five = lit(5) + >>> pi = lit(3.14159) + >>> name = lit("Alice") + >>> flag = lit(True) + >>> + >>> # Use in expressions + >>> expr = col("age") + lit(1) # Add 1 to age column + >>> + >>> # Use with Dataset.with_columns() + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> ds = ds.with_columns(age_plus_one=col("age") + lit(1)) + """ return LiteralExpr(value) # ────────────────────────────────────── -# Local evaluator (pandas batches) +# Public API for evaluation # ────────────────────────────────────── -# This is used by Dataset.with_columns – kept here so it can be re-used by -# future optimised executors. -_PANDAS_OPS: Dict[str, Callable[[Any, Any], Any]] = { - "add": operator.add, - "sub": operator.sub, - "mul": operator.mul, - "div": operator.truediv, - "gt": operator.gt, - "lt": operator.lt, - "ge": operator.ge, - "le": operator.le, - "eq": operator.eq, - "and": operator.and_, - "or": operator.or_, -} - -_NUMPY_OPS: Dict[str, Callable[[Any, Any], Any]] = { - "add": np.add, - "sub": np.subtract, - "mul": np.multiply, - "div": np.divide, - "gt": np.greater, - "lt": np.less, - "ge": np.greater_equal, - "le": np.less_equal, - "eq": np.equal, - "and": np.logical_and, - "or": np.logical_or, -} - -_ARROW_OPS: Dict[str, Callable[[Any, Any], Any]] = { - "add": pc.add, - "sub": pc.subtract, - "mul": pc.multiply, - "div": pc.divide, - "gt": pc.greater, - "lt": pc.less, - "ge": pc.greater_equal, - "le": pc.less_equal, - "eq": pc.equal, - "and": pc.and_, - "or": pc.or_, -} - - -def _eval_expr_recursive(expr: Expr, batch, ops: Dict[str, Callable]) -> Any: - """Generic recursive expression evaluator.""" - if isinstance(expr, ColumnExpr): - return batch[expr.name] - if isinstance(expr, LiteralExpr): - return expr.value - if isinstance(expr, BinaryExpr): - return ops[expr.op]( - _eval_expr_recursive(expr.left, batch, ops), - _eval_expr_recursive(expr.right, batch, ops), - ) - raise TypeError(f"Unsupported expression node: {type(expr).__name__}") - - -@DeveloperAPI -def eval_expr(expr: Expr, batch) -> Any: - """Recursively evaluate *expr* against a batch of the appropriate type.""" - if isinstance(batch, pd.DataFrame): - return _eval_expr_recursive(expr, batch, _PANDAS_OPS) - elif isinstance(batch, (np.ndarray, dict)): - return _eval_expr_recursive(expr, batch, _NUMPY_OPS) - elif isinstance(batch, pa.Table): - return _eval_expr_recursive(expr, batch, _ARROW_OPS) - raise TypeError(f"Unsupported batch type: {type(batch).__name__}") +# Note: Implementation details are in _expression_evaluator.py + +# Re-export eval_expr for public use + + +__all__ = [ + "Operation", + "Expr", + "ColumnExpr", + "LiteralExpr", + "BinaryExpr", + "AliasExpr", + "col", + "lit", + "eval_expr", +] diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 61c8689b2436..745e49fb28db 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1983,7 +1983,6 @@ def func(x, y): assert r.startswith("OneHotEncoder"), r -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) @pytest.mark.parametrize( "expr, expected_value", [ @@ -2010,9 +2009,9 @@ def func(x, y): ((lit(10) / (col("id") + 1)).alias("result"), 10.0), # 10 / (0 + 1) = 10.0 ], ) -def test_with_columns(ray_start_regular_shared, batch_format, expr, expected_value): - """Verify that `with_column` works for pandas, numpy, and pyarrow batch formats with various operations.""" - ds = ray.data.range(5).with_columns([expr], batch_format=batch_format) +def test_with_columns(ray_start_regular_shared, expr, expected_value): + """Verify that `with_column` works with various operations.""" + ds = ray.data.range(5).with_columns([expr]) result = ds.take(1)[0] assert result["id"] == 0 assert result["result"] == expected_value @@ -2028,16 +2027,15 @@ def test_with_columns_nonexistent_column(ray_start_regular_shared): ds.with_columns([(col("nonexistent_column") + 1).alias("result")]).materialize() -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) -def test_with_columns_multiple_expressions(ray_start_regular_shared, batch_format): +def test_with_columns_multiple_expressions(ray_start_regular_shared): """Verify that `with_column` correctly handles multiple expressions at once.""" ds = ray.data.range(5) expr1 = (col("id") + 1).alias("plus_one") expr2 = (col("id") * 2).alias("times_two") - expr3 = (lit(10) - col("id")).alias("ten_minus_id") + expr3 = (10 - col("id")).alias("ten_minus_id") - ds = ds.with_columns([expr1, expr2, expr3], batch_format=batch_format) + ds = ds.with_columns([expr1, expr2, expr3]) first_row = ds.take(1)[0] assert first_row["id"] == 0 From ed251a1736b4c132b41fd55dd4709dea254632ff Mon Sep 17 00:00:00 2001 From: Goutam V Date: Wed, 9 Jul 2025 13:45:30 -0700 Subject: [PATCH 09/29] Linter & remove dataclass for operations Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 8 ++++++-- python/ray/data/expressions.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 490385b00e19..486a169fe736 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -1,7 +1,7 @@ from __future__ import annotations import operator -from typing import Any, Callable, Dict, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict import numpy as np import pandas as pd @@ -10,7 +10,11 @@ # Use TYPE_CHECKING imports to avoid circular imports if TYPE_CHECKING: - from ray.data.expressions import BinaryExpr, ColumnExpr, Expr, LiteralExpr, Operation # noqa: F401 + from ray.data.expressions import ( + Expr, + Operation, + ) # noqa: F401 + def _get_operation_maps(): """Get operation maps, importing Operation enum at runtime to avoid circular imports.""" diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 4220000f6b6c..1372298c2b87 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -9,7 +9,6 @@ @DeveloperAPI -@dataclass(frozen=True, eq=False) class Operation(Enum): """Enumeration of supported operations in expressions. From e36a87b79590ca70c781e111ce5dbe16277e86b3 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Wed, 9 Jul 2025 15:27:51 -0700 Subject: [PATCH 10/29] Address comments Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 21 +------ .../logical/operators/map_operator.py | 12 +++- .../data/_internal/planner/plan_udf_map_op.py | 35 +++++++----- python/ray/data/dataset.py | 22 ++------ python/ray/data/expressions.py | 41 -------------- python/ray/data/tests/test_map.py | 56 ++++++++++--------- 6 files changed, 69 insertions(+), 118 deletions(-) diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 486a169fe736..5c20216de88b 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -3,7 +3,6 @@ import operator from typing import TYPE_CHECKING, Any, Callable, Dict -import numpy as np import pandas as pd import pyarrow as pa import pyarrow.compute as pc @@ -34,20 +33,6 @@ def _get_operation_maps(): Operation.OR: operator.or_, } - numpy_ops = { - Operation.ADD: np.add, - Operation.SUB: np.subtract, - Operation.MUL: np.multiply, - Operation.DIV: np.divide, - Operation.GT: np.greater, - Operation.LT: np.less, - Operation.GE: np.greater_equal, - Operation.LE: np.less_equal, - Operation.EQ: np.equal, - Operation.AND: np.logical_and, - Operation.OR: np.logical_or, - } - arrow_ops = { Operation.ADD: pc.add, Operation.SUB: pc.subtract, @@ -62,7 +47,7 @@ def _get_operation_maps(): Operation.OR: pc.or_, } - return pandas_ops, numpy_ops, arrow_ops + return pandas_ops, arrow_ops def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) -> Any: @@ -87,12 +72,10 @@ def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) def eval_expr(expr: "Expr", batch) -> Any: """Recursively evaluate *expr* against a batch of the appropriate type.""" - pandas_ops, numpy_ops, arrow_ops = _get_operation_maps() + pandas_ops, arrow_ops = _get_operation_maps() if isinstance(batch, pd.DataFrame): return _eval_expr_recursive(expr, batch, pandas_ops) - elif isinstance(batch, (np.ndarray, dict)): - return _eval_expr_recursive(expr, batch, numpy_ops) elif isinstance(batch, pa.Table): return _eval_expr_recursive(expr, batch, arrow_ops) raise TypeError(f"Unsupported batch type: {type(batch).__name__}") diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index ddd6bd157fc2..6b1bcefba5ad 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -264,7 +264,9 @@ def __init__( input_op: LogicalOperator, cols: Optional[List[str]] = None, cols_rename: Optional[Dict[str, str]] = None, - exprs: Optional[Dict[str, "Expr"]] = None, + exprs: Optional[ + Dict[str, "Expr"] + ] = None, # TODO Remove cols and cols_rename and replace them with corresponding exprs compute: Optional[ComputeStrategy] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): @@ -281,6 +283,14 @@ def __init__( self._batch_format = "pyarrow" self._zero_copy_batch = True + if exprs is not None: + # Validate that all values are expressions + for name, expr in exprs.items(): + if not isinstance(expr, Expr): + raise TypeError( + f"Expected Expr for column '{name}', got {type(expr)}" + ) + @property def cols(self) -> Optional[List[str]]: return self._cols diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index c73bc813e928..da5488523f66 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -99,31 +99,38 @@ def plan_project_op( def fn(block: Block) -> Block: try: - if not BlockAccessor.for_block(block).num_rows(): + block_accessor = BlockAccessor.for_block(block) + if not block_accessor.num_rows(): return block - tbl = BlockAccessor.for_block(block).to_arrow() # 1. evaluate / add expressions if exprs: + # Extract existing columns directly from the block + new_columns = {} + for col_name in block_accessor.column_names(): + # For Arrow blocks, block[col_name] gives us a ChunkedArray + # For Pandas blocks, block[col_name] gives us a Series + new_columns[col_name] = block[col_name] + + # Add/update with expression results for name, ex in exprs.items(): - arr = eval_expr(ex, tbl) - if name in tbl.column_names: - tbl = tbl.set_column( - tbl.schema.get_field_index(name), name, arr - ) - else: - tbl = tbl.append_column(name, arr) + result = eval_expr(ex, block) + new_columns[name] = result + + # Create new block from updated columns + block = BlockAccessor.batch_to_block(new_columns) + block_accessor = BlockAccessor.for_block(block) # 2. (optional) column projection if columns: - tbl = tbl.select(columns) + block = block_accessor.select(columns) + block_accessor = BlockAccessor.for_block(block) # 3. (optional) rename if columns_rename: - tbl = tbl.rename_columns( - [columns_rename.get(col, col) for col in tbl.schema.names] - ) - return tbl + block = block_accessor.rename_columns(columns_rename) + + return block except Exception as e: _handle_debugger_exception(e, block) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 47dae059329f..831ae82c5f2e 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -134,7 +134,7 @@ from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data.grouped_data import GroupedData -from ray.data.expressions import AliasExpr, Expr +from ray.data.expressions import Expr logger = logging.getLogger(__name__) @@ -781,7 +781,7 @@ def _map_batches_without_batch_size_validation( @PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha") def with_columns( self, - exprs: List[Expr], + exprs: Dict[str, Expr], **ray_remote_args, ) -> "Dataset": """ @@ -792,7 +792,7 @@ def with_columns( >>> import ray >>> from ray.data.expressions import col >>> ds = ray.data.range(100) - >>> ds.with_columns([(col("id") * 2).alias("new_id"), (col("id") * 3).alias("new_id_2")]).schema() + >>> ds.with_columns({"new_id": col("id") * 2, "new_id_2": col("id") * 3}).schema() Column Type ------ ---- id int64 @@ -800,7 +800,7 @@ def with_columns( new_id_2 int64 Args: - exprs: The expressions to evaluate to produce the new column values. + exprs: A dictionary mapping column names to expressions that define the new column values. **ray_remote_args: Additional resource requirements to request from Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. @@ -811,16 +811,6 @@ def with_columns( if not exprs: raise ValueError("at least one expression is required") - # Build mapping {new_col_name: expression} - projections: Dict[str, Expr] = {} - for e in exprs: - if not isinstance(e, Expr): - raise TypeError(f"Expected Expr, got {type(e)}") - if isinstance(e, AliasExpr): - projections[e.name] = e.expr - else: - raise ValueError("Each expression must be `.alias()`-ed.") - from ray.data._internal.logical.operators.map_operator import Project plan = self._plan.copy() @@ -828,7 +818,7 @@ def with_columns( self._logical_plan.dag, cols=None, cols_rename=None, - exprs=projections, # << pass expressions + exprs=exprs, ray_remote_args=ray_remote_args, ) logical_plan = LogicalPlan(project_op, self.context) @@ -5077,7 +5067,7 @@ def to_torch( using a local in-memory shuffle buffer, and this value will serve as the minimum number of rows that must be in the local in-memory shuffle buffer in order to yield a batch. When there are no more rows to add to - the buffer, the remaining rows in the buffer is drained. This + the buffer, the remaining rows in the buffer are drained. This buffer size must be greater than or equal to ``batch_size``, and therefore ``batch_size`` must also be specified when using local shuffling. diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 1372298c2b87..1193f0c9e56a 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -156,23 +156,6 @@ def __or__(self, other: Any) -> "Expr": """Logical OR operator (|).""" return self._bin(other, Operation.OR) - def alias(self, name: str) -> "AliasExpr": - """Give this expression a new name. - - Args: - name: The new name for the expression result - - Returns: - An AliasExpr that evaluates this expression and assigns the result - to the given name - - Example: - >>> from ray.data.expressions import col - >>> # Create a column named "sum" from adding x and y - >>> expr = (col("x") + col("y")).alias("sum") - """ - return AliasExpr(self, name) - @DeveloperAPI @dataclass(frozen=True) @@ -242,30 +225,6 @@ class BinaryExpr(Expr): right: Expr -@DeveloperAPI -@dataclass(frozen=True) -class AliasExpr(Expr): - """Expression that assigns a name to another expression's result. - - This expression type wraps another expression and assigns a specific - name to its result. This is useful for creating new columns with - meaningful names. - - Args: - expr: The expression to evaluate - name: The name to assign to the result - - Example: - >>> from ray.data.expressions import col - >>> # Create an alias for a computed column - >>> expr = AliasExpr(col("x") + col("y"), "sum") - >>> # This is equivalent to: (col("x") + col("y")).alias("sum") - """ - - expr: Expr - name: str - - @DeveloperAPI def col(name: str) -> ColumnExpr: """Reference an existing column by name. diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 745e49fb28db..91f1af6c77ac 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -1984,34 +1984,34 @@ def func(x, y): @pytest.mark.parametrize( - "expr, expected_value", + "exprs, expected_value", [ # Arithmetic operations - ((col("id") + 1).alias("result"), 1), # 0 + 1 = 1 - ((col("id") + 5).alias("result"), 5), # 0 + 5 = 5 - ((col("id") - 1).alias("result"), -1), # 0 - 1 = -1 - ((col("id") * 2).alias("result"), 0), # 0 * 2 = 0 - ((col("id") * 3).alias("result"), 0), # 0 * 3 = 0 - ((col("id") / 2).alias("result"), 0.0), # 0 / 2 = 0.0 + ({"result": col("id") + 1}, 1), # 0 + 1 = 1 + ({"result": col("id") + 5}, 5), # 0 + 5 = 5 + ({"result": col("id") - 1}, -1), # 0 - 1 = -1 + ({"result": col("id") * 2}, 0), # 0 * 2 = 0 + ({"result": col("id") * 3}, 0), # 0 * 3 = 0 + ({"result": col("id") / 2}, 0.0), # 0 / 2 = 0.0 # More complex arithmetic - (((col("id") + 1) * 2).alias("result"), 2), # (0 + 1) * 2 = 2 - (((col("id") * 2) + 3).alias("result"), 3), # 0 * 2 + 3 = 3 + ({"result": (col("id") + 1) * 2}, 2), # (0 + 1) * 2 = 2 + ({"result": (col("id") * 2) + 3}, 3), # 0 * 2 + 3 = 3 # Comparison operations - ((col("id") > 0).alias("result"), False), # 0 > 0 = False - ((col("id") >= 0).alias("result"), True), # 0 >= 0 = True - ((col("id") < 1).alias("result"), True), # 0 < 1 = True - ((col("id") <= 0).alias("result"), True), # 0 <= 0 = True - ((col("id") == 0).alias("result"), True), # 0 == 0 = True + ({"result": col("id") > 0}, False), # 0 > 0 = False + ({"result": col("id") >= 0}, True), # 0 >= 0 = True + ({"result": col("id") < 1}, True), # 0 < 1 = True + ({"result": col("id") <= 0}, True), # 0 <= 0 = True + ({"result": col("id") == 0}, True), # 0 == 0 = True # Operations with literals - ((col("id") + lit(10)).alias("result"), 10), # 0 + 10 = 10 - ((col("id") * lit(5)).alias("result"), 0), # 0 * 5 = 0 - ((lit(2) + col("id")).alias("result"), 2), # 2 + 0 = 2 - ((lit(10) / (col("id") + 1)).alias("result"), 10.0), # 10 / (0 + 1) = 10.0 + ({"result": col("id") + lit(10)}, 10), # 0 + 10 = 10 + ({"result": col("id") * lit(5)}, 0), # 0 * 5 = 0 + ({"result": lit(2) + col("id")}, 2), # 2 + 0 = 2 + ({"result": lit(10) / (col("id") + 1)}, 10.0), # 10 / (0 + 1) = 10.0 ], ) -def test_with_columns(ray_start_regular_shared, expr, expected_value): - """Verify that `with_column` works with various operations.""" - ds = ray.data.range(5).with_columns([expr]) +def test_with_columns(ray_start_regular_shared, exprs, expected_value): + """Verify that `with_columns` works with various operations.""" + ds = ray.data.range(5).with_columns(exprs) result = ds.take(1)[0] assert result["id"] == 0 assert result["result"] == expected_value @@ -2024,18 +2024,20 @@ def test_with_columns_nonexistent_column(ray_start_regular_shared): # Try to reference a non-existent column - this should raise an exception with pytest.raises(UserCodeException): - ds.with_columns([(col("nonexistent_column") + 1).alias("result")]).materialize() + ds.with_columns({"result": col("nonexistent_column") + 1}).materialize() def test_with_columns_multiple_expressions(ray_start_regular_shared): - """Verify that `with_column` correctly handles multiple expressions at once.""" + """Verify that `with_columns` correctly handles multiple expressions at once.""" ds = ray.data.range(5) - expr1 = (col("id") + 1).alias("plus_one") - expr2 = (col("id") * 2).alias("times_two") - expr3 = (10 - col("id")).alias("ten_minus_id") + exprs = { + "plus_one": col("id") + 1, + "times_two": col("id") * 2, + "ten_minus_id": 10 - col("id"), + } - ds = ds.with_columns([expr1, expr2, expr3]) + ds = ds.with_columns(exprs) first_row = ds.take(1)[0] assert first_row["id"] == 0 From 86bc9fbf5f88b07bf147d3ef5ed6362adf5b23f5 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Wed, 9 Jul 2025 15:31:21 -0700 Subject: [PATCH 11/29] revert old change Signed-off-by: Goutam V --- python/ray/data/_internal/planner/plan_udf_map_op.py | 7 ++++--- python/ray/data/expressions.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index da5488523f66..db833a54ddf5 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -123,12 +123,13 @@ def fn(block: Block) -> Block: # 2. (optional) column projection if columns: - block = block_accessor.select(columns) - block_accessor = BlockAccessor.for_block(block) + block = BlockAccessor.for_block(block).select(columns) # 3. (optional) rename if columns_rename: - block = block_accessor.rename_columns(columns_rename) + block = block.rename_columns( + [columns_rename.get(col, col) for col in block.schema.names] + ) return block except Exception as e: diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 1193f0c9e56a..47735eb41b3e 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -299,7 +299,6 @@ def lit(value: Any) -> LiteralExpr: "ColumnExpr", "LiteralExpr", "BinaryExpr", - "AliasExpr", "col", "lit", "eval_expr", From 3053b95f543bc298c08be08631460385c2c72ef3 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Wed, 9 Jul 2025 15:32:51 -0700 Subject: [PATCH 12/29] Remove unnecessary arg Signed-off-by: Goutam V --- python/ray/data/dataset.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 831ae82c5f2e..9877449ff733 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -779,11 +779,7 @@ def _map_batches_without_batch_size_validation( return Dataset(plan, logical_plan) @PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha") - def with_columns( - self, - exprs: Dict[str, Expr], - **ray_remote_args, - ) -> "Dataset": + def with_columns(self, exprs: Dict[str, Expr]) -> "Dataset": """ Add new columns to the dataset. @@ -801,9 +797,6 @@ def with_columns( Args: exprs: A dictionary mapping column names to expressions that define the new column values. - **ray_remote_args: Additional resource requirements to request from - Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See - :func:`ray.remote` for details. Returns: A new dataset with the added columns evaluated via expressions. @@ -819,7 +812,6 @@ def with_columns( cols=None, cols_rename=None, exprs=exprs, - ray_remote_args=ray_remote_args, ) logical_plan = LogicalPlan(project_op, self.context) return Dataset(plan, logical_plan) From fb3c6a16d5e0e08c99c8c571cde669daddaaf2e0 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 09:16:53 -0700 Subject: [PATCH 13/29] doctest + pytest skip if version is not met Signed-off-by: Goutam V --- python/ray/data/expressions.py | 4 ++-- python/ray/data/tests/test_map.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 47735eb41b3e..8974b35e4fa0 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -247,7 +247,7 @@ def col(name: str) -> ColumnExpr: >>> # Use with Dataset.with_columns() >>> import ray >>> ds = ray.data.from_items([{"price": 10, "quantity": 2}]) - >>> ds = ds.with_columns(total=col("price") * col("quantity")) + >>> ds = ds.with_columns({"total": col("price") * col("quantity")}) """ return ColumnExpr(name) @@ -280,7 +280,7 @@ def lit(value: Any) -> LiteralExpr: >>> # Use with Dataset.with_columns() >>> import ray >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) - >>> ds = ds.with_columns(age_plus_one=col("age") + lit(1)) + >>> ds = ds.with_columns({"age_plus_one": col("age") + lit(1)}) """ return LiteralExpr(value) diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 31ca455866dc..896cd78de907 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -16,6 +16,7 @@ import pyarrow.compute as pc import pyarrow.parquet as pq import pytest +from pkg_resources import parse_version import ray from ray._common.test_utils import wait_for_condition @@ -2207,6 +2208,10 @@ def func(x, y): assert r.startswith("OneHotEncoder"), r +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_columns requires PyArrow >= 20.0.0", +) @pytest.mark.parametrize( "exprs, expected_value", [ @@ -2241,6 +2246,10 @@ def test_with_columns(ray_start_regular_shared, exprs, expected_value): assert result["result"] == expected_value +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_columns requires PyArrow >= 20.0.0", +) def test_with_columns_nonexistent_column(ray_start_regular_shared): """Verify that referencing a non-existent column with col() raises an exception.""" # Create a dataset with known column "id" @@ -2251,6 +2260,10 @@ def test_with_columns_nonexistent_column(ray_start_regular_shared): ds.with_columns({"result": col("nonexistent_column") + 1}).materialize() +@pytest.mark.skipif( + get_pyarrow_version() < parse_version("20.0.0"), + reason="with_columns requires PyArrow >= 20.0.0", +) def test_with_columns_multiple_expressions(ray_start_regular_shared): """Verify that `with_columns` correctly handles multiple expressions at once.""" ds = ray.data.range(5) From bd7bc77dd2ca0df5a3a63f5be7b802901bc1f190 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 10:18:09 -0700 Subject: [PATCH 14/29] Remove circular dep Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 12 +++++------- python/ray/data/_internal/planner/plan_udf_map_op.py | 2 +- python/ray/data/expressions.py | 2 -- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 5c20216de88b..6b4508e84cfc 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -1,18 +1,16 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import Any, Callable, Dict import pandas as pd import pyarrow as pa import pyarrow.compute as pc -# Use TYPE_CHECKING imports to avoid circular imports -if TYPE_CHECKING: - from ray.data.expressions import ( - Expr, - Operation, - ) # noqa: F401 +from ray.data.expressions import ( + Expr, + Operation, +) def _get_operation_maps(): diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index b8ad9898d04e..597ffe5145c5 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -14,6 +14,7 @@ import ray from ray._common.utils import get_or_create_event_loop from ray._private.ray_constants import env_integer +from ray.data._expression_evaluator import eval_expr from ray.data._internal.compute import get_compute from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -50,7 +51,6 @@ ) from ray.data.context import DataContext from ray.data.exceptions import UserCodeException -from ray.data.expressions import eval_expr from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled logger = logging.getLogger(__name__) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 8974b35e4fa0..8bc001df3203 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -4,7 +4,6 @@ from enum import Enum from typing import Any -from ray.data._expression_evaluator import eval_expr from ray.util.annotations import DeveloperAPI @@ -301,5 +300,4 @@ def lit(value: Any) -> LiteralExpr: "BinaryExpr", "col", "lit", - "eval_expr", ] From 6d443f86ced7e7ef0854a7450d3f280c4d1d6ce0 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 10:47:27 -0700 Subject: [PATCH 15/29] Address comments Signed-off-by: Goutam V --- python/ray/data/_internal/block_builder.py | 34 ++++++++++++++++++- .../data/_internal/planner/plan_udf_map_op.py | 17 ++-------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/python/ray/data/_internal/block_builder.py b/python/ray/data/_internal/block_builder.py index d3c232200a9d..b79fbcee6693 100644 --- a/python/ray/data/_internal/block_builder.py +++ b/python/ray/data/_internal/block_builder.py @@ -1,6 +1,7 @@ -from typing import Generic +from typing import Dict, Generic from ray.data.block import Block, BlockAccessor, BlockType, T +from ray.data.expressions import Expr class BlockBuilder(Generic[T]): @@ -18,6 +19,37 @@ def add_block(self, block: Block) -> None: """Append an entire block to the block being built.""" raise NotImplementedError + def append_columns(self, block: Block, exprs: Dict[str, Expr]) -> Block: + """Add columns from evaluated expressions to a new builder. + + Args: + block: The source block to copy existing columns from + exprs: A dictionary mapping new column names to expressions that + define the column values. + + Returns: + A new block with existing columns from block and new columns from expressions. + """ + from ray.data._expression_evaluator import eval_expr + + # Extract existing columns directly from the block + block_accessor = BlockAccessor.for_block(block) + new_columns = {} + for col_name in block_accessor.column_names(): + # For Arrow blocks, block[col_name] gives us a ChunkedArray + # For Pandas blocks, block[col_name] gives us a Series + new_columns[col_name] = block[col_name] + + # Add/update with expression results + for name, expr in exprs.items(): + result = eval_expr(expr, block) + new_columns[name] = result + + # Create a new block from the combined columns and add it + new_block = BlockAccessor.batch_to_block(new_columns) + + return new_block + def will_build_yield_copy(self) -> bool: """Whether building this block will yield a new block copy.""" raise NotImplementedError diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 597ffe5145c5..810a3ff1daf6 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -14,7 +14,6 @@ import ray from ray._common.utils import get_or_create_event_loop from ray._private.ray_constants import env_integer -from ray.data._expression_evaluator import eval_expr from ray.data._internal.compute import get_compute from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -117,20 +116,8 @@ def fn(block: Block) -> Block: # 1. evaluate / add expressions if exprs: - # Extract existing columns directly from the block - new_columns = {} - for col_name in block_accessor.column_names(): - # For Arrow blocks, block[col_name] gives us a ChunkedArray - # For Pandas blocks, block[col_name] gives us a Series - new_columns[col_name] = block[col_name] - - # Add/update with expression results - for name, ex in exprs.items(): - result = eval_expr(ex, block) - new_columns[name] = result - - # Create new block from updated columns - block = BlockAccessor.batch_to_block(new_columns) + builder = block_accessor.builder() + block = builder.append_columns(block, exprs) block_accessor = BlockAccessor.for_block(block) # 2. (optional) column projection From fc034ec0cb5f2f71aaf0b8c6fdd0a47a2c89b237 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 11:20:08 -0700 Subject: [PATCH 16/29] remove change in block builder Signed-off-by: Goutam V --- python/ray/data/_internal/block_builder.py | 31 ------------------- .../data/_internal/planner/plan_udf_map_op.py | 16 ++++++++-- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/python/ray/data/_internal/block_builder.py b/python/ray/data/_internal/block_builder.py index b79fbcee6693..98bb8e40cc94 100644 --- a/python/ray/data/_internal/block_builder.py +++ b/python/ray/data/_internal/block_builder.py @@ -19,37 +19,6 @@ def add_block(self, block: Block) -> None: """Append an entire block to the block being built.""" raise NotImplementedError - def append_columns(self, block: Block, exprs: Dict[str, Expr]) -> Block: - """Add columns from evaluated expressions to a new builder. - - Args: - block: The source block to copy existing columns from - exprs: A dictionary mapping new column names to expressions that - define the column values. - - Returns: - A new block with existing columns from block and new columns from expressions. - """ - from ray.data._expression_evaluator import eval_expr - - # Extract existing columns directly from the block - block_accessor = BlockAccessor.for_block(block) - new_columns = {} - for col_name in block_accessor.column_names(): - # For Arrow blocks, block[col_name] gives us a ChunkedArray - # For Pandas blocks, block[col_name] gives us a Series - new_columns[col_name] = block[col_name] - - # Add/update with expression results - for name, expr in exprs.items(): - result = eval_expr(expr, block) - new_columns[name] = result - - # Create a new block from the combined columns and add it - new_block = BlockAccessor.batch_to_block(new_columns) - - return new_block - def will_build_yield_copy(self) -> bool: """Whether building this block will yield a new block copy.""" raise NotImplementedError diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 810a3ff1daf6..77895f61359a 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -51,6 +51,7 @@ from ray.data.context import DataContext from ray.data.exceptions import UserCodeException from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled +from ray.data._expression_evaluator import eval_expr logger = logging.getLogger(__name__) @@ -116,9 +117,20 @@ def fn(block: Block) -> Block: # 1. evaluate / add expressions if exprs: - builder = block_accessor.builder() - block = builder.append_columns(block, exprs) block_accessor = BlockAccessor.for_block(block) + new_columns = {} + for col_name in block_accessor.column_names(): + # For Arrow blocks, block[col_name] gives us a ChunkedArray + # For Pandas blocks, block[col_name] gives us a Series + new_columns[col_name] = block[col_name] + + # Add/update with expression results + for name, expr in exprs.items(): + result = eval_expr(expr, block) + new_columns[name] = result + + # Create a new block from the combined columns and add it + block = BlockAccessor.batch_to_block(new_columns) # 2. (optional) column projection if columns: From 1a3941fc5270530719dd72c2eb25a6dabfcb159c Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 11:21:44 -0700 Subject: [PATCH 17/29] Remove block builder change Signed-off-by: Goutam V --- python/ray/data/_internal/block_builder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/data/_internal/block_builder.py b/python/ray/data/_internal/block_builder.py index 98bb8e40cc94..d3c232200a9d 100644 --- a/python/ray/data/_internal/block_builder.py +++ b/python/ray/data/_internal/block_builder.py @@ -1,7 +1,6 @@ -from typing import Dict, Generic +from typing import Generic from ray.data.block import Block, BlockAccessor, BlockType, T -from ray.data.expressions import Expr class BlockBuilder(Generic[T]): From 3f30cbbaf3c0c1a1ea25822fee47a14cdff7c48a Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 11:40:31 -0700 Subject: [PATCH 18/29] Make pre-commit happy Signed-off-by: Goutam V --- python/ray/data/_internal/planner/plan_udf_map_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 77895f61359a..5f11e0a0bf01 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -14,6 +14,7 @@ import ray from ray._common.utils import get_or_create_event_loop from ray._private.ray_constants import env_integer +from ray.data._expression_evaluator import eval_expr from ray.data._internal.compute import get_compute from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -51,7 +52,6 @@ from ray.data.context import DataContext from ray.data.exceptions import UserCodeException from ray.util.rpdb import _is_ray_debugger_post_mortem_enabled -from ray.data._expression_evaluator import eval_expr logger = logging.getLogger(__name__) From 9b8de872ba8a369edfbb2feb88277d5f4791a943 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 13:26:03 -0700 Subject: [PATCH 19/29] Address comment on Expr AST comparison Signed-off-by: Goutam V --- python/ray/data/expressions.py | 39 ++++++++++---- python/ray/data/tests/test_expressions.py | 64 +++++++++++++++++++++++ 2 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 python/ray/data/tests/test_expressions.py diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 8bc001df3203..1436927f1f8a 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -66,14 +66,33 @@ class Expr: subclasses like ColumnExpr, LiteralExpr, etc. """ - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - # Override dataclass __eq__ with expression __eq__ - cls.__eq__ = cls._expr_eq + def structurally_equals(self, other: Any) -> bool: + """Compare two expression ASTs for structural equality. - def _expr_eq(self, other: Any) -> "Expr": - """Expression equality operator.""" - return self._bin(other, Operation.EQ) + Note: + This is different from the `==` operator, which is + used for building expression trees (e.g., `col("a") == 5`). + """ + if type(self) is not type(other): + return False + + match (self, other): + case (ColumnExpr(name=n1), ColumnExpr(name=n2)): + return n1 == n2 + case (LiteralExpr(value=v1), LiteralExpr(value=v2)): + return v1 == v2 and type(v1) is type(v2) + case ( + BinaryExpr(op=o1, left=l1, right=r1), + BinaryExpr(op=o2, left=l2, right=r2), + ): + return ( + o1 is o2 + and l1.structurally_equals(l2) + and r1.structurally_equals(r2) + ) + case _: + # This case should not be reachable for known Expr types. + return False def _bin(self, other: Any, op: Operation) -> "Expr": """Create a binary expression with the given operation. @@ -157,7 +176,7 @@ def __or__(self, other: Any) -> "Expr": @DeveloperAPI -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ColumnExpr(Expr): """Expression that references a column by name. @@ -178,7 +197,7 @@ class ColumnExpr(Expr): @DeveloperAPI -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class LiteralExpr(Expr): """Expression that represents a constant scalar value. @@ -199,7 +218,7 @@ class LiteralExpr(Expr): @DeveloperAPI -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class BinaryExpr(Expr): """Expression that represents a binary operation between two expressions. diff --git a/python/ray/data/tests/test_expressions.py b/python/ray/data/tests/test_expressions.py new file mode 100644 index 000000000000..ac4783bbb8a5 --- /dev/null +++ b/python/ray/data/tests/test_expressions.py @@ -0,0 +1,64 @@ +import pytest + +from ray.data.expressions import Expr, col, lit + +# Tuples of (expr1, expr2, expected_result) +STRUCTURAL_EQUALITY_TEST_CASES = [ + # Base cases: ColumnExpr + (col("a"), col("a"), True), + (col("a"), col("b"), False), + # Base cases: LiteralExpr + (lit(1), lit(1), True), + (lit(1), lit(2), False), + (lit("x"), lit("y"), False), + # Different expression types + (col("a"), lit("a"), False), + (lit(1), lit(1.0), False), + # Simple binary expressions + (col("a") + 1, col("a") + 1, True), + (col("a") + 1, col("a") + 2, False), # Different literal + (col("a") + 1, col("b") + 1, False), # Different column + (col("a") + 1, col("a") - 1, False), # Different operator + # Complex, nested binary expressions + ((col("a") * 2) + (col("b") / 3), (col("a") * 2) + (col("b") / 3), True), + ((col("a") * 2) + (col("b") / 3), (col("a") * 2) - (col("b") / 3), False), + ((col("a") * 2) + (col("b") / 3), (col("c") * 2) + (col("b") / 3), False), + ((col("a") * 2) + (col("b") / 3), (col("a") * 2) + (col("b") / 4), False), + # Commutative operations are not structurally equal + (col("a") + col("b"), col("b") + col("a"), False), + (lit(1) * col("c"), col("c") * lit(1), False), +] + + +@pytest.mark.parametrize( + "expr1, expr2, expected", + STRUCTURAL_EQUALITY_TEST_CASES, + ids=[f"{i}" for i in range(len(STRUCTURAL_EQUALITY_TEST_CASES))], +) +def test_structural_equality(expr1, expr2, expected): + """Tests `structurally_equals` for various expression trees.""" + assert expr1.structurally_equals(expr2) is expected + # Test for symmetry + assert expr2.structurally_equals(expr1) is expected + + +def test_operator_eq_is_not_structural_eq(): + """ + Confirms that `__eq__` (==) builds an expression, while + `structurally_equals` compares two existing expressions. + """ + # `==` returns a BinaryExpr, not a boolean + op_eq_expr = col("a") == col("a") + assert isinstance(op_eq_expr, Expr) + assert not isinstance(op_eq_expr, bool) + + # `structurally_equals` returns a boolean + struct_eq_result = col("a").structurally_equals(col("a")) + assert isinstance(struct_eq_result, bool) + assert struct_eq_result is True + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) From c13f679b94dcaa795623ea7eaaab3b898f915da8 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 13:42:12 -0700 Subject: [PATCH 20/29] Add expressions test to bazel build Signed-off-by: Goutam V --- python/ray/data/BUILD | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/ray/data/BUILD b/python/ray/data/BUILD index 4565846055df..adcc0eaae038 100644 --- a/python/ray/data/BUILD +++ b/python/ray/data/BUILD @@ -297,6 +297,20 @@ py_test( ], ) +py_test( + name = "test_expressions", + size = "small", + srcs = ["tests/test_expressions.py"], + tags = [ + "exclusive", + "team:data", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_avro", size = "small", From 8d61562208e00bd7a91310d16795f811ed505595 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 15:15:47 -0700 Subject: [PATCH 21/29] Remove match expression Signed-off-by: Goutam V --- python/ray/data/expressions.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 1436927f1f8a..ef2ce5fae15f 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -76,23 +76,22 @@ def structurally_equals(self, other: Any) -> bool: if type(self) is not type(other): return False - match (self, other): - case (ColumnExpr(name=n1), ColumnExpr(name=n2)): - return n1 == n2 - case (LiteralExpr(value=v1), LiteralExpr(value=v2)): - return v1 == v2 and type(v1) is type(v2) - case ( - BinaryExpr(op=o1, left=l1, right=r1), - BinaryExpr(op=o2, left=l2, right=r2), - ): - return ( - o1 is o2 - and l1.structurally_equals(l2) - and r1.structurally_equals(r2) - ) - case _: - # This case should not be reachable for known Expr types. - return False + if isinstance(self, ColumnExpr): + # `other` is also a ColumnExpr due to the type check above. + return self.name == other.name + elif isinstance(self, LiteralExpr): + # `other` is also a LiteralExpr. + return self.value == other.value and type(self.value) is type(other.value) + elif isinstance(self, BinaryExpr): + # `other` is also a BinaryExpr. + return ( + self.op is other.op + and self.left.structurally_equals(other.left) + and self.right.structurally_equals(other.right) + ) + + # This case should not be reachable for known Expr types. + return False def _bin(self, other: Any, op: Operation) -> "Expr": """Create a binary expression with the given operation. From d8890fda6e623476162d3ddebdbfd5c23c00e37e Mon Sep 17 00:00:00 2001 From: Goutam V Date: Thu, 10 Jul 2025 15:20:39 -0700 Subject: [PATCH 22/29] Comments Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 67 ++++++++++-------------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 6b4508e84cfc..8326548bfc18 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -8,51 +8,44 @@ import pyarrow.compute as pc from ray.data.expressions import ( + BinaryExpr, + ColumnExpr, Expr, + LiteralExpr, Operation, ) +pandas_ops = { + Operation.ADD: operator.add, + Operation.SUB: operator.sub, + Operation.MUL: operator.mul, + Operation.DIV: operator.truediv, + Operation.GT: operator.gt, + Operation.LT: operator.lt, + Operation.GE: operator.ge, + Operation.LE: operator.le, + Operation.EQ: operator.eq, + Operation.AND: operator.and_, + Operation.OR: operator.or_, +} -def _get_operation_maps(): - """Get operation maps, importing Operation enum at runtime to avoid circular imports.""" - from ray.data.expressions import Operation - - pandas_ops = { - Operation.ADD: operator.add, - Operation.SUB: operator.sub, - Operation.MUL: operator.mul, - Operation.DIV: operator.truediv, - Operation.GT: operator.gt, - Operation.LT: operator.lt, - Operation.GE: operator.ge, - Operation.LE: operator.le, - Operation.EQ: operator.eq, - Operation.AND: operator.and_, - Operation.OR: operator.or_, - } - - arrow_ops = { - Operation.ADD: pc.add, - Operation.SUB: pc.subtract, - Operation.MUL: pc.multiply, - Operation.DIV: pc.divide, - Operation.GT: pc.greater, - Operation.LT: pc.less, - Operation.GE: pc.greater_equal, - Operation.LE: pc.less_equal, - Operation.EQ: pc.equal, - Operation.AND: pc.and_, - Operation.OR: pc.or_, - } - - return pandas_ops, arrow_ops +arrow_ops = { + Operation.ADD: pc.add, + Operation.SUB: pc.subtract, + Operation.MUL: pc.multiply, + Operation.DIV: pc.divide, + Operation.GT: pc.greater, + Operation.LT: pc.less, + Operation.GE: pc.greater_equal, + Operation.LE: pc.less_equal, + Operation.EQ: pc.equal, + Operation.AND: pc.and_, + Operation.OR: pc.or_, +} def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) -> Any: """Generic recursive expression evaluator.""" - # Import classes at runtime to avoid circular imports - from ray.data.expressions import BinaryExpr, ColumnExpr, LiteralExpr - # TODO: Separate unresolved expressions (arbitrary AST with unresolved refs) # and resolved expressions (bound to a schema) for better error handling @@ -70,8 +63,6 @@ def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) def eval_expr(expr: "Expr", batch) -> Any: """Recursively evaluate *expr* against a batch of the appropriate type.""" - pandas_ops, arrow_ops = _get_operation_maps() - if isinstance(batch, pd.DataFrame): return _eval_expr_recursive(expr, batch, pandas_ops) elif isinstance(batch, pa.Table): From 164cbd3930f126ffac86727928b5af8a762d4d79 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 10:01:53 -0700 Subject: [PATCH 23/29] Address comments Signed-off-by: Goutam V --- python/ray/data/_expression_evaluator.py | 13 +- python/ray/data/expressions.py | 248 ++++------------------- 2 files changed, 49 insertions(+), 212 deletions(-) diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py index 8326548bfc18..370ef90b83bc 100644 --- a/python/ray/data/_expression_evaluator.py +++ b/python/ray/data/_expression_evaluator.py @@ -15,7 +15,7 @@ Operation, ) -pandas_ops = { +_PANDAS_EXPR_OPS_MAP = { Operation.ADD: operator.add, Operation.SUB: operator.sub, Operation.MUL: operator.mul, @@ -29,7 +29,7 @@ Operation.OR: operator.or_, } -arrow_ops = { +_ARROW_EXPR_OPS_MAP = { Operation.ADD: pc.add, Operation.SUB: pc.subtract, Operation.MUL: pc.multiply, @@ -58,13 +58,14 @@ def _eval_expr_recursive(expr: "Expr", batch, ops: Dict["Operation", Callable]) _eval_expr_recursive(expr.left, batch, ops), _eval_expr_recursive(expr.right, batch, ops), ) - raise TypeError(f"Unsupported expression node: {type(expr).__name__}") + raise TypeError(f"Unsupported expression node: {type(expr).__name__}") def eval_expr(expr: "Expr", batch) -> Any: """Recursively evaluate *expr* against a batch of the appropriate type.""" if isinstance(batch, pd.DataFrame): - return _eval_expr_recursive(expr, batch, pandas_ops) + return _eval_expr_recursive(expr, batch, _PANDAS_EXPR_OPS_MAP) elif isinstance(batch, pa.Table): - return _eval_expr_recursive(expr, batch, arrow_ops) - raise TypeError(f"Unsupported batch type: {type(batch).__name__}") + return _eval_expr_recursive(expr, batch, _ARROW_EXPR_OPS_MAP) + else: + raise TypeError(f"Unsupported batch type: {type(batch).__name__}") diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index ef2ce5fae15f..5b5c3d0e8935 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -1,32 +1,16 @@ from __future__ import annotations +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from typing import Any -from ray.util.annotations import DeveloperAPI +from ray.util.annotations import PublicAPI -@DeveloperAPI +@PublicAPI(stability="alpha") class Operation(Enum): - """Enumeration of supported operations in expressions. - - This enum defines all the binary operations that can be performed - between expressions, including arithmetic, comparison, and boolean operations. - - Attributes: - ADD: Addition operation (+) - SUB: Subtraction operation (-) - MUL: Multiplication operation (*) - DIV: Division operation (/) - GT: Greater than comparison (>) - LT: Less than comparison (<) - GE: Greater than or equal comparison (>=) - LE: Less than or equal comparison (<=) - EQ: Equality comparison (==) - AND: Logical AND operation (&) - OR: Logical OR operation (|) - """ + """Enumeration of supported operations in expressions.""" ADD = "add" SUB = "sub" @@ -41,275 +25,127 @@ class Operation(Enum): OR = "or" -@DeveloperAPI +@PublicAPI(stability="alpha") @dataclass(frozen=True) -class Expr: - """Base class for all expression nodes. - - This is the abstract base class that all expression types inherit from. - It provides operator overloads for building complex expressions using - standard Python operators. - - Expressions form a tree structure where each node represents an operation - or value. The tree can be evaluated against data batches to compute results. - - Example: - >>> from ray.data.expressions import col, lit - >>> # Create an expression tree: (col("x") + 5) * col("y") - >>> expr = (col("x") + lit(5)) * col("y") - >>> # This creates a BinaryExpr with operation=MUL - >>> # left=BinaryExpr(op=ADD, left=ColumnExpr("x"), right=LiteralExpr(5)) - >>> # right=ColumnExpr("y") - - Note: - This class should not be instantiated directly. Use the concrete - subclasses like ColumnExpr, LiteralExpr, etc. - """ +class Expr(ABC): + """Base class for all expression nodes.""" + @abstractmethod def structurally_equals(self, other: Any) -> bool: - """Compare two expression ASTs for structural equality. - - Note: - This is different from the `==` operator, which is - used for building expression trees (e.g., `col("a") == 5`). - """ - if type(self) is not type(other): - return False - - if isinstance(self, ColumnExpr): - # `other` is also a ColumnExpr due to the type check above. - return self.name == other.name - elif isinstance(self, LiteralExpr): - # `other` is also a LiteralExpr. - return self.value == other.value and type(self.value) is type(other.value) - elif isinstance(self, BinaryExpr): - # `other` is also a BinaryExpr. - return ( - self.op is other.op - and self.left.structurally_equals(other.left) - and self.right.structurally_equals(other.right) - ) - - # This case should not be reachable for known Expr types. - return False + """Compare two expression ASTs for structural equality.""" + raise NotImplementedError def _bin(self, other: Any, op: Operation) -> "Expr": - """Create a binary expression with the given operation. - - Args: - other: The right operand expression or literal value - op: The operation to perform - - Returns: - A new BinaryExpr representing the operation - - Note: - If other is not an Expr, it will be automatically converted to a LiteralExpr. - """ + """Create a binary expression with the given operation.""" if not isinstance(other, Expr): other = LiteralExpr(other) return BinaryExpr(op, self, other) - # arithmetic + # Arithmetic operators def __add__(self, other: Any) -> "Expr": - """Addition operator (+).""" return self._bin(other, Operation.ADD) def __radd__(self, other: Any) -> "Expr": - """Reverse addition operator (for literal + expr).""" return LiteralExpr(other)._bin(self, Operation.ADD) def __sub__(self, other: Any) -> "Expr": - """Subtraction operator (-).""" return self._bin(other, Operation.SUB) def __rsub__(self, other: Any) -> "Expr": - """Reverse subtraction operator (for literal - expr).""" return LiteralExpr(other)._bin(self, Operation.SUB) def __mul__(self, other: Any) -> "Expr": - """Multiplication operator (*).""" return self._bin(other, Operation.MUL) def __rmul__(self, other: Any) -> "Expr": - """Reverse multiplication operator (for literal * expr).""" return LiteralExpr(other)._bin(self, Operation.MUL) def __truediv__(self, other: Any) -> "Expr": - """Division operator (/).""" return self._bin(other, Operation.DIV) def __rtruediv__(self, other: Any) -> "Expr": - """Reverse division operator (for literal / expr).""" return LiteralExpr(other)._bin(self, Operation.DIV) - # comparison + # Comparison operators def __gt__(self, other: Any) -> "Expr": - """Greater than operator (>).""" return self._bin(other, Operation.GT) def __lt__(self, other: Any) -> "Expr": - """Less than operator (<).""" return self._bin(other, Operation.LT) def __ge__(self, other: Any) -> "Expr": - """Greater than or equal operator (>=).""" return self._bin(other, Operation.GE) def __le__(self, other: Any) -> "Expr": - """Less than or equal operator (<=).""" return self._bin(other, Operation.LE) def __eq__(self, other: Any) -> "Expr": - """Equality operator (==).""" return self._bin(other, Operation.EQ) - # boolean + # Boolean operators def __and__(self, other: Any) -> "Expr": - """Logical AND operator (&).""" return self._bin(other, Operation.AND) def __or__(self, other: Any) -> "Expr": - """Logical OR operator (|).""" return self._bin(other, Operation.OR) -@DeveloperAPI +@PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): - """Expression that references a column by name. - - This expression type represents a reference to an existing column - in the dataset. When evaluated, it returns the values from the - specified column. - - Args: - name: The name of the column to reference - - Example: - >>> from ray.data.expressions import col - >>> # Reference the "age" column - >>> age_expr = col("age") # Creates ColumnExpr(name="age") - """ + """Expression that references a column by name.""" name: str + def structurally_equals(self, other: Any) -> bool: + return isinstance(other, ColumnExpr) and self.name == other.name + -@DeveloperAPI +@PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): - """Expression that represents a constant scalar value. - - This expression type represents a literal value that will be broadcast - to all rows when evaluated. The value can be any Python object. - - Args: - value: The constant value to represent - - Example: - >>> from ray.data.expressions import lit - >>> # Create a literal value - >>> five = lit(5) # Creates LiteralExpr(value=5) - >>> name = lit("John") # Creates LiteralExpr(value="John") - """ + """Expression that represents a constant scalar value.""" value: Any + def structurally_equals(self, other: Any) -> bool: + return ( + isinstance(other, LiteralExpr) + and self.value == other.value + and type(self.value) is type(other.value) + ) + -@DeveloperAPI +@PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): - """Expression that represents a binary operation between two expressions. - - This expression type represents an operation with two operands (left and right). - The operation is specified by the `op` field, which must be one of the - supported operations from the Operation enum. - - Args: - op: The operation to perform (from Operation enum) - left: The left operand expression - right: The right operand expression - - Example: - >>> from ray.data.expressions import col, lit, Operation - >>> # Manually create a binary expression (usually done via operators) - >>> expr = BinaryExpr(Operation.ADD, col("x"), lit(5)) - >>> # This is equivalent to: col("x") + lit(5) - """ + """Expression that represents a binary operation between two expressions.""" op: Operation left: Expr right: Expr + def structurally_equals(self, other: Any) -> bool: + return ( + isinstance(other, BinaryExpr) + and self.op is other.op + and self.left.structurally_equals(other.left) + and self.right.structurally_equals(other.right) + ) + -@DeveloperAPI +@PublicAPI(stability="beta") def col(name: str) -> ColumnExpr: - """Reference an existing column by name. - - This is the primary way to reference columns in expressions. - The returned expression will extract values from the specified - column when evaluated. - - Args: - name: The name of the column to reference - - Returns: - A ColumnExpr that references the specified column - - Example: - >>> from ray.data.expressions import col - >>> # Reference columns in an expression - >>> expr = col("price") * col("quantity") - >>> - >>> # Use with Dataset.with_columns() - >>> import ray - >>> ds = ray.data.from_items([{"price": 10, "quantity": 2}]) - >>> ds = ds.with_columns({"total": col("price") * col("quantity")}) - """ + """Reference an existing column by name.""" return ColumnExpr(name) -@DeveloperAPI +@PublicAPI(stability="beta") def lit(value: Any) -> LiteralExpr: - """Create a literal expression from a constant value. - - This creates an expression that represents a constant scalar value. - The value will be broadcast to all rows when the expression is evaluated. - - Args: - value: The constant value to represent. Can be any Python object - (int, float, str, bool, etc.) - - Returns: - A LiteralExpr containing the specified value - - Example: - >>> from ray.data.expressions import col, lit - >>> # Create literals of different types - >>> five = lit(5) - >>> pi = lit(3.14159) - >>> name = lit("Alice") - >>> flag = lit(True) - >>> - >>> # Use in expressions - >>> expr = col("age") + lit(1) # Add 1 to age column - >>> - >>> # Use with Dataset.with_columns() - >>> import ray - >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) - >>> ds = ds.with_columns({"age_plus_one": col("age") + lit(1)}) - """ + """Create a literal expression from a constant value.""" return LiteralExpr(value) -# ────────────────────────────────────── -# Public API for evaluation -# ────────────────────────────────────── -# Note: Implementation details are in _expression_evaluator.py - -# Re-export eval_expr for public use - - __all__ = [ "Operation", "Expr", From b64beefe199aacb9a66722b5cba263f26d5e1331 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 10:05:25 -0700 Subject: [PATCH 24/29] Add comments back Signed-off-by: Goutam V --- python/ray/data/expressions.py | 182 +++++++++++++++++++++++++++++++-- 1 file changed, 171 insertions(+), 11 deletions(-) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 5b5c3d0e8935..f85dd8c03a80 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -10,7 +10,24 @@ @PublicAPI(stability="alpha") class Operation(Enum): - """Enumeration of supported operations in expressions.""" + """Enumeration of supported operations in expressions. + + This enum defines all the binary operations that can be performed + between expressions, including arithmetic, comparison, and boolean operations. + + Attributes: + ADD: Addition operation (+) + SUB: Subtraction operation (-) + MUL: Multiplication operation (*) + DIV: Division operation (/) + GT: Greater than comparison (>) + LT: Less than comparison (<) + GE: Greater than or equal comparison (>=) + LE: Less than or equal comparison (<=) + EQ: Equality comparison (==) + AND: Logical AND operation (&) + OR: Logical OR operation (|) + """ ADD = "add" SUB = "sub" @@ -28,7 +45,27 @@ class Operation(Enum): @PublicAPI(stability="alpha") @dataclass(frozen=True) class Expr(ABC): - """Base class for all expression nodes.""" + """Base class for all expression nodes. + + This is the abstract base class that all expression types inherit from. + It provides operator overloads for building complex expressions using + standard Python operators. + + Expressions form a tree structure where each node represents an operation + or value. The tree can be evaluated against data batches to compute results. + + Example: + >>> from ray.data.expressions import col, lit + >>> # Create an expression tree: (col("x") + 5) * col("y") + >>> expr = (col("x") + lit(5)) * col("y") + >>> # This creates a BinaryExpr with operation=MUL + >>> # left=BinaryExpr(op=ADD, left=ColumnExpr("x"), right=LiteralExpr(5)) + >>> # right=ColumnExpr("y") + + Note: + This class should not be instantiated directly. Use the concrete + subclasses like ColumnExpr, LiteralExpr, etc. + """ @abstractmethod def structurally_equals(self, other: Any) -> bool: @@ -36,64 +73,103 @@ def structurally_equals(self, other: Any) -> bool: raise NotImplementedError def _bin(self, other: Any, op: Operation) -> "Expr": - """Create a binary expression with the given operation.""" + """Create a binary expression with the given operation. + + Args: + other: The right operand expression or literal value + op: The operation to perform + + Returns: + A new BinaryExpr representing the operation + + Note: + If other is not an Expr, it will be automatically converted to a LiteralExpr. + """ if not isinstance(other, Expr): other = LiteralExpr(other) return BinaryExpr(op, self, other) - # Arithmetic operators + # arithmetic def __add__(self, other: Any) -> "Expr": + """Addition operator (+).""" return self._bin(other, Operation.ADD) def __radd__(self, other: Any) -> "Expr": + """Reverse addition operator (for literal + expr).""" return LiteralExpr(other)._bin(self, Operation.ADD) def __sub__(self, other: Any) -> "Expr": + """Subtraction operator (-).""" return self._bin(other, Operation.SUB) def __rsub__(self, other: Any) -> "Expr": + """Reverse subtraction operator (for literal - expr).""" return LiteralExpr(other)._bin(self, Operation.SUB) def __mul__(self, other: Any) -> "Expr": + """Multiplication operator (*).""" return self._bin(other, Operation.MUL) def __rmul__(self, other: Any) -> "Expr": + """Reverse multiplication operator (for literal * expr).""" return LiteralExpr(other)._bin(self, Operation.MUL) def __truediv__(self, other: Any) -> "Expr": + """Division operator (/).""" return self._bin(other, Operation.DIV) def __rtruediv__(self, other: Any) -> "Expr": + """Reverse division operator (for literal / expr).""" return LiteralExpr(other)._bin(self, Operation.DIV) - # Comparison operators + # comparison def __gt__(self, other: Any) -> "Expr": + """Greater than operator (>).""" return self._bin(other, Operation.GT) def __lt__(self, other: Any) -> "Expr": + """Less than operator (<).""" return self._bin(other, Operation.LT) def __ge__(self, other: Any) -> "Expr": + """Greater than or equal operator (>=).""" return self._bin(other, Operation.GE) def __le__(self, other: Any) -> "Expr": + """Less than or equal operator (<=).""" return self._bin(other, Operation.LE) def __eq__(self, other: Any) -> "Expr": + """Equality operator (==).""" return self._bin(other, Operation.EQ) - # Boolean operators + # boolean def __and__(self, other: Any) -> "Expr": + """Logical AND operator (&).""" return self._bin(other, Operation.AND) def __or__(self, other: Any) -> "Expr": + """Logical OR operator (|).""" return self._bin(other, Operation.OR) @PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): - """Expression that references a column by name.""" + """Expression that references a column by name. + + This expression type represents a reference to an existing column + in the dataset. When evaluated, it returns the values from the + specified column. + + Args: + name: The name of the column to reference + + Example: + >>> from ray.data.expressions import col + >>> # Reference the "age" column + >>> age_expr = col("age") # Creates ColumnExpr(name="age") + """ name: str @@ -104,7 +180,20 @@ def structurally_equals(self, other: Any) -> bool: @PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): - """Expression that represents a constant scalar value.""" + """Expression that represents a constant scalar value. + + This expression type represents a literal value that will be broadcast + to all rows when evaluated. The value can be any Python object. + + Args: + value: The constant value to represent + + Example: + >>> from ray.data.expressions import lit + >>> # Create a literal value + >>> five = lit(5) # Creates LiteralExpr(value=5) + >>> name = lit("John") # Creates LiteralExpr(value="John") + """ value: Any @@ -119,7 +208,23 @@ def structurally_equals(self, other: Any) -> bool: @PublicAPI(stability="alpha") @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): - """Expression that represents a binary operation between two expressions.""" + """Expression that represents a binary operation between two expressions. + + This expression type represents an operation with two operands (left and right). + The operation is specified by the `op` field, which must be one of the + supported operations from the Operation enum. + + Args: + op: The operation to perform (from Operation enum) + left: The left operand expression + right: The right operand expression + + Example: + >>> from ray.data.expressions import col, lit, Operation + >>> # Manually create a binary expression (usually done via operators) + >>> expr = BinaryExpr(Operation.ADD, col("x"), lit(5)) + >>> # This is equivalent to: col("x") + lit(5) + """ op: Operation left: Expr @@ -136,16 +241,71 @@ def structurally_equals(self, other: Any) -> bool: @PublicAPI(stability="beta") def col(name: str) -> ColumnExpr: - """Reference an existing column by name.""" + """Reference an existing column by name. + + This is the primary way to reference columns in expressions. + The returned expression will extract values from the specified + column when evaluated. + + Args: + name: The name of the column to reference + + Returns: + A ColumnExpr that references the specified column + + Example: + >>> from ray.data.expressions import col + >>> # Reference columns in an expression + >>> expr = col("price") * col("quantity") + >>> + >>> # Use with Dataset.with_columns() + >>> import ray + >>> ds = ray.data.from_items([{"price": 10, "quantity": 2}]) + >>> ds = ds.with_columns({"total": col("price") * col("quantity")}) + """ return ColumnExpr(name) @PublicAPI(stability="beta") def lit(value: Any) -> LiteralExpr: - """Create a literal expression from a constant value.""" + """Create a literal expression from a constant value. + + This creates an expression that represents a constant scalar value. + The value will be broadcast to all rows when the expression is evaluated. + + Args: + value: The constant value to represent. Can be any Python object + (int, float, str, bool, etc.) + + Returns: + A LiteralExpr containing the specified value + + Example: + >>> from ray.data.expressions import col, lit + >>> # Create literals of different types + >>> five = lit(5) + >>> pi = lit(3.14159) + >>> name = lit("Alice") + >>> flag = lit(True) + >>> + >>> # Use in expressions + >>> expr = col("age") + lit(1) # Add 1 to age column + >>> + >>> # Use with Dataset.with_columns() + >>> import ray + >>> ds = ray.data.from_items([{"age": 25}, {"age": 30}]) + >>> ds = ds.with_columns({"age_plus_one": col("age") + lit(1)}) + """ return LiteralExpr(value) +# ────────────────────────────────────── +# Public API for evaluation +# ────────────────────────────────────── +# Note: Implementation details are in _expression_evaluator.py + +# Re-export eval_expr for public use + __all__ = [ "Operation", "Expr", From a3f30500402dfe23274494ac946f6baa56e7404d Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 10:30:14 -0700 Subject: [PATCH 25/29] Make expression classes dev api Signed-off-by: Goutam V --- python/ray/data/expressions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index f85dd8c03a80..534ea766ee3e 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -5,10 +5,10 @@ from enum import Enum from typing import Any -from ray.util.annotations import PublicAPI +from ray.util.annotations import DeveloperAPI, PublicAPI -@PublicAPI(stability="alpha") +@DeveloperAPI class Operation(Enum): """Enumeration of supported operations in expressions. @@ -42,7 +42,7 @@ class Operation(Enum): OR = "or" -@PublicAPI(stability="alpha") +@DeveloperAPI @dataclass(frozen=True) class Expr(ABC): """Base class for all expression nodes. @@ -153,7 +153,7 @@ def __or__(self, other: Any) -> "Expr": return self._bin(other, Operation.OR) -@PublicAPI(stability="alpha") +@DeveloperAPI @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): """Expression that references a column by name. @@ -177,7 +177,7 @@ def structurally_equals(self, other: Any) -> bool: return isinstance(other, ColumnExpr) and self.name == other.name -@PublicAPI(stability="alpha") +@DeveloperAPI @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): """Expression that represents a constant scalar value. @@ -205,7 +205,7 @@ def structurally_equals(self, other: Any) -> bool: ) -@PublicAPI(stability="alpha") +@DeveloperAPI @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): """Expression that represents a binary operation between two expressions. From 821b73e041029911f5c293c772980991877289f4 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 11:16:33 -0700 Subject: [PATCH 26/29] Add stability to DeveloperAPIs Signed-off-by: Goutam V --- python/ray/data/expressions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 534ea766ee3e..2bcd038b45df 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -8,7 +8,7 @@ from ray.util.annotations import DeveloperAPI, PublicAPI -@DeveloperAPI +@DeveloperAPI(stability="alpha") class Operation(Enum): """Enumeration of supported operations in expressions. @@ -42,7 +42,7 @@ class Operation(Enum): OR = "or" -@DeveloperAPI +@DeveloperAPI(stability="alpha") @dataclass(frozen=True) class Expr(ABC): """Base class for all expression nodes. @@ -153,7 +153,7 @@ def __or__(self, other: Any) -> "Expr": return self._bin(other, Operation.OR) -@DeveloperAPI +@DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False) class ColumnExpr(Expr): """Expression that references a column by name. @@ -177,7 +177,7 @@ def structurally_equals(self, other: Any) -> bool: return isinstance(other, ColumnExpr) and self.name == other.name -@DeveloperAPI +@DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False) class LiteralExpr(Expr): """Expression that represents a constant scalar value. @@ -205,7 +205,7 @@ def structurally_equals(self, other: Any) -> bool: ) -@DeveloperAPI +@DeveloperAPI(stability="alpha") @dataclass(frozen=True, eq=False) class BinaryExpr(Expr): """Expression that represents a binary operation between two expressions. From f5b08eb32804914f32567051ac89d640eb930934 Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 16:23:03 -0700 Subject: [PATCH 27/29] Add .rst files Signed-off-by: Goutam V --- doc/source/data/api/api.rst | 1 + doc/source/data/api/expressions.rst | 60 +++++++++++++++++++++++++++++ python/ray/data/expressions.py | 6 ++- 3 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 doc/source/data/api/expressions.rst diff --git a/doc/source/data/api/api.rst b/doc/source/data/api/api.rst index d1643c96b08e..e0d0e94d9480 100644 --- a/doc/source/data/api/api.rst +++ b/doc/source/data/api/api.rst @@ -12,6 +12,7 @@ Ray Data API execution_options.rst aggregate.rst grouped_data.rst + expressions.rst data_context.rst preprocessor.rst llm.rst diff --git a/doc/source/data/api/expressions.rst b/doc/source/data/api/expressions.rst new file mode 100644 index 000000000000..f082654a7e11 --- /dev/null +++ b/doc/source/data/api/expressions.rst @@ -0,0 +1,60 @@ +.. _expressions-api: + +Expressions API +=============== + +Expressions provide a way to specify column-based operations on datasets. +Use :func:`col` to reference columns and :func:`lit` to create literal values. +These can be combined with operators to create complex expressions for filtering, +transformations, and computations. + +Examples: + +.. code-block:: python + + import ray + from ray.data.expressions import col, lit + + # Create a dataset + ds = ray.data.from_items([ + {"name": "Alice", "age": 30, "score": 85.5}, + {"name": "Bob", "age": 25, "score": 92.0}, + {"name": "Charlie", "age": 35, "score": 78.5} + ]) + + # Use expressions in transformations + ds = ds.with_columns({ + "age_plus_one": col("age") + lit(1), + "high_score": col("score") > lit(80) + }) + + # Use expressions in filtering + ds = ds.filter(col("age") >= lit(30)) + +.. currentmodule:: ray.data.expressions + +Public API +---------- + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + col + lit + +Expression Classes +------------------ + +These classes represent the structure of expressions. You typically don't need to +instantiate them directly, but you may encounter them when working with expressions. + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + Expr + ColumnExpr + LiteralExpr + BinaryExpr + Operation \ No newline at end of file diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index 2bcd038b45df..cf59aa30b5e0 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -241,7 +241,8 @@ def structurally_equals(self, other: Any) -> bool: @PublicAPI(stability="beta") def col(name: str) -> ColumnExpr: - """Reference an existing column by name. + """ + Reference an existing column by name. This is the primary way to reference columns in expressions. The returned expression will extract values from the specified @@ -268,7 +269,8 @@ def col(name: str) -> ColumnExpr: @PublicAPI(stability="beta") def lit(value: Any) -> LiteralExpr: - """Create a literal expression from a constant value. + """ + Create a literal expression from a constant value. This creates an expression that represents a constant scalar value. The value will be broadcast to all rows when the expression is evaluated. From 2741b1220fb01568c7713918e24332ae64ddb3cc Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 17:06:25 -0700 Subject: [PATCH 28/29] idk rst Signed-off-by: Goutam V --- doc/source/data/api/expressions.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/doc/source/data/api/expressions.rst b/doc/source/data/api/expressions.rst index f082654a7e11..17d14fff1ecc 100644 --- a/doc/source/data/api/expressions.rst +++ b/doc/source/data/api/expressions.rst @@ -3,6 +3,8 @@ Expressions API =============== +.. currentmodule:: ray.data.expressions + Expressions provide a way to specify column-based operations on datasets. Use :func:`col` to reference columns and :func:`lit` to create literal values. These can be combined with operators to create complex expressions for filtering, @@ -31,8 +33,6 @@ Examples: # Use expressions in filtering ds = ds.filter(col("age") >= lit(30)) -.. currentmodule:: ray.data.expressions - Public API ---------- @@ -56,5 +56,4 @@ instantiate them directly, but you may encounter them when working with expressi Expr ColumnExpr LiteralExpr - BinaryExpr - Operation \ No newline at end of file + BinaryExpr \ No newline at end of file From c7f042473806f7cb79fb86e2a429e153e9c2743b Mon Sep 17 00:00:00 2001 From: Goutam V Date: Fri, 11 Jul 2025 17:27:35 -0700 Subject: [PATCH 29/29] Remove code snippet Signed-off-by: Goutam V --- doc/source/data/api/expressions.rst | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/doc/source/data/api/expressions.rst b/doc/source/data/api/expressions.rst index 17d14fff1ecc..69c1a50c93d0 100644 --- a/doc/source/data/api/expressions.rst +++ b/doc/source/data/api/expressions.rst @@ -10,29 +10,6 @@ Use :func:`col` to reference columns and :func:`lit` to create literal values. These can be combined with operators to create complex expressions for filtering, transformations, and computations. -Examples: - -.. code-block:: python - - import ray - from ray.data.expressions import col, lit - - # Create a dataset - ds = ray.data.from_items([ - {"name": "Alice", "age": 30, "score": 85.5}, - {"name": "Bob", "age": 25, "score": 92.0}, - {"name": "Charlie", "age": 35, "score": 78.5} - ]) - - # Use expressions in transformations - ds = ds.with_columns({ - "age_plus_one": col("age") + lit(1), - "high_score": col("score") > lit(80) - }) - - # Use expressions in filtering - ds = ds.filter(col("age") >= lit(30)) - Public API ----------