Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 64 additions & 23 deletions src/analyzer/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ class TypeInferenceAnalyzer:

def __init__(self):
self.type_info: Dict[str, Any] = {}
# Add expression type cache
self.expression_cache: Dict[str, str] = {}

def analyze_types(self, tree: ast.AST) -> Dict[str, Any]:
"""Analyze types in the AST and return type information."""
self.type_info.clear()
# Clear cache at the start of analysis
self.expression_cache.clear()

for node in ast.walk(tree):
if isinstance(node, ast.Assign):
Expand Down Expand Up @@ -97,14 +101,23 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]:
return type_map.get(annotation.id)
elif isinstance(annotation, ast.Subscript):
if isinstance(annotation.value, ast.Name):
if annotation.value.id == 'List':
if annotation.value.id in ['List', 'list']:
element_type = self._annotation_to_cpp_type(annotation.slice)
return f'std::vector<{element_type or "int"}>'
elif annotation.value.id == 'Dict':
elif annotation.value.id in ['Dict', 'dict']:
if isinstance(annotation.slice, ast.Tuple) and len(annotation.slice.elts) == 2:
key_type = self._annotation_to_cpp_type(annotation.slice.elts[0])
value_type = self._annotation_to_cpp_type(annotation.slice.elts[1])
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>'
return f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>'
elif annotation.value.id in ['Tuple', 'tuple']:
if isinstance(annotation.slice, ast.Tuple):
types = []
for elt in annotation.slice.elts:
cpp_type = self._annotation_to_cpp_type(elt)
if cpp_type:
types.append(cpp_type)
if types:
return f'std::tuple<{", ".join(types)}>'
elif annotation.value.id == 'Optional':
inner_type = self._annotation_to_cpp_type(annotation.slice)
return f'std::optional<{inner_type or "int"}>'
Expand All @@ -121,49 +134,77 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]:
return None

def _infer_expression_type(self, expr: ast.AST) -> Optional[str]:
"""Infer the type of an expression."""
"""Infer the type of an expression with caching."""
# Create a cache key based on the AST dump
cache_key = ast.dump(expr)

# Check if we already cached this expression's type
if cache_key in self.expression_cache:
return self.expression_cache[cache_key]

# Infer the type
result = None
if isinstance(expr, ast.Constant):
if isinstance(expr.value, int):
return 'int'
# Check bool first since bool is a subclass of int in Python
if isinstance(expr.value, bool):
result = 'bool'
elif isinstance(expr.value, int):
result = 'int'
elif isinstance(expr.value, float):
return 'double'
result = 'double'
elif isinstance(expr.value, str):
return 'std::string'
elif isinstance(expr.value, bool):
return 'bool'
result = 'std::string'
elif isinstance(expr, ast.List):
if expr.elts:
element_type = self._infer_expression_type(expr.elts[0])
return f'std::vector<{element_type or "int"}>'
return 'std::vector<int>'
result = f'std::vector<{element_type or "int"}>'
else:
result = 'std::vector<int>'
elif isinstance(expr, ast.Dict):
if expr.keys and expr.values:
key_type = self._infer_expression_type(expr.keys[0])
value_type = self._infer_expression_type(expr.values[0])
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>'
return 'std::map<std::string, int>'
# Use std::unordered_map for better performance (O(1) vs O(log n))
result = f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>'
else:
result = 'std::unordered_map<std::string, int>'
elif isinstance(expr, ast.Set):
if expr.elts:
element_type = self._infer_expression_type(expr.elts[0])
return f'std::set<{element_type or "int"}>'
return 'std::set<int>'
result = f'std::set<{element_type or "int"}>'
else:
result = 'std::set<int>'
elif isinstance(expr, ast.Name):
# Look up the variable type if we know it
return self.type_info.get(expr.id, 'auto')
result = self.type_info.get(expr.id, 'auto')
elif isinstance(expr, ast.Call):
# Function call - could be improved with function analysis
return 'auto'
result = 'auto'
elif isinstance(expr, ast.BinOp):
# Binary operation - infer from operands
left_type = self._infer_expression_type(expr.left)
right_type = self._infer_expression_type(expr.right)
if left_type == 'double' or right_type == 'double':
return 'double'
result = 'double'
elif left_type == 'int' and right_type == 'int':
return 'int'
return 'auto'
result = 'int'
else:
result = 'auto'
elif isinstance(expr, ast.ListComp):
# List comprehension - infer from element type
element_type = self._infer_expression_type(expr.elt)
result = f'std::vector<{element_type or "auto"}>'
elif isinstance(expr, ast.DictComp):
# Dictionary comprehension - infer from key and value types
key_type = self._infer_expression_type(expr.key)
value_type = self._infer_expression_type(expr.value)
result = f'std::unordered_map<{key_type or "auto"}, {value_type or "auto"}>'

return None
# Cache the result if we found one
if result is not None:
self.expression_cache[cache_key] = result

return result

def _analyze_function_types(self, node: ast.FunctionDef) -> None:
"""Analyze function parameter and return types."""
Expand Down Expand Up @@ -191,4 +232,4 @@ def _analyze_function_types(self, node: ast.FunctionDef) -> None:
if return_type:
func_info['return_type'] = return_type

self.type_info[f'function_{node.name}'] = func_info
self.type_info[node.name] = func_info
137 changes: 137 additions & 0 deletions src/converter/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def _generate_implementation(self, analysis_result: AnalysisResult) -> str:
impl = """#include "generated.hpp"
#include <vector>
#include <map>
#include <unordered_map>
#include <set>
#include <tuple>
#include <optional>
Expand Down Expand Up @@ -1083,6 +1084,10 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
obj = self._translate_expression(node.func.value.value, local_vars)
args = [self._translate_expression(arg, local_vars) for arg in node.args]
return f"{obj}.push_back({', '.join(args)})"
elif func_name in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'floor', 'ceil', 'fabs']:
# Handle direct imports from math module (e.g., from math import sqrt)
args = [self._translate_expression(arg, local_vars) for arg in node.args]
return f"std::{func_name}({', '.join(args)})"
else:
# Regular function call
args = [self._translate_expression(arg, local_vars) for arg in node.args]
Expand All @@ -1092,6 +1097,16 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
obj = self._translate_expression(node.func.value, local_vars)
method = node.func.attr

# Handle math module functions
if obj == 'math':
if method in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'pow', 'floor', 'ceil', 'fabs']:
args = [self._translate_expression(arg, local_vars) for arg in node.args]
return f"std::{method}({', '.join(args)})"
else:
# Handle other math functions that may need special mapping
args = [self._translate_expression(arg, local_vars) for arg in node.args]
return f"std::{method}({', '.join(args)})"

# Map Python methods to C++ equivalents
if method == 'append':
method = 'push_back' # std::vector uses push_back, not append
Expand Down Expand Up @@ -1153,6 +1168,12 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
return "std::make_tuple()"

return f"std::make_tuple({', '.join(elements)})"
elif isinstance(node, ast.ListComp):
# Handle list comprehensions: [expr for item in iterable]
return self._translate_list_comprehension(node, local_vars)
elif isinstance(node, ast.DictComp):
# Handle dictionary comprehensions: {key: value for item in iterable}
return self._translate_dict_comprehension(node, local_vars)
elif isinstance(node, ast.BoolOp):
# Handle boolean operations like and, or
op_str = "&&" if isinstance(node.op, ast.And) else "||"
Expand Down Expand Up @@ -1619,6 +1640,122 @@ def _generate_cmake(self) -> str:

return '\n'.join(cmake_content)

def _translate_list_comprehension(self, node: ast.ListComp, local_vars: Dict[str, str]) -> str:
"""Translate list comprehension to C++ lambda with performance optimizations."""
# Get the comprehension parts
element_expr = node.elt
generator = node.generators[0] # For simplicity, handle only one generator
target = generator.target
iter_expr = generator.iter

# Translate components
iter_str = self._translate_expression(iter_expr, local_vars)
target_name = target.id if isinstance(target, ast.Name) else "item"
element_str = self._translate_expression(element_expr, local_vars)

# Handle conditional comprehensions (if clauses)
condition_str = ""
if generator.ifs:
conditions = []
for if_clause in generator.ifs:
condition = self._translate_expression(if_clause, local_vars)
conditions.append(condition)
condition_str = f" if ({' && '.join(conditions)})"

# Create lambda expression for the comprehension with performance optimizations
# [expr for item in iterable] becomes:
# [&]() {
# std::vector<auto> result;
# result.reserve(iterable.size()); // Performance optimization
# for (auto item : iterable) {
# if (condition) { // Only if conditions exist
# result.push_back(expr);
# }
# }
# return result;
# }()

if condition_str:
comprehension_code = f"""[&]() {{
std::vector<auto> result;
result.reserve({iter_str}.size());
for (auto {target_name} : {iter_str}) {{
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{
result.push_back({element_str});
}}
}}
return result;
}}()"""
else:
comprehension_code = f"""[&]() {{
std::vector<auto> result;
result.reserve({iter_str}.size());
for (auto {target_name} : {iter_str}) {{
result.push_back({element_str});
}}
return result;
}}()"""

return comprehension_code

def _translate_dict_comprehension(self, node: ast.DictComp, local_vars: Dict[str, str]) -> str:
"""Translate dictionary comprehension to C++ lambda with performance optimizations."""
# Get the comprehension parts
key_expr = node.key
value_expr = node.value
generator = node.generators[0] # For simplicity, handle only one generator
target = generator.target
iter_expr = generator.iter

# Translate components
iter_str = self._translate_expression(iter_expr, local_vars)
target_name = target.id if isinstance(target, ast.Name) else "item"
key_str = self._translate_expression(key_expr, local_vars)
value_str = self._translate_expression(value_expr, local_vars)

# Handle conditional comprehensions (if clauses)
condition_str = ""
if generator.ifs:
conditions = []
for if_clause in generator.ifs:
condition = self._translate_expression(if_clause, local_vars)
conditions.append(condition)
condition_str = f" if ({' && '.join(conditions)})"

# Create lambda expression for the dictionary comprehension with performance optimizations
# Use std::unordered_map instead of std::map for O(1) vs O(log n) performance
# {key: value for item in iterable} becomes:
# [&]() {
# std::unordered_map<auto, auto> result;
# for (auto item : iterable) {
# if (condition) { // Only if conditions exist
# result[key] = value;
# }
# }
# return result;
# }()

if condition_str:
comprehension_code = f"""[&]() {{
std::unordered_map<auto, auto> result;
for (auto {target_name} : {iter_str}) {{
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{
result[{key_str}] = {value_str};
}}
}}
return result;
}}()"""
else:
comprehension_code = f"""[&]() {{
std::unordered_map<auto, auto> result;
for (auto {target_name} : {iter_str}) {{
result[{key_str}] = {value_str};
}}
return result;
}}()"""

return comprehension_code

def _expression_uses_variables(self, expr: ast.AST, variable_names: List[str]) -> bool:
"""Check if an expression uses any of the given variable names."""
for node in ast.walk(expr):
Expand Down