Source code for igp2.data.data_loaders

"""
Modified version of code from https://github.com/cbrewitt/av-goal-recognition/blob/master/core/
based on https://github.com/ika-rwth-aachen/drone-dataset-tools
"""

import abc
import logging
from typing import Optional, List
from itertools import compress

from igp2.data.episode import Episode
from igp2.data.scenario import Scenario, InDScenario

logger = logging.getLogger(__name__)


[docs]class DataLoader(abc.ABC): """ Abstract class that is implemented by every DataLoader that IGP2 can use. A set of recordings are collected into an Episode. Episodes and the corresponding Map and configuration are managed by a Scenario object. The created DataLoader is iterable. """ def __init__(self, config_path: str, splits: List[str] = None): """ Create a new data loader object Args: config_path: The path under which the configuration JSON file is located splits: Optional parameter to specify which data split(s) to iterate over. """ self.config_path = config_path self.splits = splits self._scenario = None def __iter__(self): raise NotImplementedError def __next__(self) -> Episode: raise NotImplementedError @property def scenario(self) -> Optional[Scenario]: """ Return the Scenario object""" return self._scenario
[docs] def load(self): """ Load the Scenario object with the configuration file.""" raise NotImplementedError
[docs] def train(self) -> List[Episode]: """ Return the training data portion of the Scenario """ raise NotImplementedError
[docs] def valid(self) -> List[Episode]: """ Return the validation data portion of the Scenario """ raise NotImplementedError
[docs] def test(self) -> List[Episode]: """ Return the test data portion of the Scenario """ raise NotImplementedError
[docs]class InDDataLoader(DataLoader):
[docs] def load(self): """ Load all episodes of the scenario """ self._scenario = InDScenario.load(self.config_path, self.splits)
def __iter__(self): if self._scenario is None: raise RuntimeError("The scenario has not been loaded yet. Try calling the load() method!") self._iter_idx = 0 return self def __next__(self) -> Episode: if self._iter_idx < len(self._scenario.episodes): episode = self._scenario.episodes[self._iter_idx] self._iter_idx += 1 return episode else: raise StopIteration
[docs] def get_split(self, splits: List[str] = None) -> List[Episode]: if self._scenario is None: raise RuntimeError("The scenario has not been loaded yet. Try calling the load() method!") if splits is None: return self._scenario.episodes else: indices = [] for s in splits: if s not in self.splits: raise ValueError(f"Split type {s} is not in the valid loaded splits: {self.splits}!") indices.extend(self._scenario.config.dataset_split[s]) return list(compress(self._scenario.episodes, indices))
[docs] def train(self) -> List[Episode]: return self.get_split(["train"])
[docs] def valid(self) -> List[Episode]: return self.get_split(["valid"])
[docs] def test(self) -> List[Episode]: return self.get_split(["test"])