"""Utility stashes useful to common machine learning tasks.
"""
__author__ = 'Paul Landes'
import sys
import logging
from typing import Iterable, Dict, Set, Callable, Tuple, Any, List
from dataclasses import dataclass, field
from itertools import chain
from collections import OrderedDict
from io import TextIOBase
from zensols.util import time
from zensols.config import Writable
from zensols.persist import (
PersistedWork,
PersistableContainer,
persisted,
Stash,
DelegateStash,
PreemptiveStash,
)
from zensols.dataset import SplitStashContainer, SplitKeyContainer
from . import DatasetError
logger = logging.getLogger(__name__)
[docs]
@dataclass
class DatasetSplitStash(DelegateStash, SplitStashContainer,
PersistableContainer, Writable):
"""A default implementation of :class:`.SplitStashContainer`. However, it
needs an instance of a :class:`.SplitKeyContainer`. This implementation
generates a separate stash instance for each data set split (i.e. ``train``
vs ``test``). Each split instance holds the data (keys and values) for
each split.
Stash instances by split are obtained with ``splits``, and will have
a ``split`` attribute that give the name of the split.
To maintain reproducibility, key ordering must be considered (see
:class:`.SortedDatasetSplitStash`).
:see: :meth:`.SplitStashContainer.splits`
"""
split_container: SplitKeyContainer = field()
"""The instance that provides the splits in the dataset."""
def __post_init__(self):
super().__post_init__()
PersistableContainer.__init__(self)
if not isinstance(self.split_container, SplitKeyContainer):
raise DatasetError('Expecting type SplitKeyContainer but ' +
f'got: {type(self.split_container)}')
self._inst_split_name = None
self._keys_by_split = PersistedWork('_keys_by_split', self)
self._splits = PersistedWork('_splits', self)
def _add_keys(self, split_name: str, to_populate: Dict[str, str],
keys: List[str]):
to_populate[split_name] = tuple(keys)
@persisted('_keys_by_split')
def _get_keys_by_split(self) -> Dict[str, Tuple[str]]:
"""Return keys by split type (i.e. ``train`` vs ``test``) for only those
keys available by the delegate backing stash.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug('creating in memory available keys data structure')
with time('created key data structures', logging.DEBUG):
delegate_keys = set(self.delegate.keys())
avail_kbs = OrderedDict()
for split, keys in self.split_container.keys_by_split.items():
ks = list()
for k in keys:
if k in delegate_keys:
ks.append(k)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'{split} has {len(ks)} keys')
self._add_keys(split, avail_kbs, ks)
return avail_kbs
def _get_counts_by_key(self) -> Dict[str, int]:
return dict(map(lambda i: (i[0], len(i[1])),
self.keys_by_split.items()))
[docs]
def check_key_consistent(self) -> bool:
"""Return if the :obj:`split_container` have the same key count
divisiion as this stash's split counts.
"""
return self.counts_by_key == self.split_container.counts_by_key
[docs]
def keys(self) -> Iterable[str]:
self.prime()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'keys for {self.split_name}')
kbs = self.keys_by_split
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'obtained keys for {self.split_name}')
if self.split_name is None:
return chain.from_iterable(kbs.values())
else:
return kbs[self.split_name]
[docs]
def exists(self, name: str) -> bool:
if self.split_name is None:
return super().exists(name)
else:
return name in self.keys_by_split[self.split_name]
[docs]
def load(self, name: str) -> Any:
if self.split_name is None or \
name in self.keys_by_split[self.split_name]:
return super().load(name)
[docs]
def get(self, name: str, default: Any = None) -> Any:
if self.split_name is None or \
name in self.keys_by_split[self.split_name]:
return super().get(name)
return default
[docs]
def prime(self):
if logger.isEnabledFor(logging.DEBUG):
logger.debug('priming ds split stash')
super().prime()
self.keys_by_split
def _delegate_has_data(self):
return not isinstance(self.delegate, PreemptiveStash) or \
self.delegate.has_data
[docs]
def deallocate(self):
if id(self.delegate) != id(self.split_container):
self._try_deallocate(self.delegate)
self._try_deallocate(self.split_container)
self._keys_by_split.deallocate()
if self._splits.is_set():
splits = tuple(self._splits().values())
self._splits.clear()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'deallocating: {len(splits)} stash data splits')
for v in splits:
self._try_deallocate(v, recursive=True)
self._splits.deallocate()
super().deallocate()
[docs]
def clear_keys(self):
"""Clear any cache state for keys, and keys by split. It does this by
clearing the key state for stash, and then the :meth:`clear` of the
:obj:`split_container`.
"""
self.split_container.clear()
self._keys_by_split.clear()
[docs]
def clear(self):
"""Clear and destory key and delegate data.
"""
del_has_data = self._delegate_has_data()
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'clearing: {del_has_data}')
if del_has_data:
if logger.isEnabledFor(logging.DEBUG):
logger.debug('clearing delegate and split container')
super().clear()
self.clear_keys()
def _get_split_names(self) -> Set[str]:
return self.split_container.split_names
def _get_split_name(self) -> str:
return self._inst_split_name
@persisted('_splits')
def _get_splits(self) -> Dict[str, Stash]:
"""Return an instance of ta stash that contains only the data for a
split.
:param split: the name of the split of the instance to get
(i.e. ``train``, ``test``).
"""
self.prime()
stashes = OrderedDict()
for split_name in self.split_names:
clone = self.__class__(
delegate=self.delegate, split_container=self.split_container)
clone._keys_by_split.deallocate()
clone._splits.deallocate()
clone.__dict__.update(self.__dict__)
clone._inst_split_name = split_name
stashes[split_name] = clone
return stashes
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout,
include_delegate: bool = False):
self._write_line('split stash splits:', depth, writer)
t = 0
for ks in self.split_container.keys_by_split.values():
t += len(ks)
for k, ks in self.split_container.keys_by_split.items():
ln = len(ks)
self._write_line(f'{k}: {ln} ({ln/t*100:.1f}%)',
depth + 1, writer)
self._write_line(f'total: {t}', depth + 1, writer)
ckc = self.check_key_consistent()
self._write_line(f'total this instance: {len(self)}', depth, writer)
self._write_line(f'keys consistent: {ckc}', depth, writer)
if include_delegate and isinstance(self.delegate, Writable):
self._write_line('delegate:', depth, writer)
self.delegate.write(depth + 1, writer)
[docs]
@dataclass
class SortedDatasetSplitStash(DatasetSplitStash):
"""A sorted version of a :class:`DatasetSplitStash`, where keys, values,
items and iterations are sorted by key. This is important for
reproducibility of results.
An alternative is to use :class:`.DatasetSplitStash` with an instance of
:class:`.StashSplitKeyContainer` set as the :obj:`delegate` since the key
container keeps key ordering consistent.
Any shuffling of the dataset, for the sake of training on non-uniform data,
needs to come *before* using this class. This class also sorts the keys in
each split given in :obj:`splits`.
"""
ATTR_EXP_META = ('sort_function',)
sort_function: Callable = field(default=None)
"""A function, such as ``int``, used to sort keys per data set split."""
def __iter__(self):
return map(lambda x: (x, self.__getitem__(x),), self.keys())
[docs]
def values(self) -> Iterable[Any]:
return map(lambda k: self.__getitem__(k), self.keys())
[docs]
def items(self) -> Tuple[str, Any]:
return map(lambda k: (k, self.__getitem__(k)), self.keys())
def _add_keys(self, split_name: str, to_populate: Dict[str, str],
keys: List[str]):
to_populate[split_name] = tuple(sorted(keys, key=self.sort_function))
[docs]
def keys(self) -> Iterable[str]:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'sort function: {self.sort_function} ' +
f'({self.sort_function})')
keys = super().keys()
if self.sort_function is None:
keys = sorted(keys)
else:
keys = sorted(keys, key=self.sort_function)
return keys