Source code for zensols.persist.stash

"""Stash implementations.

"""
__author__ = 'Paul Landes'

import logging
from typing import Tuple, Dict, Iterable, Optional, Any, Callable
from dataclasses import dataclass, field, InitVar
from abc import ABCMeta
from collections import OrderedDict
from itertools import chain
import parse
import pickle
from pathlib import Path
import zensols.util.time as time
from zensols.util import APIError
from . import (
    PersistableError,
    Stash,
    ReadOnlyStash,
    DelegateStash,
    PrimablePreemptiveStash,
)

logger = logging.getLogger(__name__)


[docs] @dataclass class OneShotFactoryStash(PrimablePreemptiveStash, metaclass=ABCMeta): """A stash that is populated by a callable or an iterable *worker*. The data is generated by the worker and dumped to the delegate. This worker is either a callable (i.e. function) or an interable that return tuples or lists of (key, object) To use this class, this worker must be set as attribute ``worker`` or this class extended and ``worker`` be a property. .. document private functions .. automethod:: _get_worker_type """
[docs] def _get_worker_type(self) -> str: """Return the type of worker. This default implementation returns *unknown*. If not implemented, when :meth:`_process_work` is called, the API has use :func:`callable` to determine if the worker is a method or property. Doing so accessing the property invoking potentially unecessary work. :return: ``u`` for unknown, ``m`` for method (callable), or ``a`` for attribute or a property """ return 'u'
def _process_work(self): """Invoke the worker to generate the data and dump it to the delegate. """ wt = self._get_worker_type() if logger.isEnabledFor(logging.DEBUG): self._debug(f'processing with {type(self.worker)}: type={wt}') if wt == 'u': if callable(self.worker): wt = 'm' else: wt = 'a' if wt == 'm': itr = self.worker() elif wt == 'a': itr = self.worker else: raise APIError(f'Unknown worker type: {wt}') for id, obj in itr: self.delegate.dump(id, obj)
[docs] def prime(self): has_data = self.has_data if logger.isEnabledFor(logging.DEBUG): self._debug(f'asserting data: {has_data}') if not has_data: with time('processing work in OneShotFactoryStash'): self._process_work() self._reset_has_data()
[docs] @dataclass class SortedStash(DelegateStash): """Specify an sorting to how keys in a stash are returned. This usually also has an impact on the sort in which values are iterated since a call to get the keys determins it. """ ATTR_EXP_META = ('sort_function',) sort_function: Callable = field(default=None) def __iter__(self): return map(lambda x: (x, self.__getitem__(x),), self.keys())
[docs] def values(self) -> Iterable[Any]: return map(lambda k: self.__getitem__(k), self.keys())
[docs] def items(self) -> Tuple[str, Any]: return map(lambda k: (k, self.__getitem__(k)), self.keys())
[docs] def keys(self) -> Iterable[str]: keys = super().keys() if self.sort_function is None: keys = sorted(keys) else: keys = sorted(keys, key=self.sort_function) return keys
[docs] @dataclass class DictionaryStash(Stash): """Use a dictionary as a backing store to the stash. If one is not provided in the initializer a new ``dict`` is created. """ _data: Dict[str, Any] = field(default_factory=dict) """The backing dictionary for the stash data.""" @property def data(self): """The dictionary that contains the stash data.""" return self._data
[docs] def load(self, name: str) -> Any: return self.data.get(name)
[docs] def get(self, name: str, default: Any = None) -> Any: return self.data.get(name, default)
[docs] def exists(self, name: str) -> bool: return name in self.data
[docs] def dump(self, name: str, inst): self.data[name] = inst
[docs] def delete(self, name=None): del self.data[name]
[docs] def keys(self): return self.data.keys()
[docs] def clear(self): self.data.clear() super().clear()
def __getitem__(self, key): return self.data[key]
[docs] @dataclass(init=False) class LRUCacheStash(DictionaryStash): """A stash that removes elements on :meth:`dump` after :obj:`size` elements are kept. """ maxsize: InitVar[int] = field() """The highest number of items stored in the stash before deletions."""
[docs] def __init__(self, maxsize: int): super().__init__(_data=OrderedDict()) self._maxsize = maxsize
[docs] def load(self, name: str) -> Any: data: Dict[str, Any] = self.data item: Optional[Any] = data.get(name) if item is not None: data.move_to_end(name) return item
[docs] def dump(self, name: str, inst): data: Dict[str, Any] = self.data data[name] = inst data.move_to_end(name) while len(data) > self._maxsize: data.popitem(last=False)
[docs] @dataclass class CacheStash(DelegateStash): """Provide a dictionary based caching based stash. """ cache_stash: Stash = field(default_factory=DictionaryStash) """A stash used for caching (defaults to :class:`.DictionaryStash`)."""
[docs] def load(self, name: str): if self.cache_stash.exists(name): if logger.isEnabledFor(logging.DEBUG): self._debug(f'loading cached: {name}') return self.cache_stash.load(name) else: obj = self.delegate.load(name) if logger.isEnabledFor(logging.DEBUG): self._debug(f'loading from delegate, dumping to cache: {name}') self.cache_stash.dump(name, obj) return obj
[docs] def get(self, name: str, default: Any = None) -> Any: item = self.load(name) if item is None: item = default return item
[docs] def delete(self, name: str = None): if self.cache_stash.exists(name): self.cache_stash.delete(name) if not isinstance(self.delegate, ReadOnlyStash): self.delegate.delete(name)
[docs] def clear(self): if not isinstance(self.delegate, ReadOnlyStash): super().clear() self.cache_stash.clear()
[docs] @dataclass class DirectoryStash(Stash): """Creates a pickled data file with a file name in a directory with a given pattern across all instances. """ ATTR_EXP_META = ('path', 'pattern') path: Path = field() """The directory of where to store the files.""" pattern: str = field(default='{name}.dat') """The file name portion with ``name`` populating to the key of the data value. """ def __post_init__(self): if not isinstance(self.path, Path): raise PersistableError( f'Expecting pathlib.Path but got: {self.path.__class__}')
[docs] def assert_path_dir(self): if logger.isEnabledFor(logging.DEBUG): self._debug(f'path {self.path}: {self.path.exists()}') self.path.mkdir(parents=True, exist_ok=True)
[docs] def key_to_path(self, name: str) -> Path: """Return a path to the pickled data with key ``name``. """ fname = self.pattern.format(**{'name': name}) self.assert_path_dir() return Path(self.path, fname)
def _load_file(self, path: Path) -> Any: with open(path, 'rb') as f: try: return pickle.load(f) except Exception as e: raise PersistableError(f"Can not read {path}: {e}") from e def _dump_file(self, inst: Any, path: Path): with open(path, 'wb') as f: pickle.dump(inst, f)
[docs] def load(self, name: str) -> Any: path = self.key_to_path(name) inst = None if path.exists(): if logger.isEnabledFor(logging.DEBUG): self._debug(f'loading instance from {path}') inst = self._load_file(path) if logger.isEnabledFor(logging.DEBUG): self._debug(f'loaded instance: {name} -> {type(inst)}') return inst
[docs] def exists(self, name) -> bool: path = self.key_to_path(name) return path.exists()
[docs] def keys(self) -> Iterable[str]: def path_to_key(path): p = parse.parse(self.pattern, path.name) # avoid files that don't match the pattern if p is not None: p = p.named if 'name' in p: return p['name'] if logger.isEnabledFor(logging.DEBUG): self._debug(f'checking path {self.path} ({type(self)})') if not self.path.is_dir(): keys = () else: keys = filter(lambda x: x is not None, map(path_to_key, self.path.iterdir())) return keys
[docs] def dump(self, name: str, inst: Any): if logger.isEnabledFor(logging.DEBUG): self._debug(f'saving instance: {name} -> {type(inst)}') path = self.key_to_path(name) self._dump_file(inst, path)
[docs] def delete(self, name: str): if logger.isEnabledFor(logging.DEBUG): self._debug(f'deleting instance: {name}') path = self.key_to_path(name) if path.exists(): path.unlink() else: if logger.isEnabledFor(logging.WARNING): logger.warning(f'does not exist: {name}')
[docs] def close(self): pass
[docs] @dataclass class IncrementKeyDirectoryStash(DirectoryStash): """A stash that increments integer value keys in a stash and dumps/loads using the last key available in the stash. """ name: InitVar[str] = field(default='data') """The name of the :obj:`pattern` to use in the super class.""" def __post_init__(self, name: str): super().__post_init__() self.pattern = name + '-{name}.dat' self._last_key = None
[docs] def get_last_key(self, inc: bool = False) -> str: """Get the last available (highest number) keys in the stash. """ if self._last_key is None: keys = tuple(map(int, self.keys())) if len(keys) == 0: key = 0 else: key = max(keys) self._last_key = key if inc: self._last_key += 1 return str(self._last_key)
[docs] def keys(self) -> Iterable[str]: def is_good_key(data): try: int(data) return True except ValueError: return False return filter(is_good_key, super().keys())
[docs] def dump(self, name_or_inst, inst=None): """If only one argument is given, it is used as the data and the key name is derived from ``get_last_key``. """ if inst is None: key = self.get_last_key(True) inst = name_or_inst else: key = name_or_inst path = self.key_to_path(key) if logger.isEnabledFor(logging.DEBUG): self._debug(f'dumping result {self.name} to {path}') super().dump(key, inst)
[docs] def load(self, name: str = None) -> Any: """Just like ``Stash.load``, but if the key is omitted, return the value of the last key in the stash. """ if name is None: name = self.get_last_key(False) if len(self) > 0: return super().load(name)
[docs] @dataclass class UnionStash(ReadOnlyStash): """A stash joins the data of many other stashes. """ stashes: Tuple[Stash, ...] = field() """The delegate constituent stashes used for each operation."""
[docs] def load(self, name: str) -> Any: item = None for stash in self.stashes: item = stash.load(name) if item is not None: break return item
[docs] def get(self, name: str, default: Any = None) -> Any: return self.load(name)
[docs] def values(self) -> Iterable[Any]: return chain.from_iterable(map(lambda s: s.values(), self.stashes))
[docs] def keys(self) -> Iterable[str]: return chain.from_iterable(map(lambda s: s.keys(), self.stashes))
[docs] def exists(self, name: str) -> bool: for stash in self.stashes: if stash.exists(name): return True return False