Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b91b61f
wip
aseembits93 Oct 31, 2025
ffff5f1
all tests fixed
Oct 31, 2025
ecbceec
tests modified now
Oct 31, 2025
b3c3ca8
Optimize InjectPerfOnly.find_and_update_line_node
codeflash-ai[bot] Nov 1, 2025
c2817f9
tests work now
Nov 1, 2025
40e82e2
Merge pull request #870 from codeflash-ai/codeflash/optimize-pr867-20…
misrasaurabh1 Nov 1, 2025
92da986
Update codeflash/discovery/discover_unit_tests.py
aseembits93 Nov 1, 2025
bab2752
Merge branch 'main' into inspect-signature-issue
aseembits93 Nov 4, 2025
f29e33a
Merge remote-tracking branch 'origin/main' into inspect-signature-issue
Nov 5, 2025
cb6df90
potential fix
Nov 5, 2025
f305633
potential fix
Nov 5, 2025
b7225e7
Optimize ImportAnalyzer.visit_Attribute
codeflash-ai[bot] Nov 5, 2025
0f2c747
Merge pull request #877 from codeflash-ai/codeflash/optimize-pr867-20…
aseembits93 Nov 5, 2025
1ec1005
Optimize ImportAnalyzer.visit_Call
codeflash-ai[bot] Nov 5, 2025
5ff60e2
Merge pull request #878 from codeflash-ai/codeflash/optimize-pr867-20…
aseembits93 Nov 5, 2025
ccf9bda
linter fix
Nov 5, 2025
add3ddd
Optimize ImportAnalyzer._fast_generic_visit
codeflash-ai[bot] Nov 5, 2025
13d3e6b
Merge pull request #880 from codeflash-ai/codeflash/optimize-pr867-20…
aseembits93 Nov 5, 2025
1c61999
linter fix
Nov 5, 2025
0b18bac
Merge branch 'main' into inspect-signature-issue
aseembits93 Nov 5, 2025
f595de4
Merge branch 'main' into inspect-signature-issue
aseembits93 Nov 6, 2025
18a260c
classmethod and staticmethod for testing
Nov 6, 2025
dfd5128
Apply suggestion from @aseembits93
aseembits93 Nov 6, 2025
f6302d0
newline
Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 159 additions & 7 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -20,6 +21,16 @@
from codeflash.models.models import CodePosition


@dataclass(frozen=True)
class FunctionCallNodeArguments:
args: list[ast.expr]
keywords: list[ast.keyword]


def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
return FunctionCallNodeArguments(call_node.args, call_node.keywords)


def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
for pos in call_positions:
Expand Down Expand Up @@ -73,16 +84,54 @@ def __init__(
def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
return_statement = [test_node]
call_node = None
for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
all_args = get_call_arguments(call_node)
if isinstance(node.func, ast.Name):
function_name = node.func.id

if self.function_object.is_async:
return [test_node]

# Create the signature binding statements
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load()
),
args=[ast.Name(id=function_name, ctx=ast.Load())],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
Expand All @@ -97,9 +146,39 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
*(
call_node.args
if self.mode == TestingMode.PERFORMANCE
else [
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
]
),
]
node.keywords = call_node.keywords
node.keywords = (
[
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
if self.mode == TestingMode.BEHAVIOR
else call_node.keywords
)

# Return the signature binding statements along with the test_node
return_statement = (
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
)
break
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
Expand All @@ -108,9 +187,48 @@ def find_and_update_line_node(
return [test_node]

function_name = ast.unparse(node.func)

# Create the signature binding statements
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="inspect", ctx=ast.Load()),
attr="signature",
ctx=ast.Load(),
),
args=[ast.parse(function_name, mode="eval").body],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.parse(function_name, mode="eval").body,
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
Expand All @@ -125,14 +243,44 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
*(
call_node.args
if self.mode == TestingMode.PERFORMANCE
else [
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
]
),
]
node.keywords = call_node.keywords
node.keywords = (
[
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
if self.mode == TestingMode.BEHAVIOR
else call_node.keywords
)

# Return the signature binding statements along with the test_node
return_statement = (
[bind_call, apply_defaults, test_node] if self.mode == TestingMode.BEHAVIOR else [test_node]
)
break

if call_node is None:
return None
return [test_node]
return return_statement

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
Expand Down Expand Up @@ -593,7 +741,11 @@ def inject_profiling_into_existing_test(
]
if mode == TestingMode.BEHAVIOR:
new_imports.extend(
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
[
ast.Import(names=[ast.alias(name="inspect")]),
ast.Import(names=[ast.alias(name="sqlite3")]),
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
Expand Down
96 changes: 74 additions & 22 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.wildcard_modules: set[str] = set()
# Track aliases: alias_name -> original_name
self.alias_mapping: dict[str, str] = {}
# Track instances: variable_name -> class_name
self.instance_mapping: dict[str, str] = {}

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots: dict[str, list[str]] = {}
# Precompute sets for faster lookup during visit_Attribute()
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand All @@ -247,6 +257,24 @@ def visit_Import(self, node: ast.Import) -> None:
self.found_qualified_name = target_func
return

def visit_Assign(self, node: ast.Assign) -> None:
"""Track variable assignments, especially class instantiations."""
if self.found_any_target_function:
return

# Check if the assignment is a class instantiation
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
class_name = node.value.func.id
if class_name in self.imported_modules:
# Track all target variables as instances of the imported class
for target in node.targets:
if isinstance(target, ast.Name):
# Map the variable to the actual class name (handling aliases)
original_class = self.alias_mapping.get(class_name, class_name)
self.instance_mapping[target.id] = original_class

self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
"""Handle 'from module import name' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -287,6 +315,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.found_qualified_name = qname
return

# Check if any target function is a method of the imported class/module
# Be conservative except when an alias is used (which requires exact method matching)
for target_func in fnames:
if "." in target_func:
class_name, method_name = target_func.split(".", 1)
if aname == class_name and not alias.asname:
# If an alias is used, don't match conservatively
# The actual method usage should be detected in visit_Attribute
self.found_any_target_function = True
self.found_qualified_name = target_func
return

prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
Expand All @@ -301,33 +341,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
if self.found_any_target_function:
return

# Check if this is accessing a target function through an imported module

node_value = node.value
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
return

# Check if this is accessing a method on an instance variable
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
return

# Check for dynamic import match
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

self.generic_visit(node)
Expand Down
Loading
Loading