Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/cudf_polars/cudf_polars/containers/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def id(self) -> plc.TypeId:
@property
def children(self) -> list[DataType]:
"""The children types of this DataType."""
# these type ignores are needed because the type checker doesn't
# see that these equality checks passing imply a specific type for each child field.
if self.plc_type.id() == plc.TypeId.STRUCT:
return [DataType(field.dtype) for field in self.polars_type.fields]
return [DataType(field.dtype) for field in self.polars_type.fields] # type: ignore[attr-defined]
elif self.plc_type.id() == plc.TypeId.LIST:
return [DataType(self.polars_type.inner)]
return [DataType(self.polars_type.inner)] # type: ignore[attr-defined]
return []

def __eq__(self, other: object) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@

def _dtypes_for_json_decode(dtype: DataType) -> JsonDecodeType:
"""Get the dtypes for json decode."""
# the type checker doesn't know that this equality check implies a struct dtype.
if dtype.id() == plc.TypeId.STRUCT:
return [
(field.name, child.plc_type, _dtypes_for_json_decode(child))
for field, child in zip(
dtype.polars_type.fields, dtype.children, strict=True
dtype.polars_type.fields, # type: ignore[attr-defined]
dtype.children,
strict=True, # type: ignore[attr-defined]
)
]
else:
Expand Down
6 changes: 4 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def do_evaluate(
"""Evaluate this expression given a dataframe for context."""
columns = [child.evaluate(df, context=context) for child in self.children]
(column,) = columns
# these type ignores are needed because the type checker doesn't
# know that polars only calls StructFunction with struct types.
if self.name == StructFunction.Name.FieldByName:
field_index = next(
(
i
for i, field in enumerate(self.children[0].dtype.polars_type.fields)
for i, field in enumerate(self.children[0].dtype.polars_type.fields) # type: ignore[attr-defined]
if field.name == self.options[0]
),
None,
Expand All @@ -111,7 +113,7 @@ def do_evaluate(
table,
[
(field.name, [])
for field in self.children[0].dtype.polars_type.fields
for field in self.children[0].dtype.polars_type.fields # type: ignore[attr-defined]
],
)
options = (
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/utils/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def decompose_single_agg(
tid = agg.dtype.plc_type.id()
if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
cast_to = (
DataType(pl.Float64)
DataType(pl.Float64())
if tid == plc.TypeId.FLOAT64
else DataType(pl.Float32)
else DataType(pl.Float32())
)
child = expr.Cast(cast_to, child)
child_dtype = child.dtype.plc_type
Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def assert_gpu_result_equal(
else:
tol_kwargs = {"rel_tol": rtol, "abs_tol": atol}

# the type checker errors with:
# Argument 4 to "assert_frame_equal" has incompatible type "**dict[str, float]"; expected "bool" [arg-type]
# which seems to be a bug in the type checker / type annotations.
assert_frame_equal(
expect,
got,
**assert_kwargs_bool,
**tol_kwargs,
**tol_kwargs, # type: ignore[arg-type]
)


Expand Down
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/testing/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import polars as pl

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal

__all__: list[str] = ["make_partitioned_source"]
Expand Down Expand Up @@ -110,7 +111,7 @@ def make_lazy_frame(
assert path is not None, f"path is required for fmt={fmt}."
row_group_size: int | None = None
if fmt == "parquet":
read = pl.scan_parquet
read: Callable[..., pl.LazyFrame] = pl.scan_parquet
row_group_size = 10
elif fmt == "csv":
read = pl.scan_csv
Expand Down
5 changes: 5 additions & 0 deletions python/cudf_polars/cudf_polars/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"OptimizationArgs",
"PolarsExpr",
"PolarsIR",
"RankMethod",
"Schema",
"Slice",
]
Expand Down Expand Up @@ -217,3 +218,7 @@ class DataFrameHeader(TypedDict):

columns_kwargs: list[ColumnOptions]
frame_count: int


# Not public in polars yet
RankMethod = Literal["ordinal", "dense", "min", "max", "average"]
12 changes: 9 additions & 3 deletions python/cudf_polars/tests/expressions/test_numeric_unaryops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING

import pytest

Expand All @@ -14,6 +15,9 @@
)
from cudf_polars.utils.versions import POLARS_VERSION_LT_132

if TYPE_CHECKING:
from cudf_polars.typing import RankMethod


@pytest.fixture(
params=[
Expand Down Expand Up @@ -120,7 +124,9 @@ def test_null_count():

@pytest.mark.parametrize("method", ["ordinal", "dense", "min", "max", "average"])
@pytest.mark.parametrize("descending", [False, True])
def test_rank_supported(request, ldf: pl.LazyFrame, method: str, *, descending: bool):
def test_rank_supported(
request, ldf: pl.LazyFrame, method: RankMethod, *, descending: bool
):
request.applymarker(
pytest.mark.xfail(condition=POLARS_VERSION_LT_132, reason="rank unsupported")
)
Expand All @@ -133,7 +139,7 @@ def test_rank_supported(request, ldf: pl.LazyFrame, method: str, *, descending:
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("test", ["with_nulls", "with_ties"])
def test_rank_methods_with_nulls_or_ties(
request, ldf: pl.LazyFrame, method: str, *, descending: bool, test: str
request, ldf: pl.LazyFrame, method: RankMethod, *, descending: bool, test: str
) -> None:
request.applymarker(
pytest.mark.xfail(condition=POLARS_VERSION_LT_132, reason="rank unsupported")
Expand All @@ -151,7 +157,7 @@ def test_rank_methods_with_nulls_or_ties(

@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("method", ["random"])
def test_rank_unsupported(ldf: pl.LazyFrame, method: str, seed: int) -> None:
def test_rank_unsupported(ldf: pl.LazyFrame, method: RankMethod, seed: int) -> None:
expr = pl.col("a").rank(method=method, seed=seed)
q = ldf.select(expr)
assert_ir_translation_raises(q, NotImplementedError)
13 changes: 9 additions & 4 deletions python/cudf_polars/tests/expressions/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import polars as pl
Expand All @@ -13,6 +15,9 @@
)
from cudf_polars.utils.versions import POLARS_VERSION_LT_130, POLARS_VERSION_LT_132

if TYPE_CHECKING:
from cudf_polars.typing import RankMethod


@pytest.fixture
def df():
Expand Down Expand Up @@ -273,7 +278,7 @@ def test_over_broadcast_input_row_group_indices_aligned():
def test_rank_over(
request,
df: pl.LazyFrame,
method: str,
method: RankMethod,
*,
descending: bool,
order_by: None | list[str | pl.Expr],
Expand All @@ -295,7 +300,7 @@ def test_rank_over(
def test_rank_over_with_ties(
request,
df: pl.LazyFrame,
method: str,
method: RankMethod,
*,
descending: bool,
order_by: None | list[str | pl.Expr],
Expand All @@ -319,7 +324,7 @@ def test_rank_over_with_ties(
def test_rank_over_with_null_values(
request,
df: pl.LazyFrame,
method: str,
method: RankMethod,
*,
descending: bool,
order_by: None | list[str | pl.Expr],
Expand All @@ -343,7 +348,7 @@ def test_rank_over_with_null_values(
def test_rank_over_with_null_group_keys(
request,
df: pl.LazyFrame,
method: str,
method: RankMethod,
*,
descending: bool,
order_by: None | list[str | pl.Expr],
Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,10 @@ def get_handler(req: Request) -> Response:


def test_scan_ndjson_remote(
request, tmp_path: Path, df: pl.LazyFrame, httpserver: HTTPServer
request: pytest.FixtureRequest,
tmp_path: Path,
df: pl.DataFrame,
httpserver: HTTPServer,
) -> None:
request.applymarker(
pytest.mark.xfail(
Expand Down