import copy
import traceback
import logging
from typing import List, Dict, Tuple
from igp2.opendrive.map import Map
from igp2.recognition.goalprobabilities import GoalsProbabilities
from igp2.planning.tree import Tree
from igp2.planning.rollout import Rollout
from igp2.planning.node import Node
from igp2.planning.mctsaction import MCTSAction
from igp2.planning.reward import Reward
from igp2.planlibrary.macro_action import MacroActionFactory
from igp2.core.results import MCTSResult, AllMCTSResult, RunResult
from igp2.core.util import copy_agents_dict
from igp2.core.goal import Goal
from igp2.core.agentstate import AgentState, AgentMetadata
logger = logging.getLogger(__name__)
[docs]
class MCTS:
""" Class implementing single-threaded MCTS search over environment states with macro actions. """
def __init__(self,
scenario_map: Map,
n_simulations: int = 30,
max_depth: int = 5,
reward: Reward = None,
open_loop_rollout: bool = False,
trajectory_agents: bool = True,
fps: int = 10,
env_fps: int = 20,
store_results: str = None,
**kwargs):
""" Initialise a new MCTS planner over states and macro-actions.
Args:
n_simulations: number of rollout simulations to run.
max_depth: maximum search depth.
scenario_map: current road layout.
reward: class to calculate trajectory reward for ego.
open_loop_rollout: Whether to use open-loop predictions directly instead of closed-loop control.
trajectory_agents: To use trajectories or plans for non-egos in simulation.
fps: Rollout simulation frequency.
env_fps: Environment simulation frequency.
Keyword Args:
tree_type: Type of Tree to use for the search. Allows overwriting standard behaviour.
node_type: Type of Node to use in the Tree. Allows overwriting standard behaviour.
action_type: Type of MCTSAction to use for the search. Allows overwriting standard behaviour.
rollout_type: Type of Rollout simulator to use for the search. Allows overwriting standard behaviour.
"""
self.n = n_simulations
self.d_max = max_depth
self.scenario_map = scenario_map
self.reward = reward if reward is not None else Reward()
self.open_loop_rollout = open_loop_rollout
self.trajectory_agents = trajectory_agents
self.fps = fps
self.env_fps = env_fps
self.tree_type = kwargs.get("tree_type", Tree)
self.node_type = kwargs.get("node_type", Node)
self.action_type = kwargs.get("action_type", MCTSAction)
self.rollout_type = kwargs.get("rollout_type", Rollout)
self.store_results = store_results
self.results = None
self.reset()
[docs]
def search(self,
agent_id: int,
goal: Goal,
frame: Dict[int, AgentState],
meta: Dict[int, AgentMetadata],
predictions: Dict[int, GoalsProbabilities],
debug: bool = False) -> Tuple[List[MCTSAction], Tree]:
""" Run MCTS search for the given agent
Args:
agent_id: agent to plan for
goal: end goal of the vehicle
frame: current (observed) state of the environment
meta: metadata of agents present in frame
predictions: dictionary of goal predictions for agents in frame
debug: Whether to plot rollouts.
Returns:
a list of macro actions encoding the optimal plan for the ego agent given the current goal predictions
for other agents and the search tree.
"""
self.reset()
simulator = self.rollout_type(ego_id=agent_id,
initial_frame=frame,
metadata=meta,
scenario_map=self.scenario_map,
fps=self.fps,
open_loop_agents=self.open_loop_rollout,
trajectory_agents=self.trajectory_agents)
simulator.update_ego_goal(goal)
# 1. Create tree root from current frame
tree = self._create_tree(agent_id, frame, goal, predictions)
for k in range(self.n):
logger.info(f"MCTS Iteration {k + 1}/{self.n}")
self._rollout(k, agent_id, goal, tree, simulator, debug, predictions)
simulator.reset()
self.reward.reset()
tree.on_finish()
final_plan, optimal_trace = tree.select_plan()
logger.info(f"Final plan: {final_plan}")
tree.print()
if self.store_results == "final":
self.results.tree = tree
elif self.store_results == "all":
self.results.final_plan = final_plan
self.results.predictions = predictions
self.results.optimal_trace = optimal_trace
return final_plan, tree
def _sample_agents(self, aid: int, predictions: Dict[int, GoalsProbabilities]):
""" Perform sampling of goals and agent trajectories. """
goal = predictions[aid].sample_goals()[0]
trajectory, plan = predictions[aid].sample_trajectories_to_goal(goal)
if trajectory is not None:
trajectory, plan = trajectory[0], plan[0]
return goal, trajectory, plan
def _reset_results(self):
""" Resets the stored results in the MCTS instance."""
if self.store_results is None:
self.results = None
elif self.store_results == 'final':
self.results = MCTSResult()
elif self.store_results == 'all':
self.results = AllMCTSResult()
[docs]
def reset(self):
""" Reset the MCTS planner. """
self._reset_results()
self.reward.reset()
def _create_tree(self,
agent_id: int,
frame: Dict[int, AgentState],
goal: Goal,
predictions: Dict[int, GoalsProbabilities]):
""" Creates a new MCTS tree to store results. """
root = self._create_node(self.to_key(None), agent_id, frame, goal)
tree = self.tree_type(root)
return tree
def _rollout(self, k: int, agent_id: int, goal: Goal, tree: Tree,
simulator: Rollout, debug: bool, predictions: Dict[int, GoalsProbabilities]):
""" Perform a single rollout of the MCTS search and store results."""
# 3-6. Sample goal and trajectory
samples = {}
for aid, agent in simulator.agents.items():
if aid == simulator.ego_id:
continue
agent_goal, trajectory, plan = self._sample_agents(aid, predictions)
simulator.update_trajectory(aid, trajectory, plan)
samples[aid] = (agent_goal, trajectory)
logger.debug(f" Agent {aid} sample: {plan}")
final_key = self._run_simulation(agent_id, goal, tree, simulator, debug)
logger.debug(f" Final key: {final_key}")
if self.store_results == "all":
logger.debug(f" Storing MCTS search results for iteration {k}.")
mcts_result = MCTSResult(copy.deepcopy(tree), samples, final_key)
self.results.add_data(mcts_result)
def _run_simulation(self, agent_id: int, goal: Goal, tree: Tree, simulator: Rollout, debug: bool) -> tuple:
depth = 0
node = tree.root
key = node.key
current_frame = node.state
actions = []
while depth < self.d_max:
logger.debug(f" Rollout {depth + 1}/{self.d_max}")
node.state_visits += 1
final_frame = None
force_reward = False
try:
# 8. Select applicable macro action with UCB1
action = tree.select_action(node)
actions.append(action)
simulator.update_ego_action(action.macro_action_type, action.ma_args, current_frame)
logger.debug(f" Action selection: {action} from {node.actions_names} in {key}")
# 9. Forward simulate environment
trajectory, final_frame, goal_reached, alive, collisions = \
simulator.run(current_frame, debug)
collided_agents_ids = [col.agent_id for col in collisions]
if self.store_results is not None:
agents_copy = copy_agents_dict(simulator.agents, agent_id)
node.run_result = RunResult(
agents_copy,
simulator.ego_id,
trajectory,
collided_agents_ids,
goal_reached,
action)
# 10-16. Reward computation
r = self.reward(collisions=collisions,
alive=alive,
ego_trajectory=simulator.agents[agent_id].trajectory_cl if goal_reached else None,
goal=goal,
depth_reached=depth == self.d_max - 1)
if r is not None:
logger.debug(f" Reward components: {self.reward.reward_components}")
force_reward = len(collisions) > 0
except Exception as e:
logger.debug(f" Rollout failed due to error: {str(e)}")
logger.debug(traceback.format_exc())
r = -float("inf")
# Create new node at the end of rollout
key = self.to_key(actions)
# 17-19. Back-propagation
if r is not None:
logger.info(f" Rollout finished: r={r}; d={depth + 1}")
node.add_reward_result(key, copy.deepcopy(self.reward))
tree.backprop(r, key, force_reward)
break
# 20. Update state variables
current_frame = final_frame
if key not in tree:
child = self._create_node(key, agent_id, current_frame, goal)
tree.add_child(node, child)
node = tree[key]
depth += 1
return key
def _create_node(self,
key: Tuple,
agent_id: int,
frame: Dict[int, AgentState],
goal: Goal) -> Node:
""" Create a new node and expand it.
Args:
key: Key to assign to the node
agent_id: Agent we are searching for
frame: Current state of the environment
goal: Goal of the agent with agent_id
"""
actions = []
for macro_action in MacroActionFactory.get_applicable_actions(frame[agent_id], self.scenario_map):
for ma_args in macro_action.get_possible_args(frame[agent_id], self.scenario_map, goal):
actions.append(self.action_type(macro_action, ma_args))
node = self.node_type(key, frame, actions[::-1])
node.expand()
return node
[docs]
def to_key(self, plan: List[MCTSAction] = None) -> Tuple[str, ...]:
""" Convert a list of MCTS actions to an MCTS key. """
if plan is None:
return tuple(["Root"])
return ("Root",) + tuple([str(action) for action in plan])