Skip to content

Commit 64caa90

Browse files
committed
Add get_caller() function to simplify test coverage
1 parent 40d0419 commit 64caa90

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88

99
## [Unreleased]
1010

11+
### Added
12+
13+
- `get_caller()` function in coverage.py
14+
- `half_credit()` function in coverage.py
15+
- `full_credit()` function in coverage.py
16+
1117
### Changed
1218

1319
- improve error when username is not found
20+
- `assert_fail()` no longer takes a function
21+
- `assert_pass()` no longer takes a function
22+
- `assert_cover()` no longer takes a function
1423

1524

1625
## [1.6.1] - 2025-10-31

examples/8_test_coverage/test_triangles_cov.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ def test_pep8():
1717
@weight(2)
1818
def test_fail():
1919
"""All tests should fail when given random return values"""
20-
assert_fail(test_fail, "triangles.py", "test_triangles.py")
20+
assert_fail("triangles.py", "test_triangles.py")
2121

2222

2323
@required()
2424
@weight(2)
2525
def test_pass():
2626
"""All tests should pass when given actual return values"""
27-
assert_pass(test_pass, "triangles.py", "test_triangles.py")
27+
assert_pass("triangles.py", "test_triangles.py")
2828

2929

3030
@weight(5)
3131
def test_cover():
3232
"""Code coverage: all statements should run during tests"""
33-
assert_cover(test_cover, "triangles.py", "test_triangles.py")
33+
assert_cover("triangles.py", "test_triangles.py")

jmu_pytest_utils/coverage.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Run test coverage and analyze the results."""
22

33
import importlib
4+
import inspect
45
import json
56
import os
67
import pytest
@@ -34,6 +35,46 @@ def inject_random(main_filename: str) -> None:
3435
obj.__code__ = _return_random.__code__
3536

3637

38+
def get_caller() -> TestFunction:
39+
"""Get a reference to the function that called the current function.
40+
41+
This inspects the call stack to identify the immediate caller and
42+
retrieves its function object from the caller's global namespace.
43+
44+
Returns:
45+
The function that called the current function.
46+
"""
47+
frame = inspect.currentframe()
48+
assert frame is not None, "Cannot get current frame"
49+
caller_frame = frame.f_back
50+
assert caller_frame is not None, "Cannot get caller's frame"
51+
caller_frame = caller_frame.f_back
52+
assert caller_frame is not None, "Cannot get caller's caller's frame"
53+
caller = caller_frame.f_globals.get(caller_frame.f_code.co_name)
54+
assert caller is not None, "Calling function not found in globals"
55+
return caller
56+
57+
58+
def half_credit() -> None:
59+
"""Give half credit for passing the sample tests.
60+
61+
Sets the `score` and `output` attributes of the test function.
62+
"""
63+
test_func = get_caller()
64+
test_func.score = test_func.weight / 2
65+
test_func.output = "Half credit for passing the sample tests.\n"
66+
67+
68+
def full_credit() -> None:
69+
"""Give full credit for passing all of the tests.
70+
71+
Deletes the `score` and `output` attributes set by half_credit().
72+
"""
73+
test_func = get_caller()
74+
del test_func.score
75+
del test_func.output
76+
77+
3778
def _process_results_json(function: TestFunction, status: str, penalty: float) -> None:
3879
"""Verify correctness in the results.json file.
3980
@@ -71,15 +112,14 @@ def _process_results_json(function: TestFunction, status: str, penalty: float) -
71112
pytest.fail(output)
72113

73114

74-
def assert_fail(function: TestFunction, main_filename: str, test_filename: str,
115+
def assert_fail(main_filename: str, test_filename: str,
75116
penalty: float = 1) -> None:
76117
"""Run pytest and assert that all tests fail.
77118
78119
Note: The --jmu option of the jmu_pytest_utils plugin
79120
patches all functions in main_filename to return random.
80121
81122
Args:
82-
function: Test function for score/weight.
83123
main_filename: Name of the main file to test.
84124
test_filename: Name of the test file to run.
85125
penalty: Points per incorrect test function.
@@ -89,15 +129,14 @@ def assert_fail(function: TestFunction, main_filename: str, test_filename: str,
89129
"--jmu=" + main_filename,
90130
test_filename
91131
])
92-
_process_results_json(function, "fail", penalty)
132+
_process_results_json(get_caller(), "fail", penalty)
93133

94134

95-
def assert_pass(function: TestFunction, main_filename: str, test_filename: str,
135+
def assert_pass(main_filename: str, test_filename: str,
96136
penalty: float = 1) -> None:
97137
"""Run pytest and assert that all tests pass.
98138
99139
Args:
100-
function: Test function for score/weight.
101140
main_filename: Name of the main file to test.
102141
test_filename: Name of the test file to run.
103142
penalty: Points per incorrect test function.
@@ -107,15 +146,14 @@ def assert_pass(function: TestFunction, main_filename: str, test_filename: str,
107146
"--jmu=assert_pass",
108147
test_filename
109148
])
110-
_process_results_json(function, "pass", penalty)
149+
_process_results_json(get_caller(), "pass", penalty)
111150

112151

113-
def assert_cover(function: TestFunction, main_filename: str, test_filename: str,
114-
branches: bool = False, line_penalty: float = 1, branch_penalty: float = 1) -> None:
152+
def assert_cover(main_filename: str, test_filename: str, branches: bool = False,
153+
line_penalty: float = 1, branch_penalty: float = 1) -> None:
115154
"""Run pytest and analyze coverage results.
116155
117156
Args:
118-
function: Test function for score/weight.
119157
main_filename: Name of the main file to test.
120158
test_filename: Name of the test file to run.
121159
branches: Whether to report branch coverage.
@@ -159,6 +197,7 @@ def assert_cover(function: TestFunction, main_filename: str, test_filename: str,
159197

160198
# If the test didn't pass, set the score and show output
161199
if points:
200+
function = get_caller()
162201
weight = getattr(function, "weight", 0)
163202
if weight:
164203
setattr(function, "score", max(weight - points, 0))

0 commit comments

Comments
 (0)