From cef989189ad4fd571ecc89b90124e61426c33788 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Mon, 10 Nov 2025 14:15:37 -0800 Subject: [PATCH] squash --- codeflash/optimization/function_optimizer.py | 56 ++++++++++++++++++++ uv.lock | 11 ++++ 2 files changed, 67 insertions(+) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5f4ab8767..e1b9d821d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -104,6 +104,7 @@ if TYPE_CHECKING: from argparse import Namespace + from re import Pattern from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result @@ -413,6 +414,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: function_references, ) = test_setup_result.unwrap() + self.remove_failing_tests() + console.rule() + baseline_setup_result = self.setup_and_establish_baseline( code_context=code_context, original_helper_code=original_helper_code, @@ -1557,6 +1561,58 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + def remove_failing_tests(self) -> None: + from codeflash.code_utils.edit_generated_tests import _compile_function_patterns + + test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) + behavioral_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, + enable_coverage=False, + ) + + failing_test_names = [ + result.id.test_function_name + for result in behavioral_results + if (result.test_type == TestType.GENERATED_REGRESSION and not result.did_pass) + ] + + if not failing_test_names: + logger.info("All generated tests pass ✅") + + function_patterns = _compile_function_patterns(failing_test_names) + + for test_file in self.test_files.test_files: + if test_file.test_type != TestType.GENERATED_REGRESSION: + continue + + for file_path in [test_file.instrumented_behavior_file_path, test_file.benchmarking_file_path]: + if file_path and file_path.exists(): + source = file_path.read_text(encoding="utf-8") + source = self.remove_tests_from_source(source, function_patterns) + file_path.write_text(source, encoding="utf-8") + + if test_file.original_source: + test_file.original_source = self.remove_tests_from_source(test_file.original_source, function_patterns) + + if failing_test_names: + logger.info(f"Removed {len(failing_test_names)} failing generated test(s) ❌") + + def remove_tests_from_source(self, source: str, function_patterns: list[Pattern]) -> str: + for pattern in function_patterns: + match = pattern.search(source) + while match: + if "@pytest.mark.parametrize" in match.group(0): + match = pattern.search(source, match.end()) + continue + start, end = match.span() + source = source[:start] + source[end:] + match = pattern.search(source, start) + return source + def establish_original_code_baseline( self, code_context: CodeOptimizationContext, diff --git a/uv.lock b/uv.lock index 0d99bdf15..96c6b9f98 100644 --- a/uv.lock +++ b/uv.lock @@ -372,6 +372,7 @@ dev = [ ] tests = [ { name = "black" }, + { name = "eval-type-backport" }, { name = "jax", version = "0.4.30", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "jax", version = "0.6.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, { name = "jax", version = "0.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -451,6 +452,7 @@ dev = [ ] tests = [ { name = "black", specifier = ">=25.9.0" }, + { name = "eval-type-backport" }, { name = "jax", specifier = ">=0.4.30" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "pandas", specifier = ">=2.3.3" }, @@ -699,6 +701,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/c2/4bc8cd09b14e28ce3f406a8b05761bed0d785d1ca8c2a5c6684d884c66a2/editor-1.6.6-py3-none-any.whl", hash = "sha256:e818e6913f26c2a81eadef503a2741d7cca7f235d20e217274a009ecd5a74abf", size = 4017, upload-time = "2024-01-25T10:44:58.66Z" }, ] +[[package]] +name = "eval-type-backport" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/ea/8b0ac4469d4c347c6a385ff09dc3c048c2d021696664e26c7ee6791631b5/eval_type_backport-0.2.2.tar.gz", hash = "sha256:f0576b4cf01ebb5bd358d02314d31846af5e07678387486e2c798af0e7d849c1", size = 9079, upload-time = "2024-12-21T20:09:46.005Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/31/55cd413eaccd39125368be33c46de24a1f639f2e12349b0361b4678f3915/eval_type_backport-0.2.2-py3-none-any.whl", hash = "sha256:cb6ad7c393517f476f96d456d0412ea80f0a8cf96f6892834cd9340149111b0a", size = 5830, upload-time = "2024-12-21T20:09:44.175Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.0"