Source code for zensols.deeplearn.batch.domain

"""This file contains a stash used to load an embedding layer.  It creates
features in batches of matrices and persists matrix only (sans features) for
efficient retrival.

"""
from __future__ import annotations
__author__ = 'Paul Landes'
from typing import Tuple, List, Any, Dict, Union, Set
from dataclasses import dataclass, field
from abc import ABCMeta, abstractmethod
import sys
import logging
from io import TextIOBase
import torch
from torch import Tensor
import collections
from zensols.util import time
from zensols.config import Writable
from zensols.persist import (
    persisted,
    PersistedWork,
    PersistableContainer,
    Deallocatable,
)
from zensols.deeplearn import DeepLearnError, TorchConfig
from zensols.deeplearn.vectorize import (
    FeatureContext,
    NullFeatureContext,
    FeatureVectorizer,
    FeatureVectorizerManager,
    FeatureVectorizerManagerSet,
    CategoryEncodableFeatureVectorizer,
)
from . import (
    BatchError,
    FieldFeatureMapping,
    ManagerFeatureMapping,
    BatchFeatureMapping,
)

logger = logging.getLogger(__name__)


[docs] @dataclass class DataPoint(Writable, metaclass=ABCMeta): """Abstract class that makes up a container class for features created from sentences. """ id: int = field() """The ID of this data point, which maps back to the ``BatchStash`` instance's subordinate stash. """ batch_stash: BatchStash = field(repr=False) """Ephemeral instance of the stash used during encoding only."""
[docs] def write(self, depth: int = 0, writer: TextIOBase = sys.stdout): self._write_line(f'id: {self.id}', depth, writer)
def __getstate__(self): raise DeepLearnError('Data points should not be pickeled')
[docs] @dataclass class Batch(PersistableContainer, Writable, metaclass=ABCMeta): """Contains a batch of data used in the first layer of a net. This class holds the labels, but is otherwise useless without at least one embedding layer matrix defined. The user must subclass, add mapping meta data, and optionally (suggested) add getters and/or properties for the specific data so the model can by more *Pythonic* in the PyTorch :class:`torch.nn.Module`. """ STATES = {'n': 'nascent', 'e': 'encoded', 'd': 'decoded', 't': 'memory copied', 'k': 'deallocated'} """A human friendly mapping of the encoded states.""" _PERSITABLE_TRANSIENT_ATTRIBUTES = {'_data_points_val'} batch_stash: BatchStash = field(repr=False) """Ephemeral instance of the stash used during encoding and decoding.""" id: int = field() """The ID of this batch instance, which is the sequence number of the batch given during child processing of the chunked data point ID setes. """ split_name: str = field() """The name of the split for this batch (i.e. ``train`` vs ``test``).""" data_points: Tuple[DataPoint, ...] = field(repr=False) """The list of the data points given on creation for encoding, and ``None``'d out after encoding/pickinglin. """ def __post_init__(self): super().__init__() if hasattr(self, '_data_points_val') and \ self._data_points_val is not None: self.data_point_ids = tuple(map(lambda d: d.id, self.data_points)) self._decoded_state = PersistedWork( '_decoded_state', self, transient=True) self.state = 'n' @property def _data_points(self) -> Tuple[DataPoint, ...]: """The data points used to create this batch. If the batch does not contain the data points, which is the case when it has been decoded, then they are retrieved from the :obj:`batch_stash` instance's feature stash. """ if not hasattr(self, '_data_points_val') or \ self._data_points_val is None: stash: BatchStash = self.batch_stash self._data_points_val = stash._get_data_points_for_batch(self) return self._data_points_val @_data_points.setter def _data_points(self, data_points: Tuple[DataPoint, ...]): self._data_points_val = data_points @abstractmethod def _get_batch_feature_mappings(self) -> BatchFeatureMapping: """Return the feature mapping meta data for this batch and it's data points. It is best to define a class level instance of the mapping and return it here to avoid instancing for each batch. :see: :class:`.BatchFeatureMapping` """ pass
[docs] def get_labels(self) -> torch.Tensor: """Return the label tensor for this batch. """ bmap: BatchFeatureMapping = self._get_batch_feature_mappings() if bmap is None: raise DeepLearnError('No batch feature mapping set') label_attr = bmap.label_attribute_name return self.attributes[label_attr]
@property @persisted('_has_labels', transient=True) def has_labels(self) -> bool: """Return whether or not this batch has labels. If it doesn't, it is a batch used for prediction. """ return self.get_labels() is not None
[docs] def get_label_classes(self) -> List[str]: """Return the labels in this batch in their string form. This assumes the label vectorizer is instance of :class:`~zensols.deeplearn.vectorize.CategoryEncodableFeatureVectorizer`. :return: the reverse mapped, from nominal values, labels """ vec: FeatureVectorizer = self.get_label_feature_vectorizer() if not isinstance(vec, CategoryEncodableFeatureVectorizer): raise BatchError( 'Reverse label decoding is only supported with type of ' + 'CategoryEncodableFeatureVectorizer, but got: ' + f'{vec} ({(type(vec))})') return vec.get_classes(self.get_labels().cpu())
[docs] def get_label_feature_vectorizer(self) -> FeatureVectorizer: """Return the label vectorizer used in the batch. This assumes there's only one vectorizer found in the vectorizer manager. :param batch: used to access the vectorizer set via the batch stash """ mapping: BatchFeatureMapping = self._get_batch_feature_mappings() field_name: str = mapping.label_attribute_name mng, f = mapping.get_field_map_by_attribute(field_name) vec_name: str = mng.vectorizer_manager_name vec_mng_set = self.batch_stash.vectorizer_manager_set vec: FeatureVectorizerManager = vec_mng_set[vec_name] return vec[f.feature_id]
[docs] def size(self) -> int: """Return the size of this batch, which is the number of data points. """ return len(self.data_point_ids)
@property def attributes(self) -> Dict[str, torch.Tensor]: """Return the attribute batched tensors as a dictionary using the attribute names as the keys. """ return self._get_decoded_state()
[docs] def write(self, depth: int = 0, writer: TextIOBase = sys.stdout, include_data_points: bool = False): self._write_line(self.__class__.__name__, depth, writer) self._write_line(f'size: {self.size()}', depth + 1, writer) for k, v in self.attributes.items(): shape = None if v is None else v.shape self._write_line(f'{k}: {shape}', depth + 2, writer) if include_data_points: self._write_line('data points:', depth + 1, writer) for dp in self.get_data_points(): dp.write(depth + 2, writer)
@property def _feature_contexts(self) -> \ Dict[str, Dict[str, Union[FeatureContext, Tuple[FeatureContext, ...]]]]: has_ctx = hasattr(self, '_feature_context_inst') if logger.isEnabledFor(logging.DEBUG): logger.debug(f'has feature contexts: {has_ctx}') if has_ctx: if self._feature_context_inst is None: raise BatchError('Bad state transition, null contexts') else: with time(f'encoded batch {self.id}'): self._feature_context_inst = self._encode() if logger.isEnabledFor(logging.INFO): logger.info(f'access context: (state={self.state}), num keys=' + f'{len(self._feature_context_inst.keys())}') return self._feature_context_inst @_feature_contexts.setter def _feature_contexts(self, contexts: Dict[str, Dict[ str, Union[FeatureContext, Tuple[FeatureContext, ...]]]]): if logger.isEnabledFor(logging.DEBUG): obj = 'None' if contexts is None else contexts.keys() logger.debug(f'setting context: {obj}') self._feature_context_inst = contexts def __getstate__(self): if logger.isEnabledFor(logging.INFO): logger.info(f'create batch state {self.id} (state={self.state})') assert self.state == 'n' if not hasattr(self, '_feature_context_inst'): self._feature_contexts self.state = 'e' state = super().__getstate__() state.pop('batch_stash', None) state.pop('data_points', None) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'state keys: {state.keys()}') return state def __setstate__(self, state): super().__setstate__(state) if logger.isEnabledFor(logging.INFO): logger.info(f'unpickling batch: {self.id}') @persisted('_decoded_state') def _get_decoded_state(self): """Decode the pickeled attriubtes after loaded by containing ``BatchStash`` and remove the context information to save memory. """ assert self.state == 'e' if logger.isEnabledFor(logging.DEBUG): logger.debug(f'decoding ctxs: {self._feature_context_inst.keys()}') assert self._feature_context_inst is not None with time(f'decoded batch {self.id}'): attribs = self._decode(self._feature_contexts) self._feature_contexts = None assert self._feature_context_inst is None self.state = 'd' if logger.isEnabledFor(logging.DEBUG): logger.debug(f'return decoded attributes: {attribs.keys()}') return attribs @property def torch_config(self) -> TorchConfig: """The torch config used to copy from CPU to GPU memory.""" return self.batch_stash.model_torch_config def _clone(self) -> Batch: def to(arr: Tensor) -> Tensor: if arr is not None: arr = torch_config.to(arr) return arr torch_config = self.torch_config attribs = self._get_decoded_state() attribs = {k: to(attribs[k]) for k in attribs.keys()} inst = self.__class__(self.batch_stash, self.id, self.split_name, None) inst.data_point_ids = self.data_point_ids inst._decoded_state.set(attribs) inst.state = 't' return inst
[docs] def to(self) -> Batch: """Clone this instance and copy data to the CUDA device configured in the batch stash. :return: a clone of this instance with all attribute tensors copied to the given torch configuration device """ if self.state == 't': inst = self else: inst = self._clone() return inst
[docs] def deallocate(self): with time('deallocated attribute', logging.DEBUG): if self.state == 'd' or self.state == 't': attrs = self.attributes for arr in tuple(attrs.values()): del arr attrs.clear() del attrs self._decoded_state.deallocate() if hasattr(self, 'batch_stash'): del self.batch_stash if hasattr(self, 'data_point_ids'): del self.data_point_ids if hasattr(self, '_data_points_val'): Deallocatable._try_deallocate(self._data_points_val) del self._data_points_val with time('deallocated feature context', logging.DEBUG): if hasattr(self, '_feature_context_inst') and \ self._feature_context_inst is not None: for ctx in self._feature_context_inst.values(): self._try_deallocate(ctx) self._feature_context_inst.clear() del self._feature_context_inst self.state = 'k' super().deallocate() logger.debug(f'deallocated batch: {self.id}')
def _encode_field(self, vec: FeatureVectorizer, fm: FieldFeatureMapping, vals: List[Any]) -> FeatureContext: """Encode a set of features in to feature contexts: :param vec: the feature vectorizer to use to create the context :param fm: the field metadata for the feature values :param vals: a list of feature input values used to create the context :see: :class:`.BatchFeatureMapping` """ if fm.is_agg: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'encoding aggregate with {vec}') ctx = vec.encode(vals) else: ctx = tuple(map(lambda v: vec.encode(v), vals)) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'encoded: {ctx.__class__}') return ctx def _decode_context(self, vec: FeatureVectorizer, ctx: FeatureContext, fm: FieldFeatureMapping) -> torch.Tensor: """Decode ``ctx`` in to a tensor using vectorizer ``vec``. """ if logger.isEnabledFor(logging.DEBUG): logger.debug(f'decode with {fm}') if isinstance(ctx, tuple): arrs = tuple(map(vec.decode, ctx)) try: arr = torch.cat(arrs) except Exception as e: raise BatchError( 'Batch has inconsistent data point length, eg magic ' + f'bedding or using combine_sentences for NLP for: {vec}') \ from e if logger.isEnabledFor(logging.DEBUG): logger.debug(f'decodeed shape for {fm}: {arr.shape}') else: arr = vec.decode(ctx) return arr def _is_missing(self, aval: Union[Any, Tuple[Any, ...]]): return (aval is None) or \ (isinstance(aval, (tuple, list)) and all(v is None for v in aval)) def _encode(self) -> \ Dict[str, Dict[str, Union[FeatureContext, Tuple[FeatureContext, ...]]]]: """Called to create all matrices/arrays needed for the layer. After this is called, features in this instance are removed for so pickling is fast. The returned data structure has the following form: - feature vector manager name - attribute name -> feature context where feature context can be either a single context or a tuple of context. If it is a tuple, then each is decoded in turn and the resulting matrices will be concatenated together at decode time with ``_decode_context``. Note that the feature id is an attribute of the feature context. :see: meth:`_decode_context` """ vms = self.batch_stash.vectorizer_manager_set attrib_to_ctx = collections.OrderedDict() bmap: BatchFeatureMapping = self._get_batch_feature_mappings() label_attr: str = bmap.label_attribute_name if logger.isEnabledFor(logging.DEBUG): logger.debug(f"encoding with label: '{label_attr}' using {vms}") mmap: ManagerFeatureMapping for mmap in bmap.manager_mappings: vm: FeatureVectorizerManager = vms[mmap.vectorizer_manager_name] fm: FieldFeatureMapping for fm in mmap.fields: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'field: {fm}') if fm.feature_id in attrib_to_ctx: raise BatchError(f'Duplicate feature: {fm.feature_id}') vec: FeatureVectorizer = vm[fm.feature_id] avals: List[Any] = [] ctx: FeatureContext = None dp: DataPoint for dp in self.data_points: aval = getattr(dp, fm.attribute_accessor) avals.append(aval) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'attr: {fm.attr} => {aval.__class__}') try: is_label = fm.is_label or (label_attr == fm.attr) if is_label and self._is_missing(aval): # assume prediction if logger.isEnabledFor(logging.DEBUG): logger.debug('skipping missing label') ctx = NullFeatureContext(fm.feature_id) else: ctx = self._encode_field(vec, fm, avals) except Exception as e: raise BatchError( f'Could not vectorize {fm} using {vec}: {e}') from e if ctx is not None: attrib_to_ctx[fm.attr] = ctx return attrib_to_ctx def _decode(self, ctx: Dict[str, Dict[str, Union[FeatureContext, Tuple[FeatureContext, ...]]]]): """Called to create all matrices/arrays needed for the layer. After this is called, features in this instance are removed for so pickling is fast. :param ctx: the context to decode """ attribs = collections.OrderedDict() attrib_keeps: Set[str] = self.batch_stash.decoded_attributes if attrib_keeps is not None: attrib_keeps = set(attrib_keeps) bmap: BatchFeatureMapping = self._get_batch_feature_mappings() label_attr: str = bmap.label_attribute_name vms: FeatureVectorizerManagerSet = \ self.batch_stash.vectorizer_manager_set for attrib, ctx in ctx.items(): mng_fmap = bmap.get_field_map_by_attribute(attrib) if mng_fmap is None: raise BatchError( f'Missing mapped attribute \'{attrib}\' on decode') mng, fmap = mng_fmap mmap_name = mng.vectorizer_manager_name feature_id = fmap.feature_id vm: FeatureVectorizerManager = vms[mmap_name] if logger.isEnabledFor(logging.DEBUG): logger.debug(f'vec manager: {mmap_name} -> {vm}') # keep only the desired feature subset for speed up if attrib_keeps is not None and attrib not in attrib_keeps: continue if attrib_keeps is not None: attrib_keeps.remove(attrib) if isinstance(ctx, tuple): feature_id = ctx[0].feature_id elif ctx is None: feature_id = None else: feature_id = ctx.feature_id vec: FeatureVectorizer = vm.get(feature_id) if vec is None: raise BatchError( f'No such vectorizer for feature ID: {feature_id}') if logger.isEnabledFor(logging.DEBUG): logger.debug(f'decoding {ctx} with {vec}') arr: Tensor = self._decode_context(vec, ctx, fmap) if arr is None and fmap.attr != label_attr: raise BatchError( f'No decoded value for {fmap}, which is not ' + f"the label attribute '{label_attr}'") if logger.isEnabledFor(logging.DEBUG): shape = '<none>' if arr is None else arr.shape logger.debug(f'decoded: {attrib} -> {shape}') if attrib in attribs: raise BatchError( f'Attribute collision on decode: {attrib}') attribs[attrib] = arr if attrib_keeps is not None and len(attrib_keeps) > 0: raise BatchError(f'Unknown attriubtes: {attrib_keeps}') return attribs @property def state_name(self): return self.STATES.get(self.state, 'unknown: ' + self.state) def __len__(self): return len(self.data_point_ids)
[docs] def keys(self) -> Tuple[str, ...]: return tuple(self.attributes.keys())
def __getitem__(self, key: str) -> torch.Tensor: return self.attributes[key] def __str__(self): # the data_points property overrides the @dataclass field, so we must # eclipse it with our own to stringt return (f'{self.__class__.__name__}: id={self.id}, ' + f'split={self.split_name}, size={self.size()}, ' + f'state={self.state}') def __repr__(self): return self.__str__()
# keep the dataclass semantics, but allow for a setter to add AMR metadata Batch.data_points = Batch._data_points
[docs] @dataclass class DefaultBatch(Batch): """A concrete implementation that uses a :obj:`batch_feature_mapping` usually configured with :class:`.ConfigBatchFeatureMapping` and provided by :class:`.BatchStash`. """ _PERSITABLE_REMOVE_ATTRIBUTES = {'batch_feature_mappings'} batch_feature_mappings: BatchFeatureMapping = field(default=None) """The mappings used by this instance.""" def _get_batch_feature_mappings(self) -> BatchFeatureMapping: assert self.batch_feature_mappings is not None return self.batch_feature_mappings def _clone(self) -> Batch: copy = super()._clone() copy.batch_feature_mappings = self.batch_feature_mappings return copy