import abc
from typing import Any, Tuple

import numpy as np
from math import sqrt

from igp2.planning.node import Node

[docs]class Policy(abc.ABC): """ Abstract class for implementing various selection policies """
[docs] def select(self, node: Node) -> Tuple[Any, int]: """ Select an action from the node's list of actions using its Q-values. Returns: the action and its index in the list of actions of the node """ raise NotImplementedError
[docs]class MaxPolicy(Policy): """ Policy selecting the action with highest Q-value at a node. """
[docs] def select(self, node: Node): idx = np.argmax(node.q_values) return node.actions[idx], idx
[docs]class UCB1(Policy): """ Policy implementing the UCB1 selection policy. Ref:""" def __init__(self, c: float = sqrt(2)): """ Initialise new UCB1 policy Args: c: the exploration parameter """ self.c = c
[docs] def select(self, node: Node): with np.errstate(divide="ignore"): values = node.q_values + self.c * np.sqrt(np.log(node.state_visits) / node.action_visits) idx = np.argmax(values) return node.actions[idx], idx