1- import gc
21import multiprocessing
32import shutil
43import sqlite3
5- import time
64from pathlib import Path
75
86import pytest
119from codeflash .benchmarking .replay_test import generate_replay_test
1210from codeflash .benchmarking .trace_benchmarks import trace_benchmarks_pytest
1311from codeflash .benchmarking .utils import validate_and_format_benchmark_table
14-
15-
16- def safe_unlink (file_path : Path , max_retries : int = 5 , retry_delay : float = 0.5 ) -> None :
17- """Safely delete a file with retries, handling Windows file locking issues."""
18- for attempt in range (max_retries ):
19- try :
20- file_path .unlink (missing_ok = True )
21- return
22- except PermissionError :
23- if attempt < max_retries - 1 :
24- time .sleep (retry_delay )
25- else :
26- # Last attempt: force garbage collection to close any lingering SQLite connections
27- gc .collect ()
28- time .sleep (retry_delay * 2 )
29- try :
30- file_path .unlink (missing_ok = True )
31- except PermissionError :
32- # Silently fail on final attempt to avoid test failures from cleanup issues
33- pass
12+ import time
3413
3514
3615def test_trace_benchmarks () -> None :
16+ # Test the trace_benchmarks function
3717 project_root = Path (__file__ ).parent .parent / "code_to_optimize"
3818 benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
3919 replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
@@ -42,65 +22,66 @@ def test_trace_benchmarks() -> None:
4222 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
4323 assert output_file .exists ()
4424 try :
45- # Query the trace database to verify recorded function calls
46- with sqlite3 .connect (output_file .as_posix ()) as conn :
47- cursor = conn .cursor ()
48- cursor .execute (
49- "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
50- function_calls = cursor .fetchall ()
51-
52- # Accept platform-dependent run multipliers: function calls should come in complete groups of the base set (8)
53- base_count = 8
54- assert len (function_calls ) >= base_count and len (function_calls ) % base_count == 0 , (
55- f"Expected count to be a multiple of { base_count } , but got { len (function_calls )} "
56- )
57-
58- bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py" ).as_posix ()
59- process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py" ).as_posix ()
60- # Expected function calls (each appears twice due to benchmark execution pattern)
61- base_expected_calls = [
62- ("sorter" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
63- f"{ bubble_sort_path } " ,
64- "test_class_sort" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 17 ),
65-
66- ("sort_class" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
67- f"{ bubble_sort_path } " ,
68- "test_class_sort2" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 20 ),
69-
70- ("sort_static" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
71- f"{ bubble_sort_path } " ,
72- "test_class_sort3" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 23 ),
73-
74- ("__init__" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
75- f"{ bubble_sort_path } " ,
76- "test_class_sort4" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 26 ),
77-
78- ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
79- f"{ bubble_sort_path } " ,
80- "test_sort" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 7 ),
81-
82- ("compute_and_sort" , "" , "code_to_optimize.process_and_bubble_sort_codeflash_trace" ,
83- f"{ process_and_bubble_sort_path } " ,
84- "test_compute_and_sort" , "tests.pytest.benchmarks_test.test_process_and_sort_example" , 4 ),
85-
86- ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
87- f"{ bubble_sort_path } " ,
88- "test_no_func" , "tests.pytest.benchmarks_test.test_process_and_sort_example" , 8 ),
89-
90- ("recursive_bubble_sort" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
91- f"{ bubble_sort_path } " ,
92- "test_recursive_sort" , "tests.pytest.benchmarks_test.test_recursive_example" , 5 ),
93- ]
94- expected_calls = base_expected_calls * 3
95- # Order-agnostic validation: ensure at least one instance of each base expected call exists
96- normalized_calls = [(a [0 ], a [1 ], a [2 ], Path (a [3 ]).name , a [4 ], a [5 ], a [6 ]) for a in function_calls ]
97- normalized_expected = [(e [0 ], e [1 ], e [2 ], Path (e [3 ]).name , e [4 ], e [5 ], e [6 ]) for e in base_expected_calls ]
98- for expected in normalized_expected :
99- assert expected in normalized_calls , f"Missing expected call: { expected } "
100-
101- # Close database connection and ensure cleanup before opening new connections
102- gc .collect ()
103- time .sleep (0.1 )
25+ # check contents of trace file
26+ # connect to database
27+ conn = sqlite3 .connect (output_file .as_posix ())
28+ cursor = conn .cursor ()
29+
30+ # Get the count of records
31+ # Get all records
32+ cursor .execute (
33+ "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
34+ function_calls = cursor .fetchall ()
35+
36+ # Assert the length of function calls
37+ assert len (function_calls ) == 8 , f"Expected 8 function calls, but got { len (function_calls )} "
38+
39+ bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py" ).as_posix ()
40+ process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py" ).as_posix ()
41+ # Expected function calls
42+ expected_calls = [
43+ ("sorter" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
44+ f"{ bubble_sort_path } " ,
45+ "test_class_sort" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 17 ),
46+
47+ ("sort_class" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
48+ f"{ bubble_sort_path } " ,
49+ "test_class_sort2" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 20 ),
50+
51+ ("sort_static" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
52+ f"{ bubble_sort_path } " ,
53+ "test_class_sort3" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 23 ),
54+
55+ ("__init__" , "Sorter" , "code_to_optimize.bubble_sort_codeflash_trace" ,
56+ f"{ bubble_sort_path } " ,
57+ "test_class_sort4" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 26 ),
58+
59+ ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
60+ f"{ bubble_sort_path } " ,
61+ "test_sort" , "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example" , 7 ),
62+
63+ ("compute_and_sort" , "" , "code_to_optimize.process_and_bubble_sort_codeflash_trace" ,
64+ f"{ process_and_bubble_sort_path } " ,
65+ "test_compute_and_sort" , "tests.pytest.benchmarks_test.test_process_and_sort_example" , 4 ),
66+
67+ ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
68+ f"{ bubble_sort_path } " ,
69+ "test_no_func" , "tests.pytest.benchmarks_test.test_process_and_sort_example" , 8 ),
70+
71+ ("recursive_bubble_sort" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
72+ f"{ bubble_sort_path } " ,
73+ "test_recursive_sort" , "tests.pytest.benchmarks_test.test_recursive_example" , 5 ),
74+ ]
75+ for idx , (actual , expected ) in enumerate (zip (function_calls , expected_calls )):
76+ assert actual [0 ] == expected [0 ], f"Mismatch at index { idx } for function_name"
77+ assert actual [1 ] == expected [1 ], f"Mismatch at index { idx } for class_name"
78+ assert actual [2 ] == expected [2 ], f"Mismatch at index { idx } for module_name"
79+ assert Path (actual [3 ]).name == Path (expected [3 ]).name , f"Mismatch at index { idx } for file_path"
80+ assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
81+ assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
82+ assert actual [6 ] == expected [6 ], f"Mismatch at index { idx } for benchmark_line_number"
83+ # Close connection
84+ conn .close ()
10485 generate_replay_test (output_file , replay_tests_dir )
10586 test_class_sort_path = replay_tests_dir / Path ("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py" )
10687 assert test_class_sort_path .exists ()
@@ -190,13 +171,10 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
190171
191172"""
192173 assert test_sort_path .read_text ("utf-8" ).strip ()== test_sort_code .strip ()
193- # Ensure database connections are closed before cleanup
194- gc .collect ()
195- time .sleep (0.1 )
196174 finally :
197- # Cleanup with retry mechanism to handle Windows file locking issues
198- safe_unlink ( output_file )
199- shutil .rmtree (replay_tests_dir , ignore_errors = True )
175+ # cleanup
176+ output_file . unlink ( missing_ok = True )
177+ shutil .rmtree (replay_tests_dir )
200178
201179# Skip the test in CI as the machine may not be multithreaded
202180@pytest .mark .ci_skip
@@ -208,15 +186,21 @@ def test_trace_multithreaded_benchmark() -> None:
208186 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
209187 assert output_file .exists ()
210188 try :
211- # Query the trace database to verify recorded function calls
212- with sqlite3 .connect (output_file .as_posix ()) as conn :
213- cursor = conn .cursor ()
214- cursor .execute (
215- "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
216- function_calls = cursor .fetchall ()
217-
218- # Accept platform-dependent run multipliers; any positive count is fine for multithread case
219- assert len (function_calls ) >= 1 , f"Expected at least 1 function call, got { len (function_calls )} "
189+ # check contents of trace file
190+ # connect to database
191+ conn = sqlite3 .connect (output_file .as_posix ())
192+ cursor = conn .cursor ()
193+
194+ # Get the count of records
195+ # Get all records
196+ cursor .execute (
197+ "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
198+ function_calls = cursor .fetchall ()
199+
200+ conn .close ()
201+
202+ # Assert the length of function calls
203+ assert len (function_calls ) == 10 , f"Expected 10 function calls, but got { len (function_calls )} "
220204 function_benchmark_timings = codeflash_benchmark_plugin .get_function_benchmark_timings (output_file )
221205 total_benchmark_timings = codeflash_benchmark_plugin .get_benchmark_timings (output_file )
222206 function_to_results = validate_and_format_benchmark_table (function_benchmark_timings , total_benchmark_timings )
@@ -228,26 +212,26 @@ def test_trace_multithreaded_benchmark() -> None:
228212 assert percent >= 0.0
229213
230214 bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py" ).as_posix ()
231- # Expected function calls (each appears multiple times due to benchmark execution pattern)
215+ # Expected function calls
232216 expected_calls = [
233217 ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
234218 f"{ bubble_sort_path } " ,
235219 "test_benchmark_sort" , "tests.pytest.benchmarks_multithread.test_multithread_sort" , 4 ),
236- ] * 30
220+ ]
237221 for idx , (actual , expected ) in enumerate (zip (function_calls , expected_calls )):
238222 assert actual [0 ] == expected [0 ], f"Mismatch at index { idx } for function_name"
239223 assert actual [1 ] == expected [1 ], f"Mismatch at index { idx } for class_name"
240224 assert actual [2 ] == expected [2 ], f"Mismatch at index { idx } for module_name"
241225 assert Path (actual [3 ]).name == Path (expected [3 ]).name , f"Mismatch at index { idx } for file_path"
242226 assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
227+ assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
243228 assert actual [6 ] == expected [6 ], f"Mismatch at index { idx } for benchmark_line_number"
244-
245- # Ensure database connections are closed before cleanup
246- gc .collect ()
247- time .sleep (0.1 )
229+ # Close connection
230+ conn .close ()
231+
248232 finally :
249- # Cleanup with retry mechanism to handle Windows file locking issues
250- safe_unlink ( output_file )
233+ # cleanup
234+ output_file . unlink ( missing_ok = True )
251235
252236def test_trace_benchmark_decorator () -> None :
253237 project_root = Path (__file__ ).parent .parent / "code_to_optimize"
@@ -257,35 +241,31 @@ def test_trace_benchmark_decorator() -> None:
257241 trace_benchmarks_pytest (benchmarks_root , tests_root , project_root , output_file )
258242 assert output_file .exists ()
259243 try :
260- # Query the trace database to verify recorded function calls
261- with sqlite3 .connect (output_file .as_posix ()) as conn :
262- cursor = conn .cursor ()
263- cursor .execute (
264- "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
265- function_calls = cursor .fetchall ()
266-
267- # Accept platform-dependent run multipliers: should be a multiple of base set (2)
268- base_count = 2
269- assert len (function_calls ) >= base_count and len (function_calls ) % base_count == 0 , (
270- f"Expected count to be a multiple of { base_count } , but got { len (function_calls )} "
271- )
272-
273- # Close database connection and ensure cleanup before opening new connections
274- gc .collect ()
275- time .sleep (0.1 )
276-
244+ # check contents of trace file
245+ # connect to database
246+ conn = sqlite3 .connect (output_file .as_posix ())
247+ cursor = conn .cursor ()
248+
249+ # Get the count of records
250+ # Get all records
251+ cursor .execute (
252+ "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" )
253+ function_calls = cursor .fetchall ()
254+
255+ # Assert the length of function calls
256+ assert len (function_calls ) == 2 , f"Expected 2 function calls, but got { len (function_calls )} "
277257 function_benchmark_timings = codeflash_benchmark_plugin .get_function_benchmark_timings (output_file )
278258 total_benchmark_timings = codeflash_benchmark_plugin .get_benchmark_timings (output_file )
279259 function_to_results = validate_and_format_benchmark_table (function_benchmark_timings , total_benchmark_timings )
280260 assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
281261
282262 test_name , total_time , function_time , percent = function_to_results ["code_to_optimize.bubble_sort_codeflash_trace.sorter" ][0 ]
283- assert total_time >= 0.0
284- assert function_time >= 0.0
285- assert percent >= 0.0
263+ assert total_time > 0.0
264+ assert function_time > 0.0
265+ assert percent > 0.0
286266
287267 bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py" ).as_posix ()
288- # Expected function calls (each appears twice due to benchmark execution pattern)
268+ # Expected function calls
289269 expected_calls = [
290270 ("sorter" , "" , "code_to_optimize.bubble_sort_codeflash_trace" ,
291271 f"{ bubble_sort_path } " ,
@@ -294,15 +274,18 @@ def test_trace_benchmark_decorator() -> None:
294274 f"{ bubble_sort_path } " ,
295275 "test_pytest_mark" , "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator" , 11 ),
296276 ]
297- # Order-agnostic validation for decorator case as well
298- normalized_calls = [(a [0 ], a [1 ], a [2 ], Path (a [3 ]).name , a [4 ], a [5 ], a [6 ]) for a in function_calls ]
299- normalized_expected = [(e [0 ], e [1 ], e [2 ], Path (e [3 ]).name , e [4 ], e [5 ], e [6 ]) for e in expected_calls ]
300- for expected in normalized_expected :
301- assert expected in normalized_calls , f"Missing expected call: { expected } "
302-
303- # Ensure database connections are closed before cleanup
304- gc .collect ()
305- time .sleep (0.1 )
277+ for idx , (actual , expected ) in enumerate (zip (function_calls , expected_calls )):
278+ assert actual [0 ] == expected [0 ], f"Mismatch at index { idx } for function_name"
279+ assert actual [1 ] == expected [1 ], f"Mismatch at index { idx } for class_name"
280+ assert actual [2 ] == expected [2 ], f"Mismatch at index { idx } for module_name"
281+ assert Path (actual [3 ]).name == Path (expected [3 ]).name , f"Mismatch at index { idx } for file_path"
282+ assert actual [4 ] == expected [4 ], f"Mismatch at index { idx } for benchmark_function_name"
283+ assert actual [5 ] == expected [5 ], f"Mismatch at index { idx } for benchmark_module_path"
284+ # Close connection
285+ cursor .close ()
286+ conn .close ()
287+ time .sleep (2 )
306288 finally :
307- # Cleanup with retry mechanism to handle Windows file locking issues
308- safe_unlink (output_file )
289+ # cleanup
290+ output_file .unlink (missing_ok = True )
291+ time .sleep (1 )
0 commit comments