Skip to content

Commit e19f3a2

Browse files
Merge pull request #217 from codeflash-ai/proper-cleanup
cleanup concolic dirs properly, add precommit
2 parents 1879e0e + 84a4f8d commit e19f3a2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3619
-909
lines changed

.github/workflows/pre-commit.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: Lint
2+
on:
3+
pull_request:
4+
push:
5+
branches:
6+
- main
7+
8+
concurrency:
9+
group: ${{ github.workflow }}-${{ github.ref }}
10+
cancel-in-progress: true
11+
12+
jobs:
13+
lint:
14+
name: Run pre-commit hooks
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: actions/checkout@v4
18+
- uses: actions/setup-python@v5
19+
- uses: pre-commit/[email protected]

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: "v0.11.0"
4+
hooks:
5+
- id: ruff
6+
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
7+
- id: ruff-format

codeflash/api/aiservice.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def make_ai_service_request(
7373
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
7474
return response
7575

76-
def optimize_python_code(
76+
def optimize_python_code( # noqa: D417
7777
self,
7878
source_code: str,
7979
dependency_code: str,
@@ -139,7 +139,7 @@ def optimize_python_code(
139139
console.rule()
140140
return []
141141

142-
def optimize_python_code_line_profiler(
142+
def optimize_python_code_line_profiler( # noqa: D417
143143
self,
144144
source_code: str,
145145
dependency_code: str,
@@ -208,7 +208,7 @@ def optimize_python_code_line_profiler(
208208
console.rule()
209209
return []
210210

211-
def log_results(
211+
def log_results( # noqa: D417
212212
self,
213213
function_trace_id: str,
214214
speedup_ratio: dict[str, float | None] | None,
@@ -240,7 +240,7 @@ def log_results(
240240
except requests.exceptions.RequestException as e:
241241
logger.exception(f"Error logging features: {e}")
242242

243-
def generate_regression_tests(
243+
def generate_regression_tests( # noqa: D417
244244
self,
245245
source_code_being_tested: str,
246246
function_to_optimize: FunctionToOptimize,
@@ -270,10 +270,9 @@ def generate_regression_tests(
270270
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.
271271
272272
"""
273-
assert test_framework in [
274-
"pytest",
275-
"unittest",
276-
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
273+
assert test_framework in ["pytest", "unittest"], (
274+
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
275+
)
277276
payload = {
278277
"source_code_being_tested": source_code_being_tested,
279278
"function_to_optimize": function_to_optimize,
@@ -308,7 +307,7 @@ def generate_regression_tests(
308307
error = response.json()["error"]
309308
logger.error(f"Error generating tests: {response.status_code} - {error}")
310309
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
311-
return None
310+
return None # noqa: TRY300
312311
except Exception:
313312
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
314313
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})

codeflash/benchmarking/codeflash_trace.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sqlite3
55
import threading
66
import time
7-
from typing import Callable
7+
from typing import Any, Callable
88

99
from codeflash.picklepatch.pickle_patcher import PicklePatcher
1010

@@ -69,7 +69,7 @@ def write_function_timings(self) -> None:
6969
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
7070
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
7171
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
72-
self.function_calls_data
72+
self.function_calls_data,
7373
)
7474
self._connection.commit()
7575
self.function_calls_data = []
@@ -100,9 +100,10 @@ def __call__(self, func: Callable) -> Callable:
100100
The wrapped function
101101
102102
"""
103-
func_id = (func.__module__,func.__name__)
103+
func_id = (func.__module__, func.__name__)
104+
104105
@functools.wraps(func)
105-
def wrapper(*args, **kwargs):
106+
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
106107
# Initialize thread-local active functions set if it doesn't exist
107108
if not hasattr(self._thread_local, "active_functions"):
108109
self._thread_local.active_functions = set()
@@ -139,9 +140,19 @@ def wrapper(*args, **kwargs):
139140
self._thread_local.active_functions.remove(func_id)
140141
overhead_time = time.thread_time_ns() - end_time
141142
self.function_calls_data.append(
142-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
143-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
144-
overhead_time, None, None)
143+
(
144+
func.__name__,
145+
class_name,
146+
func.__module__,
147+
func.__code__.co_filename,
148+
benchmark_function_name,
149+
benchmark_module_path,
150+
benchmark_line_number,
151+
execution_time,
152+
overhead_time,
153+
None,
154+
None,
155+
)
145156
)
146157
return result
147158

@@ -155,9 +166,19 @@ def wrapper(*args, **kwargs):
155166
self._thread_local.active_functions.remove(func_id)
156167
overhead_time = time.thread_time_ns() - end_time
157168
self.function_calls_data.append(
158-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
159-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
160-
overhead_time, None, None)
169+
(
170+
func.__name__,
171+
class_name,
172+
func.__module__,
173+
func.__code__.co_filename,
174+
benchmark_function_name,
175+
benchmark_module_path,
176+
benchmark_line_number,
177+
execution_time,
178+
overhead_time,
179+
None,
180+
None,
181+
)
161182
)
162183
return result
163184
# Flush to database every 100 calls
@@ -168,12 +189,24 @@ def wrapper(*args, **kwargs):
168189
self._thread_local.active_functions.remove(func_id)
169190
overhead_time = time.thread_time_ns() - end_time
170191
self.function_calls_data.append(
171-
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
172-
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
173-
overhead_time, pickled_args, pickled_kwargs)
192+
(
193+
func.__name__,
194+
class_name,
195+
func.__module__,
196+
func.__code__.co_filename,
197+
benchmark_function_name,
198+
benchmark_module_path,
199+
benchmark_line_number,
200+
execution_time,
201+
overhead_time,
202+
pickled_args,
203+
pickled_kwargs,
204+
)
174205
)
175206
return result
207+
176208
return wrapper
177209

210+
178211
# Create a singleton instance
179212
codeflash_trace = CodeflashTrace()

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from pathlib import Path
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional, Union
24

35
import isort
46
import libcst as cst
57

6-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
8+
if TYPE_CHECKING:
9+
from pathlib import Path
10+
11+
from libcst import BaseStatement, ClassDef, FlattenSentinel, FunctionDef, RemovalSentinel
12+
13+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
714

815

916
class AddDecoratorTransformer(cst.CSTTransformer):
@@ -13,57 +20,48 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1320
self.added_codeflash_trace = False
1421
self.class_name = ""
1522
self.function_name = ""
16-
self.decorator = cst.Decorator(
17-
decorator=cst.Name(value="codeflash_trace")
18-
)
23+
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))
1924

20-
def leave_ClassDef(self, original_node, updated_node):
25+
def leave_ClassDef(
26+
self, original_node: ClassDef, updated_node: ClassDef
27+
) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]:
2128
if self.class_name == original_node.name.value:
22-
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
29+
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
2330
return updated_node
2431

25-
def visit_ClassDef(self, node):
26-
if self.class_name: # Don't go into nested class
32+
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
33+
if self.class_name: # Don't go into nested class
2734
return False
28-
self.class_name = node.name.value
35+
self.class_name = node.name.value # noqa: RET503
2936

30-
def visit_FunctionDef(self, node):
31-
if self.function_name: # Don't go into nested function
37+
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
38+
if self.function_name: # Don't go into nested function
3239
return False
33-
self.function_name = node.name.value
40+
self.function_name = node.name.value # noqa: RET503
3441

35-
def leave_FunctionDef(self, original_node, updated_node):
42+
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
3643
if self.function_name == original_node.name.value:
3744
self.function_name = ""
3845
if (self.class_name, original_node.name.value) in self.target_functions:
3946
# Add the new decorator after any existing decorators, so it gets executed first
40-
updated_decorators = list(updated_node.decorators) + [self.decorator]
47+
updated_decorators = [*list(updated_node.decorators), self.decorator]
4148
self.added_codeflash_trace = True
42-
return updated_node.with_changes(
43-
decorators=updated_decorators
44-
)
49+
return updated_node.with_changes(decorators=updated_decorators)
4550

4651
return updated_node
4752

48-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
53+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
4954
# Create import statement for codeflash_trace
5055
if not self.added_codeflash_trace:
5156
return updated_node
5257
import_stmt = cst.SimpleStatementLine(
5358
body=[
5459
cst.ImportFrom(
5560
module=cst.Attribute(
56-
value=cst.Attribute(
57-
value=cst.Name(value="codeflash"),
58-
attr=cst.Name(value="benchmarking")
59-
),
60-
attr=cst.Name(value="codeflash_trace")
61+
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
62+
attr=cst.Name(value="codeflash_trace"),
6163
),
62-
names=[
63-
cst.ImportAlias(
64-
name=cst.Name(value="codeflash_trace")
65-
)
66-
]
64+
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
6765
)
6866
]
6967
)
@@ -73,12 +71,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
7371

7472
return updated_node.with_changes(body=new_body)
7573

74+
7675
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
7776
"""Add codeflash_trace to a function.
7877
7978
Args:
8079
code: The source code as a string
81-
function_to_optimize: The FunctionToOptimize instance containing function details
80+
functions_to_optimize: List of FunctionToOptimize instances containing function details
8281
8382
Returns:
8483
The modified source code as a string
@@ -91,25 +90,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
9190
class_name = function_to_optimize.parents[0].name
9291
target_functions.add((class_name, function_to_optimize.function_name))
9392

94-
transformer = AddDecoratorTransformer(
95-
target_functions = target_functions,
96-
)
93+
transformer = AddDecoratorTransformer(target_functions=target_functions)
9794

9895
module = cst.parse_module(code)
9996
modified_module = module.visit(transformer)
10097
return modified_module.code
10198

10299

103-
def instrument_codeflash_trace_decorator(
104-
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
105-
) -> None:
100+
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
106101
"""Instrument codeflash_trace decorator to functions to optimize."""
107102
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
108103
original_code = file_path.read_text(encoding="utf-8")
109-
new_code = add_codeflash_decorator_to_code(
110-
original_code,
111-
functions_to_optimize
112-
)
104+
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
113105
# Modify the code
114106
modified_code = isort.code(code=new_code, float_to_top=True)
115107

0 commit comments

Comments
 (0)