Source code for igp2.planning.mcts

import igp2 as ip
import copy
import traceback
import logging
from typing import List, Dict, Tuple

from igp2.planning.tree import Tree
from igp2.planning.rollout import Rollout
from igp2.planning.node import Node
from igp2.planning.mctsaction import MCTSAction
from igp2.planning.reward import Reward
from igp2.core.util import copy_agents_dict

logger = logging.getLogger(__name__)


[docs]class MCTS: """ Class implementing single-threaded MCTS search over environment states with macro actions. """ def __init__(self, scenario_map: ip.Map, n_simulations: int = 30, max_depth: int = 5, reward: Reward = None, open_loop_rollout: bool = False, trajectory_agents: bool = True, fps: int = 10, store_results: str = None, tree_type: type(Tree) = None, node_type: type(Node) = None, action_type: type(MCTSAction) = None): """ Initialise a new MCTS planner over states and macro-actions. Args: n_simulations: number of rollout simulations to run. max_depth: maximum search depth. scenario_map: current road layout. reward: class to calculate trajectory reward for ego. open_loop_rollout: Whether to use open-loop predictions directly instead of closed-loop control. trajectory_agents: To use trajectories or plans for non-egos in simulation. fps: Rollout simulation frequency. tree_type: Type of Tree to use for the search. Allows overwriting standard behaviour. node_type: Type of Node to use in the Tree. Allows overwriting standard behaviour. """ self.n = n_simulations self.d_max = max_depth self.scenario_map = scenario_map self.reward = reward if reward is not None else Reward() self.open_loop_rollout = open_loop_rollout self.trajectory_agents = trajectory_agents self.fps = fps self.tree_type = tree_type if tree_type is not None else Tree self.node_type = node_type if node_type is not None else Node self.action_type = action_type if action_type is not None else MCTSAction self.store_results = store_results self.results = None self.reset_results()
[docs] def reset_results(self): """ Resets the stored results in the MCTS instance.""" if self.store_results is None: self.results = None elif self.store_results == 'final': self.results = ip.MCTSResult() elif self.store_results == 'all': self.results = ip.AllMCTSResult()
[docs] def search(self, agent_id: int, goal: ip.Goal, frame: Dict[int, ip.AgentState], meta: Dict[int, ip.AgentMetadata], predictions: Dict[int, ip.GoalsProbabilities], debug: bool = False) -> List[MCTSAction]: """ Run MCTS search for the given agent Args: agent_id: agent to plan for goal: end goal of the vehicle frame: current (observed) state of the environment meta: metadata of agents present in frame predictions: dictionary of goal predictions for agents in frame debug: Whether to plot rollouts. Returns: a list of macro actions encoding the optimal plan for the ego agent given the current goal predictions for other agents """ self.reset_results() self.reward.reset() simulator = Rollout(ego_id=agent_id, initial_frame=frame, metadata=meta, scenario_map=self.scenario_map, fps=self.fps, open_loop_agents=self.open_loop_rollout, trajectory_agents=self.trajectory_agents) simulator.update_ego_goal(goal) # 1. Create tree root from current frame root = self.create_node(MCTS.to_key(None), agent_id, frame, goal) tree = self.tree_type(root, predictions=predictions) for k in range(self.n): logger.info(f"MCTS Iteration {k + 1}/{self.n}") # 3-6. Sample goal and trajectory samples = {} for aid, agent in simulator.agents.items(): if aid == simulator.ego_id: continue agent_goal = predictions[aid].sample_goals()[0] trajectory, plan = predictions[aid].sample_trajectories_to_goal(agent_goal) if trajectory is not None: trajectory, plan = trajectory[0], plan[0] simulator.update_trajectory(aid, trajectory, plan) samples[aid] = (agent_goal, trajectory) logger.debug(f"Agent {aid} sample: {plan}") tree.set_samples(samples) final_key = self._run_simulation(agent_id, goal, tree, simulator, debug) if self.store_results == "all": logger.debug(f"Storing MCTS search results for iteration {k}.") mcts_result = ip.MCTSResult(copy.deepcopy(tree), samples, final_key) self.results.add_data(mcts_result) simulator.reset() self.reward.reset() tree.on_finish() final_plan = tree.select_plan() logger.info(f"Final plan: {final_plan}") tree.print() if self.store_results == "final": self.results.tree = tree elif self.store_results == "all": self.results.final_plan = final_plan return final_plan
def _run_simulation(self, agent_id: int, goal: ip.Goal, tree: Tree, simulator: Rollout, debug: bool) -> tuple: depth = 0 node = tree.root key = node.key current_frame = node.state actions = [] while depth < self.d_max: logger.debug(f"Rollout {depth + 1}/{self.d_max}") node.state_visits += 1 final_frame = None # 8. Select applicable macro action with UCB1 action = tree.select_action(node) actions.append(action) simulator.update_ego_action(action.macro_action_type, action.ma_args, current_frame) logger.debug(f"Action selection: {key} -> {action} from {node.actions_names}") # 9. Forward simulate environment try: trajectory, final_frame, goal_reached, alive, collisions = \ simulator.run(current_frame, debug) collided_agents_ids = [col.agent_id for col in collisions] if self.store_results is not None: agents_copy = copy_agents_dict(simulator.agents, agent_id) node.run_result = ip.RunResult( agents_copy, simulator.ego_id, trajectory, collided_agents_ids, goal_reached, action) # 10-16. Reward computation r = self.reward(collisions=collisions, alive=alive, ego_trajectory=simulator.agents[agent_id].trajectory_cl if goal_reached else None, goal=goal, depth_reached=depth == self.d_max - 1) if r is not None: logger.debug(f"Reward components: {self.reward.reward_components}") except Exception as e: logger.debug(f"Rollout failed due to error: {str(e)}") logger.debug(traceback.format_exc()) r = -float("inf") # Create new node at the end of rollout key = MCTS.to_key(actions) # 17-19. Back-propagation if r is not None: logger.info(f"Rollout finished: r={r}; d={depth + 1}") node.add_reward_result(key, copy.deepcopy(self.reward)) tree.backprop(r, key) break # 20. Update state variables current_frame = final_frame if key not in tree: child = self.create_node(key, agent_id, current_frame, goal) tree.add_child(node, child) node = tree[key] depth += 1 return key
[docs] def create_node(self, key: Tuple, agent_id: int, frame: Dict[int, ip.AgentState], goal: ip.Goal) -> Node: """ Create a new node and expand it. Args: key: Key to assign to the node agent_id: Agent we are searching for frame: Current state of the environment goal: Goal of the agent with agent_id """ actions = [] for macro_action in ip.MacroActionFactory.get_applicable_actions(frame[agent_id], self.scenario_map): for ma_args in macro_action.get_possible_args(frame[agent_id], self.scenario_map, goal): actions.append(self.action_type(macro_action, ma_args)) node = self.node_type(key, frame, actions) node.expand() return node
[docs] @staticmethod def to_key(plan: List[MCTSAction] = None) -> Tuple[str, ...]: """ Convert a list of MCTS actions to an MCTS key. """ if plan is None: return tuple(["Root"]) return ("Root",) + tuple([action.__repr__() for action in plan])