diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 5283fd12a..72b1fa512 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -88,9 +88,9 @@ class Flight: Maximum absolute error tolerance to be tolerated in the integration scheme. Flight.time_overshoot : bool, optional - If True, decouples ODE time step from parachute trigger functions - sampling rate. The time steps can overshoot the necessary trigger - function evaluation points and then interpolation is used to + If True, decouples ODE time step from parachute and controller trigger + functions sampling rate. The time steps can overshoot the necessary + trigger function evaluation points and then interpolation is used to calculate them and feed the triggers. Can greatly improve run time in some cases. Flight.terminate_on_apogee : bool @@ -550,11 +550,11 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements integration scheme. Can be given as array for each state space variable. Default is 6*[1e-3] + 4*[1e-6] + 3*[1e-3]. time_overshoot : bool, optional - If True, decouples ODE time step from parachute trigger functions - sampling rate. The time steps can overshoot the necessary trigger - function evaluation points and then interpolation is used to - calculate them and feed the triggers. Can greatly improve run - time in some cases. Default is True. + If True, decouples ODE time step from parachute and controller + trigger functions sampling rate. The time steps can overshoot the + necessary trigger function evaluation points and then interpolation + is used to calculate them and feed the triggers. Can greatly improve + run time in some cases. Default is True. verbose : bool, optional If true, verbose mode is activated. Default is False. name : str, optional @@ -647,7 +647,7 @@ def __simulate(self, verbose): for callback in phase.callbacks: callback(self) - # Create solver for this flight phase # TODO: allow different integrators + # Create solver for this flight phase self.function_evaluations.append(0) phase.solver = self._solver( @@ -662,29 +662,7 @@ def __simulate(self, verbose): ) # Initialize phase time nodes - phase.time_nodes = self.TimeNodes() - # Add first time node to the time_nodes list - phase.time_nodes.add_node(phase.t, [], [], []) - # Add non-overshootable parachute time nodes - if self.time_overshoot is False: - phase.time_nodes.add_parachutes( - self.parachutes, phase.t, phase.time_bound - ) - phase.time_nodes.add_sensors( - self.rocket.sensors, phase.t, phase.time_bound - ) - phase.time_nodes.add_controllers( - self._controllers, phase.t, phase.time_bound - ) - # Add last time node to the time_nodes list - phase.time_nodes.add_node(phase.time_bound, [], [], []) - # Organize time nodes with sort() and merge() - phase.time_nodes.sort() - phase.time_nodes.merge() - # Clear triggers from first time node if necessary - if phase.clear: - phase.time_nodes[0].parachutes = [] - phase.time_nodes[0].callbacks = [] + self.__setup_phase_time_nodes(phase, phase_index) # Iterate through time nodes for node_index, node in self.time_iterator(phase.time_nodes): @@ -705,81 +683,12 @@ def __simulate(self, verbose): for callback in node.callbacks: callback(self) - if self.sensors: - # u_dot for all sensors - u_dot = phase.derivative(self.t, self.y_sol) - for sensor, position in node._component_sensors: - relative_position = position - self.rocket._csys * Vector( - [0, 0, self.rocket.center_of_dry_mass_position] - ) - sensor.measure( - self.t, - u=self.y_sol, - u_dot=u_dot, - relative_position=relative_position, - environment=self.env, - gravity=self.env.gravity.get_value_opt( - self.solution[-1][3] - ), - pressure=self.env.pressure, - earth_radius=self.env.earth_radius, - initial_coordinates=(self.env.latitude, self.env.longitude), - ) - - for controller in node._controllers: - controller( - self.t, - self.y_sol, - self.solution, - self.sensors, - ) + self.__process_sensors_and_controllers_at_current_node(node, phase) - for parachute in node.parachutes: - # Calculate and save pressure signal - ( - noisy_pressure, - height_above_ground_level, - ) = self.__calculate_and_save_pressure_signals( - parachute, node.t, self.y_sol[2] - ) - if parachute.triggerfunc( - noisy_pressure, - height_above_ground_level, - self.y_sol, - self.sensors, - ): - # Remove parachute from flight parachutes - self.parachutes.remove(parachute) - # Create phase for time after detection and before inflation - # Must only be created if parachute has any lag - i = 1 - if parachute.lag != 0: - self.flight_phases.add_phase( - node.t, - phase.derivative, - clear=True, - index=phase_index + i, - ) - i += 1 - # Create flight phase for time after inflation - callbacks = [ - lambda self, parachute_cd_s=parachute.cd_s: setattr( - self, "parachute_cd_s", parachute_cd_s - ) - ] - self.flight_phases.add_phase( - node.t + parachute.lag, - self.u_dot_parachute, - callbacks, - clear=False, - index=phase_index + i, - ) - # Prepare to leave loops and start new flight phase - phase.time_nodes.flush_after(node_index) - phase.time_nodes.add_node(self.t, [], [], []) - phase.solver.status = "finished" - # Save parachute event - self.parachute_events.append([self.t, parachute]) + if self.__check_and_handle_parachute_triggers( + node, phase, phase_index, node_index + ): + break # Stop simulation if parachute is deployed # Step through simulation while phase.solver.status == "running": @@ -794,252 +703,16 @@ def __simulate(self, verbose): if verbose: print(f"Current Simulation Time: {self.t:3.4f} s", end="\r") - # Check for first out of rail event - if len(self.out_of_rail_state) == 1 and ( - self.y_sol[0] ** 2 - + self.y_sol[1] ** 2 - + (self.y_sol[2] - self.env.elevation) ** 2 - >= self.effective_1rl**2 + if self.__check_simulation_events(phase, phase_index, node_index): + break # Stop if + + # Process overshootable time nodes if enabled + if self.time_overshoot and self.__process_overshootable_nodes( + phase, phase_index, node_index ): - # Check exactly when it went out using root finding - # Disconsider elevation - self.solution[-2][3] -= self.env.elevation - self.solution[-1][3] -= self.env.elevation - # Get points - y0 = ( - sum(self.solution[-2][i] ** 2 for i in [1, 2, 3]) - - self.effective_1rl**2 - ) - yp0 = 2 * sum( - self.solution[-2][i] * self.solution[-2][i + 3] - for i in [1, 2, 3] - ) - t1 = self.solution[-1][0] - self.solution[-2][0] - y1 = ( - sum(self.solution[-1][i] ** 2 for i in [1, 2, 3]) - - self.effective_1rl**2 - ) - yp1 = 2 * sum( - self.solution[-1][i] * self.solution[-1][i + 3] - for i in [1, 2, 3] - ) - # Put elevation back - self.solution[-2][3] += self.env.elevation - self.solution[-1][3] += self.env.elevation - # Cubic Hermite interpolation (ax**3 + bx**2 + cx + d) - a, b, c, d = calculate_cubic_hermite_coefficients( - 0, - float(phase.solver.step_size), - y0, - yp0, - y1, - yp1, - ) - a += 1e-5 # TODO: why?? - # Find roots - t_roots = find_roots_cubic_function(a, b, c, d) - # Find correct root - valid_t_root = [ - t_root.real - for t_root in t_roots - if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001 - ] - if len(valid_t_root) > 1: # pragma: no cover - raise ValueError( - "Multiple roots found when solving for rail exit time." - ) - if len(valid_t_root) == 0: # pragma: no cover - raise ValueError( - "No valid roots found when solving for rail exit time." - ) - # Determine final state when upper button is going out of rail - self.t = valid_t_root[0] + self.solution[-2][0] - interpolator = phase.solver.dense_output() - self.y_sol = interpolator(self.t) - self.solution[-1] = [self.t, *self.y_sol] - self.out_of_rail_time = self.t - self.out_of_rail_time_index = len(self.solution) - 1 - self.out_of_rail_state = self.y_sol - # Create new flight phase - self.flight_phases.add_phase( - self.t, - self.u_dot_generalized, - index=phase_index + 1, - ) - # Prepare to leave loops and start new flight phase - phase.time_nodes.flush_after(node_index) - phase.time_nodes.add_node(self.t, [], [], []) - phase.solver.status = "finished" - - # Check for apogee event - # TODO: negative vz doesn't really mean apogee. Improve this. - if len(self.apogee_state) == 1 and self.y_sol[5] < 0: - # Assume linear vz(t) to detect when vz = 0 - t0, vz0 = self.solution[-2][0], self.solution[-2][6] - t1, vz1 = self.solution[-1][0], self.solution[-1][6] - t_root = find_root_linear_interpolation(t0, t1, vz0, vz1, 0) - # Fetch state at t_root - interpolator = phase.solver.dense_output() - self.apogee_state = interpolator(t_root) - # Store apogee data - self.apogee_time = t_root - self.apogee_x = self.apogee_state[0] - self.apogee_y = self.apogee_state[1] - self.apogee = self.apogee_state[2] - - if self.terminate_on_apogee: - self.t = self.t_final = t_root - # Roll back solution - self.solution[-1] = [self.t, *self.apogee_state] - # Set last flight phase - self.flight_phases.flush_after(phase_index) - self.flight_phases.add_phase(self.t) - # Prepare to leave loops and start new flight phase - phase.time_nodes.flush_after(node_index) - phase.time_nodes.add_node(self.t, [], [], []) - phase.solver.status = "finished" - elif len(self.solution) > 2: - # adding the apogee state to solution increases accuracy - # we can only do this if the apogee is not the first state - self.solution.insert(-1, [t_root, *self.apogee_state]) - # Check for impact event - if self.y_sol[2] < self.env.elevation: - # Check exactly when it happened using root finding - # Cubic Hermite interpolation (ax**3 + bx**2 + cx + d) - a, b, c, d = calculate_cubic_hermite_coefficients( - x0=0, # t0 - x1=float(phase.solver.step_size), # t1 - t0 - y0=float(self.solution[-2][3] - self.env.elevation), # z0 - yp0=float(self.solution[-2][6]), # vz0 - y1=float(self.solution[-1][3] - self.env.elevation), # z1 - yp1=float(self.solution[-1][6]), # vz1 - ) - # Find roots - t_roots = find_roots_cubic_function(a, b, c, d) - # Find correct root - t1 = self.solution[-1][0] - self.solution[-2][0] - valid_t_root = [ - t_root.real - for t_root in t_roots - if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1 - ] - if len(valid_t_root) > 1: # pragma: no cover - raise ValueError( - "Multiple roots found when solving for impact time." - ) - # Determine impact state at t_root - self.t = self.t_final = valid_t_root[0] + self.solution[-2][0] - interpolator = phase.solver.dense_output() - self.y_sol = self.impact_state = interpolator(self.t) - # Roll back solution - self.solution[-1] = [self.t, *self.y_sol] - # Save impact state - self.x_impact = self.impact_state[0] - self.y_impact = self.impact_state[1] - self.z_impact = self.impact_state[2] - self.impact_velocity = self.impact_state[5] - # Set last flight phase - self.flight_phases.flush_after(phase_index) - self.flight_phases.add_phase(self.t) - # Prepare to leave loops and start new flight phase - phase.time_nodes.flush_after(node_index) - phase.time_nodes.add_node(self.t, [], [], []) - phase.solver.status = "finished" - - # List and feed overshootable time nodes - if self.time_overshoot: - # Initialize phase overshootable time nodes - overshootable_nodes = self.TimeNodes() - # Add overshootable parachute time nodes - overshootable_nodes.add_parachutes( - self.parachutes, self.solution[-2][0], self.t - ) - # Add last time node (always skipped) - overshootable_nodes.add_node(self.t, [], [], []) - if len(overshootable_nodes) > 1: - # Sort and merge equal overshootable time nodes - overshootable_nodes.sort() - overshootable_nodes.merge() - # Clear if necessary - if overshootable_nodes[0].t == phase.t and phase.clear: - overshootable_nodes[0].parachutes = [] - overshootable_nodes[0].callbacks = [] - # Feed overshootable time nodes trigger - interpolator = phase.solver.dense_output() - for ( - overshootable_index, - overshootable_node, - ) in self.time_iterator(overshootable_nodes): - # Calculate state at node time - overshootable_node.y_sol = interpolator( - overshootable_node.t - ) - for parachute in overshootable_node.parachutes: - # Calculate and save pressure signal - ( - noisy_pressure, - height_above_ground_level, - ) = self.__calculate_and_save_pressure_signals( - parachute, - overshootable_node.t, - overshootable_node.y_sol[2], - ) - - # Check for parachute trigger - if parachute.triggerfunc( - noisy_pressure, - height_above_ground_level, - overshootable_node.y_sol, - self.sensors, - ): - # Remove parachute from flight parachutes - self.parachutes.remove(parachute) - # Create phase for time after detection and - # before inflation - # Must only be created if parachute has any lag - i = 1 - if parachute.lag != 0: - self.flight_phases.add_phase( - overshootable_node.t, - phase.derivative, - clear=True, - index=phase_index + i, - ) - i += 1 - # Create flight phase for time after inflation - callbacks = [ - lambda self, - parachute_cd_s=parachute.cd_s: setattr( - self, "parachute_cd_s", parachute_cd_s - ) - ] - self.flight_phases.add_phase( - overshootable_node.t + parachute.lag, - self.u_dot_parachute, - callbacks, - clear=False, - index=phase_index + i, - ) - # Rollback history - self.t = overshootable_node.t - self.y_sol = overshootable_node.y_sol - self.solution[-1] = [ - overshootable_node.t, - *overshootable_node.y_sol, - ] - # Prepare to leave loops and start new flight phase - overshootable_nodes.flush_after( - overshootable_index - ) - phase.time_nodes.flush_after(node_index) - phase.time_nodes.add_node(self.t, [], [], []) - phase.solver.status = "finished" - # Save parachute event - self.parachute_events.append( - [self.t, parachute] - ) - - # If controlled flight, post process must be done on sim time + break + + # Post-process controllers if needed if self._controllers: phase.derivative(self.t, self.y_sol, post_processing=True) @@ -1053,6 +726,573 @@ def __simulate(self, verbose): if verbose: print(f"\n>>> Simulation Completed at Time: {self.t:3.4f} s") + def __setup_phase_time_nodes(self, phase, phase_index): + """Set up time nodes for the current phase. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + """ + phase.time_nodes = self.TimeNodes() + + # Add first time node + phase.time_nodes.add_node(phase.t, [], [], []) + + if self.time_overshoot is False: + phase.time_nodes.add_parachutes(self.parachutes, phase.t, phase.time_bound) + phase.time_nodes.add_sensors(self.rocket.sensors, phase.t, phase.time_bound) + phase.time_nodes.add_controllers( + self._controllers, phase.t, phase.time_bound + ) + + # Add last time node + phase.time_nodes.add_node(phase.time_bound, [], [], []) + + # Organize time nodes + phase.time_nodes.sort() + phase.time_nodes.merge() + + # Clear triggers from first time node if necessary + if phase.clear: + phase.time_nodes[0].parachutes = [] + phase.time_nodes[0].callbacks = [] + + def __process_sensors_and_controllers_at_current_node(self, node, phase): + """Process sensors and controllers at the current node. + + Parameters + ---------- + node : TimeNode + The current time node. + phase : FlightPhase + The current flight phase. + """ + if self.sensors: + u_dot = phase.derivative(self.t, self.y_sol) + self.__measure_sensors(node._component_sensors, u_dot) + + for controller in node._controllers: + controller( + self.t, + self.y_sol, + self.solution, + self.sensors, + ) + + def __measure_sensors(self, component_sensors, u_dot, t=None, y_sol=None): + """Measure sensors with the given state and derivative. + + Parameters + ---------- + component_sensors : list + List of (sensor, position) tuples. + u_dot : array_like + State derivative vector. + t : float, optional + Time for measurement. If None, uses self.t. + y_sol : array_like, optional + State vector. If None, uses self.y_sol. + """ + if t is None: + t = self.t + if y_sol is None: + y_sol = self.y_sol + + for sensor, position in component_sensors: + relative_position = position - self.rocket._csys * Vector( + [0, 0, self.rocket.center_of_dry_mass_position] + ) + sensor.measure( + t, + u=y_sol, + u_dot=u_dot, + relative_position=relative_position, + environment=self.env, + gravity=self.env.gravity.get_value_opt( + y_sol[2] if len(y_sol) > 2 else self.solution[-1][3] + ), + pressure=self.env.pressure, + earth_radius=self.env.earth_radius, + initial_coordinates=(self.env.latitude, self.env.longitude), + ) + + def __check_and_handle_parachute_triggers( + self, node, phase, phase_index, node_index + ): + """Check for parachute triggers and handle deployment. + + Parameters + ---------- + node : TimeNode + The current time node. + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True if a parachute was triggered and the phase should break. + """ + for parachute in node.parachutes: + # Calculate and save pressure signal + ( + noisy_pressure, + height_above_ground_level, + ) = self.__calculate_and_save_pressure_signals( + parachute, node.t, self.y_sol[2] + ) + if not parachute.triggerfunc( + noisy_pressure, + height_above_ground_level, + self.y_sol, + self.sensors, + ): + return False # Early exit: parachute not deployed + + # Remove parachute from flight parachutes + self.parachutes.remove(parachute) + + # Create phase for time after detection and before inflation + # Must only be created if parachute has any lag + i = 1 + if parachute.lag != 0: + self.flight_phases.add_phase( + node.t, + phase.derivative, + clear=True, + index=phase_index + i, + ) + i += 1 + + # Create flight phase for time after inflation + callbacks = [ + lambda self, parachute_cd_s=parachute.cd_s: setattr( + self, "parachute_cd_s", parachute_cd_s + ) + ] + self.flight_phases.add_phase( + node.t + parachute.lag, + self.u_dot_parachute, + callbacks, + clear=False, + index=phase_index + i, + ) + + # Prepare to leave loops and start new flight phase + phase.time_nodes.flush_after(node_index) + phase.time_nodes.add_node(self.t, [], [], []) + phase.solver.status = "finished" + self.parachute_events.append([self.t, parachute]) + return True + + return False + + def __check_simulation_events(self, phase, phase_index, node_index): + """Check for simulation events like out of rail, apogee, and impact. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True if an event occurred and the simulation should break. + """ + # Check for first out of rail event + if len(self.out_of_rail_state) == 1 and ( + self.y_sol[0] ** 2 + + self.y_sol[1] ** 2 + + (self.y_sol[2] - self.env.elevation) ** 2 + >= self.effective_1rl**2 + ): + return self.__handle_out_of_rail_event(phase, phase_index, node_index) + + # Check for apogee event + # TODO: negative vz doesn't really mean apogee. Improve this. + if len(self.apogee_state) == 1 and self.y_sol[5] < 0: + return self.__handle_apogee_event(phase, phase_index, node_index) + + # Check for impact event + if self.y_sol[2] < self.env.elevation: + return self.__handle_impact_event(phase, phase_index, node_index) + + return False + + def __handle_out_of_rail_event(self, phase, phase_index, node_index): + """Handle the out of rail event. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True to indicate the simulation should break. + """ + # Check exactly when it went out using root finding + # Disconsider elevation + self.solution[-2][3] -= self.env.elevation + self.solution[-1][3] -= self.env.elevation + # Get points + y0 = sum(self.solution[-2][i] ** 2 for i in [1, 2, 3]) - self.effective_1rl**2 + yp0 = 2 * sum( + self.solution[-2][i] * self.solution[-2][i + 3] for i in [1, 2, 3] + ) + t1 = self.solution[-1][0] - self.solution[-2][0] + y1 = sum(self.solution[-1][i] ** 2 for i in [1, 2, 3]) - self.effective_1rl**2 + yp1 = 2 * sum( + self.solution[-1][i] * self.solution[-1][i + 3] for i in [1, 2, 3] + ) + # Put elevation back + self.solution[-2][3] += self.env.elevation + self.solution[-1][3] += self.env.elevation + # Cubic Hermite interpolation (ax**3 + bx**2 + cx + d) + a, b, c, d = calculate_cubic_hermite_coefficients( + 0, + float(phase.solver.step_size), + y0, + yp0, + y1, + yp1, + ) + a += 1e-5 # TODO: why?? + # Find roots + t_roots = find_roots_cubic_function(a, b, c, d) + # Find correct root + valid_t_root = [ + t_root.real + for t_root in t_roots + if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001 + ] + if len(valid_t_root) > 1: # pragma: no cover + raise ValueError("Multiple roots found when solving for rail exit time.") + if len(valid_t_root) == 0: # pragma: no cover + raise ValueError("No valid roots found when solving for rail exit time.") + # Determine final state when upper button is going out of rail + self.t = valid_t_root[0] + self.solution[-2][0] + interpolator = phase.solver.dense_output() + self.y_sol = interpolator(self.t) + self.solution[-1] = [self.t, *self.y_sol] + self.out_of_rail_time = self.t + self.out_of_rail_time_index = len(self.solution) - 1 + self.out_of_rail_state = self.y_sol + # Create new flight phase + self.flight_phases.add_phase( + self.t, + self.u_dot_generalized, + index=phase_index + 1, + ) + # Prepare to leave loops and start new flight phase + phase.time_nodes.flush_after(node_index) + phase.time_nodes.add_node(self.t, [], [], []) + phase.solver.status = "finished" + return True + + def __handle_apogee_event(self, phase, phase_index, node_index): + """Handle the apogee event. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True if simulation should break, False otherwise. + """ + # Assume linear vz(t) to detect when vz = 0 + t0, vz0 = self.solution[-2][0], self.solution[-2][6] + t1, vz1 = self.solution[-1][0], self.solution[-1][6] + t_root = find_root_linear_interpolation(t0, t1, vz0, vz1, 0) + # Fetch state at t_root + interpolator = phase.solver.dense_output() + self.apogee_state = interpolator(t_root) + # Store apogee data + self.apogee_time = t_root + self.apogee_x = self.apogee_state[0] + self.apogee_y = self.apogee_state[1] + self.apogee = self.apogee_state[2] + + if self.terminate_on_apogee: + self.t = self.t_final = t_root + # Roll back solution + self.solution[-1] = [self.t, *self.apogee_state] + # Set last flight phase + self.flight_phases.flush_after(phase_index) + self.flight_phases.add_phase(self.t) + # Prepare to leave loops and start new flight phase + phase.time_nodes.flush_after(node_index) + phase.time_nodes.add_node(self.t, [], [], []) + phase.solver.status = "finished" + return True + elif len(self.solution) > 2: + # adding the apogee state to solution increases accuracy + # we can only do this if the apogee is not the first state + self.solution.insert(-1, [t_root, *self.apogee_state]) + return False + + def __handle_impact_event(self, phase, phase_index, node_index): + """Handle the impact event. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True to indicate the simulation should break. + """ + # Check exactly when it happened using root finding + # Cubic Hermite interpolation (ax**3 + bx**2 + cx + d) + a, b, c, d = calculate_cubic_hermite_coefficients( + x0=0, # t0 + x1=float(phase.solver.step_size), # t1 - t0 + y0=float(self.solution[-2][3] - self.env.elevation), # z0 + yp0=float(self.solution[-2][6]), # vz0 + y1=float(self.solution[-1][3] - self.env.elevation), # z1 + yp1=float(self.solution[-1][6]), # vz1 + ) + # Find roots + t_roots = find_roots_cubic_function(a, b, c, d) + # Find correct root + t1 = self.solution[-1][0] - self.solution[-2][0] + valid_t_root = [ + t_root.real + for t_root in t_roots + if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1 + ] + if len(valid_t_root) > 1: # pragma: no cover + raise ValueError("Multiple roots found when solving for impact time.") + # Determine impact state at t_root + self.t = self.t_final = valid_t_root[0] + self.solution[-2][0] + interpolator = phase.solver.dense_output() + self.y_sol = self.impact_state = interpolator(self.t) + # Roll back solution + self.solution[-1] = [self.t, *self.y_sol] + # Save impact state + self.x_impact = self.impact_state[0] + self.y_impact = self.impact_state[1] + self.z_impact = self.impact_state[2] + self.impact_velocity = self.impact_state[5] + # Set last flight phase + self.flight_phases.flush_after(phase_index) + self.flight_phases.add_phase(self.t) + # Prepare to leave loops and start new flight phase + phase.time_nodes.flush_after(node_index) + phase.time_nodes.add_node(self.t, [], [], []) + phase.solver.status = "finished" + return True + + def __process_overshootable_nodes(self, phase, phase_index, node_index): + """Process overshootable time nodes for parachutes, controllers, and sensors. + + Parameters + ---------- + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True if a parachute was triggered and the simulation should break. + """ + overshootable_nodes = self.TimeNodes() + + overshootable_nodes.add_parachutes( + self.parachutes, self.solution[-2][0], self.t + ) + overshootable_nodes.add_controllers( + self._controllers, self.solution[-2][0], self.t + ) + overshootable_nodes.add_sensors( + self.rocket.sensors, self.solution[-2][0], self.t + ) + + # Add last time node (always skipped) + overshootable_nodes.add_node(self.t, [], [], []) + + if len(overshootable_nodes) < 1: + return False # Early exit + + overshootable_nodes.sort() + overshootable_nodes.merge() + + # Clear if necessary + if overshootable_nodes[0].t == phase.t and phase.clear: + overshootable_nodes[0].parachutes = [] + overshootable_nodes[0].callbacks = [] + + # Feed overshootable time nodes trigger + interpolator = phase.solver.dense_output() + for overshootable_index, overshootable_node in self.time_iterator( + overshootable_nodes + ): + # Calculate state at node time + overshootable_node.y_sol = interpolator(overshootable_node.t) + + # Check for parachute triggers + if self.__check_overshootable_parachute_triggers( + overshootable_node, + overshootable_nodes, + overshootable_index, + phase, + phase_index, + node_index, + ): + return True + + # Process controllers at overshootable node + for controller in overshootable_node._controllers: + controller( + overshootable_node.t, + overshootable_node.y_sol, + self.solution, + self.sensors, + ) + + # Process sensors at overshootable node + if overshootable_node._component_sensors: + # Calculate u_dot for sensors at interpolated state + u_dot = phase.derivative(overshootable_node.t, overshootable_node.y_sol) + self.__measure_sensors( + overshootable_node._component_sensors, + u_dot, + overshootable_node.t, + overshootable_node.y_sol, + ) + return False + + def __check_overshootable_parachute_triggers( + self, + overshootable_node, + overshootable_nodes, + overshootable_index, + phase, + phase_index, + node_index, + ): + """Check for parachute triggers in overshootable nodes. + + Parameters + ---------- + overshootable_node : TimeNode + The current overshootable node. + overshootable_nodes : TimeNodes + The overshootable nodes collection. + overshootable_index : int + Index of the current overshootable node. + phase : FlightPhase + The current flight phase. + phase_index : int + The index of the current phase. + node_index : int + The index of the current node. + + Returns + ------- + bool + True if a parachute was triggered and the simulation should break. + """ + for parachute in overshootable_node.parachutes: + # Calculate and save pressure signal + ( + noisy_pressure, + height_above_ground_level, + ) = self.__calculate_and_save_pressure_signals( + parachute, + overshootable_node.t, + overshootable_node.y_sol[2], + ) + + # Check for parachute trigger + if not parachute.triggerfunc( + noisy_pressure, + height_above_ground_level, + overshootable_node.y_sol, + self.sensors, + ): + return False # Early exit, parachute not triggred + + # Remove parachute from flight parachutes + self.parachutes.remove(parachute) + + # Create phase for time after detection and before inflation + # Must only be created if parachute has any lag + i = 1 + if parachute.lag != 0: + self.flight_phases.add_phase( + overshootable_node.t, + phase.derivative, + clear=True, + index=phase_index + i, + ) + i += 1 + + # Create flight phase for time after inflation + callbacks = [ + lambda self, parachute_cd_s=parachute.cd_s: setattr( + self, "parachute_cd_s", parachute_cd_s + ) + ] + self.flight_phases.add_phase( + overshootable_node.t + parachute.lag, + self.u_dot_parachute, + callbacks, + clear=False, + index=phase_index + i, + ) + + # Rollback history + self.t = overshootable_node.t + self.y_sol = overshootable_node.y_sol + self.solution[-1] = [overshootable_node.t, *overshootable_node.y_sol] + + # Prepare to leave loops and start new flight phase + overshootable_nodes.flush_after(overshootable_index) + phase.time_nodes.flush_after(node_index) + phase.time_nodes.add_node(self.t, [], [], []) + phase.solver.status = "finished" + + # Save parachute event + self.parachute_events.append([self.t, parachute]) + return True + + return False + def __calculate_and_save_pressure_signals(self, parachute, t, z): """Gets noise and pressure signals and saves them in the parachute object given the current time and altitude. @@ -1199,16 +1439,11 @@ def __init_controllers(self): """Initialize controllers and sensors""" self._controllers = self.rocket._controllers[:] self.sensors = self.rocket.sensors.get_components() - if self._controllers or self.sensors: - if self.time_overshoot: # pragma: no cover - self.time_overshoot = False - warnings.warn( - "time_overshoot has been set to False due to the presence " - "of controllers or sensors. " - ) - # reset controllable object to initial state (only airbrakes for now) - for air_brakes in self.rocket.air_brakes: - air_brakes._reset() + # Note: time_overshoot now supports both controllers and sensors + + # reset controllable object to initial state (only airbrakes for now) + for air_brakes in self.rocket.air_brakes: + air_brakes._reset() self.sensor_data = {} for sensor in self.sensors: diff --git a/tests/fixtures/flight/flight_fixtures.py b/tests/fixtures/flight/flight_fixtures.py index da47e6cdb..a83599f37 100644 --- a/tests/fixtures/flight/flight_fixtures.py +++ b/tests/fixtures/flight/flight_fixtures.py @@ -245,6 +245,22 @@ def flight_calisto_air_brakes(calisto_air_brakes_clamp_on, example_plain_env): ) +@pytest.fixture +def flight_calisto_air_brakes_time_overshoot( + calisto_air_brakes_clamp_on, example_plain_env +): + """Same as flight_calisto_air_brakes but with time_overshoot=True.""" + return Flight( + rocket=calisto_air_brakes_clamp_on, + environment=example_plain_env, + rail_length=5.2, + inclination=85, + heading=0, + time_overshoot=True, + terminate_on_apogee=True, + ) + + @pytest.fixture def flight_calisto_with_sensors(calisto_with_sensors, example_plain_env): """A rocketpy.Flight object of the Calisto rocket. This uses the calisto diff --git a/tests/integration/test_environment.py b/tests/integration/test_environment.py index e4c6b07f5..60a401007 100644 --- a/tests/integration/test_environment.py +++ b/tests/integration/test_environment.py @@ -160,9 +160,8 @@ def test_nam_atmosphere(mock_show, example_spaceport_env): # pylint: disable=un @pytest.mark.slow @patch("matplotlib.pyplot.show") def test_rap_atmosphere(mock_show, example_spaceport_env): # pylint: disable=unused-argument - today = date.today() now = datetime.now(timezone.utc) - example_spaceport_env.set_date((today.year, today.month, today.day, now.hour)) + example_spaceport_env.set_date((now.year, now.month, now.day, now.hour)) example_spaceport_env.set_atmospheric_model(type="Forecast", file="RAP") assert example_spaceport_env.all_info() is None diff --git a/tests/integration/test_flight.py b/tests/integration/test_flight.py index c47b9b124..11f25927b 100644 --- a/tests/integration/test_flight.py +++ b/tests/integration/test_flight.py @@ -228,7 +228,7 @@ def test_liquid_motor_flight(mock_show, flight_calisto_liquid_modded): # pylint @pytest.mark.slow @patch("matplotlib.pyplot.show") -def test_time_overshoot(mock_show, calisto_robust, example_spaceport_env): # pylint: disable=unused-argument +def test_time_overshoot_false(mock_show, calisto_robust, example_spaceport_env): # pylint: disable=unused-argument """Test the time_overshoot parameter of the Flight class. This basically calls the all_info() method for a simulation without time_overshoot and checks if it returns None. It is not testing if the values are correct, @@ -418,6 +418,19 @@ def test_air_brakes_flight(mock_show, flight_calisto_air_brakes): # pylint: dis assert air_brakes.prints.all() is None +@patch("matplotlib.pyplot.show") +def test_air_brakes_flight_with_overshoot( + mock_show, flight_calisto_air_brakes_time_overshoot +): # pylint: disable=unused-argument + """ + Same as test_air_brakes_flight but with time_overshoot=True. + """ + test_flight = flight_calisto_air_brakes_time_overshoot + air_brakes = test_flight.rocket.air_brakes[0] + assert air_brakes.plots.all() is None + assert air_brakes.prints.all() is None + + @patch("matplotlib.pyplot.show") def test_initial_solution(mock_show, example_plain_env, calisto_robust): # pylint: disable=unused-argument """Tests the initial_solution option of the Flight class. This test simply diff --git a/tests/integration/test_plots.py b/tests/integration/test_plots.py index 7595855e4..933737c4b 100644 --- a/tests/integration/test_plots.py +++ b/tests/integration/test_plots.py @@ -74,10 +74,10 @@ def test_compare_flights(mock_show, mock_figure_show, calisto, example_plain_env ) calisto.set_rail_buttons(-0.5, 0.2) - inclinations = [60, 70, 80, 90] - headings = [0, 45, 90, 180] + inclinations = [60, 90] + headings = [0, 180] flights = [] - # Create (4 * 4) = 16 different flights to be compared + # Create (2 * 2) = 4 different flights to be compared for heading in headings: for inclination in inclinations: flight = Flight(