Skip to content

Commit 1ee8c8a

Browse files
Merge pull request #273 from codeflash-ai/better-stdout-capture
More precise stdout capture
2 parents 47f6c02 + 0e5f79f commit 1ee8c8a

File tree

7 files changed

+295
-243
lines changed

7 files changed

+295
-243
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 51 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -462,50 +462,53 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
462462
),
463463
*(
464464
[
465+
ast.Assign(
466+
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
467+
value=ast.JoinedStr(
468+
values=[
469+
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
470+
ast.Constant(value=":"),
471+
ast.FormattedValue(
472+
value=ast.IfExp(
473+
test=ast.Name(id="test_class_name", ctx=ast.Load()),
474+
body=ast.BinOp(
475+
left=ast.Name(id="test_class_name", ctx=ast.Load()),
476+
op=ast.Add(),
477+
right=ast.Constant(value="."),
478+
),
479+
orelse=ast.Constant(value=""),
480+
),
481+
conversion=-1,
482+
),
483+
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
484+
ast.Constant(value=":"),
485+
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
486+
ast.Constant(value=":"),
487+
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
488+
ast.Constant(value=":"),
489+
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
490+
]
491+
),
492+
lineno=lineno + 9,
493+
),
465494
ast.Expr(
466495
value=ast.Call(
467496
func=ast.Name(id="print", ctx=ast.Load()),
468497
args=[
469498
ast.JoinedStr(
470499
values=[
471-
ast.Constant(value="!######"),
500+
ast.Constant(value="!$######"),
472501
ast.FormattedValue(
473-
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
502+
value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1
474503
),
475-
ast.Constant(value=":"),
476-
ast.FormattedValue(
477-
value=ast.IfExp(
478-
test=ast.Name(id="test_class_name", ctx=ast.Load()),
479-
body=ast.BinOp(
480-
left=ast.Name(id="test_class_name", ctx=ast.Load()),
481-
op=ast.Add(),
482-
right=ast.Constant(value="."),
483-
),
484-
orelse=ast.Constant(value=""),
485-
),
486-
conversion=-1,
487-
),
488-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
489-
ast.Constant(value=":"),
490-
ast.FormattedValue(
491-
value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1
492-
),
493-
ast.Constant(value=":"),
494-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
495-
ast.Constant(value=":"),
496-
ast.FormattedValue(
497-
value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1
498-
),
499-
ast.Constant(value="######!"),
504+
ast.Constant(value="######$!"),
500505
]
501506
)
502507
],
503508
keywords=[],
504509
)
505-
)
510+
),
506511
]
507-
if mode == TestingMode.BEHAVIOR
508-
else []
509512
),
510513
ast.Assign(
511514
targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10
@@ -598,56 +601,30 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
598601
keywords=[],
599602
)
600603
),
601-
*(
602-
[
603-
ast.Expr(
604-
value=ast.Call(
605-
func=ast.Name(id="print", ctx=ast.Load()),
606-
args=[
607-
ast.JoinedStr(
608-
values=[
609-
ast.Constant(value="!######"),
610-
ast.FormattedValue(
611-
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
612-
),
613-
ast.Constant(value=":"),
614-
ast.FormattedValue(
615-
value=ast.IfExp(
616-
test=ast.Name(id="test_class_name", ctx=ast.Load()),
617-
body=ast.BinOp(
618-
left=ast.Name(id="test_class_name", ctx=ast.Load()),
619-
op=ast.Add(),
620-
right=ast.Constant(value="."),
621-
),
622-
orelse=ast.Constant(value=""),
623-
),
624-
conversion=-1,
625-
),
626-
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
627-
ast.Constant(value=":"),
628-
ast.FormattedValue(
629-
value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1
630-
),
631-
ast.Constant(value=":"),
632-
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
633-
ast.Constant(value=":"),
634-
ast.FormattedValue(
635-
value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1
636-
),
604+
ast.Expr(
605+
value=ast.Call(
606+
func=ast.Name(id="print", ctx=ast.Load()),
607+
args=[
608+
ast.JoinedStr(
609+
values=[
610+
ast.Constant(value="!######"),
611+
ast.FormattedValue(value=ast.Name(id="test_stdout_tag", ctx=ast.Load()), conversion=-1),
612+
*(
613+
[
637614
ast.Constant(value=":"),
638615
ast.FormattedValue(
639616
value=ast.Name(id="codeflash_duration", ctx=ast.Load()), conversion=-1
640617
),
641-
ast.Constant(value="######!"),
642618
]
643-
)
644-
],
645-
keywords=[],
619+
if mode == TestingMode.PERFORMANCE
620+
else []
621+
),
622+
ast.Constant(value="######!"),
623+
]
646624
)
647-
)
648-
]
649-
if mode == TestingMode.PERFORMANCE
650-
else []
625+
],
626+
keywords=[],
627+
)
651628
),
652629
*(
653630
[

codeflash/verification/codeflash_capture.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,8 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003
112112

113113
# Generate invocation id
114114
invocation_id = f"{line_id}_{codeflash_test_index}"
115-
print(
116-
f"!######{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}######!"
117-
)
115+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
116+
print(f"!$######{test_stdout_tag}######$!")
118117
# Connect to sqlite
119118
codeflash_con = sqlite3.connect(f"{tmp_dir_path}_{codeflash_iteration}.sqlite")
120119
codeflash_cur = codeflash_con.cursor()
@@ -131,6 +130,7 @@ def wrapper(*args, **kwargs) -> None: # noqa: ANN002, ANN003
131130
exception = e
132131
finally:
133132
gc.enable()
133+
print(f"!######{test_stdout_tag}######!")
134134

135135
# Capture instance state after initialization
136136
if hasattr(args[0], "__dict__"):

codeflash/verification/parse_test_output.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def parse_func(file_path: Path) -> XMLParser:
3636
return parse(file_path, xml_parser)
3737

3838

39-
matches_re = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
40-
cleaner_re = re.compile(r"!######.*?######!|-+\s*Captured\s+(Log|Out)\s*-+\n?")
39+
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
40+
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
4141

4242

4343
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
@@ -265,12 +265,16 @@ def parse_test_xml(
265265
timed_out = True
266266

267267
sys_stdout = testcase.system_out or ""
268-
matches = matches_re.findall(sys_stdout)
269-
270-
if sys_stdout:
271-
sys_stdout = cleaner_re.sub("", sys_stdout).strip()
272-
273-
if not matches or not len(matches):
268+
begin_matches = list(matches_re_start.finditer(sys_stdout))
269+
end_matches = {}
270+
for match in matches_re_end.finditer(sys_stdout):
271+
groups = match.groups()
272+
if len(groups[5].split(":")) > 1:
273+
iteration_id = groups[5].split(":")[0]
274+
groups = groups[:5] + (iteration_id,)
275+
end_matches[groups] = match
276+
277+
if not begin_matches or not begin_matches:
274278
test_results.add(
275279
FunctionTestInvocation(
276280
loop_index=loop_index,
@@ -288,26 +292,36 @@ def parse_test_xml(
288292
test_type=test_type,
289293
return_value=None,
290294
timed_out=timed_out,
291-
stdout=sys_stdout,
295+
stdout="",
292296
)
293297
)
294298

295299
else:
296-
for match in matches:
297-
split_val = match[5].split(":")
298-
if len(split_val) > 1:
299-
iteration_id = split_val[0]
300-
runtime = int(split_val[1])
300+
for match_index, match in enumerate(begin_matches):
301+
groups = match.groups()
302+
end_match = end_matches.get(groups)
303+
iteration_id, runtime = groups[5], None
304+
if end_match:
305+
stdout = sys_stdout[match.end() : end_match.start()]
306+
split_val = end_match.groups()[5].split(":")
307+
if len(split_val) > 1:
308+
iteration_id = split_val[0]
309+
runtime = int(split_val[1])
310+
else:
311+
iteration_id, runtime = split_val[0], None
312+
elif match_index == len(begin_matches) - 1:
313+
stdout = sys_stdout[match.end() :]
301314
else:
302-
iteration_id, runtime = split_val[0], None
315+
stdout = sys_stdout[match.end() : begin_matches[match_index + 1].start()]
316+
303317
test_results.add(
304318
FunctionTestInvocation(
305-
loop_index=int(match[4]),
319+
loop_index=int(groups[4]),
306320
id=InvocationId(
307-
test_module_path=match[0],
308-
test_class_name=None if match[1] == "" else match[1][:-1],
309-
test_function_name=match[2],
310-
function_getting_tested=match[3],
321+
test_module_path=groups[0],
322+
test_class_name=None if groups[1] == "" else groups[1][:-1],
323+
test_function_name=groups[2],
324+
function_getting_tested=groups[3],
311325
iteration_id=iteration_id,
312326
),
313327
file_name=test_file_path,
@@ -317,7 +331,7 @@ def parse_test_xml(
317331
test_type=test_type,
318332
return_value=None,
319333
timed_out=timed_out,
320-
stdout=sys_stdout,
334+
stdout=stdout,
321335
)
322336
)
323337

0 commit comments

Comments
 (0)