Source code for igp2.planning.rollout

import logging
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

from gui.tracks_import import calculate_rotated_bboxes
from igp2.core.agentstate import AgentState, AgentMetadata
from igp2.core.trajectory import Trajectory, StateTrajectory
from igp2.core.goal import Goal
from igp2.core.vehicle import Observation
from igp2.agents.agent import Agent
from igp2.agents.macro_agent import MacroAgent
from igp2.agents.traffic_agent import TrafficAgent
from igp2.agents.trajectory_agent import TrajectoryAgent
from import Map
from igp2.opendrive.plot_map import plot_map
from igp2.planlibrary.macro_action import MacroAction

logger = logging.getLogger(__name__)

[docs]class Rollout: """ Lightweight environment simulator useful for rolling out scenarios in MCTS. One agent is designated as the ego vehicle, while the other agents follow predefined trajectories calculated during goal recognition. Simulation is performed at a given frequency with collision checking. """ def __init__(self, ego_id: int, initial_frame: Dict[int, AgentState], metadata: Dict[int, AgentMetadata], scenario_map: Map, fps: int = 10, open_loop_agents: bool = False, trajectory_agents: bool = False, t_max: int = 300): """Initialise new light-weight simulator with the given params. Args: ego_id: ID of the ego vehicle initial_frame: initial state of the environment metadata: metadata describing the agents in the environment scenario_map: current road layout fps: frame rate of simulation open_loop_agents: Whether non-ego agents follow open-loop control trajectory_agents: Whether to use predicted trajectories directly or CL macro actions for non-egos t_max: Maximum rollout time step length """ assert ego_id in initial_frame, f"Ego ID {ego_id} is not in the initial frame!" assert ego_id in metadata, f"Ego ID {ego_id} not among given metadata!" if open_loop_agents and not trajectory_agents: logger.warning(f"Set trajectory_agents to True, if using open loop rollouts.") self._scenario_map = scenario_map self._ego_id = ego_id self._initial_frame = initial_frame self._metadata = metadata self._fps = fps self._open_loop = open_loop_agents self._trajectory_agents = trajectory_agents self._t_max = t_max self._agents = self._create_agents()
[docs] def update_trajectory(self, agent_id: int, new_trajectory: Trajectory, new_plan: List[MacroAction]): """ Update the predicted trajectory of the non-ego agent. Has no effect for ego or if agent_id not in agents Args: agent_id: ID of agent to update new_trajectory: new trajectory for agent new_plan: The macro action plan that generated new_trajectory """ if agent_id in self._agents and agent_id != self._ego_id: if self._trajectory_agents: self._agents[agent_id].set_trajectory(new_trajectory) else: self._agents[agent_id].set_macro_actions(new_plan)
[docs] def update_ego_action(self, action: MacroAction, args: Dict, frame: Dict[int, AgentState]) -> MacroAction: """ Update the current macro action of the ego vehicle. Args: action: new macro action to execute args: MA initialisation arguments frame: Current state of the environment Returns: The currently execute MA of the ego """ observation = Observation(frame, self._scenario_map) ma = self._agents[self._ego_id].update_macro_action(action, args, observation) return ma
[docs] def update_ego_goal(self, goal: Goal): """ Update the final goal of the ego vehicle. Args: goal: new goal to reach """ self._agents[self._ego_id].update_goal(goal)
[docs] def reset(self): """ Reset the internal states of the environment to initialisation defaults. """ for agent_id, agent in self._agents.items(): agent.reset()
[docs] def run(self, start_frame: Dict[int, AgentState], plot_rollout: bool = False) \ -> Tuple[StateTrajectory, Dict[int, AgentState], bool, bool, List[Agent]]: """ Execute current macro action of ego and forward the state of the environment with collision checking. Returns: A 5-tuple giving the new state of the environment, the final frame of the simulation, whether the ego has reached its goal, whether the ego is still alive, and if it has collided with another (possible multiple) agents and if so the colliding agents. """ ego = self._agents[self._ego_id] current_observation = Observation(start_frame, self._scenario_map) goal_reached = False collisions = [] start_time = len(ego.trajectory_cl.states) t = 0 while t < self._t_max and ego.alive and not goal_reached and not ego.done(current_observation): new_frame = {} # Update agent states for agent_id, agent in self._agents.items(): if not agent.alive: continue if agent.done(current_observation): agent.alive = False continue new_state = agent.next_state(current_observation) agent.trajectory_cl.add_state(new_state, reload_path=False) new_frame[agent_id] = new_state agent.alive = len(self._scenario_map.roads_at(new_state.position)) > 0 current_observation = Observation(new_frame, self._scenario_map) collisions = self._check_collisions(ego) if collisions: ego.alive = False else: goal_reached = ego.goal.reached(ego.state.position) if plot_rollout and t % 5 == 0: self.plot(t) t += 1 ego.trajectory_cl.calculate_path_and_velocity() driven_trajectory = ego.trajectory_cl.slice(start_time, start_time + t) return driven_trajectory, current_observation.frame, goal_reached, ego.alive, collisions
def _create_agents(self) -> Dict[int, Agent]: """ Initialise new agents. Each non-ego is a TrajectoryAgent, while the ego is a MacroAgent. """ agents = {} for aid, state in self._initial_frame.items(): if aid == self._ego_id: agents[aid] = MacroAgent(aid, state, fps=self._fps) elif self._trajectory_agents: agents[aid] = TrajectoryAgent(aid, state, fps=self._fps, open_loop=self._open_loop) else: agents[aid] = TrafficAgent(aid, state, fps=self._fps) return agents def _check_collisions(self, ego: Agent) -> List[Agent]: """ Check for collisions with the given vehicle in the environment. """ colliding_agents = [] for agent_id, agent in self._agents.items(): if agent_id == ego.agent_id or not agent.alive: continue if agent.vehicle.overlaps(ego.vehicle): agent.alive = False colliding_agents.append(agent) return colliding_agents
[docs] def plot(self, t: int, axis: plt.Axes = None) -> plt.Axes: """ Plot the current agents and the road layout for visualisation purposes. Args: axis: Axis to draw on """ if axis is None: fig, axis = plt.subplots() color_map_ego ='Reds') color_map_non_ego ='Blues') color_ego = 'r' color_non_ego = 'b' color_bar_non_ego = None plot_map(self._scenario_map, markings=True, ax=axis) for agent_id, agent in self._agents.items(): if not agent.alive: continue if isinstance(agent, TrajectoryAgent): color = color_non_ego color_map = color_map_non_ego path = agent.trajectory.path velocity = agent.trajectory.velocity elif isinstance(agent, TrafficAgent): color = color_non_ego color_map = color_map_non_ego path = agent.current_macro.current_maneuver.trajectory.path velocity = agent.current_macro.current_maneuver.trajectory.velocity elif isinstance(agent, MacroAgent): color = color_ego color_map = color_map_ego path = agent.current_macro.current_maneuver.trajectory.path velocity = agent.current_macro.current_maneuver.trajectory.velocity vehicle = agent.vehicle bounding_box = calculate_rotated_bboxes([0],[1], vehicle.length, vehicle.width, vehicle.heading) pol = plt.Polygon(bounding_box[0], color=color) axis.add_patch(pol) agent_plot = axis.scatter(path[:, 0], path[:, 1], c=velocity, cmap=color_map, vmin=-4, vmax=20, s=8) if isinstance(agent, MacroAgent) and not isinstance(agent, TrafficAgent): plt.colorbar(agent_plot) plt.text(0, 0.1, 'Current Velocity: ' + str(agent.state.speed), horizontalalignment='left', verticalalignment='bottom', transform=axis.transAxes) plt.text(0, 0.05, 'Current Macro Action: ' + agent.current_macro.__repr__(), horizontalalignment='left', verticalalignment='bottom', transform=axis.transAxes) plt.text(0, 0, 'Current Maneuver: ' + agent.current_macro.current_maneuver.__repr__(), horizontalalignment='left', verticalalignment='bottom', transform=axis.transAxes) elif isinstance(agent, TrajectoryAgent) and color_bar_non_ego is None: color_bar_non_ego = plt.colorbar(agent_plot) plt.text(*agent.state.position, agent_id) plt.title(f"T={t}") return axis
@property def ego_id(self) -> int: """ ID of the ego vehicle""" return self._ego_id @property def agents(self) -> Dict[int, Agent]: """ Return current agents of the environment """ return self._agents @property def initial_frame(self) -> Dict[int, AgentState]: """ Return the initial state of the environment """ return self._initial_frame @property def fps(self) -> int: """ Executing frame rate of the simulator""" return self._fps @property def metadata(self) -> Dict[int, AgentMetadata]: """ Metadata of agents in the current frame """ return self._metadata