Skip to content

Commit d39cd4a

Browse files
authored
Add test cases, support annotated assignment, fix list assignments (#59)
1 parent c1dd2cb commit d39cd4a

File tree

5 files changed

+101
-30
lines changed

5 files changed

+101
-30
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
[pypi]: https://pypi.org/project/polarify
1616
[pypi-badge]: https://img.shields.io/pypi/v/polarify.svg?style=flat-square&logo=pypi&logoColor=white
1717
[python-version-badge]: https://img.shields.io/pypi/pyversions/polarify?style=flat-square&logoColor=white&logo=python
18-
[codecov-badge]: https://codecov.io/gh/quantco/polarify/branch/main/graph/badge.svg
18+
[codecov-badge]: https://img.shields.io/codecov/c/github/quantco/polarify?style=flat-square&logo=codecov
1919
[codecov]: https://codecov.io/gh/quantco/polarify
2020

2121
Welcome to **polarIFy**, a Python function decorator that simplifies the way you write logical statements for Polars. With polarIFy, you can use Python's language structures like `if / elif / else` statements and transform them into `pl.when(..).then(..).otherwise(..)` statements. This makes your code more readable and less cumbersome to write. 🎉

polarify/main.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def generic_visit(self, node):
8282
@dataclass
8383
class UnresolvedState:
8484
"""
85-
When a execution flow is not finished (i.e., not returned) in a function, we need to keep track
85+
When an execution flow is not finished (i.e., not returned) in a function, we need to keep track
8686
of the assignments.
8787
"""
8888

@@ -101,8 +101,7 @@ def _handle_assign(stmt: ast.Assign, assignments: dict[str, ast.expr]):
101101
)
102102
assert len(t.elts) == len(stmt.value.elts)
103103
for sub_t, sub_v in zip(t.elts, stmt.value.elts):
104-
diff = _handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments)
105-
assignments.update(diff)
104+
_handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments)
106105
else:
107106
raise ValueError(
108107
f"Unsupported expression type inside assignment target: {type(t)}"
@@ -140,7 +139,10 @@ class State:
140139

141140
node: UnresolvedState | ReturnState | ConditionalState
142141

143-
def handle_assign(self, expr: ast.Assign):
142+
def handle_assign(self, expr: ast.Assign | ast.AnnAssign):
143+
if isinstance(expr, ast.AnnAssign):
144+
expr = ast.Assign(targets=[expr.target], value=expr.value)
145+
144146
if isinstance(self.node, UnresolvedState):
145147
self.node.handle_assign(expr)
146148
elif isinstance(self.node, ConditionalState):
@@ -167,33 +169,13 @@ def handle_return(self, value: ast.expr):
167169
self.node.then.handle_return(value)
168170
self.node.orelse.handle_return(value)
169171

170-
def check_all_branches_return(self):
171-
if isinstance(self.node, UnresolvedState):
172-
return False
173-
elif isinstance(self.node, ReturnState):
174-
return True
175-
else:
176-
return (
177-
self.node.then.check_all_branches_return()
178-
and self.node.orelse.check_all_branches_return()
179-
)
180-
181-
182-
def is_returning_body(stmts: list[ast.stmt]) -> bool:
183-
for s in stmts:
184-
if isinstance(s, ast.Return):
185-
return True
186-
elif isinstance(s, ast.If):
187-
return bool(is_returning_body(s.body) and is_returning_body(s.orelse))
188-
return False
189-
190172

191173
def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State:
192174
if assignments is None:
193175
assignments = {}
194176
state = State(UnresolvedState(assignments))
195177
for stmt in full_body:
196-
if isinstance(stmt, ast.Assign):
178+
if isinstance(stmt, (ast.Assign, ast.AnnAssign)):
197179
state.handle_assign(stmt)
198180
elif isinstance(stmt, ast.If):
199181
state.handle_if(stmt)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
55
[project]
66
name = "polarify"
77
description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️"
8-
version = "0.1.4"
8+
version = "0.1.5"
99
readme = "README.md"
1010
license = "MIT"
1111
requires-python = ">=3.9"

tests/functions.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,71 @@ def walrus_expr(x):
107107
return s * y
108108

109109

110+
def return_nothing(x):
111+
if x > 0:
112+
return
113+
else:
114+
return 1
115+
116+
117+
def no_return(x):
118+
s = x
119+
120+
121+
def return_end(x):
122+
s = x
123+
return
124+
125+
126+
def annotated_assign(x):
127+
s: int = 15
128+
return s + x
129+
130+
131+
def conditional_assign(x):
132+
s = 1
133+
if x > 0:
134+
s = 2
135+
b = 3
136+
return b
137+
138+
139+
def return_constant(x):
140+
return 1
141+
142+
143+
def return_constant_2(x):
144+
return 1 + 2
145+
146+
147+
def return_unconditional_constant(x):
148+
if x > 0:
149+
s = 1
150+
else:
151+
s = 2
152+
return 1
153+
154+
155+
def return_constant_additional_assignments(x):
156+
s = 2
157+
return 1
158+
159+
160+
def return_conditional_constant(x):
161+
if x > 0:
162+
return 1
163+
return 0
164+
165+
166+
def multiple_if(x):
167+
s = 1
168+
if x > 0:
169+
s = 2
170+
if x > 1:
171+
s = 3
172+
return s
173+
174+
110175
def multiple_if_else(x):
111176
if x > 0:
112177
s = 1
@@ -179,6 +244,16 @@ def multiple_equals(x):
179244
return x + a + b
180245

181246

247+
def tuple_assignments(x):
248+
a, b = 1, x
249+
return x + a + b
250+
251+
252+
def list_assignments(x):
253+
[a, b] = 1, x
254+
return x + a + b
255+
256+
182257
functions = [
183258
signum,
184259
early_return,
@@ -199,14 +274,28 @@ def multiple_equals(x):
199274
signum_no_default,
200275
nested_partial_return_with_assignments,
201276
multiple_equals,
277+
tuple_assignments,
278+
list_assignments,
279+
annotated_assign,
280+
conditional_assign,
281+
multiple_if,
282+
return_unconditional_constant,
283+
return_conditional_constant,
202284
]
203285

204286
xfail_functions = [
205287
walrus_expr,
288+
# our test setup does not work with literal expressions
289+
return_constant,
290+
return_constant_2,
291+
return_constant_additional_assignments,
206292
]
207293

208294
unsupported_functions = [
209295
# function, match string in error message
210296
(chained_compare_expr, "Polars can't handle chained comparisons"),
211297
(bool_op, "ast.BoolOp"), # TODO: make error message more specific
298+
(return_end, "return needs a value"),
299+
(no_return, "Not all branches return"),
300+
(return_nothing, "return needs a value"),
212301
]

tests/test_parse_body.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
params=functions
2222
+ [pytest.param(f, marks=pytest.mark.xfail(reason="not implemented")) for f in xfail_functions],
2323
)
24-
def test_funcs(request):
24+
def funcs(request):
2525
original_func = request.param
2626
transformed_func = polarify(original_func)
2727
original_func_unparsed = inspect.getsource(original_func)
@@ -41,9 +41,9 @@ def test_funcs(request):
4141
chunked=False if pl_version < Version("0.18.1") else None,
4242
)
4343
)
44-
def test_transform_function(df: polars.DataFrame, test_funcs):
44+
def test_transform_function(df: polars.DataFrame, funcs):
4545
x = polars.col("x")
46-
transformed_func, original_func = test_funcs
46+
transformed_func, original_func = funcs
4747

4848
if pl_version < Version("0.19.0"):
4949
df_with_transformed_func = df.select(transformed_func(x).alias("apply"))

0 commit comments

Comments
 (0)