"""Implementations (some abstract) of split key containers.
"""
__author__ = 'Paul Landes'
from typing import Dict, Tuple, Sequence, Set, List
from dataclasses import dataclass, field
from abc import abstractmethod, ABCMeta
import sys
import logging
import collections
from functools import reduce
import math
from io import TextIOBase
from pathlib import Path
import shutil
import parse
import random
import pandas as pd
from zensols.config import Writable
from zensols.persist import (
Primeable, persisted, PersistedWork, Stash, PersistableContainer
)
from zensols.dataset import SplitKeyContainer
from . import DatasetError
logger = logging.getLogger(__name__)
[docs]
@dataclass
class AbstractSplitKeyContainer(PersistableContainer, SplitKeyContainer,
Primeable, Writable, metaclass=ABCMeta):
"""A default implementation of a :class:`.SplitKeyContainer`. This
implementation keeps the order of the keys consistent as well, which is
stored at the path given in :obj:`key_path`. Once the keys are generated
for the first time, they will persist on the file system.
This abstract class requires an implementation of :meth:`_create_splits`.
.. document private functions
.. automethod:: _create_splits
"""
key_path: Path = field()
"""The directory to store the split keys."""
pattern: str = field()
"""The file name pattern to use for the keys file :obj:`key_path` on the
file system, each file is named after the key split. For example, if
``{name}.dat`` is used, ``train.dat`` will be a file with the ordered keys.
"""
def __post_init__(self):
super().__init__()
[docs]
def prime(self):
self._get_keys_by_split()
[docs]
@abstractmethod
def _create_splits(self) -> Dict[str, Tuple[str]]:
"""Create the key splits using keys as the split name (i.e. ``train``)
and the values as a list of the keys for the corresponding split.
"""
pass
def _create_splits_and_write(self):
"""Write the keys in order to the file system.
"""
self.key_path.mkdir(parents=True, exist_ok=True)
for name, keys in self._create_splits().items():
fname = self.pattern.format(**{'name': name})
key_path = self.key_path / fname
with open(key_path, 'w') as f:
for k in keys:
f.write(k + '\n')
def _read_splits(self):
"""Read the keys in order from the file system.
"""
by_name = {}
for path in self.key_path.iterdir():
p = parse.parse(self.pattern, path.name)
if p is not None:
p = p.named
if 'name' in p:
with open(path) as f:
by_name[p['name']] = tuple(
map(lambda ln: ln.strip(), f.readlines()))
return by_name
@persisted('_get_keys_by_split_pw')
def _get_keys_by_split(self) -> Dict[str, Tuple[str]]:
if not self.key_path.exists():
if logger.isEnabledFor(logging.INFO):
logger.info(f'creating key splits in {self.key_path}')
self._create_splits_and_write()
return self._read_splits()
[docs]
def clear(self):
logger.debug('clearing split stash')
if self.key_path.is_dir():
logger.debug('removing key path: {self.key_path}')
shutil.rmtree(self.key_path)
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
by_name = self.counts_by_key
total = sum(by_name.values())
self._write_line('key splits:', depth, writer)
for name, cnt in by_name.items():
self._write_line(f'{name}: {cnt} ({cnt/total*100:.1f}%)',
depth + 1, writer)
self._write_line(f'total: {total}', depth, writer)
[docs]
@dataclass
class StashSplitKeyContainer(AbstractSplitKeyContainer):
"""A default implementation of :class:`.AbstractSplitKeyContainer` that uses
a delegate stash for source of the keys.
"""
stash: Stash = field()
"""The delegate stash from where to get the keys to store."""
distribution: Dict[str, float] = field(
default_factory=lambda: {'train': 0.8, 'validate': 0.1, 'test': 0.1})
"""The distribution as a percent across all key splits. The distribution
values must add to 1.
"""
shuffle: bool = field(default=True)
"""If ``True``, shuffle the keys when creating the key splits.
"""
def __post_init__(self):
super().__post_init__()
sm = float(sum(self.distribution.values()))
err, errm = (1. - sm), 1e-1
if sm < 0 or sm > 1 or err > errm:
raise DatasetError('Distriubtion must add to 1: ' +
f'{self.distribution} (err={err} > errm)')
[docs]
def prime(self):
super().prime()
if isinstance(self.stash, Primeable):
self.stash.prime()
@persisted('_split_names_pw')
def _get_split_names(self) -> Tuple[str]:
return frozenset(self.distribution.keys())
def _create_splits(self) -> Dict[str, Tuple[str]]:
if self.distribution is None:
raise DatasetError('Must either provide `distribution` or ' +
'implement `_create_splits`')
by_name = {}
keys = list(self.stash.keys())
if self.shuffle:
random.shuffle(keys)
klen = len(keys)
dists = tuple(self.distribution.items())
if len(dists) > 1:
dists, last = dists[:-1], dists[-1]
else:
dists, last = (), dists[0]
start = 0
end = len(dists)
for name, dist in dists:
end = start + int((klen * dist))
by_name[name] = tuple(keys[start:end])
start = end
by_name[last[0]] = keys[start:]
for k, v in by_name.items():
print(k, len(v), len(v) / klen)
assert sum(map(len, by_name.values())) == klen
return by_name
[docs]
@dataclass
class StratifiedStashSplitKeyContainer(StashSplitKeyContainer):
"""Like :class:`.StashSplitKeyContainer` but data is stratified by a label
(:obj:`partition_attr`) across each split.
"""
partition_attr: str = field(default=None)
"""The label used to partition the strata across each split"""
stratified_write: bool = field(default=True)
"""Whether or not to include the stratified counts when writing with
:meth:`write`.
"""
split_labels_path: Path = field(default=None)
"""If provided, the path is a pickled cache of
:obj:`stratified_count_dataframe`.
"""
def __post_init__(self):
super().__post_init__()
if self.partition_attr is None:
raise DatasetError("Missing 'partition_attr' field")
dfpath = self.split_labels_path
if dfpath is None:
dfpath = '_strat_split_labels'
self._strat_split_labels = PersistedWork(dfpath, self, mkdir=True)
def _create_splits(self) -> Dict[str, Tuple[str]]:
dist_keys: Sequence[str] = self.distribution.keys()
dist_last: str = next(iter(dist_keys))
dists: Set[str] = set(dist_keys) - {dist_last}
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'spliting dataset n={len(dist_keys)}')
rows = []
for k, v in self.stash.items():
rows.append((k, getattr(v, self.partition_attr)))
df = pd.DataFrame(rows, columns=['key', self.partition_attr])
lab_splits: Dict[str, Set[str]] = collections.defaultdict(set)
for lab, dfg in df.groupby(self.partition_attr):
splits = {}
keys: List[str] = dfg['key'].to_list()
if self.shuffle:
random.shuffle(keys)
count = len(keys)
for dist in dists:
prop = self.distribution[dist]
n_samples = math.ceil(float(count) * prop)
samp = set(keys[:n_samples])
splits[dist] = samp
lab_splits[dist].update(samp)
keys = keys[n_samples:]
samp = set(keys)
splits[dist_last] = samp
lab_splits[dist_last].update(samp)
assert sum(map(len, lab_splits.values())) == len(df)
assert reduce(lambda a, b: a | b, lab_splits.values()) == \
set(df['key'].tolist())
shuf_splits = {}
for lab, keys in lab_splits.items():
if self.shuffle:
keys = list(keys)
random.shuffle(keys)
shuf_splits[lab] = tuple(keys)
return shuf_splits
def _count_proportions_by_split(self) -> Dict[str, Dict[str, str]]:
lab_counts = {}
kbs = self.keys_by_split
for split_name in sorted(kbs.keys()):
keys = kbs[split_name]
counts = collections.defaultdict(lambda: 0)
for k in keys:
item = self.stash[k]
lab = getattr(item, self.partition_attr)
counts[lab] += 1
lab_counts[split_name] = counts
return lab_counts
@property
@persisted('_strat_split_labels')
def stratified_split_labels(self) -> pd.DataFrame:
"""A dataframe with all keys, their respective labels and split.
"""
kbs = self.keys_by_split
rows = []
for split_name in sorted(kbs.keys()):
keys = kbs[split_name]
for k in keys:
item = self.stash[k]
lab = getattr(item, self.partition_attr)
rows.append((split_name, k, lab))
return pd.DataFrame(rows, columns='split_name id label'.split())
[docs]
def clear(self):
super().clear()
self._strat_split_labels.clear()
@property
def stratified_count_dataframe(self) -> pd.DataFrame:
"""A count summarization of :obj:`stratified_split_labels`.
"""
df = self.stratified_split_labels
df = df.groupby('split_name label'.split()).size().\
reset_index(name='count')
df['proportion'] = df['count'] / df['count'].sum()
df = df.sort_values('split_name label'.split()).reset_index(drop=True)
return df
def _fmt_prop_by_split(self) -> Dict[str, Dict[str, str]]:
df = self.stratified_count_dataframe
tot = df['count'].sum()
dsets: Dict[str, Dict[str, str]] = collections.OrderedDict()
for split_name, dfg in df.groupby('split_name'):
dfg['fmt'] = df['count'].apply(lambda x: f'{x} ({x/tot*100:.2f}%)')
dsets[split_name] = dict(dfg[['label', 'fmt']].values)
return dsets
[docs]
def write(self, depth: int = 0, writer: TextIOBase = sys.stdout):
if self.stratified_write:
lab_counts: Dict[str, Dict[str, str]] = self._fmt_prop_by_split()
self._write_dict(lab_counts, depth, writer)
self._write_line(f'Total: {len(self.stash)}', depth, writer)
else:
super().write(depth, writer)