"""Stash implementations.
"""
__author__ = 'Paul Landes'
import logging
from typing import Tuple, Dict, Iterable, Optional, Any, Callable
from dataclasses import dataclass, field, InitVar
from abc import ABCMeta
from collections import OrderedDict
from itertools import chain
import parse
import pickle
from pathlib import Path
import zensols.util.time as time
from zensols.util import APIError
from . import (
PersistableError,
Stash,
ReadOnlyStash,
DelegateStash,
PrimablePreemptiveStash,
)
logger = logging.getLogger(__name__)
[docs]
@dataclass
class OneShotFactoryStash(PrimablePreemptiveStash, metaclass=ABCMeta):
"""A stash that is populated by a callable or an iterable *worker*. The
data is generated by the worker and dumped to the delegate. This worker is
either a callable (i.e. function) or an interable that return tuples or
lists of (key, object)
To use this class, this worker must be set as attribute ``worker`` or this
class extended and ``worker`` be a property.
.. document private functions
.. automethod:: _get_worker_type
"""
[docs]
def _get_worker_type(self) -> str:
"""Return the type of worker. This default implementation returns
*unknown*. If not implemented, when :meth:`_process_work` is called,
the API has use :func:`callable` to determine if the worker is a method
or property. Doing so accessing the property invoking potentially
unecessary work.
:return: ``u`` for unknown, ``m`` for method (callable), or ``a`` for
attribute or a property
"""
return 'u'
def _process_work(self):
"""Invoke the worker to generate the data and dump it to the delegate.
"""
wt = self._get_worker_type()
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'processing with {type(self.worker)}: type={wt}')
if wt == 'u':
if callable(self.worker):
wt = 'm'
else:
wt = 'a'
if wt == 'm':
itr = self.worker()
elif wt == 'a':
itr = self.worker
else:
raise APIError(f'Unknown worker type: {wt}')
for id, obj in itr:
self.delegate.dump(id, obj)
[docs]
def prime(self):
has_data = self.has_data
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'asserting data: {has_data}')
if not has_data:
with time('processing work in OneShotFactoryStash'):
self._process_work()
self._reset_has_data()
[docs]
@dataclass
class SortedStash(DelegateStash):
"""Specify an sorting to how keys in a stash are returned. This usually
also has an impact on the sort in which values are iterated since a call to
get the keys determins it.
"""
ATTR_EXP_META = ('sort_function',)
sort_function: Callable = field(default=None)
def __iter__(self):
return map(lambda x: (x, self.__getitem__(x),), self.keys())
[docs]
def values(self) -> Iterable[Any]:
return map(lambda k: self.__getitem__(k), self.keys())
[docs]
def items(self) -> Tuple[str, Any]:
return map(lambda k: (k, self.__getitem__(k)), self.keys())
[docs]
def keys(self) -> Iterable[str]:
keys = super().keys()
if self.sort_function is None:
keys = sorted(keys)
else:
keys = sorted(keys, key=self.sort_function)
return keys
[docs]
@dataclass
class DictionaryStash(Stash):
"""Use a dictionary as a backing store to the stash. If one is not provided
in the initializer a new ``dict`` is created.
"""
_data: Dict[str, Any] = field(default_factory=dict)
"""The backing dictionary for the stash data."""
@property
def data(self):
"""The dictionary that contains the stash data."""
return self._data
[docs]
def load(self, name: str) -> Any:
return self.data.get(name)
[docs]
def get(self, name: str, default: Any = None) -> Any:
return self.data.get(name, default)
[docs]
def exists(self, name: str) -> bool:
return name in self.data
[docs]
def dump(self, name: str, inst):
self.data[name] = inst
[docs]
def delete(self, name=None):
del self.data[name]
[docs]
def keys(self):
return self.data.keys()
[docs]
def clear(self):
self.data.clear()
super().clear()
def __getitem__(self, key):
return self.data[key]
[docs]
@dataclass(init=False)
class LRUCacheStash(DictionaryStash):
"""A stash that removes elements on :meth:`dump` after :obj:`size` elements
are kept.
"""
maxsize: InitVar[int] = field()
"""The highest number of items stored in the stash before deletions."""
[docs]
def __init__(self, maxsize: int):
super().__init__(_data=OrderedDict())
self._maxsize = maxsize
[docs]
def load(self, name: str) -> Any:
data: Dict[str, Any] = self.data
item: Optional[Any] = data.get(name)
if item is not None:
data.move_to_end(name)
return item
[docs]
def dump(self, name: str, inst):
data: Dict[str, Any] = self.data
data[name] = inst
data.move_to_end(name)
while len(data) > self._maxsize:
data.popitem(last=False)
[docs]
@dataclass
class CacheStash(DelegateStash):
"""Provide a dictionary based caching based stash.
"""
cache_stash: Stash = field(default_factory=DictionaryStash)
"""A stash used for caching (defaults to :class:`.DictionaryStash`)."""
[docs]
def load(self, name: str):
if self.cache_stash.exists(name):
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'loading cached: {name}')
return self.cache_stash.load(name)
else:
obj = self.delegate.load(name)
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'loading from delegate, dumping to cache: {name}')
self.cache_stash.dump(name, obj)
return obj
[docs]
def get(self, name: str, default: Any = None) -> Any:
item = self.load(name)
if item is None:
item = default
return item
[docs]
def delete(self, name: str = None):
if self.cache_stash.exists(name):
self.cache_stash.delete(name)
if not isinstance(self.delegate, ReadOnlyStash):
self.delegate.delete(name)
[docs]
def clear(self):
if not isinstance(self.delegate, ReadOnlyStash):
super().clear()
self.cache_stash.clear()
[docs]
@dataclass
class DirectoryStash(Stash):
"""Creates a pickled data file with a file name in a directory with a given
pattern across all instances.
"""
ATTR_EXP_META = ('path', 'pattern')
path: Path = field()
"""The directory of where to store the files."""
pattern: str = field(default='{name}.dat')
"""The file name portion with ``name`` populating to the key of the data
value.
"""
def __post_init__(self):
if not isinstance(self.path, Path):
raise PersistableError(
f'Expecting pathlib.Path but got: {self.path.__class__}')
[docs]
def assert_path_dir(self):
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'path {self.path}: {self.path.exists()}')
self.path.mkdir(parents=True, exist_ok=True)
[docs]
def key_to_path(self, name: str) -> Path:
"""Return a path to the pickled data with key ``name``.
"""
fname = self.pattern.format(**{'name': name})
self.assert_path_dir()
return Path(self.path, fname)
def _load_file(self, path: Path) -> Any:
with open(path, 'rb') as f:
try:
return pickle.load(f)
except Exception as e:
raise PersistableError(f"Can not read {path}: {e}") from e
def _dump_file(self, inst: Any, path: Path):
with open(path, 'wb') as f:
pickle.dump(inst, f)
[docs]
def load(self, name: str) -> Any:
path = self.key_to_path(name)
inst = None
if path.exists():
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'loading instance from {path}')
inst = self._load_file(path)
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'loaded instance: {name} -> {type(inst)}')
return inst
[docs]
def exists(self, name) -> bool:
path = self.key_to_path(name)
return path.exists()
[docs]
def keys(self) -> Iterable[str]:
def path_to_key(path):
p = parse.parse(self.pattern, path.name)
# avoid files that don't match the pattern
if p is not None:
p = p.named
if 'name' in p:
return p['name']
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'checking path {self.path} ({type(self)})')
if not self.path.is_dir():
keys = ()
else:
keys = filter(lambda x: x is not None,
map(path_to_key, self.path.iterdir()))
return keys
[docs]
def dump(self, name: str, inst: Any):
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'saving instance: {name} -> {type(inst)}')
path = self.key_to_path(name)
self._dump_file(inst, path)
[docs]
def delete(self, name: str):
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'deleting instance: {name}')
path = self.key_to_path(name)
if path.exists():
path.unlink()
else:
if logger.isEnabledFor(logging.WARNING):
logger.warning(f'does not exist: {name}')
[docs]
@dataclass
class IncrementKeyDirectoryStash(DirectoryStash):
"""A stash that increments integer value keys in a stash and dumps/loads
using the last key available in the stash.
"""
name: InitVar[str] = field(default='data')
"""The name of the :obj:`pattern` to use in the super class."""
def __post_init__(self, name: str):
super().__post_init__()
self.pattern = name + '-{name}.dat'
self._last_key = None
[docs]
def get_last_key(self, inc: bool = False) -> str:
"""Get the last available (highest number) keys in the stash.
"""
if self._last_key is None:
keys = tuple(map(int, self.keys()))
if len(keys) == 0:
key = 0
else:
key = max(keys)
self._last_key = key
if inc:
self._last_key += 1
return str(self._last_key)
[docs]
def keys(self) -> Iterable[str]:
def is_good_key(data):
try:
int(data)
return True
except ValueError:
return False
return filter(is_good_key, super().keys())
[docs]
def dump(self, name_or_inst, inst=None):
"""If only one argument is given, it is used as the data and the key
name is derived from ``get_last_key``.
"""
if inst is None:
key = self.get_last_key(True)
inst = name_or_inst
else:
key = name_or_inst
path = self.key_to_path(key)
if logger.isEnabledFor(logging.DEBUG):
self._debug(f'dumping result {self.name} to {path}')
super().dump(key, inst)
[docs]
def load(self, name: str = None) -> Any:
"""Just like ``Stash.load``, but if the key is omitted, return the value
of the last key in the stash.
"""
if name is None:
name = self.get_last_key(False)
if len(self) > 0:
return super().load(name)
[docs]
@dataclass
class UnionStash(ReadOnlyStash):
"""A stash joins the data of many other stashes.
"""
stashes: Tuple[Stash, ...] = field()
"""The delegate constituent stashes used for each operation."""
[docs]
def load(self, name: str) -> Any:
item = None
for stash in self.stashes:
item = stash.load(name)
if item is not None:
break
return item
[docs]
def get(self, name: str, default: Any = None) -> Any:
return self.load(name)
[docs]
def values(self) -> Iterable[Any]:
return chain.from_iterable(map(lambda s: s.values(), self.stashes))
[docs]
def keys(self) -> Iterable[str]:
return chain.from_iterable(map(lambda s: s.keys(), self.stashes))
[docs]
def exists(self, name: str) -> bool:
for stash in self.stashes:
if stash.exists(name):
return True
return False