Skip to content

Commit e99e7c8

Browse files
committed
revert tests/test_trace_benchmarks.py
1 parent 91b7c6f commit e99e7c8

File tree

1 file changed

+119
-136
lines changed

1 file changed

+119
-136
lines changed

tests/test_trace_benchmarks.py

Lines changed: 119 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import gc
21
import multiprocessing
32
import shutil
43
import sqlite3
5-
import time
64
from pathlib import Path
75

86
import pytest
@@ -11,29 +9,11 @@
119
from codeflash.benchmarking.replay_test import generate_replay_test
1210
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
1311
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
14-
15-
16-
def safe_unlink(file_path: Path, max_retries: int = 5, retry_delay: float = 0.5) -> None:
17-
"""Safely delete a file with retries, handling Windows file locking issues."""
18-
for attempt in range(max_retries):
19-
try:
20-
file_path.unlink(missing_ok=True)
21-
return
22-
except PermissionError:
23-
if attempt < max_retries - 1:
24-
time.sleep(retry_delay)
25-
else:
26-
# Last attempt: force garbage collection to close any lingering SQLite connections
27-
gc.collect()
28-
time.sleep(retry_delay * 2)
29-
try:
30-
file_path.unlink(missing_ok=True)
31-
except PermissionError:
32-
# Silently fail on final attempt to avoid test failures from cleanup issues
33-
pass
12+
import time
3413

3514

3615
def test_trace_benchmarks() -> None:
16+
# Test the trace_benchmarks function
3717
project_root = Path(__file__).parent.parent / "code_to_optimize"
3818
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
3919
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
@@ -42,65 +22,66 @@ def test_trace_benchmarks() -> None:
4222
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
4323
assert output_file.exists()
4424
try:
45-
# Query the trace database to verify recorded function calls
46-
with sqlite3.connect(output_file.as_posix()) as conn:
47-
cursor = conn.cursor()
48-
cursor.execute(
49-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
50-
function_calls = cursor.fetchall()
51-
52-
# Accept platform-dependent run multipliers: function calls should come in complete groups of the base set (8)
53-
base_count = 8
54-
assert len(function_calls) >= base_count and len(function_calls) % base_count == 0, (
55-
f"Expected count to be a multiple of {base_count}, but got {len(function_calls)}"
56-
)
57-
58-
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
59-
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
60-
# Expected function calls (each appears twice due to benchmark execution pattern)
61-
base_expected_calls = [
62-
("sorter", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
63-
f"{bubble_sort_path}",
64-
"test_class_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 17),
65-
66-
("sort_class", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
67-
f"{bubble_sort_path}",
68-
"test_class_sort2", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 20),
69-
70-
("sort_static", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
71-
f"{bubble_sort_path}",
72-
"test_class_sort3", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 23),
73-
74-
("__init__", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
75-
f"{bubble_sort_path}",
76-
"test_class_sort4", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 26),
77-
78-
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
79-
f"{bubble_sort_path}",
80-
"test_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 7),
81-
82-
("compute_and_sort", "", "code_to_optimize.process_and_bubble_sort_codeflash_trace",
83-
f"{process_and_bubble_sort_path}",
84-
"test_compute_and_sort", "tests.pytest.benchmarks_test.test_process_and_sort_example", 4),
85-
86-
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
87-
f"{bubble_sort_path}",
88-
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
89-
90-
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
91-
f"{bubble_sort_path}",
92-
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
93-
]
94-
expected_calls = base_expected_calls * 3
95-
# Order-agnostic validation: ensure at least one instance of each base expected call exists
96-
normalized_calls = [(a[0], a[1], a[2], Path(a[3]).name, a[4], a[5], a[6]) for a in function_calls]
97-
normalized_expected = [(e[0], e[1], e[2], Path(e[3]).name, e[4], e[5], e[6]) for e in base_expected_calls]
98-
for expected in normalized_expected:
99-
assert expected in normalized_calls, f"Missing expected call: {expected}"
100-
101-
# Close database connection and ensure cleanup before opening new connections
102-
gc.collect()
103-
time.sleep(0.1)
25+
# check contents of trace file
26+
# connect to database
27+
conn = sqlite3.connect(output_file.as_posix())
28+
cursor = conn.cursor()
29+
30+
# Get the count of records
31+
# Get all records
32+
cursor.execute(
33+
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
34+
function_calls = cursor.fetchall()
35+
36+
# Assert the length of function calls
37+
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
38+
39+
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
40+
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
41+
# Expected function calls
42+
expected_calls = [
43+
("sorter", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
44+
f"{bubble_sort_path}",
45+
"test_class_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 17),
46+
47+
("sort_class", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
48+
f"{bubble_sort_path}",
49+
"test_class_sort2", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 20),
50+
51+
("sort_static", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
52+
f"{bubble_sort_path}",
53+
"test_class_sort3", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 23),
54+
55+
("__init__", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
56+
f"{bubble_sort_path}",
57+
"test_class_sort4", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 26),
58+
59+
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
60+
f"{bubble_sort_path}",
61+
"test_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 7),
62+
63+
("compute_and_sort", "", "code_to_optimize.process_and_bubble_sort_codeflash_trace",
64+
f"{process_and_bubble_sort_path}",
65+
"test_compute_and_sort", "tests.pytest.benchmarks_test.test_process_and_sort_example", 4),
66+
67+
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
68+
f"{bubble_sort_path}",
69+
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
70+
71+
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
72+
f"{bubble_sort_path}",
73+
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
74+
]
75+
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
76+
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
77+
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
78+
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
79+
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
80+
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
81+
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
82+
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
83+
# Close connection
84+
conn.close()
10485
generate_replay_test(output_file, replay_tests_dir)
10586
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
10687
assert test_class_sort_path.exists()
@@ -190,13 +171,10 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
190171
191172
"""
192173
assert test_sort_path.read_text("utf-8").strip()==test_sort_code.strip()
193-
# Ensure database connections are closed before cleanup
194-
gc.collect()
195-
time.sleep(0.1)
196174
finally:
197-
# Cleanup with retry mechanism to handle Windows file locking issues
198-
safe_unlink(output_file)
199-
shutil.rmtree(replay_tests_dir, ignore_errors=True)
175+
# cleanup
176+
output_file.unlink(missing_ok=True)
177+
shutil.rmtree(replay_tests_dir)
200178

201179
# Skip the test in CI as the machine may not be multithreaded
202180
@pytest.mark.ci_skip
@@ -208,15 +186,21 @@ def test_trace_multithreaded_benchmark() -> None:
208186
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
209187
assert output_file.exists()
210188
try:
211-
# Query the trace database to verify recorded function calls
212-
with sqlite3.connect(output_file.as_posix()) as conn:
213-
cursor = conn.cursor()
214-
cursor.execute(
215-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
216-
function_calls = cursor.fetchall()
217-
218-
# Accept platform-dependent run multipliers; any positive count is fine for multithread case
219-
assert len(function_calls) >= 1, f"Expected at least 1 function call, got {len(function_calls)}"
189+
# check contents of trace file
190+
# connect to database
191+
conn = sqlite3.connect(output_file.as_posix())
192+
cursor = conn.cursor()
193+
194+
# Get the count of records
195+
# Get all records
196+
cursor.execute(
197+
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
198+
function_calls = cursor.fetchall()
199+
200+
conn.close()
201+
202+
# Assert the length of function calls
203+
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
220204
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
221205
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
222206
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
@@ -228,26 +212,26 @@ def test_trace_multithreaded_benchmark() -> None:
228212
assert percent >= 0.0
229213

230214
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
231-
# Expected function calls (each appears multiple times due to benchmark execution pattern)
215+
# Expected function calls
232216
expected_calls = [
233217
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
234218
f"{bubble_sort_path}",
235219
"test_benchmark_sort", "tests.pytest.benchmarks_multithread.test_multithread_sort", 4),
236-
] * 30
220+
]
237221
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
238222
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
239223
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
240224
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
241225
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
242226
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
227+
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
243228
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
244-
245-
# Ensure database connections are closed before cleanup
246-
gc.collect()
247-
time.sleep(0.1)
229+
# Close connection
230+
conn.close()
231+
248232
finally:
249-
# Cleanup with retry mechanism to handle Windows file locking issues
250-
safe_unlink(output_file)
233+
# cleanup
234+
output_file.unlink(missing_ok=True)
251235

252236
def test_trace_benchmark_decorator() -> None:
253237
project_root = Path(__file__).parent.parent / "code_to_optimize"
@@ -257,35 +241,31 @@ def test_trace_benchmark_decorator() -> None:
257241
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
258242
assert output_file.exists()
259243
try:
260-
# Query the trace database to verify recorded function calls
261-
with sqlite3.connect(output_file.as_posix()) as conn:
262-
cursor = conn.cursor()
263-
cursor.execute(
264-
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
265-
function_calls = cursor.fetchall()
266-
267-
# Accept platform-dependent run multipliers: should be a multiple of base set (2)
268-
base_count = 2
269-
assert len(function_calls) >= base_count and len(function_calls) % base_count == 0, (
270-
f"Expected count to be a multiple of {base_count}, but got {len(function_calls)}"
271-
)
272-
273-
# Close database connection and ensure cleanup before opening new connections
274-
gc.collect()
275-
time.sleep(0.1)
276-
244+
# check contents of trace file
245+
# connect to database
246+
conn = sqlite3.connect(output_file.as_posix())
247+
cursor = conn.cursor()
248+
249+
# Get the count of records
250+
# Get all records
251+
cursor.execute(
252+
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
253+
function_calls = cursor.fetchall()
254+
255+
# Assert the length of function calls
256+
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
277257
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
278258
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
279259
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
280260
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
281261

282262
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
283-
assert total_time >= 0.0
284-
assert function_time >= 0.0
285-
assert percent >= 0.0
263+
assert total_time > 0.0
264+
assert function_time > 0.0
265+
assert percent > 0.0
286266

287267
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
288-
# Expected function calls (each appears twice due to benchmark execution pattern)
268+
# Expected function calls
289269
expected_calls = [
290270
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
291271
f"{bubble_sort_path}",
@@ -294,15 +274,18 @@ def test_trace_benchmark_decorator() -> None:
294274
f"{bubble_sort_path}",
295275
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
296276
]
297-
# Order-agnostic validation for decorator case as well
298-
normalized_calls = [(a[0], a[1], a[2], Path(a[3]).name, a[4], a[5], a[6]) for a in function_calls]
299-
normalized_expected = [(e[0], e[1], e[2], Path(e[3]).name, e[4], e[5], e[6]) for e in expected_calls]
300-
for expected in normalized_expected:
301-
assert expected in normalized_calls, f"Missing expected call: {expected}"
302-
303-
# Ensure database connections are closed before cleanup
304-
gc.collect()
305-
time.sleep(0.1)
277+
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
278+
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
279+
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
280+
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
281+
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
282+
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
283+
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
284+
# Close connection
285+
cursor.close()
286+
conn.close()
287+
time.sleep(2)
306288
finally:
307-
# Cleanup with retry mechanism to handle Windows file locking issues
308-
safe_unlink(output_file)
289+
# cleanup
290+
output_file.unlink(missing_ok=True)
291+
time.sleep(1)

0 commit comments

Comments
 (0)