Skip to content

Commit 25b10d4

Browse files
committed
type hints PEP484 style and code formatting
1 parent 20cc62d commit 25b10d4

File tree

4 files changed

+77
-49
lines changed

4 files changed

+77
-49
lines changed

examples/BeamNG/IGHAStarMP.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import time
44
import traceback
55
from queue import Empty
6+
from typing import Any, Optional, Tuple, Dict
7+
68

79
class IGHAStarMP:
810
"""
911
Multiprocessing wrapper for the IGHA* planner. Runs the planner in a separate process and communicates via queues.
1012
"""
1113

12-
def __init__(self, configs):
14+
def __init__(self, configs: Dict[str, Any]) -> None:
1315
mp.set_start_method("spawn", force=True) # Safe for CUDA
1416
self.query_queue = mp.Queue(5)
1517
self.result_queue = mp.Queue(5)
@@ -23,7 +25,9 @@ def __init__(self, configs):
2325
self.completed = True
2426
self.expansion_counter = 0
2527

26-
def _planner_process(self, configs, query_queue, result_queue):
28+
def _planner_process(
29+
self, configs: Dict[str, Any], query_queue: Any, result_queue: Any
30+
) -> None:
2731
"""
2832
Planner process: loads the CUDA/C++ kernel and runs the IGHA* planner in response to queries.
2933
"""
@@ -113,16 +117,16 @@ def _planner_process(self, configs, query_queue, result_queue):
113117

114118
def set_query(
115119
self,
116-
map_center,
117-
start_state,
118-
goal_,
119-
costmap,
120-
heightmap,
121-
hysteresis,
122-
expansion_limit,
123-
stop=False,
124-
disable=False,
125-
):
120+
map_center: np.ndarray,
121+
start_state: np.ndarray,
122+
goal_: np.ndarray,
123+
costmap: np.ndarray,
124+
heightmap: np.ndarray,
125+
hysteresis: float,
126+
expansion_limit: int,
127+
stop: bool = False,
128+
disable: bool = False,
129+
) -> None:
126130
"""
127131
Submit a new planning query. Returns immediately; results are available via update().
128132
"""
@@ -148,14 +152,14 @@ def set_query(
148152
self.completed = False
149153
self.success = False
150154

151-
def reset(self):
155+
def reset(self) -> None:
152156
"""Reset planner state."""
153157
self.path = None
154158
self.success = False
155159
self.completed = True
156160
self.expansion_counter = 0
157161

158-
def update(self):
162+
def update(self) -> Tuple[bool, Optional[np.ndarray], int]:
159163
"""
160164
Call this periodically in your main loop to check for planner results.
161165
Returns (success, path, expansion_counter).
@@ -176,7 +180,7 @@ def update(self):
176180
except Empty:
177181
return False, self.path, self.expansion_counter
178182

179-
def shutdown(self):
183+
def shutdown(self) -> None:
180184
"""Shut down the planner process."""
181185
self.query_queue.put(None)
182186
self.process.terminate()

examples/BeamNG/TrackingCost.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from typing import Any, Optional, Dict
23

34

45
class SimpleCarCost(torch.nn.Module):
@@ -8,11 +9,11 @@ class SimpleCarCost(torch.nn.Module):
89

910
def __init__(
1011
self,
11-
Cost_config,
12-
Map_config,
13-
dtype=torch.float32,
14-
device=torch.device("cuda"),
15-
):
12+
Cost_config: Dict[str, Any],
13+
Map_config: Dict[str, Any],
14+
dtype: torch.dtype = torch.float32,
15+
device: torch.device = torch.device("cuda"),
16+
) -> None:
1617
super(SimpleCarCost, self).__init__()
1718
self.dtype = dtype
1819
self.d = device
@@ -67,7 +68,12 @@ def __init__(
6768
self.constraint_violation = False
6869

6970
@torch.jit.export
70-
def set_BEV(self, BEVmap_height, BEVmap_normal, BEVmap_cost):
71+
def set_BEV(
72+
self,
73+
BEVmap_height: torch.Tensor,
74+
BEVmap_normal: torch.Tensor,
75+
BEVmap_cost: torch.Tensor,
76+
) -> None:
7177
"""
7278
Set BEV (bird's-eye view) map data for cost calculation.
7379
"""
@@ -76,25 +82,25 @@ def set_BEV(self, BEVmap_height, BEVmap_normal, BEVmap_cost):
7682
self.BEVmap_cost = (255 - BEVmap_cost) / 255
7783

7884
@torch.jit.export
79-
def set_goal(self, goal_state):
85+
def set_goal(self, goal_state: torch.Tensor) -> None:
8086
self.goal_state = goal_state[:2]
8187

82-
def set_path(self, path):
88+
def set_path(self, path: torch.Tensor) -> None:
8389
self.path = torch.tensor(path, dtype=self.dtype, device=self.d)
8490

8591
@torch.jit.export
86-
def set_speed_limit(self, speed_lim):
92+
def set_speed_limit(self, speed_lim: float) -> None:
8793
self.speed_target = torch.tensor(speed_lim, dtype=self.dtype, device=self.d)
8894

89-
def meters_to_px(self, meters):
95+
def meters_to_px(self, meters: torch.Tensor) -> torch.Tensor:
9096
px = ((meters + self.BEVmap_size * 0.5) / self.BEVmap_res).to(
9197
dtype=torch.long, device=self.d
9298
)
9399
px = torch.maximum(px, torch.zeros_like(px))
94100
px = torch.minimum(px, self.BEVmap_size_px - 1)
95101
return px
96102

97-
def forward(self, state, controls):
103+
def forward(self, state: torch.Tensor, controls: torch.Tensor) -> torch.Tensor:
98104
# Unpack state
99105
x = state[..., 0]
100106
y = state[..., 1]

examples/BeamNG/example.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import argparse
1212
import sys
1313
import pathlib
14+
import os
15+
from typing import Any, Optional, Dict, List
1416

1517
BASE_DIR = pathlib.Path(__file__).resolve().parent.parent.parent
1618
sys.path.append(str(BASE_DIR / "scripts"))
@@ -28,12 +30,12 @@
2830

2931

3032
def main(
31-
config_path=None,
32-
hal_config_path=None,
33-
waypoint_folder=None,
34-
output_folder=None,
35-
args=None,
36-
):
33+
config_path: Optional[str] = None,
34+
hal_config_path: Optional[str] = None,
35+
waypoint_folder: Optional[str] = None,
36+
output_folder: Optional[str] = None,
37+
args: Optional[Any] = None,
38+
) -> None:
3739
if config_path is None:
3840
print("no config file provided!")
3941
exit()

examples/BeamNG/utils.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import time
44
import threading
55
import torch
6+
from typing import Any, Optional, List, Tuple, Dict
67

78

89
def generate_costmap_from_BEVmap(
9-
BEV_lethal, BEV_normal, costmap_cosine_thresh=np.cos(np.radians(60))
10-
):
10+
BEV_lethal: np.ndarray,
11+
BEV_normal: np.ndarray,
12+
costmap_cosine_thresh: float = np.cos(np.radians(60)),
13+
) -> np.ndarray:
1114
"""Generate a costmap from BEV lethal and normal maps."""
1215
dot_product = BEV_normal[:, :, 2]
1316
angle_cost = np.where(dot_product >= costmap_cosine_thresh, 255, 0).astype(
@@ -20,7 +23,11 @@ def generate_costmap_from_BEVmap(
2023
return costmap
2124

2225

23-
def convert_global_path_to_bng(bng_interface=None, path=None, Map_config=None):
26+
def convert_global_path_to_bng(
27+
bng_interface: Optional[Any] = None,
28+
path: Optional[np.ndarray] = None,
29+
Map_config: Optional[Dict[str, Any]] = None,
30+
) -> np.ndarray:
2431
"""Convert a global path to BeamNG waypoints with elevation and quaternion."""
2532
target_wp = []
2633
map_res = Map_config["map_res"]
@@ -44,7 +51,14 @@ def convert_global_path_to_bng(bng_interface=None, path=None, Map_config=None):
4451
return np.array(target_wp)
4552

4653

47-
def update_goal(goal, pos, target_WP, current_wp_index, lookahead, wp_radius=1.0):
54+
def update_goal(
55+
goal: Optional[np.ndarray],
56+
pos: np.ndarray,
57+
target_WP: np.ndarray,
58+
current_wp_index: int,
59+
lookahead: float,
60+
wp_radius: float = 1.0,
61+
) -> Tuple[np.ndarray, bool, int]:
4862
"""Update the goal position based on lookahead and proximity to final waypoint."""
4963
final_wp = target_WP[-1]
5064
dx = final_wp[0] - pos[0]
@@ -65,7 +79,9 @@ def update_goal(goal, pos, target_WP, current_wp_index, lookahead, wp_radius=1.0
6579
return goal, success, current_wp_index
6680

6781

68-
def steering_limiter(steer, state, RPS_config):
82+
def steering_limiter(
83+
steer: float, state: np.ndarray, RPS_config: Dict[str, Any]
84+
) -> float:
6985
"""Limit steering to prevent rollovers and respect physical constraints."""
7086
steering_setpoint = steer * RPS_config["steering_max"]
7187
whspd2 = max(1.0, np.linalg.norm(state[6:8])) ** 2 # speed squared in world frame
@@ -125,7 +141,7 @@ def steering_limiter(steer, state, RPS_config):
125141

126142

127143
class PlannerVis:
128-
def __init__(self, map_size, resolution_inv):
144+
def __init__(self, map_size: int, resolution_inv: float) -> None:
129145
self.map_size = map_size
130146
self.resolution_inv = resolution_inv
131147
self.costmap = None
@@ -140,15 +156,15 @@ def __init__(self, map_size, resolution_inv):
140156

141157
def update_vis(
142158
self,
143-
states,
144-
path,
145-
costmap,
146-
elevation_map,
147-
resolution_inv,
148-
goal,
149-
expansion_counter,
150-
hysteresis,
151-
):
159+
states: np.ndarray,
160+
path: Optional[np.ndarray],
161+
costmap: np.ndarray,
162+
elevation_map: np.ndarray,
163+
resolution_inv: float,
164+
goal: np.ndarray,
165+
expansion_counter: int,
166+
hysteresis: float,
167+
) -> None:
152168
with self.lock:
153169
if isinstance(states, torch.Tensor):
154170
self.states = states.cpu().numpy()
@@ -162,12 +178,12 @@ def update_vis(
162178
self.hysteresis = hysteresis
163179
self.expansion_counter = expansion_counter
164180

165-
def generate_costmap_from_BEVmap(self, normal):
181+
def generate_costmap_from_BEVmap(self, normal: np.ndarray) -> np.ndarray:
166182
dot_product = normal[:, :, 2]
167183
costmap = np.where(dot_product >= self.cosine_thresh, 255, 0).astype(np.float32)
168184
return costmap
169185

170-
def costmap_vis(self):
186+
def costmap_vis(self) -> None:
171187
while True:
172188
if (
173189
self.states is not None

0 commit comments

Comments
 (0)