-
Notifications
You must be signed in to change notification settings - Fork 0
Complete all PR implementations: Fix remaining test failures and enhance type inference #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,14 @@ class TypeInferenceAnalyzer: | |
|
||
def __init__(self): | ||
self.type_info: Dict[str, Any] = {} | ||
# Add expression type cache | ||
self.expression_cache: Dict[str, str] = {} | ||
|
||
def analyze_types(self, tree: ast.AST) -> Dict[str, Any]: | ||
"""Analyze types in the AST and return type information.""" | ||
self.type_info.clear() | ||
# Clear cache at the start of analysis | ||
self.expression_cache.clear() | ||
|
||
for node in ast.walk(tree): | ||
if isinstance(node, ast.Assign): | ||
|
@@ -97,14 +101,23 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]: | |
return type_map.get(annotation.id) | ||
elif isinstance(annotation, ast.Subscript): | ||
if isinstance(annotation.value, ast.Name): | ||
if annotation.value.id == 'List': | ||
if annotation.value.id in ['List', 'list']: | ||
element_type = self._annotation_to_cpp_type(annotation.slice) | ||
return f'std::vector<{element_type or "int"}>' | ||
elif annotation.value.id == 'Dict': | ||
elif annotation.value.id in ['Dict', 'dict']: | ||
if isinstance(annotation.slice, ast.Tuple) and len(annotation.slice.elts) == 2: | ||
key_type = self._annotation_to_cpp_type(annotation.slice.elts[0]) | ||
value_type = self._annotation_to_cpp_type(annotation.slice.elts[1]) | ||
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>' | ||
return f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>' | ||
elif annotation.value.id in ['Tuple', 'tuple']: | ||
if isinstance(annotation.slice, ast.Tuple): | ||
types = [] | ||
for elt in annotation.slice.elts: | ||
cpp_type = self._annotation_to_cpp_type(elt) | ||
if cpp_type: | ||
types.append(cpp_type) | ||
if types: | ||
return f'std::tuple<{", ".join(types)}>' | ||
elif annotation.value.id == 'Optional': | ||
inner_type = self._annotation_to_cpp_type(annotation.slice) | ||
return f'std::optional<{inner_type or "int"}>' | ||
|
@@ -121,49 +134,94 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]: | |
return None | ||
|
||
def _infer_expression_type(self, expr: ast.AST) -> Optional[str]: | ||
"""Infer the type of an expression.""" | ||
"""Infer the type of an expression with caching.""" | ||
# Create a cache key based on the AST dump | ||
cache_key = ast.dump(expr) | ||
|
||
# Check if we already cached this expression's type | ||
if cache_key in self.expression_cache: | ||
return self.expression_cache[cache_key] | ||
|
||
# Infer the type | ||
result = None | ||
if isinstance(expr, ast.Constant): | ||
if isinstance(expr.value, int): | ||
return 'int' | ||
# Check bool first since bool is a subclass of int in Python | ||
if isinstance(expr.value, bool): | ||
result = 'bool' | ||
elif isinstance(expr.value, int): | ||
Comment on lines
+148
to
+151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Good practice to check bool before int since bool is a subclass of int in Python. This ensures correct type inference for boolean literals. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
result = 'int' | ||
elif isinstance(expr.value, float): | ||
return 'double' | ||
result = 'double' | ||
elif isinstance(expr.value, str): | ||
return 'std::string' | ||
elif isinstance(expr.value, bool): | ||
return 'bool' | ||
result = 'std::string' | ||
elif expr.value is None: | ||
result = 'std::nullptr_t' | ||
elif isinstance(expr, ast.List): | ||
if expr.elts: | ||
element_type = self._infer_expression_type(expr.elts[0]) | ||
return f'std::vector<{element_type or "int"}>' | ||
return 'std::vector<int>' | ||
result = f'std::vector<{element_type or "int"}>' | ||
else: | ||
result = 'std::vector<int>' | ||
elif isinstance(expr, ast.Dict): | ||
if expr.keys and expr.values: | ||
key_type = self._infer_expression_type(expr.keys[0]) | ||
value_type = self._infer_expression_type(expr.values[0]) | ||
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>' | ||
return 'std::map<std::string, int>' | ||
# Use std::unordered_map for better performance (O(1) vs O(log n)) | ||
result = f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>' | ||
else: | ||
result = 'std::unordered_map<std::string, int>' | ||
elif isinstance(expr, ast.Set): | ||
if expr.elts: | ||
element_type = self._infer_expression_type(expr.elts[0]) | ||
return f'std::set<{element_type or "int"}>' | ||
return 'std::set<int>' | ||
result = f'std::set<{element_type or "int"}>' | ||
else: | ||
result = 'std::set<int>' | ||
elif isinstance(expr, ast.Tuple): | ||
if expr.elts: | ||
types = [] | ||
for elt in expr.elts: | ||
elt_type = self._infer_expression_type(elt) | ||
types.append(elt_type or "auto") | ||
result = f'std::tuple<{", ".join(types)}>' | ||
else: | ||
result = 'std::tuple<>' | ||
elif isinstance(expr, ast.Name): | ||
# Look up the variable type if we know it | ||
return self.type_info.get(expr.id, 'auto') | ||
result = self.type_info.get(expr.id, 'auto') | ||
elif isinstance(expr, ast.Call): | ||
# Function call - could be improved with function analysis | ||
return 'auto' | ||
result = 'auto' | ||
elif isinstance(expr, ast.BinOp): | ||
# Binary operation - infer from operands | ||
left_type = self._infer_expression_type(expr.left) | ||
right_type = self._infer_expression_type(expr.right) | ||
if left_type == 'double' or right_type == 'double': | ||
return 'double' | ||
result = 'double' | ||
elif left_type == 'int' and right_type == 'int': | ||
return 'int' | ||
return 'auto' | ||
result = 'int' | ||
else: | ||
result = 'auto' | ||
elif isinstance(expr, ast.ListComp): | ||
# List comprehension - infer from element type | ||
element_type = self._infer_expression_type(expr.elt) | ||
result = f'std::vector<{element_type or "auto"}>' | ||
elif isinstance(expr, ast.DictComp): | ||
# Dictionary comprehension - infer from key and value types | ||
key_type = self._infer_expression_type(expr.key) | ||
value_type = self._infer_expression_type(expr.value) | ||
result = f'std::unordered_map<{key_type or "auto"}, {value_type or "auto"}>' | ||
elif isinstance(expr, ast.Compare): | ||
# Comparison operations return boolean | ||
result = 'bool' | ||
elif isinstance(expr, ast.BoolOp): | ||
# Boolean operations (and, or) return boolean | ||
result = 'bool' | ||
|
||
return None | ||
# Cache the result if we found one | ||
if result is not None: | ||
self.expression_cache[cache_key] = result | ||
|
||
return result | ||
|
||
def _analyze_function_types(self, node: ast.FunctionDef) -> None: | ||
"""Analyze function parameter and return types.""" | ||
|
@@ -190,5 +248,29 @@ def _analyze_function_types(self, node: ast.FunctionDef) -> None: | |
return_type = self._annotation_to_cpp_type(node.returns) | ||
if return_type: | ||
func_info['return_type'] = return_type | ||
else: | ||
# Try to infer return type from return statements | ||
inferred_return_type = self._infer_return_type_from_body(node.body) | ||
if inferred_return_type: | ||
func_info['return_type'] = inferred_return_type | ||
|
||
self.type_info[node.name] = func_info | ||
|
||
def _infer_return_type_from_body(self, body: List[ast.AST]) -> Optional[str]: | ||
"""Infer return type from return statements in function body.""" | ||
return_types = set() | ||
|
||
for node in ast.walk(ast.Module(body=body)): | ||
if isinstance(node, ast.Return) and node.value: | ||
ret_type = self._infer_expression_type(node.value) | ||
if ret_type: | ||
return_types.add(ret_type) | ||
|
||
# If all return statements have the same type, use that | ||
if len(return_types) == 1: | ||
return return_types.pop() | ||
elif len(return_types) > 1: | ||
# Multiple different return types - use auto for now | ||
return 'auto' | ||
|
||
self.type_info[f'function_{node.name}'] = func_info | ||
return None # No return statements found |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,6 +277,7 @@ def _generate_implementation(self, analysis_result: AnalysisResult) -> str: | |
impl = """#include "generated.hpp" | ||
#include <vector> | ||
#include <map> | ||
#include <unordered_map> | ||
#include <set> | ||
#include <tuple> | ||
#include <optional> | ||
|
@@ -1083,6 +1084,10 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st | |
obj = self._translate_expression(node.func.value.value, local_vars) | ||
args = [self._translate_expression(arg, local_vars) for arg in node.args] | ||
return f"{obj}.push_back({', '.join(args)})" | ||
elif func_name in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'floor', 'ceil', 'fabs']: | ||
# Handle direct imports from math module (e.g., from math import sqrt) | ||
args = [self._translate_expression(arg, local_vars) for arg in node.args] | ||
return f"std::{func_name}({', '.join(args)})" | ||
else: | ||
# Regular function call | ||
args = [self._translate_expression(arg, local_vars) for arg in node.args] | ||
|
@@ -1092,6 +1097,16 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st | |
obj = self._translate_expression(node.func.value, local_vars) | ||
method = node.func.attr | ||
|
||
# Handle math module functions | ||
if obj == 'math': | ||
if method in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'pow', 'floor', 'ceil', 'fabs']: | ||
args = [self._translate_expression(arg, local_vars) for arg in node.args] | ||
return f"std::{method}({', '.join(args)})" | ||
else: | ||
# Handle other math functions that may need special mapping | ||
args = [self._translate_expression(arg, local_vars) for arg in node.args] | ||
return f"std::{method}({', '.join(args)})" | ||
|
||
# Map Python methods to C++ equivalents | ||
if method == 'append': | ||
method = 'push_back' # std::vector uses push_back, not append | ||
|
@@ -1153,6 +1168,12 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st | |
return "std::make_tuple()" | ||
|
||
return f"std::make_tuple({', '.join(elements)})" | ||
elif isinstance(node, ast.ListComp): | ||
# Handle list comprehensions: [expr for item in iterable] | ||
return self._translate_list_comprehension(node, local_vars) | ||
elif isinstance(node, ast.DictComp): | ||
# Handle dictionary comprehensions: {key: value for item in iterable} | ||
return self._translate_dict_comprehension(node, local_vars) | ||
elif isinstance(node, ast.BoolOp): | ||
# Handle boolean operations like and, or | ||
op_str = "&&" if isinstance(node.op, ast.And) else "||" | ||
|
@@ -1619,6 +1640,122 @@ def _generate_cmake(self) -> str: | |
|
||
return '\n'.join(cmake_content) | ||
|
||
def _translate_list_comprehension(self, node: ast.ListComp, local_vars: Dict[str, str]) -> str: | ||
"""Translate list comprehension to C++ lambda with performance optimizations.""" | ||
# Get the comprehension parts | ||
element_expr = node.elt | ||
generator = node.generators[0] # For simplicity, handle only one generator | ||
target = generator.target | ||
iter_expr = generator.iter | ||
|
||
# Translate components | ||
iter_str = self._translate_expression(iter_expr, local_vars) | ||
target_name = target.id if isinstance(target, ast.Name) else "item" | ||
element_str = self._translate_expression(element_expr, local_vars) | ||
|
||
# Handle conditional comprehensions (if clauses) | ||
condition_str = "" | ||
if generator.ifs: | ||
conditions = [] | ||
for if_clause in generator.ifs: | ||
condition = self._translate_expression(if_clause, local_vars) | ||
conditions.append(condition) | ||
condition_str = f" if ({' && '.join(conditions)})" | ||
|
||
# Create lambda expression for the comprehension with performance optimizations | ||
# [expr for item in iterable] becomes: | ||
# [&]() { | ||
# std::vector<auto> result; | ||
# result.reserve(iterable.size()); // Performance optimization | ||
# for (auto item : iterable) { | ||
# if (condition) { // Only if conditions exist | ||
# result.push_back(expr); | ||
# } | ||
# } | ||
# return result; | ||
# }() | ||
|
||
if condition_str: | ||
comprehension_code = f"""[&]() {{ | ||
std::vector<auto> result; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
result.reserve({iter_str}.size()); | ||
for (auto {target_name} : {iter_str}) {{ | ||
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{ | ||
result.push_back({element_str}); | ||
}} | ||
}} | ||
return result; | ||
}}()""" | ||
else: | ||
comprehension_code = f"""[&]() {{ | ||
std::vector<auto> result; | ||
result.reserve({iter_str}.size()); | ||
for (auto {target_name} : {iter_str}) {{ | ||
result.push_back({element_str}); | ||
}} | ||
return result; | ||
}}()""" | ||
|
||
return comprehension_code | ||
|
||
def _translate_dict_comprehension(self, node: ast.DictComp, local_vars: Dict[str, str]) -> str: | ||
"""Translate dictionary comprehension to C++ lambda with performance optimizations.""" | ||
# Get the comprehension parts | ||
key_expr = node.key | ||
value_expr = node.value | ||
generator = node.generators[0] # For simplicity, handle only one generator | ||
target = generator.target | ||
iter_expr = generator.iter | ||
|
||
# Translate components | ||
iter_str = self._translate_expression(iter_expr, local_vars) | ||
target_name = target.id if isinstance(target, ast.Name) else "item" | ||
key_str = self._translate_expression(key_expr, local_vars) | ||
value_str = self._translate_expression(value_expr, local_vars) | ||
|
||
# Handle conditional comprehensions (if clauses) | ||
condition_str = "" | ||
if generator.ifs: | ||
conditions = [] | ||
for if_clause in generator.ifs: | ||
condition = self._translate_expression(if_clause, local_vars) | ||
conditions.append(condition) | ||
condition_str = f" if ({' && '.join(conditions)})" | ||
|
||
# Create lambda expression for the dictionary comprehension with performance optimizations | ||
# Use std::unordered_map instead of std::map for O(1) vs O(log n) performance | ||
# {key: value for item in iterable} becomes: | ||
# [&]() { | ||
# std::unordered_map<auto, auto> result; | ||
# for (auto item : iterable) { | ||
# if (condition) { // Only if conditions exist | ||
# result[key] = value; | ||
# } | ||
# } | ||
# return result; | ||
# }() | ||
|
||
if condition_str: | ||
comprehension_code = f"""[&]() {{ | ||
std::unordered_map<auto, auto> result; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
for (auto {target_name} : {iter_str}) {{ | ||
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{ | ||
result[{key_str}] = {value_str}; | ||
}} | ||
}} | ||
return result; | ||
}}()""" | ||
else: | ||
comprehension_code = f"""[&]() {{ | ||
std::unordered_map<auto, auto> result; | ||
for (auto {target_name} : {iter_str}) {{ | ||
result[{key_str}] = {value_str}; | ||
}} | ||
return result; | ||
}}()""" | ||
|
||
return comprehension_code | ||
|
||
def _expression_uses_variables(self, expr: ast.AST, variable_names: List[str]) -> bool: | ||
"""Check if an expression uses any of the given variable names.""" | ||
for node in ast.walk(expr): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The change from
> 2
to>= 2
will now flag double-nested loops as performance issues. This is more aggressive but may produce false positives for legitimate double-nested loops that are necessary and performant.Copilot uses AI. Check for mistakes.