Skip to content

Commit d8879b5

Browse files
CopilotCrazyDubya
andauthored
[WIP] there are numerous PR with merge conflcits work through and reoslve them all... test code at end make sure fucntional (#12)
* Initial plan * Implement PRs #4, #6, #9, #10: Math functions, comprehensions, caching, performance optimizations Co-authored-by: CrazyDubya <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: CrazyDubya <[email protected]>
1 parent 9af1409 commit d8879b5

File tree

2 files changed

+201
-23
lines changed

2 files changed

+201
-23
lines changed

src/analyzer/type_inference.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ class TypeInferenceAnalyzer:
1111

1212
def __init__(self):
1313
self.type_info: Dict[str, Any] = {}
14+
# Add expression type cache
15+
self.expression_cache: Dict[str, str] = {}
1416

1517
def analyze_types(self, tree: ast.AST) -> Dict[str, Any]:
1618
"""Analyze types in the AST and return type information."""
1719
self.type_info.clear()
20+
# Clear cache at the start of analysis
21+
self.expression_cache.clear()
1822

1923
for node in ast.walk(tree):
2024
if isinstance(node, ast.Assign):
@@ -97,14 +101,23 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]:
97101
return type_map.get(annotation.id)
98102
elif isinstance(annotation, ast.Subscript):
99103
if isinstance(annotation.value, ast.Name):
100-
if annotation.value.id == 'List':
104+
if annotation.value.id in ['List', 'list']:
101105
element_type = self._annotation_to_cpp_type(annotation.slice)
102106
return f'std::vector<{element_type or "int"}>'
103-
elif annotation.value.id == 'Dict':
107+
elif annotation.value.id in ['Dict', 'dict']:
104108
if isinstance(annotation.slice, ast.Tuple) and len(annotation.slice.elts) == 2:
105109
key_type = self._annotation_to_cpp_type(annotation.slice.elts[0])
106110
value_type = self._annotation_to_cpp_type(annotation.slice.elts[1])
107-
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>'
111+
return f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>'
112+
elif annotation.value.id in ['Tuple', 'tuple']:
113+
if isinstance(annotation.slice, ast.Tuple):
114+
types = []
115+
for elt in annotation.slice.elts:
116+
cpp_type = self._annotation_to_cpp_type(elt)
117+
if cpp_type:
118+
types.append(cpp_type)
119+
if types:
120+
return f'std::tuple<{", ".join(types)}>'
108121
elif annotation.value.id == 'Optional':
109122
inner_type = self._annotation_to_cpp_type(annotation.slice)
110123
return f'std::optional<{inner_type or "int"}>'
@@ -121,49 +134,77 @@ def _annotation_to_cpp_type(self, annotation: ast.AST) -> Optional[str]:
121134
return None
122135

123136
def _infer_expression_type(self, expr: ast.AST) -> Optional[str]:
124-
"""Infer the type of an expression."""
137+
"""Infer the type of an expression with caching."""
138+
# Create a cache key based on the AST dump
139+
cache_key = ast.dump(expr)
140+
141+
# Check if we already cached this expression's type
142+
if cache_key in self.expression_cache:
143+
return self.expression_cache[cache_key]
144+
145+
# Infer the type
146+
result = None
125147
if isinstance(expr, ast.Constant):
126-
if isinstance(expr.value, int):
127-
return 'int'
148+
# Check bool first since bool is a subclass of int in Python
149+
if isinstance(expr.value, bool):
150+
result = 'bool'
151+
elif isinstance(expr.value, int):
152+
result = 'int'
128153
elif isinstance(expr.value, float):
129-
return 'double'
154+
result = 'double'
130155
elif isinstance(expr.value, str):
131-
return 'std::string'
132-
elif isinstance(expr.value, bool):
133-
return 'bool'
156+
result = 'std::string'
134157
elif isinstance(expr, ast.List):
135158
if expr.elts:
136159
element_type = self._infer_expression_type(expr.elts[0])
137-
return f'std::vector<{element_type or "int"}>'
138-
return 'std::vector<int>'
160+
result = f'std::vector<{element_type or "int"}>'
161+
else:
162+
result = 'std::vector<int>'
139163
elif isinstance(expr, ast.Dict):
140164
if expr.keys and expr.values:
141165
key_type = self._infer_expression_type(expr.keys[0])
142166
value_type = self._infer_expression_type(expr.values[0])
143-
return f'std::map<{key_type or "std::string"}, {value_type or "int"}>'
144-
return 'std::map<std::string, int>'
167+
# Use std::unordered_map for better performance (O(1) vs O(log n))
168+
result = f'std::unordered_map<{key_type or "std::string"}, {value_type or "int"}>'
169+
else:
170+
result = 'std::unordered_map<std::string, int>'
145171
elif isinstance(expr, ast.Set):
146172
if expr.elts:
147173
element_type = self._infer_expression_type(expr.elts[0])
148-
return f'std::set<{element_type or "int"}>'
149-
return 'std::set<int>'
174+
result = f'std::set<{element_type or "int"}>'
175+
else:
176+
result = 'std::set<int>'
150177
elif isinstance(expr, ast.Name):
151178
# Look up the variable type if we know it
152-
return self.type_info.get(expr.id, 'auto')
179+
result = self.type_info.get(expr.id, 'auto')
153180
elif isinstance(expr, ast.Call):
154181
# Function call - could be improved with function analysis
155-
return 'auto'
182+
result = 'auto'
156183
elif isinstance(expr, ast.BinOp):
157184
# Binary operation - infer from operands
158185
left_type = self._infer_expression_type(expr.left)
159186
right_type = self._infer_expression_type(expr.right)
160187
if left_type == 'double' or right_type == 'double':
161-
return 'double'
188+
result = 'double'
162189
elif left_type == 'int' and right_type == 'int':
163-
return 'int'
164-
return 'auto'
190+
result = 'int'
191+
else:
192+
result = 'auto'
193+
elif isinstance(expr, ast.ListComp):
194+
# List comprehension - infer from element type
195+
element_type = self._infer_expression_type(expr.elt)
196+
result = f'std::vector<{element_type or "auto"}>'
197+
elif isinstance(expr, ast.DictComp):
198+
# Dictionary comprehension - infer from key and value types
199+
key_type = self._infer_expression_type(expr.key)
200+
value_type = self._infer_expression_type(expr.value)
201+
result = f'std::unordered_map<{key_type or "auto"}, {value_type or "auto"}>'
165202

166-
return None
203+
# Cache the result if we found one
204+
if result is not None:
205+
self.expression_cache[cache_key] = result
206+
207+
return result
167208

168209
def _analyze_function_types(self, node: ast.FunctionDef) -> None:
169210
"""Analyze function parameter and return types."""
@@ -191,4 +232,4 @@ def _analyze_function_types(self, node: ast.FunctionDef) -> None:
191232
if return_type:
192233
func_info['return_type'] = return_type
193234

194-
self.type_info[f'function_{node.name}'] = func_info
235+
self.type_info[node.name] = func_info

src/converter/code_generator.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def _generate_implementation(self, analysis_result: AnalysisResult) -> str:
277277
impl = """#include "generated.hpp"
278278
#include <vector>
279279
#include <map>
280+
#include <unordered_map>
280281
#include <set>
281282
#include <tuple>
282283
#include <optional>
@@ -1083,6 +1084,10 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
10831084
obj = self._translate_expression(node.func.value.value, local_vars)
10841085
args = [self._translate_expression(arg, local_vars) for arg in node.args]
10851086
return f"{obj}.push_back({', '.join(args)})"
1087+
elif func_name in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'floor', 'ceil', 'fabs']:
1088+
# Handle direct imports from math module (e.g., from math import sqrt)
1089+
args = [self._translate_expression(arg, local_vars) for arg in node.args]
1090+
return f"std::{func_name}({', '.join(args)})"
10861091
else:
10871092
# Regular function call
10881093
args = [self._translate_expression(arg, local_vars) for arg in node.args]
@@ -1092,6 +1097,16 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
10921097
obj = self._translate_expression(node.func.value, local_vars)
10931098
method = node.func.attr
10941099

1100+
# Handle math module functions
1101+
if obj == 'math':
1102+
if method in ['sqrt', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'exp', 'log', 'log10', 'pow', 'floor', 'ceil', 'fabs']:
1103+
args = [self._translate_expression(arg, local_vars) for arg in node.args]
1104+
return f"std::{method}({', '.join(args)})"
1105+
else:
1106+
# Handle other math functions that may need special mapping
1107+
args = [self._translate_expression(arg, local_vars) for arg in node.args]
1108+
return f"std::{method}({', '.join(args)})"
1109+
10951110
# Map Python methods to C++ equivalents
10961111
if method == 'append':
10971112
method = 'push_back' # std::vector uses push_back, not append
@@ -1153,6 +1168,12 @@ def _translate_expression(self, node: ast.AST, local_vars: Dict[str, str]) -> st
11531168
return "std::make_tuple()"
11541169

11551170
return f"std::make_tuple({', '.join(elements)})"
1171+
elif isinstance(node, ast.ListComp):
1172+
# Handle list comprehensions: [expr for item in iterable]
1173+
return self._translate_list_comprehension(node, local_vars)
1174+
elif isinstance(node, ast.DictComp):
1175+
# Handle dictionary comprehensions: {key: value for item in iterable}
1176+
return self._translate_dict_comprehension(node, local_vars)
11561177
elif isinstance(node, ast.BoolOp):
11571178
# Handle boolean operations like and, or
11581179
op_str = "&&" if isinstance(node.op, ast.And) else "||"
@@ -1619,6 +1640,122 @@ def _generate_cmake(self) -> str:
16191640

16201641
return '\n'.join(cmake_content)
16211642

1643+
def _translate_list_comprehension(self, node: ast.ListComp, local_vars: Dict[str, str]) -> str:
1644+
"""Translate list comprehension to C++ lambda with performance optimizations."""
1645+
# Get the comprehension parts
1646+
element_expr = node.elt
1647+
generator = node.generators[0] # For simplicity, handle only one generator
1648+
target = generator.target
1649+
iter_expr = generator.iter
1650+
1651+
# Translate components
1652+
iter_str = self._translate_expression(iter_expr, local_vars)
1653+
target_name = target.id if isinstance(target, ast.Name) else "item"
1654+
element_str = self._translate_expression(element_expr, local_vars)
1655+
1656+
# Handle conditional comprehensions (if clauses)
1657+
condition_str = ""
1658+
if generator.ifs:
1659+
conditions = []
1660+
for if_clause in generator.ifs:
1661+
condition = self._translate_expression(if_clause, local_vars)
1662+
conditions.append(condition)
1663+
condition_str = f" if ({' && '.join(conditions)})"
1664+
1665+
# Create lambda expression for the comprehension with performance optimizations
1666+
# [expr for item in iterable] becomes:
1667+
# [&]() {
1668+
# std::vector<auto> result;
1669+
# result.reserve(iterable.size()); // Performance optimization
1670+
# for (auto item : iterable) {
1671+
# if (condition) { // Only if conditions exist
1672+
# result.push_back(expr);
1673+
# }
1674+
# }
1675+
# return result;
1676+
# }()
1677+
1678+
if condition_str:
1679+
comprehension_code = f"""[&]() {{
1680+
std::vector<auto> result;
1681+
result.reserve({iter_str}.size());
1682+
for (auto {target_name} : {iter_str}) {{
1683+
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{
1684+
result.push_back({element_str});
1685+
}}
1686+
}}
1687+
return result;
1688+
}}()"""
1689+
else:
1690+
comprehension_code = f"""[&]() {{
1691+
std::vector<auto> result;
1692+
result.reserve({iter_str}.size());
1693+
for (auto {target_name} : {iter_str}) {{
1694+
result.push_back({element_str});
1695+
}}
1696+
return result;
1697+
}}()"""
1698+
1699+
return comprehension_code
1700+
1701+
def _translate_dict_comprehension(self, node: ast.DictComp, local_vars: Dict[str, str]) -> str:
1702+
"""Translate dictionary comprehension to C++ lambda with performance optimizations."""
1703+
# Get the comprehension parts
1704+
key_expr = node.key
1705+
value_expr = node.value
1706+
generator = node.generators[0] # For simplicity, handle only one generator
1707+
target = generator.target
1708+
iter_expr = generator.iter
1709+
1710+
# Translate components
1711+
iter_str = self._translate_expression(iter_expr, local_vars)
1712+
target_name = target.id if isinstance(target, ast.Name) else "item"
1713+
key_str = self._translate_expression(key_expr, local_vars)
1714+
value_str = self._translate_expression(value_expr, local_vars)
1715+
1716+
# Handle conditional comprehensions (if clauses)
1717+
condition_str = ""
1718+
if generator.ifs:
1719+
conditions = []
1720+
for if_clause in generator.ifs:
1721+
condition = self._translate_expression(if_clause, local_vars)
1722+
conditions.append(condition)
1723+
condition_str = f" if ({' && '.join(conditions)})"
1724+
1725+
# Create lambda expression for the dictionary comprehension with performance optimizations
1726+
# Use std::unordered_map instead of std::map for O(1) vs O(log n) performance
1727+
# {key: value for item in iterable} becomes:
1728+
# [&]() {
1729+
# std::unordered_map<auto, auto> result;
1730+
# for (auto item : iterable) {
1731+
# if (condition) { // Only if conditions exist
1732+
# result[key] = value;
1733+
# }
1734+
# }
1735+
# return result;
1736+
# }()
1737+
1738+
if condition_str:
1739+
comprehension_code = f"""[&]() {{
1740+
std::unordered_map<auto, auto> result;
1741+
for (auto {target_name} : {iter_str}) {{
1742+
if ({' && '.join(self._translate_expression(if_clause, local_vars) for if_clause in generator.ifs)}) {{
1743+
result[{key_str}] = {value_str};
1744+
}}
1745+
}}
1746+
return result;
1747+
}}()"""
1748+
else:
1749+
comprehension_code = f"""[&]() {{
1750+
std::unordered_map<auto, auto> result;
1751+
for (auto {target_name} : {iter_str}) {{
1752+
result[{key_str}] = {value_str};
1753+
}}
1754+
return result;
1755+
}}()"""
1756+
1757+
return comprehension_code
1758+
16221759
def _expression_uses_variables(self, expr: ast.AST, variable_names: List[str]) -> bool:
16231760
"""Check if an expression uses any of the given variable names."""
16241761
for node in ast.walk(expr):

0 commit comments

Comments
 (0)