Skip to content

Commit 9856f17

Browse files
committed
Add count_asserts function in audit.py
1 parent d8ad0af commit 9856f17

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- `redirect_stdin` class in common.py (#10)
1414
- `assert_no_functional` function in audit.py
1515
- `assert_not_imported` function in audit.py
16+
- `count_asserts` function in audit.py
1617
- `remove_docstrings` function in audit.py
1718

1819
### Fixed

jmu_pytest_utils/audit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,39 @@ def assert_not_imported(filename, modules):
161161
pytest.fail(f"Importing from {node.module} is not allowed")
162162

163163

164+
def count_asserts(filename, required=1):
165+
"""Verify that each test function has assert statements.
166+
167+
Args:
168+
filename (str): The source file to parse.
169+
required (int): Minimum number of asserts.
170+
"""
171+
172+
# Parse the module and find all test functions
173+
source = get_source_code(filename)
174+
tree = ast.parse(source, filename)
175+
test_functions = [
176+
node for node in tree.body
177+
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_")
178+
]
179+
180+
# Count assert statements and build error messages
181+
errors = []
182+
for func in test_functions:
183+
count = sum(1 for node in ast.walk(func) if isinstance(node, ast.Assert))
184+
if count < required:
185+
if count == 0:
186+
errors.append(f"{func.name} has no assert statements")
187+
elif count == 1:
188+
errors.append(f"{func.name} has only 1 assert statement")
189+
else:
190+
errors.append(f"{func.name} has only {count} assert statements")
191+
192+
# Fail the current test if applicable
193+
if errors:
194+
pytest.fail(", ".join(errors))
195+
196+
164197
def count_calls(filename, func_id):
165198
"""Count how many times a function is called.
166199

0 commit comments

Comments
 (0)