Source code for zensols.dataset.interface
"""Interfaces used for dealing with dataset splits.
"""
__author__ = 'Paul Landes'
from typing import Dict, Set, Tuple
from dataclasses import dataclass
from abc import abstractmethod, ABCMeta
import logging
import sys
from io import TextIOBase
from zensols.util import APIError
from zensols.config import Writable
from zensols.persist import Stash, PrimeableStash
logger = logging.getLogger(__name__)
[docs]
class DatasetError(APIError):
"""Thrown when any dataset related is raised."""
[docs]
@dataclass
class SplitKeyContainer(Writable, metaclass=ABCMeta):
"""An interface defining a container that partitions data sets
(i.e. ``train`` vs ``test``). For instances of this class, that data are
the unique keys that point at the data.
"""
def _get_split_names(self) -> Set[str]:
return self._get_keys_by_split().keys()
def _get_counts_by_key(self) -> Dict[str, int]:
ks = self._get_keys_by_split()
return {k: len(ks[k]) for k in ks.keys()}
@abstractmethod
def _get_keys_by_split(self) -> Dict[str, Tuple[str, ...]]:
pass
@property
def split_names(self) -> Set[str]:
"""Return the names of each split in the dataset.
"""
return self._get_split_names()
@property
def counts_by_key(self) -> Dict[str, int]:
"""Return data set splits name to count for that respective split.
"""
return self._get_counts_by_key()
@property
def keys_by_split(self) -> Dict[str, Tuple[str, ...]]:
"""Generate a dictionary of split name to keys for that split. It is
expected this method will be very expensive.
"""
return self._get_keys_by_split()
[docs]
def clear(self):
"""Clear any cached state."""
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout,
include_delegate: bool = False):
self._write_line('split stash splits:', depth, writer)
t = 0
for ks in self.keys_by_split.values():
t += len(ks)
for k, ks in self.keys_by_split.items():
ln = len(ks)
self._write_line(f'{k}: {ln} ({ln/t*100:.1f}%)',
depth + 1, writer)
self._write_line(f'total: {t}', depth + 1, writer)
[docs]
@dataclass
class SplitStashContainer(PrimeableStash, SplitKeyContainer,
metaclass=ABCMeta):
"""An interface like ``SplitKeyContainer``, but whose implementations are of
``Stash`` containing the instance data.
For a default implemetnation, see :class:`.DatasetSplitStash`.
"""
@abstractmethod
def _get_split_name(self) -> str:
pass
@abstractmethod
def _get_splits(self) -> Dict[str, Stash]:
pass
@property
def split_name(self) -> str:
"""Return the name of the split this stash contains. Thus, all
data/items returned by this stash are in the data set given by this name
(i.e. ``train``).
"""
return self._get_split_name()
@property
def splits(self) -> Dict[str, Stash]:
"""Return a dictionary with keys as split names and values as the
stashes represented by that split.
:see: :meth:`split_name`
"""
return self._get_splits()