Source code for igp2.planning.node

import copy
from collections import defaultdict
from typing import Dict, List, Tuple

import logging
import numpy as np

from igp2.planning.mctsaction import MCTSAction
from igp2.planning.reward import Reward
from igp2.core.results import RunResult
from igp2.core.agentstate import AgentState
from igp2.core.util import copy_agents_dict

logger = logging.getLogger(__name__)


[docs]class Node: """ Represents a search node in the MCTS tree. Stores all relevant information for computation of Q-values and action selection. Keys must be hashable. During search, a Node must be expanded before it can be added to a Tree. Children of the node are stored in a dictionary with the key being the state and the value the child node itself. """ def __init__(self, key: Tuple, state: Dict[int, AgentState], actions: List[MCTSAction]): if key is None or not isinstance(key, Tuple): raise TypeError(f"Node key must not be a tuple.") self._key = key self._state = state self._actions = actions self._children = {} self._state_visits = 0 self._q_values = None self._action_visits = None self._run_result = None self._reward_results = defaultdict(list) def __repr__(self): return str(self.key) def __deepcopy__(self, memodict={}): """ Overwrite standard deepcopy to avoid infinite recursion with run results. """ cls = self.__class__ result = cls.__new__(cls) memodict[id(self)] = result for k, v in self.__dict__.items(): if k == "_run_result" and isinstance(v, RunResult): run_result = RunResult.__new__(RunResult) memodict[id(run_result)] = run_result for rk, rv in self._run_result.__dict__.items(): if rk == "agents": setattr(run_result, rk, copy_agents_dict(rv, memodict)) else: setattr(run_result, rk, copy.deepcopy(rv, memodict)) setattr(result, k, run_result) else: setattr(result, k, copy.deepcopy(v, memodict)) return result
[docs] def expand(self): if self._actions is None: raise TypeError("Cannot expand node without actions") self._q_values = np.zeros(len(self._actions)) self._action_visits = np.zeros(len(self._actions), dtype=np.int32)
[docs] def add_child(self, child: "Node"): """ Add a new child to the dictionary of children. """ self._children[child.key] = child
[docs] def add_reward_result(self, key: Tuple[str], reward_results: Reward): """ Add a new reward outcome to the node if the search has ended here. """ action = key[-1] assert action in self.actions_names, f"Action {action} not in Node {self._key}" self._reward_results[action].append(reward_results)
[docs] def store_q_values(self): """ Save the current q_values into the last element of run_results. """ if self._run_result is not None: self._run_result.q_values = copy.copy(self.q_values)
@property def q_values(self) -> np.ndarray: """ Return the Q-values corresponding to each action. """ return self._q_values @q_values.setter def q_values(self, value: np.ndarray): self._q_values = value @property def key(self) -> Tuple: """ Unique hashable key identifying the node and the sequence of actions that lead to it. """ return self._key @property def state(self) -> Dict[int, AgentState]: """ Return the state corresponding to this node. """ return self._state @property def actions(self) -> List[MCTSAction]: """ Return possible actions in state of node. """ return self._actions @property def actions_names(self) -> List[str]: """ Return the human-readable names of actions in the node. """ return [action.__repr__() for action in self._actions] @property def state_visits(self) -> int: """ Return number of time this state has been selected. """ return self._state_visits @state_visits.setter def state_visits(self, value: int): self._state_visits = value @property def action_visits(self) -> np.ndarray: """ Return number of time each action has been selected in this node. """ return self._action_visits @property def children(self) -> Dict[Tuple, "Node"]: """ Return the dictionary of children. """ return self._children @property def is_leaf(self) -> bool: """ Return true if the node has no children. """ return len(self._children) == 0 @property def run_result(self) -> RunResult: """ Return a list of the simulated runs results for this node. """ return self._run_result @run_result.setter def run_result(self, value: RunResult): self._run_result = value @property def reward_results(self) -> Dict[str, List[Reward]]: """ Returns a dictionary of reward outcomes where the keys are all possible actions in the node. """ return self._reward_results @property def descendants(self): """ Return all descendants of this node. """ descendants = [] for key, child in self.children.items(): descendants.append((key, child)) descendants.extend(child.descendants) return descendants